package streamsql import ( "fmt" "github.com/rulego/streamsql/utils/cast" "testing" "time" "github.com/rulego/streamsql/functions" "github.com/stretchr/testify/assert" ) // TestPluginStyleCustomFunctions 测试插件式自定义函数 func TestPluginStyleCustomFunctions(t *testing.T) { fmt.Println("🔌 测试插件式自定义函数系统") // 动态注册新函数(运行时注册,无需修改SQL解析代码) // 1. 注册字符串处理函数(应该直接处理,不需要窗口) err := functions.RegisterCustomFunction( "mask_phone", // 全新的函数名 functions.TypeString, "数据脱敏", "手机号脱敏", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { phone := cast.ToString(args[0]) if len(phone) != 11 { return phone, nil } return phone[:3] + "****" + phone[7:], nil }, ) assert.NoError(t, err) defer functions.Unregister("mask_phone") // 2. 注册转换函数(应该直接处理) err = functions.RegisterCustomFunction( "format_id", functions.TypeConversion, "格式化", "格式化ID", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { id := cast.ToString(args[0]) return "ID_" + id, nil }, ) assert.NoError(t, err) defer functions.Unregister("format_id") // 3. 注册数学函数(用于窗口聚合) err = functions.RegisterCustomFunction( "calculate_commission", functions.TypeMath, "业务计算", "计算销售佣金", 2, 2, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { sales := cast.ToFloat64(args[0]) rate := cast.ToFloat64(args[1]) return sales * rate / 100, nil }, ) assert.NoError(t, err) defer functions.Unregister("calculate_commission") // 测试1:纯字符串函数(不需要窗口) testStringFunctionsOnly(t) // 测试2:转换函数(不需要窗口) testConversionFunctionsOnly(t) // 测试3:数学函数在聚合中使用(需要窗口) testMathFunctionsInAggregate(t) fmt.Println("✅ 插件式自定义函数测试完成") } func testStringFunctionsOnly(t *testing.T) { fmt.Println("\n📝 测试纯字符串函数(直接处理模式)...") streamsql := New() defer streamsql.Stop() sql := ` SELECT employee_id, mask_phone(phone) as masked_phone FROM stream ` err := streamsql.Execute(sql) assert.NoError(t, err) resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "employee_id": "E001", "phone": "13812345678", } streamsql.AddData(testData) time.Sleep(300 * time.Millisecond) select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "E001", item["employee_id"]) assert.Equal(t, "138****5678", item["masked_phone"]) // 脱敏后的手机号 fmt.Printf(" 📊 字符串函数结果: %v\n", item) case <-time.After(2 * time.Second): t.Fatal("字符串函数测试超时") } } func testConversionFunctionsOnly(t *testing.T) { fmt.Println("\n🔄 测试转换函数(直接处理模式)...") streamsql := New() defer streamsql.Stop() sql := ` SELECT user_id, format_id(user_id) as formatted_id FROM stream ` err := streamsql.Execute(sql) assert.NoError(t, err) resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) // 添加测试数据 testData := map[string]interface{}{ "user_id": "12345", } streamsql.AddData(testData) time.Sleep(300 * time.Millisecond) select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "12345", item["user_id"]) assert.Equal(t, "ID_12345", item["formatted_id"]) fmt.Printf(" 📊 转换函数结果: %v\n", item) case <-time.After(2 * time.Second): t.Fatal("转换函数测试超时") } } func testMathFunctionsInAggregate(t *testing.T) { fmt.Println("\n📈 测试数学函数在聚合中使用(窗口模式)...") streamsql := New() defer streamsql.Stop() sql := ` SELECT department, AVG(calculate_commission(sales, commission_rate)) as avg_commission FROM stream GROUP BY department, TumblingWindow('1s') ` err := streamsql.Execute(sql) assert.NoError(t, err) resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) // 添加测试数据 testData := []interface{}{ map[string]interface{}{ "department": "sales", "sales": 8000.0, "commission_rate": 3.0, }, map[string]interface{}{ "department": "sales", "sales": 12000.0, "commission_rate": 4.0, }, } for _, data := range testData { streamsql.AddData(data) } time.Sleep(1 * time.Second) streamsql.Stream().Window.Trigger() time.Sleep(500 * time.Millisecond) select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "sales", item["department"]) // 验证聚合计算结果 avgCommission, ok := item["avg_commission"].(float64) assert.True(t, ok) expectedAvg := (8000*3/100 + 12000*4/100) / 2 // (240 + 480) / 2 = 360 assert.InEpsilon(t, expectedAvg, avgCommission, 0.01) fmt.Printf(" 📊 聚合数学函数结果: %v\n", item) case <-time.After(3 * time.Second): t.Fatal("聚合数学函数测试超时") } } // TestRuntimeFunctionManagement 测试运行时函数管理 func TestRuntimeFunctionManagement(t *testing.T) { fmt.Println("\n🔧 测试运行时函数管理...") // 动态注册函数 err := functions.RegisterCustomFunction( "temp_function", functions.TypeString, // 使用字符串类型以便直接处理 "临时函数", "临时测试函数", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { val := cast.ToString(args[0]) return "TEMP_" + val, nil }, ) assert.NoError(t, err) // 验证函数已注册 fn, exists := functions.Get("temp_function") assert.True(t, exists) assert.Equal(t, "temp_function", fn.GetName()) // 在SQL中使用 streamsql := New() defer streamsql.Stop() sql := `SELECT temp_function(value) as result FROM stream` err = streamsql.Execute(sql) assert.NoError(t, err) resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) streamsql.AddData(map[string]interface{}{"value": "test"}) time.Sleep(300 * time.Millisecond) select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) assert.Equal(t, "TEMP_test", resultSlice[0]["result"]) case <-time.After(2 * time.Second): t.Fatal("运行时函数管理测试超时") } // 运行时注销函数 success := functions.Unregister("temp_function") assert.True(t, success) // 验证函数已注销 _, exists = functions.Get("temp_function") assert.False(t, exists) fmt.Println("✅ 运行时函数管理测试完成") } // TestFunctionPluginDiscovery 测试函数插件发现机制 func TestFunctionPluginDiscovery(t *testing.T) { fmt.Println("\n🔍 测试函数插件发现机制...") // 注册不同类型的函数 functions.RegisterCustomFunction("plugin_math", functions.TypeMath, "插件", "数学插件", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { return args[0], nil }) functions.RegisterCustomFunction("plugin_string", functions.TypeString, "插件", "字符串插件", 1, 1, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { return args[0], nil }) defer functions.Unregister("plugin_math") defer functions.Unregister("plugin_string") // 测试按类型发现函数 mathFunctions := functions.GetByType(functions.TypeMath) assert.Greater(t, len(mathFunctions), 0) // 验证新注册的函数被发现 found := false for _, fn := range mathFunctions { if fn.GetName() == "plugin_math" { found = true break } } assert.True(t, found, "新注册的数学函数应该被发现") // 测试全量函数发现 allFunctions := functions.ListAll() assert.Contains(t, allFunctions, "plugin_math") assert.Contains(t, allFunctions, "plugin_string") //fmt.Println(fmt.Sprintf("发现的函数总数: %d", len(allFunctions))) fmt.Println("✅ 函数插件发现机制测试完成") } // TestCompleteSQLIntegration 测试完整的SQL集成 func TestCompleteSQLIntegration(t *testing.T) { fmt.Println("\n🎯 测试完整SQL集成...") // 注册完全新的业务函数 err := functions.RegisterCustomFunction( "business_metric", functions.TypeString, "业务指标", "计算业务指标", 2, 2, func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { category := cast.ToString(args[0]) value := cast.ToFloat64(args[1]) var multiplier float64 switch category { case "premium": multiplier = 1.5 case "standard": multiplier = 1.0 default: multiplier = 0.8 } return fmt.Sprintf("%s:%.2f", category, value*multiplier), nil }, ) assert.NoError(t, err) defer functions.Unregister("business_metric") streamsql := New() defer streamsql.Stop() // 使用全新的函数在SQL中 sql := ` SELECT customer_id, business_metric(tier, amount) as metric FROM stream ` err = streamsql.Execute(sql) assert.NoError(t, err) resultChan := make(chan interface{}, 10) streamsql.Stream().AddSink(func(result interface{}) { resultChan <- result }) testData := map[string]interface{}{ "customer_id": "C001", "tier": "premium", "amount": 100.0, } streamsql.AddData(testData) time.Sleep(300 * time.Millisecond) select { case result := <-resultChan: resultSlice, ok := result.([]map[string]interface{}) assert.True(t, ok) assert.Len(t, resultSlice, 1) item := resultSlice[0] assert.Equal(t, "C001", item["customer_id"]) assert.Equal(t, "premium:150.00", item["metric"]) case <-time.After(2 * time.Second): t.Fatal("完整SQL集成测试超时") } fmt.Println("✅ 完整SQL集成测试完成") }