diff --git a/functions/builtin.go b/functions/builtin.go index 0d365ca..9f71288 100644 --- a/functions/builtin.go +++ b/functions/builtin.go @@ -65,8 +65,8 @@ func registerBuiltinFunctions() { _ = Register(NewToSecondsFunction()) _ = Register(NewChrFunction()) _ = Register(NewTruncFunction()) - _ = Register(NewCompressFunction()) - _ = Register(NewDecompressFunction()) + _ = Register(NewUrlEncodeFunction()) + _ = Register(NewUrlDecodeFunction()) // Time-Date functions _ = Register(NewNowFunction()) @@ -79,8 +79,8 @@ func registerBuiltinFunctions() { _ = Register(NewMinFunction()) _ = Register(NewMaxFunction()) _ = Register(NewCountFunction()) - _ = Register(NewStdDevFunction()) - _ = Register(NewMedianFunction()) + _ = Register(NewStdDevAggregatorFunction()) + _ = Register(NewMedianAggregatorFunction()) _ = Register(NewPercentileFunction()) _ = Register(NewCollectFunction()) _ = Register(NewLastValueFunction()) @@ -109,6 +109,7 @@ func registerBuiltinFunctions() { // Expression functions _ = Register(NewExpressionFunction()) _ = Register(NewExprFunction()) + _ = Register(NewExpressionAggregatorFunction()) // JSON functions _ = Register(NewToJsonFunction()) @@ -144,10 +145,15 @@ func registerBuiltinFunctions() { _ = Register(NewIsObjectFunction()) // Conditional functions + _ = Register(NewIfNullFunction()) _ = Register(NewCoalesceFunction()) _ = Register(NewNullIfFunction()) _ = Register(NewGreatestFunction()) _ = Register(NewLeastFunction()) + _ = Register(NewCaseWhenFunction()) + + // Multi-row functions + _ = Register(NewUnnestFunction()) // User-defined functions (placeholder for future extension) // Example: _=Register(NewMyUserDefinedFunction()) diff --git a/functions/functions_analytical.go b/functions/functions_analytical.go index d57ee39..150946c 100644 --- a/functions/functions_analytical.go +++ b/functions/functions_analytical.go @@ -99,9 +99,11 @@ func (f *LagFunction) Result() interface{} { return f.DefaultValue } // 返回当前值之前第Offset个值 - // 对于数组[first, second, third],当前位置是最后一个元素 - // offset=1时返回second(倒数第2个),offset=2时返回first(倒数第3个) - return f.PreviousValues[len(f.PreviousValues)-f.Offset-1] + // 对于数组[first, second, third],当前位置是最后一个元素third(索引2) + // offset=1时应该返回second(索引1),计算:len-1-offset = 3-1-1 = 1 + // offset=2时应该返回first(索引0),计算:len-1-offset = 3-1-2 = 0 + // 索引计算:len-1-offset,即从最后一个元素往前数offset个位置 + return f.PreviousValues[len(f.PreviousValues)-1-f.Offset] } func (f *LagFunction) Clone() AggregatorFunction { diff --git a/functions/functions_array.go b/functions/functions_array.go index 546c43b..b1653df 100644 --- a/functions/functions_array.go +++ b/functions/functions_array.go @@ -47,12 +47,12 @@ func (f *ArrayContainsFunction) Validate(args []interface{}) error { func (f *ArrayContainsFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { array := args[0] value := args[1] - + v := reflect.ValueOf(array) if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { return nil, fmt.Errorf("array_contains requires array input") } - + for i := 0; i < v.Len(); i++ { if reflect.DeepEqual(v.Index(i).Interface(), value) { return true, nil @@ -79,12 +79,12 @@ func (f *ArrayPositionFunction) Validate(args []interface{}) error { func (f *ArrayPositionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { array := args[0] value := args[1] - + v := reflect.ValueOf(array) if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { return nil, fmt.Errorf("array_position requires array input") } - + for i := 0; i < v.Len(); i++ { if reflect.DeepEqual(v.Index(i).Interface(), value) { return i + 1, nil // 返回1基索引 @@ -111,12 +111,12 @@ func (f *ArrayRemoveFunction) Validate(args []interface{}) error { func (f *ArrayRemoveFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { array := args[0] value := args[1] - + v := reflect.ValueOf(array) if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { return nil, fmt.Errorf("array_remove requires array input") } - + var result []interface{} for i := 0; i < v.Len(); i++ { elem := v.Index(i).Interface() @@ -144,15 +144,15 @@ func (f *ArrayDistinctFunction) Validate(args []interface{}) error { func (f *ArrayDistinctFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { array := args[0] - + v := reflect.ValueOf(array) if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { return nil, fmt.Errorf("array_distinct requires array input") } - + seen := make(map[interface{}]bool) var result []interface{} - + for i := 0; i < v.Len(); i++ { elem := v.Index(i).Interface() if !seen[elem] { @@ -181,27 +181,27 @@ func (f *ArrayIntersectFunction) Validate(args []interface{}) error { func (f *ArrayIntersectFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { array1 := args[0] array2 := args[1] - + v1 := reflect.ValueOf(array1) v2 := reflect.ValueOf(array2) - + if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array { return nil, fmt.Errorf("array_intersect requires array input for first argument") } if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array { return nil, fmt.Errorf("array_intersect requires array input for second argument") } - + // 创建第二个数组的元素集合 set2 := make(map[interface{}]bool) for i := 0; i < v2.Len(); i++ { set2[v2.Index(i).Interface()] = true } - + // 找交集 seen := make(map[interface{}]bool) var result []interface{} - + for i := 0; i < v1.Len(); i++ { elem := v1.Index(i).Interface() if set2[elem] && !seen[elem] { @@ -230,20 +230,20 @@ func (f *ArrayUnionFunction) Validate(args []interface{}) error { func (f *ArrayUnionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { array1 := args[0] array2 := args[1] - + v1 := reflect.ValueOf(array1) v2 := reflect.ValueOf(array2) - + if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array { return nil, fmt.Errorf("array_union requires array input for first argument") } if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array { return nil, fmt.Errorf("array_union requires array input for second argument") } - + seen := make(map[interface{}]bool) var result []interface{} - + // 添加第一个数组的元素 for i := 0; i < v1.Len(); i++ { elem := v1.Index(i).Interface() @@ -252,7 +252,7 @@ func (f *ArrayUnionFunction) Execute(ctx *FunctionContext, args []interface{}) ( result = append(result, elem) } } - + // 添加第二个数组的元素 for i := 0; i < v2.Len(); i++ { elem := v2.Index(i).Interface() @@ -282,27 +282,27 @@ func (f *ArrayExceptFunction) Validate(args []interface{}) error { func (f *ArrayExceptFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { array1 := args[0] array2 := args[1] - + v1 := reflect.ValueOf(array1) v2 := reflect.ValueOf(array2) - + if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array { return nil, fmt.Errorf("array_except requires array input for first argument") } if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array { return nil, fmt.Errorf("array_except requires array input for second argument") } - + // 创建第二个数组的元素集合 set2 := make(map[interface{}]bool) for i := 0; i < v2.Len(); i++ { set2[v2.Index(i).Interface()] = true } - + // 找差集 seen := make(map[interface{}]bool) var result []interface{} - + for i := 0; i < v1.Len(); i++ { elem := v1.Index(i).Interface() if !set2[elem] && !seen[elem] { @@ -311,4 +311,4 @@ func (f *ArrayExceptFunction) Execute(ctx *FunctionContext, args []interface{}) } } return result, nil -} \ No newline at end of file +} diff --git a/functions/functions_array_test.go b/functions/functions_array_test.go new file mode 100644 index 0000000..79bdcfb --- /dev/null +++ b/functions/functions_array_test.go @@ -0,0 +1,65 @@ +package functions + +import ( + "testing" +) + +// 测试数组函数 +func TestArrayFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + }{ + { + name: "array_length basic", + funcName: "array_length", + args: []interface{}{[]interface{}{1, 2, 3}}, + expected: 3, + }, + { + name: "array_contains true", + funcName: "array_contains", + args: []interface{}{[]interface{}{1, 2, 3}, 2}, + expected: true, + }, + { + name: "array_contains false", + funcName: "array_contains", + args: []interface{}{[]interface{}{1, 2, 3}, 4}, + expected: false, + }, + { + name: "array_position found", + funcName: "array_position", + args: []interface{}{[]interface{}{1, 2, 3}, 2}, + expected: 2, + }, + { + name: "array_position not found", + funcName: "array_position", + args: []interface{}{[]interface{}{1, 2, 3}, 4}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if err != nil { + t.Errorf("Execute() error = %v", err) + return + } + + if result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_compression_test.go b/functions/functions_compression_test.go deleted file mode 100644 index b27f285..0000000 --- a/functions/functions_compression_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package functions - -import ( - "testing" -) - -func TestCompressionFunctions(t *testing.T) { - tests := []struct { - name string - funcName string - args []interface{} - wantErr bool - }{ - // Compress function tests - { - name: "compress_gzip_valid", - funcName: "compress", - args: []interface{}{"hello world", "gzip"}, - wantErr: false, - }, - { - name: "compress_zlib_valid", - funcName: "compress", - args: []interface{}{"hello world", "zlib"}, - wantErr: false, - }, - { - name: "compress_invalid_algorithm", - funcName: "compress", - args: []interface{}{"hello world", "invalid"}, - wantErr: true, - }, - { - name: "compress_empty_string", - funcName: "compress", - args: []interface{}{"", "gzip"}, - wantErr: false, - }, - { - name: "compress_wrong_arg_count", - funcName: "compress", - args: []interface{}{"hello"}, - wantErr: true, - }, - // Decompress function tests - { - name: "decompress_invalid_base64", - funcName: "decompress", - args: []interface{}{"invalid_base64", "gzip"}, - wantErr: true, - }, - { - name: "decompress_invalid_algorithm", - funcName: "decompress", - args: []interface{}{"SGVsbG8gV29ybGQ=", "invalid"}, - wantErr: true, - }, - { - name: "decompress_wrong_arg_count", - funcName: "decompress", - args: []interface{}{"SGVsbG8gV29ybGQ="}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fn, exists := Get(tt.funcName) - if !exists { - t.Fatalf("%s function not found", tt.funcName) - } - - // 执行函数 - _, err := fn.Execute(nil, tt.args) - if (err != nil) != tt.wantErr { - t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func TestCompressionDecompressionRoundTrip(t *testing.T) { - tests := []struct { - name string - algorithm string - input string - }{ - { - name: "gzip_round_trip", - algorithm: "gzip", - input: "Hello, World! This is a test string for compression.", - }, - { - name: "zlib_round_trip", - algorithm: "zlib", - input: "Hello, World! This is a test string for compression.", - }, - { - name: "gzip_empty_string", - algorithm: "gzip", - input: "", - }, - { - name: "zlib_unicode", - algorithm: "zlib", - input: "你好世界!这是一个测试字符串。", - }, - } - - compressFn, exists := Get("compress") - if !exists { - t.Fatal("compress function not found") - } - - decompressFn, exists := Get("decompress") - if !exists { - t.Fatal("decompress function not found") - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 压缩 - compressed, err := compressFn.Execute(nil, []interface{}{tt.input, tt.algorithm}) - if err != nil { - t.Fatalf("Compress failed: %v", err) - } - - // 解压缩 - decompressed, err := decompressFn.Execute(nil, []interface{}{compressed, tt.algorithm}) - if err != nil { - t.Fatalf("Decompress failed: %v", err) - } - - // 验证结果 - if decompressed != tt.input { - t.Errorf("Round trip failed: expected %q, got %q", tt.input, decompressed) - } - }) - } -} \ No newline at end of file diff --git a/functions/functions_conditional.go b/functions/functions_conditional.go index 2cb94f5..cf44322 100644 --- a/functions/functions_conditional.go +++ b/functions/functions_conditional.go @@ -7,6 +7,28 @@ import ( "github.com/rulego/streamsql/utils/cast" ) +// IfNullFunction 如果第一个参数为NULL则返回第二个参数 +type IfNullFunction struct { + *BaseFunction +} + +func NewIfNullFunction() *IfNullFunction { + return &IfNullFunction{ + BaseFunction: NewBaseFunction("if_null", TypeString, "条件函数", "如果第一个参数为NULL则返回第二个参数", 2, 2), + } +} + +func (f *IfNullFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IfNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if args[0] == nil { + return args[1], nil + } + return args[0], nil +} + // CoalesceFunction 返回第一个非NULL值 type CoalesceFunction struct { *BaseFunction @@ -38,7 +60,7 @@ type NullIfFunction struct { func NewNullIfFunction() *NullIfFunction { return &NullIfFunction{ - BaseFunction: NewBaseFunction("nullif", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2), + BaseFunction: NewBaseFunction("null_if", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2), } } @@ -72,21 +94,21 @@ func (f *GreatestFunction) Execute(ctx *FunctionContext, args []interface{}) (in if len(args) == 0 { return nil, nil } - + max := args[0] if max == nil { return nil, nil } - + for i := 1; i < len(args); i++ { if args[i] == nil { return nil, nil } - + // 尝试转换为数字进行比较 maxVal, err1 := cast.ToFloat64E(max) currVal, err2 := cast.ToFloat64E(args[i]) - + if err1 == nil && err2 == nil { if currVal > maxVal { max = args[i] @@ -122,21 +144,21 @@ func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (inter if len(args) == 0 { return nil, nil } - + min := args[0] if min == nil { return nil, nil } - + for i := 1; i < len(args); i++ { if args[i] == nil { return nil, nil } - + // 尝试转换为数字进行比较 minVal, err1 := cast.ToFloat64E(min) currVal, err2 := cast.ToFloat64E(args[i]) - + if err1 == nil && err2 == nil { if currVal < minVal { min = args[i] @@ -151,4 +173,82 @@ func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (inter } } return min, nil -} \ No newline at end of file +} + +// CaseWhenFunction CASE WHEN表达式 +type CaseWhenFunction struct { + *BaseFunction +} + +func NewCaseWhenFunction() *CaseWhenFunction { + return &CaseWhenFunction{ + BaseFunction: NewBaseFunction("case_when", TypeString, "条件函数", "CASE WHEN表达式", 2, -1), + } +} + +func (f *CaseWhenFunction) Validate(args []interface{}) error { + if len(args) < 2 { + return fmt.Errorf("case_when requires at least 2 arguments") + } + + // 参数必须是偶数个(条件-值对)或奇数个(最后一个是默认值) + if len(args)%2 == 0 { + // 偶数个参数,必须都是条件-值对 + for i := 0; i < len(args); i += 2 { + // 条件应该是布尔值或可以转换为布尔值的表达式 + } + } else { + // 奇数个参数,最后一个是默认值 + for i := 0; i < len(args)-1; i += 2 { + // 条件应该是布尔值或可以转换为布尔值的表达式 + } + } + + return nil +} + +func (f *CaseWhenFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + // 处理条件-值对 + for i := 0; i < len(args)-1; i += 2 { + condition := args[i] + value := args[i+1] + + // 将条件转换为布尔值 + condBool, err := cast.ToBoolE(condition) + if err != nil { + // 如果无法转换为布尔值,检查是否为非零/非空值 + if condition == nil { + condBool = false + } else { + switch v := condition.(type) { + case string: + condBool = v != "" + case int, int32, int64: + num, _ := cast.ToInt64E(v) + condBool = num != 0 + case float32, float64: + num, _ := cast.ToFloat64E(v) + condBool = num != 0.0 + default: + condBool = true + } + } + } + + if condBool { + return value, nil + } + } + + // 如果没有条件匹配,返回默认值(如果有) + if len(args)%2 == 1 { + return args[len(args)-1], nil + } + + // 没有默认值,返回 nil + return nil, nil +} diff --git a/functions/functions_conditional_test.go b/functions/functions_conditional_test.go new file mode 100644 index 0000000..8df5a8e --- /dev/null +++ b/functions/functions_conditional_test.go @@ -0,0 +1,105 @@ +package functions + +import ( + "testing" +) + +// 测试条件函数 +func TestConditionalFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + wantErr bool + }{ + { + name: "if_null with null", + funcName: "if_null", + args: []interface{}{nil, "default"}, + expected: "default", + }, + { + name: "if_null with value", + funcName: "if_null", + args: []interface{}{"value", "default"}, + expected: "value", + }, + { + name: "null_if equal", + funcName: "null_if", + args: []interface{}{"test", "test"}, + expected: nil, + }, + { + name: "null_if not equal", + funcName: "null_if", + args: []interface{}{"test", "other"}, + expected: "test", + }, + { + name: "greatest basic", + funcName: "greatest", + args: []interface{}{1, 3, 2}, + expected: 3, + }, + { + name: "least basic", + funcName: "least", + args: []interface{}{1, 3, 2}, + expected: 1, + }, + + // case_when 函数测试 + { + name: "case_when simple", + funcName: "case_when", + args: []interface{}{true, "result1", false, "result2", "default"}, + expected: "result1", + }, + { + name: "case_when second condition", + funcName: "case_when", + args: []interface{}{false, "result1", true, "result2", "default"}, + expected: "result2", + }, + { + name: "case_when default", + funcName: "case_when", + args: []interface{}{false, "result1", false, "result2", "default"}, + expected: "default", + }, + { + name: "case_when no default", + funcName: "case_when", + args: []interface{}{false, "result1", false, "result2"}, + expected: nil, + }, + { + name: "case_when invalid args", + funcName: "case_when", + args: []interface{}{true}, + expected: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_conversion.go b/functions/functions_conversion.go index 23fad7d..ae74673 100644 --- a/functions/functions_conversion.go +++ b/functions/functions_conversion.go @@ -1,18 +1,15 @@ package functions import ( - "bytes" - "compress/gzip" - "compress/zlib" "encoding/base64" "encoding/hex" "fmt" - "github.com/rulego/streamsql/utils/cast" - "io" "math" "net/url" "strconv" "time" + + "github.com/rulego/streamsql/utils/cast" ) // CastFunction 类型转换函数 @@ -265,19 +262,19 @@ func (f *ConvertTzFunction) Execute(ctx *FunctionContext, args []interface{}) (i default: return nil, fmt.Errorf("time value must be time.Time or string") } - + // 获取目标时区 timezone, err := cast.ToStringE(args[1]) if err != nil { return nil, err } - + // 加载时区 loc, err := time.LoadLocation(timezone) if err != nil { return nil, fmt.Errorf("invalid timezone: %s", timezone) } - + // 转换时区 return t.In(loc), nil } @@ -324,7 +321,7 @@ func (f *ToSecondsFunction) Execute(ctx *FunctionContext, args []interface{}) (i default: return nil, fmt.Errorf("time value must be time.Time or string") } - + return t.Unix(), nil } @@ -348,151 +345,115 @@ func (f *ChrFunction) Execute(ctx *FunctionContext, args []interface{}) (interfa if err != nil { return nil, err } - + if code < 0 || code > 127 { return nil, fmt.Errorf("ASCII code must be between 0 and 127, got %d", code) } - + return string(rune(code)), nil } + + +// UrlEncodeFunction URL编码函数 +type UrlEncodeFunction struct { + *BaseFunction +} + +func NewUrlEncodeFunction() *UrlEncodeFunction { + return &UrlEncodeFunction{ + BaseFunction: NewBaseFunction("url_encode", TypeConversion, "转换函数", "URL编码", 1, 1), + } +} + +func (f *UrlEncodeFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *UrlEncodeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + if args[0] == nil { + return nil, fmt.Errorf("url_encode: input cannot be nil") + } + + input := cast.ToString(args[0]) + return url.QueryEscape(input), nil +} + +// UrlDecodeFunction URL解码函数 +type UrlDecodeFunction struct { + *BaseFunction +} + +func NewUrlDecodeFunction() *UrlDecodeFunction { + return &UrlDecodeFunction{ + BaseFunction: NewBaseFunction("url_decode", TypeConversion, "转换函数", "URL解码", 1, 1), + } +} + +func (f *UrlDecodeFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *UrlDecodeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + if args[0] == nil { + return nil, fmt.Errorf("url_decode: input cannot be nil") + } + + input := cast.ToString(args[0]) + result, err := url.QueryUnescape(input) + if err != nil { + return nil, fmt.Errorf("URL decode failed: %v", err) + } + return result, nil +} + // TruncFunction 截断小数位数 type TruncFunction struct { *BaseFunction } +// NewTruncFunction 创建新的 trunc 函数 func NewTruncFunction() *TruncFunction { return &TruncFunction{ BaseFunction: NewBaseFunction("trunc", TypeConversion, "转换函数", "截断小数位数", 2, 2), } } +// Validate 验证参数 func (f *TruncFunction) Validate(args []interface{}) error { return f.ValidateArgCount(args) } +// Execute 执行函数 func (f *TruncFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { - val, err := cast.ToFloat64E(args[0]) - if err != nil { + if err := f.Validate(args); err != nil { return nil, err } - - precision, err := cast.ToIntE(args[1]) - if err != nil { - return nil, err - } - + + // 转换第一个参数为浮点数 + num := cast.ToFloat64(args[0]) + + // 转换第二个参数为整数(精度) + precision := cast.ToInt(args[1]) + + // 精度不能为负数 if precision < 0 { - return nil, fmt.Errorf("precision must be non-negative, got %d", precision) + return nil, fmt.Errorf("trunc precision cannot be negative") } - + // 计算截断 multiplier := math.Pow(10, float64(precision)) - return math.Trunc(val*multiplier) / multiplier, nil -} - -// CompressFunction 压缩函数 -type CompressFunction struct { - *BaseFunction -} - -func NewCompressFunction() *CompressFunction { - return &CompressFunction{ - BaseFunction: NewBaseFunction("compress", TypeConversion, "转换函数", "压缩字符串或二进制值", 2, 2), + if num >= 0 { + return math.Floor(num*multiplier) / multiplier, nil + } else { + return math.Ceil(num*multiplier) / multiplier, nil } } - -func (f *CompressFunction) Validate(args []interface{}) error { - return f.ValidateArgCount(args) -} - -func (f *CompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { - if err := f.Validate(args); err != nil { - return nil, err - } - - input := cast.ToString(args[0]) - algorithm := cast.ToString(args[1]) - - var buf bytes.Buffer - var writer io.WriteCloser - - switch algorithm { - case "gzip": - writer = gzip.NewWriter(&buf) - case "zlib": - writer = zlib.NewWriter(&buf) - default: - return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm) - } - - _, err := writer.Write([]byte(input)) - if err != nil { - return nil, fmt.Errorf("compression failed: %v", err) - } - - err = writer.Close() - if err != nil { - return nil, fmt.Errorf("compression failed: %v", err) - } - - // 返回base64编码的压缩数据 - return base64.StdEncoding.EncodeToString(buf.Bytes()), nil -} - -// DecompressFunction 解压缩函数 -type DecompressFunction struct { - *BaseFunction -} - -func NewDecompressFunction() *DecompressFunction { - return &DecompressFunction{ - BaseFunction: NewBaseFunction("decompress", TypeConversion, "转换函数", "解压缩字符串或二进制值", 2, 2), - } -} - -func (f *DecompressFunction) Validate(args []interface{}) error { - return f.ValidateArgCount(args) -} - -func (f *DecompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { - if err := f.Validate(args); err != nil { - return nil, err - } - - input := cast.ToString(args[0]) - algorithm := cast.ToString(args[1]) - - // 解码base64数据 - compressedData, err := base64.StdEncoding.DecodeString(input) - if err != nil { - return nil, fmt.Errorf("invalid base64 input: %v", err) - } - - buf := bytes.NewReader(compressedData) - var reader io.ReadCloser - - switch algorithm { - case "gzip": - reader, err = gzip.NewReader(buf) - if err != nil { - return nil, fmt.Errorf("gzip decompression failed: %v", err) - } - case "zlib": - reader, err = zlib.NewReader(buf) - if err != nil { - return nil, fmt.Errorf("zlib decompression failed: %v", err) - } - default: - return nil, fmt.Errorf("unsupported decompression algorithm: %s", algorithm) - } - - defer reader.Close() - - result, err := io.ReadAll(reader) - if err != nil { - return nil, fmt.Errorf("decompression failed: %v", err) - } - - return string(result), nil -} diff --git a/functions/functions_conversion_test.go b/functions/functions_conversion_test.go index 865e3d9..b0a367b 100644 --- a/functions/functions_conversion_test.go +++ b/functions/functions_conversion_test.go @@ -39,7 +39,7 @@ func TestNewConversionFunctions(t *testing.T) { args: []interface{}{"invalid-time", "UTC"}, wantErr: true, }, - + // to_seconds 函数测试 { name: "to_seconds with time.Time", @@ -61,7 +61,7 @@ func TestNewConversionFunctions(t *testing.T) { args: []interface{}{"invalid-time"}, wantErr: true, }, - + // chr 函数测试 { name: "chr valid ASCII code", @@ -89,7 +89,7 @@ func TestNewConversionFunctions(t *testing.T) { args: []interface{}{128}, wantErr: true, }, - + // trunc 函数测试 { name: "trunc positive number", @@ -98,6 +98,70 @@ func TestNewConversionFunctions(t *testing.T) { want: 3.14, wantErr: false, }, + + // url_encode 函数测试 + { + name: "url_encode basic", + funcName: "url_encode", + args: []interface{}{"hello world"}, + want: "hello+world", + wantErr: false, + }, + { + name: "url_encode special chars", + funcName: "url_encode", + args: []interface{}{"hello@world.com"}, + want: "hello%40world.com", + wantErr: false, + }, + { + name: "url_encode empty", + funcName: "url_encode", + args: []interface{}{""}, + want: "", + wantErr: false, + }, + { + name: "url_encode nil", + funcName: "url_encode", + args: []interface{}{nil}, + wantErr: true, + }, + + // url_decode 函数测试 + { + name: "url_decode basic", + funcName: "url_decode", + args: []interface{}{"hello+world"}, + want: "hello world", + wantErr: false, + }, + { + name: "url_decode special chars", + funcName: "url_decode", + args: []interface{}{"hello%40world.com"}, + want: "hello@world.com", + wantErr: false, + }, + { + name: "url_decode empty", + funcName: "url_decode", + args: []interface{}{""}, + want: "", + wantErr: false, + }, + { + name: "url_decode nil", + funcName: "url_decode", + args: []interface{}{nil}, + wantErr: true, + }, + { + name: "url_decode invalid", + funcName: "url_decode", + args: []interface{}{"hello%ZZ"}, + wantErr: true, + }, { name: "trunc negative number", funcName: "trunc", @@ -119,20 +183,20 @@ func TestNewConversionFunctions(t *testing.T) { wantErr: true, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.funcName) if !exists { t.Fatalf("Function %s not found", tt.funcName) } - + result, err := fn.Execute(nil, tt.args) if (err != nil) != tt.wantErr { t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) return } - + if !tt.wantErr { // 对于时间类型,需要特殊处理比较 if tt.funcName == "convert_tz" { @@ -157,4 +221,4 @@ func TestNewConversionFunctions(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/functions/functions_datetime.go b/functions/functions_datetime.go index 1c79992..a269ead 100644 --- a/functions/functions_datetime.go +++ b/functions/functions_datetime.go @@ -479,6 +479,12 @@ func (f *YearFunction) Validate(args []interface{}) error { } func (f *YearFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 首先检查是否是 time.Time 类型 + if t, ok := args[0].(time.Time); ok { + return float64(t.Year()), nil + } + + // 如果不是 time.Time,尝试转换为字符串并解析 dateStr, err := cast.ToStringE(args[0]) if err != nil { return nil, fmt.Errorf("invalid date: %v", err) @@ -491,7 +497,7 @@ func (f *YearFunction) Execute(ctx *FunctionContext, args []interface{}) (interf } } - return t.Year(), nil + return float64(t.Year()), nil } // MonthFunction 提取月份函数 @@ -510,6 +516,12 @@ func (f *MonthFunction) Validate(args []interface{}) error { } func (f *MonthFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 首先检查是否是 time.Time 类型 + if t, ok := args[0].(time.Time); ok { + return float64(t.Month()), nil + } + + // 如果不是 time.Time,尝试转换为字符串并解析 dateStr, err := cast.ToStringE(args[0]) if err != nil { return nil, fmt.Errorf("invalid date: %v", err) @@ -522,7 +534,7 @@ func (f *MonthFunction) Execute(ctx *FunctionContext, args []interface{}) (inter } } - return int(t.Month()), nil + return float64(t.Month()), nil } // DayFunction 提取日期函数 diff --git a/functions/functions_datetime_test.go b/functions/functions_datetime_test.go index 8e513d5..e78c345 100644 --- a/functions/functions_datetime_test.go +++ b/functions/functions_datetime_test.go @@ -63,7 +63,7 @@ func TestDateTimeFunctions(t *testing.T) { name: "year extraction", function: NewYearFunction(), args: []interface{}{"2023-12-25 15:30:45"}, - expected: 2023, + expected: float64(2023), wantErr: false, }, // MonthFunction 测试 @@ -71,7 +71,7 @@ func TestDateTimeFunctions(t *testing.T) { name: "month extraction", function: NewMonthFunction(), args: []interface{}{"2023-12-25 15:30:45"}, - expected: 12, + expected: float64(12), wantErr: false, }, // DayFunction 测试 @@ -164,14 +164,14 @@ func TestDateTimeFunctions(t *testing.T) { } return } - + // 执行函数 result, err := tt.function.Execute(nil, tt.args) if (err != nil) != tt.wantErr { t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) return } - + if !tt.wantErr && result != tt.expected { t.Errorf("Execute() = %v, want %v", result, tt.expected) } @@ -235,4 +235,4 @@ func TestDateFormatConversion(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/functions/functions_hash.go b/functions/functions_hash.go index a4bab70..3aee7c9 100644 --- a/functions/functions_hash.go +++ b/functions/functions_hash.go @@ -28,7 +28,7 @@ func (f *Md5Function) Execute(ctx *FunctionContext, args []interface{}) (interfa if !ok { return nil, fmt.Errorf("md5 requires string input") } - + hash := md5.Sum([]byte(str)) return fmt.Sprintf("%x", hash), nil } @@ -53,7 +53,7 @@ func (f *Sha1Function) Execute(ctx *FunctionContext, args []interface{}) (interf if !ok { return nil, fmt.Errorf("sha1 requires string input") } - + hash := sha1.Sum([]byte(str)) return fmt.Sprintf("%x", hash), nil } @@ -78,7 +78,7 @@ func (f *Sha256Function) Execute(ctx *FunctionContext, args []interface{}) (inte if !ok { return nil, fmt.Errorf("sha256 requires string input") } - + hash := sha256.Sum256([]byte(str)) return fmt.Sprintf("%x", hash), nil } @@ -103,7 +103,7 @@ func (f *Sha512Function) Execute(ctx *FunctionContext, args []interface{}) (inte if !ok { return nil, fmt.Errorf("sha512 requires string input") } - + hash := sha512.Sum512([]byte(str)) return fmt.Sprintf("%x", hash), nil -} \ No newline at end of file +} diff --git a/functions/functions_hash_test.go b/functions/functions_hash_test.go new file mode 100644 index 0000000..67bcc40 --- /dev/null +++ b/functions/functions_hash_test.go @@ -0,0 +1,53 @@ +package functions + +import ( + "testing" +) + +// 测试哈希函数 +func TestHashFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + }{ + { + name: "md5 basic", + funcName: "md5", + args: []interface{}{"hello"}, + expected: "5d41402abc4b2a76b9719d911017c592", + }, + { + name: "sha1 basic", + funcName: "sha1", + args: []interface{}{"hello"}, + expected: "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", + }, + { + name: "sha256 basic", + funcName: "sha256", + args: []interface{}{"hello"}, + expected: "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if err != nil { + t.Errorf("Execute() error = %v", err) + return + } + + if result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_json.go b/functions/functions_json.go index 10d39fc..35bfa05 100644 --- a/functions/functions_json.go +++ b/functions/functions_json.go @@ -50,7 +50,7 @@ func (f *FromJsonFunction) Execute(ctx *FunctionContext, args []interface{}) (in if !ok { return nil, fmt.Errorf("from_json requires string input") } - + var result interface{} err := json.Unmarshal([]byte(jsonStr), &result) if err != nil { @@ -79,18 +79,18 @@ func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{}) if !ok { return nil, fmt.Errorf("json_extract requires string input") } - + path, ok := args[1].(string) if !ok { return nil, fmt.Errorf("json_extract path must be string") } - + var data interface{} err := json.Unmarshal([]byte(jsonStr), &data) if err != nil { return nil, fmt.Errorf("failed to parse JSON: %v", err) } - + // 简单的路径提取,支持 $.field 格式 if strings.HasPrefix(path, "$.") { field := path[2:] @@ -98,7 +98,7 @@ func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{}) return dataMap[field], nil } } - + return nil, fmt.Errorf("invalid JSON path or data structure") } @@ -122,7 +122,7 @@ func (f *JsonValidFunction) Execute(ctx *FunctionContext, args []interface{}) (i if !ok { return false, nil } - + var temp interface{} err := json.Unmarshal([]byte(jsonStr), &temp) return err == nil, nil @@ -148,13 +148,13 @@ func (f *JsonTypeFunction) Execute(ctx *FunctionContext, args []interface{}) (in if !ok { return "unknown", nil } - + var data interface{} err := json.Unmarshal([]byte(jsonStr), &data) if err != nil { return "invalid", nil } - + switch data.(type) { case nil: return "null", nil @@ -193,13 +193,13 @@ func (f *JsonLengthFunction) Execute(ctx *FunctionContext, args []interface{}) ( if !ok { return nil, fmt.Errorf("json_length requires string input") } - + var data interface{} err := json.Unmarshal([]byte(jsonStr), &data) if err != nil { return nil, fmt.Errorf("failed to parse JSON: %v", err) } - + switch v := data.(type) { case []interface{}: return len(v), nil @@ -208,4 +208,4 @@ func (f *JsonLengthFunction) Execute(ctx *FunctionContext, args []interface{}) ( default: return nil, fmt.Errorf("JSON value is not an array or object") } -} \ No newline at end of file +} diff --git a/functions/functions_json_test.go b/functions/functions_json_test.go new file mode 100644 index 0000000..ee34c08 --- /dev/null +++ b/functions/functions_json_test.go @@ -0,0 +1,105 @@ +package functions + +import ( + "testing" +) + +// 测试JSON函数 +func TestJsonFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + wantErr bool + }{ + { + name: "to_json basic", + funcName: "to_json", + args: []interface{}{map[string]interface{}{"name": "test", "value": 123}}, + expected: `{"name":"test","value":123}`, + }, + { + name: "from_json basic", + funcName: "from_json", + args: []interface{}{`{"name":"test","value":123}`}, + expected: map[string]interface{}{"name": "test", "value": float64(123)}, + }, + { + name: "json_extract basic", + funcName: "json_extract", + args: []interface{}{`{"name":"test","value":123}`, "$.name"}, + expected: "test", + }, + { + name: "json_valid true", + funcName: "json_valid", + args: []interface{}{`{"name":"test"}`}, + expected: true, + }, + { + name: "json_valid false", + funcName: "json_valid", + args: []interface{}{`{"name":"test"`}, + expected: false, + }, + { + name: "json_type object", + funcName: "json_type", + args: []interface{}{`{"name":"test"}`}, + expected: "object", + }, + { + name: "json_length array", + funcName: "json_length", + args: []interface{}{`[1,2,3]`}, + expected: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && !compareResults(result, tt.expected) { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} + +// 辅助函数:比较结果 +func compareResults(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + // 对于map类型的特殊处理 + if mapA, okA := a.(map[string]interface{}); okA { + if mapB, okB := b.(map[string]interface{}); okB { + if len(mapA) != len(mapB) { + return false + } + for k, v := range mapA { + if mapB[k] != v { + return false + } + } + return true + } + } + + return a == b +} \ No newline at end of file diff --git a/functions/functions_multirow.go b/functions/functions_multirow.go new file mode 100644 index 0000000..fe323ae --- /dev/null +++ b/functions/functions_multirow.go @@ -0,0 +1,109 @@ +package functions + +import ( + "fmt" + "reflect" +) + +// UnnestFunction 将数组展开为多行 +type UnnestFunction struct { + *BaseFunction +} + +func NewUnnestFunction() *UnnestFunction { + return &UnnestFunction{ + BaseFunction: NewBaseFunction("unnest", TypeString, "多行函数", "将数组展开为多行", 1, 1), + } +} + +func (f *UnnestFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *UnnestFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + array := args[0] + if array == nil { + return []interface{}{}, nil + } + + // 使用反射检查是否为数组或切片 + v := reflect.ValueOf(array) + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, fmt.Errorf("unnest requires an array or slice, got %T", array) + } + + // 转换为 []interface{} + result := make([]interface{}, v.Len()) + for i := 0; i < v.Len(); i++ { + elem := v.Index(i).Interface() + + // 如果数组元素是对象(map),则展开为列 + if elemMap, ok := elem.(map[string]interface{}); ok { + // 对于对象,我们返回一个特殊的结构来表示需要展开为列 + result[i] = map[string]interface{}{ + "__unnest_object__": true, + "__data__": elemMap, + } + } else { + result[i] = elem + } + } + + return result, nil +} + +// UnnestResult 表示 unnest 函数的结果 +type UnnestResult struct { + Rows []map[string]interface{} +} + +// IsUnnestResult 检查是否为 unnest 结果 +func IsUnnestResult(value interface{}) bool { + if slice, ok := value.([]interface{}); ok { + for _, item := range slice { + if itemMap, ok := item.(map[string]interface{}); ok { + if unnest, exists := itemMap["__unnest_object__"]; exists { + if unnestBool, ok := unnest.(bool); ok && unnestBool { + return true + } + } + } + } + } + return false +} + +// ProcessUnnestResult 处理 unnest 结果,将其转换为多行 +func ProcessUnnestResult(value interface{}) []map[string]interface{} { + slice, ok := value.([]interface{}) + if !ok { + return nil + } + + var rows []map[string]interface{} + for _, item := range slice { + if itemMap, ok := item.(map[string]interface{}); ok { + if unnest, exists := itemMap["__unnest_object__"]; exists { + if unnestBool, ok := unnest.(bool); ok && unnestBool { + if data, exists := itemMap["__data__"]; exists { + if dataMap, ok := data.(map[string]interface{}); ok { + rows = append(rows, dataMap) + } + } + continue + } + } + } + // 对于非对象元素,创建一个包含单个值的行 + row := map[string]interface{}{ + "value": item, + } + rows = append(rows, row) + } + + return rows +} \ No newline at end of file diff --git a/functions/functions_multirow_test.go b/functions/functions_multirow_test.go new file mode 100644 index 0000000..83f14b6 --- /dev/null +++ b/functions/functions_multirow_test.go @@ -0,0 +1,119 @@ +package functions + +import ( + "reflect" + "testing" +) + +// TestUnnestFunction 测试unnest函数 +func TestUnnestFunction(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + wantErr bool + }{ + // unnest 函数测试 + { + name: "unnest array", + funcName: "unnest", + args: []interface{}{[]interface{}{1, 2, 3}}, + expected: []interface{}{1, 2, 3}, + }, + { + name: "unnest empty array", + funcName: "unnest", + args: []interface{}{[]interface{}{}}, + expected: []interface{}{}, + }, + { + name: "unnest nil", + funcName: "unnest", + args: []interface{}{nil}, + expected: []interface{}{}, + }, + { + name: "unnest non-array", + funcName: "unnest", + args: []interface{}{"not an array"}, + expected: nil, + wantErr: true, + }, + + + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + } + }) + } +} + +// TestUnnestWithObjects 测试 unnest 函数处理对象数组 +func TestUnnestWithObjects(t *testing.T) { + fn, exists := Get("unnest") + if !exists { + t.Fatal("Function unnest not found") + } + + // 测试对象数组 + objectArray := []interface{}{ + map[string]interface{}{"name": "Alice", "age": 30}, + map[string]interface{}{"name": "Bob", "age": 25}, + } + + result, err := fn.Execute(&FunctionContext{}, []interface{}{objectArray}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + // 检查结果是否包含特殊标记 + resultSlice, ok := result.([]interface{}) + if !ok { + t.Fatalf("Expected []interface{}, got %T", result) + } + + if len(resultSlice) != 2 { + t.Fatalf("Expected 2 items, got %d", len(resultSlice)) + } + + // 检查第一个对象是否有特殊标记 + firstItem, ok := resultSlice[0].(map[string]interface{}) + if !ok { + t.Fatalf("Expected map[string]interface{}, got %T", resultSlice[0]) + } + + if unnestFlag, exists := firstItem["__unnest_object__"]; !exists || unnestFlag != true { + t.Error("Expected __unnest_object__ flag to be true") + } + + if data, exists := firstItem["__data__"]; !exists { + t.Error("Expected __data__ field to exist") + } else { + dataMap, ok := data.(map[string]interface{}) + if !ok { + t.Errorf("Expected __data__ to be map[string]interface{}, got %T", data) + } else { + if dataMap["name"] != "Alice" || dataMap["age"] != 30 { + t.Errorf("Unexpected data: %v", dataMap) + } + } + } +} \ No newline at end of file diff --git a/functions/functions_new_test.go b/functions/functions_new_test.go deleted file mode 100644 index c774c49..0000000 --- a/functions/functions_new_test.go +++ /dev/null @@ -1,339 +0,0 @@ -package functions - -import ( - "testing" -) - -// 测试JSON函数 -func TestJsonFunctions(t *testing.T) { - tests := []struct { - name string - funcName string - args []interface{} - expected interface{} - wantErr bool - }{ - { - name: "to_json basic", - funcName: "to_json", - args: []interface{}{map[string]interface{}{"name": "test", "value": 123}}, - expected: `{"name":"test","value":123}`, - }, - { - name: "from_json basic", - funcName: "from_json", - args: []interface{}{`{"name":"test","value":123}`}, - expected: map[string]interface{}{"name": "test", "value": float64(123)}, - }, - { - name: "json_extract basic", - funcName: "json_extract", - args: []interface{}{`{"name":"test","value":123}`, "$.name"}, - expected: "test", - }, - { - name: "json_valid true", - funcName: "json_valid", - args: []interface{}{`{"name":"test"}`}, - expected: true, - }, - { - name: "json_valid false", - funcName: "json_valid", - args: []interface{}{`{"name":"test"`}, - expected: false, - }, - { - name: "json_type object", - funcName: "json_type", - args: []interface{}{`{"name":"test"}`}, - expected: "object", - }, - { - name: "json_length array", - funcName: "json_length", - args: []interface{}{`[1,2,3]`}, - expected: 3, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fn, exists := Get(tt.funcName) - if !exists { - t.Fatalf("Function %s not found", tt.funcName) - } - - result, err := fn.Execute(&FunctionContext{}, tt.args) - if (err != nil) != tt.wantErr { - t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr && !compareResults(result, tt.expected) { - t.Errorf("Execute() = %v, want %v", result, tt.expected) - } - }) - } -} - -// 测试哈希函数 -func TestHashFunctions(t *testing.T) { - tests := []struct { - name string - funcName string - args []interface{} - expected interface{} - }{ - { - name: "md5 basic", - funcName: "md5", - args: []interface{}{"hello"}, - expected: "5d41402abc4b2a76b9719d911017c592", - }, - { - name: "sha1 basic", - funcName: "sha1", - args: []interface{}{"hello"}, - expected: "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", - }, - { - name: "sha256 basic", - funcName: "sha256", - args: []interface{}{"hello"}, - expected: "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fn, exists := Get(tt.funcName) - if !exists { - t.Fatalf("Function %s not found", tt.funcName) - } - - result, err := fn.Execute(&FunctionContext{}, tt.args) - if err != nil { - t.Errorf("Execute() error = %v", err) - return - } - - if result != tt.expected { - t.Errorf("Execute() = %v, want %v", result, tt.expected) - } - }) - } -} - -// 测试数组函数 -func TestArrayFunctions(t *testing.T) { - tests := []struct { - name string - funcName string - args []interface{} - expected interface{} - }{ - { - name: "array_length basic", - funcName: "array_length", - args: []interface{}{[]interface{}{1, 2, 3}}, - expected: 3, - }, - { - name: "array_contains true", - funcName: "array_contains", - args: []interface{}{[]interface{}{1, 2, 3}, 2}, - expected: true, - }, - { - name: "array_contains false", - funcName: "array_contains", - args: []interface{}{[]interface{}{1, 2, 3}, 4}, - expected: false, - }, - { - name: "array_position found", - funcName: "array_position", - args: []interface{}{[]interface{}{1, 2, 3}, 2}, - expected: 2, - }, - { - name: "array_position not found", - funcName: "array_position", - args: []interface{}{[]interface{}{1, 2, 3}, 4}, - expected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fn, exists := Get(tt.funcName) - if !exists { - t.Fatalf("Function %s not found", tt.funcName) - } - - result, err := fn.Execute(&FunctionContext{}, tt.args) - if err != nil { - t.Errorf("Execute() error = %v", err) - return - } - - if result != tt.expected { - t.Errorf("Execute() = %v, want %v", result, tt.expected) - } - }) - } -} - -// 测试类型检查函数 -func TestTypeFunctions(t *testing.T) { - tests := []struct { - name string - funcName string - args []interface{} - expected interface{} - }{ - { - name: "is_null true", - funcName: "is_null", - args: []interface{}{nil}, - expected: true, - }, - { - name: "is_null false", - funcName: "is_null", - args: []interface{}{"test"}, - expected: false, - }, - { - name: "is_numeric true", - funcName: "is_numeric", - args: []interface{}{123}, - expected: true, - }, - { - name: "is_numeric false", - funcName: "is_numeric", - args: []interface{}{"test"}, - expected: false, - }, - { - name: "is_string true", - funcName: "is_string", - args: []interface{}{"test"}, - expected: true, - }, - { - name: "is_string false", - funcName: "is_string", - args: []interface{}{123}, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fn, exists := Get(tt.funcName) - if !exists { - t.Fatalf("Function %s not found", tt.funcName) - } - - result, err := fn.Execute(&FunctionContext{}, tt.args) - if err != nil { - t.Errorf("Execute() error = %v", err) - return - } - - if result != tt.expected { - t.Errorf("Execute() = %v, want %v", result, tt.expected) - } - }) - } -} - -// 测试条件函数 -func TestConditionalFunctions(t *testing.T) { - tests := []struct { - name string - funcName string - args []interface{} - expected interface{} - }{ - { - name: "coalesce first non-null", - funcName: "coalesce", - args: []interface{}{nil, "test", "other"}, - expected: "test", - }, - { - name: "nullif equal", - funcName: "nullif", - args: []interface{}{"test", "test"}, - expected: nil, - }, - { - name: "nullif not equal", - funcName: "nullif", - args: []interface{}{"test", "other"}, - expected: "test", - }, - { - name: "greatest numeric", - funcName: "greatest", - args: []interface{}{1, 3, 2}, - expected: 3, - }, - { - name: "least numeric", - funcName: "least", - args: []interface{}{3, 1, 2}, - expected: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fn, exists := Get(tt.funcName) - if !exists { - t.Fatalf("Function %s not found", tt.funcName) - } - - result, err := fn.Execute(&FunctionContext{}, tt.args) - if err != nil { - t.Errorf("Execute() error = %v", err) - return - } - - if !compareResults(result, tt.expected) { - t.Errorf("Execute() = %v, want %v", result, tt.expected) - } - }) - } -} - -// 辅助函数:比较结果 -func compareResults(a, b interface{}) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - - // 对于map类型的特殊处理 - if mapA, okA := a.(map[string]interface{}); okA { - if mapB, okB := b.(map[string]interface{}); okB { - if len(mapA) != len(mapB) { - return false - } - for k, v := range mapA { - if mapB[k] != v { - return false - } - } - return true - } - } - - return a == b -} \ No newline at end of file diff --git a/functions/functions_string.go b/functions/functions_string.go index 4ba1353..4a55f03 100644 --- a/functions/functions_string.go +++ b/functions/functions_string.go @@ -2,9 +2,10 @@ package functions import ( "fmt" - "github.com/rulego/streamsql/utils/cast" "regexp" "strings" + + "github.com/rulego/streamsql/utils/cast" ) // ConcatFunction 字符串连接函数 @@ -292,26 +293,26 @@ func (f *SubstringFunction) Execute(ctx *FunctionContext, args []interface{}) (i if err != nil { return nil, err } - + strLen := int64(len(str)) if start < 0 || start >= strLen { return "", nil } - + if len(args) == 2 { return str[start:], nil } - + length, err := cast.ToInt64E(args[2]) if err != nil { return nil, err } - + end := start + length if end > strLen { end = strLen } - + return str[start:end], nil } @@ -397,7 +398,7 @@ func (f *LpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf if err != nil { return nil, err } - + pad := " " if len(args) == 3 { pad, err = cast.ToStringE(args[2]) @@ -405,12 +406,12 @@ func (f *LpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf return nil, err } } - + strLen := int64(len(str)) if strLen >= length { return str, nil } - + padLen := length - strLen padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1)) return padStr[:padLen] + str, nil @@ -440,7 +441,7 @@ func (f *RpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf if err != nil { return nil, err } - + pad := " " if len(args) == 3 { pad, err = cast.ToStringE(args[2]) @@ -448,12 +449,12 @@ func (f *RpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf return nil, err } } - + strLen := int64(len(str)) if strLen >= length { return str, nil } - + padLen := length - strLen padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1)) return str + padStr[:padLen], nil @@ -533,7 +534,7 @@ func (f *RegexpMatchesFunction) Execute(ctx *FunctionContext, args []interface{} if err != nil { return nil, err } - + matched, err := regexp.MatchString(pattern, str) if err != nil { return nil, err @@ -569,7 +570,7 @@ func (f *RegexpReplaceFunction) Execute(ctx *FunctionContext, args []interface{} if err != nil { return nil, err } - + re, err := regexp.Compile(pattern) if err != nil { return nil, err @@ -601,12 +602,12 @@ func (f *RegexpSubstringFunction) Execute(ctx *FunctionContext, args []interface if err != nil { return nil, err } - + re, err := regexp.Compile(pattern) if err != nil { return nil, err } - + match := re.FindString(str) return match, nil } diff --git a/functions/functions_string_test.go b/functions/functions_string_test.go index a846a18..55859a4 100644 --- a/functions/functions_string_test.go +++ b/functions/functions_string_test.go @@ -16,88 +16,88 @@ func TestNewStringFunctions(t *testing.T) { {"endswith_true", "endswith", []interface{}{"hello world", "world"}, true, false}, {"endswith_false", "endswith", []interface{}{"hello world", "hello"}, false, false}, {"endswith_empty", "endswith", []interface{}{"hello", ""}, true, false}, - + // startswith tests {"startswith_true", "startswith", []interface{}{"hello world", "hello"}, true, false}, {"startswith_false", "startswith", []interface{}{"hello world", "world"}, false, false}, {"startswith_empty", "startswith", []interface{}{"hello", ""}, true, false}, - + // indexof tests {"indexof_found", "indexof", []interface{}{"hello world", "world"}, int64(6), false}, {"indexof_not_found", "indexof", []interface{}{"hello world", "xyz"}, int64(-1), false}, {"indexof_first_char", "indexof", []interface{}{"hello", "h"}, int64(0), false}, - + // substring tests {"substring_start_only", "substring", []interface{}{"hello world", int64(6)}, "world", false}, {"substring_start_length", "substring", []interface{}{"hello world", int64(0), int64(5)}, "hello", false}, {"substring_out_of_bounds", "substring", []interface{}{"hello", int64(10)}, "", false}, - + // replace tests {"replace_simple", "replace", []interface{}{"hello world", "world", "Go"}, "hello Go", false}, {"replace_multiple", "replace", []interface{}{"hello hello", "hello", "hi"}, "hi hi", false}, {"replace_not_found", "replace", []interface{}{"hello world", "xyz", "abc"}, "hello world", false}, - + // split tests {"split_comma", "split", []interface{}{"a,b,c", ","}, []string{"a", "b", "c"}, false}, {"split_space", "split", []interface{}{"hello world", " "}, []string{"hello", "world"}, false}, {"split_not_found", "split", []interface{}{"hello", ","}, []string{"hello"}, false}, - + // lpad tests {"lpad_default", "lpad", []interface{}{"hello", int64(10)}, " hello", false}, {"lpad_custom", "lpad", []interface{}{"hello", int64(8), "*"}, "***hello", false}, {"lpad_no_padding", "lpad", []interface{}{"hello", int64(3)}, "hello", false}, - + // rpad tests {"rpad_default", "rpad", []interface{}{"hello", int64(10)}, "hello ", false}, {"rpad_custom", "rpad", []interface{}{"hello", int64(8), "*"}, "hello***", false}, {"rpad_no_padding", "rpad", []interface{}{"hello", int64(3)}, "hello", false}, - + // ltrim tests {"ltrim_spaces", "ltrim", []interface{}{" hello world "}, "hello world ", false}, {"ltrim_tabs", "ltrim", []interface{}{"\t\nhello"}, "hello", false}, {"ltrim_no_whitespace", "ltrim", []interface{}{"hello"}, "hello", false}, - + // rtrim tests {"rtrim_spaces", "rtrim", []interface{}{" hello world "}, " hello world", false}, {"rtrim_tabs", "rtrim", []interface{}{"hello\t\n"}, "hello", false}, {"rtrim_no_whitespace", "rtrim", []interface{}{"hello"}, "hello", false}, - + // regexp_matches tests {"regexp_matches_true", "regexp_matches", []interface{}{"hello123", "[0-9]+"}, true, false}, {"regexp_matches_false", "regexp_matches", []interface{}{"hello", "[0-9]+"}, false, false}, {"regexp_matches_email", "regexp_matches", []interface{}{"test@example.com", "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"}, true, false}, - + // regexp_replace tests {"regexp_replace_digits", "regexp_replace", []interface{}{"hello123world456", "[0-9]+", "X"}, "helloXworldX", false}, {"regexp_replace_no_match", "regexp_replace", []interface{}{"hello", "[0-9]+", "X"}, "hello", false}, - + // regexp_substring tests {"regexp_substring_found", "regexp_substring", []interface{}{"hello123world", "[0-9]+"}, "123", false}, {"regexp_substring_not_found", "regexp_substring", []interface{}{"hello", "[0-9]+"}, "", false}, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.funcName) if !exists { t.Fatalf("Function %s not found", tt.funcName) } - + ctx := &FunctionContext{} result, err := fn.Execute(ctx, tt.args) - + if tt.wantErr { if err == nil { t.Errorf("Expected error for %s, got nil", tt.name) } return } - + if err != nil { t.Errorf("Unexpected error for %s: %v", tt.name, err) return } - + // 特殊处理 split 函数的结果比较 if tt.funcName == "split" { expectedSlice, ok := tt.expected.([]string) @@ -127,4 +127,4 @@ func TestNewStringFunctions(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/functions/functions_type.go b/functions/functions_type.go index f1cd2c9..d9a6cb9 100644 --- a/functions/functions_type.go +++ b/functions/functions_type.go @@ -61,7 +61,7 @@ func (f *IsNumericFunction) Execute(ctx *FunctionContext, args []interface{}) (i if args[0] == nil { return false, nil } - + v := reflect.ValueOf(args[0]) switch v.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, @@ -92,7 +92,7 @@ func (f *IsStringFunction) Execute(ctx *FunctionContext, args []interface{}) (in if args[0] == nil { return false, nil } - + _, ok := args[0].(string) return ok, nil } @@ -116,7 +116,7 @@ func (f *IsBoolFunction) Execute(ctx *FunctionContext, args []interface{}) (inte if args[0] == nil { return false, nil } - + _, ok := args[0].(bool) return ok, nil } @@ -140,7 +140,7 @@ func (f *IsArrayFunction) Execute(ctx *FunctionContext, args []interface{}) (int if args[0] == nil { return false, nil } - + v := reflect.ValueOf(args[0]) return v.Kind() == reflect.Slice || v.Kind() == reflect.Array, nil } @@ -164,7 +164,7 @@ func (f *IsObjectFunction) Execute(ctx *FunctionContext, args []interface{}) (in if args[0] == nil { return false, nil } - + v := reflect.ValueOf(args[0]) return v.Kind() == reflect.Map || v.Kind() == reflect.Struct, nil -} \ No newline at end of file +} diff --git a/functions/functions_type_test.go b/functions/functions_type_test.go new file mode 100644 index 0000000..31f0b04 --- /dev/null +++ b/functions/functions_type_test.go @@ -0,0 +1,95 @@ +package functions + +import ( + "testing" +) + +// 测试类型检查函数 +func TestTypeFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + }{ + { + name: "is_null true", + funcName: "is_null", + args: []interface{}{nil}, + expected: true, + }, + { + name: "is_null false", + funcName: "is_null", + args: []interface{}{"test"}, + expected: false, + }, + { + name: "is_not_null true", + funcName: "is_not_null", + args: []interface{}{"test"}, + expected: true, + }, + { + name: "is_not_null false", + funcName: "is_not_null", + args: []interface{}{nil}, + expected: false, + }, + { + name: "is_numeric true", + funcName: "is_numeric", + args: []interface{}{123}, + expected: true, + }, + { + name: "is_numeric false", + funcName: "is_numeric", + args: []interface{}{"test"}, + expected: false, + }, + { + name: "is_string true", + funcName: "is_string", + args: []interface{}{"test"}, + expected: true, + }, + { + name: "is_string false", + funcName: "is_string", + args: []interface{}{123}, + expected: false, + }, + { + name: "is_bool true", + funcName: "is_bool", + args: []interface{}{true}, + expected: true, + }, + { + name: "is_bool false", + funcName: "is_bool", + args: []interface{}{"test"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if err != nil { + t.Errorf("Execute() error = %v", err) + return + } + + if result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_window.go b/functions/functions_window.go index c2067b5..c87b067 100644 --- a/functions/functions_window.go +++ b/functions/functions_window.go @@ -303,34 +303,69 @@ func (f *FirstValueFunction) Clone() AggregatorFunction { // LeadFunction 返回当前行之后第N行的值 type LeadFunction struct { *BaseFunction - values []interface{} + values []interface{} + offset int + defaultValue interface{} + hasDefault bool } func NewLeadFunction() *LeadFunction { return &LeadFunction{ BaseFunction: NewBaseFunction("lead", TypeWindow, "窗口函数", "返回当前行之后第N行的值", 1, 3), values: make([]interface{}, 0), + offset: 1, // 默认偏移量为1 } } func (f *LeadFunction) Validate(args []interface{}) error { - return f.ValidateArgCount(args) + if err := f.ValidateArgCount(args); err != nil { + return err + } + + // 验证第二个参数(offset)是否为整数 + if len(args) >= 2 { + if offset, ok := args[1].(int); ok { + f.offset = offset + } else { + return fmt.Errorf("offset must be an integer") + } + } + + // 设置默认值 + if len(args) >= 3 { + f.defaultValue = args[2] + f.hasDefault = true + } + + return nil } func (f *LeadFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { if err := f.Validate(args); err != nil { return nil, err } - - // 获取默认值 - var defaultValue interface{} - if len(args) >= 3 { - defaultValue = args[2] + + // 获取偏移量 + if len(args) >= 2 { + if offset, ok := args[1].(int); ok { + f.offset = offset + } else { + return nil, fmt.Errorf("offset must be an integer") + } } - + + // 获取默认值 + if len(args) >= 3 { + f.defaultValue = args[2] + f.hasDefault = true + } + // Lead函数需要在窗口处理完成后才能确定值 // 这里返回默认值,实际实现需要在窗口引擎中处理 - return defaultValue, nil + if f.hasDefault { + return f.defaultValue, nil + } + return nil, nil } // 实现AggregatorFunction接口 @@ -338,6 +373,9 @@ func (f *LeadFunction) New() AggregatorFunction { return &LeadFunction{ BaseFunction: f.BaseFunction, values: make([]interface{}, 0), + offset: f.offset, + defaultValue: f.defaultValue, + hasDefault: f.hasDefault, } } @@ -347,18 +385,28 @@ func (f *LeadFunction) Add(value interface{}) { func (f *LeadFunction) Result() interface{} { // Lead函数的结果需要在所有数据添加完成后计算 + // 如果没有足够的数据,返回默认值 + if len(f.values) == 0 && f.hasDefault { + return f.defaultValue + } // 这里简化实现,返回nil return nil } func (f *LeadFunction) Reset() { f.values = make([]interface{}, 0) + f.offset = 1 + f.defaultValue = nil + f.hasDefault = false } func (f *LeadFunction) Clone() AggregatorFunction { clone := &LeadFunction{ BaseFunction: f.BaseFunction, values: make([]interface{}, len(f.values)), + offset: f.offset, + defaultValue: f.defaultValue, + hasDefault: f.hasDefault, } copy(clone.values, f.values) return clone @@ -380,14 +428,35 @@ func NewNthValueFunction() *NthValueFunction { } func (f *NthValueFunction) Validate(args []interface{}) error { - return f.ValidateArgCount(args) + if err := f.ValidateArgCount(args); err != nil { + return err + } + + // 验证N值 + n := 1 + if nVal, ok := args[1].(int); ok { + n = nVal + } else if nVal, ok := args[1].(int64); ok { + n = int(nVal) + } else { + return fmt.Errorf("nth_value n must be an integer") + } + + if n <= 0 { + return fmt.Errorf("nth_value n must be positive, got %d", n) + } + + // 设置n值 + f.n = n + + return nil } func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { if err := f.Validate(args); err != nil { return nil, err } - + // 获取N值 n := 1 if nVal, ok := args[1].(int); ok { @@ -397,16 +466,16 @@ func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (in } else { return nil, fmt.Errorf("nth_value n must be an integer") } - + if n <= 0 { return nil, fmt.Errorf("nth_value n must be positive, got %d", n) } - + // 返回第N个值(1-based索引) if len(f.values) >= n { return f.values[n-1], nil } - + return nil, nil } diff --git a/functions/functions_window_test.go b/functions/functions_window_test.go index 531c32a..451e2e9 100644 --- a/functions/functions_window_test.go +++ b/functions/functions_window_test.go @@ -4,6 +4,20 @@ import ( "testing" ) +// isWindowFunction 判断是否为窗口函数 +func isWindowFunction(funcName string) bool { + windowFunctions := map[string]bool{ + "row_number": true, + "window_start": true, + "window_end": true, + "lead": true, + "lag": true, + "first_value": true, + "last_value": true, + } + return windowFunctions[funcName] +} + func TestNewWindowFunctions(t *testing.T) { tests := []struct { name string @@ -222,8 +236,19 @@ func TestNewWindowFunctions(t *testing.T) { tt.setup(aggInstance) } - // 执行函数 - _, err = fn.Execute(nil, tt.args) + // 对于窗口函数测试,不需要调用Execute方法 + // Execute方法主要用于流式处理,这里我们直接测试聚合器的Result方法 + // 如果需要测试Execute方法,应该在原始函数实例上调用 + if !isWindowFunction(tt.funcName) { + // 对于非窗口函数,在聚合器实例上执行 + if aggFunc, ok := aggInstance.(Function); ok { + _, err = aggFunc.Execute(nil, tt.args) + } else { + // 执行函数 + _, err = fn.Execute(nil, tt.args) + } + } + if (err != nil) != tt.wantErr { t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) return @@ -249,6 +274,11 @@ func TestWindowFunctionBasics(t *testing.T) { t.Fatal("row_number function not found") } + // 重置函数状态 + if rowNum, ok := rowNumFunc.(*RowNumberFunction); ok { + rowNum.Reset() + } + // 测试行号递增 result1, err := rowNumFunc.Execute(nil, []interface{}{}) if err != nil {