mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-07-09 17:21:32 +00:00
328 lines
7.7 KiB
Go
328 lines
7.7 KiB
Go
package functions
|
||
|
||
import (
|
||
"testing"
|
||
)
|
||
|
||
// isWindowFunction 判断是否为窗口函数
|
||
func isWindowFunction(funcName string) bool {
|
||
windowFunctions := map[string]bool{
|
||
"row_number": true,
|
||
"window_start": true,
|
||
"window_end": true,
|
||
"lead": true,
|
||
"lag": true,
|
||
"first_value": true,
|
||
"last_value": true,
|
||
}
|
||
return windowFunctions[funcName]
|
||
}
|
||
|
||
func TestNewWindowFunctions(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
funcName string
|
||
args []interface{}
|
||
want interface{}
|
||
wantErr bool
|
||
setup func(fn AggregatorFunction)
|
||
}{
|
||
// first_value 函数测试
|
||
{
|
||
name: "first_value basic",
|
||
funcName: "first_value",
|
||
args: []interface{}{"test"},
|
||
want: "first",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
fn.Add("third")
|
||
},
|
||
},
|
||
{
|
||
name: "first_value empty",
|
||
funcName: "first_value",
|
||
args: []interface{}{"test"},
|
||
want: nil,
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
|
||
// last_value 函数测试
|
||
{
|
||
name: "last_value basic",
|
||
funcName: "last_value",
|
||
args: []interface{}{"test"},
|
||
want: "third",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
fn.Add("third")
|
||
},
|
||
},
|
||
{
|
||
name: "last_value empty",
|
||
funcName: "last_value",
|
||
args: []interface{}{"test"},
|
||
want: nil,
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
|
||
// lag 函数测试
|
||
{
|
||
name: "lag default offset",
|
||
funcName: "lag",
|
||
args: []interface{}{"test"},
|
||
want: "second",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
fn.Add("third")
|
||
},
|
||
},
|
||
{
|
||
name: "lag with offset 2",
|
||
funcName: "lag",
|
||
args: []interface{}{"test", 2},
|
||
want: "first",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
fn.Add("third")
|
||
},
|
||
},
|
||
{
|
||
name: "lag with default value",
|
||
funcName: "lag",
|
||
args: []interface{}{"test", 5, "default"},
|
||
want: "default",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
},
|
||
},
|
||
{
|
||
name: "lag invalid offset type",
|
||
funcName: "lag",
|
||
args: []interface{}{"test", "invalid"},
|
||
wantErr: true,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
|
||
// lead 函数测试
|
||
{
|
||
name: "lead default offset",
|
||
funcName: "lead",
|
||
args: []interface{}{"test"},
|
||
want: nil, // Lead函数简化实现返回nil
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
fn.Add("third")
|
||
},
|
||
},
|
||
{
|
||
name: "lead with default value",
|
||
funcName: "lead",
|
||
args: []interface{}{"test", 1, "default"},
|
||
want: "default",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
{
|
||
name: "lead invalid offset type",
|
||
funcName: "lead",
|
||
args: []interface{}{"test", "invalid"},
|
||
wantErr: true,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
|
||
// nth_value 函数测试
|
||
{
|
||
name: "nth_value first",
|
||
funcName: "nth_value",
|
||
args: []interface{}{"test", 1},
|
||
want: "first",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
fn.Add("third")
|
||
},
|
||
},
|
||
{
|
||
name: "nth_value second",
|
||
funcName: "nth_value",
|
||
args: []interface{}{"test", 2},
|
||
want: "second",
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
fn.Add("third")
|
||
},
|
||
},
|
||
{
|
||
name: "nth_value out of range",
|
||
funcName: "nth_value",
|
||
args: []interface{}{"test", 5},
|
||
want: nil,
|
||
wantErr: false,
|
||
setup: func(fn AggregatorFunction) {
|
||
fn.Add("first")
|
||
fn.Add("second")
|
||
},
|
||
},
|
||
{
|
||
name: "nth_value invalid n type",
|
||
funcName: "nth_value",
|
||
args: []interface{}{"test", "invalid"},
|
||
wantErr: true,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
{
|
||
name: "nth_value zero n",
|
||
funcName: "nth_value",
|
||
args: []interface{}{"test", 0},
|
||
wantErr: true,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
{
|
||
name: "nth_value negative n",
|
||
funcName: "nth_value",
|
||
args: []interface{}{"test", -1},
|
||
wantErr: true,
|
||
setup: func(fn AggregatorFunction) {},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
fn, exists := Get(tt.funcName)
|
||
if !exists {
|
||
t.Fatalf("Function %s not found", tt.funcName)
|
||
}
|
||
|
||
// 检查函数是否实现了AggregatorFunction接口
|
||
aggFn, ok := fn.(AggregatorFunction)
|
||
if !ok {
|
||
t.Fatalf("Function %s does not implement AggregatorFunction", tt.funcName)
|
||
}
|
||
|
||
// 先执行函数的Validate方法来设置参数
|
||
err := fn.Validate(tt.args)
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||
return
|
||
}
|
||
|
||
// 如果期望错误且Validate已经失败,则测试通过
|
||
if tt.wantErr && err != nil {
|
||
return
|
||
}
|
||
|
||
// 创建新的聚合器实例
|
||
aggInstance := aggFn.New()
|
||
|
||
// 执行setup函数添加测试数据
|
||
if tt.setup != nil {
|
||
tt.setup(aggInstance)
|
||
}
|
||
|
||
// 对于窗口函数测试,不需要调用Execute方法
|
||
// Execute方法主要用于流式处理,这里我们直接测试聚合器的Result方法
|
||
// 如果需要测试Execute方法,应该在原始函数实例上调用
|
||
if !isWindowFunction(tt.funcName) {
|
||
// 对于非窗口函数,在聚合器实例上执行
|
||
if aggFunc, ok := aggInstance.(Function); ok {
|
||
_, err = aggFunc.Execute(nil, tt.args)
|
||
} else {
|
||
// 执行函数
|
||
_, err = fn.Execute(nil, tt.args)
|
||
}
|
||
}
|
||
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||
return
|
||
}
|
||
|
||
if !tt.wantErr {
|
||
// 对于窗口函数,我们主要测试聚合器的Result方法
|
||
aggResult := aggInstance.Result()
|
||
if tt.want != nil && aggResult != tt.want {
|
||
t.Errorf("AggregatorFunction.Result() = %v, want %v", aggResult, tt.want)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// 测试窗口函数的基本功能
|
||
func TestWindowFunctionBasics(t *testing.T) {
|
||
// 测试row_number函数
|
||
t.Run("RowNumberFunction", func(t *testing.T) {
|
||
rowNumFunc, exists := Get("row_number")
|
||
if !exists {
|
||
t.Fatal("row_number function not found")
|
||
}
|
||
|
||
// 重置函数状态
|
||
if rowNum, ok := rowNumFunc.(*RowNumberFunction); ok {
|
||
rowNum.Reset()
|
||
}
|
||
|
||
// 测试行号递增
|
||
result1, err := rowNumFunc.Execute(nil, []interface{}{})
|
||
if err != nil {
|
||
t.Errorf("Execute() error = %v", err)
|
||
}
|
||
if result1 != int64(1) {
|
||
t.Errorf("First call should return 1, got %v", result1)
|
||
}
|
||
|
||
result2, err := rowNumFunc.Execute(nil, []interface{}{})
|
||
if err != nil {
|
||
t.Errorf("Execute() error = %v", err)
|
||
}
|
||
if result2 != int64(2) {
|
||
t.Errorf("Second call should return 2, got %v", result2)
|
||
}
|
||
})
|
||
|
||
// 测试window_start和window_end函数
|
||
t.Run("WindowStartEndFunctions", func(t *testing.T) {
|
||
windowStartFunc, exists := Get("window_start")
|
||
if !exists {
|
||
t.Fatal("window_start function not found")
|
||
}
|
||
|
||
windowEndFunc, exists := Get("window_end")
|
||
if !exists {
|
||
t.Fatal("window_end function not found")
|
||
}
|
||
|
||
// 测试无窗口信息时的行为
|
||
ctx := &FunctionContext{
|
||
Data: map[string]interface{}{},
|
||
}
|
||
_, err := windowStartFunc.Execute(ctx, []interface{}{})
|
||
if err != nil {
|
||
t.Errorf("Execute() error = %v", err)
|
||
}
|
||
// 无窗口信息时应该返回nil或默认值
|
||
|
||
_, err = windowEndFunc.Execute(ctx, []interface{}{})
|
||
if err != nil {
|
||
t.Errorf("Execute() error = %v", err)
|
||
}
|
||
// 无窗口信息时应该返回nil或默认值
|
||
})
|
||
} |