feat:增加函数实现

This commit is contained in:
rulego-team
2025-06-11 18:45:39 +08:00
parent a5e4967021
commit c9b0486564
24 changed files with 1151 additions and 735 deletions
+10 -4
View File
@@ -65,8 +65,8 @@ func registerBuiltinFunctions() {
_ = Register(NewToSecondsFunction())
_ = Register(NewChrFunction())
_ = Register(NewTruncFunction())
_ = Register(NewCompressFunction())
_ = Register(NewDecompressFunction())
_ = Register(NewUrlEncodeFunction())
_ = Register(NewUrlDecodeFunction())
// Time-Date functions
_ = Register(NewNowFunction())
@@ -79,8 +79,8 @@ func registerBuiltinFunctions() {
_ = Register(NewMinFunction())
_ = Register(NewMaxFunction())
_ = Register(NewCountFunction())
_ = Register(NewStdDevFunction())
_ = Register(NewMedianFunction())
_ = Register(NewStdDevAggregatorFunction())
_ = Register(NewMedianAggregatorFunction())
_ = Register(NewPercentileFunction())
_ = Register(NewCollectFunction())
_ = Register(NewLastValueFunction())
@@ -109,6 +109,7 @@ func registerBuiltinFunctions() {
// Expression functions
_ = Register(NewExpressionFunction())
_ = Register(NewExprFunction())
_ = Register(NewExpressionAggregatorFunction())
// JSON functions
_ = Register(NewToJsonFunction())
@@ -144,10 +145,15 @@ func registerBuiltinFunctions() {
_ = Register(NewIsObjectFunction())
// Conditional functions
_ = Register(NewIfNullFunction())
_ = Register(NewCoalesceFunction())
_ = Register(NewNullIfFunction())
_ = Register(NewGreatestFunction())
_ = Register(NewLeastFunction())
_ = Register(NewCaseWhenFunction())
// Multi-row functions
_ = Register(NewUnnestFunction())
// User-defined functions (placeholder for future extension)
// Example: _=Register(NewMyUserDefinedFunction())
+5 -3
View File
@@ -99,9 +99,11 @@ func (f *LagFunction) Result() interface{} {
return f.DefaultValue
}
// 返回当前值之前第Offset个值
// 对于数组[first, second, third],当前位置是最后一个元素
// offset=1时返回second倒数第2个),offset=2时返回first(倒数第3个)
return f.PreviousValues[len(f.PreviousValues)-f.Offset-1]
// 对于数组[first, second, third],当前位置是最后一个元素third(索引2
// offset=1时应该返回second索引1),计算:len-1-offset = 3-1-1 = 1
// offset=2时应该返回first(索引0),计算:len-1-offset = 3-1-2 = 0
// 索引计算:len-1-offset,即从最后一个元素往前数offset个位置
return f.PreviousValues[len(f.PreviousValues)-1-f.Offset]
}
func (f *LagFunction) Clone() AggregatorFunction {
+25 -25
View File
@@ -47,12 +47,12 @@ func (f *ArrayContainsFunction) Validate(args []interface{}) error {
func (f *ArrayContainsFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
array := args[0]
value := args[1]
v := reflect.ValueOf(array)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("array_contains requires array input")
}
for i := 0; i < v.Len(); i++ {
if reflect.DeepEqual(v.Index(i).Interface(), value) {
return true, nil
@@ -79,12 +79,12 @@ func (f *ArrayPositionFunction) Validate(args []interface{}) error {
func (f *ArrayPositionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
array := args[0]
value := args[1]
v := reflect.ValueOf(array)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("array_position requires array input")
}
for i := 0; i < v.Len(); i++ {
if reflect.DeepEqual(v.Index(i).Interface(), value) {
return i + 1, nil // 返回1基索引
@@ -111,12 +111,12 @@ func (f *ArrayRemoveFunction) Validate(args []interface{}) error {
func (f *ArrayRemoveFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
array := args[0]
value := args[1]
v := reflect.ValueOf(array)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("array_remove requires array input")
}
var result []interface{}
for i := 0; i < v.Len(); i++ {
elem := v.Index(i).Interface()
@@ -144,15 +144,15 @@ func (f *ArrayDistinctFunction) Validate(args []interface{}) error {
func (f *ArrayDistinctFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
array := args[0]
v := reflect.ValueOf(array)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("array_distinct requires array input")
}
seen := make(map[interface{}]bool)
var result []interface{}
for i := 0; i < v.Len(); i++ {
elem := v.Index(i).Interface()
if !seen[elem] {
@@ -181,27 +181,27 @@ func (f *ArrayIntersectFunction) Validate(args []interface{}) error {
func (f *ArrayIntersectFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
array1 := args[0]
array2 := args[1]
v1 := reflect.ValueOf(array1)
v2 := reflect.ValueOf(array2)
if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array {
return nil, fmt.Errorf("array_intersect requires array input for first argument")
}
if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array {
return nil, fmt.Errorf("array_intersect requires array input for second argument")
}
// 创建第二个数组的元素集合
set2 := make(map[interface{}]bool)
for i := 0; i < v2.Len(); i++ {
set2[v2.Index(i).Interface()] = true
}
// 找交集
seen := make(map[interface{}]bool)
var result []interface{}
for i := 0; i < v1.Len(); i++ {
elem := v1.Index(i).Interface()
if set2[elem] && !seen[elem] {
@@ -230,20 +230,20 @@ func (f *ArrayUnionFunction) Validate(args []interface{}) error {
func (f *ArrayUnionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
array1 := args[0]
array2 := args[1]
v1 := reflect.ValueOf(array1)
v2 := reflect.ValueOf(array2)
if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array {
return nil, fmt.Errorf("array_union requires array input for first argument")
}
if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array {
return nil, fmt.Errorf("array_union requires array input for second argument")
}
seen := make(map[interface{}]bool)
var result []interface{}
// 添加第一个数组的元素
for i := 0; i < v1.Len(); i++ {
elem := v1.Index(i).Interface()
@@ -252,7 +252,7 @@ func (f *ArrayUnionFunction) Execute(ctx *FunctionContext, args []interface{}) (
result = append(result, elem)
}
}
// 添加第二个数组的元素
for i := 0; i < v2.Len(); i++ {
elem := v2.Index(i).Interface()
@@ -282,27 +282,27 @@ func (f *ArrayExceptFunction) Validate(args []interface{}) error {
func (f *ArrayExceptFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
array1 := args[0]
array2 := args[1]
v1 := reflect.ValueOf(array1)
v2 := reflect.ValueOf(array2)
if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array {
return nil, fmt.Errorf("array_except requires array input for first argument")
}
if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array {
return nil, fmt.Errorf("array_except requires array input for second argument")
}
// 创建第二个数组的元素集合
set2 := make(map[interface{}]bool)
for i := 0; i < v2.Len(); i++ {
set2[v2.Index(i).Interface()] = true
}
// 找差集
seen := make(map[interface{}]bool)
var result []interface{}
for i := 0; i < v1.Len(); i++ {
elem := v1.Index(i).Interface()
if !set2[elem] && !seen[elem] {
@@ -311,4 +311,4 @@ func (f *ArrayExceptFunction) Execute(ctx *FunctionContext, args []interface{})
}
}
return result, nil
}
}
+65
View File
@@ -0,0 +1,65 @@
package functions
import (
"testing"
)
// 测试数组函数
func TestArrayFunctions(t *testing.T) {
tests := []struct {
name string
funcName string
args []interface{}
expected interface{}
}{
{
name: "array_length basic",
funcName: "array_length",
args: []interface{}{[]interface{}{1, 2, 3}},
expected: 3,
},
{
name: "array_contains true",
funcName: "array_contains",
args: []interface{}{[]interface{}{1, 2, 3}, 2},
expected: true,
},
{
name: "array_contains false",
funcName: "array_contains",
args: []interface{}{[]interface{}{1, 2, 3}, 4},
expected: false,
},
{
name: "array_position found",
funcName: "array_position",
args: []interface{}{[]interface{}{1, 2, 3}, 2},
expected: 2,
},
{
name: "array_position not found",
funcName: "array_position",
args: []interface{}{[]interface{}{1, 2, 3}, 4},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
result, err := fn.Execute(&FunctionContext{}, tt.args)
if err != nil {
t.Errorf("Execute() error = %v", err)
return
}
if result != tt.expected {
t.Errorf("Execute() = %v, want %v", result, tt.expected)
}
})
}
}
-141
View File
@@ -1,141 +0,0 @@
package functions
import (
"testing"
)
func TestCompressionFunctions(t *testing.T) {
tests := []struct {
name string
funcName string
args []interface{}
wantErr bool
}{
// Compress function tests
{
name: "compress_gzip_valid",
funcName: "compress",
args: []interface{}{"hello world", "gzip"},
wantErr: false,
},
{
name: "compress_zlib_valid",
funcName: "compress",
args: []interface{}{"hello world", "zlib"},
wantErr: false,
},
{
name: "compress_invalid_algorithm",
funcName: "compress",
args: []interface{}{"hello world", "invalid"},
wantErr: true,
},
{
name: "compress_empty_string",
funcName: "compress",
args: []interface{}{"", "gzip"},
wantErr: false,
},
{
name: "compress_wrong_arg_count",
funcName: "compress",
args: []interface{}{"hello"},
wantErr: true,
},
// Decompress function tests
{
name: "decompress_invalid_base64",
funcName: "decompress",
args: []interface{}{"invalid_base64", "gzip"},
wantErr: true,
},
{
name: "decompress_invalid_algorithm",
funcName: "decompress",
args: []interface{}{"SGVsbG8gV29ybGQ=", "invalid"},
wantErr: true,
},
{
name: "decompress_wrong_arg_count",
funcName: "decompress",
args: []interface{}{"SGVsbG8gV29ybGQ="},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("%s function not found", tt.funcName)
}
// 执行函数
_, err := fn.Execute(nil, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestCompressionDecompressionRoundTrip(t *testing.T) {
tests := []struct {
name string
algorithm string
input string
}{
{
name: "gzip_round_trip",
algorithm: "gzip",
input: "Hello, World! This is a test string for compression.",
},
{
name: "zlib_round_trip",
algorithm: "zlib",
input: "Hello, World! This is a test string for compression.",
},
{
name: "gzip_empty_string",
algorithm: "gzip",
input: "",
},
{
name: "zlib_unicode",
algorithm: "zlib",
input: "你好世界!这是一个测试字符串。",
},
}
compressFn, exists := Get("compress")
if !exists {
t.Fatal("compress function not found")
}
decompressFn, exists := Get("decompress")
if !exists {
t.Fatal("decompress function not found")
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 压缩
compressed, err := compressFn.Execute(nil, []interface{}{tt.input, tt.algorithm})
if err != nil {
t.Fatalf("Compress failed: %v", err)
}
// 解压缩
decompressed, err := decompressFn.Execute(nil, []interface{}{compressed, tt.algorithm})
if err != nil {
t.Fatalf("Decompress failed: %v", err)
}
// 验证结果
if decompressed != tt.input {
t.Errorf("Round trip failed: expected %q, got %q", tt.input, decompressed)
}
})
}
}
+110 -10
View File
@@ -7,6 +7,28 @@ import (
"github.com/rulego/streamsql/utils/cast"
)
// IfNullFunction 如果第一个参数为NULL则返回第二个参数
type IfNullFunction struct {
*BaseFunction
}
func NewIfNullFunction() *IfNullFunction {
return &IfNullFunction{
BaseFunction: NewBaseFunction("if_null", TypeString, "条件函数", "如果第一个参数为NULL则返回第二个参数", 2, 2),
}
}
func (f *IfNullFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *IfNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if args[0] == nil {
return args[1], nil
}
return args[0], nil
}
// CoalesceFunction 返回第一个非NULL值
type CoalesceFunction struct {
*BaseFunction
@@ -38,7 +60,7 @@ type NullIfFunction struct {
func NewNullIfFunction() *NullIfFunction {
return &NullIfFunction{
BaseFunction: NewBaseFunction("nullif", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2),
BaseFunction: NewBaseFunction("null_if", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2),
}
}
@@ -72,21 +94,21 @@ func (f *GreatestFunction) Execute(ctx *FunctionContext, args []interface{}) (in
if len(args) == 0 {
return nil, nil
}
max := args[0]
if max == nil {
return nil, nil
}
for i := 1; i < len(args); i++ {
if args[i] == nil {
return nil, nil
}
// 尝试转换为数字进行比较
maxVal, err1 := cast.ToFloat64E(max)
currVal, err2 := cast.ToFloat64E(args[i])
if err1 == nil && err2 == nil {
if currVal > maxVal {
max = args[i]
@@ -122,21 +144,21 @@ func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (inter
if len(args) == 0 {
return nil, nil
}
min := args[0]
if min == nil {
return nil, nil
}
for i := 1; i < len(args); i++ {
if args[i] == nil {
return nil, nil
}
// 尝试转换为数字进行比较
minVal, err1 := cast.ToFloat64E(min)
currVal, err2 := cast.ToFloat64E(args[i])
if err1 == nil && err2 == nil {
if currVal < minVal {
min = args[i]
@@ -151,4 +173,82 @@ func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (inter
}
}
return min, nil
}
}
// CaseWhenFunction CASE WHEN表达式
type CaseWhenFunction struct {
*BaseFunction
}
func NewCaseWhenFunction() *CaseWhenFunction {
return &CaseWhenFunction{
BaseFunction: NewBaseFunction("case_when", TypeString, "条件函数", "CASE WHEN表达式", 2, -1),
}
}
func (f *CaseWhenFunction) Validate(args []interface{}) error {
if len(args) < 2 {
return fmt.Errorf("case_when requires at least 2 arguments")
}
// 参数必须是偶数个(条件-值对)或奇数个(最后一个是默认值)
if len(args)%2 == 0 {
// 偶数个参数,必须都是条件-值对
for i := 0; i < len(args); i += 2 {
// 条件应该是布尔值或可以转换为布尔值的表达式
}
} else {
// 奇数个参数,最后一个是默认值
for i := 0; i < len(args)-1; i += 2 {
// 条件应该是布尔值或可以转换为布尔值的表达式
}
}
return nil
}
func (f *CaseWhenFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
// 处理条件-值对
for i := 0; i < len(args)-1; i += 2 {
condition := args[i]
value := args[i+1]
// 将条件转换为布尔值
condBool, err := cast.ToBoolE(condition)
if err != nil {
// 如果无法转换为布尔值,检查是否为非零/非空值
if condition == nil {
condBool = false
} else {
switch v := condition.(type) {
case string:
condBool = v != ""
case int, int32, int64:
num, _ := cast.ToInt64E(v)
condBool = num != 0
case float32, float64:
num, _ := cast.ToFloat64E(v)
condBool = num != 0.0
default:
condBool = true
}
}
}
if condBool {
return value, nil
}
}
// 如果没有条件匹配,返回默认值(如果有)
if len(args)%2 == 1 {
return args[len(args)-1], nil
}
// 没有默认值,返回 nil
return nil, nil
}
+105
View File
@@ -0,0 +1,105 @@
package functions
import (
"testing"
)
// 测试条件函数
func TestConditionalFunctions(t *testing.T) {
tests := []struct {
name string
funcName string
args []interface{}
expected interface{}
wantErr bool
}{
{
name: "if_null with null",
funcName: "if_null",
args: []interface{}{nil, "default"},
expected: "default",
},
{
name: "if_null with value",
funcName: "if_null",
args: []interface{}{"value", "default"},
expected: "value",
},
{
name: "null_if equal",
funcName: "null_if",
args: []interface{}{"test", "test"},
expected: nil,
},
{
name: "null_if not equal",
funcName: "null_if",
args: []interface{}{"test", "other"},
expected: "test",
},
{
name: "greatest basic",
funcName: "greatest",
args: []interface{}{1, 3, 2},
expected: 3,
},
{
name: "least basic",
funcName: "least",
args: []interface{}{1, 3, 2},
expected: 1,
},
// case_when 函数测试
{
name: "case_when simple",
funcName: "case_when",
args: []interface{}{true, "result1", false, "result2", "default"},
expected: "result1",
},
{
name: "case_when second condition",
funcName: "case_when",
args: []interface{}{false, "result1", true, "result2", "default"},
expected: "result2",
},
{
name: "case_when default",
funcName: "case_when",
args: []interface{}{false, "result1", false, "result2", "default"},
expected: "default",
},
{
name: "case_when no default",
funcName: "case_when",
args: []interface{}{false, "result1", false, "result2"},
expected: nil,
},
{
name: "case_when invalid args",
funcName: "case_when",
args: []interface{}{true},
expected: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
result, err := fn.Execute(&FunctionContext{}, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && result != tt.expected {
t.Errorf("Execute() = %v, want %v", result, tt.expected)
}
})
}
}
+88 -127
View File
@@ -1,18 +1,15 @@
package functions
import (
"bytes"
"compress/gzip"
"compress/zlib"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/rulego/streamsql/utils/cast"
"io"
"math"
"net/url"
"strconv"
"time"
"github.com/rulego/streamsql/utils/cast"
)
// CastFunction 类型转换函数
@@ -265,19 +262,19 @@ func (f *ConvertTzFunction) Execute(ctx *FunctionContext, args []interface{}) (i
default:
return nil, fmt.Errorf("time value must be time.Time or string")
}
// 获取目标时区
timezone, err := cast.ToStringE(args[1])
if err != nil {
return nil, err
}
// 加载时区
loc, err := time.LoadLocation(timezone)
if err != nil {
return nil, fmt.Errorf("invalid timezone: %s", timezone)
}
// 转换时区
return t.In(loc), nil
}
@@ -324,7 +321,7 @@ func (f *ToSecondsFunction) Execute(ctx *FunctionContext, args []interface{}) (i
default:
return nil, fmt.Errorf("time value must be time.Time or string")
}
return t.Unix(), nil
}
@@ -348,151 +345,115 @@ func (f *ChrFunction) Execute(ctx *FunctionContext, args []interface{}) (interfa
if err != nil {
return nil, err
}
if code < 0 || code > 127 {
return nil, fmt.Errorf("ASCII code must be between 0 and 127, got %d", code)
}
return string(rune(code)), nil
}
// UrlEncodeFunction URL编码函数
type UrlEncodeFunction struct {
*BaseFunction
}
func NewUrlEncodeFunction() *UrlEncodeFunction {
return &UrlEncodeFunction{
BaseFunction: NewBaseFunction("url_encode", TypeConversion, "转换函数", "URL编码", 1, 1),
}
}
func (f *UrlEncodeFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *UrlEncodeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
if args[0] == nil {
return nil, fmt.Errorf("url_encode: input cannot be nil")
}
input := cast.ToString(args[0])
return url.QueryEscape(input), nil
}
// UrlDecodeFunction URL解码函数
type UrlDecodeFunction struct {
*BaseFunction
}
func NewUrlDecodeFunction() *UrlDecodeFunction {
return &UrlDecodeFunction{
BaseFunction: NewBaseFunction("url_decode", TypeConversion, "转换函数", "URL解码", 1, 1),
}
}
func (f *UrlDecodeFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *UrlDecodeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
if args[0] == nil {
return nil, fmt.Errorf("url_decode: input cannot be nil")
}
input := cast.ToString(args[0])
result, err := url.QueryUnescape(input)
if err != nil {
return nil, fmt.Errorf("URL decode failed: %v", err)
}
return result, nil
}
// TruncFunction 截断小数位数
type TruncFunction struct {
*BaseFunction
}
// NewTruncFunction 创建新的 trunc 函数
func NewTruncFunction() *TruncFunction {
return &TruncFunction{
BaseFunction: NewBaseFunction("trunc", TypeConversion, "转换函数", "截断小数位数", 2, 2),
}
}
// Validate 验证参数
func (f *TruncFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
// Execute 执行函数
func (f *TruncFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
val, err := cast.ToFloat64E(args[0])
if err != nil {
if err := f.Validate(args); err != nil {
return nil, err
}
precision, err := cast.ToIntE(args[1])
if err != nil {
return nil, err
}
// 转换第一个参数为浮点数
num := cast.ToFloat64(args[0])
// 转换第二个参数为整数(精度)
precision := cast.ToInt(args[1])
// 精度不能为负数
if precision < 0 {
return nil, fmt.Errorf("precision must be non-negative, got %d", precision)
return nil, fmt.Errorf("trunc precision cannot be negative")
}
// 计算截断
multiplier := math.Pow(10, float64(precision))
return math.Trunc(val*multiplier) / multiplier, nil
}
// CompressFunction 压缩函数
type CompressFunction struct {
*BaseFunction
}
func NewCompressFunction() *CompressFunction {
return &CompressFunction{
BaseFunction: NewBaseFunction("compress", TypeConversion, "转换函数", "压缩字符串或二进制值", 2, 2),
if num >= 0 {
return math.Floor(num*multiplier) / multiplier, nil
} else {
return math.Ceil(num*multiplier) / multiplier, nil
}
}
func (f *CompressFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *CompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
input := cast.ToString(args[0])
algorithm := cast.ToString(args[1])
var buf bytes.Buffer
var writer io.WriteCloser
switch algorithm {
case "gzip":
writer = gzip.NewWriter(&buf)
case "zlib":
writer = zlib.NewWriter(&buf)
default:
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
}
_, err := writer.Write([]byte(input))
if err != nil {
return nil, fmt.Errorf("compression failed: %v", err)
}
err = writer.Close()
if err != nil {
return nil, fmt.Errorf("compression failed: %v", err)
}
// 返回base64编码的压缩数据
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// DecompressFunction 解压缩函数
type DecompressFunction struct {
*BaseFunction
}
func NewDecompressFunction() *DecompressFunction {
return &DecompressFunction{
BaseFunction: NewBaseFunction("decompress", TypeConversion, "转换函数", "解压缩字符串或二进制值", 2, 2),
}
}
func (f *DecompressFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *DecompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
input := cast.ToString(args[0])
algorithm := cast.ToString(args[1])
// 解码base64数据
compressedData, err := base64.StdEncoding.DecodeString(input)
if err != nil {
return nil, fmt.Errorf("invalid base64 input: %v", err)
}
buf := bytes.NewReader(compressedData)
var reader io.ReadCloser
switch algorithm {
case "gzip":
reader, err = gzip.NewReader(buf)
if err != nil {
return nil, fmt.Errorf("gzip decompression failed: %v", err)
}
case "zlib":
reader, err = zlib.NewReader(buf)
if err != nil {
return nil, fmt.Errorf("zlib decompression failed: %v", err)
}
default:
return nil, fmt.Errorf("unsupported decompression algorithm: %s", algorithm)
}
defer reader.Close()
result, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("decompression failed: %v", err)
}
return string(result), nil
}
+71 -7
View File
@@ -39,7 +39,7 @@ func TestNewConversionFunctions(t *testing.T) {
args: []interface{}{"invalid-time", "UTC"},
wantErr: true,
},
// to_seconds 函数测试
{
name: "to_seconds with time.Time",
@@ -61,7 +61,7 @@ func TestNewConversionFunctions(t *testing.T) {
args: []interface{}{"invalid-time"},
wantErr: true,
},
// chr 函数测试
{
name: "chr valid ASCII code",
@@ -89,7 +89,7 @@ func TestNewConversionFunctions(t *testing.T) {
args: []interface{}{128},
wantErr: true,
},
// trunc 函数测试
{
name: "trunc positive number",
@@ -98,6 +98,70 @@ func TestNewConversionFunctions(t *testing.T) {
want: 3.14,
wantErr: false,
},
// url_encode 函数测试
{
name: "url_encode basic",
funcName: "url_encode",
args: []interface{}{"hello world"},
want: "hello+world",
wantErr: false,
},
{
name: "url_encode special chars",
funcName: "url_encode",
args: []interface{}{"hello@world.com"},
want: "hello%40world.com",
wantErr: false,
},
{
name: "url_encode empty",
funcName: "url_encode",
args: []interface{}{""},
want: "",
wantErr: false,
},
{
name: "url_encode nil",
funcName: "url_encode",
args: []interface{}{nil},
wantErr: true,
},
// url_decode 函数测试
{
name: "url_decode basic",
funcName: "url_decode",
args: []interface{}{"hello+world"},
want: "hello world",
wantErr: false,
},
{
name: "url_decode special chars",
funcName: "url_decode",
args: []interface{}{"hello%40world.com"},
want: "hello@world.com",
wantErr: false,
},
{
name: "url_decode empty",
funcName: "url_decode",
args: []interface{}{""},
want: "",
wantErr: false,
},
{
name: "url_decode nil",
funcName: "url_decode",
args: []interface{}{nil},
wantErr: true,
},
{
name: "url_decode invalid",
funcName: "url_decode",
args: []interface{}{"hello%ZZ"},
wantErr: true,
},
{
name: "trunc negative number",
funcName: "trunc",
@@ -119,20 +183,20 @@ func TestNewConversionFunctions(t *testing.T) {
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
result, err := fn.Execute(nil, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
// 对于时间类型,需要特殊处理比较
if tt.funcName == "convert_tz" {
@@ -157,4 +221,4 @@ func TestNewConversionFunctions(t *testing.T) {
}
})
}
}
}
+14 -2
View File
@@ -479,6 +479,12 @@ func (f *YearFunction) Validate(args []interface{}) error {
}
func (f *YearFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
// 首先检查是否是 time.Time 类型
if t, ok := args[0].(time.Time); ok {
return float64(t.Year()), nil
}
// 如果不是 time.Time,尝试转换为字符串并解析
dateStr, err := cast.ToStringE(args[0])
if err != nil {
return nil, fmt.Errorf("invalid date: %v", err)
@@ -491,7 +497,7 @@ func (f *YearFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
}
}
return t.Year(), nil
return float64(t.Year()), nil
}
// MonthFunction 提取月份函数
@@ -510,6 +516,12 @@ func (f *MonthFunction) Validate(args []interface{}) error {
}
func (f *MonthFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
// 首先检查是否是 time.Time 类型
if t, ok := args[0].(time.Time); ok {
return float64(t.Month()), nil
}
// 如果不是 time.Time,尝试转换为字符串并解析
dateStr, err := cast.ToStringE(args[0])
if err != nil {
return nil, fmt.Errorf("invalid date: %v", err)
@@ -522,7 +534,7 @@ func (f *MonthFunction) Execute(ctx *FunctionContext, args []interface{}) (inter
}
}
return int(t.Month()), nil
return float64(t.Month()), nil
}
// DayFunction 提取日期函数
+5 -5
View File
@@ -63,7 +63,7 @@ func TestDateTimeFunctions(t *testing.T) {
name: "year extraction",
function: NewYearFunction(),
args: []interface{}{"2023-12-25 15:30:45"},
expected: 2023,
expected: float64(2023),
wantErr: false,
},
// MonthFunction 测试
@@ -71,7 +71,7 @@ func TestDateTimeFunctions(t *testing.T) {
name: "month extraction",
function: NewMonthFunction(),
args: []interface{}{"2023-12-25 15:30:45"},
expected: 12,
expected: float64(12),
wantErr: false,
},
// DayFunction 测试
@@ -164,14 +164,14 @@ func TestDateTimeFunctions(t *testing.T) {
}
return
}
// 执行函数
result, err := tt.function.Execute(nil, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && result != tt.expected {
t.Errorf("Execute() = %v, want %v", result, tt.expected)
}
@@ -235,4 +235,4 @@ func TestDateFormatConversion(t *testing.T) {
}
})
}
}
}
+5 -5
View File
@@ -28,7 +28,7 @@ func (f *Md5Function) Execute(ctx *FunctionContext, args []interface{}) (interfa
if !ok {
return nil, fmt.Errorf("md5 requires string input")
}
hash := md5.Sum([]byte(str))
return fmt.Sprintf("%x", hash), nil
}
@@ -53,7 +53,7 @@ func (f *Sha1Function) Execute(ctx *FunctionContext, args []interface{}) (interf
if !ok {
return nil, fmt.Errorf("sha1 requires string input")
}
hash := sha1.Sum([]byte(str))
return fmt.Sprintf("%x", hash), nil
}
@@ -78,7 +78,7 @@ func (f *Sha256Function) Execute(ctx *FunctionContext, args []interface{}) (inte
if !ok {
return nil, fmt.Errorf("sha256 requires string input")
}
hash := sha256.Sum256([]byte(str))
return fmt.Sprintf("%x", hash), nil
}
@@ -103,7 +103,7 @@ func (f *Sha512Function) Execute(ctx *FunctionContext, args []interface{}) (inte
if !ok {
return nil, fmt.Errorf("sha512 requires string input")
}
hash := sha512.Sum512([]byte(str))
return fmt.Sprintf("%x", hash), nil
}
}
+53
View File
@@ -0,0 +1,53 @@
package functions
import (
"testing"
)
// 测试哈希函数
func TestHashFunctions(t *testing.T) {
tests := []struct {
name string
funcName string
args []interface{}
expected interface{}
}{
{
name: "md5 basic",
funcName: "md5",
args: []interface{}{"hello"},
expected: "5d41402abc4b2a76b9719d911017c592",
},
{
name: "sha1 basic",
funcName: "sha1",
args: []interface{}{"hello"},
expected: "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d",
},
{
name: "sha256 basic",
funcName: "sha256",
args: []interface{}{"hello"},
expected: "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
result, err := fn.Execute(&FunctionContext{}, tt.args)
if err != nil {
t.Errorf("Execute() error = %v", err)
return
}
if result != tt.expected {
t.Errorf("Execute() = %v, want %v", result, tt.expected)
}
})
}
}
+11 -11
View File
@@ -50,7 +50,7 @@ func (f *FromJsonFunction) Execute(ctx *FunctionContext, args []interface{}) (in
if !ok {
return nil, fmt.Errorf("from_json requires string input")
}
var result interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
if err != nil {
@@ -79,18 +79,18 @@ func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{})
if !ok {
return nil, fmt.Errorf("json_extract requires string input")
}
path, ok := args[1].(string)
if !ok {
return nil, fmt.Errorf("json_extract path must be string")
}
var data interface{}
err := json.Unmarshal([]byte(jsonStr), &data)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON: %v", err)
}
// 简单的路径提取,支持 $.field 格式
if strings.HasPrefix(path, "$.") {
field := path[2:]
@@ -98,7 +98,7 @@ func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{})
return dataMap[field], nil
}
}
return nil, fmt.Errorf("invalid JSON path or data structure")
}
@@ -122,7 +122,7 @@ func (f *JsonValidFunction) Execute(ctx *FunctionContext, args []interface{}) (i
if !ok {
return false, nil
}
var temp interface{}
err := json.Unmarshal([]byte(jsonStr), &temp)
return err == nil, nil
@@ -148,13 +148,13 @@ func (f *JsonTypeFunction) Execute(ctx *FunctionContext, args []interface{}) (in
if !ok {
return "unknown", nil
}
var data interface{}
err := json.Unmarshal([]byte(jsonStr), &data)
if err != nil {
return "invalid", nil
}
switch data.(type) {
case nil:
return "null", nil
@@ -193,13 +193,13 @@ func (f *JsonLengthFunction) Execute(ctx *FunctionContext, args []interface{}) (
if !ok {
return nil, fmt.Errorf("json_length requires string input")
}
var data interface{}
err := json.Unmarshal([]byte(jsonStr), &data)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON: %v", err)
}
switch v := data.(type) {
case []interface{}:
return len(v), nil
@@ -208,4 +208,4 @@ func (f *JsonLengthFunction) Execute(ctx *FunctionContext, args []interface{}) (
default:
return nil, fmt.Errorf("JSON value is not an array or object")
}
}
}
+105
View File
@@ -0,0 +1,105 @@
package functions
import (
"testing"
)
// 测试JSON函数
func TestJsonFunctions(t *testing.T) {
tests := []struct {
name string
funcName string
args []interface{}
expected interface{}
wantErr bool
}{
{
name: "to_json basic",
funcName: "to_json",
args: []interface{}{map[string]interface{}{"name": "test", "value": 123}},
expected: `{"name":"test","value":123}`,
},
{
name: "from_json basic",
funcName: "from_json",
args: []interface{}{`{"name":"test","value":123}`},
expected: map[string]interface{}{"name": "test", "value": float64(123)},
},
{
name: "json_extract basic",
funcName: "json_extract",
args: []interface{}{`{"name":"test","value":123}`, "$.name"},
expected: "test",
},
{
name: "json_valid true",
funcName: "json_valid",
args: []interface{}{`{"name":"test"}`},
expected: true,
},
{
name: "json_valid false",
funcName: "json_valid",
args: []interface{}{`{"name":"test"`},
expected: false,
},
{
name: "json_type object",
funcName: "json_type",
args: []interface{}{`{"name":"test"}`},
expected: "object",
},
{
name: "json_length array",
funcName: "json_length",
args: []interface{}{`[1,2,3]`},
expected: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
result, err := fn.Execute(&FunctionContext{}, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !compareResults(result, tt.expected) {
t.Errorf("Execute() = %v, want %v", result, tt.expected)
}
})
}
}
// 辅助函数:比较结果
func compareResults(a, b interface{}) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
// 对于map类型的特殊处理
if mapA, okA := a.(map[string]interface{}); okA {
if mapB, okB := b.(map[string]interface{}); okB {
if len(mapA) != len(mapB) {
return false
}
for k, v := range mapA {
if mapB[k] != v {
return false
}
}
return true
}
}
return a == b
}
+109
View File
@@ -0,0 +1,109 @@
package functions
import (
"fmt"
"reflect"
)
// UnnestFunction 将数组展开为多行
type UnnestFunction struct {
*BaseFunction
}
func NewUnnestFunction() *UnnestFunction {
return &UnnestFunction{
BaseFunction: NewBaseFunction("unnest", TypeString, "多行函数", "将数组展开为多行", 1, 1),
}
}
func (f *UnnestFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *UnnestFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
array := args[0]
if array == nil {
return []interface{}{}, nil
}
// 使用反射检查是否为数组或切片
v := reflect.ValueOf(array)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("unnest requires an array or slice, got %T", array)
}
// 转换为 []interface{}
result := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
elem := v.Index(i).Interface()
// 如果数组元素是对象(map),则展开为列
if elemMap, ok := elem.(map[string]interface{}); ok {
// 对于对象,我们返回一个特殊的结构来表示需要展开为列
result[i] = map[string]interface{}{
"__unnest_object__": true,
"__data__": elemMap,
}
} else {
result[i] = elem
}
}
return result, nil
}
// UnnestResult 表示 unnest 函数的结果
type UnnestResult struct {
Rows []map[string]interface{}
}
// IsUnnestResult 检查是否为 unnest 结果
func IsUnnestResult(value interface{}) bool {
if slice, ok := value.([]interface{}); ok {
for _, item := range slice {
if itemMap, ok := item.(map[string]interface{}); ok {
if unnest, exists := itemMap["__unnest_object__"]; exists {
if unnestBool, ok := unnest.(bool); ok && unnestBool {
return true
}
}
}
}
}
return false
}
// ProcessUnnestResult 处理 unnest 结果,将其转换为多行
func ProcessUnnestResult(value interface{}) []map[string]interface{} {
slice, ok := value.([]interface{})
if !ok {
return nil
}
var rows []map[string]interface{}
for _, item := range slice {
if itemMap, ok := item.(map[string]interface{}); ok {
if unnest, exists := itemMap["__unnest_object__"]; exists {
if unnestBool, ok := unnest.(bool); ok && unnestBool {
if data, exists := itemMap["__data__"]; exists {
if dataMap, ok := data.(map[string]interface{}); ok {
rows = append(rows, dataMap)
}
}
continue
}
}
}
// 对于非对象元素,创建一个包含单个值的行
row := map[string]interface{}{
"value": item,
}
rows = append(rows, row)
}
return rows
}
+119
View File
@@ -0,0 +1,119 @@
package functions
import (
"reflect"
"testing"
)
// TestUnnestFunction 测试unnest函数
func TestUnnestFunction(t *testing.T) {
tests := []struct {
name string
funcName string
args []interface{}
expected interface{}
wantErr bool
}{
// unnest 函数测试
{
name: "unnest array",
funcName: "unnest",
args: []interface{}{[]interface{}{1, 2, 3}},
expected: []interface{}{1, 2, 3},
},
{
name: "unnest empty array",
funcName: "unnest",
args: []interface{}{[]interface{}{}},
expected: []interface{}{},
},
{
name: "unnest nil",
funcName: "unnest",
args: []interface{}{nil},
expected: []interface{}{},
},
{
name: "unnest non-array",
funcName: "unnest",
args: []interface{}{"not an array"},
expected: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
result, err := fn.Execute(&FunctionContext{}, tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("Execute() = %v, want %v", result, tt.expected)
}
}
})
}
}
// TestUnnestWithObjects 测试 unnest 函数处理对象数组
func TestUnnestWithObjects(t *testing.T) {
fn, exists := Get("unnest")
if !exists {
t.Fatal("Function unnest not found")
}
// 测试对象数组
objectArray := []interface{}{
map[string]interface{}{"name": "Alice", "age": 30},
map[string]interface{}{"name": "Bob", "age": 25},
}
result, err := fn.Execute(&FunctionContext{}, []interface{}{objectArray})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
// 检查结果是否包含特殊标记
resultSlice, ok := result.([]interface{})
if !ok {
t.Fatalf("Expected []interface{}, got %T", result)
}
if len(resultSlice) != 2 {
t.Fatalf("Expected 2 items, got %d", len(resultSlice))
}
// 检查第一个对象是否有特殊标记
firstItem, ok := resultSlice[0].(map[string]interface{})
if !ok {
t.Fatalf("Expected map[string]interface{}, got %T", resultSlice[0])
}
if unnestFlag, exists := firstItem["__unnest_object__"]; !exists || unnestFlag != true {
t.Error("Expected __unnest_object__ flag to be true")
}
if data, exists := firstItem["__data__"]; !exists {
t.Error("Expected __data__ field to exist")
} else {
dataMap, ok := data.(map[string]interface{})
if !ok {
t.Errorf("Expected __data__ to be map[string]interface{}, got %T", data)
} else {
if dataMap["name"] != "Alice" || dataMap["age"] != 30 {
t.Errorf("Unexpected data: %v", dataMap)
}
}
}
}
File diff suppressed because it is too large Load Diff
+17 -16
View File
@@ -2,9 +2,10 @@ package functions
import (
"fmt"
"github.com/rulego/streamsql/utils/cast"
"regexp"
"strings"
"github.com/rulego/streamsql/utils/cast"
)
// ConcatFunction 字符串连接函数
@@ -292,26 +293,26 @@ func (f *SubstringFunction) Execute(ctx *FunctionContext, args []interface{}) (i
if err != nil {
return nil, err
}
strLen := int64(len(str))
if start < 0 || start >= strLen {
return "", nil
}
if len(args) == 2 {
return str[start:], nil
}
length, err := cast.ToInt64E(args[2])
if err != nil {
return nil, err
}
end := start + length
if end > strLen {
end = strLen
}
return str[start:end], nil
}
@@ -397,7 +398,7 @@ func (f *LpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
if err != nil {
return nil, err
}
pad := " "
if len(args) == 3 {
pad, err = cast.ToStringE(args[2])
@@ -405,12 +406,12 @@ func (f *LpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
return nil, err
}
}
strLen := int64(len(str))
if strLen >= length {
return str, nil
}
padLen := length - strLen
padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1))
return padStr[:padLen] + str, nil
@@ -440,7 +441,7 @@ func (f *RpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
if err != nil {
return nil, err
}
pad := " "
if len(args) == 3 {
pad, err = cast.ToStringE(args[2])
@@ -448,12 +449,12 @@ func (f *RpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
return nil, err
}
}
strLen := int64(len(str))
if strLen >= length {
return str, nil
}
padLen := length - strLen
padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1))
return str + padStr[:padLen], nil
@@ -533,7 +534,7 @@ func (f *RegexpMatchesFunction) Execute(ctx *FunctionContext, args []interface{}
if err != nil {
return nil, err
}
matched, err := regexp.MatchString(pattern, str)
if err != nil {
return nil, err
@@ -569,7 +570,7 @@ func (f *RegexpReplaceFunction) Execute(ctx *FunctionContext, args []interface{}
if err != nil {
return nil, err
}
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
@@ -601,12 +602,12 @@ func (f *RegexpSubstringFunction) Execute(ctx *FunctionContext, args []interface
if err != nil {
return nil, err
}
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
match := re.FindString(str)
return match, nil
}
+18 -18
View File
@@ -16,88 +16,88 @@ func TestNewStringFunctions(t *testing.T) {
{"endswith_true", "endswith", []interface{}{"hello world", "world"}, true, false},
{"endswith_false", "endswith", []interface{}{"hello world", "hello"}, false, false},
{"endswith_empty", "endswith", []interface{}{"hello", ""}, true, false},
// startswith tests
{"startswith_true", "startswith", []interface{}{"hello world", "hello"}, true, false},
{"startswith_false", "startswith", []interface{}{"hello world", "world"}, false, false},
{"startswith_empty", "startswith", []interface{}{"hello", ""}, true, false},
// indexof tests
{"indexof_found", "indexof", []interface{}{"hello world", "world"}, int64(6), false},
{"indexof_not_found", "indexof", []interface{}{"hello world", "xyz"}, int64(-1), false},
{"indexof_first_char", "indexof", []interface{}{"hello", "h"}, int64(0), false},
// substring tests
{"substring_start_only", "substring", []interface{}{"hello world", int64(6)}, "world", false},
{"substring_start_length", "substring", []interface{}{"hello world", int64(0), int64(5)}, "hello", false},
{"substring_out_of_bounds", "substring", []interface{}{"hello", int64(10)}, "", false},
// replace tests
{"replace_simple", "replace", []interface{}{"hello world", "world", "Go"}, "hello Go", false},
{"replace_multiple", "replace", []interface{}{"hello hello", "hello", "hi"}, "hi hi", false},
{"replace_not_found", "replace", []interface{}{"hello world", "xyz", "abc"}, "hello world", false},
// split tests
{"split_comma", "split", []interface{}{"a,b,c", ","}, []string{"a", "b", "c"}, false},
{"split_space", "split", []interface{}{"hello world", " "}, []string{"hello", "world"}, false},
{"split_not_found", "split", []interface{}{"hello", ","}, []string{"hello"}, false},
// lpad tests
{"lpad_default", "lpad", []interface{}{"hello", int64(10)}, " hello", false},
{"lpad_custom", "lpad", []interface{}{"hello", int64(8), "*"}, "***hello", false},
{"lpad_no_padding", "lpad", []interface{}{"hello", int64(3)}, "hello", false},
// rpad tests
{"rpad_default", "rpad", []interface{}{"hello", int64(10)}, "hello ", false},
{"rpad_custom", "rpad", []interface{}{"hello", int64(8), "*"}, "hello***", false},
{"rpad_no_padding", "rpad", []interface{}{"hello", int64(3)}, "hello", false},
// ltrim tests
{"ltrim_spaces", "ltrim", []interface{}{" hello world "}, "hello world ", false},
{"ltrim_tabs", "ltrim", []interface{}{"\t\nhello"}, "hello", false},
{"ltrim_no_whitespace", "ltrim", []interface{}{"hello"}, "hello", false},
// rtrim tests
{"rtrim_spaces", "rtrim", []interface{}{" hello world "}, " hello world", false},
{"rtrim_tabs", "rtrim", []interface{}{"hello\t\n"}, "hello", false},
{"rtrim_no_whitespace", "rtrim", []interface{}{"hello"}, "hello", false},
// regexp_matches tests
{"regexp_matches_true", "regexp_matches", []interface{}{"hello123", "[0-9]+"}, true, false},
{"regexp_matches_false", "regexp_matches", []interface{}{"hello", "[0-9]+"}, false, false},
{"regexp_matches_email", "regexp_matches", []interface{}{"test@example.com", "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"}, true, false},
// regexp_replace tests
{"regexp_replace_digits", "regexp_replace", []interface{}{"hello123world456", "[0-9]+", "X"}, "helloXworldX", false},
{"regexp_replace_no_match", "regexp_replace", []interface{}{"hello", "[0-9]+", "X"}, "hello", false},
// regexp_substring tests
{"regexp_substring_found", "regexp_substring", []interface{}{"hello123world", "[0-9]+"}, "123", false},
{"regexp_substring_not_found", "regexp_substring", []interface{}{"hello", "[0-9]+"}, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
ctx := &FunctionContext{}
result, err := fn.Execute(ctx, tt.args)
if tt.wantErr {
if err == nil {
t.Errorf("Expected error for %s, got nil", tt.name)
}
return
}
if err != nil {
t.Errorf("Unexpected error for %s: %v", tt.name, err)
return
}
// 特殊处理 split 函数的结果比较
if tt.funcName == "split" {
expectedSlice, ok := tt.expected.([]string)
@@ -127,4 +127,4 @@ func TestNewStringFunctions(t *testing.T) {
}
})
}
}
}
+6 -6
View File
@@ -61,7 +61,7 @@ func (f *IsNumericFunction) Execute(ctx *FunctionContext, args []interface{}) (i
if args[0] == nil {
return false, nil
}
v := reflect.ValueOf(args[0])
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
@@ -92,7 +92,7 @@ func (f *IsStringFunction) Execute(ctx *FunctionContext, args []interface{}) (in
if args[0] == nil {
return false, nil
}
_, ok := args[0].(string)
return ok, nil
}
@@ -116,7 +116,7 @@ func (f *IsBoolFunction) Execute(ctx *FunctionContext, args []interface{}) (inte
if args[0] == nil {
return false, nil
}
_, ok := args[0].(bool)
return ok, nil
}
@@ -140,7 +140,7 @@ func (f *IsArrayFunction) Execute(ctx *FunctionContext, args []interface{}) (int
if args[0] == nil {
return false, nil
}
v := reflect.ValueOf(args[0])
return v.Kind() == reflect.Slice || v.Kind() == reflect.Array, nil
}
@@ -164,7 +164,7 @@ func (f *IsObjectFunction) Execute(ctx *FunctionContext, args []interface{}) (in
if args[0] == nil {
return false, nil
}
v := reflect.ValueOf(args[0])
return v.Kind() == reflect.Map || v.Kind() == reflect.Struct, nil
}
}
+95
View File
@@ -0,0 +1,95 @@
package functions
import (
"testing"
)
// 测试类型检查函数
func TestTypeFunctions(t *testing.T) {
tests := []struct {
name string
funcName string
args []interface{}
expected interface{}
}{
{
name: "is_null true",
funcName: "is_null",
args: []interface{}{nil},
expected: true,
},
{
name: "is_null false",
funcName: "is_null",
args: []interface{}{"test"},
expected: false,
},
{
name: "is_not_null true",
funcName: "is_not_null",
args: []interface{}{"test"},
expected: true,
},
{
name: "is_not_null false",
funcName: "is_not_null",
args: []interface{}{nil},
expected: false,
},
{
name: "is_numeric true",
funcName: "is_numeric",
args: []interface{}{123},
expected: true,
},
{
name: "is_numeric false",
funcName: "is_numeric",
args: []interface{}{"test"},
expected: false,
},
{
name: "is_string true",
funcName: "is_string",
args: []interface{}{"test"},
expected: true,
},
{
name: "is_string false",
funcName: "is_string",
args: []interface{}{123},
expected: false,
},
{
name: "is_bool true",
funcName: "is_bool",
args: []interface{}{true},
expected: true,
},
{
name: "is_bool false",
funcName: "is_bool",
args: []interface{}{"test"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fn, exists := Get(tt.funcName)
if !exists {
t.Fatalf("Function %s not found", tt.funcName)
}
result, err := fn.Execute(&FunctionContext{}, tt.args)
if err != nil {
t.Errorf("Execute() error = %v", err)
return
}
if result != tt.expected {
t.Errorf("Execute() = %v, want %v", result, tt.expected)
}
})
}
}
+83 -14
View File
@@ -303,34 +303,69 @@ func (f *FirstValueFunction) Clone() AggregatorFunction {
// LeadFunction 返回当前行之后第N行的值
type LeadFunction struct {
*BaseFunction
values []interface{}
values []interface{}
offset int
defaultValue interface{}
hasDefault bool
}
func NewLeadFunction() *LeadFunction {
return &LeadFunction{
BaseFunction: NewBaseFunction("lead", TypeWindow, "窗口函数", "返回当前行之后第N行的值", 1, 3),
values: make([]interface{}, 0),
offset: 1, // 默认偏移量为1
}
}
func (f *LeadFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
if err := f.ValidateArgCount(args); err != nil {
return err
}
// 验证第二个参数(offset)是否为整数
if len(args) >= 2 {
if offset, ok := args[1].(int); ok {
f.offset = offset
} else {
return fmt.Errorf("offset must be an integer")
}
}
// 设置默认值
if len(args) >= 3 {
f.defaultValue = args[2]
f.hasDefault = true
}
return nil
}
func (f *LeadFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
// 获取默认值
var defaultValue interface{}
if len(args) >= 3 {
defaultValue = args[2]
// 获取偏移量
if len(args) >= 2 {
if offset, ok := args[1].(int); ok {
f.offset = offset
} else {
return nil, fmt.Errorf("offset must be an integer")
}
}
// 获取默认值
if len(args) >= 3 {
f.defaultValue = args[2]
f.hasDefault = true
}
// Lead函数需要在窗口处理完成后才能确定值
// 这里返回默认值,实际实现需要在窗口引擎中处理
return defaultValue, nil
if f.hasDefault {
return f.defaultValue, nil
}
return nil, nil
}
// 实现AggregatorFunction接口
@@ -338,6 +373,9 @@ func (f *LeadFunction) New() AggregatorFunction {
return &LeadFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, 0),
offset: f.offset,
defaultValue: f.defaultValue,
hasDefault: f.hasDefault,
}
}
@@ -347,18 +385,28 @@ func (f *LeadFunction) Add(value interface{}) {
func (f *LeadFunction) Result() interface{} {
// Lead函数的结果需要在所有数据添加完成后计算
// 如果没有足够的数据,返回默认值
if len(f.values) == 0 && f.hasDefault {
return f.defaultValue
}
// 这里简化实现,返回nil
return nil
}
func (f *LeadFunction) Reset() {
f.values = make([]interface{}, 0)
f.offset = 1
f.defaultValue = nil
f.hasDefault = false
}
func (f *LeadFunction) Clone() AggregatorFunction {
clone := &LeadFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, len(f.values)),
offset: f.offset,
defaultValue: f.defaultValue,
hasDefault: f.hasDefault,
}
copy(clone.values, f.values)
return clone
@@ -380,14 +428,35 @@ func NewNthValueFunction() *NthValueFunction {
}
func (f *NthValueFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
if err := f.ValidateArgCount(args); err != nil {
return err
}
// 验证N值
n := 1
if nVal, ok := args[1].(int); ok {
n = nVal
} else if nVal, ok := args[1].(int64); ok {
n = int(nVal)
} else {
return fmt.Errorf("nth_value n must be an integer")
}
if n <= 0 {
return fmt.Errorf("nth_value n must be positive, got %d", n)
}
// 设置n值
f.n = n
return nil
}
func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
// 获取N值
n := 1
if nVal, ok := args[1].(int); ok {
@@ -397,16 +466,16 @@ func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (in
} else {
return nil, fmt.Errorf("nth_value n must be an integer")
}
if n <= 0 {
return nil, fmt.Errorf("nth_value n must be positive, got %d", n)
}
// 返回第N个值(1-based索引)
if len(f.values) >= n {
return f.values[n-1], nil
}
return nil, nil
}
+32 -2
View File
@@ -4,6 +4,20 @@ import (
"testing"
)
// isWindowFunction 判断是否为窗口函数
func isWindowFunction(funcName string) bool {
windowFunctions := map[string]bool{
"row_number": true,
"window_start": true,
"window_end": true,
"lead": true,
"lag": true,
"first_value": true,
"last_value": true,
}
return windowFunctions[funcName]
}
func TestNewWindowFunctions(t *testing.T) {
tests := []struct {
name string
@@ -222,8 +236,19 @@ func TestNewWindowFunctions(t *testing.T) {
tt.setup(aggInstance)
}
// 执行函数
_, err = fn.Execute(nil, tt.args)
// 对于窗口函数测试,不需要调用Execute方法
// Execute方法主要用于流式处理,这里我们直接测试聚合器的Result方法
// 如果需要测试Execute方法,应该在原始函数实例上调用
if !isWindowFunction(tt.funcName) {
// 对于非窗口函数,在聚合器实例上执行
if aggFunc, ok := aggInstance.(Function); ok {
_, err = aggFunc.Execute(nil, tt.args)
} else {
// 执行函数
_, err = fn.Execute(nil, tt.args)
}
}
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -249,6 +274,11 @@ func TestWindowFunctionBasics(t *testing.T) {
t.Fatal("row_number function not found")
}
// 重置函数状态
if rowNum, ok := rowNumFunc.(*RowNumberFunction); ok {
rowNum.Reset()
}
// 测试行号递增
result1, err := rowNumFunc.Execute(nil, []interface{}{})
if err != nil {