package functions import ( "fmt" "math" "testing" "github.com/rulego/streamsql/utils/cast" ) func TestAggregatorFunctionInterface(t *testing.T) { // Test Sum aggregation function sumFunc := NewSumFunction() // Test creating new instance aggInstance := sumFunc.New() if aggInstance == nil { t.Fatal("Failed to create new aggregator instance") } // Test incremental calculation aggInstance.Add(10.0) aggInstance.Add(20.0) aggInstance.Add(30.0) result := aggInstance.Result() if result != 60.0 { t.Errorf("Expected 60.0, got %v", result) } // Test reset aggInstance.Reset() result = aggInstance.Result() if result != nil { t.Errorf("Expected nil after reset (SQL standard: SUM with no rows returns NULL), got %v", result) } // Test clone aggInstance.Add(15.0) cloned := aggInstance.Clone() cloned.Add(25.0) originalResult := aggInstance.Result() clonedResult := cloned.Result() if originalResult != 15.0 { t.Errorf("Expected original result 15.0, got %v", originalResult) } if clonedResult != 40.0 { t.Errorf("Expected cloned result 40.0, got %v", clonedResult) } } func TestAnalyticalFunctionInterface(t *testing.T) { // Test Lag analytical function lagFunc := NewLagFunction() ctx := &FunctionContext{ Data: make(map[string]interface{}), } // Test first value (should return default value nil) result, err := lagFunc.Execute(ctx, []interface{}{10}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if result != nil { t.Errorf("Expected nil for first value, got %v", result) } // Test second value (should return first value) result, err = lagFunc.Execute(ctx, []interface{}{20}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if result != 10 { t.Errorf("Expected 10, got %v", result) } // Test clone cloned := lagFunc.Clone() if cloned == nil { t.Fatal("Failed to clone analytical function") } // Test reset lagFunc.Reset() result, err = lagFunc.Execute(ctx, []interface{}{30}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if result != nil { t.Errorf("Expected nil after reset, got %v", result) } } func TestCreateAggregator(t *testing.T) { // Test creating registered aggregator aggFunc, err := CreateAggregator("sum") if err != nil { t.Fatalf("Failed to create sum aggregator: %v", err) } if aggFunc == nil { t.Fatal("Created aggregator is nil") } // Test creating non-existent aggregator _, err = CreateAggregator("nonexistent") if err == nil { t.Error("Expected error for nonexistent aggregator") } } func TestCreateAnalytical(t *testing.T) { // Test creating registered analytical function analFunc, err := CreateAnalytical("lag") if err != nil { t.Fatalf("Failed to create lag analytical function: %v", err) } if analFunc == nil { t.Fatal("Created analytical function is nil") } // Test creating non-existent analytical function _, err = CreateAnalytical("nonexistent") if err == nil { t.Error("Expected error for nonexistent analytical function") } } func TestAggregatorAdapter(t *testing.T) { // 测试聚合器适配器 adapter, err := NewAggregatorAdapter("sum") if err != nil { t.Fatalf("Failed to create aggregator adapter: %v", err) } // 测试创建新实例 newInstance := adapter.New() if newInstance == nil { t.Fatal("Failed to create new adapter instance") } newAdapter, ok := newInstance.(*AggregatorAdapter) if !ok { t.Fatal("New instance is not an AggregatorAdapter") } // 测试添加值和获取结果 newAdapter.Add(5.0) newAdapter.Add(10.0) result := newAdapter.Result() if result != 15.0 { t.Errorf("Expected 15.0, got %v", result) } } func TestAnalyticalAdapter(t *testing.T) { // 测试分析函数适配器 adapter, err := NewAnalyticalAdapter("latest") if err != nil { t.Fatalf("Failed to create analytical adapter: %v", err) } ctx := &FunctionContext{ Data: make(map[string]interface{}), } // 测试执行 result, err := adapter.Execute(ctx, []interface{}{"test_value"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if result != "test_value" { t.Errorf("Expected 'test_value', got %v", result) } // 测试克隆 cloned := adapter.Clone() if cloned == nil { t.Fatal("Failed to clone analytical adapter") } // 测试重置 adapter.Reset() } func TestCustomFunctionRegistration(t *testing.T) { // 注册自定义函数示例 RegisterCustomFunctions() // 测试自定义聚合函数 productFunc, exists := Get("product") if !exists { t.Fatal("Custom product function not registered") } if productFunc.GetType() != TypeAggregation { t.Error("Product function should be aggregation type") } // 测试自定义分析函数 movingAvgFunc, exists := Get("moving_avg") if !exists { t.Fatal("Custom moving average function not registered") } if movingAvgFunc.GetType() != TypeAnalytical { t.Error("Moving average function should be analytical type") } // 测试简单自定义函数 doubleFunc, exists := Get("double") if !exists { t.Fatal("Custom double function not registered") } ctx := &FunctionContext{ Data: make(map[string]interface{}), } result, err := doubleFunc.Execute(ctx, []interface{}{5.0}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if result != 10.0 { t.Errorf("Expected 10.0, got %v", result) } } func TestFunctionRegistry(t *testing.T) { // 测试函数注册表 allFunctions := ListAll() if len(allFunctions) == 0 { t.Error("No functions registered") } // 测试按类型获取函数 aggFunctions := GetByType(TypeAggregation) if len(aggFunctions) == 0 { t.Error("No aggregation functions found") } analFunctions := GetByType(TypeAnalytical) if len(analFunctions) == 0 { t.Error("No analytical functions found") } // 验证一些内置函数存在 expectedFunctions := []string{"sum", "avg", "min", "max", "count", "lag", "latest"} for _, funcName := range expectedFunctions { if _, exists := Get(funcName); !exists { t.Errorf("Expected function %s not found", funcName) } } } func BenchmarkAggregatorIncremental(b *testing.B) { sumFunc := NewSumFunction() aggInstance := sumFunc.New() b.ResetTimer() for i := 0; i < b.N; i++ { aggInstance.Add(float64(i)) } _ = aggInstance.Result() } func BenchmarkAggregatorBatch(b *testing.B) { sumFunc := NewSumFunction() ctx := &FunctionContext{ Data: make(map[string]interface{}), } // 准备测试数据 args := make([]interface{}, b.N) for i := 0; i < b.N; i++ { args[i] = float64(i) } b.ResetTimer() _, _ = sumFunc.Execute(ctx, args) } // CustomProductFunction 自定义乘积聚合函数示例 type CustomProductFunction struct { *BaseFunction product float64 first bool } func NewCustomProductFunction() *CustomProductFunction { return &CustomProductFunction{ BaseFunction: NewBaseFunction("product", TypeAggregation, "自定义聚合函数", "计算数值乘积", 1, -1), product: 1.0, first: true, } } func (f *CustomProductFunction) Validate(args []interface{}) error { return f.ValidateArgCount(args) } func (f *CustomProductFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { product := 1.0 for _, arg := range args { val, err := cast.ToFloat64E(arg) if err != nil { return nil, err } product *= val } return product, nil } // 实现AggregatorFunction接口 func (f *CustomProductFunction) New() AggregatorFunction { return &CustomProductFunction{ BaseFunction: f.BaseFunction, product: 1.0, first: true, } } func (f *CustomProductFunction) Add(value interface{}) { if val, err := cast.ToFloat64E(value); err == nil { if f.first { f.product = val f.first = false } else { f.product *= val } } } func (f *CustomProductFunction) Result() interface{} { if f.first { return 0.0 } return f.product } func (f *CustomProductFunction) Reset() { f.product = 1.0 f.first = true } func (f *CustomProductFunction) Clone() AggregatorFunction { return &CustomProductFunction{ BaseFunction: f.BaseFunction, product: f.product, first: f.first, } } // CustomMovingAverageFunction 自定义移动平均分析函数示例 type CustomMovingAverageFunction struct { *BaseFunction values []float64 windowSize int } func NewCustomMovingAverageFunction(windowSize int) *CustomMovingAverageFunction { return &CustomMovingAverageFunction{ BaseFunction: NewBaseFunction("moving_avg", TypeAnalytical, "自定义分析函数", fmt.Sprintf("计算窗口大小为%d的移动平均", windowSize), 1, 1), windowSize: windowSize, values: make([]float64, 0), } } func (f *CustomMovingAverageFunction) Validate(args []interface{}) error { return f.ValidateArgCount(args) } func (f *CustomMovingAverageFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { val, err := cast.ToFloat64E(args[0]) if err != nil { return nil, err } // 添加新值 f.values = append(f.values, val) // 保持窗口大小 if len(f.values) > f.windowSize { f.values = f.values[1:] } // 计算移动平均 sum := 0.0 for _, v := range f.values { sum += v } return sum / float64(len(f.values)), nil } // 实现AnalyticalFunction接口 func (f *CustomMovingAverageFunction) Reset() { f.values = make([]float64, 0) } // 实现AggregatorFunction接口 - 增量计算支持 func (f *CustomMovingAverageFunction) New() AggregatorFunction { return &CustomMovingAverageFunction{ BaseFunction: f.BaseFunction, windowSize: f.windowSize, values: make([]float64, 0), } } func (f *CustomMovingAverageFunction) Add(value interface{}) { if val, err := cast.ToFloat64E(value); err == nil { // 添加新值 f.values = append(f.values, val) // 保持窗口大小 if len(f.values) > f.windowSize { f.values = f.values[1:] } } } func (f *CustomMovingAverageFunction) Result() interface{} { if len(f.values) == 0 { return 0.0 } // 计算移动平均 sum := 0.0 for _, v := range f.values { sum += v } return sum / float64(len(f.values)) } func (f *CustomMovingAverageFunction) Clone() AggregatorFunction { clone := &CustomMovingAverageFunction{ BaseFunction: f.BaseFunction, windowSize: f.windowSize, values: make([]float64, len(f.values)), } copy(clone.values, f.values) return clone } // CustomGeometricMeanFunction 自定义几何平均聚合函数示例 type CustomGeometricMeanFunction struct { *BaseFunction product float64 count int } func NewCustomGeometricMeanFunction() *CustomGeometricMeanFunction { return &CustomGeometricMeanFunction{ BaseFunction: NewBaseFunction("geomean", TypeAggregation, "自定义聚合函数", "计算几何平均数", 1, -1), product: 1.0, count: 0, } } func (f *CustomGeometricMeanFunction) Validate(args []interface{}) error { return f.ValidateArgCount(args) } func (f *CustomGeometricMeanFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { product := 1.0 for _, arg := range args { val, err := cast.ToFloat64E(arg) if err != nil { return nil, err } if val <= 0 { return nil, fmt.Errorf("geometric mean requires positive values") } product *= val } return math.Pow(product, 1.0/float64(len(args))), nil } // 实现AggregatorFunction接口 func (f *CustomGeometricMeanFunction) New() AggregatorFunction { return &CustomGeometricMeanFunction{ BaseFunction: f.BaseFunction, product: 1.0, count: 0, } } func (f *CustomGeometricMeanFunction) Add(value interface{}) { if val, err := cast.ToFloat64E(value); err == nil && val > 0 { f.product *= val f.count++ } } func (f *CustomGeometricMeanFunction) Result() interface{} { if f.count == 0 { return 0.0 } return math.Pow(f.product, 1.0/float64(f.count)) } func (f *CustomGeometricMeanFunction) Reset() { f.product = 1.0 f.count = 0 } func (f *CustomGeometricMeanFunction) Clone() AggregatorFunction { return &CustomGeometricMeanFunction{ BaseFunction: f.BaseFunction, product: f.product, count: f.count, } } // RegisterCustomFunctions 注册自定义函数的示例 func RegisterCustomFunctions() { // 注册自定义聚合函数 _ = Register(NewCustomProductFunction()) _ = Register(NewCustomGeometricMeanFunction()) // 注册自定义分析函数 _ = Register(NewCustomMovingAverageFunction(5)) // 5个值的移动平均 // 注册适配器 RegisterAggregatorAdapter("product") RegisterAggregatorAdapter("geomean") RegisterAnalyticalAdapter("moving_avg") // 使用RegisterCustomFunction的方式注册简单函数 RegisterCustomFunction("double", TypeAggregation, "自定义函数", "将值乘以2", 1, 1, func(ctx *FunctionContext, args []interface{}) (interface{}, error) { val, err := cast.ToFloat64E(args[0]) if err != nil { return nil, err } return val * 2, nil }) }