Files
streamsql/docs/FUNCTION_QUICK_START.md
T
2025-05-25 18:02:37 +08:00

12 KiB
Raw Blame History

StreamSQL 自定义函数快速入门

🚀 概述

StreamSQL 提供了强大的自定义函数系统,让你可以轻松扩展框架功能。本指南将帮你快速上手,创建和使用自定义函数。

📋 快速开始

1. 注册简单函数

最简单的方式是使用 RegisterCustomFunction 方法:

import "github.com/rulego/streamsql/functions"

// 注册一个平方函数
err := functions.RegisterCustomFunction(
    "square",                    // 函数名
    functions.TypeMath,          // 函数类型
    "数学函数",                   // 分类
    "计算数值的平方",             // 描述
    1,                          // 最少参数数量
    1,                          // 最多参数数量
    func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
        // 参数转换
        val, err := functions.ConvertToFloat64(args[0])
        if err != nil {
            return nil, err
        }
        // 业务逻辑
        return val * val, nil
    },
)

2. 在SQL中使用

SELECT device, square(value) as squared_value FROM stream

🎯 函数类型

数学函数 (TypeMath)

// 距离计算函数
functions.RegisterCustomFunction(
    "distance",
    functions.TypeMath,
    "几何数学",
    "计算两点间的欧几里得距离",
    4, 4,
    func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
        x1, _ := functions.ConvertToFloat64(args[0])
        y1, _ := functions.ConvertToFloat64(args[1])
        x2, _ := functions.ConvertToFloat64(args[2])
        y2, _ := functions.ConvertToFloat64(args[3])
        
        return math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2)), nil
    },
)

// SQL使用
// SELECT device, distance(lat1, lon1, lat2, lon2) as dist FROM stream

字符串函数 (TypeString)

// 字符串反转函数
functions.RegisterCustomFunction(
    "reverse",
    functions.TypeString,
    "字符串处理",
    "反转字符串",
    1, 1,
    func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
        str, err := functions.ConvertToString(args[0])
        if err != nil {
            return nil, err
        }
        
        runes := []rune(str)
        for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
            runes[i], runes[j] = runes[j], runes[i]
        }
        
        return string(runes), nil
    },
)

// SQL使用
// SELECT device, reverse(device_name) as reversed_name FROM stream

转换函数 (TypeConversion)

// IP地址转整数
functions.RegisterCustomFunction(
    "ip_to_int",
    functions.TypeConversion,
    "网络转换",
    "将IPv4地址转换为32位整数",
    1, 1,
    func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
        ipStr, err := functions.ConvertToString(args[0])
        if err != nil {
            return nil, err
        }
        
        ip := net.ParseIP(ipStr).To4()
        if ip == nil {
            return nil, fmt.Errorf("invalid IPv4: %s", ipStr)
        }
        
        return int64(ip[0])<<24 + int64(ip[1])<<16 + int64(ip[2])<<8 + int64(ip[3]), nil
    },
)

// SQL使用
// SELECT device, ip_to_int(client_ip) as ip_num FROM stream

时间日期函数 (TypeDateTime)

// 时间格式化函数
functions.RegisterCustomFunction(
    "format_time",
    functions.TypeDateTime,
    "时间格式化",
    "格式化Unix时间戳",
    2, 2,
    func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
        timestamp, err := functions.ConvertToInt64(args[0])
        if err != nil {
            return nil, err
        }
        
        format, err := functions.ConvertToString(args[1])
        if err != nil {
            return nil, err
        }
        
        t := time.Unix(timestamp, 0)
        return t.Format(format), nil
    },
)

// SQL使用
// SELECT device, format_time(timestamp, '2006-01-02 15:04:05') as formatted_time FROM stream

🏗️ 复杂函数实现

对于复杂函数,建议使用结构体方式:

// 1. 定义函数结构
type StatefulFunction struct {
    *functions.BaseFunction
    counter int64
    mutex   sync.Mutex
}

// 2. 构造函数
func NewStatefulFunction() *StatefulFunction {
    return &StatefulFunction{
        BaseFunction: functions.NewBaseFunction(
            "counter",
            functions.TypeCustom,
            "状态函数",
            "递增计数器",
            0, 0,
        ),
        counter: 0,
    }
}

// 3. 验证参数(可选)
func (f *StatefulFunction) Validate(args []interface{}) error {
    return f.ValidateArgCount(args)
}

// 4. 执行函数
func (f *StatefulFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
    f.mutex.Lock()
    defer f.mutex.Unlock()
    
    f.counter++
    return f.counter, nil
}

// 5. 注册函数
func init() {
    functions.Register(NewStatefulFunction())
}

📊 聚合函数

聚合函数需要同时实现函数和聚合器:

// 1. 实现聚合函数
type GeometricMeanFunction struct {
    *functions.BaseFunction
}

func NewGeometricMeanFunction() *GeometricMeanFunction {
    return &GeometricMeanFunction{
        BaseFunction: functions.NewBaseFunction(
            "geometric_mean",
            functions.TypeAggregation,
            "统计聚合",
            "计算几何平均数",
            1, -1,
        ),
    }
}

func (f *GeometricMeanFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
    return nil, nil // 逻辑在聚合器中
}

// 2. 实现聚合器
type GeometricMeanAggregator struct {
    values []float64
}

func (g *GeometricMeanAggregator) New() aggregator.AggregatorFunction {
    return &GeometricMeanAggregator{values: make([]float64, 0)}
}

func (g *GeometricMeanAggregator) Add(value interface{}) {
    if val, err := functions.ConvertToFloat64(value); err == nil && val > 0 {
        g.values = append(g.values, val)
    }
}

