Files
streamsql/functions/functions_test.go
T
2025-08-08 09:00:02 +08:00

908 lines
32 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package functions
import (
"math"
"testing"
"time"
"github.com/rulego/streamsql/utils/cast"
"github.com/stretchr/testify/assert"
)
func TestBasicFunctionRegistry(t *testing.T) {
// 测试基本函数注册
tests := []struct {
name string
functionName string
expectedType FunctionType
}{
{"abs function", "abs", TypeMath},
{"concat function", "concat", TypeString},
{"sqrt function", "sqrt", TypeMath},
{"upper function", "upper", TypeString},
{"cast function", "cast", TypeConversion},
{"now function", "now", TypeDateTime},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.functionName)
assert.True(t, exists, "%s should be registered", tt.functionName)
assert.NotNil(t, fn)
assert.Equal(t, tt.functionName, fn.GetName())
assert.Equal(t, tt.expectedType, fn.GetType())
})
}
// 测试不存在的函数
_, exists := Get("nonexistent")
assert.False(t, exists, "nonexistent function should not be found")
}
func TestNewMathFunctions(t *testing.T) {
ctx := &FunctionContext{
Data: map[string]interface{}{},
}
// 表驱动测试用例
tests := []struct {
name string
functionName string
args []interface{}
expected interface{}
expectError bool
errorMsg string
delta float64 // 用于浮点数比较的精度
}{
// Log function tests
{"log valid", "log", []interface{}{10.0}, 1.0, false, "", 1e-10},
{"log negative", "log", []interface{}{-1}, nil, true, "value must be positive", 0},
{"log zero", "log", []interface{}{0}, nil, true, "value must be positive", 0},
// Log10 function tests
{"log10 100", "log10", []interface{}{100}, 2.0, false, "", 1e-10},
{"log10 10", "log10", []interface{}{10}, 1.0, false, "", 1e-10},
// Log2 function tests
{"log2 8", "log2", []interface{}{8}, 3.0, false, "", 1e-10},
{"log2 2", "log2", []interface{}{2}, 1.0, false, "", 1e-10},
// Mod function tests
{"mod 10,3", "mod", []interface{}{10, 3}, 1.0, false, "", 1e-10},
{"mod 7.5,2.5", "mod", []interface{}{7.5, 2.5}, 0.0, false, "", 1e-10},
{"mod division by zero", "mod", []interface{}{10, 0}, nil, true, "division by zero", 0},
// Round function tests
{"round 3.7", "round", []interface{}{3.7}, 4.0, false, "", 1e-10},
{"round 3.2", "round", []interface{}{3.2}, 3.0, false, "", 1e-10},
{"round with precision", "round", []interface{}{3.14159, 2}, 3.14, false, "", 1e-10},
// Sign function tests
{"sign positive", "sign", []interface{}{5.5}, 1, false, "", 0},
{"sign negative", "sign", []interface{}{-3.2}, -1, false, "", 0},
{"sign zero", "sign", []interface{}{0}, 0, false, "", 0},
// Trigonometric function tests
{"sin 0", "sin", []interface{}{0}, 0.0, false, "", 1e-10},
{"sin π/2", "sin", []interface{}{math.Pi / 2}, 1.0, false, "", 1e-10},
{"sinh 0", "sinh", []interface{}{0}, 0.0, false, "", 1e-10},
{"tan 0", "tan", []interface{}{0}, 0.0, false, "", 1e-10},
{"tanh 0", "tanh", []interface{}{0}, 0.0, false, "", 1e-10},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.functionName)
assert.True(t, exists, "Function %s should be registered", tt.functionName)
result, err := fn.Execute(ctx, tt.args)
if tt.expectError {
assert.Error(t, err)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
assert.NoError(t, err)
if tt.delta > 0 {
assert.InDelta(t, tt.expected, result, tt.delta)
} else {
assert.Equal(t, tt.expected, result)
}
}
})
}
// 特殊测试rand函数因为结果是随机的
t.Run("rand function", func(t *testing.T) {
fn, exists := Get("rand")
assert.True(t, exists)
result, err := fn.Execute(ctx, []interface{}{})
assert.NoError(t, err)
val, ok := result.(float64)
assert.True(t, ok)
assert.GreaterOrEqual(t, val, 0.0)
assert.Less(t, val, 1.0)
})
}
func TestFunctionExecution(t *testing.T) {
ctx := &FunctionContext{
Data: map[string]interface{}{},
}
// 函数执行测试用例
tests := []struct {
name string
functionName string
args []interface{}
expected interface{}
expectError bool
}{
// 数学函数
{"abs with positive", "abs", []interface{}{5.5}, 5.5, false},
{"abs with negative", "abs", []interface{}{-5.5}, 5.5, false},
{"abs with zero", "abs", []interface{}{0}, 0.0, false},
{"sqrt with perfect square", "sqrt", []interface{}{16.0}, 4.0, false},
{"sqrt with decimal", "sqrt", []interface{}{2.0}, 1.4142135623730951, false},
{"sqrt with zero", "sqrt", []interface{}{0}, 0.0, false},
{"sqrt with negative", "sqrt", []interface{}{-1}, nil, true},
// 时间日期函数
{"now basic", "now", []interface{}{}, time.Now().Unix(), false},
{"current_time basic", "current_time", []interface{}{}, time.Now().Format("15:04:05"), false},
{"current_date basic", "current_date", []interface{}{}, time.Now().Format("2006-01-02"), false},
// 新增数学函数测试
{"acos valid", "acos", []interface{}{0.5}, math.Acos(0.5), false},
{"acos invalid", "acos", []interface{}{2.0}, nil, true},
{"asin valid", "asin", []interface{}{0.5}, math.Asin(0.5), false},
{"asin invalid", "asin", []interface{}{2.0}, nil, true},
{"atan valid", "atan", []interface{}{1.0}, math.Atan(1.0), false},
{"atan2 valid", "atan2", []interface{}{1.0, 1.0}, math.Atan2(1.0, 1.0), false},
{"bitand valid", "bitand", []interface{}{5, 3}, int64(1), false},
{"bitor valid", "bitor", []interface{}{5, 3}, int64(7), false},
{"bitxor valid", "bitxor", []interface{}{5, 3}, int64(6), false},
{"bitnot valid", "bitnot", []interface{}{5}, int64(-6), false},
{"ceiling positive", "ceiling", []interface{}{3.7}, 4.0, false},
{"ceiling negative", "ceiling", []interface{}{-3.7}, -3.0, false},
{"cos valid", "cos", []interface{}{0.0}, 1.0, false},
{"cosh valid", "cosh", []interface{}{0.0}, 1.0, false},
{"exp valid", "exp", []interface{}{1.0}, math.E, false},
{"floor positive", "floor", []interface{}{3.7}, 3.0, false},
{"floor negative", "floor", []interface{}{-3.7}, -4.0, false},
{"ln valid", "ln", []interface{}{math.E}, 1.0, false},
{"ln invalid", "ln", []interface{}{-1.0}, nil, true},
{"power valid", "power", []interface{}{2.0, 3.0}, 8.0, false},
// 字符串函数
{"concat basic", "concat", []interface{}{"hello", " ", "world"}, "hello world", false},
{"concat single", "concat", []interface{}{"hello"}, "hello", false},
{"concat numbers", "concat", []interface{}{1, 2, 3}, "123", false},
{"length basic", "length", []interface{}{"hello"}, int(5), false},
{"length empty", "length", []interface{}{""}, int(0), false},
{"upper basic", "upper", []interface{}{"hello"}, "HELLO", false},
{"upper mixed", "upper", []interface{}{"Hello World"}, "HELLO WORLD", false},
{"lower basic", "lower", []interface{}{"HELLO"}, "hello", false},
{"lower mixed", "lower", []interface{}{"Hello World"}, "hello world", false},
// 转换函数
{"cast to int64", "cast", []interface{}{"123", "int64"}, int64(123), false},
{"cast to float64", "cast", []interface{}{"123.45", "float64"}, 123.45, false},
{"cast to string", "cast", []interface{}{123, "string"}, "123", false},
{"hex2dec basic", "hex2dec", []interface{}{"ff"}, int64(255), false},
{"hex2dec upper", "hex2dec", []interface{}{"FF"}, int64(255), false},
{"hex2dec with prefix", "hex2dec", []interface{}{"a0"}, int64(160), false},
{"dec2hex basic", "dec2hex", []interface{}{255}, "ff", false},
{"dec2hex zero", "dec2hex", []interface{}{0}, "0", false},
{"dec2hex large", "dec2hex", []interface{}{4095}, "fff", false},
{"encode base64", "encode", []interface{}{"hello", "base64"}, "aGVsbG8=", false},
{"encode hex", "encode", []interface{}{"hello", "hex"}, "68656c6c6f", false},
{"encode url", "encode", []interface{}{"hello world", "url"}, "hello+world", false},
{"encode invalid format", "encode", []interface{}{"hello", "invalid"}, nil, true},
{"encode invalid input", "encode", []interface{}{123, "base64"}, nil, true},
{"decode base64", "decode", []interface{}{"aGVsbG8=", "base64"}, "hello", false},
{"decode hex", "decode", []interface{}{"68656c6c6f", "hex"}, "hello", false},
{"decode url", "decode", []interface{}{"hello+world", "url"}, "hello world", false},
{"decode invalid format", "decode", []interface{}{"hello", "invalid"}, nil, true},
{"decode invalid base64", "decode", []interface{}{"invalid!", "base64"}, nil, true},
{"decode invalid hex", "decode", []interface{}{"invalid!", "hex"}, nil, true},
// 聚合函数
{"sum basic", "sum", []interface{}{1, 2, 3}, 6.0, false},
{"sum float", "sum", []interface{}{1.5, 2.5}, 4.0, false},
{"avg basic", "avg", []interface{}{1, 2, 3}, 2.0, false},
{"min basic", "min", []interface{}{3, 1, 2}, 1.0, false},
{"max basic", "max", []interface{}{3, 1, 2}, 3.0, false},
{"count basic", "count", []interface{}{1, 2, 3, 4, 5}, int64(5), false},
// 错误情况
{"hex2dec invalid", "hex2dec", []interface{}{"xyz"}, nil, true},
// 字符串函数
{"trim basic", "trim", []interface{}{" hello world "}, "hello world", false},
{"trim empty", "trim", []interface{}{""}, "", false},
{"format number 2 decimals", "format", []interface{}{123.456, "0.00"}, "123.46", false},
{"format number 0 decimals", "format", []interface{}{123.456, "0"}, "123", false},
{"format string only", "format", []interface{}{"hello"}, "hello", false},
// 新增的聚合函数
{"collect basic", "collect", []interface{}{1, 2, 3}, []interface{}{1, 2, 3}, false},
{"last_value basic", "last_value", []interface{}{1, 2, 3, 4}, 4, false},
{"merge_agg basic", "merge_agg", []interface{}{"a", "b", "c"}, "a,b,c", false},
{"stddevs basic", "stddevs", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 1.5811388300841898, false},
{"deduplicate basic", "deduplicate", []interface{}{1, 2, 2, 3, 3, 3}, []interface{}{1, 2, 3}, false},
{"var basic", "var", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 2.0, false},
{"vars basic", "vars", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 2.5, false},
// 窗口函数
{"row_number basic", "row_number", []interface{}{}, int64(1), false},
// 分析函数
{"latest basic", "latest", []interface{}{"hello"}, "hello", false},
{"had_changed first", "had_changed", []interface{}{"value1"}, true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.functionName)
assert.True(t, exists, "function %s should exist", tt.functionName)
if !exists || fn == nil {
t.Errorf("Function %s not found or is nil", tt.functionName)
return
}
result, err := fn.Execute(ctx, tt.args)
if tt.expectError {
assert.Error(t, err, "expected error for %s", tt.name)
} else {
assert.NoError(t, err, "no error expected for %s", tt.name)
if tt.expected != nil {
switch expected := tt.expected.(type) {
case float64:
if resultFloat, ok := result.(float64); ok {
assert.InDelta(t, expected, resultFloat, 0.0001, "result should match for %s", tt.name)
} else {
t.Errorf("Expected float64 but got %T for %s", result, tt.name)
}
case int64:
if tt.functionName == "now" {
// 对于 now 函数,我们只检查结果是否为 int64 类型,因为具体值会随时间变化
_, ok := result.(int64)
assert.True(t, ok, "now function should return int64")
} else {
assert.Equal(t, expected, result, "result should match for %s", tt.name)
}
case string:
if tt.functionName == "current_time" || tt.functionName == "current_date" {
// 对于时间日期函数,我们只检查格式是否正确
resultStr, ok := result.(string)
assert.True(t, ok, "%s function should return string", tt.functionName)
if tt.functionName == "current_time" {
_, err := time.Parse("15:04:05", resultStr)
assert.NoError(t, err, "current_time should return valid time format")
} else if tt.functionName == "current_date" {
_, err := time.Parse("2006-01-02", resultStr)
assert.NoError(t, err, "current_date should return valid date format")
}
} else {
assert.Equal(t, expected, result, "result should match for %s", tt.name)
}
default:
assert.Equal(t, expected, result, "result should match for %s", tt.name)
}
}
}
})
}
}
func TestFunctionValidation(t *testing.T) {
// 参数验证测试用例
tests := []struct {
name string
functionName string
args []interface{}
expectError bool
description string
}{
// abs 函数 - 需要1个参数
{"abs no args", "abs", []interface{}{}, true, "abs requires 1 argument"},
{"abs too many args", "abs", []interface{}{1.0, 2.0}, true, "abs accepts only 1 argument"},
{"abs correct args", "abs", []interface{}{1.0}, false, "abs should accept 1 argument"},
// 时间日期函数参数验证
{"current_time with args", "current_time", []interface{}{1}, true, "current_time should not accept arguments"},
{"current_date with args", "current_date", []interface{}{1}, true, "current_date should not accept arguments"},
// concat 函数 - 需要至少1个参数
{"concat no args", "concat", []interface{}{}, true, "concat requires at least 1 argument"},
{"concat one arg", "concat", []interface{}{"hello"}, false, "concat should accept 1 argument"},
{"concat multiple args", "concat", []interface{}{"a", "b", "c"}, false, "concat should accept multiple arguments"},
// cast 函数 - 需要恰好2个参数
{"cast no args", "cast", []interface{}{}, true, "cast requires 2 arguments"},
{"cast one arg", "cast", []interface{}{"123"}, true, "cast requires 2 arguments"},
{"cast correct args", "cast", []interface{}{"123", "int64"}, false, "cast should accept 2 arguments"},
{"cast too many args", "cast", []interface{}{"123", "int64", "extra"}, true, "cast accepts only 2 arguments"},
// now 函数 - 不需要参数
{"now no args", "now", []interface{}{}, false, "now should accept no arguments"},
{"now with args", "now", []interface{}{1}, true, "now should not accept arguments"},
// 新增数学函数参数验证
{"acos no args", "acos", []interface{}{}, true, "acos requires 1 argument"},
{"acos too many args", "acos", []interface{}{1.0, 2.0}, true, "acos accepts only 1 argument"},
{"atan2 no args", "atan2", []interface{}{}, true, "atan2 requires 2 arguments"},
{"atan2 one arg", "atan2", []interface{}{1.0}, true, "atan2 requires 2 arguments"},
{"atan2 too many args", "atan2", []interface{}{1.0, 2.0, 3.0}, true, "atan2 accepts only 2 arguments"},
{"bitand no args", "bitand", []interface{}{}, true, "bitand requires 2 arguments"},
{"bitand one arg", "bitand", []interface{}{1}, true, "bitand requires 2 arguments"},
{"bitand too many args", "bitand", []interface{}{1, 2, 3}, true, "bitand accepts only 2 arguments"},
{"bitnot no args", "bitnot", []interface{}{}, true, "bitnot requires 1 argument"},
{"bitnot too many args", "bitnot", []interface{}{1, 2}, true, "bitnot accepts only 1 argument"},
{"power no args", "power", []interface{}{}, true, "power requires 2 arguments"},
{"power one arg", "power", []interface{}{2.0}, true, "power requires 2 arguments"},
{"power too many args", "power", []interface{}{2.0, 3.0, 4.0}, true, "power accepts only 2 arguments"},
// 转换函数参数验证
{"encode no args", "encode", []interface{}{}, true, "encode requires 2 arguments"},
{"encode one arg", "encode", []interface{}{"hello"}, true, "encode requires 2 arguments"},
{"encode three args", "encode", []interface{}{"hello", "base64", "extra"}, true, "encode requires exactly 2 arguments"},
{"encode invalid format type", "encode", []interface{}{"hello", 123}, true, "encode format must be a string"},
{"decode no args", "decode", []interface{}{}, true, "decode requires 2 arguments"},
{"decode one arg", "decode", []interface{}{"aGVsbG8="}, true, "decode requires 2 arguments"},
{"decode three args", "decode", []interface{}{"aGVsbG8=", "base64", "extra"}, true, "decode requires exactly 2 arguments"},
{"decode invalid input type", "decode", []interface{}{123, "base64"}, true, "decode input must be a string"},
{"decode invalid format type", "decode", []interface{}{"aGVsbG8=", 123}, true, "decode format must be a string"},
// 新增函数的验证测试
{"trim no args", "trim", []interface{}{}, true, "function trim requires at least 1 arguments"},
{"trim too many args", "trim", []interface{}{"hello", "world"}, true, "function trim accepts at most 1 arguments"},
{"format too many args", "format", []interface{}{"hello", "pattern", "locale", "extra"}, true, "function format accepts at most 3 arguments"},
{"collect no args", "collect", []interface{}{}, true, "function collect requires at least 1 arguments"},
{"row_number with args", "row_number", []interface{}{"invalid"}, true, "function row_number accepts at most 0 arguments"},
{"latest no args", "latest", []interface{}{}, true, "function latest requires at least 1 arguments"},
{"had_changed no args", "had_changed", []interface{}{}, true, "function had_changed requires at least 1 arguments"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.functionName)
assert.True(t, exists, "function %s should exist", tt.functionName)
err := fn.Validate(tt.args)
if tt.expectError {
assert.Error(t, err, tt.description)
} else {
assert.NoError(t, err, tt.description)
}
})
}
}
func TestFunctionTypes(t *testing.T) {
// 函数类型分类测试
tests := []struct {
functionType FunctionType
functions []string
}{
{TypeMath, []string{
"abs", "sqrt", "acos", "asin", "atan", "atan2",
"bitand", "bitor", "bitxor", "bitnot",
"ceiling", "cos", "cosh", "exp", "floor", "ln", "power",
}},
{TypeString, []string{"concat", "length", "upper", "lower", "trim", "format"}},
{TypeConversion, []string{"cast", "hex2dec", "dec2hex", "encode", "decode"}},
{TypeDateTime, []string{"now", "current_time", "current_date"}},
{TypeAggregation, []string{"sum", "avg", "min", "max", "count", "stddev", "median", "collect", "last_value", "merge_agg", "stddevs", "deduplicate", "var", "vars"}},
{TypeWindow, []string{"row_number"}},
{TypeAnalytical, []string{"lag", "latest", "changed_col", "had_changed"}},
}
for _, tt := range tests {
t.Run(string(tt.functionType), func(t *testing.T) {
functions := GetByType(tt.functionType)
assert.GreaterOrEqual(t, len(functions), len(tt.functions),
"should have at least %d functions of type %s", len(tt.functions), tt.functionType)
// 验证特定函数存在
functionNames := make(map[string]bool)
for _, fn := range functions {
functionNames[fn.GetName()] = true
}
for _, expectedFn := range tt.functions {
assert.True(t, functionNames[expectedFn],
"function %s should be of type %s", expectedFn, tt.functionType)
}
})
}
}
func TestCustomFunction(t *testing.T) {
// 注册自定义函数
err := RegisterCustomFunction("double2", TypeCustom, "自定义函数", "将数值乘以2", 1, 1,
func(ctx *FunctionContext, args []interface{}) (interface{}, error) {
val := cast.ToFloat64(args[0])
return val * 2, nil
})
assert.NoError(t, err)
// 测试自定义函数
tests := []struct {
name string
args []interface{}
expected interface{}
}{
{"double positive", []interface{}{5.0}, 10.0},
{"double negative", []interface{}{-3.0}, -6.0},
{"double zero", []interface{}{0}, 0.0},
{"double string number", []interface{}{"2.5"}, 5.0},
}
ctx := &FunctionContext{
Data: map[string]interface{}{},
}
doubleFunc, exists := Get("double2")
assert.True(t, exists)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := doubleFunc.Execute(ctx, tt.args)
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
// 清理
Unregister("double2")
}
func TestComplexFunctionCombinations(t *testing.T) {
ctx := &FunctionContext{
Data: map[string]interface{}{},
}
// 测试复杂函数组合
tests := []struct {
name string
description string
operations func() (interface{}, error)
expected interface{}
}{
{
name: "abs of negative sum",
description: "计算负数之和的绝对值",
operations: func() (interface{}, error) {
sumFn, _ := Get("sum")
sum, err := sumFn.Execute(ctx, []interface{}{-1, -2, -3})
if err != nil {
return nil, err
}
absFn, _ := Get("abs")
return absFn.Execute(ctx, []interface{}{sum})
},
expected: 6.0,
},
{
name: "concat and upper",
description: "连接字符串后转大写",
operations: func() (interface{}, error) {
concatFn, _ := Get("concat")
concat, err := concatFn.Execute(ctx, []interface{}{"hello", " ", "world"})
if err != nil {
return nil, err
}
upperFn, _ := Get("upper")
return upperFn.Execute(ctx, []interface{}{concat})
},
expected: "HELLO WORLD",
},
{
name: "hex conversion round trip",
description: "十进制转十六进制再转回十进制",
operations: func() (interface{}, error) {
dec2hexFn, _ := Get("dec2hex")
hex, err := dec2hexFn.Execute(ctx, []interface{}{255})
if err != nil {
return nil, err
}
hex2decFn, _ := Get("hex2dec")
return hex2decFn.Execute(ctx, []interface{}{hex})
},
expected: int64(255),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := tt.operations()
assert.NoError(t, err, tt.description)
assert.Equal(t, tt.expected, result, tt.description)
})
}
}
// TestCaseInsensitiveFunctions 测试函数大小写不敏感
func TestCaseInsensitiveFunctions(t *testing.T) {
tests := []struct {
name string
functionName string
expected bool
}{
{"小写concat", "concat", true},
{"大写CONCAT", "CONCAT", true},
{"混合大小写Concat", "Concat", true},
{"混合大小写cOnCaT", "cOnCaT", true},
{"小写upper", "upper", true},
{"大写UPPER", "UPPER", true},
{"混合大小写Upper", "Upper", true},
{"小写lower", "lower", true},
{"大写LOWER", "LOWER", true},
{"混合大小写Lower", "Lower", true},
{"小写length", "length", true},
{"大写LENGTH", "LENGTH", true},
{"混合大小写Length", "Length", true},
{"小写trim", "trim", true},
{"大写TRIM", "TRIM", true},
{"混合大小写Trim", "Trim", true},
{"不存在的函数", "nonexistent", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, exists := Get(tt.functionName)
assert.Equal(t, tt.expected, exists, "函数 %s 的查找结果应该是 %v", tt.functionName, tt.expected)
})
}
}
// TestConcatFunctionCaseInsensitive 测试CONCAT函数的大小写不敏感执行
func TestConcatFunctionCaseInsensitive(t *testing.T) {
tests := []struct {
name string
functionName string
args []interface{}
expected string
}{
{"小写concat", "concat", []interface{}{"hello", " ", "world"}, "hello world"},
{"大写CONCAT", "CONCAT", []interface{}{"hello", " ", "world"}, "hello world"},
{"混合大小写Concat", "Concat", []interface{}{"hello", " ", "world"}, "hello world"},
{"混合大小写cOnCaT", "cOnCaT", []interface{}{"hello", " ", "world"}, "hello world"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.functionName)
assert.True(t, exists, "函数 %s 应该存在", tt.functionName)
ctx := &FunctionContext{
Data: make(map[string]interface{}),
}
result, err := fn.Execute(ctx, tt.args)
assert.NoError(t, err, "函数 %s 执行不应该出错", tt.functionName)
assert.Equal(t, tt.expected, result, "函数 %s 的执行结果应该正确", tt.functionName)
})
}
}
// TestStringFunctionsCaseInsensitive 测试所有字符串函数的大小写不敏感
func TestStringFunctionsCaseInsensitive(t *testing.T) {
ctx := &FunctionContext{
Data: make(map[string]interface{}),
}
// 测试UPPER函数
t.Run("UPPER函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"upper", "UPPER", "Upper", "uPpEr"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
result, err := fn.Execute(ctx, []interface{}{"hello"})
assert.NoError(t, err, "函数 %s 执行不应该出错", name)
assert.Equal(t, "HELLO", result, "函数 %s 的执行结果应该正确", name)
}
})
// 测试LOWER函数
t.Run("LOWER函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"lower", "LOWER", "Lower", "lOwEr"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
result, err := fn.Execute(ctx, []interface{}{"HELLO"})
assert.NoError(t, err, "函数 %s 执行不应该出错", name)
assert.Equal(t, "hello", result, "函数 %s 的执行结果应该正确", name)
}
})
// 测试LENGTH函数
t.Run("LENGTH函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"length", "LENGTH", "Length", "lEnGtH"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
result, err := fn.Execute(ctx, []interface{}{"hello"})
assert.NoError(t, err, "函数 %s 执行不应该出错", name)
assert.Equal(t, int(5), result, "函数 %s 的执行结果应该正确", name)
}
})
// 测试TRIM函数
t.Run("TRIM函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"trim", "TRIM", "Trim", "tRiM"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
result, err := fn.Execute(ctx, []interface{}{" hello "})
assert.NoError(t, err, "函数 %s 执行不应该出错", name)
assert.Equal(t, "hello", result, "函数 %s 的执行结果应该正确", name)
}
})
}
// TestMathFunctionsCaseInsensitive 测试数学函数的大小写不敏感
func TestMathFunctionsCaseInsensitive(t *testing.T) {
ctx := &FunctionContext{
Data: make(map[string]interface{}),
}
// 测试ABS函数
t.Run("ABS函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"abs", "ABS", "Abs", "aBs"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
result, err := fn.Execute(ctx, []interface{}{-5.5})
assert.NoError(t, err, "函数 %s 执行不应该出错", name)
assert.Equal(t, 5.5, result, "函数 %s 的执行结果应该正确", name)
}
})
// 测试SQRT函数
t.Run("SQRT函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"sqrt", "SQRT", "Sqrt", "sQrT"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
result, err := fn.Execute(ctx, []interface{}{9.0})
assert.NoError(t, err, "函数 %s 执行不应该出错", name)
assert.Equal(t, 3.0, result, "函数 %s 的执行结果应该正确", name)
}
})
}
// TestAggregationFunctionsCaseInsensitive 测试聚合函数的大小写不敏感
func TestAggregationFunctionsCaseInsensitive(t *testing.T) {
// 测试SUM函数
t.Run("SUM函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"sum", "SUM", "Sum", "sUm"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
assert.Equal(t, TypeAggregation, fn.GetType(), "函数 %s 应该是聚合函数", name)
}
})
// 测试AVG函数
t.Run("AVG函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"avg", "AVG", "Avg", "aVg"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
assert.Equal(t, TypeAggregation, fn.GetType(), "函数 %s 应该是聚合函数", name)
}
})
// 测试COUNT函数
t.Run("COUNT函数大小写不敏感", func(t *testing.T) {
functionNames := []string{"count", "COUNT", "Count", "cOuNt"}
for _, name := range functionNames {
fn, exists := Get(name)
assert.True(t, exists, "函数 %s 应该存在", name)
assert.Equal(t, TypeAggregation, fn.GetType(), "函数 %s 应该是聚合函数", name)
}
})
}
// TestFunctionAliases 测试函数别名功能
func TestFunctionAliases(t *testing.T) {
// 测试 power 函数的 pow 别名
powerFunc, exists := Get("power")
if !exists {
t.Fatal("power function not found")
}
powFunc, exists := Get("pow")
if !exists {
t.Fatal("pow alias not found")
}
// 验证别名指向同一个函数实例
if powerFunc != powFunc {
t.Error("pow alias should point to the same function as power")
}
// 测试 length 函数的 len 别名
lengthFunc, exists := Get("length")
if !exists {
t.Fatal("length function not found")
}
lenFunc, exists := Get("len")
if !exists {
t.Fatal("len alias not found")
}
// 验证别名指向同一个函数实例
if lengthFunc != lenFunc {
t.Error("len alias should point to the same function as length")
}
// 测试 ceiling 函数的 ceil 别名
ceilingFunc, exists := Get("ceiling")
if !exists {
t.Fatal("ceiling function not found")
}
ceilFunc, exists := Get("ceil")
if !exists {
t.Fatal("ceil alias not found")
}
// 验证别名指向同一个函数实例
if ceilingFunc != ceilFunc {
t.Error("ceil alias should point to the same function as ceiling")
}
// 验证别名列表
aliases := powerFunc.GetAliases()
if len(aliases) != 1 || aliases[0] != "pow" {
t.Errorf("Expected aliases [pow], got %v", aliases)
}
aliases = lengthFunc.GetAliases()
if len(aliases) != 1 || aliases[0] != "len" {
t.Errorf("Expected aliases [len], got %v", aliases)
}
aliases = ceilingFunc.GetAliases()
if len(aliases) != 1 || aliases[0] != "ceil" {
t.Errorf("Expected aliases [ceil], got %v", aliases)
}
}
// TestFunctionAliasExecution 测试通过别名执行函数
// TestExecuteFunction 测试Execute函数的各种场景
func TestExecuteFunction(t *testing.T) {
ctx := &FunctionContext{
Data: map[string]interface{}{
"x": 10,
"y": 20,
},
}
// 测试正常执行
result, err := Execute("abs", ctx, []interface{}{-5})
assert.NoError(t, err)
assert.Equal(t, 5.0, result)
// 测试函数不存在
result, err = Execute("nonexistent_function", ctx, []interface{}{1})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "function nonexistent_function not found")
// 测试参数验证失败
result, err = Execute("abs", ctx, []interface{}{})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "validation failed")
// 测试参数过多
result, err = Execute("abs", ctx, []interface{}{1, 2, 3})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "validation failed")
}
// TestCustomFunctionValidate 测试CustomFunction的Validate方法
func TestCustomFunctionValidate(t *testing.T) {
// 注册一个自定义函数
err := RegisterCustomFunction("test_custom", TypeMath, "test", "test function", 1, 2,
func(ctx *FunctionContext, args []interface{}) (interface{}, error) {
return args[0], nil
})
assert.NoError(t, err)
// 获取自定义函数
fn, exists := Get("test_custom")
assert.True(t, exists)
// 测试参数数量正确
err = fn.Validate([]interface{}{1})
assert.NoError(t, err)
err = fn.Validate([]interface{}{1, 2})
assert.NoError(t, err)
// 测试参数数量不足
err = fn.Validate([]interface{}{})
assert.Error(t, err)
// 测试参数数量过多
err = fn.Validate([]interface{}{1, 2, 3})
assert.Error(t, err)
// 清理
Unregister("test_custom")
}
// TestRegisterCustomFunctionErrors 测试RegisterCustomFunction的错误情况
func TestRegisterCustomFunctionErrors(t *testing.T) {
// 测试空函数名
err := RegisterCustomFunction("", TypeMath, "test", "test function", 1, 2,
func(ctx *FunctionContext, args []interface{}) (interface{}, error) {
return nil, nil
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "function name cannot be empty")
// 测试正常注册
err = RegisterCustomFunction("valid_custom", TypeMath, "test", "test function", 1, 2,
func(ctx *FunctionContext, args []interface{}) (interface{}, error) {
return args[0], nil
})
assert.NoError(t, err)
// 清理
Unregister("valid_custom")
}
func TestFunctionAliasExecution(t *testing.T) {
ctx := &FunctionContext{}
// 测试 pow 别名执行
result, err := Execute("pow", ctx, []interface{}{2.0, 3.0})
if err != nil {
t.Fatalf("pow execution failed: %v", err)
}
if result != 8.0 {
t.Errorf("Expected 8.0, got %v", result)
}
// 测试 len 别名执行
result, err = Execute("len", ctx, []interface{}{"hello"})
if err != nil {
t.Fatalf("len execution failed: %v", err)
}
if result != 5 {
t.Errorf("Expected 5, got %v", result)
}
// 测试 ceil 别名执行
result, err = Execute("ceil", ctx, []interface{}{3.2})
if err != nil {
t.Fatalf("ceil execution failed: %v", err)
}
if result != 4.0 {
t.Errorf("Expected 4.0, got %v", result)
}
}