package streamsql import ( "encoding/json" "fmt" "math" "net" "testing" "time" "github.com/rulego/streamsql/utils/cast" "github.com/rulego/streamsql/aggregator" "github.com/rulego/streamsql/expr" "github.com/rulego/streamsql/functions" "github.com/rulego/streamsql/rsql" "github.com/stretchr/testify/assert" ) // TestCustomMathFunctions 测试自定义数学函数 func TestCustomMathFunctions(t *testing.T) { // 注册平方函数 err := functions.RegisterCustomFunction( "square", functions.TypeMath, "数学函数", "计算平方", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { val := cast.ToFloat64(args[0]) return val * val, nil }, ) assert.NoError(t, err) defer functions.Unregister("square") // 注册距离计算函数 err = functions.RegisterCustomFunction( "distance", functions.TypeMath, "几何数学", "计算两点间距离", 4, 4, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { x1 := cast.ToFloat64(args[0]) y1 := cast.ToFloat64(args[1]) x2 := cast.ToFloat64(args[2]) y2 := cast.ToFloat64(args[3]) distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2)) return distance, nil }, ) assert.NoError(t, err) defer functions.Unregister("distance") // 测试在SQL中使用 streamsql := New() defer streamsql.Stop() sql := ` SELECT device, AVG(square(value)) as squared_value, AVG(distance(x1, y1, x2, y2)) as calculated_distance FROM stream GROUP BY device, TumblingWindow('1s') ` err = streamsql.Execute(sql) assert.NoError(t, err) // 创建结果接收通道 resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "device": "sensor1", "value": 5.0, "x1": 0.0, "y1": 0.0, "x2": 3.0, "y2": 4.0, // 距离应该是5 } streamsql.AddData(testData) // 等待窗口触发 time.Sleep(1 * time.Second) streamsql.Stream().Window.Trigger() time.Sleep(500 * time.Millisecond) // 验证结果 select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) assert.Equal(t, 25.0, item["squared_value"]) // 5^2 = 25 assert.Equal(t, 5.0, item["calculated_distance"]) // sqrt((3-0)^2 + (4-0)^2) = 5 case <-time.After(2 * time.Second): t.Fatal("测试超时") } fmt.Println("✅ 自定义数学函数测试通过") } // TestCustomStringFunctions 测试自定义字符串函数 func TestCustomStringFunctions(t *testing.T) { // 注册字符串反转函数 err := functions.RegisterCustomFunction( "reverse_str", functions.TypeString, "字符串函数", "反转字符串", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { str := cast.ToString(args[0]) runes := []rune(str) for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { runes[i], runes[j] = runes[j], runes[i] } return string(runes), nil }, ) assert.NoError(t, err) defer functions.Unregister("reverse_str") // 注册JSON提取函数 err = functions.RegisterCustomFunction( "json_get", functions.TypeString, "JSON处理", "从JSON字符串中提取字段值", 2, 2, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { jsonStr := cast.ToString(args[0]) key := cast.ToString(args[1]) var data map[string]interface{} if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { return nil, fmt.Errorf("invalid JSON: %v", err) } value, exists := data[key] if !exists { return nil, nil } return value, nil }, ) assert.NoError(t, err) defer functions.Unregister("json_get") // 测试在SQL中使用 streamsql := New() defer streamsql.Stop() sql := ` SELECT device, reverse_str(device) as reversed_device, json_get(metadata, 'version') as version FROM stream ` err = streamsql.Execute(sql) assert.NoError(t, err) // 创建结果接收通道 resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "device": "sensor1", "metadata": `{"version":"1.0","type":"temperature"}`, } streamsql.AddData(testData) time.Sleep(200 * time.Millisecond) // 验证结果 select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) assert.Equal(t, "1rosnes", item["reversed_device"]) // "sensor1" 反转 assert.Equal(t, "1.0", item["version"]) case <-time.After(2 * time.Second): t.Fatal("测试超时") } fmt.Println("✅ 自定义字符串函数测试通过") } // TestCustomConversionFunctions 测试自定义转换函数 func TestCustomConversionFunctions(t *testing.T) { // 注册IP地址转换函数 err := functions.RegisterCustomFunction( "ip_to_num", functions.TypeConversion, "网络转换", "将IP地址转换为整数", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { ipStr := cast.ToString(args[0]) ip := net.ParseIP(ipStr) if ip == nil { return nil, fmt.Errorf("invalid IP address: %s", ipStr) } ip = ip.To4() if ip == nil { return nil, fmt.Errorf("not an IPv4 address: %s", ipStr) } return int64(ip[0])<<24 + int64(ip[1])<<16 + int64(ip[2])<<8 + int64(ip[3]), nil }, ) assert.NoError(t, err) defer functions.Unregister("ip_to_num") // 注册字节大小格式化函数 err = functions.RegisterCustomFunction( "format_bytes", functions.TypeConversion, "数据格式化", "格式化字节大小为人类可读格式", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { bytes := cast.ToFloat64(args[0]) units := []string{"B", "KB", "MB", "GB", "TB"} i := 0 for bytes >= 1024 && i < len(units)-1 { bytes /= 1024 i++ } return fmt.Sprintf("%.2f %s", bytes, units[i]), nil }, ) assert.NoError(t, err) defer functions.Unregister("format_bytes") // 测试函数直接调用 ctx := &functions.FunctionContext{Data: make(map[string]interface{})} // 测试IP转换 ipFunc, exists := functions.Get("ip_to_num") assert.True(t, exists) result, err := ipFunc.Execute(ctx, []interface{}{"192.168.1.100"}) assert.NoError(t, err) expectedIP := int64(192)<<24 + int64(168)<<16 + int64(1)<<8 + int64(100) assert.Equal(t, expectedIP, result) // 测试字节格式化 bytesFunc, exists := functions.Get("format_bytes") assert.True(t, exists) result, err = bytesFunc.Execute(ctx, []interface{}{1073741824}) // 1GB assert.NoError(t, err) assert.Equal(t, "1.00 GB", result) fmt.Println("✅ 自定义转换函数测试通过") } // TestCustomAggregateFunctions 测试自定义聚合函数 func TestCustomAggregateFunctions(t *testing.T) { // 注册几何平均数聚合函数 functions.Register(NewGeometricMeanFunction()) aggregator.Register("geometric_mean", func() aggregator.AggregatorFunction { return &GeometricMeanAggregator{} }) defer functions.Unregister("geometric_mean") // 注册众数聚合函数 functions.Register(NewModeFunction()) aggregator.Register("mode_value", func() aggregator.AggregatorFunction { return &ModeAggregator{} }) defer functions.Unregister("mode_value") // 测试在SQL中使用 streamsql := New() defer streamsql.Stop() sql := ` SELECT device, geometric_mean(value) as geo_mean, mode_value(category) as most_common FROM stream GROUP BY device, TumblingWindow('1s') ` err := streamsql.Execute(sql) assert.NoError(t, err) // 创建结果接收通道 resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) // 添加测试数据 testData := []interface{}{ map[string]interface{}{"device": "sensor1", "value": 2.0, "category": "A"}, map[string]interface{}{"device": "sensor1", "value": 8.0, "category": "A"}, map[string]interface{}{"device": "sensor1", "value": 32.0, "category": "B"}, map[string]interface{}{"device": "sensor1", "value": 128.0, "category": "A"}, } for _, data := range testData { streamsql.AddData(data) } time.Sleep(1 * time.Second) streamsql.Stream().Window.Trigger() time.Sleep(500 * time.Millisecond) // 验证结果 select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sensor1", item["device"]) // 几何平均数: (2 * 8 * 32 * 128) ^ (1/4) = 16 geoMean, ok := item["geo_mean"].(float64) assert.True(t, ok) assert.InEpsilon(t, 16.0, geoMean, 0.01) // 众数: A出现3次,B出现1次,所以众数是A mode := item["most_common"] assert.Equal(t, "A", mode) case <-time.After(3 * time.Second): t.Fatal("测试超时") } fmt.Println("✅ 自定义聚合函数测试通过") } // GeometricMeanFunction 几何平均数函数 type GeometricMeanFunction struct { *functions.BaseFunction } func NewGeometricMeanFunction() *GeometricMeanFunction { return &GeometricMeanFunction{ BaseFunction: functions.NewBaseFunction( "geometric_mean", functions.TypeAggregation, "统计聚合", "计算几何平均数", 1, -1, ), } } func (f *GeometricMeanFunction) Validate(args []interface{}) error { return f.ValidateArgCount(args) } func (f *GeometricMeanFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { return nil, nil // 实际逻辑在聚合器中 } // GeometricMeanAggregator 几何平均数聚合器 type GeometricMeanAggregator struct { values []float64 } func (g *GeometricMeanAggregator) New() aggregator.AggregatorFunction { return &GeometricMeanAggregator{values: make([]float64, 0)} } func (g *GeometricMeanAggregator) Add(value interface{}) { if val := cast.ToFloat64(value); val > 0 { g.values = append(g.values, val) } } func (g *GeometricMeanAggregator) Result() interface{} { if len(g.values) == 0 { return 0.0 } product := 1.0 for _, v := range g.values { product *= v } return math.Pow(product, 1.0/float64(len(g.values))) } // ModeFunction 众数函数 type ModeFunction struct { *functions.BaseFunction } func NewModeFunction() *ModeFunction { return &ModeFunction{ BaseFunction: functions.NewBaseFunction( "mode_value", functions.TypeAggregation, "统计聚合", "计算众数", 1, -1, ), } } func (f *ModeFunction) Validate(args []interface{}) error { return f.ValidateArgCount(args) } func (f *ModeFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { return nil, nil // 实际逻辑在聚合器中 } // ModeAggregator 众数聚合器 type ModeAggregator struct { counts map[string]int } func (m *ModeAggregator) New() aggregator.AggregatorFunction { return &ModeAggregator{counts: make(map[string]int)} } func (m *ModeAggregator) Add(value interface{}) { key := fmt.Sprintf("%v", value) m.counts[key]++ } func (m *ModeAggregator) Result() interface{} { if len(m.counts) == 0 { return nil } maxCount := 0 var mode interface{} for key, count := range m.counts { if count > maxCount { maxCount = count mode = key } } return mode } // TestFunctionManagement 测试函数管理功能 func TestFunctionManagement(t *testing.T) { // 注册测试函数 err := functions.RegisterCustomFunction( "test_func", functions.TypeCustom, "测试函数", "用于测试的函数", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { return args[0], nil }, ) assert.NoError(t, err) // 测试函数查找 fn, exists := functions.Get("test_func") assert.True(t, exists) assert.Equal(t, "test_func", fn.GetName()) assert.Equal(t, functions.TypeCustom, fn.GetType()) // 测试函数列表 allFunctions := functions.ListAll() assert.Contains(t, allFunctions, "test_func") // 测试按类型获取 customFunctions := functions.GetByType(functions.TypeCustom) found := false for _, f := range customFunctions { if f.GetName() == "test_func" { found = true break } } assert.True(t, found) // 测试函数注销 success := functions.Unregister("test_func") assert.True(t, success) // 验证函数已被注销 _, exists = functions.Get("test_func") assert.False(t, exists) fmt.Println("✅ 函数管理功能测试通过") } // TestCustomFunctionWithAggregation 测试自定义函数与聚合函数结合使用 func TestCustomFunctionWithAggregation(t *testing.T) { // 注册温度转换函数 err := functions.RegisterCustomFunction( "celsius_to_fahrenheit", functions.TypeConversion, "温度转换", "摄氏度转华氏度", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { celsius := cast.ToFloat64(args[0]) fahrenheit := celsius*9/5 + 32 return fahrenheit, nil }, ) assert.NoError(t, err) defer functions.Unregister("celsius_to_fahrenheit") // 测试在聚合SQL中使用 streamsql := New() defer streamsql.Stop() sql := ` SELECT device, AVG(celsius_to_fahrenheit(temperature)) as avg_fahrenheit, MAX(celsius_to_fahrenheit(temperature)) as max_fahrenheit FROM stream GROUP BY device, TumblingWindow('1s') ` err = streamsql.Execute(sql) assert.NoError(t, err) // 创建结果接收通道 resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) // 添加测试数据(摄氏度) testData := []interface{}{ map[string]interface{}{"device": "thermometer", "temperature": 0.0}, // 32°F map[string]interface{}{"device": "thermometer", "temperature": 100.0}, // 212°F } for _, data := range testData { streamsql.AddData(data) } time.Sleep(1 * time.Second) streamsql.Stream().Window.Trigger() time.Sleep(500 * time.Millisecond) // 验证结果 select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "thermometer", item["device"]) // 平均华氏度: (32 + 212) / 2 = 122 avgFahrenheit, ok := item["avg_fahrenheit"].(float64) assert.True(t, ok) assert.InEpsilon(t, 122.0, avgFahrenheit, 0.01) // 最大华氏度: 212 maxFahrenheit, ok := item["max_fahrenheit"].(float64) assert.True(t, ok) assert.InEpsilon(t, 212.0, maxFahrenheit, 0.01) case <-time.After(3 * time.Second): t.Fatal("测试超时") } fmt.Println("✅ 自定义函数与聚合函数结合使用测试通过") } // TestDebugCustomFunctions 调试自定义函数问题 func TestDebugCustomFunctions(t *testing.T) { // 注册简单的平方函数 err := functions.RegisterCustomFunction( "square", functions.TypeMath, "数学函数", "计算平方", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { val := cast.ToFloat64(args[0]) fmt.Printf("Square function called with: %v, result: %v\n", val, val*val) return val * val, nil }, ) assert.NoError(t, err) defer functions.Unregister("square") // 测试函数是否能被找到 fn, exists := functions.Get("square") assert.True(t, exists) fmt.Printf("Function found: %s, type: %s\n", fn.GetName(), fn.GetType()) // 测试表达式解析 expr, err := expr.NewExpression("square(value)") assert.NoError(t, err) // 获取表达式字段 fields := expr.GetFields() fmt.Printf("Expression fields: %v\n", fields) // 测试表达式计算 data := map[string]interface{}{"value": 5.0} result, err := expr.Evaluate(data) assert.NoError(t, err) fmt.Printf("Expression result: %v\n", result) assert.Equal(t, 25.0, result) // 测试SQL解析 parser := rsql.NewParser("SELECT square(value) as squared FROM stream") stmt, err := parser.Parse() assert.NoError(t, err) config, _, err := stmt.ToStreamConfig() assert.NoError(t, err) fmt.Printf("SQL Config - SelectFields: %v\n", config.SelectFields) fmt.Printf("SQL Config - FieldAlias: %v\n", config.FieldAlias) fmt.Printf("SQL Config - FieldExpressions: %v\n", config.FieldExpressions) fmt.Printf("SQL Config - NeedWindow: %v\n", config.NeedWindow) fmt.Println("✅ 调试测试完成") } // TestDebugMultiParameterFunction 测试多参数自定义函数 func TestDebugMultiParameterFunction(t *testing.T) { // 注册距离计算函数 err := functions.RegisterCustomFunction( "distance", functions.TypeMath, "几何数学", "计算两点间距离", 4, 4, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { x1 := cast.ToFloat64(args[0]) y1 := cast.ToFloat64(args[1]) x2 := cast.ToFloat64(args[2]) y2 := cast.ToFloat64(args[3]) distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2)) fmt.Printf("Distance function called with: (%v,%v) to (%v,%v), result: %v\n", x1, y1, x2, y2, distance) return distance, nil }, ) assert.NoError(t, err) defer functions.Unregister("distance") // 测试表达式解析 expr, err := expr.NewExpression("distance(x1, y1, x2, y2)") assert.NoError(t, err) // 获取表达式字段 fields := expr.GetFields() fmt.Printf("Distance expression fields: %v\n", fields) // 测试表达式计算 data := map[string]interface{}{ "x1": 0.0, "y1": 0.0, "x2": 3.0, "y2": 4.0, } result, err := expr.Evaluate(data) assert.NoError(t, err) fmt.Printf("Distance expression result: %v\n", result) assert.Equal(t, 5.0, result) // 测试SQL解析 parser := rsql.NewParser("SELECT AVG(distance(x1, y1, x2, y2)) as avg_distance FROM stream GROUP BY device, TumblingWindow('1s')") stmt, err := parser.Parse() assert.NoError(t, err) config, _, err := stmt.ToStreamConfig() assert.NoError(t, err) fmt.Printf("Distance SQL Config - SelectFields: %v\n", config.SelectFields) fmt.Printf("Distance SQL Config - FieldAlias: %v\n", config.FieldAlias) fmt.Printf("Distance SQL Config - FieldExpressions: %v\n", config.FieldExpressions) fmt.Println("✅ 多参数函数调试测试完成") } // TestDebugSQLParsing 调试SQL解析过程 func TestDebugSQLParsing(t *testing.T) { // 注册距离计算函数 err := functions.RegisterCustomFunction( "distance", functions.TypeMath, "几何数学", "计算两点间距离", 4, 4, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { x1 := cast.ToFloat64(args[0]) y1 := cast.ToFloat64(args[1]) x2 := cast.ToFloat64(args[2]) y2 := cast.ToFloat64(args[3]) distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2)) return distance, nil }, ) assert.NoError(t, err) defer functions.Unregister("distance") // 测试不同的SQL形式 testCases := []string{ "SELECT distance(x1, y1, x2, y2) as calc_distance FROM stream", "SELECT AVG(distance(x1, y1, x2, y2)) as avg_distance FROM stream", "SELECT AVG(distance(x1, y1, x2, y2)) as avg_distance FROM stream GROUP BY device, TumblingWindow('1s')", } for i, sql := range testCases { fmt.Printf("\n=== 测试SQL %d: %s ===\n", i+1, sql) parser := rsql.NewParser(sql) stmt, err := parser.Parse() if err != nil { fmt.Printf("SQL解析错误: %v\n", err) continue } // 打印解析结果 fmt.Printf("解析到的字段数量: %d\n", len(stmt.Fields)) for j, field := range stmt.Fields { fmt.Printf("字段 %d: Expression='%s', Alias='%s'\n", j, field.Expression, field.Alias) } config, condition, err := stmt.ToStreamConfig() if err != nil { fmt.Printf("转换配置错误: %v\n", err) continue } fmt.Printf("转换后配置:\n") fmt.Printf(" SelectFields: %v\n", config.SelectFields) fmt.Printf(" FieldAlias: %v\n", config.FieldAlias) fmt.Printf(" FieldExpressions: %v\n", config.FieldExpressions) fmt.Printf(" NeedWindow: %v\n", config.NeedWindow) fmt.Printf(" Condition: %s\n", condition) } fmt.Println("✅ SQL解析调试测试完成") }