func (g *GeometricMeanAggregator) Result() interface{} {
    if len(g.values) == 0 {
        return 0.0
    }
    
    product := 1.0
    for _, v := range g.values {
        product *= v
    }
    
    return math.Pow(product, 1.0/float64(len(g.values)))
}

// 3. 注册
func init() {
    functions.Register(NewGeometricMeanFunction())
    aggregator.Register("geometric_mean", func() aggregator.AggregatorFunction {
        return &GeometricMeanAggregator{}
    })
}

// SQL使用
// SELECT device, geometric_mean(value) as geo_mean FROM stream GROUP BY device

🔧 函数管理

查看注册的函数

// 列出所有函数
allFunctions := functions.ListAll()
for name, fn := range allFunctions {
    fmt.Printf("函数: %s (%s) - %s\n", name, fn.GetType(), fn.GetDescription())
}

// 按类型查看
mathFunctions := functions.GetByType(functions.TypeMath)
for _, fn := range mathFunctions {
    fmt.Printf("数学函数: %s\n", fn.GetName())
}

// 查找特定函数
if fn, exists := functions.Get("square"); exists {
    fmt.Printf("找到函数: %s\n", fn.GetDescription())
}

注销函数

// 注销函数
success := functions.Unregister("my_function")
if success {
    fmt.Println("函数注销成功")
}

🎯 完整示例

创建温度转换函数

package main

import (
    "fmt"
    "time"
    "github.com/rulego/streamsql"
    "github.com/rulego/streamsql/functions"
)

func main() {
    // 1. 注册自定义函数
    registerCustomFunctions()
    
    // 2. 创建StreamSQL实例
    ssql := streamsql.New()
    defer ssql.Stop()
    
    // 3. 执行SQL
    sql := `
        SELECT 
            device,
            celsius_to_fahrenheit(temperature) as temp_f,
            format_temperature(temperature, 'C') as formatted_temp
        FROM stream
    `
    
    err := ssql.Execute(sql)
    if err != nil {
        panic(err)
    }
    
    // 4. 添加结果监听
    ssql.Stream().AddSink(func(result interface{}) {
        fmt.Printf("结果: %v\n", result)
    })
    
    // 5. 添加数据
    ssql.AddData(map[string]interface{}{
        "device": "thermometer1",
        "temperature": 25.0,
    })
    
    time.Sleep(time.Second)
}

func registerCustomFunctions() {
    // 摄氏度转华氏度
    functions.RegisterCustomFunction(
        "celsius_to_fahrenheit",
        functions.TypeMath,
        "温度转换",
        "摄氏度转华氏度",
        1, 1,
        func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
            celsius, err := functions.ConvertToFloat64(args[0])
            if err != nil {
                return nil, err
            }
            return celsius*9/5 + 32, nil
        },
    )
    
    // 温度格式化
    functions.RegisterCustomFunction(
        "format_temperature",
        functions.TypeString,
        "格式化函数",
        "格式化温度显示",
        2, 2,
        func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
            temp, err := functions.ConvertToFloat64(args[0])
            if err != nil {
                return nil, err
            }
            
            unit, err := functions.ConvertToString(args[1])
            if err != nil {
                return nil, err
            }
            
            return fmt.Sprintf("%.1f°%s", temp, unit), nil
        },
    )
}

🚨 最佳实践

1. 错误处理

func (f *MyFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
    // 参数数量检查
    if len(args) == 0 {
        return nil, fmt.Errorf("至少需要1个参数")
    }
    
    // 类型转换
    val, err := functions.ConvertToFloat64(args[0])
    if err != nil {
        return nil, fmt.Errorf("参数类型错误: %v", err)
    }
    
    // 业务逻辑验证
    if val < 0 {
        return nil, fmt.Errorf("参数必须为非负数")
    }
    
    return math.Sqrt(val), nil
}

2. 性能优化

type CachedFunction struct {
    *functions.BaseFunction
    cache map[string]interface{}
    mutex sync.RWMutex
}

func (f *CachedFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
    key := fmt.Sprintf("%v", args)
    
    // 检查缓存
    f.mutex.RLock()
    if result, exists := f.cache[key]; exists {
        f.mutex.RUnlock()
        return result, nil
    }
    f.mutex.RUnlock()
    
    // 计算结果
    result := f.calculate(args)
    
    // 缓存结果
    f.mutex.Lock()
    f.cache[key] = result
    f.mutex.Unlock()
    
    return result, nil
}

3. 线程安全

type ThreadSafeFunction struct {
    *functions.BaseFunction
    state map[string]interface{}
    mutex sync.RWMutex
}

func (f *ThreadSafeFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
    f.mutex.Lock()
    defer f.mutex.Unlock()
    
    // 安全地修改状态
    f.state["counter"] = f.state["counter"].(int) + 1
    
    return f.state["counter"], nil
}

📝 测试你的函数

func TestMyCustomFunction(t *testing.T) {
    // 注册函数
    err := functions.RegisterCustomFunction("test_func", functions.TypeMath, "测试", "测试函数", 1, 1,
        func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
            val, err := functions.ConvertToFloat64(args[0])
            return val * 2, err
        })
    assert.NoError(t, err)
    defer functions.Unregister("test_func")
    
    // 获取并测试函数
    fn, exists := functions.Get("test_func")
    assert.True(t, exists)
    
    ctx := &functions.FunctionContext{Data: make(map[string]interface{})}
    result, err := fn.Execute(ctx, []interface{}{5.0})
    
    assert.NoError(t, err)
    assert.Equal(t, 10.0, result)
}

通过这个快速入门指南你已经掌握了StreamSQL自定义函数的基本用法。现在可以开始创建自己的函数来扩展框架功能