Files
streamsql/streamsql_custom_functions_test.go

777 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package streamsql
import (
"encoding/json"
"fmt"
"math"
"net"
"testing"
"time"
"github.com/rulego/streamsql/utils/cast"
"github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/expr"
"github.com/rulego/streamsql/functions"
"github.com/rulego/streamsql/rsql"
"github.com/stretchr/testify/assert"
)
// TestCustomMathFunctions 测试自定义数学函数
func TestCustomMathFunctions(t *testing.T) {
// 注册平方函数
err := functions.RegisterCustomFunction(
"square",
functions.TypeMath,
"数学函数",
"计算平方",
1, 1,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
val := cast.ToFloat64(args[0])
return val * val, nil
},
)
assert.NoError(t, err)
defer functions.Unregister("square")
// 注册距离计算函数
err = functions.RegisterCustomFunction(
"distance",
functions.TypeMath,
"几何数学",
"计算两点间距离",
4, 4,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
x1 := cast.ToFloat64(args[0])
y1 := cast.ToFloat64(args[1])
x2 := cast.ToFloat64(args[2])
y2 := cast.ToFloat64(args[3])
distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2))
return distance, nil
},
)
assert.NoError(t, err)
defer functions.Unregister("distance")
// 测试在SQL中使用
streamsql := New()
defer streamsql.Stop()
sql := `
SELECT
device,
AVG(square(value)) as squared_value,
AVG(distance(x1, y1, x2, y2)) as calculated_distance
FROM stream
GROUP BY device, TumblingWindow('1s')
`
err = streamsql.Execute(sql)
assert.NoError(t, err)
// 创建结果接收通道
resultChan := make(chan interface{}, 10)
streamsql.Stream().AddSink(func(result interface{}) {
resultChan <- result
})
// 添加测试数据
testData := map[string]interface{}{
"device": "sensor1",
"value": 5.0,
"x1": 0.0,
"y1": 0.0,
"x2": 3.0,
"y2": 4.0, // 距离应该是5
}
streamsql.AddData(testData)
// 等待窗口触发
time.Sleep(1 * time.Second)
streamsql.Stream().Window.Trigger()
time.Sleep(500 * time.Millisecond)
// 验证结果
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
assert.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
assert.Equal(t, 25.0, item["squared_value"]) // 5^2 = 25
assert.Equal(t, 5.0, item["calculated_distance"]) // sqrt((3-0)^2 + (4-0)^2) = 5
case <-time.After(2 * time.Second):
t.Fatal("测试超时")
}
fmt.Println("✅ 自定义数学函数测试通过")
}
// TestCustomStringFunctions 测试自定义字符串函数
func TestCustomStringFunctions(t *testing.T) {
// 注册字符串反转函数
err := functions.RegisterCustomFunction(
"reverse_str",
functions.TypeString,
"字符串函数",
"反转字符串",
1, 1,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
str := cast.ToString(args[0])
runes := []rune(str)
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
runes[i], runes[j] = runes[j], runes[i]
}
return string(runes), nil
},
)
assert.NoError(t, err)
defer functions.Unregister("reverse_str")
// 注册JSON提取函数
err = functions.RegisterCustomFunction(
"json_get",
functions.TypeString,
"JSON处理",
"从JSON字符串中提取字段值",
2, 2,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
jsonStr := cast.ToString(args[0])
key := cast.ToString(args[1])
var data map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
value, exists := data[key]
if !exists {
return nil, nil
}
return value, nil
},
)
assert.NoError(t, err)
defer functions.Unregister("json_get")
// 测试在SQL中使用
streamsql := New()
defer streamsql.Stop()
sql := `
SELECT
device,
reverse_str(device) as reversed_device,
json_get(metadata, 'version') as version
FROM stream
`
err = streamsql.Execute(sql)
assert.NoError(t, err)
// 创建结果接收通道
resultChan := make(chan interface{}, 10)
streamsql.Stream().AddSink(func(result interface{}) {
resultChan <- result
})
// 添加测试数据
testData := map[string]interface{}{
"device": "sensor1",
"metadata": `{"version":"1.0","type":"temperature"}`,
}
streamsql.AddData(testData)
time.Sleep(200 * time.Millisecond)
// 验证结果
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
assert.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
assert.Equal(t, "1rosnes", item["reversed_device"]) // "sensor1" 反转
assert.Equal(t, "1.0", item["version"])
case <-time.After(2 * time.Second):
t.Fatal("测试超时")
}
fmt.Println("✅ 自定义字符串函数测试通过")
}
// TestCustomConversionFunctions 测试自定义转换函数
func TestCustomConversionFunctions(t *testing.T) {
// 注册IP地址转换函数
err := functions.RegisterCustomFunction(
"ip_to_num",
functions.TypeConversion,
"网络转换",
"将IP地址转换为整数",
1, 1,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
ipStr := cast.ToString(args[0])
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", ipStr)
}
ip = ip.To4()
if ip == nil {
return nil, fmt.Errorf("not an IPv4 address: %s", ipStr)
}
return int64(ip[0])<<24 + int64(ip[1])<<16 + int64(ip[2])<<8 + int64(ip[3]), nil
},
)
assert.NoError(t, err)
defer functions.Unregister("ip_to_num")
// 注册字节大小格式化函数
err = functions.RegisterCustomFunction(
"format_bytes",
functions.TypeConversion,
"数据格式化",
"格式化字节大小为人类可读格式",
1, 1,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
bytes := cast.ToFloat64(args[0])
units := []string{"B", "KB", "MB", "GB", "TB"}
i := 0
for bytes >= 1024 && i < len(units)-1 {
bytes /= 1024
i++
}
return fmt.Sprintf("%.2f %s", bytes, units[i]), nil
},
)
assert.NoError(t, err)
defer functions.Unregister("format_bytes")
// 测试函数直接调用
ctx := &functions.FunctionContext{Data: make(map[string]interface{})}
// 测试IP转换
ipFunc, exists := functions.Get("ip_to_num")
assert.True(t, exists)
result, err := ipFunc.Execute(ctx, []interface{}{"192.168.1.100"})
assert.NoError(t, err)
expectedIP := int64(192)<<24 + int64(168)<<16 + int64(1)<<8 + int64(100)
assert.Equal(t, expectedIP, result)
// 测试字节格式化
bytesFunc, exists := functions.Get("format_bytes")
assert.True(t, exists)
result, err = bytesFunc.Execute(ctx, []interface{}{1073741824}) // 1GB
assert.NoError(t, err)
assert.Equal(t, "1.00 GB", result)
fmt.Println("✅ 自定义转换函数测试通过")
}
// TestCustomAggregateFunctions 测试自定义聚合函数
func TestCustomAggregateFunctions(t *testing.T) {
// 注册几何平均数聚合函数
functions.Register(NewGeometricMeanFunction())
aggregator.Register("geometric_mean", func() aggregator.AggregatorFunction {
return &GeometricMeanAggregator{}
})
defer functions.Unregister("geometric_mean")
// 注册众数聚合函数
functions.Register(NewModeFunction())
aggregator.Register("mode_value", func() aggregator.AggregatorFunction {
return &ModeAggregator{}
})
defer functions.Unregister("mode_value")
// 测试在SQL中使用
streamsql := New()
defer streamsql.Stop()
sql := `
SELECT
device,
geometric_mean(value) as geo_mean,
mode_value(category) as most_common
FROM stream
GROUP BY device, TumblingWindow('1s')
`
err := streamsql.Execute(sql)
assert.NoError(t, err)
// 创建结果接收通道
resultChan := make(chan interface{}, 10)
streamsql.Stream().AddSink(func(result interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []interface{}{
map[string]interface{}{"device": "sensor1", "value": 2.0, "category": "A"},
map[string]interface{}{"device": "sensor1", "value": 8.0, "category": "A"},
map[string]interface{}{"device": "sensor1", "value": 32.0, "category": "B"},
map[string]interface{}{"device": "sensor1", "value": 128.0, "category": "A"},
}
for _, data := range testData {
streamsql.AddData(data)
}
time.Sleep(1 * time.Second)
streamsql.Stream().Window.Trigger()
time.Sleep(500 * time.Millisecond)
// 验证结果
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
assert.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "sensor1", item["device"])
// 几何平均数: (2 * 8 * 32 * 128) ^ (1/4) = 16
geoMean, ok := item["geo_mean"].(float64)
assert.True(t, ok)
assert.InEpsilon(t, 16.0, geoMean, 0.01)
// 众数: A出现3次B出现1次所以众数是A
mode := item["most_common"]
assert.Equal(t, "A", mode)
case <-time.After(3 * time.Second):
t.Fatal("测试超时")
}
fmt.Println("✅ 自定义聚合函数测试通过")
}
// GeometricMeanFunction 几何平均数函数
type GeometricMeanFunction struct {
*functions.BaseFunction
}
func NewGeometricMeanFunction() *GeometricMeanFunction {
return &GeometricMeanFunction{
BaseFunction: functions.NewBaseFunction(
"geometric_mean",
functions.TypeAggregation,
"统计聚合",
"计算几何平均数",
1, -1,
),
}
}
func (f *GeometricMeanFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *GeometricMeanFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
return nil, nil // 实际逻辑在聚合器中
}
// GeometricMeanAggregator 几何平均数聚合器
type GeometricMeanAggregator struct {
values []float64
}
func (g *GeometricMeanAggregator) New() aggregator.AggregatorFunction {
return &GeometricMeanAggregator{values: make([]float64, 0)}
}
func (g *GeometricMeanAggregator) Add(value interface{}) {
if val := cast.ToFloat64(value); val > 0 {
g.values = append(g.values, val)
}
}
func (g *GeometricMeanAggregator) Result() interface{} {
if len(g.values) == 0 {
return 0.0
}
product := 1.0
for _, v := range g.values {
product *= v
}
return math.Pow(product, 1.0/float64(len(g.values)))
}
// ModeFunction 众数函数
type ModeFunction struct {
*functions.BaseFunction
}
func NewModeFunction() *ModeFunction {
return &ModeFunction{
BaseFunction: functions.NewBaseFunction(
"mode_value",
functions.TypeAggregation,
"统计聚合",
"计算众数",
1, -1,
),
}
}
func (f *ModeFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *ModeFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
return nil, nil // 实际逻辑在聚合器中
}
// ModeAggregator 众数聚合器
type ModeAggregator struct {
counts map[string]int
}
func (m *ModeAggregator) New() aggregator.AggregatorFunction {
return &ModeAggregator{counts: make(map[string]int)}
}
func (m *ModeAggregator) Add(value interface{}) {
key := fmt.Sprintf("%v", value)
m.counts[key]++
}
func (m *ModeAggregator) Result() interface{} {
if len(m.counts) == 0 {
return nil
}
maxCount := 0
var mode interface{}
for key, count := range m.counts {
if count > maxCount {
maxCount = count
mode = key
}
}
return mode
}
// TestFunctionManagement 测试函数管理功能
func TestFunctionManagement(t *testing.T) {
// 注册测试函数
err := functions.RegisterCustomFunction(
"test_func",
functions.TypeCustom,
"测试函数",
"用于测试的函数",
1, 1,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
return args[0], nil
},
)
assert.NoError(t, err)
// 测试函数查找
fn, exists := functions.Get("test_func")
assert.True(t, exists)
assert.Equal(t, "test_func", fn.GetName())
assert.Equal(t, functions.TypeCustom, fn.GetType())
// 测试函数列表
allFunctions := functions.ListAll()
assert.Contains(t, allFunctions, "test_func")
// 测试按类型获取
customFunctions := functions.GetByType(functions.TypeCustom)
found := false
for _, f := range customFunctions {
if f.GetName() == "test_func" {
found = true
break
}
}
assert.True(t, found)
// 测试函数注销
success := functions.Unregister("test_func")
assert.True(t, success)
// 验证函数已被注销
_, exists = functions.Get("test_func")
assert.False(t, exists)
fmt.Println("✅ 函数管理功能测试通过")
}
// TestCustomFunctionWithAggregation 测试自定义函数与聚合函数结合使用
func TestCustomFunctionWithAggregation(t *testing.T) {
// 注册温度转换函数
err := functions.RegisterCustomFunction(
"celsius_to_fahrenheit",
functions.TypeConversion,
"温度转换",
"摄氏度转华氏度",
1, 1,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
celsius := cast.ToFloat64(args[0])
fahrenheit := celsius*9/5 + 32
return fahrenheit, nil
},
)
assert.NoError(t, err)
defer functions.Unregister("celsius_to_fahrenheit")
// 测试在聚合SQL中使用
streamsql := New()
defer streamsql.Stop()
sql := `
SELECT
device,
AVG(celsius_to_fahrenheit(temperature)) as avg_fahrenheit,
MAX(celsius_to_fahrenheit(temperature)) as max_fahrenheit
FROM stream
GROUP BY device, TumblingWindow('1s')
`
err = streamsql.Execute(sql)
assert.NoError(t, err)
// 创建结果接收通道
resultChan := make(chan interface{}, 10)
streamsql.Stream().AddSink(func(result interface{}) {
resultChan <- result
})
// 添加测试数据(摄氏度)
testData := []interface{}{
map[string]interface{}{"device": "thermometer", "temperature": 0.0}, // 32°F
map[string]interface{}{"device": "thermometer", "temperature": 100.0}, // 212°F
}
for _, data := range testData {
streamsql.AddData(data)
}
time.Sleep(1 * time.Second)
streamsql.Stream().Window.Trigger()
time.Sleep(500 * time.Millisecond)
// 验证结果
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
assert.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
assert.Equal(t, "thermometer", item["device"])
// 平均华氏度: (32 + 212) / 2 = 122
avgFahrenheit, ok := item["avg_fahrenheit"].(float64)
assert.True(t, ok)
assert.InEpsilon(t, 122.0, avgFahrenheit, 0.01)
// 最大华氏度: 212
maxFahrenheit, ok := item["max_fahrenheit"].(float64)
assert.True(t, ok)
assert.InEpsilon(t, 212.0, maxFahrenheit, 0.01)
case <-time.After(3 * time.Second):
t.Fatal("测试超时")
}
fmt.Println("✅ 自定义函数与聚合函数结合使用测试通过")
}
// TestDebugCustomFunctions 调试自定义函数问题
func TestDebugCustomFunctions(t *testing.T) {
// 注册简单的平方函数
err := functions.RegisterCustomFunction(
"square",
functions.TypeMath,
"数学函数",
"计算平方",
1, 1,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
val := cast.ToFloat64(args[0])
fmt.Printf("Square function called with: %v, result: %v\n", val, val*val)
return val * val, nil
},
)
assert.NoError(t, err)
defer functions.Unregister("square")
// 测试函数是否能被找到
fn, exists := functions.Get("square")
assert.True(t, exists)
fmt.Printf("Function found: %s, type: %s\n", fn.GetName(), fn.GetType())
// 测试表达式解析
expr, err := expr.NewExpression("square(value)")
assert.NoError(t, err)
// 获取表达式字段
fields := expr.GetFields()
fmt.Printf("Expression fields: %v\n", fields)
// 测试表达式计算
data := map[string]interface{}{"value": 5.0}
result, err := expr.Evaluate(data)
assert.NoError(t, err)
fmt.Printf("Expression result: %v\n", result)
assert.Equal(t, 25.0, result)
// 测试SQL解析
parser := rsql.NewParser("SELECT square(value) as squared FROM stream")
stmt, err := parser.Parse()
assert.NoError(t, err)
config, _, err := stmt.ToStreamConfig()
assert.NoError(t, err)
fmt.Printf("SQL Config - SelectFields: %v\n", config.SelectFields)
fmt.Printf("SQL Config - FieldAlias: %v\n", config.FieldAlias)
fmt.Printf("SQL Config - FieldExpressions: %v\n", config.FieldExpressions)
fmt.Printf("SQL Config - NeedWindow: %v\n", config.NeedWindow)
fmt.Println("✅ 调试测试完成")
}
// TestDebugMultiParameterFunction 测试多参数自定义函数
func TestDebugMultiParameterFunction(t *testing.T) {
// 注册距离计算函数
err := functions.RegisterCustomFunction(
"distance",
functions.TypeMath,
"几何数学",
"计算两点间距离",
4, 4,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
x1 := cast.ToFloat64(args[0])
y1 := cast.ToFloat64(args[1])
x2 := cast.ToFloat64(args[2])
y2 := cast.ToFloat64(args[3])
distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2))
fmt.Printf("Distance function called with: (%v,%v) to (%v,%v), result: %v\n", x1, y1, x2, y2, distance)
return distance, nil
},
)
assert.NoError(t, err)
defer functions.Unregister("distance")
// 测试表达式解析
expr, err := expr.NewExpression("distance(x1, y1, x2, y2)")
assert.NoError(t, err)
// 获取表达式字段
fields := expr.GetFields()
fmt.Printf("Distance expression fields: %v\n", fields)
// 测试表达式计算
data := map[string]interface{}{
"x1": 0.0,
"y1": 0.0,
"x2": 3.0,
"y2": 4.0,
}
result, err := expr.Evaluate(data)
assert.NoError(t, err)
fmt.Printf("Distance expression result: %v\n", result)
assert.Equal(t, 5.0, result)
// 测试SQL解析
parser := rsql.NewParser("SELECT AVG(distance(x1, y1, x2, y2)) as avg_distance FROM stream GROUP BY device, TumblingWindow('1s')")
stmt, err := parser.Parse()
assert.NoError(t, err)
config, _, err := stmt.ToStreamConfig()
assert.NoError(t, err)
fmt.Printf("Distance SQL Config - SelectFields: %v\n", config.SelectFields)
fmt.Printf("Distance SQL Config - FieldAlias: %v\n", config.FieldAlias)
fmt.Printf("Distance SQL Config - FieldExpressions: %v\n", config.FieldExpressions)
fmt.Println("✅ 多参数函数调试测试完成")
}
// TestDebugSQLParsing 调试SQL解析过程
func TestDebugSQLParsing(t *testing.T) {
// 注册距离计算函数
err := functions.RegisterCustomFunction(
"distance",
functions.TypeMath,
"几何数学",
"计算两点间距离",
4, 4,
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
x1 := cast.ToFloat64(args[0])
y1 := cast.ToFloat64(args[1])
x2 := cast.ToFloat64(args[2])
y2 := cast.ToFloat64(args[3])
distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2))
return distance, nil
},
)
assert.NoError(t, err)
defer functions.Unregister("distance")
// 测试不同的SQL形式
testCases := []string{
"SELECT distance(x1, y1, x2, y2) as calc_distance FROM stream",
"SELECT AVG(distance(x1, y1, x2, y2)) as avg_distance FROM stream",
"SELECT AVG(distance(x1, y1, x2, y2)) as avg_distance FROM stream GROUP BY device, TumblingWindow('1s')",
}
for i, sql := range testCases {
fmt.Printf("\n=== 测试SQL %d: %s ===\n", i+1, sql)
parser := rsql.NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
fmt.Printf("SQL解析错误: %v\n", err)
continue
}
// 打印解析结果
fmt.Printf("解析到的字段数量: %d\n", len(stmt.Fields))
for j, field := range stmt.Fields {
fmt.Printf("字段 %d: Expression='%s', Alias='%s'\n", j, field.Expression, field.Alias)
}
config, condition, err := stmt.ToStreamConfig()
if err != nil {
fmt.Printf("转换配置错误: %v\n", err)
continue
}
fmt.Printf("转换后配置:\n")
fmt.Printf(" SelectFields: %v\n", config.SelectFields)
fmt.Printf(" FieldAlias: %v\n", config.FieldAlias)
fmt.Printf(" FieldExpressions: %v\n", config.FieldExpressions)
fmt.Printf(" NeedWindow: %v\n", config.NeedWindow)
fmt.Printf(" Condition: %s\n", condition)
}
fmt.Println("✅ SQL解析调试测试完成")
}