Files
streamsql/functions/aggregator_interface_test.go
T
2025-08-30 00:56:48 +08:00

362 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package functions
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockParameterizedAggregator 实现ParameterizedFunction接口用于测试
type mockParameterizedAggregator struct {
name string
args []interface{}
initialized bool
addedValues []interface{}
result interface{}
}
// 实现Function接口
func (m *mockParameterizedAggregator) GetName() string {
return m.name
}
func (m *mockParameterizedAggregator) GetMinArgs() int {
return 1
}
func (m *mockParameterizedAggregator) GetMaxArgs() int {
return 3
}
func (m *mockParameterizedAggregator) GetType() FunctionType {
return TypeAggregation
}
func (m *mockParameterizedAggregator) GetCategory() string {
return "test"
}
func (m *mockParameterizedAggregator) GetAliases() []string {
return []string{}
}
func (m *mockParameterizedAggregator) Validate(args []interface{}) error {
return nil
}
func (m *mockParameterizedAggregator) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
return m.result, nil
}
func (m *mockParameterizedAggregator) GetDescription() string {
return "Mock parameterized aggregator"
}
// 实现AggregatorFunction接口
func (m *mockParameterizedAggregator) New() AggregatorFunction {
return &mockParameterizedAggregator{
name: m.name,
addedValues: make([]interface{}, 0),
}
}
func (m *mockParameterizedAggregator) Add(value interface{}) {
m.addedValues = append(m.addedValues, value)
}
func (m *mockParameterizedAggregator) Result() interface{} {
if m.result != nil {
return m.result
}
// 默认返回值的数量
return len(m.addedValues)
}
func (m *mockParameterizedAggregator) Reset() {
m.addedValues = make([]interface{}, 0)
m.result = nil
}
func (m *mockParameterizedAggregator) Clone() AggregatorFunction {
return &mockParameterizedAggregator{
name: m.name,
args: m.args,
initialized: m.initialized,
addedValues: make([]interface{}, len(m.addedValues)),
result: m.result,
}
}
// 实现ParameterizedFunction接口
func (m *mockParameterizedAggregator) Init(args []interface{}) error {
m.args = args
m.initialized = true
// 根据参数设置结果
if len(args) > 0 {
if val, ok := args[0].(int); ok {
m.result = val * 10 // 简单的计算逻辑
}
}
return nil
}
// mockSimpleAggregator 实现AggregatorFunction接口但不实现ParameterizedFunction
type mockSimpleAggregator struct {
name string
values []interface{}
}
// 实现Function接口
func (m *mockSimpleAggregator) GetName() string {
return m.name
}
func (m *mockSimpleAggregator) GetMinArgs() int {
return 1
}
func (m *mockSimpleAggregator) GetMaxArgs() int {
return 1
}
func (m *mockSimpleAggregator) GetType() FunctionType {
return TypeAggregation
}
func (m *mockSimpleAggregator) GetCategory() string {
return "test"
}
func (m *mockSimpleAggregator) GetAliases() []string {
return []string{}
}
func (m *mockSimpleAggregator) Validate(args []interface{}) error {
return nil
}
func (m *mockSimpleAggregator) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
return len(m.values), nil
}
func (m *mockSimpleAggregator) GetDescription() string {
return "Mock simple aggregator"
}
// 实现AggregatorFunction接口
func (m *mockSimpleAggregator) New() AggregatorFunction {
return &mockSimpleAggregator{
name: m.name,
values: make([]interface{}, 0),
}
}
func (m *mockSimpleAggregator) Add(value interface{}) {
m.values = append(m.values, value)
}
func (m *mockSimpleAggregator) Result() interface{} {
return len(m.values)
}
func (m *mockSimpleAggregator) Reset() {
m.values = make([]interface{}, 0)
}
func (m *mockSimpleAggregator) Clone() AggregatorFunction {
return &mockSimpleAggregator{
name: m.name,
values: make([]interface{}, len(m.values)),
}
}
// TestCreateParameterizedAggregator 测试CreateParameterizedAggregator函数的完整功能
func TestCreateParameterizedAggregator(t *testing.T) {
t.Run("测试参数化聚合器创建和初始化", func(t *testing.T) {
// 注册测试用的参数化聚合器
mockParamAgg := &mockParameterizedAggregator{name: "test_param_agg"}
err := Register(mockParamAgg)
require.NoError(t, err)
// 测试创建参数化聚合器
args := []interface{}{5, "test", 3.14}
aggregator, err := CreateParameterizedAggregator("test_param_agg", args)
require.NoError(t, err)
require.NotNil(t, aggregator)
// 验证聚合器类型
paramAgg, ok := aggregator.(*mockParameterizedAggregator)
require.True(t, ok)
assert.True(t, paramAgg.initialized)
assert.Equal(t, args, paramAgg.args)
assert.Equal(t, 50, paramAgg.result) // 5 * 10 = 50
// 清理
Unregister("test_param_agg")
})
t.Run("测试非参数化聚合器的回退处理", func(t *testing.T) {
// 注册测试用的简单聚合器
mockSimpleAgg := &mockSimpleAggregator{name: "test_simple_agg"}
err := Register(mockSimpleAgg)
require.NoError(t, err)
// 测试创建非参数化聚合器(应该回退到常规创建)
args := []interface{}{1, 2, 3}
aggregator, err := CreateParameterizedAggregator("test_simple_agg", args)
require.NoError(t, err)
require.NotNil(t, aggregator)
// 验证聚合器类型
simpleAgg, ok := aggregator.(*mockSimpleAggregator)
require.True(t, ok)
assert.Equal(t, "test_simple_agg", simpleAgg.name)
// 清理
Unregister("test_simple_agg")
})
t.Run("测试不存在的聚合器函数", func(t *testing.T) {
// 测试创建不存在的聚合器
args := []interface{}{1, 2, 3}
aggregator, err := CreateParameterizedAggregator("non_existent_agg", args)
assert.Error(t, err)
assert.Nil(t, aggregator)
assert.Contains(t, err.Error(), "aggregator function non_existent_agg not found")
})
t.Run("测试参数化聚合器初始化失败", func(t *testing.T) {
// 创建一个会初始化失败的参数化聚合器
failingAgg := &mockParameterizedAggregatorWithFailingInit{
mockParameterizedAggregator: mockParameterizedAggregator{name: "failing_agg"},
}
err := Register(failingAgg)
require.NoError(t, err)
// 测试创建时初始化失败
args := []interface{}{"invalid"}
aggregator, err := CreateParameterizedAggregator("failing_agg", args)
assert.Error(t, err)
assert.Nil(t, aggregator)
assert.Contains(t, err.Error(), "failed to initialize parameterized function")
// 清理
Unregister("failing_agg")
})
t.Run("测试空参数列表", func(t *testing.T) {
// 注册测试用的参数化聚合器
mockParamAgg := &mockParameterizedAggregator{name: "test_empty_args"}
err := Register(mockParamAgg)
require.NoError(t, err)
// 测试空参数列表
args := []interface{}{}
aggregator, err := CreateParameterizedAggregator("test_empty_args", args)
require.NoError(t, err)
require.NotNil(t, aggregator)
// 验证聚合器已初始化
paramAgg, ok := aggregator.(*mockParameterizedAggregator)
require.True(t, ok)
assert.True(t, paramAgg.initialized)
assert.Equal(t, args, paramAgg.args)
// 清理
Unregister("test_empty_args")
})
t.Run("测试聚合器功能", func(t *testing.T) {
// 注册测试用的参数化聚合器
mockParamAgg := &mockParameterizedAggregator{name: "test_functionality"}
err := Register(mockParamAgg)
require.NoError(t, err)
// 创建聚合器
args := []interface{}{3}
aggregator, err := CreateParameterizedAggregator("test_functionality", args)
require.NoError(t, err)
require.NotNil(t, aggregator)
// 测试聚合器功能
aggregator.Add(10)
aggregator.Add(20)
aggregator.Add(30)
// 验证结果应该是初始化时设置的值3 * 10 = 30
result := aggregator.Result()
assert.Equal(t, 30, result)
// 清理
Unregister("test_functionality")
})
}
// mockParameterizedAggregatorWithFailingInit 用于测试初始化失败的情况
type mockParameterizedAggregatorWithFailingInit struct {
mockParameterizedAggregator
}
func (m *mockParameterizedAggregatorWithFailingInit) Init(args []interface{}) error {
return errors.New("initialization failed") // 返回一个错误
}
func (m *mockParameterizedAggregatorWithFailingInit) New() AggregatorFunction {
return &mockParameterizedAggregatorWithFailingInit{
mockParameterizedAggregator: mockParameterizedAggregator{
name: m.mockParameterizedAggregator.name,
addedValues: make([]interface{}, 0),
result: m.mockParameterizedAggregator.result,
},
}
}
// TestCreateAggregatorInterface 测试CreateAggregator函数
func TestCreateAggregatorInterface(t *testing.T) {
t.Run("测试创建存在的聚合器", func(t *testing.T) {
// 注册测试聚合器
mockAgg := &mockSimpleAggregator{name: "test_create_agg"}
err := Register(mockAgg)
require.NoError(t, err)
// 创建聚合器
aggregator, err := CreateAggregator("test_create_agg")
require.NoError(t, err)
require.NotNil(t, aggregator)
assert.IsType(t, &mockSimpleAggregator{}, aggregator)
// 清理
Unregister("test_create_agg")
})
t.Run("测试创建不存在的聚合器", func(t *testing.T) {
aggregator, err := CreateAggregator("non_existent")
assert.Error(t, err)
assert.Nil(t, aggregator)
assert.Contains(t, err.Error(), "aggregator function non_existent not found")
})
}
// TestIsAggregatorFunction 测试IsAggregatorFunction函数
func TestIsAggregatorFunction(t *testing.T) {
t.Run("测试已注册的聚合器函数", func(t *testing.T) {
// 注册测试聚合器
mockAgg := &mockSimpleAggregator{name: "test_is_agg"}
err := Register(mockAgg)
require.NoError(t, err)
// 测试IsAggregatorFunction
isAgg := IsAggregatorFunction("test_is_agg")
assert.True(t, isAgg)
// 清理
Unregister("test_is_agg")
})
t.Run("测试不存在的函数", func(t *testing.T) {
isAgg := IsAggregatorFunction("non_existent_function")
assert.False(t, isAgg)
})
}