Files
streamsql/streamsql_function_integration_test.go

1280 lines
37 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package streamsql
import (
"context"
"testing"
"time"
"github.com/rulego/streamsql/functions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestFunctionIntegrationNonAggregation 测试非聚合函数在SQL中的集成
func TestFunctionIntegrationNonAggregation(t *testing.T) {
t.Run("MathFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试多个数学函数abs, sqrt, round
rsql := "SELECT device, abs(temperature) as abs_temp, sqrt(humidity) as sqrt_humidity, round(temperature) as rounded_temp FROM stream"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := map[string]interface{}{
"device": "test-device",
"temperature": -25.5,
"humidity": 64.0,
}
strm.Emit(testData)
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
require.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "test-device", item["device"])
// 验证 abs(-25.5) = 25.5
assert.InEpsilon(t, 25.5, item["abs_temp"], 0.001)
// 验证 sqrt(64) = 8
assert.InEpsilon(t, 8.0, item["sqrt_humidity"], 0.001)
// 验证 round(-25.5) = -26
assert.InEpsilon(t, -26.0, item["rounded_temp"], 0.001)
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("StringFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试字符串函数upper, lower, concat, length
rsql := "SELECT upper(device) as upper_device, lower(location) as lower_location, concat(device, '-', location) as combined, length(device) as device_len FROM stream"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := map[string]interface{}{
"device": "sensor01",
"location": "ROOM_A",
}
strm.Emit(testData)
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
require.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "SENSOR01", item["upper_device"])
assert.Equal(t, "room_a", item["lower_location"])
assert.Equal(t, "sensor01-ROOM_A", item["combined"])
assert.Equal(t, 8, item["device_len"])
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("ConversionFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试转换函数cast
rsql := "SELECT device, cast(temperature, 'int') as temp_int, cast(humidity, 'string') as humidity_str FROM stream"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := map[string]interface{}{
"device": "test-device",
"temperature": 25.7,
"humidity": 65.0,
}
strm.Emit(testData)
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
require.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "test-device", item["device"])
assert.Equal(t, int32(25), item["temp_int"])
assert.Equal(t, "65", item["humidity_str"])
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("DateTimeFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试日期时间函数now, year, month, day
rsql := "SELECT device, now() as current_time, year(timestamp) as ts_year, month(timestamp) as ts_month FROM stream"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
//testTime := time.Date(2025, 4, 15, 10, 30, 0, 0, time.UTC)
testData := map[string]interface{}{
"device": "test-device",
"timestamp": "2025-08-25",
}
strm.Emit(testData)
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
require.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "test-device", item["device"])
assert.NotNil(t, item["current_time"])
assert.Equal(t, 2025, item["ts_year"])
assert.Equal(t, 8, item["ts_month"])
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("JSONFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试JSON函数json_extract, json_valid
rsql := "SELECT device, json_extract(metadata, '$.type') as device_type, json_valid(metadata) as is_valid_json FROM stream"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := map[string]interface{}{
"device": "test-device",
"metadata": `{"type": "temperature_sensor", "version": "1.0"}`,
}
strm.Emit(testData)
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
require.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "test-device", item["device"])
assert.Equal(t, "temperature_sensor", item["device_type"])
assert.Equal(t, true, item["is_valid_json"])
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
}
// TestFunctionIntegrationAggregation 测试聚合函数在SQL中的集成
func TestFunctionIntegrationAggregation(t *testing.T) {
t.Run("BasicAggregationFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试基本聚合函数sum, avg, min, max, count
rsql := "SELECT device, sum(temperature) as total_temp, avg(temperature) as avg_temp, min(temperature) as min_temp, max(temperature) as max_temp, count(temperature) as temp_count FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 20.0},
{"device": "sensor1", "temperature": 25.0},
{"device": "sensor1", "temperature": 30.0},
{"device": "sensor2", "temperature": 15.0},
{"device": "sensor2", "temperature": 18.0},
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 2)
// 验证sensor1的聚合结果
for _, item := range resultSlice {
device := item["device"].(string)
if device == "sensor1" {
assert.InEpsilon(t, 75.0, item["total_temp"].(float64), 0.001)
assert.InEpsilon(t, 25.0, item["avg_temp"].(float64), 0.001)
assert.InEpsilon(t, 20.0, item["min_temp"].(float64), 0.001)
assert.InEpsilon(t, 30.0, item["max_temp"].(float64), 0.001)
assert.Equal(t, 3.0, item["temp_count"].(float64))
} else if device == "sensor2" {
assert.InEpsilon(t, 33.0, item["total_temp"].(float64), 0.001)
assert.InEpsilon(t, 16.5, item["avg_temp"].(float64), 0.001)
assert.InEpsilon(t, 15.0, item["min_temp"].(float64), 0.001)
assert.InEpsilon(t, 18.0, item["max_temp"].(float64), 0.001)
assert.Equal(t, 2.0, item["temp_count"].(float64))
}
}
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("StatisticalAggregationFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试统计聚合函数stddev, median, percentile
rsql := "SELECT device, stddev(temperature) as temp_stddev, median(temperature) as temp_median FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 10.0},
{"device": "sensor1", "temperature": 20.0},
{"device": "sensor1", "temperature": 30.0},
{"device": "sensor1", "temperature": 40.0},
{"device": "sensor1", "temperature": 50.0},
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
// 标准差应该约为15.81
assert.InEpsilon(t, 15.81, item["temp_stddev"].(float64), 0.1)
// 中位数应该为30.0
assert.InEpsilon(t, 30.0, item["temp_median"].(float64), 0.001)
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("CollectionAggregationFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试集合聚合函数collect, first_value, last_value
rsql := "SELECT device, collect(temperature) as temp_array, first_value(temperature) as first_temp, last_value(temperature) as last_temp FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 20.0},
{"device": "sensor1", "temperature": 25.0},
{"device": "sensor1", "temperature": 30.0},
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
// 验证collect函数返回的数组
tempArray, ok := item["temp_array"].([]interface{})
assert.True(t, ok)
assert.Len(t, tempArray, 3)
assert.Contains(t, tempArray, 20.0)
assert.Contains(t, tempArray, 25.0)
assert.Contains(t, tempArray, 30.0)
// 验证first_value和last_value
assert.Equal(t, 20.0, item["first_temp"])
assert.Equal(t, 30.0, item["last_temp"])
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
}
// TestFunctionIntegrationMixed 测试混合函数场景
func TestFunctionIntegrationMixed(t *testing.T) {
t.Run("AggregationWithNonAggregationFunctions", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试聚合函数与非聚合函数混合使用
rsql := "SELECT device, upper(device) as device_upper, avg(temperature) as avg_temp, round(avg(temperature), 2) as rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 20.567},
{"device": "sensor1", "temperature": 25.234},
{"device": "sensor1", "temperature": 30.123},
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
assert.Equal(t, "SENSOR1", item["device_upper"])
// 验证平均值
if avgTemp, exists := item["avg_temp"]; exists && avgTemp != nil {
assert.InEpsilon(t, 25.308, avgTemp.(float64), 0.001)
} else {
t.Errorf("avg_temp is missing or nil: %v", avgTemp)
}
// 验证四舍五入的平均值
if roundedAvg, exists := item["rounded_avg"]; exists {
if roundedAvg == nil {
t.Errorf("rounded_avg exists but is nil - this indicates the round(avg()) expression failed")
} else if val, ok := roundedAvg.(float64); ok {
// 验证结果在合理范围内
assert.True(t, val >= 25.0 && val <= 25.5, "rounded_avg should be between 25.0 and 25.5, got %v", val)
} else {
t.Errorf("rounded_avg is not a float64: %v (type: %T)", roundedAvg, roundedAvg)
}
} else {
t.Errorf("rounded_avg field is missing from result")
}
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("NestedFunctionCalls", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试嵌套函数调用
rsql := "SELECT device, upper(concat(device, '_', cast(round(temperature, 0), 'string'))) as device_temp_label FROM stream"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := map[string]interface{}{
"device": "sensor1",
"temperature": 25.7,
}
strm.Emit(testData)
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
require.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
// round(25.7, 0) = 26, cast(26, 'string') = "26", concat("sensor1", "_", "26") = "sensor1_26", upper("sensor1_26") = "SENSOR1_26"
assert.Equal(t, "SENSOR1_26", item["device_temp_label"])
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("WindowFunctionsWithAggregation", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试窗口函数与聚合函数结合
rsql := "SELECT device, avg(temperature) as avg_temp, window_start() as start_time, window_end() as end_time FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 20.0},
{"device": "sensor1", "temperature": 30.0},
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
assert.InEpsilon(t, 25.0, item["avg_temp"].(float64), 0.001)
assert.NotNil(t, item["start_time"])
assert.NotNil(t, item["end_time"])
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
}
// TestNestedFunctionSupport 测试嵌套函数支持
func TestNestedFunctionSupport(t *testing.T) {
t.Run("NormalFunctionNestingAggregation", func(t *testing.T) {
// 测试普通函数嵌套聚合函数round(avg(temperature), 2)
streamsql := New()
defer streamsql.Stop()
// 执行包含 round(avg(temperature), 2) 的查询
query := "SELECT device, round(avg(temperature), 2) as rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 20.567},
{"device": "sensor1", "temperature": 25.234},
{"device": "sensor1", "temperature": 30.123},
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
// 验证四舍五入的平均值
if roundedAvg, exists := item["rounded_avg"]; exists {
if roundedAvg == nil {
t.Errorf("rounded_avg exists but is nil - this indicates the round(avg()) expression failed")
} else if val, ok := roundedAvg.(float64); ok {
// 平均值应该是 (20.567 + 25.234 + 30.123) / 3 = 25.308
// round(25.308, 2) = 25.31
assert.InEpsilon(t, 25.31, val, 0.01)
} else {
t.Errorf("rounded_avg is not a float64: %v (type: %T)", roundedAvg, roundedAvg)
}
} else {
t.Errorf("rounded_avg field is missing from result")
}
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("AggregationNestingNormalFunction", func(t *testing.T) {
// 测试聚合函数嵌套普通函数avg(round(temperature, 2))
streamsql := New()
defer streamsql.Stop()
// 执行包含 avg(round(temperature, 2)) 的查询
query := "SELECT device, avg(round(temperature, 2)) as avg_rounded FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 20.567}, // round(20.567, 2) = 20.57
{"device": "sensor1", "temperature": 25.234}, // round(25.234, 2) = 25.23
{"device": "sensor1", "temperature": 30.123}, // round(30.123, 2) = 30.12
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
// 验证聚合函数嵌套普通函数的结果
if avgRounded, exists := item["avg_rounded"]; exists {
if avgRounded == nil {
t.Errorf("avg_rounded exists but is nil - this indicates the avg(round()) expression failed")
} else if val, ok := avgRounded.(float64); ok {
// 期望值avg(20.57, 25.23, 30.12) = (20.57 + 25.23 + 30.12) / 3 = 25.31
assert.InEpsilon(t, 25.31, val, 0.01)
} else {
t.Errorf("avg_rounded is not a float64: %v (type: %T)", avgRounded, avgRounded)
}
} else {
t.Errorf("avg_rounded field is missing from result")
}
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("ComplexNestedFunctions", func(t *testing.T) {
// 测试更复杂的嵌套函数round(avg(abs(temperature)), 1)
streamsql := New()
defer streamsql.Stop()
// 执行包含 round(avg(abs(temperature)), 1) 的查询
query := "SELECT device, round(avg(abs(temperature)), 1) as complex_result FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据(包含负数)
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": -20.567}, // abs(-20.567) = 20.567
{"device": "sensor1", "temperature": 25.234}, // abs(25.234) = 25.234
{"device": "sensor1", "temperature": -30.123}, // abs(-30.123) = 30.123
}
for _, data := range testData {
strm.Emit(data)
}
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
// 验证复杂嵌套函数的结果
if complexResult, exists := item["complex_result"]; exists {
if complexResult == nil {
t.Errorf("complex_result exists but is nil - this indicates the round(avg(abs())) expression failed")
} else if val, ok := complexResult.(float64); ok {
// 期望值avg(20.567, 25.234, 30.123) = 25.308, round(25.308, 1) = 25.3
assert.InEpsilon(t, 25.3, val, 0.01)
} else {
t.Errorf("complex_result is not a float64: %v (type: %T)", complexResult, complexResult)
}
} else {
t.Errorf("complex_result field is missing from result")
}
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
}
// TestNestedFunctionExecutionOrder 测试嵌套函数的执行顺序和不同类型函数的组合
func TestNestedFunctionExecutionOrder(t *testing.T) {
// 测试1: 字符串函数嵌套数学函数
t.Run("StringFunctionNestingMathFunction", func(t *testing.T) {
// 测试 upper(concat("temp_", round(temperature, 1)))
streamsql := New()
defer streamsql.Stop()
query := "SELECT device, upper(concat('temp_', round(temperature, 1))) as formatted_temp FROM stream"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 25.67})
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
// 验证执行顺序round(25.67, 1) -> 25.7, concat('temp_', '25.7') -> 'temp_25.7', upper('temp_25.7') -> 'TEMP_25.7'
assert.Equal(t, "TEMP_25.7", item["formatted_temp"])
case <-ctx.Done():
t.Fatal("测试超时")
}
})
// 测试2: 数学函数嵌套字符串函数
t.Run("MathFunctionNestingStringFunction", func(t *testing.T) {
// 测试 round(len(upper(device)), 0)
streamsql := New()
defer streamsql.Stop()
query := "SELECT device, round(len(upper(device)), 0) as device_length FROM stream"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
strm.Emit(map[string]interface{}{"device": "sensor1"})
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
// 验证执行顺序upper('sensor1') -> 'SENSOR1', len('SENSOR1') -> 7, round(7, 0) -> 7
assert.Equal(t, float64(7), item["device_length"])
case <-ctx.Done():
t.Fatal("测试超时")
}
})
// 测试3: 多层嵌套函数3层
t.Run("ThreeLevelNestedFunctions", func(t *testing.T) {
// 测试 abs(round(sqrt(temperature), 2))
streamsql := New()
defer streamsql.Stop()
query := "SELECT device, abs(round(sqrt(temperature), 2)) as processed_temp FROM stream"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 16.0})
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
// 验证执行顺序sqrt(16) -> 4, round(4, 2) -> 4.00, abs(4.00) -> 4.00
assert.Equal(t, float64(4), item["processed_temp"])
case <-ctx.Done():
t.Fatal("测试超时")
}
})
// 测试6: 复杂的聚合函数嵌套 - 应该报错
t.Run("ComplexAggregationNesting", func(t *testing.T) {
// 测试 max(round(avg(temperature), 1)) - 这是嵌套聚合函数,应该报错
streamsql := New()
defer streamsql.Stop()
query := "SELECT device, max(round(avg(temperature), 1)) as max_rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')"
err := streamsql.Execute(query)
// 应该返回嵌套聚合函数错误
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "aggregate function calls cannot be nested")
})
// 测试7: 其他类型的嵌套聚合函数检测
t.Run("NestedAggregationDetection", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 测试 sum(count(*)) - 聚合函数嵌套聚合函数
query1 := "SELECT sum(count(*)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')"
err1 := streamsql.Execute(query1)
assert.NotNil(t, err1)
assert.Contains(t, err1.Error(), "aggregate function calls cannot be nested")
// 测试 avg(min(temperature)) - 聚合函数嵌套聚合函数
query2 := "SELECT avg(min(temperature)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')"
err2 := streamsql.Execute(query2)
assert.NotNil(t, err2)
assert.Contains(t, err2.Error(), "aggregate function calls cannot be nested")
// 测试 round(avg(temperature), 1) - 正常函数嵌套聚合函数,应该正常
query3 := "SELECT round(avg(temperature), 1) as normal_nesting FROM stream GROUP BY device, TumblingWindow('1s')"
err3 := streamsql.Execute(query3)
assert.Nil(t, err3) // 这种嵌套应该是允许的
})
// 测试7: 日期时间函数嵌套
t.Run("DateTimeFunctionNesting", func(t *testing.T) {
// 测试 year(date_add(created_at, 1, 'years'))
streamsql := New()
defer streamsql.Stop()
query := "SELECT device, year(date_add(created_at, 1, 'years')) as next_year FROM stream"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据
strm.Emit(map[string]interface{}{"device": "sensor1", "created_at": "2023-12-25 15:30:45"})
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
// 验证执行顺序date_add('2023-12-25 15:30:45', 1, 'years') -> '2024-12-25 15:30:45', year('2024-12-25 15:30:45') -> 2024
assert.Equal(t, 2024, item["next_year"])
case <-ctx.Done():
t.Fatal("测试超时")
}
})
// 测试8: 错误的嵌套函数执行顺序
t.Run("ErrorHandlingInNestedFunctions", func(t *testing.T) {
// 测试 sqrt(len(invalid_field)) - 应该处理错误
streamsql := New()
defer streamsql.Stop()
query := "SELECT device, sqrt(len(invalid_field)) as error_result FROM stream"
err := streamsql.Execute(query)
assert.Nil(t, err)
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 添加测试数据不包含invalid_field
strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 25.0})
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
// 验证错误处理invalid_field不存在应该返回nil或默认值
_, exists := item["error_result"]
assert.True(t, exists)
case <-ctx.Done():
t.Fatal("测试超时")
}
})
}
// flattenUnnestRows 将可能包含 unnest 结果的批次结果展开为多行,便于断言
// 兼容两种形态:
// 1) 当前实现:返回单行,其中 alias 字段为 []interface{}(需要在测试侧展开)
// 2) 未来实现:引擎直接返回多行(此时原样返回)
func flattenUnnestRows(result []map[string]interface{}, alias string) []map[string]interface{} {
// 如果已经是多行,直接返回
if len(result) > 1 {
return result
}
if len(result) == 0 {
return result
}
// 形如:[{ alias: []interface{}{...} , ...}]
if v, ok := result[0][alias]; ok {
if functions.IsUnnestResult(v) {
// 使用ProcessUnnestResultWithFieldName保留字段名并合并其他字段
expandedRows := functions.ProcessUnnestResultWithFieldName(v, alias)
if len(expandedRows) == 0 {
return result
}
// 将其他字段合并到每一行中
results := make([]map[string]interface{}, len(expandedRows))
for i, unnestRow := range expandedRows {
newRow := make(map[string]interface{}, len(result[0])+len(unnestRow))
// 复制原始行的其他字段除了unnest字段
for k, v := range result[0] {
if k != alias {
newRow[k] = v
}
}
// 添加unnest展开的字段
for k, v := range unnestRow {
newRow[k] = v
}
results[i] = newRow
}
return results
}
}
return result
}
// TestUnnestFunctionIntegration 验证 unnest(array) 是否按预期将数组展开为多行
// 该用例集成到完整 SQL 执行路径:
// - 语法: unnest(array)
// - 描述: 将数组展开为多行
// - 示例: SELECT unnest(tags) as tag FROM stream
func TestUnnestFunctionIntegration(t *testing.T) {
t.Run("PrimitiveArray", func(t *testing.T) {
ssql := New()
defer ssql.Stop()
sql := "SELECT unnest(tags) as tag FROM stream"
err := ssql.Execute(sql)
require.NoError(t, err)
strm := ssql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 输入为普通字符串数组
input := map[string]interface{}{
"tags": []string{"a", "b", "c"},
}
strm.Emit(input)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case raw := <-resultChan:
batch, ok := raw.([]map[string]interface{})
require.True(t, ok)
// 按两种形态规范化为多行
rows := flattenUnnestRows(batch, "tag")
require.Len(t, rows, 3)
expected := []string{"a", "b", "c"}
for i, exp := range expected {
row := rows[i]
// 兼容两种字段命名:引擎直接展开可能使用别名(tag),函数侧展开为默认字段(value)
var got interface{}
if v, ok := row["tag"]; ok {
got = v
} else if v, ok := row["value"]; ok {
got = v
} else {
t.Fatalf("row %d does not contain expected field 'tag' or 'value': %v", i, row)
}
assert.Equal(t, exp, got)
}
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("CombinedColumns", func(t *testing.T) {
ssql := New()
defer ssql.Stop()
// 测试组合列SELECT id,unnest(tags) as tag FROM events
sql := "SELECT id, unnest(tags) as tag FROM stream"
err := ssql.Execute(sql)
require.NoError(t, err)
strm := ssql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 输入包含id字段和tags数组
input := map[string]interface{}{
"id": 100,
"tags": []string{"a", "b", "c"},
}
strm.Emit(input)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case raw := <-resultChan:
batch, ok := raw.([]map[string]interface{})
require.True(t, ok)
// 展开unnest结果
rows := flattenUnnestRows(batch, "tag")
require.Len(t, rows, 3)
// 验证每行都包含id字段和tag字段
expectedTags := []string{"a", "b", "c"}
for i, expectedTag := range expectedTags {
row := rows[i]
// 验证id字段保持不变
assert.Equal(t, 100, row["id"], "row %d should have id=100", i)
// 验证tag字段
var gotTag interface{}
if v, ok := row["tag"]; ok {
gotTag = v
} else if v, ok := row["value"]; ok {
gotTag = v
} else {
t.Fatalf("row %d does not contain expected field 'tag' or 'value': %v", i, row)
}
assert.Equal(t, expectedTag, gotTag, "row %d should have tag=%s", i, expectedTag)
}
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("ObjectArray", func(t *testing.T) {
ssql := New()
defer ssql.Stop()
sql := "SELECT unnest(props) as prop FROM stream"
err := ssql.Execute(sql)
require.NoError(t, err)
strm := ssql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 输入为对象数组
input := map[string]interface{}{
"props": []map[string]interface{}{
{"k": "x", "v": 1},
{"k": "y", "v": 2},
},
}
strm.Emit(input)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case raw := <-resultChan:
batch, ok := raw.([]map[string]interface{})
require.True(t, ok)
rows := flattenUnnestRows(batch, "prop")
require.Len(t, rows, 2)
// 校验每一行包含对象内的字段
assert.Equal(t, "x", firstOf(rows[0], "k", "prop", "k"))
assert.Equal(t, 1, firstOf(rows[0], "v", "prop", "v"))
assert.Equal(t, "y", firstOf(rows[1], "k", "prop", "k"))
assert.Equal(t, 2, firstOf(rows[1], "v", "prop", "v"))
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("EmptyArray", func(t *testing.T) {
ssql := New()
defer ssql.Stop()
sql := "SELECT unnest(tags) as tag FROM stream"
err := ssql.Execute(sql)
require.NoError(t, err)
strm := ssql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 空数组
input := map[string]interface{}{
"tags": []string{},
}
strm.Emit(input)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case raw := <-resultChan:
batch, ok := raw.([]map[string]interface{})
require.True(t, ok)
rows := flattenUnnestRows(batch, "tag")
assert.Len(t, rows, 0)
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
t.Run("NilArray", func(t *testing.T) {
ssql := New()
defer ssql.Stop()
sql := "SELECT unnest(tags) as tag FROM stream"
err := ssql.Execute(sql)
require.NoError(t, err)
strm := ssql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// nil 值
input := map[string]interface{}{
"tags": nil,
}
strm.Emit(input)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case raw := <-resultChan:
batch, ok := raw.([]map[string]interface{})
require.True(t, ok)
rows := flattenUnnestRows(batch, "tag")
assert.Len(t, rows, 0)
case <-ctx.Done():
t.Fatal("测试超时,未收到结果")
}
})
}
// firstOf 辅助从行中读取字段值,兼容 prop 为对象的形态
// 优先按 top-level 字段取值,若不存在则尝试从嵌套对象(如 prop[k])获取
func firstOf(row map[string]interface{}, topLevelKey string, nestedObjKey string, nestedField string) interface{} {
if v, ok := row[topLevelKey]; ok {
return v
}
if m, ok := row[nestedObjKey].(map[string]interface{}); ok {
return m[nestedField]
}
return nil
}