diff --git a/functions/analytical_aggregator_adapter.go b/functions/analytical_aggregator_adapter.go index 6c2810b..613fc37 100644 --- a/functions/analytical_aggregator_adapter.go +++ b/functions/analytical_aggregator_adapter.go @@ -23,6 +23,18 @@ func NewAnalyticalAggregatorAdapter(name string) (*AnalyticalAggregatorAdapter, // New 创建新的适配器实例 func (a *AnalyticalAggregatorAdapter) New() interface{} { + // 对于实现了AggregatorFunction接口的函数,使用其New方法 + if aggFunc, ok := a.analFunc.(AggregatorFunction); ok { + newAnalFunc := aggFunc.New().(AnalyticalFunction) + return &AnalyticalAggregatorAdapter{ + analFunc: newAnalFunc, + ctx: &FunctionContext{ + Data: make(map[string]interface{}), + }, + } + } + + // 对于其他分析函数,使用Clone方法 return &AnalyticalAggregatorAdapter{ analFunc: a.analFunc.Clone(), ctx: &FunctionContext{ @@ -33,7 +45,13 @@ func (a *AnalyticalAggregatorAdapter) New() interface{} { // Add 添加值 func (a *AnalyticalAggregatorAdapter) Add(value interface{}) { - // 执行分析函数 + // 对于实现了AggregatorFunction接口的函数,直接调用Add方法 + if aggFunc, ok := a.analFunc.(AggregatorFunction); ok { + aggFunc.Add(value) + return + } + + // 对于其他分析函数,执行分析函数 args := []interface{}{value} a.analFunc.Execute(a.ctx, args) } @@ -50,6 +68,11 @@ func (a *AnalyticalAggregatorAdapter) Result() interface{} { return hadChangedFunc.IsSet } + // 对于LagFunction,调用其Result方法 + if lagFunc, ok := a.analFunc.(*LagFunction); ok { + return lagFunc.Result() + } + // 对于其他分析函数,尝试执行一次来获取当前状态的结果 // 这里传入nil作为参数,表示获取当前状态 result, _ := a.analFunc.Execute(a.ctx, []interface{}{nil}) diff --git a/functions/builtin.go b/functions/builtin.go index 57880b8..0d365ca 100644 --- a/functions/builtin.go +++ b/functions/builtin.go @@ -21,6 +21,17 @@ func registerBuiltinFunctions() { _ = Register(NewExpFunction()) _ = Register(NewFloorFunction()) _ = Register(NewLnFunction()) + _ = Register(NewLogFunction()) + _ = Register(NewLog10Function()) + _ = Register(NewLog2Function()) + _ = Register(NewModFunction()) + _ = Register(NewRandFunction()) + _ = Register(NewRoundFunction()) + _ = Register(NewSignFunction()) + _ = Register(NewSinFunction()) + _ = Register(NewSinhFunction()) + _ = Register(NewTanFunction()) + _ = Register(NewTanhFunction()) _ = Register(NewPowerFunction()) // String functions @@ -30,6 +41,19 @@ func registerBuiltinFunctions() { _ = Register(NewLowerFunction()) _ = Register(NewTrimFunction()) _ = Register(NewFormatFunction()) + _ = Register(NewEndswithFunction()) + _ = Register(NewStartswithFunction()) + _ = Register(NewIndexofFunction()) + _ = Register(NewSubstringFunction()) + _ = Register(NewReplaceFunction()) + _ = Register(NewSplitFunction()) + _ = Register(NewLpadFunction()) + _ = Register(NewRpadFunction()) + _ = Register(NewLtrimFunction()) + _ = Register(NewRtrimFunction()) + _ = Register(NewRegexpMatchesFunction()) + _ = Register(NewRegexpReplaceFunction()) + _ = Register(NewRegexpSubstringFunction()) // Conversion functions _ = Register(NewCastFunction()) @@ -37,6 +61,12 @@ func registerBuiltinFunctions() { _ = Register(NewDec2HexFunction()) _ = Register(NewEncodeFunction()) _ = Register(NewDecodeFunction()) + _ = Register(NewConvertTzFunction()) + _ = Register(NewToSecondsFunction()) + _ = Register(NewChrFunction()) + _ = Register(NewTruncFunction()) + _ = Register(NewCompressFunction()) + _ = Register(NewDecompressFunction()) // Time-Date functions _ = Register(NewNowFunction()) @@ -62,6 +92,9 @@ func registerBuiltinFunctions() { // Window functions _ = Register(NewRowNumberFunction()) + _ = Register(NewFirstValueFunction()) + _ = Register(NewLeadFunction()) + _ = Register(NewNthValueFunction()) // Analytical functions _ = Register(NewLagFunction()) @@ -77,6 +110,45 @@ func registerBuiltinFunctions() { _ = Register(NewExpressionFunction()) _ = Register(NewExprFunction()) + // JSON functions + _ = Register(NewToJsonFunction()) + _ = Register(NewFromJsonFunction()) + _ = Register(NewJsonExtractFunction()) + _ = Register(NewJsonValidFunction()) + _ = Register(NewJsonTypeFunction()) + _ = Register(NewJsonLengthFunction()) + + // Hash functions + _ = Register(NewMd5Function()) + _ = Register(NewSha1Function()) + _ = Register(NewSha256Function()) + _ = Register(NewSha512Function()) + + // Array functions + _ = Register(NewArrayLengthFunction()) + _ = Register(NewArrayContainsFunction()) + _ = Register(NewArrayPositionFunction()) + _ = Register(NewArrayRemoveFunction()) + _ = Register(NewArrayDistinctFunction()) + _ = Register(NewArrayIntersectFunction()) + _ = Register(NewArrayUnionFunction()) + _ = Register(NewArrayExceptFunction()) + + // Type checking functions + _ = Register(NewIsNullFunction()) + _ = Register(NewIsNotNullFunction()) + _ = Register(NewIsNumericFunction()) + _ = Register(NewIsStringFunction()) + _ = Register(NewIsBoolFunction()) + _ = Register(NewIsArrayFunction()) + _ = Register(NewIsObjectFunction()) + + // Conditional functions + _ = Register(NewCoalesceFunction()) + _ = Register(NewNullIfFunction()) + _ = Register(NewGreatestFunction()) + _ = Register(NewLeastFunction()) + // User-defined functions (placeholder for future extension) // Example: _=Register(NewMyUserDefinedFunction()) } diff --git a/functions/functions_analytical.go b/functions/functions_analytical.go index ea561e6..d57ee39 100644 --- a/functions/functions_analytical.go +++ b/functions/functions_analytical.go @@ -45,11 +45,6 @@ func (f *LagFunction) Validate(args []interface{}) error { } func (f *LagFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { - // 确保Offset有默认值 - if f.Offset <= 0 { - f.Offset = 1 - } - currentValue := args[0] var result interface{} @@ -75,12 +70,18 @@ func (f *LagFunction) Reset() { // 实现AggregatorFunction接口 - 增量计算支持 func (f *LagFunction) New() AggregatorFunction { - return &LagFunction{ + // 确保Offset有默认值 + offset := f.Offset + if offset <= 0 { + offset = 1 + } + newFunc := &LagFunction{ BaseFunction: f.BaseFunction, DefaultValue: f.DefaultValue, - Offset: f.Offset, + Offset: offset, PreviousValues: make([]interface{}, 0), } + return newFunc } func (f *LagFunction) Add(value interface{}) { @@ -93,10 +94,14 @@ func (f *LagFunction) Add(value interface{}) { } func (f *LagFunction) Result() interface{} { - if len(f.PreviousValues)-1 < f.Offset { + // 检查是否有足够的历史值 + if len(f.PreviousValues) <= f.Offset { return f.DefaultValue } - return f.PreviousValues[len(f.PreviousValues)-1-f.Offset] + // 返回当前值之前第Offset个值 + // 对于数组[first, second, third],当前位置是最后一个元素 + // offset=1时返回second(倒数第2个),offset=2时返回first(倒数第3个) + return f.PreviousValues[len(f.PreviousValues)-f.Offset-1] } func (f *LagFunction) Clone() AggregatorFunction { diff --git a/functions/functions_array.go b/functions/functions_array.go new file mode 100644 index 0000000..546c43b --- /dev/null +++ b/functions/functions_array.go @@ -0,0 +1,314 @@ +package functions + +import ( + "fmt" + "reflect" +) + +// ArrayLengthFunction 返回数组长度 +type ArrayLengthFunction struct { + *BaseFunction +} + +func NewArrayLengthFunction() *ArrayLengthFunction { + return &ArrayLengthFunction{ + BaseFunction: NewBaseFunction("array_length", TypeMath, "数组函数", "返回数组长度", 1, 1), + } +} + +func (f *ArrayLengthFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayLengthFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array := args[0] + v := reflect.ValueOf(array) + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, fmt.Errorf("array_length requires array input") + } + return v.Len(), nil +} + +// ArrayContainsFunction 检查数组是否包含指定值 +type ArrayContainsFunction struct { + *BaseFunction +} + +func NewArrayContainsFunction() *ArrayContainsFunction { + return &ArrayContainsFunction{ + BaseFunction: NewBaseFunction("array_contains", TypeString, "数组函数", "检查数组是否包含指定值", 2, 2), + } +} + +func (f *ArrayContainsFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayContainsFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array := args[0] + value := args[1] + + v := reflect.ValueOf(array) + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, fmt.Errorf("array_contains requires array input") + } + + for i := 0; i < v.Len(); i++ { + if reflect.DeepEqual(v.Index(i).Interface(), value) { + return true, nil + } + } + return false, nil +} + +// ArrayPositionFunction 返回值在数组中的位置 +type ArrayPositionFunction struct { + *BaseFunction +} + +func NewArrayPositionFunction() *ArrayPositionFunction { + return &ArrayPositionFunction{ + BaseFunction: NewBaseFunction("array_position", TypeMath, "数组函数", "返回值在数组中的位置", 2, 2), + } +} + +func (f *ArrayPositionFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayPositionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array := args[0] + value := args[1] + + v := reflect.ValueOf(array) + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, fmt.Errorf("array_position requires array input") + } + + for i := 0; i < v.Len(); i++ { + if reflect.DeepEqual(v.Index(i).Interface(), value) { + return i + 1, nil // 返回1基索引 + } + } + return 0, nil // 未找到返回0 +} + +// ArrayRemoveFunction 从数组中移除指定值 +type ArrayRemoveFunction struct { + *BaseFunction +} + +func NewArrayRemoveFunction() *ArrayRemoveFunction { + return &ArrayRemoveFunction{ + BaseFunction: NewBaseFunction("array_remove", TypeString, "数组函数", "从数组中移除指定值", 2, 2), + } +} + +func (f *ArrayRemoveFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayRemoveFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array := args[0] + value := args[1] + + v := reflect.ValueOf(array) + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, fmt.Errorf("array_remove requires array input") + } + + var result []interface{} + for i := 0; i < v.Len(); i++ { + elem := v.Index(i).Interface() + if !reflect.DeepEqual(elem, value) { + result = append(result, elem) + } + } + return result, nil +} + +// ArrayDistinctFunction 数组去重 +type ArrayDistinctFunction struct { + *BaseFunction +} + +func NewArrayDistinctFunction() *ArrayDistinctFunction { + return &ArrayDistinctFunction{ + BaseFunction: NewBaseFunction("array_distinct", TypeString, "数组函数", "数组去重", 1, 1), + } +} + +func (f *ArrayDistinctFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayDistinctFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array := args[0] + + v := reflect.ValueOf(array) + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return nil, fmt.Errorf("array_distinct requires array input") + } + + seen := make(map[interface{}]bool) + var result []interface{} + + for i := 0; i < v.Len(); i++ { + elem := v.Index(i).Interface() + if !seen[elem] { + seen[elem] = true + result = append(result, elem) + } + } + return result, nil +} + +// ArrayIntersectFunction 数组交集 +type ArrayIntersectFunction struct { + *BaseFunction +} + +func NewArrayIntersectFunction() *ArrayIntersectFunction { + return &ArrayIntersectFunction{ + BaseFunction: NewBaseFunction("array_intersect", TypeString, "数组函数", "数组交集", 2, 2), + } +} + +func (f *ArrayIntersectFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayIntersectFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array1 := args[0] + array2 := args[1] + + v1 := reflect.ValueOf(array1) + v2 := reflect.ValueOf(array2) + + if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array { + return nil, fmt.Errorf("array_intersect requires array input for first argument") + } + if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array { + return nil, fmt.Errorf("array_intersect requires array input for second argument") + } + + // 创建第二个数组的元素集合 + set2 := make(map[interface{}]bool) + for i := 0; i < v2.Len(); i++ { + set2[v2.Index(i).Interface()] = true + } + + // 找交集 + seen := make(map[interface{}]bool) + var result []interface{} + + for i := 0; i < v1.Len(); i++ { + elem := v1.Index(i).Interface() + if set2[elem] && !seen[elem] { + seen[elem] = true + result = append(result, elem) + } + } + return result, nil +} + +// ArrayUnionFunction 数组并集 +type ArrayUnionFunction struct { + *BaseFunction +} + +func NewArrayUnionFunction() *ArrayUnionFunction { + return &ArrayUnionFunction{ + BaseFunction: NewBaseFunction("array_union", TypeString, "数组函数", "数组并集", 2, 2), + } +} + +func (f *ArrayUnionFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayUnionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array1 := args[0] + array2 := args[1] + + v1 := reflect.ValueOf(array1) + v2 := reflect.ValueOf(array2) + + if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array { + return nil, fmt.Errorf("array_union requires array input for first argument") + } + if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array { + return nil, fmt.Errorf("array_union requires array input for second argument") + } + + seen := make(map[interface{}]bool) + var result []interface{} + + // 添加第一个数组的元素 + for i := 0; i < v1.Len(); i++ { + elem := v1.Index(i).Interface() + if !seen[elem] { + seen[elem] = true + result = append(result, elem) + } + } + + // 添加第二个数组的元素 + for i := 0; i < v2.Len(); i++ { + elem := v2.Index(i).Interface() + if !seen[elem] { + seen[elem] = true + result = append(result, elem) + } + } + return result, nil +} + +// ArrayExceptFunction 数组差集 +type ArrayExceptFunction struct { + *BaseFunction +} + +func NewArrayExceptFunction() *ArrayExceptFunction { + return &ArrayExceptFunction{ + BaseFunction: NewBaseFunction("array_except", TypeString, "数组函数", "数组差集", 2, 2), + } +} + +func (f *ArrayExceptFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ArrayExceptFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + array1 := args[0] + array2 := args[1] + + v1 := reflect.ValueOf(array1) + v2 := reflect.ValueOf(array2) + + if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array { + return nil, fmt.Errorf("array_except requires array input for first argument") + } + if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array { + return nil, fmt.Errorf("array_except requires array input for second argument") + } + + // 创建第二个数组的元素集合 + set2 := make(map[interface{}]bool) + for i := 0; i < v2.Len(); i++ { + set2[v2.Index(i).Interface()] = true + } + + // 找差集 + seen := make(map[interface{}]bool) + var result []interface{} + + for i := 0; i < v1.Len(); i++ { + elem := v1.Index(i).Interface() + if !set2[elem] && !seen[elem] { + seen[elem] = true + result = append(result, elem) + } + } + return result, nil +} \ No newline at end of file diff --git a/functions/functions_compression_test.go b/functions/functions_compression_test.go new file mode 100644 index 0000000..b27f285 --- /dev/null +++ b/functions/functions_compression_test.go @@ -0,0 +1,141 @@ +package functions + +import ( + "testing" +) + +func TestCompressionFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + wantErr bool + }{ + // Compress function tests + { + name: "compress_gzip_valid", + funcName: "compress", + args: []interface{}{"hello world", "gzip"}, + wantErr: false, + }, + { + name: "compress_zlib_valid", + funcName: "compress", + args: []interface{}{"hello world", "zlib"}, + wantErr: false, + }, + { + name: "compress_invalid_algorithm", + funcName: "compress", + args: []interface{}{"hello world", "invalid"}, + wantErr: true, + }, + { + name: "compress_empty_string", + funcName: "compress", + args: []interface{}{"", "gzip"}, + wantErr: false, + }, + { + name: "compress_wrong_arg_count", + funcName: "compress", + args: []interface{}{"hello"}, + wantErr: true, + }, + // Decompress function tests + { + name: "decompress_invalid_base64", + funcName: "decompress", + args: []interface{}{"invalid_base64", "gzip"}, + wantErr: true, + }, + { + name: "decompress_invalid_algorithm", + funcName: "decompress", + args: []interface{}{"SGVsbG8gV29ybGQ=", "invalid"}, + wantErr: true, + }, + { + name: "decompress_wrong_arg_count", + funcName: "decompress", + args: []interface{}{"SGVsbG8gV29ybGQ="}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("%s function not found", tt.funcName) + } + + // 执行函数 + _, err := fn.Execute(nil, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestCompressionDecompressionRoundTrip(t *testing.T) { + tests := []struct { + name string + algorithm string + input string + }{ + { + name: "gzip_round_trip", + algorithm: "gzip", + input: "Hello, World! This is a test string for compression.", + }, + { + name: "zlib_round_trip", + algorithm: "zlib", + input: "Hello, World! This is a test string for compression.", + }, + { + name: "gzip_empty_string", + algorithm: "gzip", + input: "", + }, + { + name: "zlib_unicode", + algorithm: "zlib", + input: "你好世界!这是一个测试字符串。", + }, + } + + compressFn, exists := Get("compress") + if !exists { + t.Fatal("compress function not found") + } + + decompressFn, exists := Get("decompress") + if !exists { + t.Fatal("decompress function not found") + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 压缩 + compressed, err := compressFn.Execute(nil, []interface{}{tt.input, tt.algorithm}) + if err != nil { + t.Fatalf("Compress failed: %v", err) + } + + // 解压缩 + decompressed, err := decompressFn.Execute(nil, []interface{}{compressed, tt.algorithm}) + if err != nil { + t.Fatalf("Decompress failed: %v", err) + } + + // 验证结果 + if decompressed != tt.input { + t.Errorf("Round trip failed: expected %q, got %q", tt.input, decompressed) + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_conditional.go b/functions/functions_conditional.go new file mode 100644 index 0000000..2cb94f5 --- /dev/null +++ b/functions/functions_conditional.go @@ -0,0 +1,154 @@ +package functions + +import ( + "fmt" + "reflect" + + "github.com/rulego/streamsql/utils/cast" +) + +// CoalesceFunction 返回第一个非NULL值 +type CoalesceFunction struct { + *BaseFunction +} + +func NewCoalesceFunction() *CoalesceFunction { + return &CoalesceFunction{ + BaseFunction: NewBaseFunction("coalesce", TypeString, "条件函数", "返回第一个非NULL值", 1, -1), + } +} + +func (f *CoalesceFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *CoalesceFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + for _, arg := range args { + if arg != nil { + return arg, nil + } + } + return nil, nil +} + +// NullIfFunction 如果两个值相等则返回NULL +type NullIfFunction struct { + *BaseFunction +} + +func NewNullIfFunction() *NullIfFunction { + return &NullIfFunction{ + BaseFunction: NewBaseFunction("nullif", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2), + } +} + +func (f *NullIfFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *NullIfFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if reflect.DeepEqual(args[0], args[1]) { + return nil, nil + } + return args[0], nil +} + +// GreatestFunction 返回最大值 +type GreatestFunction struct { + *BaseFunction +} + +func NewGreatestFunction() *GreatestFunction { + return &GreatestFunction{ + BaseFunction: NewBaseFunction("greatest", TypeMath, "条件函数", "返回最大值", 1, -1), + } +} + +func (f *GreatestFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *GreatestFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if len(args) == 0 { + return nil, nil + } + + max := args[0] + if max == nil { + return nil, nil + } + + for i := 1; i < len(args); i++ { + if args[i] == nil { + return nil, nil + } + + // 尝试转换为数字进行比较 + maxVal, err1 := cast.ToFloat64E(max) + currVal, err2 := cast.ToFloat64E(args[i]) + + if err1 == nil && err2 == nil { + if currVal > maxVal { + max = args[i] + } + } else { + // 如果不能转换为数字,则按字符串比较 + maxStr := fmt.Sprintf("%v", max) + currStr := fmt.Sprintf("%v", args[i]) + if currStr > maxStr { + max = args[i] + } + } + } + return max, nil +} + +// LeastFunction 返回最小值 +type LeastFunction struct { + *BaseFunction +} + +func NewLeastFunction() *LeastFunction { + return &LeastFunction{ + BaseFunction: NewBaseFunction("least", TypeMath, "条件函数", "返回最小值", 1, -1), + } +} + +func (f *LeastFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if len(args) == 0 { + return nil, nil + } + + min := args[0] + if min == nil { + return nil, nil + } + + for i := 1; i < len(args); i++ { + if args[i] == nil { + return nil, nil + } + + // 尝试转换为数字进行比较 + minVal, err1 := cast.ToFloat64E(min) + currVal, err2 := cast.ToFloat64E(args[i]) + + if err1 == nil && err2 == nil { + if currVal < minVal { + min = args[i] + } + } else { + // 如果不能转换为数字,则按字符串比较 + minStr := fmt.Sprintf("%v", min) + currStr := fmt.Sprintf("%v", args[i]) + if currStr < minStr { + min = args[i] + } + } + } + return min, nil +} \ No newline at end of file diff --git a/functions/functions_conversion.go b/functions/functions_conversion.go index 8e0234c..23fad7d 100644 --- a/functions/functions_conversion.go +++ b/functions/functions_conversion.go @@ -1,12 +1,18 @@ package functions import ( + "bytes" + "compress/gzip" + "compress/zlib" "encoding/base64" "encoding/hex" "fmt" "github.com/rulego/streamsql/utils/cast" + "io" + "math" "net/url" "strconv" + "time" ) // CastFunction 类型转换函数 @@ -216,3 +222,277 @@ func (f *DecodeFunction) Execute(ctx *FunctionContext, args []interface{}) (inte return nil, fmt.Errorf("unsupported decode format: %s", format) } } + +// ConvertTzFunction 时区转换函数 +type ConvertTzFunction struct { + *BaseFunction +} + +func NewConvertTzFunction() *ConvertTzFunction { + return &ConvertTzFunction{ + BaseFunction: NewBaseFunction("convert_tz", TypeConversion, "转换函数", "将时间转换为指定时区", 2, 2), + } +} + +func (f *ConvertTzFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ConvertTzFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 获取时间值 + var t time.Time + switch v := args[0].(type) { + case time.Time: + t = v + case string: + var err error + // 尝试多种时间格式解析 + formats := []string{ + time.RFC3339, + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + "2006-01-02 15:04:05.000", + "2006-01-02T15:04:05.000Z", + } + for _, format := range formats { + if t, err = time.Parse(format, v); err == nil { + break + } + } + if err != nil { + return nil, fmt.Errorf("invalid time format: %s", v) + } + default: + return nil, fmt.Errorf("time value must be time.Time or string") + } + + // 获取目标时区 + timezone, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + + // 加载时区 + loc, err := time.LoadLocation(timezone) + if err != nil { + return nil, fmt.Errorf("invalid timezone: %s", timezone) + } + + // 转换时区 + return t.In(loc), nil +} + +// ToSecondsFunction 转换为Unix时间戳(秒) +type ToSecondsFunction struct { + *BaseFunction +} + +func NewToSecondsFunction() *ToSecondsFunction { + return &ToSecondsFunction{ + BaseFunction: NewBaseFunction("to_seconds", TypeConversion, "转换函数", "将日期时间转换为Unix时间戳(秒)", 1, 1), + } +} + +func (f *ToSecondsFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ToSecondsFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 获取时间值 + var t time.Time + switch v := args[0].(type) { + case time.Time: + t = v + case string: + var err error + // 尝试多种时间格式解析 + formats := []string{ + time.RFC3339, + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + "2006-01-02 15:04:05.000", + "2006-01-02T15:04:05.000Z", + } + for _, format := range formats { + if t, err = time.Parse(format, v); err == nil { + break + } + } + if err != nil { + return nil, fmt.Errorf("invalid time format: %s", v) + } + default: + return nil, fmt.Errorf("time value must be time.Time or string") + } + + return t.Unix(), nil +} + +// ChrFunction 返回对应ASCII字符 +type ChrFunction struct { + *BaseFunction +} + +func NewChrFunction() *ChrFunction { + return &ChrFunction{ + BaseFunction: NewBaseFunction("chr", TypeConversion, "转换函数", "返回对应ASCII字符", 1, 1), + } +} + +func (f *ChrFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ChrFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + code, err := cast.ToInt64E(args[0]) + if err != nil { + return nil, err + } + + if code < 0 || code > 127 { + return nil, fmt.Errorf("ASCII code must be between 0 and 127, got %d", code) + } + + return string(rune(code)), nil +} + +// TruncFunction 截断小数位数 +type TruncFunction struct { + *BaseFunction +} + +func NewTruncFunction() *TruncFunction { + return &TruncFunction{ + BaseFunction: NewBaseFunction("trunc", TypeConversion, "转换函数", "截断小数位数", 2, 2), + } +} + +func (f *TruncFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *TruncFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + + precision, err := cast.ToIntE(args[1]) + if err != nil { + return nil, err + } + + if precision < 0 { + return nil, fmt.Errorf("precision must be non-negative, got %d", precision) + } + + // 计算截断 + multiplier := math.Pow(10, float64(precision)) + return math.Trunc(val*multiplier) / multiplier, nil +} + +// CompressFunction 压缩函数 +type CompressFunction struct { + *BaseFunction +} + +func NewCompressFunction() *CompressFunction { + return &CompressFunction{ + BaseFunction: NewBaseFunction("compress", TypeConversion, "转换函数", "压缩字符串或二进制值", 2, 2), + } +} + +func (f *CompressFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *CompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + input := cast.ToString(args[0]) + algorithm := cast.ToString(args[1]) + + var buf bytes.Buffer + var writer io.WriteCloser + + switch algorithm { + case "gzip": + writer = gzip.NewWriter(&buf) + case "zlib": + writer = zlib.NewWriter(&buf) + default: + return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm) + } + + _, err := writer.Write([]byte(input)) + if err != nil { + return nil, fmt.Errorf("compression failed: %v", err) + } + + err = writer.Close() + if err != nil { + return nil, fmt.Errorf("compression failed: %v", err) + } + + // 返回base64编码的压缩数据 + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +// DecompressFunction 解压缩函数 +type DecompressFunction struct { + *BaseFunction +} + +func NewDecompressFunction() *DecompressFunction { + return &DecompressFunction{ + BaseFunction: NewBaseFunction("decompress", TypeConversion, "转换函数", "解压缩字符串或二进制值", 2, 2), + } +} + +func (f *DecompressFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DecompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + input := cast.ToString(args[0]) + algorithm := cast.ToString(args[1]) + + // 解码base64数据 + compressedData, err := base64.StdEncoding.DecodeString(input) + if err != nil { + return nil, fmt.Errorf("invalid base64 input: %v", err) + } + + buf := bytes.NewReader(compressedData) + var reader io.ReadCloser + + switch algorithm { + case "gzip": + reader, err = gzip.NewReader(buf) + if err != nil { + return nil, fmt.Errorf("gzip decompression failed: %v", err) + } + case "zlib": + reader, err = zlib.NewReader(buf) + if err != nil { + return nil, fmt.Errorf("zlib decompression failed: %v", err) + } + default: + return nil, fmt.Errorf("unsupported decompression algorithm: %s", algorithm) + } + + defer reader.Close() + + result, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("decompression failed: %v", err) + } + + return string(result), nil +} diff --git a/functions/functions_conversion_test.go b/functions/functions_conversion_test.go new file mode 100644 index 0000000..865e3d9 --- /dev/null +++ b/functions/functions_conversion_test.go @@ -0,0 +1,160 @@ +package functions + +import ( + "testing" + "time" +) + +func TestNewConversionFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + want interface{} + wantErr bool + }{ + // convert_tz 函数测试 + { + name: "convert_tz with time.Time", + funcName: "convert_tz", + args: []interface{}{time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), "Asia/Shanghai"}, + want: time.Date(2023, 1, 1, 20, 0, 0, 0, time.FixedZone("CST", 8*3600)), + wantErr: false, + }, + { + name: "convert_tz with string", + funcName: "convert_tz", + args: []interface{}{"2023-01-01 12:00:00", "America/New_York"}, + wantErr: false, + }, + { + name: "convert_tz invalid timezone", + funcName: "convert_tz", + args: []interface{}{time.Now(), "Invalid/Timezone"}, + wantErr: true, + }, + { + name: "convert_tz invalid time format", + funcName: "convert_tz", + args: []interface{}{"invalid-time", "UTC"}, + wantErr: true, + }, + + // to_seconds 函数测试 + { + name: "to_seconds with time.Time", + funcName: "to_seconds", + args: []interface{}{time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)}, + want: int64(1672531200), + wantErr: false, + }, + { + name: "to_seconds with string", + funcName: "to_seconds", + args: []interface{}{"2023-01-01T00:00:00Z"}, + want: int64(1672531200), + wantErr: false, + }, + { + name: "to_seconds invalid time format", + funcName: "to_seconds", + args: []interface{}{"invalid-time"}, + wantErr: true, + }, + + // chr 函数测试 + { + name: "chr valid ASCII code", + funcName: "chr", + args: []interface{}{65}, + want: "A", + wantErr: false, + }, + { + name: "chr space character", + funcName: "chr", + args: []interface{}{32}, + want: " ", + wantErr: false, + }, + { + name: "chr invalid code negative", + funcName: "chr", + args: []interface{}{-1}, + wantErr: true, + }, + { + name: "chr invalid code too large", + funcName: "chr", + args: []interface{}{128}, + wantErr: true, + }, + + // trunc 函数测试 + { + name: "trunc positive number", + funcName: "trunc", + args: []interface{}{3.14159, 2}, + want: 3.14, + wantErr: false, + }, + { + name: "trunc negative number", + funcName: "trunc", + args: []interface{}{-3.14159, 3}, + want: -3.141, + wantErr: false, + }, + { + name: "trunc zero precision", + funcName: "trunc", + args: []interface{}{3.14159, 0}, + want: 3.0, + wantErr: false, + }, + { + name: "trunc negative precision", + funcName: "trunc", + args: []interface{}{3.14159, -1}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(nil, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // 对于时间类型,需要特殊处理比较 + if tt.funcName == "convert_tz" { + if resultTime, ok := result.(time.Time); ok { + if wantTime, ok := tt.want.(time.Time); ok { + // 比较时间戳而不是直接比较时间对象 + if resultTime.Unix() != wantTime.Unix() { + t.Errorf("Execute() = %v, want %v", result, tt.want) + } + } else { + // 如果期望值不是时间类型,只检查结果是否为时间类型 + if resultTime.IsZero() { + t.Errorf("Execute() returned zero time") + } + } + } else { + t.Errorf("Execute() result is not time.Time") + } + } else if tt.want != nil && result != tt.want { + t.Errorf("Execute() = %v, want %v", result, tt.want) + } + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_datetime.go b/functions/functions_datetime.go index 4afac29..1c79992 100644 --- a/functions/functions_datetime.go +++ b/functions/functions_datetime.go @@ -1,7 +1,11 @@ package functions import ( + "fmt" + "strings" "time" + + "github.com/rulego/streamsql/utils/cast" ) // NowFunction 当前时间函数 @@ -62,3 +66,679 @@ func (f *CurrentDateFunction) Execute(ctx *FunctionContext, args []interface{}) now := time.Now() return now.Format("2006-01-02"), nil } + +// DateAddFunction 日期加法函数 +type DateAddFunction struct { + *BaseFunction +} + +func NewDateAddFunction() *DateAddFunction { + return &DateAddFunction{ + BaseFunction: NewBaseFunction("date_add", TypeDateTime, "时间日期函数", "日期加法", 3, 3), + } +} + +func (f *DateAddFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DateAddFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + interval, err := cast.ToInt64E(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid interval: %v", err) + } + + unit, err := cast.ToStringE(args[2]) + if err != nil { + return nil, fmt.Errorf("invalid unit: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + // 尝试其他格式 + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + switch strings.ToLower(unit) { + case "year", "years": + t = t.AddDate(int(interval), 0, 0) + case "month", "months": + t = t.AddDate(0, int(interval), 0) + case "day", "days": + t = t.AddDate(0, 0, int(interval)) + case "hour", "hours": + t = t.Add(time.Duration(interval) * time.Hour) + case "minute", "minutes": + t = t.Add(time.Duration(interval) * time.Minute) + case "second", "seconds": + t = t.Add(time.Duration(interval) * time.Second) + default: + return nil, fmt.Errorf("unsupported unit: %s", unit) + } + + return t.Format("2006-01-02 15:04:05"), nil +} + +// DateSubFunction 日期减法函数 +type DateSubFunction struct { + *BaseFunction +} + +func NewDateSubFunction() *DateSubFunction { + return &DateSubFunction{ + BaseFunction: NewBaseFunction("date_sub", TypeDateTime, "时间日期函数", "日期减法", 3, 3), + } +} + +func (f *DateSubFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DateSubFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + interval, err := cast.ToInt64E(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid interval: %v", err) + } + + unit, err := cast.ToStringE(args[2]) + if err != nil { + return nil, fmt.Errorf("invalid unit: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + switch strings.ToLower(unit) { + case "year", "years": + t = t.AddDate(-int(interval), 0, 0) + case "month", "months": + t = t.AddDate(0, -int(interval), 0) + case "day", "days": + t = t.AddDate(0, 0, -int(interval)) + case "hour", "hours": + t = t.Add(-time.Duration(interval) * time.Hour) + case "minute", "minutes": + t = t.Add(-time.Duration(interval) * time.Minute) + case "second", "seconds": + t = t.Add(-time.Duration(interval) * time.Second) + default: + return nil, fmt.Errorf("unsupported unit: %s", unit) + } + + return t.Format("2006-01-02 15:04:05"), nil +} + +// DateDiffFunction 日期差函数 +type DateDiffFunction struct { + *BaseFunction +} + +func NewDateDiffFunction() *DateDiffFunction { + return &DateDiffFunction{ + BaseFunction: NewBaseFunction("date_diff", TypeDateTime, "时间日期函数", "计算日期差", 3, 3), + } +} + +func (f *DateDiffFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DateDiffFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + date1Str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date1: %v", err) + } + + date2Str, err := cast.ToStringE(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid date2: %v", err) + } + + unit, err := cast.ToStringE(args[2]) + if err != nil { + return nil, fmt.Errorf("invalid unit: %v", err) + } + + t1, err := time.Parse("2006-01-02 15:04:05", date1Str) + if err != nil { + if t1, err = time.Parse("2006-01-02", date1Str); err != nil { + return nil, fmt.Errorf("invalid date1 format: %v", err) + } + } + + t2, err := time.Parse("2006-01-02 15:04:05", date2Str) + if err != nil { + if t2, err = time.Parse("2006-01-02", date2Str); err != nil { + return nil, fmt.Errorf("invalid date2 format: %v", err) + } + } + + diff := t1.Sub(t2) + + switch strings.ToLower(unit) { + case "year", "years": + return int64(diff.Hours() / (24 * 365)), nil + case "month", "months": + return int64(diff.Hours() / (24 * 30)), nil + case "day", "days": + return int64(diff.Hours() / 24), nil + case "hour", "hours": + return int64(diff.Hours()), nil + case "minute", "minutes": + return int64(diff.Minutes()), nil + case "second", "seconds": + return int64(diff.Seconds()), nil + default: + return nil, fmt.Errorf("unsupported unit: %s", unit) + } +} + +// DateFormatFunction 日期格式化函数 +type DateFormatFunction struct { + *BaseFunction +} + +func NewDateFormatFunction() *DateFormatFunction { + return &DateFormatFunction{ + BaseFunction: NewBaseFunction("date_format", TypeDateTime, "时间日期函数", "格式化日期", 2, 2), + } +} + +func (f *DateFormatFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DateFormatFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + format, err := cast.ToStringE(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid format: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + // 转换常见的格式字符串 + goFormat := convertToGoFormat(format) + return t.Format(goFormat), nil +} + +// convertToGoFormat 将常见的日期格式转换为Go的时间格式 +func convertToGoFormat(format string) string { + // 按照长度从长到短的顺序替换,避免短的模式覆盖长的模式 + replacements := []struct { + old string + new string + }{ + {"YYYY", "2006"}, + {"yyyy", "2006"}, + {"YY", "06"}, + {"yy", "06"}, + {"MM", "01"}, + {"mm", "01"}, + {"DD", "02"}, + {"dd", "02"}, + {"HH", "15"}, + {"hh", "15"}, + {"MI", "04"}, + {"mi", "04"}, + {"SS", "05"}, + {"ss", "05"}, + } + + result := format + for _, r := range replacements { + result = strings.ReplaceAll(result, r.old, r.new) + } + return result +} + +// DateParseFunction 日期解析函数 +type DateParseFunction struct { + *BaseFunction +} + +func NewDateParseFunction() *DateParseFunction { + return &DateParseFunction{ + BaseFunction: NewBaseFunction("date_parse", TypeDateTime, "时间日期函数", "解析日期字符串", 2, 2), + } +} + +func (f *DateParseFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DateParseFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date string: %v", err) + } + + format, err := cast.ToStringE(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid format: %v", err) + } + + goFormat := convertToGoFormat(format) + t, err := time.Parse(goFormat, dateStr) + if err != nil { + return nil, fmt.Errorf("failed to parse date: %v", err) + } + + return t.Format("2006-01-02 15:04:05"), nil +} + +// ExtractFunction 提取日期部分函数 +type ExtractFunction struct { + *BaseFunction +} + +func NewExtractFunction() *ExtractFunction { + return &ExtractFunction{ + BaseFunction: NewBaseFunction("extract", TypeDateTime, "时间日期函数", "提取日期部分", 2, 2), + } +} + +func (f *ExtractFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ExtractFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + unit, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid unit: %v", err) + } + + dateStr, err := cast.ToStringE(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + switch strings.ToLower(unit) { + case "year": + return t.Year(), nil + case "month": + return int(t.Month()), nil + case "day": + return t.Day(), nil + case "hour": + return t.Hour(), nil + case "minute": + return t.Minute(), nil + case "second": + return t.Second(), nil + case "weekday": + return int(t.Weekday()), nil + case "yearday": + return t.YearDay(), nil + default: + return nil, fmt.Errorf("unsupported unit: %s", unit) + } +} + +// UnixTimestampFunction Unix时间戳函数 +type UnixTimestampFunction struct { + *BaseFunction +} + +func NewUnixTimestampFunction() *UnixTimestampFunction { + return &UnixTimestampFunction{ + BaseFunction: NewBaseFunction("unix_timestamp", TypeDateTime, "时间日期函数", "转换为Unix时间戳", 1, 1), + } +} + +func (f *UnixTimestampFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *UnixTimestampFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return t.Unix(), nil +} + +// FromUnixtimeFunction 从Unix时间戳转换函数 +type FromUnixtimeFunction struct { + *BaseFunction +} + +func NewFromUnixtimeFunction() *FromUnixtimeFunction { + return &FromUnixtimeFunction{ + BaseFunction: NewBaseFunction("from_unixtime", TypeDateTime, "时间日期函数", "从Unix时间戳转换", 1, 1), + } +} + +func (f *FromUnixtimeFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *FromUnixtimeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + timestamp, err := cast.ToInt64E(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid timestamp: %v", err) + } + + t := time.Unix(timestamp, 0).UTC() + return t.Format("2006-01-02 15:04:05"), nil +} + +// YearFunction 提取年份函数 +type YearFunction struct { + *BaseFunction +} + +func NewYearFunction() *YearFunction { + return &YearFunction{ + BaseFunction: NewBaseFunction("year", TypeDateTime, "时间日期函数", "提取年份", 1, 1), + } +} + +func (f *YearFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *YearFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return t.Year(), nil +} + +// MonthFunction 提取月份函数 +type MonthFunction struct { + *BaseFunction +} + +func NewMonthFunction() *MonthFunction { + return &MonthFunction{ + BaseFunction: NewBaseFunction("month", TypeDateTime, "时间日期函数", "提取月份", 1, 1), + } +} + +func (f *MonthFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *MonthFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return int(t.Month()), nil +} + +// DayFunction 提取日期函数 +type DayFunction struct { + *BaseFunction +} + +func NewDayFunction() *DayFunction { + return &DayFunction{ + BaseFunction: NewBaseFunction("day", TypeDateTime, "时间日期函数", "提取日期", 1, 1), + } +} + +func (f *DayFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DayFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return t.Day(), nil +} + +// HourFunction 提取小时函数 +type HourFunction struct { + *BaseFunction +} + +func NewHourFunction() *HourFunction { + return &HourFunction{ + BaseFunction: NewBaseFunction("hour", TypeDateTime, "时间日期函数", "提取小时", 1, 1), + } +} + +func (f *HourFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *HourFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return t.Hour(), nil +} + +// MinuteFunction 提取分钟函数 +type MinuteFunction struct { + *BaseFunction +} + +func NewMinuteFunction() *MinuteFunction { + return &MinuteFunction{ + BaseFunction: NewBaseFunction("minute", TypeDateTime, "时间日期函数", "提取分钟", 1, 1), + } +} + +func (f *MinuteFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *MinuteFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return t.Minute(), nil +} + +// SecondFunction 提取秒数函数 +type SecondFunction struct { + *BaseFunction +} + +func NewSecondFunction() *SecondFunction { + return &SecondFunction{ + BaseFunction: NewBaseFunction("second", TypeDateTime, "时间日期函数", "提取秒数", 1, 1), + } +} + +func (f *SecondFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *SecondFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return t.Second(), nil +} + +// DayOfWeekFunction 获取星期几函数 +type DayOfWeekFunction struct { + *BaseFunction +} + +func NewDayOfWeekFunction() *DayOfWeekFunction { + return &DayOfWeekFunction{ + BaseFunction: NewBaseFunction("dayofweek", TypeDateTime, "时间日期函数", "获取星期几", 1, 1), + } +} + +func (f *DayOfWeekFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DayOfWeekFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return int(t.Weekday()), nil +} + +// DayOfYearFunction 获取一年中的第几天函数 +type DayOfYearFunction struct { + *BaseFunction +} + +func NewDayOfYearFunction() *DayOfYearFunction { + return &DayOfYearFunction{ + BaseFunction: NewBaseFunction("dayofyear", TypeDateTime, "时间日期函数", "获取一年中的第几天", 1, 1), + } +} + +func (f *DayOfYearFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *DayOfYearFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + return t.YearDay(), nil +} + +// WeekOfYearFunction 获取一年中的第几周函数 +type WeekOfYearFunction struct { + *BaseFunction +} + +func NewWeekOfYearFunction() *WeekOfYearFunction { + return &WeekOfYearFunction{ + BaseFunction: NewBaseFunction("weekofyear", TypeDateTime, "时间日期函数", "获取一年中的第几周", 1, 1), + } +} + +func (f *WeekOfYearFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *WeekOfYearFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + dateStr, err := cast.ToStringE(args[0]) + if err != nil { + return nil, fmt.Errorf("invalid date: %v", err) + } + + t, err := time.Parse("2006-01-02 15:04:05", dateStr) + if err != nil { + if t, err = time.Parse("2006-01-02", dateStr); err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + } + + _, week := t.ISOWeek() + return week, nil +} diff --git a/functions/functions_datetime_test.go b/functions/functions_datetime_test.go new file mode 100644 index 0000000..8e513d5 --- /dev/null +++ b/functions/functions_datetime_test.go @@ -0,0 +1,238 @@ +package functions + +import ( + "testing" +) + +func TestDateTimeFunctions(t *testing.T) { + tests := []struct { + name string + function Function + args []interface{} + expected interface{} + wantErr bool + }{ + // DateFormatFunction 测试 + { + name: "date_format basic", + function: NewDateFormatFunction(), + args: []interface{}{"2023-12-25 15:30:45", "YYYY-MM-DD HH:MI:SS"}, + expected: "2023-12-25 15:30:45", + wantErr: false, + }, + { + name: "date_format custom", + function: NewDateFormatFunction(), + args: []interface{}{"2023-12-25 15:30:45", "YYYY/MM/DD"}, + expected: "2023/12/25", + wantErr: false, + }, + // DateAddFunction 测试 + { + name: "date_add days", + function: NewDateAddFunction(), + args: []interface{}{"2023-12-25", 7, "days"}, + expected: "2024-01-01 00:00:00", + wantErr: false, + }, + { + name: "date_add months", + function: NewDateAddFunction(), + args: []interface{}{"2023-12-25", 1, "months"}, + expected: "2024-01-25 00:00:00", + wantErr: false, + }, + // DateSubFunction 测试 + { + name: "date_sub days", + function: NewDateSubFunction(), + args: []interface{}{"2024-01-01", 7, "days"}, + expected: "2023-12-25 00:00:00", + wantErr: false, + }, + // DateDiffFunction 测试 + { + name: "date_diff days", + function: NewDateDiffFunction(), + args: []interface{}{"2024-01-01", "2023-12-25", "days"}, + expected: int64(7), + wantErr: false, + }, + // YearFunction 测试 + { + name: "year extraction", + function: NewYearFunction(), + args: []interface{}{"2023-12-25 15:30:45"}, + expected: 2023, + wantErr: false, + }, + // MonthFunction 测试 + { + name: "month extraction", + function: NewMonthFunction(), + args: []interface{}{"2023-12-25 15:30:45"}, + expected: 12, + wantErr: false, + }, + // DayFunction 测试 + { + name: "day extraction", + function: NewDayFunction(), + args: []interface{}{"2023-12-25 15:30:45"}, + expected: 25, + wantErr: false, + }, + // HourFunction 测试 + { + name: "hour extraction", + function: NewHourFunction(), + args: []interface{}{"2023-12-25 15:30:45"}, + expected: 15, + wantErr: false, + }, + // MinuteFunction 测试 + { + name: "minute extraction", + function: NewMinuteFunction(), + args: []interface{}{"2023-12-25 15:30:45"}, + expected: 30, + wantErr: false, + }, + // SecondFunction 测试 + { + name: "second extraction", + function: NewSecondFunction(), + args: []interface{}{"2023-12-25 15:30:45"}, + expected: 45, + wantErr: false, + }, + // UnixTimestampFunction 测试 + { + name: "unix_timestamp", + function: NewUnixTimestampFunction(), + args: []interface{}{"2023-01-01 00:00:00"}, + expected: int64(1672531200), + wantErr: false, + }, + // FromUnixtimeFunction 测试 + { + name: "from_unixtime", + function: NewFromUnixtimeFunction(), + args: []interface{}{1672531200}, + expected: "2023-01-01 00:00:00", + wantErr: false, + }, + // ExtractFunction 测试 + { + name: "extract year", + function: NewExtractFunction(), + args: []interface{}{"year", "2023-12-25 15:30:45"}, + expected: 2023, + wantErr: false, + }, + { + name: "extract month", + function: NewExtractFunction(), + args: []interface{}{"month", "2023-12-25 15:30:45"}, + expected: 12, + wantErr: false, + }, + // DayOfWeekFunction 测试 + { + name: "dayofweek", + function: NewDayOfWeekFunction(), + args: []interface{}{"2023-12-25"}, + expected: 1, // Monday + wantErr: false, + }, + // DateParseFunction 测试 + { + name: "date_parse", + function: NewDateParseFunction(), + args: []interface{}{"2023/12/25", "YYYY/MM/DD"}, + expected: "2023-12-25 00:00:00", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 验证参数 + if err := tt.function.Validate(tt.args); err != nil { + if !tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + + // 执行函数 + result, err := tt.function.Execute(nil, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestDateTimeRegistration(t *testing.T) { + // 测试函数是否正确注册 + dateTimeFunctions := []string{ + "date_format", + "date_add", + "date_sub", + "date_diff", + "date_parse", + "extract", + "unix_timestamp", + "from_unixtime", + "year", + "month", + "day", + "hour", + "minute", + "second", + "dayofweek", + "dayofyear", + "weekofyear", + } + + for _, funcName := range dateTimeFunctions { + t.Run("register_"+funcName, func(t *testing.T) { + func_, exists := Get(funcName) + if !exists { + t.Errorf("Function %s not registered", funcName) + return + } + if func_ == nil { + t.Errorf("Function %s is nil", funcName) + } + }) + } +} + +func TestDateFormatConversion(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"YYYY-MM-DD", "2006-01-02"}, + {"yyyy/mm/dd", "2006/01/02"}, + {"DD/MM/YYYY", "02/01/2006"}, + {"HH:MI:SS", "15:04:05"}, + {"YYYY-MM-DD HH:MI:SS", "2006-01-02 15:04:05"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := convertToGoFormat(tt.input) + if result != tt.expected { + t.Errorf("convertToGoFormat(%s) = %s, want %s", tt.input, result, tt.expected) + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_hash.go b/functions/functions_hash.go new file mode 100644 index 0000000..a4bab70 --- /dev/null +++ b/functions/functions_hash.go @@ -0,0 +1,109 @@ +package functions + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "fmt" +) + +// Md5Function 计算MD5哈希值 +type Md5Function struct { + *BaseFunction +} + +func NewMd5Function() *Md5Function { + return &Md5Function{ + BaseFunction: NewBaseFunction("md5", TypeString, "哈希函数", "计算MD5哈希值", 1, 1), + } +} + +func (f *Md5Function) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *Md5Function) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("md5 requires string input") + } + + hash := md5.Sum([]byte(str)) + return fmt.Sprintf("%x", hash), nil +} + +// Sha1Function 计算SHA1哈希值 +type Sha1Function struct { + *BaseFunction +} + +func NewSha1Function() *Sha1Function { + return &Sha1Function{ + BaseFunction: NewBaseFunction("sha1", TypeString, "哈希函数", "计算SHA1哈希值", 1, 1), + } +} + +func (f *Sha1Function) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *Sha1Function) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("sha1 requires string input") + } + + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash), nil +} + +// Sha256Function 计算SHA256哈希值 +type Sha256Function struct { + *BaseFunction +} + +func NewSha256Function() *Sha256Function { + return &Sha256Function{ + BaseFunction: NewBaseFunction("sha256", TypeString, "哈希函数", "计算SHA256哈希值", 1, 1), + } +} + +func (f *Sha256Function) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *Sha256Function) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("sha256 requires string input") + } + + hash := sha256.Sum256([]byte(str)) + return fmt.Sprintf("%x", hash), nil +} + +// Sha512Function 计算SHA512哈希值 +type Sha512Function struct { + *BaseFunction +} + +func NewSha512Function() *Sha512Function { + return &Sha512Function{ + BaseFunction: NewBaseFunction("sha512", TypeString, "哈希函数", "计算SHA512哈希值", 1, 1), + } +} + +func (f *Sha512Function) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *Sha512Function) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("sha512 requires string input") + } + + hash := sha512.Sum512([]byte(str)) + return fmt.Sprintf("%x", hash), nil +} \ No newline at end of file diff --git a/functions/functions_json.go b/functions/functions_json.go new file mode 100644 index 0000000..10d39fc --- /dev/null +++ b/functions/functions_json.go @@ -0,0 +1,211 @@ +package functions + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ToJsonFunction 转换为JSON字符串 +type ToJsonFunction struct { + *BaseFunction +} + +func NewToJsonFunction() *ToJsonFunction { + return &ToJsonFunction{ + BaseFunction: NewBaseFunction("to_json", TypeConversion, "JSON函数", "转换为JSON字符串", 1, 1), + } +} + +func (f *ToJsonFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ToJsonFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + value := args[0] + jsonBytes, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("failed to convert to JSON: %v", err) + } + return string(jsonBytes), nil +} + +// FromJsonFunction 从JSON字符串解析 +type FromJsonFunction struct { + *BaseFunction +} + +func NewFromJsonFunction() *FromJsonFunction { + return &FromJsonFunction{ + BaseFunction: NewBaseFunction("from_json", TypeConversion, "JSON函数", "从JSON字符串解析", 1, 1), + } +} + +func (f *FromJsonFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *FromJsonFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + jsonStr, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("from_json requires string input") + } + + var result interface{} + err := json.Unmarshal([]byte(jsonStr), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON: %v", err) + } + return result, nil +} + +// JsonExtractFunction 提取JSON字段值 +type JsonExtractFunction struct { + *BaseFunction +} + +func NewJsonExtractFunction() *JsonExtractFunction { + return &JsonExtractFunction{ + BaseFunction: NewBaseFunction("json_extract", TypeString, "JSON函数", "提取JSON字段值", 2, 2), + } +} + +func (f *JsonExtractFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + jsonStr, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("json_extract requires string input") + } + + path, ok := args[1].(string) + if !ok { + return nil, fmt.Errorf("json_extract path must be string") + } + + var data interface{} + err := json.Unmarshal([]byte(jsonStr), &data) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON: %v", err) + } + + // 简单的路径提取,支持 $.field 格式 + if strings.HasPrefix(path, "$.") { + field := path[2:] + if dataMap, ok := data.(map[string]interface{}); ok { + return dataMap[field], nil + } + } + + return nil, fmt.Errorf("invalid JSON path or data structure") +} + +// JsonValidFunction 验证JSON格式是否有效 +type JsonValidFunction struct { + *BaseFunction +} + +func NewJsonValidFunction() *JsonValidFunction { + return &JsonValidFunction{ + BaseFunction: NewBaseFunction("json_valid", TypeString, "JSON函数", "验证JSON格式是否有效", 1, 1), + } +} + +func (f *JsonValidFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *JsonValidFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + jsonStr, ok := args[0].(string) + if !ok { + return false, nil + } + + var temp interface{} + err := json.Unmarshal([]byte(jsonStr), &temp) + return err == nil, nil +} + +// JsonTypeFunction 返回JSON值的类型 +type JsonTypeFunction struct { + *BaseFunction +} + +func NewJsonTypeFunction() *JsonTypeFunction { + return &JsonTypeFunction{ + BaseFunction: NewBaseFunction("json_type", TypeString, "JSON函数", "返回JSON值的类型", 1, 1), + } +} + +func (f *JsonTypeFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *JsonTypeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + jsonStr, ok := args[0].(string) + if !ok { + return "unknown", nil + } + + var data interface{} + err := json.Unmarshal([]byte(jsonStr), &data) + if err != nil { + return "invalid", nil + } + + switch data.(type) { + case nil: + return "null", nil + case bool: + return "boolean", nil + case float64: + return "number", nil + case string: + return "string", nil + case []interface{}: + return "array", nil + case map[string]interface{}: + return "object", nil + default: + return "unknown", nil + } +} + +// JsonLengthFunction 返回JSON数组或对象的长度 +type JsonLengthFunction struct { + *BaseFunction +} + +func NewJsonLengthFunction() *JsonLengthFunction { + return &JsonLengthFunction{ + BaseFunction: NewBaseFunction("json_length", TypeString, "JSON函数", "返回JSON数组或对象的长度", 1, 1), + } +} + +func (f *JsonLengthFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *JsonLengthFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + jsonStr, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("json_length requires string input") + } + + var data interface{} + err := json.Unmarshal([]byte(jsonStr), &data) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON: %v", err) + } + + switch v := data.(type) { + case []interface{}: + return len(v), nil + case map[string]interface{}: + return len(v), nil + default: + return nil, fmt.Errorf("JSON value is not an array or object") + } +} \ No newline at end of file diff --git a/functions/functions_math.go b/functions/functions_math.go index c1e9ce3..20600ed 100644 --- a/functions/functions_math.go +++ b/functions/functions_math.go @@ -4,6 +4,8 @@ import ( "fmt" "github.com/rulego/streamsql/utils/cast" "math" + "math/rand" + "time" ) // AbsFunction 绝对值函数 @@ -402,6 +404,290 @@ func (f *LnFunction) Execute(ctx *FunctionContext, args []interface{}) (interfac return math.Log(val), nil } +// LogFunction 自然对数函数 (log的别名) +type LogFunction struct { + *BaseFunction +} + +func NewLogFunction() *LogFunction { + return &LogFunction{ + BaseFunction: NewBaseFunction("log", TypeMath, "数学函数", "计算自然对数", 1, 1), + } +} + +func (f *LogFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *LogFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + if val <= 0 { + return nil, fmt.Errorf("log: value must be positive") + } + return math.Log(val), nil +} + +// Log10Function 以10为底的对数函数 +type Log10Function struct { + *BaseFunction +} + +func NewLog10Function() *Log10Function { + return &Log10Function{ + BaseFunction: NewBaseFunction("log10", TypeMath, "数学函数", "计算以10为底的对数", 1, 1), + } +} + +func (f *Log10Function) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *Log10Function) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + if val <= 0 { + return nil, fmt.Errorf("log10: value must be positive") + } + return math.Log10(val), nil +} + +// Log2Function 以2为底的对数函数 +type Log2Function struct { + *BaseFunction +} + +func NewLog2Function() *Log2Function { + return &Log2Function{ + BaseFunction: NewBaseFunction("log2", TypeMath, "数学函数", "计算以2为底的对数", 1, 1), + } +} + +func (f *Log2Function) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *Log2Function) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + if val <= 0 { + return nil, fmt.Errorf("log2: value must be positive") + } + return math.Log2(val), nil +} + +// ModFunction 取模函数 +type ModFunction struct { + *BaseFunction +} + +func NewModFunction() *ModFunction { + return &ModFunction{ + BaseFunction: NewBaseFunction("mod", TypeMath, "数学函数", "取模运算", 2, 2), + } +} + +func (f *ModFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ModFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + x, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + y, err := cast.ToFloat64E(args[1]) + if err != nil { + return nil, err + } + if y == 0 { + return nil, fmt.Errorf("mod: division by zero") + } + return math.Mod(x, y), nil +} + +// RandFunction 随机数函数 +type RandFunction struct { + *BaseFunction +} + +func NewRandFunction() *RandFunction { + return &RandFunction{ + BaseFunction: NewBaseFunction("rand", TypeMath, "数学函数", "生成0-1之间的随机数", 0, 0), + } +} + +func (f *RandFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *RandFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 使用当前时间作为种子 + rand.Seed(time.Now().UnixNano()) + return rand.Float64(), nil +} + +// RoundFunction 四舍五入函数 +type RoundFunction struct { + *BaseFunction +} + +func NewRoundFunction() *RoundFunction { + return &RoundFunction{ + BaseFunction: NewBaseFunction("round", TypeMath, "数学函数", "四舍五入", 1, 2), + } +} + +func (f *RoundFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *RoundFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + + if len(args) == 1 { + return math.Round(val), nil + } + + precision, err := cast.ToIntE(args[1]) + if err != nil { + return nil, err + } + + shift := math.Pow(10, float64(precision)) + return math.Round(val*shift) / shift, nil +} + +// SignFunction 符号函数 +type SignFunction struct { + *BaseFunction +} + +func NewSignFunction() *SignFunction { + return &SignFunction{ + BaseFunction: NewBaseFunction("sign", TypeMath, "数学函数", "返回数字的符号", 1, 1), + } +} + +func (f *SignFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *SignFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + + if val > 0 { + return 1, nil + } else if val < 0 { + return -1, nil + } + return 0, nil +} + +// SinFunction 正弦函数 +type SinFunction struct { + *BaseFunction +} + +func NewSinFunction() *SinFunction { + return &SinFunction{ + BaseFunction: NewBaseFunction("sin", TypeMath, "数学函数", "计算正弦值", 1, 1), + } +} + +func (f *SinFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *SinFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + return math.Sin(val), nil +} + +// SinhFunction 双曲正弦函数 +type SinhFunction struct { + *BaseFunction +} + +func NewSinhFunction() *SinhFunction { + return &SinhFunction{ + BaseFunction: NewBaseFunction("sinh", TypeMath, "数学函数", "计算双曲正弦值", 1, 1), + } +} + +func (f *SinhFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *SinhFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + return math.Sinh(val), nil +} + +// TanFunction 正切函数 +type TanFunction struct { + *BaseFunction +} + +func NewTanFunction() *TanFunction { + return &TanFunction{ + BaseFunction: NewBaseFunction("tan", TypeMath, "数学函数", "计算正切值", 1, 1), + } +} + +func (f *TanFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *TanFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + return math.Tan(val), nil +} + +// TanhFunction 双曲正切函数 +type TanhFunction struct { + *BaseFunction +} + +func NewTanhFunction() *TanhFunction { + return &TanhFunction{ + BaseFunction: NewBaseFunction("tanh", TypeMath, "数学函数", "计算双曲正切值", 1, 1), + } +} + +func (f *TanhFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *TanhFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + val, err := cast.ToFloat64E(args[0]) + if err != nil { + return nil, err + } + return math.Tanh(val), nil +} + // PowerFunction 幂函数 type PowerFunction struct { *BaseFunction diff --git a/functions/functions_new_test.go b/functions/functions_new_test.go new file mode 100644 index 0000000..c774c49 --- /dev/null +++ b/functions/functions_new_test.go @@ -0,0 +1,339 @@ +package functions + +import ( + "testing" +) + +// 测试JSON函数 +func TestJsonFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + wantErr bool + }{ + { + name: "to_json basic", + funcName: "to_json", + args: []interface{}{map[string]interface{}{"name": "test", "value": 123}}, + expected: `{"name":"test","value":123}`, + }, + { + name: "from_json basic", + funcName: "from_json", + args: []interface{}{`{"name":"test","value":123}`}, + expected: map[string]interface{}{"name": "test", "value": float64(123)}, + }, + { + name: "json_extract basic", + funcName: "json_extract", + args: []interface{}{`{"name":"test","value":123}`, "$.name"}, + expected: "test", + }, + { + name: "json_valid true", + funcName: "json_valid", + args: []interface{}{`{"name":"test"}`}, + expected: true, + }, + { + name: "json_valid false", + funcName: "json_valid", + args: []interface{}{`{"name":"test"`}, + expected: false, + }, + { + name: "json_type object", + funcName: "json_type", + args: []interface{}{`{"name":"test"}`}, + expected: "object", + }, + { + name: "json_length array", + funcName: "json_length", + args: []interface{}{`[1,2,3]`}, + expected: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && !compareResults(result, tt.expected) { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} + +// 测试哈希函数 +func TestHashFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + }{ + { + name: "md5 basic", + funcName: "md5", + args: []interface{}{"hello"}, + expected: "5d41402abc4b2a76b9719d911017c592", + }, + { + name: "sha1 basic", + funcName: "sha1", + args: []interface{}{"hello"}, + expected: "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", + }, + { + name: "sha256 basic", + funcName: "sha256", + args: []interface{}{"hello"}, + expected: "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if err != nil { + t.Errorf("Execute() error = %v", err) + return + } + + if result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} + +// 测试数组函数 +func TestArrayFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + }{ + { + name: "array_length basic", + funcName: "array_length", + args: []interface{}{[]interface{}{1, 2, 3}}, + expected: 3, + }, + { + name: "array_contains true", + funcName: "array_contains", + args: []interface{}{[]interface{}{1, 2, 3}, 2}, + expected: true, + }, + { + name: "array_contains false", + funcName: "array_contains", + args: []interface{}{[]interface{}{1, 2, 3}, 4}, + expected: false, + }, + { + name: "array_position found", + funcName: "array_position", + args: []interface{}{[]interface{}{1, 2, 3}, 2}, + expected: 2, + }, + { + name: "array_position not found", + funcName: "array_position", + args: []interface{}{[]interface{}{1, 2, 3}, 4}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if err != nil { + t.Errorf("Execute() error = %v", err) + return + } + + if result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} + +// 测试类型检查函数 +func TestTypeFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + }{ + { + name: "is_null true", + funcName: "is_null", + args: []interface{}{nil}, + expected: true, + }, + { + name: "is_null false", + funcName: "is_null", + args: []interface{}{"test"}, + expected: false, + }, + { + name: "is_numeric true", + funcName: "is_numeric", + args: []interface{}{123}, + expected: true, + }, + { + name: "is_numeric false", + funcName: "is_numeric", + args: []interface{}{"test"}, + expected: false, + }, + { + name: "is_string true", + funcName: "is_string", + args: []interface{}{"test"}, + expected: true, + }, + { + name: "is_string false", + funcName: "is_string", + args: []interface{}{123}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if err != nil { + t.Errorf("Execute() error = %v", err) + return + } + + if result != tt.expected { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} + +// 测试条件函数 +func TestConditionalFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + }{ + { + name: "coalesce first non-null", + funcName: "coalesce", + args: []interface{}{nil, "test", "other"}, + expected: "test", + }, + { + name: "nullif equal", + funcName: "nullif", + args: []interface{}{"test", "test"}, + expected: nil, + }, + { + name: "nullif not equal", + funcName: "nullif", + args: []interface{}{"test", "other"}, + expected: "test", + }, + { + name: "greatest numeric", + funcName: "greatest", + args: []interface{}{1, 3, 2}, + expected: 3, + }, + { + name: "least numeric", + funcName: "least", + args: []interface{}{3, 1, 2}, + expected: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + result, err := fn.Execute(&FunctionContext{}, tt.args) + if err != nil { + t.Errorf("Execute() error = %v", err) + return + } + + if !compareResults(result, tt.expected) { + t.Errorf("Execute() = %v, want %v", result, tt.expected) + } + }) + } +} + +// 辅助函数:比较结果 +func compareResults(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + // 对于map类型的特殊处理 + if mapA, okA := a.(map[string]interface{}); okA { + if mapB, okB := b.(map[string]interface{}); okB { + if len(mapA) != len(mapB) { + return false + } + for k, v := range mapA { + if mapB[k] != v { + return false + } + } + return true + } + } + + return a == b +} \ No newline at end of file diff --git a/functions/functions_string.go b/functions/functions_string.go index bd3051c..4ba1353 100644 --- a/functions/functions_string.go +++ b/functions/functions_string.go @@ -3,6 +3,7 @@ package functions import ( "fmt" "github.com/rulego/streamsql/utils/cast" + "regexp" "strings" ) @@ -185,3 +186,427 @@ func (f *FormatFunction) Execute(ctx *FunctionContext, args []interface{}) (inte // 如果有第三个参数(locale),这里简化处理 return str, nil } + +// EndswithFunction 检查字符串是否以指定后缀结尾 +type EndswithFunction struct { + *BaseFunction +} + +func NewEndswithFunction() *EndswithFunction { + return &EndswithFunction{ + BaseFunction: NewBaseFunction("endswith", TypeString, "字符串函数", "检查字符串是否以指定后缀结尾", 2, 2), + } +} + +func (f *EndswithFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *EndswithFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + suffix, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + return strings.HasSuffix(str, suffix), nil +} + +// StartswithFunction 检查字符串是否以指定前缀开始 +type StartswithFunction struct { + *BaseFunction +} + +func NewStartswithFunction() *StartswithFunction { + return &StartswithFunction{ + BaseFunction: NewBaseFunction("startswith", TypeString, "字符串函数", "检查字符串是否以指定前缀开始", 2, 2), + } +} + +func (f *StartswithFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *StartswithFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + prefix, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + return strings.HasPrefix(str, prefix), nil +} + +// IndexofFunction 返回子字符串在字符串中的位置 +type IndexofFunction struct { + *BaseFunction +} + +func NewIndexofFunction() *IndexofFunction { + return &IndexofFunction{ + BaseFunction: NewBaseFunction("indexof", TypeString, "字符串函数", "返回子字符串在字符串中的位置", 2, 2), + } +} + +func (f *IndexofFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IndexofFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + substr, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + return int64(strings.Index(str, substr)), nil +} + +// SubstringFunction 提取子字符串 +type SubstringFunction struct { + *BaseFunction +} + +func NewSubstringFunction() *SubstringFunction { + return &SubstringFunction{ + BaseFunction: NewBaseFunction("substring", TypeString, "字符串函数", "提取子字符串", 2, 3), + } +} + +func (f *SubstringFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *SubstringFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + start, err := cast.ToInt64E(args[1]) + if err != nil { + return nil, err + } + + strLen := int64(len(str)) + if start < 0 || start >= strLen { + return "", nil + } + + if len(args) == 2 { + return str[start:], nil + } + + length, err := cast.ToInt64E(args[2]) + if err != nil { + return nil, err + } + + end := start + length + if end > strLen { + end = strLen + } + + return str[start:end], nil +} + +// ReplaceFunction 替换字符串中的内容 +type ReplaceFunction struct { + *BaseFunction +} + +func NewReplaceFunction() *ReplaceFunction { + return &ReplaceFunction{ + BaseFunction: NewBaseFunction("replace", TypeString, "字符串函数", "替换字符串中的内容", 3, 3), + } +} + +func (f *ReplaceFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *ReplaceFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + old, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + new, err := cast.ToStringE(args[2]) + if err != nil { + return nil, err + } + return strings.ReplaceAll(str, old, new), nil +} + +// SplitFunction 按分隔符分割字符串 +type SplitFunction struct { + *BaseFunction +} + +func NewSplitFunction() *SplitFunction { + return &SplitFunction{ + BaseFunction: NewBaseFunction("split", TypeString, "字符串函数", "按分隔符分割字符串", 2, 2), + } +} + +func (f *SplitFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *SplitFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + delimiter, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + return strings.Split(str, delimiter), nil +} + +// LpadFunction 左填充字符串 +type LpadFunction struct { + *BaseFunction +} + +func NewLpadFunction() *LpadFunction { + return &LpadFunction{ + BaseFunction: NewBaseFunction("lpad", TypeString, "字符串函数", "左填充字符串", 2, 3), + } +} + +func (f *LpadFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *LpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + length, err := cast.ToInt64E(args[1]) + if err != nil { + return nil, err + } + + pad := " " + if len(args) == 3 { + pad, err = cast.ToStringE(args[2]) + if err != nil { + return nil, err + } + } + + strLen := int64(len(str)) + if strLen >= length { + return str, nil + } + + padLen := length - strLen + padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1)) + return padStr[:padLen] + str, nil +} + +// RpadFunction 右填充字符串 +type RpadFunction struct { + *BaseFunction +} + +func NewRpadFunction() *RpadFunction { + return &RpadFunction{ + BaseFunction: NewBaseFunction("rpad", TypeString, "字符串函数", "右填充字符串", 2, 3), + } +} + +func (f *RpadFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *RpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + length, err := cast.ToInt64E(args[1]) + if err != nil { + return nil, err + } + + pad := " " + if len(args) == 3 { + pad, err = cast.ToStringE(args[2]) + if err != nil { + return nil, err + } + } + + strLen := int64(len(str)) + if strLen >= length { + return str, nil + } + + padLen := length - strLen + padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1)) + return str + padStr[:padLen], nil +} + +// LtrimFunction 去除左侧空白字符 +type LtrimFunction struct { + *BaseFunction +} + +func NewLtrimFunction() *LtrimFunction { + return &LtrimFunction{ + BaseFunction: NewBaseFunction("ltrim", TypeString, "字符串函数", "去除左侧空白字符", 1, 1), + } +} + +func (f *LtrimFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *LtrimFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + return strings.TrimLeftFunc(str, func(r rune) bool { + return r == ' ' || r == '\t' || r == '\n' || r == '\r' + }), nil +} + +// RtrimFunction 去除右侧空白字符 +type RtrimFunction struct { + *BaseFunction +} + +func NewRtrimFunction() *RtrimFunction { + return &RtrimFunction{ + BaseFunction: NewBaseFunction("rtrim", TypeString, "字符串函数", "去除右侧空白字符", 1, 1), + } +} + +func (f *RtrimFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *RtrimFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + return strings.TrimRightFunc(str, func(r rune) bool { + return r == ' ' || r == '\t' || r == '\n' || r == '\r' + }), nil +} + +// RegexpMatchesFunction 正则表达式匹配 +type RegexpMatchesFunction struct { + *BaseFunction +} + +func NewRegexpMatchesFunction() *RegexpMatchesFunction { + return &RegexpMatchesFunction{ + BaseFunction: NewBaseFunction("regexp_matches", TypeString, "字符串函数", "正则表达式匹配", 2, 2), + } +} + +func (f *RegexpMatchesFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *RegexpMatchesFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + pattern, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + + matched, err := regexp.MatchString(pattern, str) + if err != nil { + return nil, err + } + return matched, nil +} + +// RegexpReplaceFunction 正则表达式替换 +type RegexpReplaceFunction struct { + *BaseFunction +} + +func NewRegexpReplaceFunction() *RegexpReplaceFunction { + return &RegexpReplaceFunction{ + BaseFunction: NewBaseFunction("regexp_replace", TypeString, "字符串函数", "正则表达式替换", 3, 3), + } +} + +func (f *RegexpReplaceFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *RegexpReplaceFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + pattern, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + replacement, err := cast.ToStringE(args[2]) + if err != nil { + return nil, err + } + + re, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + return re.ReplaceAllString(str, replacement), nil +} + +// RegexpSubstringFunction 正则表达式提取子字符串 +type RegexpSubstringFunction struct { + *BaseFunction +} + +func NewRegexpSubstringFunction() *RegexpSubstringFunction { + return &RegexpSubstringFunction{ + BaseFunction: NewBaseFunction("regexp_substring", TypeString, "字符串函数", "正则表达式提取子字符串", 2, 2), + } +} + +func (f *RegexpSubstringFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *RegexpSubstringFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + str, err := cast.ToStringE(args[0]) + if err != nil { + return nil, err + } + pattern, err := cast.ToStringE(args[1]) + if err != nil { + return nil, err + } + + re, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + + match := re.FindString(str) + return match, nil +} diff --git a/functions/functions_string_test.go b/functions/functions_string_test.go new file mode 100644 index 0000000..a846a18 --- /dev/null +++ b/functions/functions_string_test.go @@ -0,0 +1,130 @@ +package functions + +import ( + "testing" +) + +func TestNewStringFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + expected interface{} + wantErr bool + }{ + // endswith tests + {"endswith_true", "endswith", []interface{}{"hello world", "world"}, true, false}, + {"endswith_false", "endswith", []interface{}{"hello world", "hello"}, false, false}, + {"endswith_empty", "endswith", []interface{}{"hello", ""}, true, false}, + + // startswith tests + {"startswith_true", "startswith", []interface{}{"hello world", "hello"}, true, false}, + {"startswith_false", "startswith", []interface{}{"hello world", "world"}, false, false}, + {"startswith_empty", "startswith", []interface{}{"hello", ""}, true, false}, + + // indexof tests + {"indexof_found", "indexof", []interface{}{"hello world", "world"}, int64(6), false}, + {"indexof_not_found", "indexof", []interface{}{"hello world", "xyz"}, int64(-1), false}, + {"indexof_first_char", "indexof", []interface{}{"hello", "h"}, int64(0), false}, + + // substring tests + {"substring_start_only", "substring", []interface{}{"hello world", int64(6)}, "world", false}, + {"substring_start_length", "substring", []interface{}{"hello world", int64(0), int64(5)}, "hello", false}, + {"substring_out_of_bounds", "substring", []interface{}{"hello", int64(10)}, "", false}, + + // replace tests + {"replace_simple", "replace", []interface{}{"hello world", "world", "Go"}, "hello Go", false}, + {"replace_multiple", "replace", []interface{}{"hello hello", "hello", "hi"}, "hi hi", false}, + {"replace_not_found", "replace", []interface{}{"hello world", "xyz", "abc"}, "hello world", false}, + + // split tests + {"split_comma", "split", []interface{}{"a,b,c", ","}, []string{"a", "b", "c"}, false}, + {"split_space", "split", []interface{}{"hello world", " "}, []string{"hello", "world"}, false}, + {"split_not_found", "split", []interface{}{"hello", ","}, []string{"hello"}, false}, + + // lpad tests + {"lpad_default", "lpad", []interface{}{"hello", int64(10)}, " hello", false}, + {"lpad_custom", "lpad", []interface{}{"hello", int64(8), "*"}, "***hello", false}, + {"lpad_no_padding", "lpad", []interface{}{"hello", int64(3)}, "hello", false}, + + // rpad tests + {"rpad_default", "rpad", []interface{}{"hello", int64(10)}, "hello ", false}, + {"rpad_custom", "rpad", []interface{}{"hello", int64(8), "*"}, "hello***", false}, + {"rpad_no_padding", "rpad", []interface{}{"hello", int64(3)}, "hello", false}, + + // ltrim tests + {"ltrim_spaces", "ltrim", []interface{}{" hello world "}, "hello world ", false}, + {"ltrim_tabs", "ltrim", []interface{}{"\t\nhello"}, "hello", false}, + {"ltrim_no_whitespace", "ltrim", []interface{}{"hello"}, "hello", false}, + + // rtrim tests + {"rtrim_spaces", "rtrim", []interface{}{" hello world "}, " hello world", false}, + {"rtrim_tabs", "rtrim", []interface{}{"hello\t\n"}, "hello", false}, + {"rtrim_no_whitespace", "rtrim", []interface{}{"hello"}, "hello", false}, + + // regexp_matches tests + {"regexp_matches_true", "regexp_matches", []interface{}{"hello123", "[0-9]+"}, true, false}, + {"regexp_matches_false", "regexp_matches", []interface{}{"hello", "[0-9]+"}, false, false}, + {"regexp_matches_email", "regexp_matches", []interface{}{"test@example.com", "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"}, true, false}, + + // regexp_replace tests + {"regexp_replace_digits", "regexp_replace", []interface{}{"hello123world456", "[0-9]+", "X"}, "helloXworldX", false}, + {"regexp_replace_no_match", "regexp_replace", []interface{}{"hello", "[0-9]+", "X"}, "hello", false}, + + // regexp_substring tests + {"regexp_substring_found", "regexp_substring", []interface{}{"hello123world", "[0-9]+"}, "123", false}, + {"regexp_substring_not_found", "regexp_substring", []interface{}{"hello", "[0-9]+"}, "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + ctx := &FunctionContext{} + result, err := fn.Execute(ctx, tt.args) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error for %s, got nil", tt.name) + } + return + } + + if err != nil { + t.Errorf("Unexpected error for %s: %v", tt.name, err) + return + } + + // 特殊处理 split 函数的结果比较 + if tt.funcName == "split" { + expectedSlice, ok := tt.expected.([]string) + if !ok { + t.Errorf("Expected slice for split function") + return + } + resultSlice, ok := result.([]string) + if !ok { + t.Errorf("Result is not a slice for split function") + return + } + if len(expectedSlice) != len(resultSlice) { + t.Errorf("Expected %v, got %v for %s", expectedSlice, resultSlice, tt.name) + return + } + for i, v := range expectedSlice { + if v != resultSlice[i] { + t.Errorf("Expected %v, got %v for %s", expectedSlice, resultSlice, tt.name) + return + } + } + } else { + if result != tt.expected { + t.Errorf("Expected %v, got %v for %s", tt.expected, result, tt.name) + } + } + }) + } +} \ No newline at end of file diff --git a/functions/functions_test.go b/functions/functions_test.go index 68575ed..9c8521f 100644 --- a/functions/functions_test.go +++ b/functions/functions_test.go @@ -39,6 +39,95 @@ func TestBasicFunctionRegistry(t *testing.T) { assert.False(t, exists, "nonexistent function should not be found") } +func TestNewMathFunctions(t *testing.T) { + ctx := &FunctionContext{ + Data: map[string]interface{}{}, + } + + // 表驱动测试用例 + tests := []struct { + name string + functionName string + args []interface{} + expected interface{} + expectError bool + errorMsg string + delta float64 // 用于浮点数比较的精度 + }{ + // Log function tests + {"log valid", "log", []interface{}{math.E}, 1.0, false, "", 1e-10}, + {"log negative", "log", []interface{}{-1}, nil, true, "value must be positive", 0}, + {"log zero", "log", []interface{}{0}, nil, true, "value must be positive", 0}, + + // Log10 function tests + {"log10 100", "log10", []interface{}{100}, 2.0, false, "", 1e-10}, + {"log10 10", "log10", []interface{}{10}, 1.0, false, "", 1e-10}, + + // Log2 function tests + {"log2 8", "log2", []interface{}{8}, 3.0, false, "", 1e-10}, + {"log2 2", "log2", []interface{}{2}, 1.0, false, "", 1e-10}, + + // Mod function tests + {"mod 10,3", "mod", []interface{}{10, 3}, 1.0, false, "", 1e-10}, + {"mod 7.5,2.5", "mod", []interface{}{7.5, 2.5}, 0.0, false, "", 1e-10}, + {"mod division by zero", "mod", []interface{}{10, 0}, nil, true, "division by zero", 0}, + + // Round function tests + {"round 3.7", "round", []interface{}{3.7}, 4.0, false, "", 1e-10}, + {"round 3.2", "round", []interface{}{3.2}, 3.0, false, "", 1e-10}, + {"round with precision", "round", []interface{}{3.14159, 2}, 3.14, false, "", 1e-10}, + + // Sign function tests + {"sign positive", "sign", []interface{}{5.5}, 1, false, "", 0}, + {"sign negative", "sign", []interface{}{-3.2}, -1, false, "", 0}, + {"sign zero", "sign", []interface{}{0}, 0, false, "", 0}, + + // Trigonometric function tests + {"sin 0", "sin", []interface{}{0}, 0.0, false, "", 1e-10}, + {"sin π/2", "sin", []interface{}{math.Pi / 2}, 1.0, false, "", 1e-10}, + {"sinh 0", "sinh", []interface{}{0}, 0.0, false, "", 1e-10}, + {"tan 0", "tan", []interface{}{0}, 0.0, false, "", 1e-10}, + {"tanh 0", "tanh", []interface{}{0}, 0.0, false, "", 1e-10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.functionName) + assert.True(t, exists, "Function %s should be registered", tt.functionName) + + result, err := fn.Execute(ctx, tt.args) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + if tt.delta > 0 { + assert.InDelta(t, tt.expected, result, tt.delta) + } else { + assert.Equal(t, tt.expected, result) + } + } + }) + } + + // 特殊测试:rand函数(因为结果是随机的) + t.Run("rand function", func(t *testing.T) { + fn, exists := Get("rand") + assert.True(t, exists) + + result, err := fn.Execute(ctx, []interface{}{}) + assert.NoError(t, err) + + val, ok := result.(float64) + assert.True(t, ok) + assert.GreaterOrEqual(t, val, 0.0) + assert.Less(t, val, 1.0) + }) +} + func TestFunctionExecution(t *testing.T) { ctx := &FunctionContext{ Data: map[string]interface{}{}, @@ -161,6 +250,10 @@ func TestFunctionExecution(t *testing.T) { t.Run(tt.name, func(t *testing.T) { fn, exists := Get(tt.functionName) assert.True(t, exists, "function %s should exist", tt.functionName) + if !exists || fn == nil { + t.Errorf("Function %s not found or is nil", tt.functionName) + return + } result, err := fn.Execute(ctx, tt.args) @@ -171,7 +264,11 @@ func TestFunctionExecution(t *testing.T) { if tt.expected != nil { switch expected := tt.expected.(type) { case float64: - assert.InDelta(t, expected, result.(float64), 0.0001, "result should match for %s", tt.name) + if resultFloat, ok := result.(float64); ok { + assert.InDelta(t, expected, resultFloat, 0.0001, "result should match for %s", tt.name) + } else { + t.Errorf("Expected float64 but got %T for %s", result, tt.name) + } case int64: if tt.functionName == "now" { // 对于 now 函数,我们只检查结果是否为 int64 类型,因为具体值会随时间变化 diff --git a/functions/functions_type.go b/functions/functions_type.go new file mode 100644 index 0000000..f1cd2c9 --- /dev/null +++ b/functions/functions_type.go @@ -0,0 +1,170 @@ +package functions + +import ( + "reflect" +) + +// IsNullFunction 检查是否为NULL +type IsNullFunction struct { + *BaseFunction +} + +func NewIsNullFunction() *IsNullFunction { + return &IsNullFunction{ + BaseFunction: NewBaseFunction("is_null", TypeString, "类型检查函数", "检查是否为NULL", 1, 1), + } +} + +func (f *IsNullFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IsNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + return args[0] == nil, nil +} + +// IsNotNullFunction 检查是否不为NULL +type IsNotNullFunction struct { + *BaseFunction +} + +func NewIsNotNullFunction() *IsNotNullFunction { + return &IsNotNullFunction{ + BaseFunction: NewBaseFunction("is_not_null", TypeString, "类型检查函数", "检查是否不为NULL", 1, 1), + } +} + +func (f *IsNotNullFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IsNotNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + return args[0] != nil, nil +} + +// IsNumericFunction 检查是否为数字类型 +type IsNumericFunction struct { + *BaseFunction +} + +func NewIsNumericFunction() *IsNumericFunction { + return &IsNumericFunction{ + BaseFunction: NewBaseFunction("is_numeric", TypeString, "类型检查函数", "检查是否为数字类型", 1, 1), + } +} + +func (f *IsNumericFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IsNumericFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if args[0] == nil { + return false, nil + } + + v := reflect.ValueOf(args[0]) + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return true, nil + default: + return false, nil + } +} + +// IsStringFunction 检查是否为字符串类型 +type IsStringFunction struct { + *BaseFunction +} + +func NewIsStringFunction() *IsStringFunction { + return &IsStringFunction{ + BaseFunction: NewBaseFunction("is_string", TypeString, "类型检查函数", "检查是否为字符串类型", 1, 1), + } +} + +func (f *IsStringFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IsStringFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if args[0] == nil { + return false, nil + } + + _, ok := args[0].(string) + return ok, nil +} + +// IsBoolFunction 检查是否为布尔类型 +type IsBoolFunction struct { + *BaseFunction +} + +func NewIsBoolFunction() *IsBoolFunction { + return &IsBoolFunction{ + BaseFunction: NewBaseFunction("is_bool", TypeString, "类型检查函数", "检查是否为布尔类型", 1, 1), + } +} + +func (f *IsBoolFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IsBoolFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if args[0] == nil { + return false, nil + } + + _, ok := args[0].(bool) + return ok, nil +} + +// IsArrayFunction 检查是否为数组类型 +type IsArrayFunction struct { + *BaseFunction +} + +func NewIsArrayFunction() *IsArrayFunction { + return &IsArrayFunction{ + BaseFunction: NewBaseFunction("is_array", TypeString, "类型检查函数", "检查是否为数组类型", 1, 1), + } +} + +func (f *IsArrayFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IsArrayFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if args[0] == nil { + return false, nil + } + + v := reflect.ValueOf(args[0]) + return v.Kind() == reflect.Slice || v.Kind() == reflect.Array, nil +} + +// IsObjectFunction 检查是否为对象类型 +type IsObjectFunction struct { + *BaseFunction +} + +func NewIsObjectFunction() *IsObjectFunction { + return &IsObjectFunction{ + BaseFunction: NewBaseFunction("is_object", TypeString, "类型检查函数", "检查是否为对象类型", 1, 1), + } +} + +func (f *IsObjectFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *IsObjectFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if args[0] == nil { + return false, nil + } + + v := reflect.ValueOf(args[0]) + return v.Kind() == reflect.Map || v.Kind() == reflect.Struct, nil +} \ No newline at end of file diff --git a/functions/functions_window.go b/functions/functions_window.go index 24efd52..c2067b5 100644 --- a/functions/functions_window.go +++ b/functions/functions_window.go @@ -1,5 +1,9 @@ package functions +import ( + "fmt" +) + // RowNumberFunction 行号函数 type RowNumberFunction struct { *BaseFunction @@ -241,3 +245,201 @@ func (f *ExpressionAggregatorFunction) Clone() AggregatorFunction { lastResult: f.lastResult, } } + +// FirstValueFunction 返回窗口中第一个值 +type FirstValueFunction struct { + *BaseFunction + firstValue interface{} + hasValue bool +} + +func NewFirstValueFunction() *FirstValueFunction { + return &FirstValueFunction{ + BaseFunction: NewBaseFunction("first_value", TypeWindow, "窗口函数", "返回窗口中第一个值", 1, 1), + hasValue: false, + } +} + +func (f *FirstValueFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *FirstValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + return f.firstValue, nil +} + +// 实现AggregatorFunction接口 +func (f *FirstValueFunction) New() AggregatorFunction { + return &FirstValueFunction{ + BaseFunction: f.BaseFunction, + hasValue: false, + } +} + +func (f *FirstValueFunction) Add(value interface{}) { + if !f.hasValue { + f.firstValue = value + f.hasValue = true + } +} + +func (f *FirstValueFunction) Result() interface{} { + return f.firstValue +} + +func (f *FirstValueFunction) Reset() { + f.firstValue = nil + f.hasValue = false +} + +func (f *FirstValueFunction) Clone() AggregatorFunction { + return &FirstValueFunction{ + BaseFunction: f.BaseFunction, + firstValue: f.firstValue, + hasValue: f.hasValue, + } +} + +// LeadFunction 返回当前行之后第N行的值 +type LeadFunction struct { + *BaseFunction + values []interface{} +} + +func NewLeadFunction() *LeadFunction { + return &LeadFunction{ + BaseFunction: NewBaseFunction("lead", TypeWindow, "窗口函数", "返回当前行之后第N行的值", 1, 3), + values: make([]interface{}, 0), + } +} + +func (f *LeadFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *LeadFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + // 获取默认值 + var defaultValue interface{} + if len(args) >= 3 { + defaultValue = args[2] + } + + // Lead函数需要在窗口处理完成后才能确定值 + // 这里返回默认值,实际实现需要在窗口引擎中处理 + return defaultValue, nil +} + +// 实现AggregatorFunction接口 +func (f *LeadFunction) New() AggregatorFunction { + return &LeadFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, 0), + } +} + +func (f *LeadFunction) Add(value interface{}) { + f.values = append(f.values, value) +} + +func (f *LeadFunction) Result() interface{} { + // Lead函数的结果需要在所有数据添加完成后计算 + // 这里简化实现,返回nil + return nil +} + +func (f *LeadFunction) Reset() { + f.values = make([]interface{}, 0) +} + +func (f *LeadFunction) Clone() AggregatorFunction { + clone := &LeadFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, len(f.values)), + } + copy(clone.values, f.values) + return clone +} + +// NthValueFunction 返回窗口中第N个值 +type NthValueFunction struct { + *BaseFunction + values []interface{} + n int +} + +func NewNthValueFunction() *NthValueFunction { + return &NthValueFunction{ + BaseFunction: NewBaseFunction("nth_value", TypeWindow, "窗口函数", "返回窗口中第N个值", 2, 2), + values: make([]interface{}, 0), + n: 1, // 默认第1个值 + } +} + +func (f *NthValueFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + + // 获取N值 + n := 1 + if nVal, ok := args[1].(int); ok { + n = nVal + } else if nVal, ok := args[1].(int64); ok { + n = int(nVal) + } else { + return nil, fmt.Errorf("nth_value n must be an integer") + } + + if n <= 0 { + return nil, fmt.Errorf("nth_value n must be positive, got %d", n) + } + + // 返回第N个值(1-based索引) + if len(f.values) >= n { + return f.values[n-1], nil + } + + return nil, nil +} + +// 实现AggregatorFunction接口 +func (f *NthValueFunction) New() AggregatorFunction { + return &NthValueFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, 0), + n: f.n, + } +} + +func (f *NthValueFunction) Add(value interface{}) { + f.values = append(f.values, value) +} + +func (f *NthValueFunction) Result() interface{} { + if len(f.values) >= f.n && f.n > 0 { + return f.values[f.n-1] + } + return nil +} + +func (f *NthValueFunction) Reset() { + f.values = make([]interface{}, 0) +} + +func (f *NthValueFunction) Clone() AggregatorFunction { + clone := &NthValueFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, len(f.values)), + n: f.n, + } + copy(clone.values, f.values) + return clone +} diff --git a/functions/functions_window_test.go b/functions/functions_window_test.go new file mode 100644 index 0000000..531c32a --- /dev/null +++ b/functions/functions_window_test.go @@ -0,0 +1,298 @@ +package functions + +import ( + "testing" +) + +func TestNewWindowFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + want interface{} + wantErr bool + setup func(fn AggregatorFunction) + }{ + // first_value 函数测试 + { + name: "first_value basic", + funcName: "first_value", + args: []interface{}{"test"}, + want: "first", + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + fn.Add("third") + }, + }, + { + name: "first_value empty", + funcName: "first_value", + args: []interface{}{"test"}, + want: nil, + wantErr: false, + setup: func(fn AggregatorFunction) {}, + }, + + // last_value 函数测试 + { + name: "last_value basic", + funcName: "last_value", + args: []interface{}{"test"}, + want: "third", + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + fn.Add("third") + }, + }, + { + name: "last_value empty", + funcName: "last_value", + args: []interface{}{"test"}, + want: nil, + wantErr: false, + setup: func(fn AggregatorFunction) {}, + }, + + // lag 函数测试 + { + name: "lag default offset", + funcName: "lag", + args: []interface{}{"test"}, + want: "second", + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + fn.Add("third") + }, + }, + { + name: "lag with offset 2", + funcName: "lag", + args: []interface{}{"test", 2}, + want: "first", + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + fn.Add("third") + }, + }, + { + name: "lag with default value", + funcName: "lag", + args: []interface{}{"test", 5, "default"}, + want: "default", + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + }, + }, + { + name: "lag invalid offset type", + funcName: "lag", + args: []interface{}{"test", "invalid"}, + wantErr: true, + setup: func(fn AggregatorFunction) {}, + }, + + // lead 函数测试 + { + name: "lead default offset", + funcName: "lead", + args: []interface{}{"test"}, + want: nil, // Lead函数简化实现返回nil + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + fn.Add("third") + }, + }, + { + name: "lead with default value", + funcName: "lead", + args: []interface{}{"test", 1, "default"}, + want: "default", + wantErr: false, + setup: func(fn AggregatorFunction) {}, + }, + { + name: "lead invalid offset type", + funcName: "lead", + args: []interface{}{"test", "invalid"}, + wantErr: true, + setup: func(fn AggregatorFunction) {}, + }, + + // nth_value 函数测试 + { + name: "nth_value first", + funcName: "nth_value", + args: []interface{}{"test", 1}, + want: "first", + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + fn.Add("third") + }, + }, + { + name: "nth_value second", + funcName: "nth_value", + args: []interface{}{"test", 2}, + want: "second", + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + fn.Add("third") + }, + }, + { + name: "nth_value out of range", + funcName: "nth_value", + args: []interface{}{"test", 5}, + want: nil, + wantErr: false, + setup: func(fn AggregatorFunction) { + fn.Add("first") + fn.Add("second") + }, + }, + { + name: "nth_value invalid n type", + funcName: "nth_value", + args: []interface{}{"test", "invalid"}, + wantErr: true, + setup: func(fn AggregatorFunction) {}, + }, + { + name: "nth_value zero n", + funcName: "nth_value", + args: []interface{}{"test", 0}, + wantErr: true, + setup: func(fn AggregatorFunction) {}, + }, + { + name: "nth_value negative n", + funcName: "nth_value", + args: []interface{}{"test", -1}, + wantErr: true, + setup: func(fn AggregatorFunction) {}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + // 检查函数是否实现了AggregatorFunction接口 + aggFn, ok := fn.(AggregatorFunction) + if !ok { + t.Fatalf("Function %s does not implement AggregatorFunction", tt.funcName) + } + + // 先执行函数的Validate方法来设置参数 + err := fn.Validate(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // 如果期望错误且Validate已经失败,则测试通过 + if tt.wantErr && err != nil { + return + } + + // 创建新的聚合器实例 + aggInstance := aggFn.New() + + // 执行setup函数添加测试数据 + if tt.setup != nil { + tt.setup(aggInstance) + } + + // 执行函数 + _, err = fn.Execute(nil, tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // 对于窗口函数,我们主要测试聚合器的Result方法 + aggResult := aggInstance.Result() + if tt.want != nil && aggResult != tt.want { + t.Errorf("AggregatorFunction.Result() = %v, want %v", aggResult, tt.want) + } + } + }) + } +} + +// 测试窗口函数的基本功能 +func TestWindowFunctionBasics(t *testing.T) { + // 测试row_number函数 + t.Run("RowNumberFunction", func(t *testing.T) { + rowNumFunc, exists := Get("row_number") + if !exists { + t.Fatal("row_number function not found") + } + + // 测试行号递增 + result1, err := rowNumFunc.Execute(nil, []interface{}{}) + if err != nil { + t.Errorf("Execute() error = %v", err) + } + if result1 != int64(1) { + t.Errorf("First call should return 1, got %v", result1) + } + + result2, err := rowNumFunc.Execute(nil, []interface{}{}) + if err != nil { + t.Errorf("Execute() error = %v", err) + } + if result2 != int64(2) { + t.Errorf("Second call should return 2, got %v", result2) + } + }) + + // 测试window_start和window_end函数 + t.Run("WindowStartEndFunctions", func(t *testing.T) { + windowStartFunc, exists := Get("window_start") + if !exists { + t.Fatal("window_start function not found") + } + + windowEndFunc, exists := Get("window_end") + if !exists { + t.Fatal("window_end function not found") + } + + // 测试无窗口信息时的行为 + ctx := &FunctionContext{ + Data: map[string]interface{}{}, + } + _, err := windowStartFunc.Execute(ctx, []interface{}{}) + if err != nil { + t.Errorf("Execute() error = %v", err) + } + // 无窗口信息时应该返回nil或默认值 + + _, err = windowEndFunc.Execute(ctx, []interface{}{}) + if err != nil { + t.Errorf("Execute() error = %v", err) + } + // 无窗口信息时应该返回nil或默认值 + }) +} \ No newline at end of file diff --git a/functions/registry.go b/functions/registry.go index 28350ac..de11137 100644 --- a/functions/registry.go +++ b/functions/registry.go @@ -218,3 +218,123 @@ func (f *CustomFunction) Validate(args []interface{}) error { func (f *CustomFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { return f.executor(ctx, args) } + +func init() { + // 注册数学函数 + Register(NewAbsFunction()) + Register(NewSqrtFunction()) + Register(NewPowerFunction()) + Register(NewCeilingFunction()) + Register(NewFloorFunction()) + Register(NewRoundFunction()) + Register(NewModFunction()) + Register(NewMaxFunction()) + Register(NewMinFunction()) + Register(NewRandFunction()) + + // 注册字符串函数 + Register(NewUpperFunction()) + Register(NewLowerFunction()) + Register(NewLengthFunction()) + Register(NewSubstringFunction()) + Register(NewConcatFunction()) + Register(NewTrimFunction()) + Register(NewLtrimFunction()) + Register(NewRtrimFunction()) + Register(NewReplaceFunction()) + Register(NewSplitFunction()) + Register(NewStartswithFunction()) + Register(NewEndswithFunction()) + Register(NewRegexpMatchesFunction()) + Register(NewRegexpReplaceFunction()) + Register(NewLpadFunction()) + Register(NewRpadFunction()) + Register(NewIndexofFunction()) + Register(NewFormatFunction()) + + // 注册时间日期函数 + Register(NewNowFunction()) + Register(NewCurrentTimeFunction()) + Register(NewCurrentDateFunction()) + Register(NewDateAddFunction()) + Register(NewDateSubFunction()) + Register(NewDateDiffFunction()) + Register(NewDateFormatFunction()) + Register(NewDateParseFunction()) + Register(NewExtractFunction()) + Register(NewUnixTimestampFunction()) + Register(NewFromUnixtimeFunction()) + Register(NewYearFunction()) + Register(NewMonthFunction()) + Register(NewDayFunction()) + Register(NewHourFunction()) + Register(NewMinuteFunction()) + Register(NewSecondFunction()) + Register(NewDayOfWeekFunction()) + Register(NewDayOfYearFunction()) + Register(NewWeekOfYearFunction()) + + // 注册转换函数 + Register(NewCastFunction()) + Register(NewHex2DecFunction()) + Register(NewDec2HexFunction()) + Register(NewEncodeFunction()) + Register(NewDecodeFunction()) + + // 注册聚合函数 + Register(NewCountFunction()) + Register(NewSumFunction()) + Register(NewAvgFunction()) + Register(NewMaxFunction()) + Register(NewMinFunction()) + + // 注册窗口函数 + Register(NewRowNumberFunction()) + Register(NewLagFunction()) + Register(NewLeadFunction()) + Register(NewFirstValueFunction()) + Register(NewNthValueFunction()) + + // 注册分析函数 + Register(NewLatestFunction()) + Register(NewHadChangedFunction()) + + // 注册JSON函数 + Register(NewJsonExtractFunction()) + Register(NewJsonValidFunction()) + Register(NewJsonTypeFunction()) + Register(NewJsonLengthFunction()) + Register(NewToJsonFunction()) + Register(NewFromJsonFunction()) + + // 注册哈希函数 + Register(NewMd5Function()) + Register(NewSha1Function()) + Register(NewSha256Function()) + Register(NewSha512Function()) + + // 注册数组函数 + Register(NewArrayLengthFunction()) + Register(NewArrayContainsFunction()) + Register(NewArrayPositionFunction()) + Register(NewArrayRemoveFunction()) + Register(NewArrayDistinctFunction()) + Register(NewArrayIntersectFunction()) + Register(NewArrayUnionFunction()) + Register(NewArrayExceptFunction()) + + // 注册类型检查函数 + Register(NewIsNullFunction()) + Register(NewIsNotNullFunction()) + Register(NewIsStringFunction()) + Register(NewIsNumericFunction()) + Register(NewIsBoolFunction()) + Register(NewIsArrayFunction()) + Register(NewIsObjectFunction()) + + // 注册条件函数 + Register(NewCoalesceFunction()) + Register(NewNullIfFunction()) + Register(NewGreatestFunction()) + Register(NewLeastFunction()) +}