From 049295b599eeff6db0432502bed6a9028a8395ef Mon Sep 17 00:00:00 2001 From: rulego-team Date: Fri, 8 Aug 2025 09:40:26 +0800 Subject: [PATCH] test:add test cases --- functions/aggregator_types_test.go | 338 ++++++++++++++++++ .../analytical_aggregator_adapter_test.go | 297 +++++++++++++++ functions/functions_expr_test.go | 151 ++++++++ functions/functions_json_test.go | 326 ++++++++++++++++- functions/functions_multirow_test.go | 126 +++++++ 5 files changed, 1237 insertions(+), 1 deletion(-) create mode 100644 functions/aggregator_types_test.go create mode 100644 functions/analytical_aggregator_adapter_test.go diff --git a/functions/aggregator_types_test.go b/functions/aggregator_types_test.go new file mode 100644 index 0000000..9181367 --- /dev/null +++ b/functions/aggregator_types_test.go @@ -0,0 +1,338 @@ +package functions + +import ( + "testing" +) + +// TestAggregateTypeConstants 测试聚合类型常量 +func TestAggregateTypeConstants(t *testing.T) { + tests := []struct { + name string + aggType AggregateType + expected string + }{ + {"Sum", Sum, "sum"}, + {"Count", Count, "count"}, + {"Avg", Avg, "avg"}, + {"Max", Max, "max"}, + {"Min", Min, "min"}, + {"Median", Median, "median"}, + {"Percentile", Percentile, "percentile"}, + {"WindowStart", WindowStart, "window_start"}, + {"WindowEnd", WindowEnd, "window_end"}, + {"Collect", Collect, "collect"}, + {"LastValue", LastValue, "last_value"}, + {"MergeAgg", MergeAgg, "merge_agg"}, + {"StdDev", StdDev, "stddev"}, + {"StdDevS", StdDevS, "stddevs"}, + {"Deduplicate", Deduplicate, "deduplicate"}, + {"Var", Var, "var"}, + {"VarS", VarS, "vars"}, + {"Lag", Lag, "lag"}, + {"Latest", Latest, "latest"}, + {"ChangedCol", ChangedCol, "changed_col"}, + {"HadChanged", HadChanged, "had_changed"}, + {"Expression", Expression, "expression"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if string(tt.aggType) != tt.expected { + t.Errorf("AggregateType %s = %s, want %s", tt.name, string(tt.aggType), tt.expected) + } + }) + } +} + +// TestStringConstants 测试字符串常量 +func TestStringConstants(t *testing.T) { + tests := []struct { + name string + constant string + expected string + }{ + {"SumStr", SumStr, "sum"}, + {"CountStr", CountStr, "count"}, + {"AvgStr", AvgStr, "avg"}, + {"MaxStr", MaxStr, "max"}, + {"MinStr", MinStr, "min"}, + {"MedianStr", MedianStr, "median"}, + {"PercentileStr", PercentileStr, "percentile"}, + {"WindowStartStr", WindowStartStr, "window_start"}, + {"WindowEndStr", WindowEndStr, "window_end"}, + {"CollectStr", CollectStr, "collect"}, + {"LastValueStr", LastValueStr, "last_value"}, + {"MergeAggStr", MergeAggStr, "merge_agg"}, + {"StdDevStr", StdDevStr, "stddev"}, + {"StdDevSStr", StdDevSStr, "stddevs"}, + {"DeduplicateStr", DeduplicateStr, "deduplicate"}, + {"VarStr", VarStr, "var"}, + {"VarSStr", VarSStr, "vars"}, + {"LagStr", LagStr, "lag"}, + {"LatestStr", LatestStr, "latest"}, + {"ChangedColStr", ChangedColStr, "changed_col"}, + {"HadChangedStr", HadChangedStr, "had_changed"}, + {"ExpressionStr", ExpressionStr, "expression"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.constant != tt.expected { + t.Errorf("Constant %s = %s, want %s", tt.name, tt.constant, tt.expected) + } + }) + } +} + +// TestRegisterLegacyAggregator 测试注册遗留聚合器 +func TestRegisterLegacyAggregator(t *testing.T) { + // 创建一个测试聚合器构造函数 + constructor := func() LegacyAggregatorFunction { + return &TestLegacyAggregator{} + } + + // 注册聚合器 + RegisterLegacyAggregator("test_agg", constructor) + + // 验证注册成功 + legacyRegistryMutex.RLock() + _, exists := legacyAggregatorRegistry["test_agg"] + legacyRegistryMutex.RUnlock() + + if !exists { + t.Error("Failed to register legacy aggregator") + } + + // 测试创建聚合器 + createdAgg := CreateLegacyAggregator("test_agg") + if createdAgg == nil { + t.Error("Failed to create legacy aggregator") + } + + // 测试聚合器功能 + createdAgg.Add(10) + createdAgg.Add(20) + result := createdAgg.Result() + if result != 30 { + t.Errorf("Expected result 30, got %v", result) + } +} + +// TestCreateLegacyAggregatorPanic 测试创建不存在的聚合器时的panic +func TestCreateLegacyAggregatorPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for unsupported aggregator type") + } + }() + + CreateLegacyAggregator("nonexistent_aggregator") +} + +// TestFunctionAggregatorWrapper 测试函数聚合器包装器 +func TestFunctionAggregatorWrapper(t *testing.T) { + // 创建一个测试聚合器函数 + testAgg := &TestAggregatorFunction{} + + // 创建一个测试适配器 + adapter := &AggregatorAdapter{ + aggFunc: testAgg, + } + wrapper := &FunctionAggregatorWrapper{adapter: adapter} + + // 测试New方法 + newWrapper := wrapper.New() + if newWrapper == nil { + t.Error("New() should return a new wrapper") + } + + // 测试GetContextKey方法 + contextKey := wrapper.GetContextKey() + if contextKey != "" { + t.Logf("Context key: %s", contextKey) + } +} + +// TestAnalyticalAggregatorWrapper 测试分析聚合器包装器 +func TestAnalyticalAggregatorWrapper(t *testing.T) { + // 创建一个测试分析函数 + testAnalFunc := &TestAnalyticalFunction{} + + // 创建一个测试适配器 + adapter := &AnalyticalAggregatorAdapter{ + analFunc: testAnalFunc, + ctx: &FunctionContext{ + Data: make(map[string]interface{}), + }, + } + wrapper := &AnalyticalAggregatorWrapper{adapter: adapter} + + // 测试New方法 + newWrapper := wrapper.New() + if newWrapper == nil { + t.Error("New() should return a new wrapper") + } + + // 测试Add和Result方法 + wrapper.Add("test") + result := wrapper.Result() + t.Logf("Result: %v", result) +} + +// TestLegacyAggregator 测试用的遗留聚合器实现 +type TestLegacyAggregator struct { + sum int +} + +// New 创建新的聚合器实例 +func (t *TestLegacyAggregator) New() LegacyAggregatorFunction { + return &TestLegacyAggregator{} +} + +// Add 添加值 +func (t *TestLegacyAggregator) Add(value interface{}) { + if v, ok := value.(int); ok { + t.sum += v + } +} + +// Result 返回聚合结果 +func (t *TestLegacyAggregator) Result() interface{} { + return t.sum +} + +// TestAggregatorFunction 测试用的聚合器函数实现 +type TestAggregatorFunction struct { + sum int +} + +// New 创建新的聚合器实例 +func (t *TestAggregatorFunction) New() AggregatorFunction { + return &TestAggregatorFunction{} +} + +// Add 添加值 +func (t *TestAggregatorFunction) Add(value interface{}) { + if v, ok := value.(int); ok { + t.sum += v + } +} + +// Result 返回聚合结果 +func (t *TestAggregatorFunction) Result() interface{} { + return t.sum +} + +// Reset 重置聚合器状态 +func (t *TestAggregatorFunction) Reset() { + t.sum = 0 +} + +// Clone 克隆聚合器 +func (t *TestAggregatorFunction) Clone() AggregatorFunction { + return &TestAggregatorFunction{sum: t.sum} +} + +// GetName 返回函数名称 +func (t *TestAggregatorFunction) GetName() string { + return "test_aggregator" +} + +// GetType 返回函数类型 +func (t *TestAggregatorFunction) GetType() FunctionType { + return TypeAggregation +} + +// GetCategory 返回函数分类 +func (t *TestAggregatorFunction) GetCategory() string { + return "test" +} + +// GetDescription 返回函数描述 +func (t *TestAggregatorFunction) GetDescription() string { + return "Test aggregator function" +} + +// GetAliases 返回函数别名 +func (t *TestAggregatorFunction) GetAliases() []string { + return []string{} +} + +// Validate 验证参数 +func (t *TestAggregatorFunction) Validate(args []interface{}) error { + return nil +} + +// Execute 执行函数 +func (t *TestAggregatorFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + return t.Result(), nil +} + +// TestAnalyticalFunction 测试用的分析函数实现 +type TestAnalyticalFunction struct { + values []interface{} +} + +// New 创建新的分析函数实例 +func (t *TestAnalyticalFunction) New() AggregatorFunction { + return &TestAnalyticalFunction{ + values: make([]interface{}, 0), + } +} + +// Add 添加值 +func (t *TestAnalyticalFunction) Add(value interface{}) { + t.values = append(t.values, value) +} + +// Result 返回分析结果 +func (t *TestAnalyticalFunction) Result() interface{} { + return len(t.values) +} + +// Reset 重置分析函数状态 +func (t *TestAnalyticalFunction) Reset() { + t.values = make([]interface{}, 0) +} + +// Clone 克隆分析函数 +func (t *TestAnalyticalFunction) Clone() AggregatorFunction { + newValues := make([]interface{}, len(t.values)) + copy(newValues, t.values) + return &TestAnalyticalFunction{values: newValues} +} + +// GetName 返回函数名称 +func (t *TestAnalyticalFunction) GetName() string { + return "test_analytical" +} + +// GetType 返回函数类型 +func (t *TestAnalyticalFunction) GetType() FunctionType { + return TypeAnalytical +} + +// GetCategory 返回函数分类 +func (t *TestAnalyticalFunction) GetCategory() string { + return "test" +} + +// GetDescription 返回函数描述 +func (t *TestAnalyticalFunction) GetDescription() string { + return "Test analytical function" +} + +// GetAliases 返回函数别名 +func (t *TestAnalyticalFunction) GetAliases() []string { + return []string{} +} + +// Validate 验证参数 +func (t *TestAnalyticalFunction) Validate(args []interface{}) error { + return nil +} + +// Execute 执行函数 +func (t *TestAnalyticalFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + return t.Result(), nil +} \ No newline at end of file diff --git a/functions/analytical_aggregator_adapter_test.go b/functions/analytical_aggregator_adapter_test.go new file mode 100644 index 0000000..7efc0e8 --- /dev/null +++ b/functions/analytical_aggregator_adapter_test.go @@ -0,0 +1,297 @@ +package functions + +import ( + "testing" +) + +// TestNewAnalyticalAggregatorAdapter 测试创建分析聚合器适配器 +func TestNewAnalyticalAggregatorAdapter(t *testing.T) { + tests := []struct { + name string + funcName string + expectError bool + }{ + { + name: "valid analytical function", + funcName: "lag", + expectError: false, + }, + { + name: "invalid function name", + funcName: "nonexistent_function", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + adapter, err := NewAnalyticalAggregatorAdapter(tt.funcName) + if tt.expectError { + if err == nil { + t.Error("Expected error but got none") + } + if adapter != nil { + t.Error("Expected nil adapter but got one") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if adapter == nil { + t.Error("Expected adapter but got nil") + } + if adapter != nil { + if adapter.analFunc == nil { + t.Error("Expected analytical function but got nil") + } + if adapter.ctx == nil { + t.Error("Expected context but got nil") + } + } + } + }) + } +} + +// TestAnalyticalAggregatorAdapterNew 测试New方法 +func TestAnalyticalAggregatorAdapterNew(t *testing.T) { + // 创建一个测试适配器 + adapter, err := NewAnalyticalAggregatorAdapter("lag") + if err != nil { + t.Fatalf("Failed to create adapter: %v", err) + } + + // 测试New方法 + newAdapter := adapter.New() + if newAdapter == nil { + t.Error("New() should return a new adapter") + } + + // 验证返回的是正确的类型 + if _, ok := newAdapter.(*AnalyticalAggregatorAdapter); !ok { + t.Error("New() should return AnalyticalAggregatorAdapter") + } + + // 验证新适配器有独立的上下文 + newAdapterTyped := newAdapter.(*AnalyticalAggregatorAdapter) + if newAdapterTyped.ctx == adapter.ctx { + t.Error("New adapter should have independent context") + } +} + +// TestAnalyticalAggregatorAdapterAdd 测试Add方法 +func TestAnalyticalAggregatorAdapterAdd(t *testing.T) { + adapter, err := NewAnalyticalAggregatorAdapter("lag") + if err != nil { + t.Fatalf("Failed to create adapter: %v", err) + } + + // 测试Add方法 + adapter.Add(10) + adapter.Add(20) + adapter.Add(30) + + // 验证没有panic + t.Log("Add method executed successfully") +} + +// TestAnalyticalAggregatorAdapterResult 测试Result方法 +func TestAnalyticalAggregatorAdapterResult(t *testing.T) { + tests := []struct { + name string + funcName string + }{ + {"lag function", "lag"}, + {"latest function", "latest"}, + {"had_changed function", "had_changed"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + adapter, err := NewAnalyticalAggregatorAdapter(tt.funcName) + if err != nil { + t.Skipf("Function %s not available: %v", tt.funcName, err) + return + } + + // 添加一些数据 + adapter.Add(10) + adapter.Add(20) + + // 获取结果 + result := adapter.Result() + t.Logf("Result for %s: %v", tt.funcName, result) + }) + } +} + +// TestCreateAnalyticalAggregatorFromFunctions 测试从函数模块创建分析聚合器 +func TestCreateAnalyticalAggregatorFromFunctions(t *testing.T) { + tests := []struct { + name string + funcType string + expected bool // 是否期望创建成功 + }{ + { + name: "valid analytical function", + funcType: "lag", + expected: true, + }, + { + name: "another valid function", + funcType: "latest", + expected: true, + }, + { + name: "invalid function", + funcType: "nonexistent", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CreateAnalyticalAggregatorFromFunctions(tt.funcType) + if tt.expected { + if result == nil { + t.Errorf("Expected to create aggregator for %s but got nil", tt.funcType) + } else { + // 验证返回的是正确的类型 + if _, ok := result.(*AnalyticalAggregatorAdapter); !ok { + t.Errorf("Expected AnalyticalAggregatorAdapter but got %T", result) + } + } + } else { + if result != nil { + t.Errorf("Expected nil for %s but got %v", tt.funcType, result) + } + } + }) + } +} + +// TestAnalyticalAggregatorAdapterWithMockFunction 测试使用模拟函数的适配器 +func TestAnalyticalAggregatorAdapterWithMockFunction(t *testing.T) { + // 创建模拟分析函数 + mockFunc := &MockAnalyticalFunction{ + name: "mock_analytical", + values: []interface{}{}, + } + + // 创建适配器 + adapter := &AnalyticalAggregatorAdapter{ + analFunc: mockFunc, + ctx: &FunctionContext{ + Data: make(map[string]interface{}), + }, + } + + // 测试Add方法 + adapter.Add("test1") + adapter.Add("test2") + + // 验证值被添加 + if len(mockFunc.values) != 2 { + t.Errorf("Expected 2 values, got %d", len(mockFunc.values)) + } + + // 测试Result方法 + result := adapter.Result() + if result != "mock_result" { + t.Errorf("Expected 'mock_result', got %v", result) + } + + // 测试New方法 + newAdapter := adapter.New() + if newAdapter == nil { + t.Error("New() should return a new adapter") + } +} + +// MockAnalyticalFunction 模拟分析函数用于测试 +type MockAnalyticalFunction struct { + name string + values []interface{} +} + +// GetName 返回函数名称 +func (m *MockAnalyticalFunction) GetName() string { + return m.name +} + +// GetType 返回函数类型 +func (m *MockAnalyticalFunction) GetType() FunctionType { + return TypeAnalytical +} + +// GetCategory 返回函数分类 +func (m *MockAnalyticalFunction) GetCategory() string { + return "mock" +} + +// GetDescription 返回函数描述 +func (m *MockAnalyticalFunction) GetDescription() string { + return "Mock analytical function for testing" +} + +// GetAliases 返回函数别名 +func (m *MockAnalyticalFunction) GetAliases() []string { + return []string{} +} + +// GetMinArgs 返回最小参数数量 +func (m *MockAnalyticalFunction) GetMinArgs() int { + return 1 +} + +// GetMaxArgs 返回最大参数数量 +func (m *MockAnalyticalFunction) GetMaxArgs() int { + return 1 +} + +// Validate 验证参数 +func (m *MockAnalyticalFunction) Validate(args []interface{}) error { + return nil +} + +// Execute 执行函数 +func (m *MockAnalyticalFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if len(args) > 0 { + m.values = append(m.values, args[0]) + } + return "mock_result", nil +} + +// Add 添加值到聚合器 +func (m *MockAnalyticalFunction) Add(value interface{}) { + m.values = append(m.values, value) +} + +// Result 返回聚合结果 +func (m *MockAnalyticalFunction) Result() interface{} { + return len(m.values) +} + +// Reset 重置聚合器 +func (m *MockAnalyticalFunction) Reset() { + m.values = make([]interface{}, 0) +} + +// New 创建新的聚合器实例 +func (m *MockAnalyticalFunction) New() AggregatorFunction { + newMock := &MockAnalyticalFunction{ + name: m.name, + values: make([]interface{}, 0), + } + return newMock +} + +// Clone 克隆函数 - 返回AggregatorFunction以满足接口要求 +func (m *MockAnalyticalFunction) Clone() AggregatorFunction { + newMock := &MockAnalyticalFunction{ + name: m.name, + values: make([]interface{}, len(m.values)), + } + copy(newMock.values, m.values) + return newMock +} \ No newline at end of file diff --git a/functions/functions_expr_test.go b/functions/functions_expr_test.go index 7abd4d1..7dbd997 100644 --- a/functions/functions_expr_test.go +++ b/functions/functions_expr_test.go @@ -53,4 +53,155 @@ func TestExprFunctionEdgeCases(t *testing.T) { if err == nil { t.Error("ExprFunction.Execute should fail for empty args") } + + // 测试非字符串参数(现在应该成功) + ctx := &FunctionContext{Data: map[string]interface{}{}} + _, err = fn.Execute(ctx, []interface{}{123}) + if err != nil { + t.Errorf("ExprFunction.Execute should accept non-string argument: %v", err) + } + + // 测试无效表达式 + _, err = fn.Execute(ctx, []interface{}{"invalid expression +++"}) + if err == nil { + t.Error("ExprFunction.Execute should fail for invalid expression") + } +} + +// TestExprFunctionCreation 测试ExprFunction的创建和属性 +func TestExprFunctionCreation(t *testing.T) { + fn := NewExprFunction() + if fn == nil { + t.Error("NewExprFunction should not return nil") + } + + if fn.GetName() != "expr" { + t.Errorf("Expected name 'expr', got %s", fn.GetName()) + } + + if fn.GetType() != TypeString { + t.Errorf("Expected type %s, got %s", TypeString, fn.GetType()) + } + + // BaseFunction doesn't expose GetMinArgs/GetMaxArgs methods + // We can only test through Validate method + err := fn.Validate([]interface{}{"test"}) + if err != nil { + t.Errorf("Validate should accept 1 argument: %v", err) + } + + err = fn.Validate([]interface{}{}) + if err == nil { + t.Error("Validate should reject 0 arguments") + } + + err = fn.Validate([]interface{}{"arg1", "arg2"}) + if err == nil { + t.Error("Validate should reject 2 arguments") + } + + if fn.GetCategory() == "" { + t.Error("Function category should not be empty") + } + + if fn.GetDescription() == "" { + t.Error("Function description should not be empty") + } +} + +// TestExprFunctionWithDifferentExpressions 测试不同类型的表达式 +func TestExprFunctionWithDifferentExpressions(t *testing.T) { + fn := NewExprFunction() + ctx := &FunctionContext{ + Data: map[string]interface{}{ + "x": 10, + "y": 5, + "name": "John", + "active": true, + }, + } + + // 测试数学表达式 + result, err := fn.Execute(ctx, []interface{}{"x + y"}) + if err != nil { + t.Errorf("Math expression failed: %v", err) + } + if result != 15 { + t.Errorf("Expected 15, got %v", result) + } + + // 测试比较表达式 + result, err = fn.Execute(ctx, []interface{}{"x > y"}) + if err != nil { + t.Errorf("Comparison expression failed: %v", err) + } + if result != true { + t.Errorf("Expected true, got %v", result) + } + + // 测试字符串表达式 + result, err = fn.Execute(ctx, []interface{}{"name + ' Doe'"}) + if err != nil { + t.Errorf("String expression failed: %v", err) + } + if result != "John Doe" { + t.Errorf("Expected 'John Doe', got %v", result) + } + + // 测试布尔表达式 + result, err = fn.Execute(ctx, []interface{}{"active && true"}) + if err != nil { + t.Errorf("Boolean expression failed: %v", err) + } + if result != true { + t.Errorf("Expected true, got %v", result) + } + + // 测试复杂表达式 + result, err = fn.Execute(ctx, []interface{}{"(x + y) * 2"}) + if err != nil { + t.Errorf("Complex expression failed: %v", err) + } + if result != 30 { + t.Errorf("Expected 30, got %v", result) + } +} + +// TestExprFunctionWithFunctionCalls 测试函数调用表达式 +func TestExprFunctionWithFunctionCalls(t *testing.T) { + fn := NewExprFunction() + ctx := &FunctionContext{ + Data: map[string]interface{}{ + "text": "Hello World", + "num": -42, + }, + } + + // 测试abs函数调用 + result, err := fn.Execute(ctx, []interface{}{"abs(-10)"}) + if err != nil { + t.Errorf("Function call expression failed: %v", err) + } + if result != float64(10) { + t.Errorf("Expected 10, got %v", result) + } + + // 测试length函数调用 + result, err = fn.Execute(ctx, []interface{}{"length(text)"}) + if err != nil { + t.Errorf("Length function call failed: %v", err) + } + if result != 11 { + t.Errorf("Expected 11, got %v", result) + } + + // 测试组合函数调用 + result, err = fn.Execute(ctx, []interface{}{"abs(num) + length(text)"}) + if err != nil { + t.Errorf("Combined function calls failed: %v", err) + } + expected := float64(53) // 42 + 11 + if result != expected { + t.Errorf("Expected %v, got %v", expected, result) + } } diff --git a/functions/functions_json_test.go b/functions/functions_json_test.go index 37364c0..64b4c1c 100644 --- a/functions/functions_json_test.go +++ b/functions/functions_json_test.go @@ -19,18 +19,84 @@ func TestJsonFunctions(t *testing.T) { args: []interface{}{map[string]interface{}{"name": "test", "value": 123}}, expected: `{"name":"test","value":123}`, }, + { + name: "to_json array", + funcName: "to_json", + args: []interface{}{[]interface{}{1, 2, 3}}, + expected: `[1,2,3]`, + }, + { + name: "to_json string", + funcName: "to_json", + args: []interface{}{"hello"}, + expected: `"hello"`, + }, { name: "from_json basic", funcName: "from_json", args: []interface{}{`{"name":"test","value":123}`}, expected: map[string]interface{}{"name": "test", "value": float64(123)}, }, + { + name: "from_json array", + funcName: "from_json", + args: []interface{}{`[1,2,3]`}, + expected: []interface{}{float64(1), float64(2), float64(3)}, + }, + { + name: "from_json invalid", + funcName: "from_json", + args: []interface{}{`{"name":"test"`}, + wantErr: true, + }, + { + name: "from_json non-string", + funcName: "from_json", + args: []interface{}{123}, + wantErr: true, + }, { name: "json_extract basic", funcName: "json_extract", args: []interface{}{`{"name":"test","value":123}`, "$.name"}, expected: "test", }, + { + name: "json_extract number", + funcName: "json_extract", + args: []interface{}{`{"name":"test","value":123}`, "$.value"}, + expected: float64(123), + }, + { + name: "json_extract invalid json", + funcName: "json_extract", + args: []interface{}{`{"name":"test"`, "$.name"}, + wantErr: true, + }, + { + name: "json_extract non-string json", + funcName: "json_extract", + args: []interface{}{123, "$.name"}, + wantErr: true, + }, + { + name: "json_extract non-string path", + funcName: "json_extract", + args: []interface{}{`{"name":"test"}`, 123}, + wantErr: true, + }, + { + name: "json_extract invalid path", + funcName: "json_extract", + args: []interface{}{`{"name":"test"}`, "invalid_path"}, + wantErr: true, + }, + { + name: "json_extract non-object", + funcName: "json_extract", + args: []interface{}{`[1,2,3]`, "$.name"}, + wantErr: true, + }, { name: "json_valid true", funcName: "json_valid", @@ -43,18 +109,102 @@ func TestJsonFunctions(t *testing.T) { args: []interface{}{`{"name":"test"`}, expected: false, }, + { + name: "json_valid non-string", + funcName: "json_valid", + args: []interface{}{123}, + expected: false, + }, { name: "json_type object", funcName: "json_type", args: []interface{}{`{"name":"test"}`}, expected: "object", }, + { + name: "json_type array", + funcName: "json_type", + args: []interface{}{`[1,2,3]`}, + expected: "array", + }, + { + name: "json_type string", + funcName: "json_type", + args: []interface{}{`"hello"`}, + expected: "string", + }, + { + name: "json_type number", + funcName: "json_type", + args: []interface{}{`123`}, + expected: "number", + }, + { + name: "json_type boolean", + funcName: "json_type", + args: []interface{}{`true`}, + expected: "boolean", + }, + { + name: "json_type null", + funcName: "json_type", + args: []interface{}{`null`}, + expected: "null", + }, + { + name: "json_type invalid", + funcName: "json_type", + args: []interface{}{`{"name":"test"`}, + expected: "invalid", + }, + { + name: "json_type non-string", + funcName: "json_type", + args: []interface{}{123}, + expected: "unknown", + }, { name: "json_length array", funcName: "json_length", args: []interface{}{`[1,2,3]`}, expected: 3, }, + { + name: "json_length object", + funcName: "json_length", + args: []interface{}{`{"a":1,"b":2}`}, + expected: 2, + }, + { + name: "json_length empty array", + funcName: "json_length", + args: []interface{}{`[]`}, + expected: 0, + }, + { + name: "json_length empty object", + funcName: "json_length", + args: []interface{}{`{}`}, + expected: 0, + }, + { + name: "json_length invalid json", + funcName: "json_length", + args: []interface{}{`{"name":"test"`}, + wantErr: true, + }, + { + name: "json_length non-string", + funcName: "json_length", + args: []interface{}{123}, + wantErr: true, + }, + { + name: "json_length string value", + funcName: "json_length", + args: []interface{}{`"hello"`}, + wantErr: true, + }, } for _, tt := range tests { @@ -77,6 +227,165 @@ func TestJsonFunctions(t *testing.T) { } } +// TestJsonFunctionValidation 测试JSON函数参数验证 +func TestJsonFunctionValidation(t *testing.T) { + tests := []struct { + name string + funcName string + args []interface{} + wantErr bool + }{ + { + name: "to_json no args", + funcName: "to_json", + args: []interface{}{}, + wantErr: true, + }, + { + name: "to_json too many args", + funcName: "to_json", + args: []interface{}{"test", "extra"}, + wantErr: true, + }, + { + name: "from_json no args", + funcName: "from_json", + args: []interface{}{}, + wantErr: true, + }, + { + name: "from_json too many args", + funcName: "from_json", + args: []interface{}{"test", "extra"}, + wantErr: true, + }, + { + name: "json_extract one arg", + funcName: "json_extract", + args: []interface{}{"test"}, + wantErr: true, + }, + { + name: "json_extract too many args", + funcName: "json_extract", + args: []interface{}{"test", "path", "extra"}, + wantErr: true, + }, + { + name: "json_valid no args", + funcName: "json_valid", + args: []interface{}{}, + wantErr: true, + }, + { + name: "json_type no args", + funcName: "json_type", + args: []interface{}{}, + wantErr: true, + }, + { + name: "json_length no args", + funcName: "json_length", + args: []interface{}{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, exists := Get(tt.funcName) + if !exists { + t.Fatalf("Function %s not found", tt.funcName) + } + + err := fn.Validate(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// TestJsonFunctionCreation 测试JSON函数创建 +func TestJsonFunctionCreation(t *testing.T) { + tests := []struct { + name string + constructor func() Function + expectedName string + }{ + { + name: "ToJsonFunction", + constructor: func() Function { return NewToJsonFunction() }, + expectedName: "to_json", + }, + { + name: "FromJsonFunction", + constructor: func() Function { return NewFromJsonFunction() }, + expectedName: "from_json", + }, + { + name: "JsonExtractFunction", + constructor: func() Function { return NewJsonExtractFunction() }, + expectedName: "json_extract", + }, + { + name: "JsonValidFunction", + constructor: func() Function { return NewJsonValidFunction() }, + expectedName: "json_valid", + }, + { + name: "JsonTypeFunction", + constructor: func() Function { return NewJsonTypeFunction() }, + expectedName: "json_type", + }, + { + name: "JsonLengthFunction", + constructor: func() Function { return NewJsonLengthFunction() }, + expectedName: "json_length", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn := tt.constructor() + if fn == nil { + t.Error("Constructor returned nil") + return + } + + if fn.GetName() != tt.expectedName { + t.Errorf("Expected name %s, got %s", tt.expectedName, fn.GetName()) + } + + if fn.GetType() == "" { + t.Error("Function type should not be empty") + } + + if fn.GetCategory() == "" { + t.Error("Function category should not be empty") + } + + if fn.GetDescription() == "" { + t.Error("Function description should not be empty") + } + + // Test argument validation through Validate method + // Most JSON functions require exactly 1 argument, except json_extract which needs 2 + if tt.expectedName == "json_extract" { + err := fn.Validate([]interface{}{"test", "$.path"}) + if err != nil { + t.Errorf("Function %s should accept 2 arguments: %v", tt.expectedName, err) + } + } else if tt.expectedName != "json_length" { // json_length might have different requirements + err := fn.Validate([]interface{}{"test"}) + if err != nil { + t.Errorf("Function %s should accept 1 argument: %v", tt.expectedName, err) + } + } + }) + } +} + // 辅助函数:比较结果 func compareResults(a, b interface{}) bool { if a == nil && b == nil { @@ -93,7 +402,22 @@ func compareResults(a, b interface{}) bool { return false } for k, v := range mapA { - if mapB[k] != v { + if !compareResults(v, mapB[k]) { + return false + } + } + return true + } + } + + // 对于slice类型的特殊处理 + if sliceA, okA := a.([]interface{}); okA { + if sliceB, okB := b.([]interface{}); okB { + if len(sliceA) != len(sliceB) { + return false + } + for i, v := range sliceA { + if !compareResults(v, sliceB[i]) { return false } } diff --git a/functions/functions_multirow_test.go b/functions/functions_multirow_test.go index 06142c2..cd795de 100644 --- a/functions/functions_multirow_test.go +++ b/functions/functions_multirow_test.go @@ -5,6 +5,132 @@ import ( "testing" ) +func TestUnnestFunction(t *testing.T) { + fn := NewUnnestFunction() + ctx := &FunctionContext{} + + // 测试基本unnest功能 + args := []interface{}{[]interface{}{"a", "b", "c"}} + result, err := fn.Execute(ctx, args) + if err != nil { + t.Errorf("UnnestFunction should not return error: %v", err) + } + expected := []interface{}{"a", "b", "c"} + if !reflect.DeepEqual(result, expected) { + t.Errorf("UnnestFunction = %v, want %v", result, expected) + } + + // 测试对象数组unnest + args = []interface{}{ + []interface{}{ + map[string]interface{}{"name": "Alice", "age": 25}, + map[string]interface{}{"name": "Bob", "age": 30}, + }, + } + result, err = fn.Execute(ctx, args) + if err != nil { + t.Errorf("UnnestFunction should not return error: %v", err) + } + expected = []interface{}{ + map[string]interface{}{ + "__unnest_object__": true, + "__data__": map[string]interface{}{"name": "Alice", "age": 25}, + }, + map[string]interface{}{ + "__unnest_object__": true, + "__data__": map[string]interface{}{"name": "Bob", "age": 30}, + }, + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("UnnestFunction = %v, want %v", result, expected) + } + + // 测试空数组 + args = []interface{}{[]interface{}{}} + result, err = fn.Execute(ctx, args) + if err != nil { + t.Errorf("UnnestFunction should not return error for empty array: %v", err) + } + if len(result.([]interface{})) != 0 { + t.Errorf("UnnestFunction should return empty array for empty input") + } + + // 测试nil参数 + args = []interface{}{nil} + result, err = fn.Execute(ctx, args) + if err != nil { + t.Errorf("UnnestFunction should not return error for nil: %v", err) + } + if len(result.([]interface{})) != 0 { + t.Errorf("UnnestFunction should return empty array for nil input") + } + + // 测试错误参数数量 + args = []interface{}{} + err = fn.Validate(args) + if err == nil { + t.Errorf("UnnestFunction should return error for no arguments") + } + + // 测试非数组参数 + args = []interface{}{"not an array"} + _, err = fn.Execute(ctx, args) + if err == nil { + t.Errorf("UnnestFunction should return error for non-array argument") + } + + // 测试数组类型 + args = []interface{}{[3]string{"x", "y", "z"}} + result, err = fn.Execute(ctx, args) + if err != nil { + t.Errorf("UnnestFunction should handle arrays: %v", err) + } + expected = []interface{}{"x", "y", "z"} + if !reflect.DeepEqual(result, expected) { + t.Errorf("UnnestFunction array = %v, want %v", result, expected) + } +} + +// TestUnnestFunctionCreation 测试UnnestFunction创建 +func TestUnnestFunctionCreation(t *testing.T) { + fn := NewUnnestFunction() + if fn == nil { + t.Error("NewUnnestFunction should not return nil") + } + + if fn.GetName() != "unnest" { + t.Errorf("Expected name 'unnest', got %s", fn.GetName()) + } + + // Test argument validation through Validate method + err := fn.Validate([]interface{}{"test"}) + if err != nil { + t.Errorf("Validate should accept 1 argument: %v", err) + } + + err = fn.Validate([]interface{}{}) + if err == nil { + t.Error("Validate should reject 0 arguments") + } + + err = fn.Validate([]interface{}{"arg1", "arg2"}) + if err == nil { + t.Error("Validate should reject 2 arguments") + } + + if fn.GetType() == "" { + t.Error("Function type should not be empty") + } + + if fn.GetCategory() == "" { + t.Error("Function category should not be empty") + } + + if fn.GetDescription() == "" { + t.Error("Function description should not be empty") + } +} + func TestIsUnnestResult(t *testing.T) { // 测试非unnest结果 normalSlice := []interface{}{"a", "b", "c"}