package functions import ( "math" "testing" "time" "github.com/rulego/streamsql/utils/cast" "github.com/stretchr/testify/assert" ) func TestBasicFunctionRegistry(t *testing.T) { // 测试基本函数注册 tests := []struct { name string functionName string expectedType FunctionType }{ {"abs function", "abs", TypeMath}, {"concat function", "concat", TypeString}, {"sqrt function", "sqrt", TypeMath}, {"upper function", "upper", TypeString}, {"cast function", "cast", TypeConversion}, {"now function", "now", TypeDateTime}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.functionName) assert.True(t, exists, "%s should be registered", tt.functionName) assert.NotNil(t, fn) assert.Equal(t, tt.functionName, fn.GetName()) assert.Equal(t, tt.expectedType, fn.GetType()) }) } // 测试不存在的函数 _, exists := Get("nonexistent") assert.False(t, exists, "nonexistent function should not be found") } func TestNewMathFunctions(t *testing.T) { ctx := &FunctionContext{ Data: map[string]interface{}{}, } // 表驱动测试用例 tests := []struct { name string functionName string args []interface{} expected interface{} expectError bool errorMsg string delta float64 // 用于浮点数比较的精度 }{ // Log function tests {"log valid", "log", []interface{}{10.0}, 1.0, false, "", 1e-10}, {"log negative", "log", []interface{}{-1}, nil, true, "value must be positive", 0}, {"log zero", "log", []interface{}{0}, nil, true, "value must be positive", 0}, // Log10 function tests {"log10 100", "log10", []interface{}{100}, 2.0, false, "", 1e-10}, {"log10 10", "log10", []interface{}{10}, 1.0, false, "", 1e-10}, // Log2 function tests {"log2 8", "log2", []interface{}{8}, 3.0, false, "", 1e-10}, {"log2 2", "log2", []interface{}{2}, 1.0, false, "", 1e-10}, // Mod function tests {"mod 10,3", "mod", []interface{}{10, 3}, 1.0, false, "", 1e-10}, {"mod 7.5,2.5", "mod", []interface{}{7.5, 2.5}, 0.0, false, "", 1e-10}, {"mod division by zero", "mod", []interface{}{10, 0}, nil, true, "division by zero", 0}, // Round function tests {"round 3.7", "round", []interface{}{3.7}, 4.0, false, "", 1e-10}, {"round 3.2", "round", []interface{}{3.2}, 3.0, false, "", 1e-10}, {"round with precision", "round", []interface{}{3.14159, 2}, 3.14, false, "", 1e-10}, // Sign function tests {"sign positive", "sign", []interface{}{5.5}, 1, false, "", 0}, {"sign negative", "sign", []interface{}{-3.2}, -1, false, "", 0}, {"sign zero", "sign", []interface{}{0}, 0, false, "", 0}, // Trigonometric function tests {"sin 0", "sin", []interface{}{0}, 0.0, false, "", 1e-10}, {"sin π/2", "sin", []interface{}{math.Pi / 2}, 1.0, false, "", 1e-10}, {"sinh 0", "sinh", []interface{}{0}, 0.0, false, "", 1e-10}, {"tan 0", "tan", []interface{}{0}, 0.0, false, "", 1e-10}, {"tanh 0", "tanh", []interface{}{0}, 0.0, false, "", 1e-10}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.functionName) assert.True(t, exists, "Function %s should be registered", tt.functionName) result, err := fn.Execute(ctx, tt.args) if tt.expectError { assert.Error(t, err) if tt.errorMsg != "" { assert.Contains(t, err.Error(), tt.errorMsg) } } else { assert.NoError(t, err) if tt.delta > 0 { assert.InDelta(t, tt.expected, result, tt.delta) } else { assert.Equal(t, tt.expected, result) } } }) } // 特殊测试:rand函数(因为结果是随机的) t.Run("rand function", func(t *testing.T) { fn, exists := Get("rand") assert.True(t, exists) result, err := fn.Execute(ctx, []interface{}{}) assert.NoError(t, err) val, ok := result.(float64) assert.True(t, ok) assert.GreaterOrEqual(t, val, 0.0) assert.Less(t, val, 1.0) }) } func TestFunctionExecution(t *testing.T) { ctx := &FunctionContext{ Data: map[string]interface{}{}, } // 函数执行测试用例 tests := []struct { name string functionName string args []interface{} expected interface{} expectError bool }{ // 数学函数 {"abs with positive", "abs", []interface{}{5.5}, 5.5, false}, {"abs with negative", "abs", []interface{}{-5.5}, 5.5, false}, {"abs with zero", "abs", []interface{}{0}, 0.0, false}, {"sqrt with perfect square", "sqrt", []interface{}{16.0}, 4.0, false}, {"sqrt with decimal", "sqrt", []interface{}{2.0}, 1.4142135623730951, false}, {"sqrt with zero", "sqrt", []interface{}{0}, 0.0, false}, {"sqrt with negative", "sqrt", []interface{}{-1}, nil, true}, // 时间日期函数 {"now basic", "now", []interface{}{}, time.Now().Unix(), false}, {"current_time basic", "current_time", []interface{}{}, time.Now().Format("15:04:05"), false}, {"current_date basic", "current_date", []interface{}{}, time.Now().Format("2006-01-02"), false}, // 新增数学函数测试 {"acos valid", "acos", []interface{}{0.5}, math.Acos(0.5), false}, {"acos invalid", "acos", []interface{}{2.0}, nil, true}, {"asin valid", "asin", []interface{}{0.5}, math.Asin(0.5), false}, {"asin invalid", "asin", []interface{}{2.0}, nil, true}, {"atan valid", "atan", []interface{}{1.0}, math.Atan(1.0), false}, {"atan2 valid", "atan2", []interface{}{1.0, 1.0}, math.Atan2(1.0, 1.0), false}, {"bitand valid", "bitand", []interface{}{5, 3}, int64(1), false}, {"bitor valid", "bitor", []interface{}{5, 3}, int64(7), false}, {"bitxor valid", "bitxor", []interface{}{5, 3}, int64(6), false}, {"bitnot valid", "bitnot", []interface{}{5}, int64(-6), false}, {"ceiling positive", "ceiling", []interface{}{3.7}, 4.0, false}, {"ceiling negative", "ceiling", []interface{}{-3.7}, -3.0, false}, {"cos valid", "cos", []interface{}{0.0}, 1.0, false}, {"cosh valid", "cosh", []interface{}{0.0}, 1.0, false}, {"exp valid", "exp", []interface{}{1.0}, math.E, false}, {"floor positive", "floor", []interface{}{3.7}, 3.0, false}, {"floor negative", "floor", []interface{}{-3.7}, -4.0, false}, {"ln valid", "ln", []interface{}{math.E}, 1.0, false}, {"ln invalid", "ln", []interface{}{-1.0}, nil, true}, {"power valid", "power", []interface{}{2.0, 3.0}, 8.0, false}, // 字符串函数 {"concat basic", "concat", []interface{}{"hello", " ", "world"}, "hello world", false}, {"concat single", "concat", []interface{}{"hello"}, "hello", false}, {"concat numbers", "concat", []interface{}{1, 2, 3}, "123", false}, {"length basic", "length", []interface{}{"hello"}, int(5), false}, {"length empty", "length", []interface{}{""}, int(0), false}, {"upper basic", "upper", []interface{}{"hello"}, "HELLO", false}, {"upper mixed", "upper", []interface{}{"Hello World"}, "HELLO WORLD", false}, {"lower basic", "lower", []interface{}{"HELLO"}, "hello", false}, {"lower mixed", "lower", []interface{}{"Hello World"}, "hello world", false}, // 转换函数 {"cast to int64", "cast", []interface{}{"123", "int64"}, int64(123), false}, {"cast to float64", "cast", []interface{}{"123.45", "float64"}, 123.45, false}, {"cast to string", "cast", []interface{}{123, "string"}, "123", false}, {"hex2dec basic", "hex2dec", []interface{}{"ff"}, int64(255), false}, {"hex2dec upper", "hex2dec", []interface{}{"FF"}, int64(255), false}, {"hex2dec with prefix", "hex2dec", []interface{}{"a0"}, int64(160), false}, {"dec2hex basic", "dec2hex", []interface{}{255}, "ff", false}, {"dec2hex zero", "dec2hex", []interface{}{0}, "0", false}, {"dec2hex large", "dec2hex", []interface{}{4095}, "fff", false}, {"encode base64", "encode", []interface{}{"hello", "base64"}, "aGVsbG8=", false}, {"encode hex", "encode", []interface{}{"hello", "hex"}, "68656c6c6f", false}, {"encode url", "encode", []interface{}{"hello world", "url"}, "hello+world", false}, {"encode invalid format", "encode", []interface{}{"hello", "invalid"}, nil, true}, {"encode invalid input", "encode", []interface{}{123, "base64"}, nil, true}, {"decode base64", "decode", []interface{}{"aGVsbG8=", "base64"}, "hello", false}, {"decode hex", "decode", []interface{}{"68656c6c6f", "hex"}, "hello", false}, {"decode url", "decode", []interface{}{"hello+world", "url"}, "hello world", false}, {"decode invalid format", "decode", []interface{}{"hello", "invalid"}, nil, true}, {"decode invalid base64", "decode", []interface{}{"invalid!", "base64"}, nil, true}, {"decode invalid hex", "decode", []interface{}{"invalid!", "hex"}, nil, true}, // 聚合函数 {"sum basic", "sum", []interface{}{1, 2, 3}, 6.0, false}, {"sum float", "sum", []interface{}{1.5, 2.5}, 4.0, false}, {"avg basic", "avg", []interface{}{1, 2, 3}, 2.0, false}, {"min basic", "min", []interface{}{3, 1, 2}, 1.0, false}, {"max basic", "max", []interface{}{3, 1, 2}, 3.0, false}, {"count basic", "count", []interface{}{1, 2, 3, 4, 5}, int64(5), false}, // 错误情况 {"hex2dec invalid", "hex2dec", []interface{}{"xyz"}, nil, true}, // 字符串函数 {"trim basic", "trim", []interface{}{" hello world "}, "hello world", false}, {"trim empty", "trim", []interface{}{""}, "", false}, {"format number 2 decimals", "format", []interface{}{123.456, "0.00"}, "123.46", false}, {"format number 0 decimals", "format", []interface{}{123.456, "0"}, "123", false}, {"format string only", "format", []interface{}{"hello"}, "hello", false}, // 新增的聚合函数 {"collect basic", "collect", []interface{}{1, 2, 3}, []interface{}{1, 2, 3}, false}, {"last_value basic", "last_value", []interface{}{1, 2, 3, 4}, 4, false}, {"merge_agg basic", "merge_agg", []interface{}{"a", "b", "c"}, "a,b,c", false}, {"stddevs basic", "stddevs", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 1.5811388300841898, false}, {"deduplicate basic", "deduplicate", []interface{}{1, 2, 2, 3, 3, 3}, []interface{}{1, 2, 3}, false}, {"var basic", "var", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 2.0, false}, {"vars basic", "vars", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 2.5, false}, // 窗口函数 {"row_number basic", "row_number", []interface{}{}, int64(1), false}, // 分析函数 {"latest basic", "latest", []interface{}{"hello"}, "hello", false}, {"had_changed first", "had_changed", []interface{}{"value1"}, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.functionName) assert.True(t, exists, "function %s should exist", tt.functionName) if !exists || fn == nil { t.Errorf("Function %s not found or is nil", tt.functionName) return } result, err := fn.Execute(ctx, tt.args) if tt.expectError { assert.Error(t, err, "expected error for %s", tt.name) } else { assert.NoError(t, err, "no error expected for %s", tt.name) if tt.expected != nil { switch expected := tt.expected.(type) { case float64: if resultFloat, ok := result.(float64); ok { assert.InDelta(t, expected, resultFloat, 0.0001, "result should match for %s", tt.name) } else { t.Errorf("Expected float64 but got %T for %s", result, tt.name) } case int64: if tt.functionName == "now" { // 对于 now 函数,我们只检查结果是否为 int64 类型,因为具体值会随时间变化 _, ok := result.(int64) assert.True(t, ok, "now function should return int64") } else { assert.Equal(t, expected, result, "result should match for %s", tt.name) } case string: if tt.functionName == "current_time" || tt.functionName == "current_date" { // 对于时间日期函数,我们只检查格式是否正确 resultStr, ok := result.(string) assert.True(t, ok, "%s function should return string", tt.functionName) if tt.functionName == "current_time" { _, err := time.Parse("15:04:05", resultStr) assert.NoError(t, err, "current_time should return valid time format") } else if tt.functionName == "current_date" { _, err := time.Parse("2006-01-02", resultStr) assert.NoError(t, err, "current_date should return valid date format") } } else { assert.Equal(t, expected, result, "result should match for %s", tt.name) } default: assert.Equal(t, expected, result, "result should match for %s", tt.name) } } } }) } } func TestFunctionValidation(t *testing.T) { // 参数验证测试用例 tests := []struct { name string functionName string args []interface{} expectError bool description string }{ // abs 函数 - 需要1个参数 {"abs no args", "abs", []interface{}{}, true, "abs requires 1 argument"}, {"abs too many args", "abs", []interface{}{1.0, 2.0}, true, "abs accepts only 1 argument"}, {"abs correct args", "abs", []interface{}{1.0}, false, "abs should accept 1 argument"}, // 时间日期函数参数验证 {"current_time with args", "current_time", []interface{}{1}, true, "current_time should not accept arguments"}, {"current_date with args", "current_date", []interface{}{1}, true, "current_date should not accept arguments"}, // concat 函数 - 需要至少1个参数 {"concat no args", "concat", []interface{}{}, true, "concat requires at least 1 argument"}, {"concat one arg", "concat", []interface{}{"hello"}, false, "concat should accept 1 argument"}, {"concat multiple args", "concat", []interface{}{"a", "b", "c"}, false, "concat should accept multiple arguments"}, // cast 函数 - 需要恰好2个参数 {"cast no args", "cast", []interface{}{}, true, "cast requires 2 arguments"}, {"cast one arg", "cast", []interface{}{"123"}, true, "cast requires 2 arguments"}, {"cast correct args", "cast", []interface{}{"123", "int64"}, false, "cast should accept 2 arguments"}, {"cast too many args", "cast", []interface{}{"123", "int64", "extra"}, true, "cast accepts only 2 arguments"}, // now 函数 - 不需要参数 {"now no args", "now", []interface{}{}, false, "now should accept no arguments"}, {"now with args", "now", []interface{}{1}, true, "now should not accept arguments"}, // 新增数学函数参数验证 {"acos no args", "acos", []interface{}{}, true, "acos requires 1 argument"}, {"acos too many args", "acos", []interface{}{1.0, 2.0}, true, "acos accepts only 1 argument"}, {"atan2 no args", "atan2", []interface{}{}, true, "atan2 requires 2 arguments"}, {"atan2 one arg", "atan2", []interface{}{1.0}, true, "atan2 requires 2 arguments"}, {"atan2 too many args", "atan2", []interface{}{1.0, 2.0, 3.0}, true, "atan2 accepts only 2 arguments"}, {"bitand no args", "bitand", []interface{}{}, true, "bitand requires 2 arguments"}, {"bitand one arg", "bitand", []interface{}{1}, true, "bitand requires 2 arguments"}, {"bitand too many args", "bitand", []interface{}{1, 2, 3}, true, "bitand accepts only 2 arguments"}, {"bitnot no args", "bitnot", []interface{}{}, true, "bitnot requires 1 argument"}, {"bitnot too many args", "bitnot", []interface{}{1, 2}, true, "bitnot accepts only 1 argument"}, {"power no args", "power", []interface{}{}, true, "power requires 2 arguments"}, {"power one arg", "power", []interface{}{2.0}, true, "power requires 2 arguments"}, {"power too many args", "power", []interface{}{2.0, 3.0, 4.0}, true, "power accepts only 2 arguments"}, // 转换函数参数验证 {"encode no args", "encode", []interface{}{}, true, "encode requires 2 arguments"}, {"encode one arg", "encode", []interface{}{"hello"}, true, "encode requires 2 arguments"}, {"encode three args", "encode", []interface{}{"hello", "base64", "extra"}, true, "encode requires exactly 2 arguments"}, {"encode invalid format type", "encode", []interface{}{"hello", 123}, true, "encode format must be a string"}, {"decode no args", "decode", []interface{}{}, true, "decode requires 2 arguments"}, {"decode one arg", "decode", []interface{}{"aGVsbG8="}, true, "decode requires 2 arguments"}, {"decode three args", "decode", []interface{}{"aGVsbG8=", "base64", "extra"}, true, "decode requires exactly 2 arguments"}, {"decode invalid input type", "decode", []interface{}{123, "base64"}, true, "decode input must be a string"}, {"decode invalid format type", "decode", []interface{}{"aGVsbG8=", 123}, true, "decode format must be a string"}, // 新增函数的验证测试 {"trim no args", "trim", []interface{}{}, true, "function trim requires at least 1 arguments"}, {"trim too many args", "trim", []interface{}{"hello", "world"}, true, "function trim accepts at most 1 arguments"}, {"format too many args", "format", []interface{}{"hello", "pattern", "locale", "extra"}, true, "function format accepts at most 3 arguments"}, {"collect no args", "collect", []interface{}{}, true, "function collect requires at least 1 arguments"}, {"row_number with args", "row_number", []interface{}{"invalid"}, true, "function row_number accepts at most 0 arguments"}, {"latest no args", "latest", []interface{}{}, true, "function latest requires at least 1 arguments"}, {"had_changed no args", "had_changed", []interface{}{}, true, "function had_changed requires at least 1 arguments"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.functionName) assert.True(t, exists, "function %s should exist", tt.functionName) err := fn.Validate(tt.args) if tt.expectError { assert.Error(t, err, tt.description) } else { assert.NoError(t, err, tt.description) } }) } } func TestFunctionTypes(t *testing.T) { // 函数类型分类测试 tests := []struct { functionType FunctionType functions []string }{ {TypeMath, []string{ "abs", "sqrt", "acos", "asin", "atan", "atan2", "bitand", "bitor", "bitxor", "bitnot", "ceiling", "cos", "cosh", "exp", "floor", "ln", "power", }}, {TypeString, []string{"concat", "length", "upper", "lower", "trim", "format"}}, {TypeConversion, []string{"cast", "hex2dec", "dec2hex", "encode", "decode"}}, {TypeDateTime, []string{"now", "current_time", "current_date"}}, {TypeAggregation, []string{"sum", "avg", "min", "max", "count", "stddev", "median", "collect", "last_value", "merge_agg", "stddevs", "deduplicate", "var", "vars"}}, {TypeWindow, []string{"row_number"}}, {TypeAnalytical, []string{"lag", "latest", "changed_col", "had_changed"}}, } for _, tt := range tests { t.Run(string(tt.functionType), func(t *testing.T) { functions := GetByType(tt.functionType) assert.GreaterOrEqual(t, len(functions), len(tt.functions), "should have at least %d functions of type %s", len(tt.functions), tt.functionType) // 验证特定函数存在 functionNames := make(map[string]bool) for _, fn := range functions { functionNames[fn.GetName()] = true } for _, expectedFn := range tt.functions { assert.True(t, functionNames[expectedFn], "function %s should be of type %s", expectedFn, tt.functionType) } }) } } func TestCustomFunction(t *testing.T) { // 注册自定义函数 err := RegisterCustomFunction("double2", TypeCustom, "自定义函数", "将数值乘以2", 1, 1, func(ctx *FunctionContext, args []interface{}) (interface{}, error) { val := cast.ToFloat64(args[0]) return val * 2, nil }) assert.NoError(t, err) // 测试自定义函数 tests := []struct { name string args []interface{} expected interface{} }{ {"double positive", []interface{}{5.0}, 10.0}, {"double negative", []interface{}{-3.0}, -6.0}, {"double zero", []interface{}{0}, 0.0}, {"double string number", []interface{}{"2.5"}, 5.0}, } ctx := &FunctionContext{ Data: map[string]interface{}{}, } doubleFunc, exists := Get("double2") assert.True(t, exists) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := doubleFunc.Execute(ctx, tt.args) assert.NoError(t, err) assert.Equal(t, tt.expected, result) }) } // 清理 Unregister("double2") } func TestComplexFunctionCombinations(t *testing.T) { ctx := &FunctionContext{ Data: map[string]interface{}{}, } // 测试复杂函数组合 tests := []struct { name string description string operations func() (interface{}, error) expected interface{} }{ { name: "abs of negative sum", description: "计算负数之和的绝对值", operations: func() (interface{}, error) { sumFn, _ := Get("sum") sum, err := sumFn.Execute(ctx, []interface{}{-1, -2, -3}) if err != nil { return nil, err } absFn, _ := Get("abs") return absFn.Execute(ctx, []interface{}{sum}) }, expected: 6.0, }, { name: "concat and upper", description: "连接字符串后转大写", operations: func() (interface{}, error) { concatFn, _ := Get("concat") concat, err := concatFn.Execute(ctx, []interface{}{"hello", " ", "world"}) if err != nil { return nil, err } upperFn, _ := Get("upper") return upperFn.Execute(ctx, []interface{}{concat}) }, expected: "HELLO WORLD", }, { name: "hex conversion round trip", description: "十进制转十六进制再转回十进制", operations: func() (interface{}, error) { dec2hexFn, _ := Get("dec2hex") hex, err := dec2hexFn.Execute(ctx, []interface{}{255}) if err != nil { return nil, err } hex2decFn, _ := Get("hex2dec") return hex2decFn.Execute(ctx, []interface{}{hex}) }, expected: int64(255), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := tt.operations() assert.NoError(t, err, tt.description) assert.Equal(t, tt.expected, result, tt.description) }) } } // TestCaseInsensitiveFunctions 测试函数大小写不敏感 func TestCaseInsensitiveFunctions(t *testing.T) { tests := []struct { name string functionName string expected bool }{ {"小写concat", "concat", true}, {"大写CONCAT", "CONCAT", true}, {"混合大小写Concat", "Concat", true}, {"混合大小写cOnCaT", "cOnCaT", true}, {"小写upper", "upper", true}, {"大写UPPER", "UPPER", true}, {"混合大小写Upper", "Upper", true}, {"小写lower", "lower", true}, {"大写LOWER", "LOWER", true}, {"混合大小写Lower", "Lower", true}, {"小写length", "length", true}, {"大写LENGTH", "LENGTH", true}, {"混合大小写Length", "Length", true}, {"小写trim", "trim", true}, {"大写TRIM", "TRIM", true}, {"混合大小写Trim", "Trim", true}, {"不存在的函数", "nonexistent", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, exists := Get(tt.functionName) assert.Equal(t, tt.expected, exists, "函数 %s 的查找结果应该是 %v", tt.functionName, tt.expected) }) } } // TestConcatFunctionCaseInsensitive 测试CONCAT函数的大小写不敏感执行 func TestConcatFunctionCaseInsensitive(t *testing.T) { tests := []struct { name string functionName string args []interface{} expected string }{ {"小写concat", "concat", []interface{}{"hello", " ", "world"}, "hello world"}, {"大写CONCAT", "CONCAT", []interface{}{"hello", " ", "world"}, "hello world"}, {"混合大小写Concat", "Concat", []interface{}{"hello", " ", "world"}, "hello world"}, {"混合大小写cOnCaT", "cOnCaT", []interface{}{"hello", " ", "world"}, "hello world"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.functionName) assert.True(t, exists, "函数 %s 应该存在", tt.functionName) ctx := &FunctionContext{ Data: make(map[string]interface{}), } result, err := fn.Execute(ctx, tt.args) assert.NoError(t, err, "函数 %s 执行不应该出错", tt.functionName) assert.Equal(t, tt.expected, result, "函数 %s 的执行结果应该正确", tt.functionName) }) } } // TestStringFunctionsCaseInsensitive 测试所有字符串函数的大小写不敏感 func TestStringFunctionsCaseInsensitive(t *testing.T) { ctx := &FunctionContext{ Data: make(map[string]interface{}), } // 测试UPPER函数 t.Run("UPPER函数大小写不敏感", func(t *testing.T) { functionNames := []string{"upper", "UPPER", "Upper", "uPpEr"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) result, err := fn.Execute(ctx, []interface{}{"hello"}) assert.NoError(t, err, "函数 %s 执行不应该出错", name) assert.Equal(t, "HELLO", result, "函数 %s 的执行结果应该正确", name) } }) // 测试LOWER函数 t.Run("LOWER函数大小写不敏感", func(t *testing.T) { functionNames := []string{"lower", "LOWER", "Lower", "lOwEr"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) result, err := fn.Execute(ctx, []interface{}{"HELLO"}) assert.NoError(t, err, "函数 %s 执行不应该出错", name) assert.Equal(t, "hello", result, "函数 %s 的执行结果应该正确", name) } }) // 测试LENGTH函数 t.Run("LENGTH函数大小写不敏感", func(t *testing.T) { functionNames := []string{"length", "LENGTH", "Length", "lEnGtH"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) result, err := fn.Execute(ctx, []interface{}{"hello"}) assert.NoError(t, err, "函数 %s 执行不应该出错", name) assert.Equal(t, int(5), result, "函数 %s 的执行结果应该正确", name) } }) // 测试TRIM函数 t.Run("TRIM函数大小写不敏感", func(t *testing.T) { functionNames := []string{"trim", "TRIM", "Trim", "tRiM"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) result, err := fn.Execute(ctx, []interface{}{" hello "}) assert.NoError(t, err, "函数 %s 执行不应该出错", name) assert.Equal(t, "hello", result, "函数 %s 的执行结果应该正确", name) } }) } // TestMathFunctionsCaseInsensitive 测试数学函数的大小写不敏感 func TestMathFunctionsCaseInsensitive(t *testing.T) { ctx := &FunctionContext{ Data: make(map[string]interface{}), } // 测试ABS函数 t.Run("ABS函数大小写不敏感", func(t *testing.T) { functionNames := []string{"abs", "ABS", "Abs", "aBs"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) result, err := fn.Execute(ctx, []interface{}{-5.5}) assert.NoError(t, err, "函数 %s 执行不应该出错", name) assert.Equal(t, 5.5, result, "函数 %s 的执行结果应该正确", name) } }) // 测试SQRT函数 t.Run("SQRT函数大小写不敏感", func(t *testing.T) { functionNames := []string{"sqrt", "SQRT", "Sqrt", "sQrT"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) result, err := fn.Execute(ctx, []interface{}{9.0}) assert.NoError(t, err, "函数 %s 执行不应该出错", name) assert.Equal(t, 3.0, result, "函数 %s 的执行结果应该正确", name) } }) } // TestAggregationFunctionsCaseInsensitive 测试聚合函数的大小写不敏感 func TestAggregationFunctionsCaseInsensitive(t *testing.T) { // 测试SUM函数 t.Run("SUM函数大小写不敏感", func(t *testing.T) { functionNames := []string{"sum", "SUM", "Sum", "sUm"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) assert.Equal(t, TypeAggregation, fn.GetType(), "函数 %s 应该是聚合函数", name) } }) // 测试AVG函数 t.Run("AVG函数大小写不敏感", func(t *testing.T) { functionNames := []string{"avg", "AVG", "Avg", "aVg"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) assert.Equal(t, TypeAggregation, fn.GetType(), "函数 %s 应该是聚合函数", name) } }) // 测试COUNT函数 t.Run("COUNT函数大小写不敏感", func(t *testing.T) { functionNames := []string{"count", "COUNT", "Count", "cOuNt"} for _, name := range functionNames { fn, exists := Get(name) assert.True(t, exists, "函数 %s 应该存在", name) assert.Equal(t, TypeAggregation, fn.GetType(), "函数 %s 应该是聚合函数", name) } }) } // TestFunctionAliases 测试函数别名功能 func TestFunctionAliases(t *testing.T) { // 测试 power 函数的 pow 别名 powerFunc, exists := Get("power") if !exists { t.Fatal("power function not found") } powFunc, exists := Get("pow") if !exists { t.Fatal("pow alias not found") } // 验证别名指向同一个函数实例 if powerFunc != powFunc { t.Error("pow alias should point to the same function as power") } // 测试 length 函数的 len 别名 lengthFunc, exists := Get("length") if !exists { t.Fatal("length function not found") } lenFunc, exists := Get("len") if !exists { t.Fatal("len alias not found") } // 验证别名指向同一个函数实例 if lengthFunc != lenFunc { t.Error("len alias should point to the same function as length") } // 测试 ceiling 函数的 ceil 别名 ceilingFunc, exists := Get("ceiling") if !exists { t.Fatal("ceiling function not found") } ceilFunc, exists := Get("ceil") if !exists { t.Fatal("ceil alias not found") } // 验证别名指向同一个函数实例 if ceilingFunc != ceilFunc { t.Error("ceil alias should point to the same function as ceiling") } // 验证别名列表 aliases := powerFunc.GetAliases() if len(aliases) != 1 || aliases[0] != "pow" { t.Errorf("Expected aliases [pow], got %v", aliases) } aliases = lengthFunc.GetAliases() if len(aliases) != 1 || aliases[0] != "len" { t.Errorf("Expected aliases [len], got %v", aliases) } aliases = ceilingFunc.GetAliases() if len(aliases) != 1 || aliases[0] != "ceil" { t.Errorf("Expected aliases [ceil], got %v", aliases) } } // TestFunctionAliasExecution 测试通过别名执行函数 // TestExecuteFunction 测试Execute函数的各种场景 func TestExecuteFunction(t *testing.T) { ctx := &FunctionContext{ Data: map[string]interface{}{ "x": 10, "y": 20, }, } // 测试正常执行 result, err := Execute("abs", ctx, []interface{}{-5}) assert.NoError(t, err) assert.Equal(t, 5.0, result) // 测试函数不存在 result, err = Execute("nonexistent_function", ctx, []interface{}{1}) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "function nonexistent_function not found") // 测试参数验证失败 result, err = Execute("abs", ctx, []interface{}{}) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "validation failed") // 测试参数过多 result, err = Execute("abs", ctx, []interface{}{1, 2, 3}) assert.Error(t, err) assert.Nil(t, result) assert.Contains(t, err.Error(), "validation failed") } // TestCustomFunctionValidate 测试CustomFunction的Validate方法 func TestCustomFunctionValidate(t *testing.T) { // 注册一个自定义函数 err := RegisterCustomFunction("test_custom", TypeMath, "test", "test function", 1, 2, func(ctx *FunctionContext, args []interface{}) (interface{}, error) { return args[0], nil }) assert.NoError(t, err) // 获取自定义函数 fn, exists := Get("test_custom") assert.True(t, exists) // 测试参数数量正确 err = fn.Validate([]interface{}{1}) assert.NoError(t, err) err = fn.Validate([]interface{}{1, 2}) assert.NoError(t, err) // 测试参数数量不足 err = fn.Validate([]interface{}{}) assert.Error(t, err) // 测试参数数量过多 err = fn.Validate([]interface{}{1, 2, 3}) assert.Error(t, err) // 清理 Unregister("test_custom") } // TestRegisterCustomFunctionErrors 测试RegisterCustomFunction的错误情况 func TestRegisterCustomFunctionErrors(t *testing.T) { // 测试空函数名 err := RegisterCustomFunction("", TypeMath, "test", "test function", 1, 2, func(ctx *FunctionContext, args []interface{}) (interface{}, error) { return nil, nil }) assert.Error(t, err) assert.Contains(t, err.Error(), "function name cannot be empty") // 测试正常注册 err = RegisterCustomFunction("valid_custom", TypeMath, "test", "test function", 1, 2, func(ctx *FunctionContext, args []interface{}) (interface{}, error) { return args[0], nil }) assert.NoError(t, err) // 清理 Unregister("valid_custom") } func TestFunctionAliasExecution(t *testing.T) { ctx := &FunctionContext{} // 测试 pow 别名执行 result, err := Execute("pow", ctx, []interface{}{2.0, 3.0}) if err != nil { t.Fatalf("pow execution failed: %v", err) } if result != 8.0 { t.Errorf("Expected 8.0, got %v", result) } // 测试 len 别名执行 result, err = Execute("len", ctx, []interface{}{"hello"}) if err != nil { t.Fatalf("len execution failed: %v", err) } if result != 5 { t.Errorf("Expected 5, got %v", result) } // 测试 ceil 别名执行 result, err = Execute("ceil", ctx, []interface{}{3.2}) if err != nil { t.Fatalf("ceil execution failed: %v", err) } if result != 4.0 { t.Errorf("Expected 4.0, got %v", result) } }