mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-04-08 13:33:20 +00:00
533 lines
21 KiB
Go
533 lines
21 KiB
Go
package functions
|
|
|
|
import (
|
|
"github.com/rulego/streamsql/utils/cast"
|
|
"math"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestBasicFunctionRegistry(t *testing.T) {
|
|
// 测试基本函数注册
|
|
tests := []struct {
|
|
name string
|
|
functionName string
|
|
expectedType FunctionType
|
|
}{
|
|
{"abs function", "abs", TypeMath},
|
|
{"concat function", "concat", TypeString},
|
|
{"sqrt function", "sqrt", TypeMath},
|
|
{"upper function", "upper", TypeString},
|
|
{"cast function", "cast", TypeConversion},
|
|
{"now function", "now", TypeDateTime},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
fn, exists := Get(tt.functionName)
|
|
assert.True(t, exists, "%s should be registered", tt.functionName)
|
|
assert.NotNil(t, fn)
|
|
assert.Equal(t, tt.functionName, fn.GetName())
|
|
assert.Equal(t, tt.expectedType, fn.GetType())
|
|
})
|
|
}
|
|
|
|
// 测试不存在的函数
|
|
_, exists := Get("nonexistent")
|
|
assert.False(t, exists, "nonexistent function should not be found")
|
|
}
|
|
|
|
func TestNewMathFunctions(t *testing.T) {
|
|
ctx := &FunctionContext{
|
|
Data: map[string]interface{}{},
|
|
}
|
|
|
|
// 表驱动测试用例
|
|
tests := []struct {
|
|
name string
|
|
functionName string
|
|
args []interface{}
|
|
expected interface{}
|
|
expectError bool
|
|
errorMsg string
|
|
delta float64 // 用于浮点数比较的精度
|
|
}{
|
|
// Log function tests
|
|
{"log valid", "log", []interface{}{math.E}, 1.0, false, "", 1e-10},
|
|
{"log negative", "log", []interface{}{-1}, nil, true, "value must be positive", 0},
|
|
{"log zero", "log", []interface{}{0}, nil, true, "value must be positive", 0},
|
|
|
|
// Log10 function tests
|
|
{"log10 100", "log10", []interface{}{100}, 2.0, false, "", 1e-10},
|
|
{"log10 10", "log10", []interface{}{10}, 1.0, false, "", 1e-10},
|
|
|
|
// Log2 function tests
|
|
{"log2 8", "log2", []interface{}{8}, 3.0, false, "", 1e-10},
|
|
{"log2 2", "log2", []interface{}{2}, 1.0, false, "", 1e-10},
|
|
|
|
// Mod function tests
|
|
{"mod 10,3", "mod", []interface{}{10, 3}, 1.0, false, "", 1e-10},
|
|
{"mod 7.5,2.5", "mod", []interface{}{7.5, 2.5}, 0.0, false, "", 1e-10},
|
|
{"mod division by zero", "mod", []interface{}{10, 0}, nil, true, "division by zero", 0},
|
|
|
|
// Round function tests
|
|
{"round 3.7", "round", []interface{}{3.7}, 4.0, false, "", 1e-10},
|
|
{"round 3.2", "round", []interface{}{3.2}, 3.0, false, "", 1e-10},
|
|
{"round with precision", "round", []interface{}{3.14159, 2}, 3.14, false, "", 1e-10},
|
|
|
|
// Sign function tests
|
|
{"sign positive", "sign", []interface{}{5.5}, 1, false, "", 0},
|
|
{"sign negative", "sign", []interface{}{-3.2}, -1, false, "", 0},
|
|
{"sign zero", "sign", []interface{}{0}, 0, false, "", 0},
|
|
|
|
// Trigonometric function tests
|
|
{"sin 0", "sin", []interface{}{0}, 0.0, false, "", 1e-10},
|
|
{"sin π/2", "sin", []interface{}{math.Pi / 2}, 1.0, false, "", 1e-10},
|
|
{"sinh 0", "sinh", []interface{}{0}, 0.0, false, "", 1e-10},
|
|
{"tan 0", "tan", []interface{}{0}, 0.0, false, "", 1e-10},
|
|
{"tanh 0", "tanh", []interface{}{0}, 0.0, false, "", 1e-10},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
fn, exists := Get(tt.functionName)
|
|
assert.True(t, exists, "Function %s should be registered", tt.functionName)
|
|
|
|
result, err := fn.Execute(ctx, tt.args)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
if tt.errorMsg != "" {
|
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
|
}
|
|
} else {
|
|
assert.NoError(t, err)
|
|
if tt.delta > 0 {
|
|
assert.InDelta(t, tt.expected, result, tt.delta)
|
|
} else {
|
|
assert.Equal(t, tt.expected, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// 特殊测试:rand函数(因为结果是随机的)
|
|
t.Run("rand function", func(t *testing.T) {
|
|
fn, exists := Get("rand")
|
|
assert.True(t, exists)
|
|
|
|
result, err := fn.Execute(ctx, []interface{}{})
|
|
assert.NoError(t, err)
|
|
|
|
val, ok := result.(float64)
|
|
assert.True(t, ok)
|
|
assert.GreaterOrEqual(t, val, 0.0)
|
|
assert.Less(t, val, 1.0)
|
|
})
|
|
}
|
|
|
|
func TestFunctionExecution(t *testing.T) {
|
|
ctx := &FunctionContext{
|
|
Data: map[string]interface{}{},
|
|
}
|
|
|
|
// 函数执行测试用例
|
|
tests := []struct {
|
|
name string
|
|
functionName string
|
|
args []interface{}
|
|
expected interface{}
|
|
expectError bool
|
|
}{
|
|
// 数学函数
|
|
{"abs with positive", "abs", []interface{}{5.5}, 5.5, false},
|
|
{"abs with negative", "abs", []interface{}{-5.5}, 5.5, false},
|
|
{"abs with zero", "abs", []interface{}{0}, 0.0, false},
|
|
{"sqrt with perfect square", "sqrt", []interface{}{16.0}, 4.0, false},
|
|
{"sqrt with decimal", "sqrt", []interface{}{2.0}, 1.4142135623730951, false},
|
|
{"sqrt with zero", "sqrt", []interface{}{0}, 0.0, false},
|
|
{"sqrt with negative", "sqrt", []interface{}{-1}, nil, true},
|
|
|
|
// 时间日期函数
|
|
{"now basic", "now", []interface{}{}, time.Now().Unix(), false},
|
|
{"current_time basic", "current_time", []interface{}{}, time.Now().Format("15:04:05"), false},
|
|
{"current_date basic", "current_date", []interface{}{}, time.Now().Format("2006-01-02"), false},
|
|
|
|
// 新增数学函数测试
|
|
{"acos valid", "acos", []interface{}{0.5}, math.Acos(0.5), false},
|
|
{"acos invalid", "acos", []interface{}{2.0}, nil, true},
|
|
{"asin valid", "asin", []interface{}{0.5}, math.Asin(0.5), false},
|
|
{"asin invalid", "asin", []interface{}{2.0}, nil, true},
|
|
{"atan valid", "atan", []interface{}{1.0}, math.Atan(1.0), false},
|
|
{"atan2 valid", "atan2", []interface{}{1.0, 1.0}, math.Atan2(1.0, 1.0), false},
|
|
{"bitand valid", "bitand", []interface{}{5, 3}, int64(1), false},
|
|
{"bitor valid", "bitor", []interface{}{5, 3}, int64(7), false},
|
|
{"bitxor valid", "bitxor", []interface{}{5, 3}, int64(6), false},
|
|
{"bitnot valid", "bitnot", []interface{}{5}, int64(-6), false},
|
|
{"ceiling positive", "ceiling", []interface{}{3.7}, 4.0, false},
|
|
{"ceiling negative", "ceiling", []interface{}{-3.7}, -3.0, false},
|
|
{"cos valid", "cos", []interface{}{0.0}, 1.0, false},
|
|
{"cosh valid", "cosh", []interface{}{0.0}, 1.0, false},
|
|
{"exp valid", "exp", []interface{}{1.0}, math.E, false},
|
|
{"floor positive", "floor", []interface{}{3.7}, 3.0, false},
|
|
{"floor negative", "floor", []interface{}{-3.7}, -4.0, false},
|
|
{"ln valid", "ln", []interface{}{math.E}, 1.0, false},
|
|
{"ln invalid", "ln", []interface{}{-1.0}, nil, true},
|
|
{"power valid", "power", []interface{}{2.0, 3.0}, 8.0, false},
|
|
|
|
// 字符串函数
|
|
{"concat basic", "concat", []interface{}{"hello", " ", "world"}, "hello world", false},
|
|
{"concat single", "concat", []interface{}{"hello"}, "hello", false},
|
|
{"concat numbers", "concat", []interface{}{1, 2, 3}, "123", false},
|
|
{"length basic", "length", []interface{}{"hello"}, int64(5), false},
|
|
{"length empty", "length", []interface{}{""}, int64(0), false},
|
|
{"upper basic", "upper", []interface{}{"hello"}, "HELLO", false},
|
|
{"upper mixed", "upper", []interface{}{"Hello World"}, "HELLO WORLD", false},
|
|
{"lower basic", "lower", []interface{}{"HELLO"}, "hello", false},
|
|
{"lower mixed", "lower", []interface{}{"Hello World"}, "hello world", false},
|
|
|
|
// 转换函数
|
|
{"cast to int64", "cast", []interface{}{"123", "int64"}, int64(123), false},
|
|
{"cast to float64", "cast", []interface{}{"123.45", "float64"}, 123.45, false},
|
|
{"cast to string", "cast", []interface{}{123, "string"}, "123", false},
|
|
{"hex2dec basic", "hex2dec", []interface{}{"ff"}, int64(255), false},
|
|
{"hex2dec upper", "hex2dec", []interface{}{"FF"}, int64(255), false},
|
|
{"hex2dec with prefix", "hex2dec", []interface{}{"a0"}, int64(160), false},
|
|
{"dec2hex basic", "dec2hex", []interface{}{255}, "ff", false},
|
|
{"dec2hex zero", "dec2hex", []interface{}{0}, "0", false},
|
|
{"dec2hex large", "dec2hex", []interface{}{4095}, "fff", false},
|
|
{"encode base64", "encode", []interface{}{"hello", "base64"}, "aGVsbG8=", false},
|
|
{"encode hex", "encode", []interface{}{"hello", "hex"}, "68656c6c6f", false},
|
|
{"encode url", "encode", []interface{}{"hello world", "url"}, "hello+world", false},
|
|
{"encode invalid format", "encode", []interface{}{"hello", "invalid"}, nil, true},
|
|
{"encode invalid input", "encode", []interface{}{123, "base64"}, nil, true},
|
|
|
|
{"decode base64", "decode", []interface{}{"aGVsbG8=", "base64"}, "hello", false},
|
|
{"decode hex", "decode", []interface{}{"68656c6c6f", "hex"}, "hello", false},
|
|
{"decode url", "decode", []interface{}{"hello+world", "url"}, "hello world", false},
|
|
{"decode invalid format", "decode", []interface{}{"hello", "invalid"}, nil, true},
|
|
{"decode invalid base64", "decode", []interface{}{"invalid!", "base64"}, nil, true},
|
|
{"decode invalid hex", "decode", []interface{}{"invalid!", "hex"}, nil, true},
|
|
|
|
// 聚合函数
|
|
{"sum basic", "sum", []interface{}{1, 2, 3}, 6.0, false},
|
|
{"sum float", "sum", []interface{}{1.5, 2.5}, 4.0, false},
|
|
{"avg basic", "avg", []interface{}{1, 2, 3}, 2.0, false},
|
|
{"min basic", "min", []interface{}{3, 1, 2}, 1.0, false},
|
|
{"max basic", "max", []interface{}{3, 1, 2}, 3.0, false},
|
|
{"count basic", "count", []interface{}{1, 2, 3, 4, 5}, int64(5), false},
|
|
|
|
// 错误情况
|
|
{"hex2dec invalid", "hex2dec", []interface{}{"xyz"}, nil, true},
|
|
|
|
// 新增的字符串函数
|
|
{"trim basic", "trim", []interface{}{" hello world "}, "hello world", false},
|
|
{"trim empty", "trim", []interface{}{""}, "", false},
|
|
{"format number 2 decimals", "format", []interface{}{123.456, "0.00"}, "123.46", false},
|
|
{"format number 0 decimals", "format", []interface{}{123.456, "0"}, "123", false},
|
|
{"format string only", "format", []interface{}{"hello"}, "hello", false},
|
|
|
|
// 新增的聚合函数
|
|
{"collect basic", "collect", []interface{}{1, 2, 3}, []interface{}{1, 2, 3}, false},
|
|
{"last_value basic", "last_value", []interface{}{1, 2, 3, 4}, 4, false},
|
|
{"merge_agg basic", "merge_agg", []interface{}{"a", "b", "c"}, "a,b,c", false},
|
|
{"stddevs basic", "stddevs", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 1.5811388300841898, false},
|
|
{"deduplicate basic", "deduplicate", []interface{}{1, 2, 2, 3, 3, 3}, []interface{}{1, 2, 3}, false},
|
|
{"var basic", "var", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 2.0, false},
|
|
{"vars basic", "vars", []interface{}{1.0, 2.0, 3.0, 4.0, 5.0}, 2.5, false},
|
|
|
|
// 窗口函数
|
|
{"row_number basic", "row_number", []interface{}{}, int64(1), false},
|
|
|
|
// 分析函数
|
|
{"latest basic", "latest", []interface{}{"hello"}, "hello", false},
|
|
{"had_changed first", "had_changed", []interface{}{"value1"}, true, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
fn, exists := Get(tt.functionName)
|
|
assert.True(t, exists, "function %s should exist", tt.functionName)
|
|
if !exists || fn == nil {
|
|
t.Errorf("Function %s not found or is nil", tt.functionName)
|
|
return
|
|
}
|
|
|
|
result, err := fn.Execute(ctx, tt.args)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err, "expected error for %s", tt.name)
|
|
} else {
|
|
assert.NoError(t, err, "no error expected for %s", tt.name)
|
|
if tt.expected != nil {
|
|
switch expected := tt.expected.(type) {
|
|
case float64:
|
|
if resultFloat, ok := result.(float64); ok {
|
|
assert.InDelta(t, expected, resultFloat, 0.0001, "result should match for %s", tt.name)
|
|
} else {
|
|
t.Errorf("Expected float64 but got %T for %s", result, tt.name)
|
|
}
|
|
case int64:
|
|
if tt.functionName == "now" {
|
|
// 对于 now 函数,我们只检查结果是否为 int64 类型,因为具体值会随时间变化
|
|
_, ok := result.(int64)
|
|
assert.True(t, ok, "now function should return int64")
|
|
} else {
|
|
assert.Equal(t, expected, result, "result should match for %s", tt.name)
|
|
}
|
|
case string:
|
|
if tt.functionName == "current_time" || tt.functionName == "current_date" {
|
|
// 对于时间日期函数,我们只检查格式是否正确
|
|
resultStr, ok := result.(string)
|
|
assert.True(t, ok, "%s function should return string", tt.functionName)
|
|
if tt.functionName == "current_time" {
|
|
_, err := time.Parse("15:04:05", resultStr)
|
|
assert.NoError(t, err, "current_time should return valid time format")
|
|
} else if tt.functionName == "current_date" {
|
|
_, err := time.Parse("2006-01-02", resultStr)
|
|
assert.NoError(t, err, "current_date should return valid date format")
|
|
}
|
|
} else {
|
|
assert.Equal(t, expected, result, "result should match for %s", tt.name)
|
|
}
|
|
default:
|
|
assert.Equal(t, expected, result, "result should match for %s", tt.name)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFunctionValidation(t *testing.T) {
|
|
// 参数验证测试用例
|
|
tests := []struct {
|
|
name string
|
|
functionName string
|
|
args []interface{}
|
|
expectError bool
|
|
description string
|
|
}{
|
|
// abs 函数 - 需要1个参数
|
|
{"abs no args", "abs", []interface{}{}, true, "abs requires 1 argument"},
|
|
{"abs too many args", "abs", []interface{}{1.0, 2.0}, true, "abs accepts only 1 argument"},
|
|
{"abs correct args", "abs", []interface{}{1.0}, false, "abs should accept 1 argument"},
|
|
|
|
// 时间日期函数参数验证
|
|
{"current_time with args", "current_time", []interface{}{1}, true, "current_time should not accept arguments"},
|
|
{"current_date with args", "current_date", []interface{}{1}, true, "current_date should not accept arguments"},
|
|
|
|
// concat 函数 - 需要至少1个参数
|
|
{"concat no args", "concat", []interface{}{}, true, "concat requires at least 1 argument"},
|
|
{"concat one arg", "concat", []interface{}{"hello"}, false, "concat should accept 1 argument"},
|
|
{"concat multiple args", "concat", []interface{}{"a", "b", "c"}, false, "concat should accept multiple arguments"},
|
|
|
|
// cast 函数 - 需要恰好2个参数
|
|
{"cast no args", "cast", []interface{}{}, true, "cast requires 2 arguments"},
|
|
{"cast one arg", "cast", []interface{}{"123"}, true, "cast requires 2 arguments"},
|
|
{"cast correct args", "cast", []interface{}{"123", "int64"}, false, "cast should accept 2 arguments"},
|
|
{"cast too many args", "cast", []interface{}{"123", "int64", "extra"}, true, "cast accepts only 2 arguments"},
|
|
|
|
// now 函数 - 不需要参数
|
|
{"now no args", "now", []interface{}{}, false, "now should accept no arguments"},
|
|
{"now with args", "now", []interface{}{1}, true, "now should not accept arguments"},
|
|
|
|
// 新增数学函数参数验证
|
|
{"acos no args", "acos", []interface{}{}, true, "acos requires 1 argument"},
|
|
{"acos too many args", "acos", []interface{}{1.0, 2.0}, true, "acos accepts only 1 argument"},
|
|
{"atan2 no args", "atan2", []interface{}{}, true, "atan2 requires 2 arguments"},
|
|
{"atan2 one arg", "atan2", []interface{}{1.0}, true, "atan2 requires 2 arguments"},
|
|
{"atan2 too many args", "atan2", []interface{}{1.0, 2.0, 3.0}, true, "atan2 accepts only 2 arguments"},
|
|
{"bitand no args", "bitand", []interface{}{}, true, "bitand requires 2 arguments"},
|
|
{"bitand one arg", "bitand", []interface{}{1}, true, "bitand requires 2 arguments"},
|
|
{"bitand too many args", "bitand", []interface{}{1, 2, 3}, true, "bitand accepts only 2 arguments"},
|
|
{"bitnot no args", "bitnot", []interface{}{}, true, "bitnot requires 1 argument"},
|
|
{"bitnot too many args", "bitnot", []interface{}{1, 2}, true, "bitnot accepts only 1 argument"},
|
|
{"power no args", "power", []interface{}{}, true, "power requires 2 arguments"},
|
|
{"power one arg", "power", []interface{}{2.0}, true, "power requires 2 arguments"},
|
|
{"power too many args", "power", []interface{}{2.0, 3.0, 4.0}, true, "power accepts only 2 arguments"},
|
|
|
|
// 转换函数参数验证
|
|
{"encode no args", "encode", []interface{}{}, true, "encode requires 2 arguments"},
|
|
{"encode one arg", "encode", []interface{}{"hello"}, true, "encode requires 2 arguments"},
|
|
{"encode three args", "encode", []interface{}{"hello", "base64", "extra"}, true, "encode requires exactly 2 arguments"},
|
|
{"encode invalid format type", "encode", []interface{}{"hello", 123}, true, "encode format must be a string"},
|
|
|
|
{"decode no args", "decode", []interface{}{}, true, "decode requires 2 arguments"},
|
|
{"decode one arg", "decode", []interface{}{"aGVsbG8="}, true, "decode requires 2 arguments"},
|
|
{"decode three args", "decode", []interface{}{"aGVsbG8=", "base64", "extra"}, true, "decode requires exactly 2 arguments"},
|
|
{"decode invalid input type", "decode", []interface{}{123, "base64"}, true, "decode input must be a string"},
|
|
{"decode invalid format type", "decode", []interface{}{"aGVsbG8=", 123}, true, "decode format must be a string"},
|
|
|
|
// 新增函数的验证测试
|
|
{"trim no args", "trim", []interface{}{}, true, "function trim requires at least 1 arguments"},
|
|
{"trim too many args", "trim", []interface{}{"hello", "world"}, true, "function trim accepts at most 1 arguments"},
|
|
{"format too many args", "format", []interface{}{"hello", "pattern", "locale", "extra"}, true, "function format accepts at most 3 arguments"},
|
|
{"collect no args", "collect", []interface{}{}, true, "function collect requires at least 1 arguments"},
|
|
{"row_number with args", "row_number", []interface{}{"invalid"}, true, "function row_number accepts at most 0 arguments"},
|
|
{"latest no args", "latest", []interface{}{}, true, "function latest requires at least 1 arguments"},
|
|
{"had_changed no args", "had_changed", []interface{}{}, true, "function had_changed requires at least 1 arguments"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
fn, exists := Get(tt.functionName)
|
|
assert.True(t, exists, "function %s should exist", tt.functionName)
|
|
|
|
err := fn.Validate(tt.args)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err, tt.description)
|
|
} else {
|
|
assert.NoError(t, err, tt.description)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFunctionTypes(t *testing.T) {
|
|
// 函数类型分类测试
|
|
tests := []struct {
|
|
functionType FunctionType
|
|
functions []string
|
|
}{
|
|
{TypeMath, []string{
|
|
"abs", "sqrt", "acos", "asin", "atan", "atan2",
|
|
"bitand", "bitor", "bitxor", "bitnot",
|
|
"ceiling", "cos", "cosh", "exp", "floor", "ln", "power",
|
|
}},
|
|
{TypeString, []string{"concat", "length", "upper", "lower", "trim", "format"}},
|
|
{TypeConversion, []string{"cast", "hex2dec", "dec2hex", "encode", "decode"}},
|
|
{TypeDateTime, []string{"now", "current_time", "current_date"}},
|
|
{TypeAggregation, []string{"sum", "avg", "min", "max", "count", "stddev", "median", "collect", "last_value", "merge_agg", "stddevs", "deduplicate", "var", "vars"}},
|
|
{TypeWindow, []string{"row_number"}},
|
|
{TypeAnalytical, []string{"lag", "latest", "changed_col", "had_changed"}},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(string(tt.functionType), func(t *testing.T) {
|
|
functions := GetByType(tt.functionType)
|
|
assert.GreaterOrEqual(t, len(functions), len(tt.functions),
|
|
"should have at least %d functions of type %s", len(tt.functions), tt.functionType)
|
|
|
|
// 验证特定函数存在
|
|
functionNames := make(map[string]bool)
|
|
for _, fn := range functions {
|
|
functionNames[fn.GetName()] = true
|
|
}
|
|
|
|
for _, expectedFn := range tt.functions {
|
|
assert.True(t, functionNames[expectedFn],
|
|
"function %s should be of type %s", expectedFn, tt.functionType)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCustomFunction(t *testing.T) {
|
|
// 注册自定义函数
|
|
err := RegisterCustomFunction("double2", TypeCustom, "自定义函数", "将数值乘以2", 1, 1,
|
|
func(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
|
val := cast.ToFloat64(args[0])
|
|
return val * 2, nil
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
// 测试自定义函数
|
|
tests := []struct {
|
|
name string
|
|
args []interface{}
|
|
expected interface{}
|
|
}{
|
|
{"double positive", []interface{}{5.0}, 10.0},
|
|
{"double negative", []interface{}{-3.0}, -6.0},
|
|
{"double zero", []interface{}{0}, 0.0},
|
|
{"double string number", []interface{}{"2.5"}, 5.0},
|
|
}
|
|
|
|
ctx := &FunctionContext{
|
|
Data: map[string]interface{}{},
|
|
}
|
|
|
|
doubleFunc, exists := Get("double2")
|
|
assert.True(t, exists)
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := doubleFunc.Execute(ctx, tt.args)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
|
|
// 清理
|
|
Unregister("double2")
|
|
}
|
|
|
|
func TestComplexFunctionCombinations(t *testing.T) {
|
|
ctx := &FunctionContext{
|
|
Data: map[string]interface{}{},
|
|
}
|
|
|
|
// 测试复杂函数组合
|
|
tests := []struct {
|
|
name string
|
|
description string
|
|
operations func() (interface{}, error)
|
|
expected interface{}
|
|
}{
|
|
{
|
|
name: "abs of negative sum",
|
|
description: "计算负数之和的绝对值",
|
|
operations: func() (interface{}, error) {
|
|
sumFn, _ := Get("sum")
|
|
sum, err := sumFn.Execute(ctx, []interface{}{-1, -2, -3})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
absFn, _ := Get("abs")
|
|
return absFn.Execute(ctx, []interface{}{sum})
|
|
},
|
|
expected: 6.0,
|
|
},
|
|
{
|
|
name: "concat and upper",
|
|
description: "连接字符串后转大写",
|
|
operations: func() (interface{}, error) {
|
|
concatFn, _ := Get("concat")
|
|
concat, err := concatFn.Execute(ctx, []interface{}{"hello", " ", "world"})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upperFn, _ := Get("upper")
|
|
return upperFn.Execute(ctx, []interface{}{concat})
|
|
},
|
|
expected: "HELLO WORLD",
|
|
},
|
|
{
|
|
name: "hex conversion round trip",
|
|
description: "十进制转十六进制再转回十进制",
|
|
operations: func() (interface{}, error) {
|
|
dec2hexFn, _ := Get("dec2hex")
|
|
hex, err := dec2hexFn.Execute(ctx, []interface{}{255})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
hex2decFn, _ := Get("hex2dec")
|
|
return hex2decFn.Execute(ctx, []interface{}{hex})
|
|
},
|
|
expected: int64(255),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := tt.operations()
|
|
assert.NoError(t, err, tt.description)
|
|
assert.Equal(t, tt.expected, result, tt.description)
|
|
})
|
|
}
|
|
}
|