package streamsql import ( "context" "testing" "time" "github.com/rulego/streamsql/functions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestFunctionIntegrationNonAggregation 测试非聚合函数在SQL中的集成 func TestFunctionIntegrationNonAggregation(t *testing.T) { t.Run("MathFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试多个数学函数:abs, sqrt, round rsql := "SELECT device, abs(temperature) as abs_temp, sqrt(humidity) as sqrt_humidity, round(temperature) as rounded_temp FROM stream" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "device": "test-device", "temperature": -25.5, "humidity": 64.0, } strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) require.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "test-device", item["device"]) // 验证 abs(-25.5) = 25.5 assert.InEpsilon(t, 25.5, item["abs_temp"], 0.001) // 验证 sqrt(64) = 8 assert.InEpsilon(t, 8.0, item["sqrt_humidity"], 0.001) // 验证 round(-25.5) = -26 assert.InEpsilon(t, -26.0, item["rounded_temp"], 0.001) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("StringFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试字符串函数:upper, lower, concat, length rsql := "SELECT upper(device) as upper_device, lower(location) as lower_location, concat(device, '-', location) as combined, length(device) as device_len FROM stream" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "device": "sensor01", "location": "ROOM_A", } strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) require.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "SENSOR01", item["upper_device"]) assert.Equal(t, "room_a", item["lower_location"]) assert.Equal(t, "sensor01-ROOM_A", item["combined"]) assert.Equal(t, 8, item["device_len"]) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("ConversionFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试转换函数:cast rsql := "SELECT device, cast(temperature, 'int') as temp_int, cast(humidity, 'string') as humidity_str FROM stream" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "device": "test-device", "temperature": 25.7, "humidity": 65.0, } strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) require.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "test-device", item["device"]) assert.Equal(t, int32(25), item["temp_int"]) assert.Equal(t, "65", item["humidity_str"]) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("DateTimeFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试日期时间函数:now, year, month, day rsql := "SELECT device, now() as current_time, year(timestamp) as ts_year, month(timestamp) as ts_month FROM stream" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 //testTime := time.Date(2025, 4, 15, 10, 30, 0, 0, time.UTC) testData := map[string]interface{}{ "device": "test-device", "timestamp": "2025-08-25", } strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) require.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "test-device", item["device"]) assert.NotNil(t, item["current_time"]) assert.Equal(t, 2025, item["ts_year"]) assert.Equal(t, 8, item["ts_month"]) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("JSONFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试JSON函数:json_extract, json_valid rsql := "SELECT device, json_extract(metadata, '$.type') as device_type, json_valid(metadata) as is_valid_json FROM stream" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "device": "test-device", "metadata": `{"type": "temperature_sensor", "version": "1.0"}`, } strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) require.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "test-device", item["device"]) assert.Equal(t, "temperature_sensor", item["device_type"]) assert.Equal(t, true, item["is_valid_json"]) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) } // TestFunctionIntegrationAggregation 测试聚合函数在SQL中的集成 func TestFunctionIntegrationAggregation(t *testing.T) { t.Run("BasicAggregationFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试基本聚合函数:sum, avg, min, max, count rsql := "SELECT device, sum(temperature) as total_temp, avg(temperature) as avg_temp, min(temperature) as min_temp, max(temperature) as max_temp, count(temperature) as temp_count FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := []map[string]interface{}{ {"device": "sensor1", "temperature": 20.0}, {"device": "sensor1", "temperature": 25.0}, {"device": "sensor1", "temperature": 30.0}, {"device": "sensor2", "temperature": 15.0}, {"device": "sensor2", "temperature": 18.0}, } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 2) // 验证sensor1的聚合结果 for _, item := range resultSlice { device := item["device"].(string) if device == "sensor1" { assert.InEpsilon(t, 75.0, item["total_temp"].(float64), 0.001) assert.InEpsilon(t, 25.0, item["avg_temp"].(float64), 0.001) assert.InEpsilon(t, 20.0, item["min_temp"].(float64), 0.001) assert.InEpsilon(t, 30.0, item["max_temp"].(float64), 0.001) assert.Equal(t, 3.0, item["temp_count"].(float64)) } else if device == "sensor2" { assert.InEpsilon(t, 33.0, item["total_temp"].(float64), 0.001) assert.InEpsilon(t, 16.5, item["avg_temp"].(float64), 0.001) assert.InEpsilon(t, 15.0, item["min_temp"].(float64), 0.001) assert.InEpsilon(t, 18.0, item["max_temp"].(float64), 0.001) assert.Equal(t, 2.0, item["temp_count"].(float64)) } } case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("StatisticalAggregationFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试统计聚合函数:stddev, median, percentile rsql := "SELECT device, stddev(temperature) as temp_stddev, median(temperature) as temp_median FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := []map[string]interface{}{ {"device": "sensor1", "temperature": 10.0}, {"device": "sensor1", "temperature": 20.0}, {"device": "sensor1", "temperature": 30.0}, {"device": "sensor1", "temperature": 40.0}, {"device": "sensor1", "temperature": 50.0}, } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) // 标准差应该约为15.81 assert.InEpsilon(t, 15.81, item["temp_stddev"].(float64), 0.1) // 中位数应该为30.0 assert.InEpsilon(t, 30.0, item["temp_median"].(float64), 0.001) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("CollectionAggregationFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试集合聚合函数:collect, first_value, last_value rsql := "SELECT device, collect(temperature) as temp_array, first_value(temperature) as first_temp, last_value(temperature) as last_temp FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := []map[string]interface{}{ {"device": "sensor1", "temperature": 20.0}, {"device": "sensor1", "temperature": 25.0}, {"device": "sensor1", "temperature": 30.0}, } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) // 验证collect函数返回的数组 tempArray, ok := item["temp_array"].([]interface{}) assert.True(t, ok) assert.Len(t, tempArray, 3) assert.Contains(t, tempArray, 20.0) assert.Contains(t, tempArray, 25.0) assert.Contains(t, tempArray, 30.0) // 验证first_value和last_value assert.Equal(t, 20.0, item["first_temp"]) assert.Equal(t, 30.0, item["last_temp"]) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) } // TestFunctionIntegrationMixed 测试混合函数场景 func TestFunctionIntegrationMixed(t *testing.T) { t.Run("AggregationWithNonAggregationFunctions", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试聚合函数与非聚合函数混合使用 rsql := "SELECT device, upper(device) as device_upper, avg(temperature) as avg_temp, round(avg(temperature), 2) as rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := []map[string]interface{}{ {"device": "sensor1", "temperature": 20.567}, {"device": "sensor1", "temperature": 25.234}, {"device": "sensor1", "temperature": 30.123}, } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) assert.Equal(t, "SENSOR1", item["device_upper"]) // 验证平均值 if avgTemp, exists := item["avg_temp"]; exists && avgTemp != nil { assert.InEpsilon(t, 25.308, avgTemp.(float64), 0.001) } else { t.Errorf("avg_temp is missing or nil: %v", avgTemp) } // 验证四舍五入的平均值 if roundedAvg, exists := item["rounded_avg"]; exists { if roundedAvg == nil { t.Errorf("rounded_avg exists but is nil - this indicates the round(avg()) expression failed") } else if val, ok := roundedAvg.(float64); ok { // 验证结果在合理范围内 assert.True(t, val >= 25.0 && val <= 25.5, "rounded_avg should be between 25.0 and 25.5, got %v", val) } else { t.Errorf("rounded_avg is not a float64: %v (type: %T)", roundedAvg, roundedAvg) } } else { t.Errorf("rounded_avg field is missing from result") } case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("NestedFunctionCalls", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试嵌套函数调用 rsql := "SELECT device, upper(concat(device, '_', cast(round(temperature, 0), 'string'))) as device_temp_label FROM stream" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "device": "sensor1", "temperature": 25.7, } strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) require.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) // round(25.7, 0) = 26, cast(26, 'string') = "26", concat("sensor1", "_", "26") = "sensor1_26", upper("sensor1_26") = "SENSOR1_26" assert.Equal(t, "SENSOR1_26", item["device_temp_label"]) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("WindowFunctionsWithAggregation", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试窗口函数与聚合函数结合 rsql := "SELECT device, avg(temperature) as avg_temp, window_start() as start_time, window_end() as end_time FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := []map[string]interface{}{ {"device": "sensor1", "temperature": 20.0}, {"device": "sensor1", "temperature": 30.0}, } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) assert.InEpsilon(t, 25.0, item["avg_temp"].(float64), 0.001) assert.NotNil(t, item["start_time"]) assert.NotNil(t, item["end_time"]) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) } // TestNestedFunctionSupport 测试嵌套函数支持 func TestNestedFunctionSupport(t *testing.T) { t.Run("NormalFunctionNestingAggregation", func(t *testing.T) { // 测试普通函数嵌套聚合函数:round(avg(temperature), 2) streamsql := New() defer streamsql.Stop() // 执行包含 round(avg(temperature), 2) 的查询 query := "SELECT device, round(avg(temperature), 2) as rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := []map[string]interface{}{ {"device": "sensor1", "temperature": 20.567}, {"device": "sensor1", "temperature": 25.234}, {"device": "sensor1", "temperature": 30.123}, } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) // 验证四舍五入的平均值 if roundedAvg, exists := item["rounded_avg"]; exists { if roundedAvg == nil { t.Errorf("rounded_avg exists but is nil - this indicates the round(avg()) expression failed") } else if val, ok := roundedAvg.(float64); ok { // 平均值应该是 (20.567 + 25.234 + 30.123) / 3 = 25.308 // round(25.308, 2) = 25.31 assert.InEpsilon(t, 25.31, val, 0.01) } else { t.Errorf("rounded_avg is not a float64: %v (type: %T)", roundedAvg, roundedAvg) } } else { t.Errorf("rounded_avg field is missing from result") } case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("AggregationNestingNormalFunction", func(t *testing.T) { // 测试聚合函数嵌套普通函数:avg(round(temperature, 2)) streamsql := New() defer streamsql.Stop() // 执行包含 avg(round(temperature, 2)) 的查询 query := "SELECT device, avg(round(temperature, 2)) as avg_rounded FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 testData := []map[string]interface{}{ {"device": "sensor1", "temperature": 20.567}, // round(20.567, 2) = 20.57 {"device": "sensor1", "temperature": 25.234}, // round(25.234, 2) = 25.23 {"device": "sensor1", "temperature": 30.123}, // round(30.123, 2) = 30.12 } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) // 验证聚合函数嵌套普通函数的结果 if avgRounded, exists := item["avg_rounded"]; exists { if avgRounded == nil { t.Errorf("avg_rounded exists but is nil - this indicates the avg(round()) expression failed") } else if val, ok := avgRounded.(float64); ok { // 期望值:avg(20.57, 25.23, 30.12) = (20.57 + 25.23 + 30.12) / 3 = 25.31 assert.InEpsilon(t, 25.31, val, 0.01) } else { t.Errorf("avg_rounded is not a float64: %v (type: %T)", avgRounded, avgRounded) } } else { t.Errorf("avg_rounded field is missing from result") } case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("ComplexNestedFunctions", func(t *testing.T) { // 测试更复杂的嵌套函数:round(avg(abs(temperature)), 1) streamsql := New() defer streamsql.Stop() // 执行包含 round(avg(abs(temperature)), 1) 的查询 query := "SELECT device, round(avg(abs(temperature)), 1) as complex_result FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据(包含负数) testData := []map[string]interface{}{ {"device": "sensor1", "temperature": -20.567}, // abs(-20.567) = 20.567 {"device": "sensor1", "temperature": 25.234}, // abs(25.234) = 25.234 {"device": "sensor1", "temperature": -30.123}, // abs(-30.123) = 30.123 } for _, data := range testData { strm.Emit(data) } // 等待窗口初始化 time.Sleep(1 * time.Second) strm.Window.Trigger() // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) // 验证复杂嵌套函数的结果 if complexResult, exists := item["complex_result"]; exists { if complexResult == nil { t.Errorf("complex_result exists but is nil - this indicates the round(avg(abs())) expression failed") } else if val, ok := complexResult.(float64); ok { // 期望值:avg(20.567, 25.234, 30.123) = 25.308, round(25.308, 1) = 25.3 assert.InEpsilon(t, 25.3, val, 0.01) } else { t.Errorf("complex_result is not a float64: %v (type: %T)", complexResult, complexResult) } } else { t.Errorf("complex_result field is missing from result") } case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) } // TestNestedFunctionExecutionOrder 测试嵌套函数的执行顺序和不同类型函数的组合 func TestNestedFunctionExecutionOrder(t *testing.T) { // 测试1: 字符串函数嵌套数学函数 t.Run("StringFunctionNestingMathFunction", func(t *testing.T) { // 测试 upper(concat("temp_", round(temperature, 1))) streamsql := New() defer streamsql.Stop() query := "SELECT device, upper(concat('temp_', round(temperature, 1))) as formatted_temp FROM stream" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 25.67}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] // 验证执行顺序:round(25.67, 1) -> 25.7, concat('temp_', '25.7') -> 'temp_25.7', upper('temp_25.7') -> 'TEMP_25.7' assert.Equal(t, "TEMP_25.7", item["formatted_temp"]) case <-ctx.Done(): t.Fatal("测试超时") } }) // 测试2: 数学函数嵌套字符串函数 t.Run("MathFunctionNestingStringFunction", func(t *testing.T) { // 测试 round(len(upper(device)), 0) streamsql := New() defer streamsql.Stop() query := "SELECT device, round(len(upper(device)), 0) as device_length FROM stream" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 strm.Emit(map[string]interface{}{"device": "sensor1"}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] // 验证执行顺序:upper('sensor1') -> 'SENSOR1', len('SENSOR1') -> 7, round(7, 0) -> 7 assert.Equal(t, float64(7), item["device_length"]) case <-ctx.Done(): t.Fatal("测试超时") } }) // 测试3: 多层嵌套函数(3层) t.Run("ThreeLevelNestedFunctions", func(t *testing.T) { // 测试 abs(round(sqrt(temperature), 2)) streamsql := New() defer streamsql.Stop() query := "SELECT device, abs(round(sqrt(temperature), 2)) as processed_temp FROM stream" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 16.0}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] // 验证执行顺序:sqrt(16) -> 4, round(4, 2) -> 4.00, abs(4.00) -> 4.00 assert.Equal(t, float64(4), item["processed_temp"]) case <-ctx.Done(): t.Fatal("测试超时") } }) // 测试6: 复杂的聚合函数嵌套 - 应该报错 t.Run("ComplexAggregationNesting", func(t *testing.T) { // 测试 max(round(avg(temperature), 1)) - 这是嵌套聚合函数,应该报错 streamsql := New() defer streamsql.Stop() query := "SELECT device, max(round(avg(temperature), 1)) as max_rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')" err := streamsql.Execute(query) // 应该返回嵌套聚合函数错误 assert.NotNil(t, err) assert.Contains(t, err.Error(), "aggregate function calls cannot be nested") }) // 测试7: 其他类型的嵌套聚合函数检测 t.Run("NestedAggregationDetection", func(t *testing.T) { streamsql := New() defer streamsql.Stop() // 测试 sum(count(*)) - 聚合函数嵌套聚合函数 query1 := "SELECT sum(count(*)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')" err1 := streamsql.Execute(query1) assert.NotNil(t, err1) assert.Contains(t, err1.Error(), "aggregate function calls cannot be nested") // 测试 avg(min(temperature)) - 聚合函数嵌套聚合函数 query2 := "SELECT avg(min(temperature)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')" err2 := streamsql.Execute(query2) assert.NotNil(t, err2) assert.Contains(t, err2.Error(), "aggregate function calls cannot be nested") // 测试 round(avg(temperature), 1) - 正常函数嵌套聚合函数,应该正常 query3 := "SELECT round(avg(temperature), 1) as normal_nesting FROM stream GROUP BY device, TumblingWindow('1s')" err3 := streamsql.Execute(query3) assert.Nil(t, err3) // 这种嵌套应该是允许的 }) // 测试7: 日期时间函数嵌套 t.Run("DateTimeFunctionNesting", func(t *testing.T) { // 测试 year(date_add(created_at, 1, 'years')) streamsql := New() defer streamsql.Stop() query := "SELECT device, year(date_add(created_at, 1, 'years')) as next_year FROM stream" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据 strm.Emit(map[string]interface{}{"device": "sensor1", "created_at": "2023-12-25 15:30:45"}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] // 验证执行顺序:date_add('2023-12-25 15:30:45', 1, 'years') -> '2024-12-25 15:30:45', year('2024-12-25 15:30:45') -> 2024 assert.Equal(t, 2024, item["next_year"]) case <-ctx.Done(): t.Fatal("测试超时") } }) // 测试8: 错误的嵌套函数执行顺序 t.Run("ErrorHandlingInNestedFunctions", func(t *testing.T) { // 测试 sqrt(len(invalid_field)) - 应该处理错误 streamsql := New() defer streamsql.Stop() query := "SELECT device, sqrt(len(invalid_field)) as error_result FROM stream" err := streamsql.Execute(query) assert.Nil(t, err) strm := streamsql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 添加测试数据(不包含invalid_field) strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 25.0}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) require.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] // 验证错误处理:invalid_field不存在,应该返回nil或默认值 _, exists := item["error_result"] assert.True(t, exists) case <-ctx.Done(): t.Fatal("测试超时") } }) } // flattenUnnestRows 将可能包含 unnest 结果的批次结果展开为多行,便于断言 // 兼容两种形态: // 1) 当前实现:返回单行,其中 alias 字段为 []interface{}(需要在测试侧展开) // 2) 未来实现:引擎直接返回多行(此时原样返回) func flattenUnnestRows(result []map[string]interface{}, alias string) []map[string]interface{} { // 如果已经是多行,直接返回 if len(result) > 1 { return result } if len(result) == 0 { return result } // 形如:[{ alias: []interface{}{...} , ...}] if v, ok := result[0][alias]; ok { if functions.IsUnnestResult(v) { // 使用ProcessUnnestResultWithFieldName保留字段名,并合并其他字段 expandedRows := functions.ProcessUnnestResultWithFieldName(v, alias) if len(expandedRows) == 0 { return result } // 将其他字段合并到每一行中 results := make([]map[string]interface{}, len(expandedRows)) for i, unnestRow := range expandedRows { newRow := make(map[string]interface{}, len(result[0])+len(unnestRow)) // 复制原始行的其他字段(除了unnest字段) for k, v := range result[0] { if k != alias { newRow[k] = v } } // 添加unnest展开的字段 for k, v := range unnestRow { newRow[k] = v } results[i] = newRow } return results } } return result } // TestUnnestFunctionIntegration 验证 unnest(array) 是否按预期将数组展开为多行 // 该用例集成到完整 SQL 执行路径: // - 语法: unnest(array) // - 描述: 将数组展开为多行 // - 示例: SELECT unnest(tags) as tag FROM stream func TestUnnestFunctionIntegration(t *testing.T) { t.Run("PrimitiveArray", func(t *testing.T) { ssql := New() defer ssql.Stop() sql := "SELECT unnest(tags) as tag FROM stream" err := ssql.Execute(sql) require.NoError(t, err) strm := ssql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 输入为普通字符串数组 input := map[string]interface{}{ "tags": []string{"a", "b", "c"}, } strm.Emit(input) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case raw := <-resultChan: batch, ok := raw.([]map[string]interface{}) require.True(t, ok) // 按两种形态规范化为多行 rows := flattenUnnestRows(batch, "tag") require.Len(t, rows, 3) expected := []string{"a", "b", "c"} for i, exp := range expected { row := rows[i] // 兼容两种字段命名:引擎直接展开可能使用别名(tag),函数侧展开为默认字段(value) var got interface{} if v, ok := row["tag"]; ok { got = v } else if v, ok := row["value"]; ok { got = v } else { t.Fatalf("row %d does not contain expected field 'tag' or 'value': %v", i, row) } assert.Equal(t, exp, got) } case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("CombinedColumns", func(t *testing.T) { ssql := New() defer ssql.Stop() // 测试组合列:SELECT id,unnest(tags) as tag FROM events sql := "SELECT id, unnest(tags) as tag FROM stream" err := ssql.Execute(sql) require.NoError(t, err) strm := ssql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 输入包含id字段和tags数组 input := map[string]interface{}{ "id": 100, "tags": []string{"a", "b", "c"}, } strm.Emit(input) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case raw := <-resultChan: batch, ok := raw.([]map[string]interface{}) require.True(t, ok) // 展开unnest结果 rows := flattenUnnestRows(batch, "tag") require.Len(t, rows, 3) // 验证每行都包含id字段和tag字段 expectedTags := []string{"a", "b", "c"} for i, expectedTag := range expectedTags { row := rows[i] // 验证id字段保持不变 assert.Equal(t, 100, row["id"], "row %d should have id=100", i) // 验证tag字段 var gotTag interface{} if v, ok := row["tag"]; ok { gotTag = v } else if v, ok := row["value"]; ok { gotTag = v } else { t.Fatalf("row %d does not contain expected field 'tag' or 'value': %v", i, row) } assert.Equal(t, expectedTag, gotTag, "row %d should have tag=%s", i, expectedTag) } case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("ObjectArray", func(t *testing.T) { ssql := New() defer ssql.Stop() sql := "SELECT unnest(props) as prop FROM stream" err := ssql.Execute(sql) require.NoError(t, err) strm := ssql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 输入为对象数组 input := map[string]interface{}{ "props": []map[string]interface{}{ {"k": "x", "v": 1}, {"k": "y", "v": 2}, }, } strm.Emit(input) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case raw := <-resultChan: batch, ok := raw.([]map[string]interface{}) require.True(t, ok) rows := flattenUnnestRows(batch, "prop") require.Len(t, rows, 2) // 校验每一行包含对象内的字段 assert.Equal(t, "x", firstOf(rows[0], "k", "prop", "k")) assert.Equal(t, 1, firstOf(rows[0], "v", "prop", "v")) assert.Equal(t, "y", firstOf(rows[1], "k", "prop", "k")) assert.Equal(t, 2, firstOf(rows[1], "v", "prop", "v")) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("EmptyArray", func(t *testing.T) { ssql := New() defer ssql.Stop() sql := "SELECT unnest(tags) as tag FROM stream" err := ssql.Execute(sql) require.NoError(t, err) strm := ssql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // 空数组 input := map[string]interface{}{ "tags": []string{}, } strm.Emit(input) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case raw := <-resultChan: batch, ok := raw.([]map[string]interface{}) require.True(t, ok) rows := flattenUnnestRows(batch, "tag") assert.Len(t, rows, 0) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) t.Run("NilArray", func(t *testing.T) { ssql := New() defer ssql.Stop() sql := "SELECT unnest(tags) as tag FROM stream" err := ssql.Execute(sql) require.NoError(t, err) strm := ssql.stream resultChan := make(chan interface{}, 10) strm.AddSink(func(result []map[string]interface{}) { resultChan <- result }) // nil 值 input := map[string]interface{}{ "tags": nil, } strm.Emit(input) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() select { case raw := <-resultChan: batch, ok := raw.([]map[string]interface{}) require.True(t, ok) rows := flattenUnnestRows(batch, "tag") assert.Len(t, rows, 0) case <-ctx.Done(): t.Fatal("测试超时,未收到结果") } }) } // firstOf 辅助从行中读取字段值,兼容 prop 为对象的形态 // 优先按 top-level 字段取值,若不存在则尝试从嵌套对象(如 prop[k])获取 func firstOf(row map[string]interface{}, topLevelKey string, nestedObjKey string, nestedField string) interface{} { if v, ok := row[topLevelKey]; ok { return v } if m, ok := row[nestedObjKey].(map[string]interface{}); ok { return m[nestedField] } return nil }