mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-05-12 13:37:05 +00:00
feat:增加函数实现
This commit is contained in:
+10
-4
@@ -65,8 +65,8 @@ func registerBuiltinFunctions() {
|
||||
_ = Register(NewToSecondsFunction())
|
||||
_ = Register(NewChrFunction())
|
||||
_ = Register(NewTruncFunction())
|
||||
_ = Register(NewCompressFunction())
|
||||
_ = Register(NewDecompressFunction())
|
||||
_ = Register(NewUrlEncodeFunction())
|
||||
_ = Register(NewUrlDecodeFunction())
|
||||
|
||||
// Time-Date functions
|
||||
_ = Register(NewNowFunction())
|
||||
@@ -79,8 +79,8 @@ func registerBuiltinFunctions() {
|
||||
_ = Register(NewMinFunction())
|
||||
_ = Register(NewMaxFunction())
|
||||
_ = Register(NewCountFunction())
|
||||
_ = Register(NewStdDevFunction())
|
||||
_ = Register(NewMedianFunction())
|
||||
_ = Register(NewStdDevAggregatorFunction())
|
||||
_ = Register(NewMedianAggregatorFunction())
|
||||
_ = Register(NewPercentileFunction())
|
||||
_ = Register(NewCollectFunction())
|
||||
_ = Register(NewLastValueFunction())
|
||||
@@ -109,6 +109,7 @@ func registerBuiltinFunctions() {
|
||||
// Expression functions
|
||||
_ = Register(NewExpressionFunction())
|
||||
_ = Register(NewExprFunction())
|
||||
_ = Register(NewExpressionAggregatorFunction())
|
||||
|
||||
// JSON functions
|
||||
_ = Register(NewToJsonFunction())
|
||||
@@ -144,10 +145,15 @@ func registerBuiltinFunctions() {
|
||||
_ = Register(NewIsObjectFunction())
|
||||
|
||||
// Conditional functions
|
||||
_ = Register(NewIfNullFunction())
|
||||
_ = Register(NewCoalesceFunction())
|
||||
_ = Register(NewNullIfFunction())
|
||||
_ = Register(NewGreatestFunction())
|
||||
_ = Register(NewLeastFunction())
|
||||
_ = Register(NewCaseWhenFunction())
|
||||
|
||||
// Multi-row functions
|
||||
_ = Register(NewUnnestFunction())
|
||||
|
||||
// User-defined functions (placeholder for future extension)
|
||||
// Example: _=Register(NewMyUserDefinedFunction())
|
||||
|
||||
@@ -99,9 +99,11 @@ func (f *LagFunction) Result() interface{} {
|
||||
return f.DefaultValue
|
||||
}
|
||||
// 返回当前值之前第Offset个值
|
||||
// 对于数组[first, second, third],当前位置是最后一个元素
|
||||
// offset=1时返回second(倒数第2个),offset=2时返回first(倒数第3个)
|
||||
return f.PreviousValues[len(f.PreviousValues)-f.Offset-1]
|
||||
// 对于数组[first, second, third],当前位置是最后一个元素third(索引2)
|
||||
// offset=1时应该返回second(索引1),计算:len-1-offset = 3-1-1 = 1
|
||||
// offset=2时应该返回first(索引0),计算:len-1-offset = 3-1-2 = 0
|
||||
// 索引计算:len-1-offset,即从最后一个元素往前数offset个位置
|
||||
return f.PreviousValues[len(f.PreviousValues)-1-f.Offset]
|
||||
}
|
||||
|
||||
func (f *LagFunction) Clone() AggregatorFunction {
|
||||
|
||||
@@ -47,12 +47,12 @@ func (f *ArrayContainsFunction) Validate(args []interface{}) error {
|
||||
func (f *ArrayContainsFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
array := args[0]
|
||||
value := args[1]
|
||||
|
||||
|
||||
v := reflect.ValueOf(array)
|
||||
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_contains requires array input")
|
||||
}
|
||||
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if reflect.DeepEqual(v.Index(i).Interface(), value) {
|
||||
return true, nil
|
||||
@@ -79,12 +79,12 @@ func (f *ArrayPositionFunction) Validate(args []interface{}) error {
|
||||
func (f *ArrayPositionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
array := args[0]
|
||||
value := args[1]
|
||||
|
||||
|
||||
v := reflect.ValueOf(array)
|
||||
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_position requires array input")
|
||||
}
|
||||
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if reflect.DeepEqual(v.Index(i).Interface(), value) {
|
||||
return i + 1, nil // 返回1基索引
|
||||
@@ -111,12 +111,12 @@ func (f *ArrayRemoveFunction) Validate(args []interface{}) error {
|
||||
func (f *ArrayRemoveFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
array := args[0]
|
||||
value := args[1]
|
||||
|
||||
|
||||
v := reflect.ValueOf(array)
|
||||
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_remove requires array input")
|
||||
}
|
||||
|
||||
|
||||
var result []interface{}
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i).Interface()
|
||||
@@ -144,15 +144,15 @@ func (f *ArrayDistinctFunction) Validate(args []interface{}) error {
|
||||
|
||||
func (f *ArrayDistinctFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
array := args[0]
|
||||
|
||||
|
||||
v := reflect.ValueOf(array)
|
||||
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_distinct requires array input")
|
||||
}
|
||||
|
||||
|
||||
seen := make(map[interface{}]bool)
|
||||
var result []interface{}
|
||||
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i).Interface()
|
||||
if !seen[elem] {
|
||||
@@ -181,27 +181,27 @@ func (f *ArrayIntersectFunction) Validate(args []interface{}) error {
|
||||
func (f *ArrayIntersectFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
array1 := args[0]
|
||||
array2 := args[1]
|
||||
|
||||
|
||||
v1 := reflect.ValueOf(array1)
|
||||
v2 := reflect.ValueOf(array2)
|
||||
|
||||
|
||||
if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_intersect requires array input for first argument")
|
||||
}
|
||||
if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_intersect requires array input for second argument")
|
||||
}
|
||||
|
||||
|
||||
// 创建第二个数组的元素集合
|
||||
set2 := make(map[interface{}]bool)
|
||||
for i := 0; i < v2.Len(); i++ {
|
||||
set2[v2.Index(i).Interface()] = true
|
||||
}
|
||||
|
||||
|
||||
// 找交集
|
||||
seen := make(map[interface{}]bool)
|
||||
var result []interface{}
|
||||
|
||||
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
elem := v1.Index(i).Interface()
|
||||
if set2[elem] && !seen[elem] {
|
||||
@@ -230,20 +230,20 @@ func (f *ArrayUnionFunction) Validate(args []interface{}) error {
|
||||
func (f *ArrayUnionFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
array1 := args[0]
|
||||
array2 := args[1]
|
||||
|
||||
|
||||
v1 := reflect.ValueOf(array1)
|
||||
v2 := reflect.ValueOf(array2)
|
||||
|
||||
|
||||
if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_union requires array input for first argument")
|
||||
}
|
||||
if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_union requires array input for second argument")
|
||||
}
|
||||
|
||||
|
||||
seen := make(map[interface{}]bool)
|
||||
var result []interface{}
|
||||
|
||||
|
||||
// 添加第一个数组的元素
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
elem := v1.Index(i).Interface()
|
||||
@@ -252,7 +252,7 @@ func (f *ArrayUnionFunction) Execute(ctx *FunctionContext, args []interface{}) (
|
||||
result = append(result, elem)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 添加第二个数组的元素
|
||||
for i := 0; i < v2.Len(); i++ {
|
||||
elem := v2.Index(i).Interface()
|
||||
@@ -282,27 +282,27 @@ func (f *ArrayExceptFunction) Validate(args []interface{}) error {
|
||||
func (f *ArrayExceptFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
array1 := args[0]
|
||||
array2 := args[1]
|
||||
|
||||
|
||||
v1 := reflect.ValueOf(array1)
|
||||
v2 := reflect.ValueOf(array2)
|
||||
|
||||
|
||||
if v1.Kind() != reflect.Slice && v1.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_except requires array input for first argument")
|
||||
}
|
||||
if v2.Kind() != reflect.Slice && v2.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("array_except requires array input for second argument")
|
||||
}
|
||||
|
||||
|
||||
// 创建第二个数组的元素集合
|
||||
set2 := make(map[interface{}]bool)
|
||||
for i := 0; i < v2.Len(); i++ {
|
||||
set2[v2.Index(i).Interface()] = true
|
||||
}
|
||||
|
||||
|
||||
// 找差集
|
||||
seen := make(map[interface{}]bool)
|
||||
var result []interface{}
|
||||
|
||||
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
elem := v1.Index(i).Interface()
|
||||
if !set2[elem] && !seen[elem] {
|
||||
@@ -311,4 +311,4 @@ func (f *ArrayExceptFunction) Execute(ctx *FunctionContext, args []interface{})
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// 测试数组函数
|
||||
func TestArrayFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
funcName string
|
||||
args []interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "array_length basic",
|
||||
funcName: "array_length",
|
||||
args: []interface{}{[]interface{}{1, 2, 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "array_contains true",
|
||||
funcName: "array_contains",
|
||||
args: []interface{}{[]interface{}{1, 2, 3}, 2},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "array_contains false",
|
||||
funcName: "array_contains",
|
||||
args: []interface{}{[]interface{}{1, 2, 3}, 4},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "array_position found",
|
||||
funcName: "array_position",
|
||||
args: []interface{}{[]interface{}{1, 2, 3}, 2},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "array_position not found",
|
||||
funcName: "array_position",
|
||||
args: []interface{}{[]interface{}{1, 2, 3}, 4},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
result, err := fn.Execute(&FunctionContext{}, tt.args)
|
||||
if err != nil {
|
||||
t.Errorf("Execute() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Execute() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompressionFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
funcName string
|
||||
args []interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
// Compress function tests
|
||||
{
|
||||
name: "compress_gzip_valid",
|
||||
funcName: "compress",
|
||||
args: []interface{}{"hello world", "gzip"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "compress_zlib_valid",
|
||||
funcName: "compress",
|
||||
args: []interface{}{"hello world", "zlib"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "compress_invalid_algorithm",
|
||||
funcName: "compress",
|
||||
args: []interface{}{"hello world", "invalid"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "compress_empty_string",
|
||||
funcName: "compress",
|
||||
args: []interface{}{"", "gzip"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "compress_wrong_arg_count",
|
||||
funcName: "compress",
|
||||
args: []interface{}{"hello"},
|
||||
wantErr: true,
|
||||
},
|
||||
// Decompress function tests
|
||||
{
|
||||
name: "decompress_invalid_base64",
|
||||
funcName: "decompress",
|
||||
args: []interface{}{"invalid_base64", "gzip"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "decompress_invalid_algorithm",
|
||||
funcName: "decompress",
|
||||
args: []interface{}{"SGVsbG8gV29ybGQ=", "invalid"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "decompress_wrong_arg_count",
|
||||
funcName: "decompress",
|
||||
args: []interface{}{"SGVsbG8gV29ybGQ="},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("%s function not found", tt.funcName)
|
||||
}
|
||||
|
||||
// 执行函数
|
||||
_, err := fn.Execute(nil, tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressionDecompressionRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
algorithm string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "gzip_round_trip",
|
||||
algorithm: "gzip",
|
||||
input: "Hello, World! This is a test string for compression.",
|
||||
},
|
||||
{
|
||||
name: "zlib_round_trip",
|
||||
algorithm: "zlib",
|
||||
input: "Hello, World! This is a test string for compression.",
|
||||
},
|
||||
{
|
||||
name: "gzip_empty_string",
|
||||
algorithm: "gzip",
|
||||
input: "",
|
||||
},
|
||||
{
|
||||
name: "zlib_unicode",
|
||||
algorithm: "zlib",
|
||||
input: "你好世界!这是一个测试字符串。",
|
||||
},
|
||||
}
|
||||
|
||||
compressFn, exists := Get("compress")
|
||||
if !exists {
|
||||
t.Fatal("compress function not found")
|
||||
}
|
||||
|
||||
decompressFn, exists := Get("decompress")
|
||||
if !exists {
|
||||
t.Fatal("decompress function not found")
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 压缩
|
||||
compressed, err := compressFn.Execute(nil, []interface{}{tt.input, tt.algorithm})
|
||||
if err != nil {
|
||||
t.Fatalf("Compress failed: %v", err)
|
||||
}
|
||||
|
||||
// 解压缩
|
||||
decompressed, err := decompressFn.Execute(nil, []interface{}{compressed, tt.algorithm})
|
||||
if err != nil {
|
||||
t.Fatalf("Decompress failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
if decompressed != tt.input {
|
||||
t.Errorf("Round trip failed: expected %q, got %q", tt.input, decompressed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,28 @@ import (
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
)
|
||||
|
||||
// IfNullFunction 如果第一个参数为NULL则返回第二个参数
|
||||
type IfNullFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
func NewIfNullFunction() *IfNullFunction {
|
||||
return &IfNullFunction{
|
||||
BaseFunction: NewBaseFunction("if_null", TypeString, "条件函数", "如果第一个参数为NULL则返回第二个参数", 2, 2),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *IfNullFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
}
|
||||
|
||||
func (f *IfNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if args[0] == nil {
|
||||
return args[1], nil
|
||||
}
|
||||
return args[0], nil
|
||||
}
|
||||
|
||||
// CoalesceFunction 返回第一个非NULL值
|
||||
type CoalesceFunction struct {
|
||||
*BaseFunction
|
||||
@@ -38,7 +60,7 @@ type NullIfFunction struct {
|
||||
|
||||
func NewNullIfFunction() *NullIfFunction {
|
||||
return &NullIfFunction{
|
||||
BaseFunction: NewBaseFunction("nullif", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2),
|
||||
BaseFunction: NewBaseFunction("null_if", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,21 +94,21 @@ func (f *GreatestFunction) Execute(ctx *FunctionContext, args []interface{}) (in
|
||||
if len(args) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
max := args[0]
|
||||
if max == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
for i := 1; i < len(args); i++ {
|
||||
if args[i] == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
// 尝试转换为数字进行比较
|
||||
maxVal, err1 := cast.ToFloat64E(max)
|
||||
currVal, err2 := cast.ToFloat64E(args[i])
|
||||
|
||||
|
||||
if err1 == nil && err2 == nil {
|
||||
if currVal > maxVal {
|
||||
max = args[i]
|
||||
@@ -122,21 +144,21 @@ func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (inter
|
||||
if len(args) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
min := args[0]
|
||||
if min == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
for i := 1; i < len(args); i++ {
|
||||
if args[i] == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
// 尝试转换为数字进行比较
|
||||
minVal, err1 := cast.ToFloat64E(min)
|
||||
currVal, err2 := cast.ToFloat64E(args[i])
|
||||
|
||||
|
||||
if err1 == nil && err2 == nil {
|
||||
if currVal < minVal {
|
||||
min = args[i]
|
||||
@@ -151,4 +173,82 @@ func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (inter
|
||||
}
|
||||
}
|
||||
return min, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CaseWhenFunction CASE WHEN表达式
|
||||
type CaseWhenFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
func NewCaseWhenFunction() *CaseWhenFunction {
|
||||
return &CaseWhenFunction{
|
||||
BaseFunction: NewBaseFunction("case_when", TypeString, "条件函数", "CASE WHEN表达式", 2, -1),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *CaseWhenFunction) Validate(args []interface{}) error {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("case_when requires at least 2 arguments")
|
||||
}
|
||||
|
||||
// 参数必须是偶数个(条件-值对)或奇数个(最后一个是默认值)
|
||||
if len(args)%2 == 0 {
|
||||
// 偶数个参数,必须都是条件-值对
|
||||
for i := 0; i < len(args); i += 2 {
|
||||
// 条件应该是布尔值或可以转换为布尔值的表达式
|
||||
}
|
||||
} else {
|
||||
// 奇数个参数,最后一个是默认值
|
||||
for i := 0; i < len(args)-1; i += 2 {
|
||||
// 条件应该是布尔值或可以转换为布尔值的表达式
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *CaseWhenFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理条件-值对
|
||||
for i := 0; i < len(args)-1; i += 2 {
|
||||
condition := args[i]
|
||||
value := args[i+1]
|
||||
|
||||
// 将条件转换为布尔值
|
||||
condBool, err := cast.ToBoolE(condition)
|
||||
if err != nil {
|
||||
// 如果无法转换为布尔值,检查是否为非零/非空值
|
||||
if condition == nil {
|
||||
condBool = false
|
||||
} else {
|
||||
switch v := condition.(type) {
|
||||
case string:
|
||||
condBool = v != ""
|
||||
case int, int32, int64:
|
||||
num, _ := cast.ToInt64E(v)
|
||||
condBool = num != 0
|
||||
case float32, float64:
|
||||
num, _ := cast.ToFloat64E(v)
|
||||
condBool = num != 0.0
|
||||
default:
|
||||
condBool = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if condBool {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有条件匹配,返回默认值(如果有)
|
||||
if len(args)%2 == 1 {
|
||||
return args[len(args)-1], nil
|
||||
}
|
||||
|
||||
// 没有默认值,返回 nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// 测试条件函数
|
||||
func TestConditionalFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
funcName string
|
||||
args []interface{}
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "if_null with null",
|
||||
funcName: "if_null",
|
||||
args: []interface{}{nil, "default"},
|
||||
expected: "default",
|
||||
},
|
||||
{
|
||||
name: "if_null with value",
|
||||
funcName: "if_null",
|
||||
args: []interface{}{"value", "default"},
|
||||
expected: "value",
|
||||
},
|
||||
{
|
||||
name: "null_if equal",
|
||||
funcName: "null_if",
|
||||
args: []interface{}{"test", "test"},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "null_if not equal",
|
||||
funcName: "null_if",
|
||||
args: []interface{}{"test", "other"},
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "greatest basic",
|
||||
funcName: "greatest",
|
||||
args: []interface{}{1, 3, 2},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "least basic",
|
||||
funcName: "least",
|
||||
args: []interface{}{1, 3, 2},
|
||||
expected: 1,
|
||||
},
|
||||
|
||||
// case_when 函数测试
|
||||
{
|
||||
name: "case_when simple",
|
||||
funcName: "case_when",
|
||||
args: []interface{}{true, "result1", false, "result2", "default"},
|
||||
expected: "result1",
|
||||
},
|
||||
{
|
||||
name: "case_when second condition",
|
||||
funcName: "case_when",
|
||||
args: []interface{}{false, "result1", true, "result2", "default"},
|
||||
expected: "result2",
|
||||
},
|
||||
{
|
||||
name: "case_when default",
|
||||
funcName: "case_when",
|
||||
args: []interface{}{false, "result1", false, "result2", "default"},
|
||||
expected: "default",
|
||||
},
|
||||
{
|
||||
name: "case_when no default",
|
||||
funcName: "case_when",
|
||||
args: []interface{}{false, "result1", false, "result2"},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "case_when invalid args",
|
||||
funcName: "case_when",
|
||||
args: []interface{}{true},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
result, err := fn.Execute(&FunctionContext{}, tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && result != tt.expected {
|
||||
t.Errorf("Execute() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,15 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
"io"
|
||||
"math"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
)
|
||||
|
||||
// CastFunction 类型转换函数
|
||||
@@ -265,19 +262,19 @@ func (f *ConvertTzFunction) Execute(ctx *FunctionContext, args []interface{}) (i
|
||||
default:
|
||||
return nil, fmt.Errorf("time value must be time.Time or string")
|
||||
}
|
||||
|
||||
|
||||
// 获取目标时区
|
||||
timezone, err := cast.ToStringE(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 加载时区
|
||||
loc, err := time.LoadLocation(timezone)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid timezone: %s", timezone)
|
||||
}
|
||||
|
||||
|
||||
// 转换时区
|
||||
return t.In(loc), nil
|
||||
}
|
||||
@@ -324,7 +321,7 @@ func (f *ToSecondsFunction) Execute(ctx *FunctionContext, args []interface{}) (i
|
||||
default:
|
||||
return nil, fmt.Errorf("time value must be time.Time or string")
|
||||
}
|
||||
|
||||
|
||||
return t.Unix(), nil
|
||||
}
|
||||
|
||||
@@ -348,151 +345,115 @@ func (f *ChrFunction) Execute(ctx *FunctionContext, args []interface{}) (interfa
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
if code < 0 || code > 127 {
|
||||
return nil, fmt.Errorf("ASCII code must be between 0 and 127, got %d", code)
|
||||
}
|
||||
|
||||
|
||||
return string(rune(code)), nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// UrlEncodeFunction URL编码函数
|
||||
type UrlEncodeFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
func NewUrlEncodeFunction() *UrlEncodeFunction {
|
||||
return &UrlEncodeFunction{
|
||||
BaseFunction: NewBaseFunction("url_encode", TypeConversion, "转换函数", "URL编码", 1, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UrlEncodeFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
}
|
||||
|
||||
func (f *UrlEncodeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if args[0] == nil {
|
||||
return nil, fmt.Errorf("url_encode: input cannot be nil")
|
||||
}
|
||||
|
||||
input := cast.ToString(args[0])
|
||||
return url.QueryEscape(input), nil
|
||||
}
|
||||
|
||||
// UrlDecodeFunction URL解码函数
|
||||
type UrlDecodeFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
func NewUrlDecodeFunction() *UrlDecodeFunction {
|
||||
return &UrlDecodeFunction{
|
||||
BaseFunction: NewBaseFunction("url_decode", TypeConversion, "转换函数", "URL解码", 1, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UrlDecodeFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
}
|
||||
|
||||
func (f *UrlDecodeFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if args[0] == nil {
|
||||
return nil, fmt.Errorf("url_decode: input cannot be nil")
|
||||
}
|
||||
|
||||
input := cast.ToString(args[0])
|
||||
result, err := url.QueryUnescape(input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("URL decode failed: %v", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TruncFunction 截断小数位数
|
||||
type TruncFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
// NewTruncFunction 创建新的 trunc 函数
|
||||
func NewTruncFunction() *TruncFunction {
|
||||
return &TruncFunction{
|
||||
BaseFunction: NewBaseFunction("trunc", TypeConversion, "转换函数", "截断小数位数", 2, 2),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate 验证参数
|
||||
func (f *TruncFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
}
|
||||
|
||||
// Execute 执行函数
|
||||
func (f *TruncFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
val, err := cast.ToFloat64E(args[0])
|
||||
if err != nil {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
precision, err := cast.ToIntE(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 转换第一个参数为浮点数
|
||||
num := cast.ToFloat64(args[0])
|
||||
|
||||
// 转换第二个参数为整数(精度)
|
||||
precision := cast.ToInt(args[1])
|
||||
|
||||
// 精度不能为负数
|
||||
if precision < 0 {
|
||||
return nil, fmt.Errorf("precision must be non-negative, got %d", precision)
|
||||
return nil, fmt.Errorf("trunc precision cannot be negative")
|
||||
}
|
||||
|
||||
|
||||
// 计算截断
|
||||
multiplier := math.Pow(10, float64(precision))
|
||||
return math.Trunc(val*multiplier) / multiplier, nil
|
||||
}
|
||||
|
||||
// CompressFunction 压缩函数
|
||||
type CompressFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
func NewCompressFunction() *CompressFunction {
|
||||
return &CompressFunction{
|
||||
BaseFunction: NewBaseFunction("compress", TypeConversion, "转换函数", "压缩字符串或二进制值", 2, 2),
|
||||
if num >= 0 {
|
||||
return math.Floor(num*multiplier) / multiplier, nil
|
||||
} else {
|
||||
return math.Ceil(num*multiplier) / multiplier, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *CompressFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
}
|
||||
|
||||
func (f *CompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
input := cast.ToString(args[0])
|
||||
algorithm := cast.ToString(args[1])
|
||||
|
||||
var buf bytes.Buffer
|
||||
var writer io.WriteCloser
|
||||
|
||||
switch algorithm {
|
||||
case "gzip":
|
||||
writer = gzip.NewWriter(&buf)
|
||||
case "zlib":
|
||||
writer = zlib.NewWriter(&buf)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported compression algorithm: %s", algorithm)
|
||||
}
|
||||
|
||||
_, err := writer.Write([]byte(input))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compression failed: %v", err)
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compression failed: %v", err)
|
||||
}
|
||||
|
||||
// 返回base64编码的压缩数据
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
// DecompressFunction 解压缩函数
|
||||
type DecompressFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
func NewDecompressFunction() *DecompressFunction {
|
||||
return &DecompressFunction{
|
||||
BaseFunction: NewBaseFunction("decompress", TypeConversion, "转换函数", "解压缩字符串或二进制值", 2, 2),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DecompressFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
}
|
||||
|
||||
func (f *DecompressFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
input := cast.ToString(args[0])
|
||||
algorithm := cast.ToString(args[1])
|
||||
|
||||
// 解码base64数据
|
||||
compressedData, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 input: %v", err)
|
||||
}
|
||||
|
||||
buf := bytes.NewReader(compressedData)
|
||||
var reader io.ReadCloser
|
||||
|
||||
switch algorithm {
|
||||
case "gzip":
|
||||
reader, err = gzip.NewReader(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip decompression failed: %v", err)
|
||||
}
|
||||
case "zlib":
|
||||
reader, err = zlib.NewReader(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("zlib decompression failed: %v", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported decompression algorithm: %s", algorithm)
|
||||
}
|
||||
|
||||
defer reader.Close()
|
||||
|
||||
result, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decompression failed: %v", err)
|
||||
}
|
||||
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func TestNewConversionFunctions(t *testing.T) {
|
||||
args: []interface{}{"invalid-time", "UTC"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
|
||||
// to_seconds 函数测试
|
||||
{
|
||||
name: "to_seconds with time.Time",
|
||||
@@ -61,7 +61,7 @@ func TestNewConversionFunctions(t *testing.T) {
|
||||
args: []interface{}{"invalid-time"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
|
||||
// chr 函数测试
|
||||
{
|
||||
name: "chr valid ASCII code",
|
||||
@@ -89,7 +89,7 @@ func TestNewConversionFunctions(t *testing.T) {
|
||||
args: []interface{}{128},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
|
||||
// trunc 函数测试
|
||||
{
|
||||
name: "trunc positive number",
|
||||
@@ -98,6 +98,70 @@ func TestNewConversionFunctions(t *testing.T) {
|
||||
want: 3.14,
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// url_encode 函数测试
|
||||
{
|
||||
name: "url_encode basic",
|
||||
funcName: "url_encode",
|
||||
args: []interface{}{"hello world"},
|
||||
want: "hello+world",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "url_encode special chars",
|
||||
funcName: "url_encode",
|
||||
args: []interface{}{"hello@world.com"},
|
||||
want: "hello%40world.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "url_encode empty",
|
||||
funcName: "url_encode",
|
||||
args: []interface{}{""},
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "url_encode nil",
|
||||
funcName: "url_encode",
|
||||
args: []interface{}{nil},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// url_decode 函数测试
|
||||
{
|
||||
name: "url_decode basic",
|
||||
funcName: "url_decode",
|
||||
args: []interface{}{"hello+world"},
|
||||
want: "hello world",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "url_decode special chars",
|
||||
funcName: "url_decode",
|
||||
args: []interface{}{"hello%40world.com"},
|
||||
want: "hello@world.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "url_decode empty",
|
||||
funcName: "url_decode",
|
||||
args: []interface{}{""},
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "url_decode nil",
|
||||
funcName: "url_decode",
|
||||
args: []interface{}{nil},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "url_decode invalid",
|
||||
funcName: "url_decode",
|
||||
args: []interface{}{"hello%ZZ"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "trunc negative number",
|
||||
funcName: "trunc",
|
||||
@@ -119,20 +183,20 @@ func TestNewConversionFunctions(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
|
||||
result, err := fn.Execute(nil, tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if !tt.wantErr {
|
||||
// 对于时间类型,需要特殊处理比较
|
||||
if tt.funcName == "convert_tz" {
|
||||
@@ -157,4 +221,4 @@ func TestNewConversionFunctions(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -479,6 +479,12 @@ func (f *YearFunction) Validate(args []interface{}) error {
|
||||
}
|
||||
|
||||
func (f *YearFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
// 首先检查是否是 time.Time 类型
|
||||
if t, ok := args[0].(time.Time); ok {
|
||||
return float64(t.Year()), nil
|
||||
}
|
||||
|
||||
// 如果不是 time.Time,尝试转换为字符串并解析
|
||||
dateStr, err := cast.ToStringE(args[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid date: %v", err)
|
||||
@@ -491,7 +497,7 @@ func (f *YearFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
|
||||
}
|
||||
}
|
||||
|
||||
return t.Year(), nil
|
||||
return float64(t.Year()), nil
|
||||
}
|
||||
|
||||
// MonthFunction 提取月份函数
|
||||
@@ -510,6 +516,12 @@ func (f *MonthFunction) Validate(args []interface{}) error {
|
||||
}
|
||||
|
||||
func (f *MonthFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
// 首先检查是否是 time.Time 类型
|
||||
if t, ok := args[0].(time.Time); ok {
|
||||
return float64(t.Month()), nil
|
||||
}
|
||||
|
||||
// 如果不是 time.Time,尝试转换为字符串并解析
|
||||
dateStr, err := cast.ToStringE(args[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid date: %v", err)
|
||||
@@ -522,7 +534,7 @@ func (f *MonthFunction) Execute(ctx *FunctionContext, args []interface{}) (inter
|
||||
}
|
||||
}
|
||||
|
||||
return int(t.Month()), nil
|
||||
return float64(t.Month()), nil
|
||||
}
|
||||
|
||||
// DayFunction 提取日期函数
|
||||
|
||||
@@ -63,7 +63,7 @@ func TestDateTimeFunctions(t *testing.T) {
|
||||
name: "year extraction",
|
||||
function: NewYearFunction(),
|
||||
args: []interface{}{"2023-12-25 15:30:45"},
|
||||
expected: 2023,
|
||||
expected: float64(2023),
|
||||
wantErr: false,
|
||||
},
|
||||
// MonthFunction 测试
|
||||
@@ -71,7 +71,7 @@ func TestDateTimeFunctions(t *testing.T) {
|
||||
name: "month extraction",
|
||||
function: NewMonthFunction(),
|
||||
args: []interface{}{"2023-12-25 15:30:45"},
|
||||
expected: 12,
|
||||
expected: float64(12),
|
||||
wantErr: false,
|
||||
},
|
||||
// DayFunction 测试
|
||||
@@ -164,14 +164,14 @@ func TestDateTimeFunctions(t *testing.T) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 执行函数
|
||||
result, err := tt.function.Execute(nil, tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if !tt.wantErr && result != tt.expected {
|
||||
t.Errorf("Execute() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
@@ -235,4 +235,4 @@ func TestDateFormatConversion(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func (f *Md5Function) Execute(ctx *FunctionContext, args []interface{}) (interfa
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("md5 requires string input")
|
||||
}
|
||||
|
||||
|
||||
hash := md5.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash), nil
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func (f *Sha1Function) Execute(ctx *FunctionContext, args []interface{}) (interf
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sha1 requires string input")
|
||||
}
|
||||
|
||||
|
||||
hash := sha1.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash), nil
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func (f *Sha256Function) Execute(ctx *FunctionContext, args []interface{}) (inte
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sha256 requires string input")
|
||||
}
|
||||
|
||||
|
||||
hash := sha256.Sum256([]byte(str))
|
||||
return fmt.Sprintf("%x", hash), nil
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func (f *Sha512Function) Execute(ctx *FunctionContext, args []interface{}) (inte
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sha512 requires string input")
|
||||
}
|
||||
|
||||
|
||||
hash := sha512.Sum512([]byte(str))
|
||||
return fmt.Sprintf("%x", hash), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// 测试哈希函数
|
||||
func TestHashFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
funcName string
|
||||
args []interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "md5 basic",
|
||||
funcName: "md5",
|
||||
args: []interface{}{"hello"},
|
||||
expected: "5d41402abc4b2a76b9719d911017c592",
|
||||
},
|
||||
{
|
||||
name: "sha1 basic",
|
||||
funcName: "sha1",
|
||||
args: []interface{}{"hello"},
|
||||
expected: "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d",
|
||||
},
|
||||
{
|
||||
name: "sha256 basic",
|
||||
funcName: "sha256",
|
||||
args: []interface{}{"hello"},
|
||||
expected: "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
result, err := fn.Execute(&FunctionContext{}, tt.args)
|
||||
if err != nil {
|
||||
t.Errorf("Execute() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Execute() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+11
-11
@@ -50,7 +50,7 @@ func (f *FromJsonFunction) Execute(ctx *FunctionContext, args []interface{}) (in
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("from_json requires string input")
|
||||
}
|
||||
|
||||
|
||||
var result interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
if err != nil {
|
||||
@@ -79,18 +79,18 @@ func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("json_extract requires string input")
|
||||
}
|
||||
|
||||
|
||||
path, ok := args[1].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("json_extract path must be string")
|
||||
}
|
||||
|
||||
|
||||
var data interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 简单的路径提取,支持 $.field 格式
|
||||
if strings.HasPrefix(path, "$.") {
|
||||
field := path[2:]
|
||||
@@ -98,7 +98,7 @@ func (f *JsonExtractFunction) Execute(ctx *FunctionContext, args []interface{})
|
||||
return dataMap[field], nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil, fmt.Errorf("invalid JSON path or data structure")
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ func (f *JsonValidFunction) Execute(ctx *FunctionContext, args []interface{}) (i
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
var temp interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &temp)
|
||||
return err == nil, nil
|
||||
@@ -148,13 +148,13 @@ func (f *JsonTypeFunction) Execute(ctx *FunctionContext, args []interface{}) (in
|
||||
if !ok {
|
||||
return "unknown", nil
|
||||
}
|
||||
|
||||
|
||||
var data interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &data)
|
||||
if err != nil {
|
||||
return "invalid", nil
|
||||
}
|
||||
|
||||
|
||||
switch data.(type) {
|
||||
case nil:
|
||||
return "null", nil
|
||||
@@ -193,13 +193,13 @@ func (f *JsonLengthFunction) Execute(ctx *FunctionContext, args []interface{}) (
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("json_length requires string input")
|
||||
}
|
||||
|
||||
|
||||
var data interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
|
||||
switch v := data.(type) {
|
||||
case []interface{}:
|
||||
return len(v), nil
|
||||
@@ -208,4 +208,4 @@ func (f *JsonLengthFunction) Execute(ctx *FunctionContext, args []interface{}) (
|
||||
default:
|
||||
return nil, fmt.Errorf("JSON value is not an array or object")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// 测试JSON函数
|
||||
func TestJsonFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
funcName string
|
||||
args []interface{}
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "to_json basic",
|
||||
funcName: "to_json",
|
||||
args: []interface{}{map[string]interface{}{"name": "test", "value": 123}},
|
||||
expected: `{"name":"test","value":123}`,
|
||||
},
|
||||
{
|
||||
name: "from_json basic",
|
||||
funcName: "from_json",
|
||||
args: []interface{}{`{"name":"test","value":123}`},
|
||||
expected: map[string]interface{}{"name": "test", "value": float64(123)},
|
||||
},
|
||||
{
|
||||
name: "json_extract basic",
|
||||
funcName: "json_extract",
|
||||
args: []interface{}{`{"name":"test","value":123}`, "$.name"},
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "json_valid true",
|
||||
funcName: "json_valid",
|
||||
args: []interface{}{`{"name":"test"}`},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "json_valid false",
|
||||
funcName: "json_valid",
|
||||
args: []interface{}{`{"name":"test"`},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "json_type object",
|
||||
funcName: "json_type",
|
||||
args: []interface{}{`{"name":"test"}`},
|
||||
expected: "object",
|
||||
},
|
||||
{
|
||||
name: "json_length array",
|
||||
funcName: "json_length",
|
||||
args: []interface{}{`[1,2,3]`},
|
||||
expected: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
result, err := fn.Execute(&FunctionContext{}, tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && !compareResults(result, tt.expected) {
|
||||
t.Errorf("Execute() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:比较结果
|
||||
func compareResults(a, b interface{}) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于map类型的特殊处理
|
||||
if mapA, okA := a.(map[string]interface{}); okA {
|
||||
if mapB, okB := b.(map[string]interface{}); okB {
|
||||
if len(mapA) != len(mapB) {
|
||||
return false
|
||||
}
|
||||
for k, v := range mapA {
|
||||
if mapB[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return a == b
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// UnnestFunction 将数组展开为多行
|
||||
type UnnestFunction struct {
|
||||
*BaseFunction
|
||||
}
|
||||
|
||||
func NewUnnestFunction() *UnnestFunction {
|
||||
return &UnnestFunction{
|
||||
BaseFunction: NewBaseFunction("unnest", TypeString, "多行函数", "将数组展开为多行", 1, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UnnestFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
}
|
||||
|
||||
func (f *UnnestFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
array := args[0]
|
||||
if array == nil {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
|
||||
// 使用反射检查是否为数组或切片
|
||||
v := reflect.ValueOf(array)
|
||||
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
|
||||
return nil, fmt.Errorf("unnest requires an array or slice, got %T", array)
|
||||
}
|
||||
|
||||
// 转换为 []interface{}
|
||||
result := make([]interface{}, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i).Interface()
|
||||
|
||||
// 如果数组元素是对象(map),则展开为列
|
||||
if elemMap, ok := elem.(map[string]interface{}); ok {
|
||||
// 对于对象,我们返回一个特殊的结构来表示需要展开为列
|
||||
result[i] = map[string]interface{}{
|
||||
"__unnest_object__": true,
|
||||
"__data__": elemMap,
|
||||
}
|
||||
} else {
|
||||
result[i] = elem
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UnnestResult 表示 unnest 函数的结果
|
||||
type UnnestResult struct {
|
||||
Rows []map[string]interface{}
|
||||
}
|
||||
|
||||
// IsUnnestResult 检查是否为 unnest 结果
|
||||
func IsUnnestResult(value interface{}) bool {
|
||||
if slice, ok := value.([]interface{}); ok {
|
||||
for _, item := range slice {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if unnest, exists := itemMap["__unnest_object__"]; exists {
|
||||
if unnestBool, ok := unnest.(bool); ok && unnestBool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ProcessUnnestResult 处理 unnest 结果,将其转换为多行
|
||||
func ProcessUnnestResult(value interface{}) []map[string]interface{} {
|
||||
slice, ok := value.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var rows []map[string]interface{}
|
||||
for _, item := range slice {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if unnest, exists := itemMap["__unnest_object__"]; exists {
|
||||
if unnestBool, ok := unnest.(bool); ok && unnestBool {
|
||||
if data, exists := itemMap["__data__"]; exists {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
rows = append(rows, dataMap)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
// 对于非对象元素,创建一个包含单个值的行
|
||||
row := map[string]interface{}{
|
||||
"value": item,
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
|
||||
return rows
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestUnnestFunction 测试unnest函数
|
||||
func TestUnnestFunction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
funcName string
|
||||
args []interface{}
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
// unnest 函数测试
|
||||
{
|
||||
name: "unnest array",
|
||||
funcName: "unnest",
|
||||
args: []interface{}{[]interface{}{1, 2, 3}},
|
||||
expected: []interface{}{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "unnest empty array",
|
||||
funcName: "unnest",
|
||||
args: []interface{}{[]interface{}{}},
|
||||
expected: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "unnest nil",
|
||||
funcName: "unnest",
|
||||
args: []interface{}{nil},
|
||||
expected: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "unnest non-array",
|
||||
funcName: "unnest",
|
||||
args: []interface{}{"not an array"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
result, err := fn.Execute(&FunctionContext{}, tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("Execute() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnnestWithObjects 测试 unnest 函数处理对象数组
|
||||
func TestUnnestWithObjects(t *testing.T) {
|
||||
fn, exists := Get("unnest")
|
||||
if !exists {
|
||||
t.Fatal("Function unnest not found")
|
||||
}
|
||||
|
||||
// 测试对象数组
|
||||
objectArray := []interface{}{
|
||||
map[string]interface{}{"name": "Alice", "age": 30},
|
||||
map[string]interface{}{"name": "Bob", "age": 25},
|
||||
}
|
||||
|
||||
result, err := fn.Execute(&FunctionContext{}, []interface{}{objectArray})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
// 检查结果是否包含特殊标记
|
||||
resultSlice, ok := result.([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected []interface{}, got %T", result)
|
||||
}
|
||||
|
||||
if len(resultSlice) != 2 {
|
||||
t.Fatalf("Expected 2 items, got %d", len(resultSlice))
|
||||
}
|
||||
|
||||
// 检查第一个对象是否有特殊标记
|
||||
firstItem, ok := resultSlice[0].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected map[string]interface{}, got %T", resultSlice[0])
|
||||
}
|
||||
|
||||
if unnestFlag, exists := firstItem["__unnest_object__"]; !exists || unnestFlag != true {
|
||||
t.Error("Expected __unnest_object__ flag to be true")
|
||||
}
|
||||
|
||||
if data, exists := firstItem["__data__"]; !exists {
|
||||
t.Error("Expected __data__ field to exist")
|
||||
} else {
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected __data__ to be map[string]interface{}, got %T", data)
|
||||
} else {
|
||||
if dataMap["name"] != "Alice" || dataMap["age"] != 30 {
|
||||
t.Errorf("Unexpected data: %v", dataMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,9 +2,10 @@ package functions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
)
|
||||
|
||||
// ConcatFunction 字符串连接函数
|
||||
@@ -292,26 +293,26 @@ func (f *SubstringFunction) Execute(ctx *FunctionContext, args []interface{}) (i
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
strLen := int64(len(str))
|
||||
if start < 0 || start >= strLen {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
|
||||
if len(args) == 2 {
|
||||
return str[start:], nil
|
||||
}
|
||||
|
||||
|
||||
length, err := cast.ToInt64E(args[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
end := start + length
|
||||
if end > strLen {
|
||||
end = strLen
|
||||
}
|
||||
|
||||
|
||||
return str[start:end], nil
|
||||
}
|
||||
|
||||
@@ -397,7 +398,7 @@ func (f *LpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
pad := " "
|
||||
if len(args) == 3 {
|
||||
pad, err = cast.ToStringE(args[2])
|
||||
@@ -405,12 +406,12 @@ func (f *LpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
strLen := int64(len(str))
|
||||
if strLen >= length {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
|
||||
padLen := length - strLen
|
||||
padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1))
|
||||
return padStr[:padLen] + str, nil
|
||||
@@ -440,7 +441,7 @@ func (f *RpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
pad := " "
|
||||
if len(args) == 3 {
|
||||
pad, err = cast.ToStringE(args[2])
|
||||
@@ -448,12 +449,12 @@ func (f *RpadFunction) Execute(ctx *FunctionContext, args []interface{}) (interf
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
strLen := int64(len(str))
|
||||
if strLen >= length {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
|
||||
padLen := length - strLen
|
||||
padStr := strings.Repeat(pad, int(padLen/int64(len(pad))+1))
|
||||
return str + padStr[:padLen], nil
|
||||
@@ -533,7 +534,7 @@ func (f *RegexpMatchesFunction) Execute(ctx *FunctionContext, args []interface{}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
matched, err := regexp.MatchString(pattern, str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -569,7 +570,7 @@ func (f *RegexpReplaceFunction) Execute(ctx *FunctionContext, args []interface{}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -601,12 +602,12 @@ func (f *RegexpSubstringFunction) Execute(ctx *FunctionContext, args []interface
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
match := re.FindString(str)
|
||||
return match, nil
|
||||
}
|
||||
|
||||
@@ -16,88 +16,88 @@ func TestNewStringFunctions(t *testing.T) {
|
||||
{"endswith_true", "endswith", []interface{}{"hello world", "world"}, true, false},
|
||||
{"endswith_false", "endswith", []interface{}{"hello world", "hello"}, false, false},
|
||||
{"endswith_empty", "endswith", []interface{}{"hello", ""}, true, false},
|
||||
|
||||
|
||||
// startswith tests
|
||||
{"startswith_true", "startswith", []interface{}{"hello world", "hello"}, true, false},
|
||||
{"startswith_false", "startswith", []interface{}{"hello world", "world"}, false, false},
|
||||
{"startswith_empty", "startswith", []interface{}{"hello", ""}, true, false},
|
||||
|
||||
|
||||
// indexof tests
|
||||
{"indexof_found", "indexof", []interface{}{"hello world", "world"}, int64(6), false},
|
||||
{"indexof_not_found", "indexof", []interface{}{"hello world", "xyz"}, int64(-1), false},
|
||||
{"indexof_first_char", "indexof", []interface{}{"hello", "h"}, int64(0), false},
|
||||
|
||||
|
||||
// substring tests
|
||||
{"substring_start_only", "substring", []interface{}{"hello world", int64(6)}, "world", false},
|
||||
{"substring_start_length", "substring", []interface{}{"hello world", int64(0), int64(5)}, "hello", false},
|
||||
{"substring_out_of_bounds", "substring", []interface{}{"hello", int64(10)}, "", false},
|
||||
|
||||
|
||||
// replace tests
|
||||
{"replace_simple", "replace", []interface{}{"hello world", "world", "Go"}, "hello Go", false},
|
||||
{"replace_multiple", "replace", []interface{}{"hello hello", "hello", "hi"}, "hi hi", false},
|
||||
{"replace_not_found", "replace", []interface{}{"hello world", "xyz", "abc"}, "hello world", false},
|
||||
|
||||
|
||||
// split tests
|
||||
{"split_comma", "split", []interface{}{"a,b,c", ","}, []string{"a", "b", "c"}, false},
|
||||
{"split_space", "split", []interface{}{"hello world", " "}, []string{"hello", "world"}, false},
|
||||
{"split_not_found", "split", []interface{}{"hello", ","}, []string{"hello"}, false},
|
||||
|
||||
|
||||
// lpad tests
|
||||
{"lpad_default", "lpad", []interface{}{"hello", int64(10)}, " hello", false},
|
||||
{"lpad_custom", "lpad", []interface{}{"hello", int64(8), "*"}, "***hello", false},
|
||||
{"lpad_no_padding", "lpad", []interface{}{"hello", int64(3)}, "hello", false},
|
||||
|
||||
|
||||
// rpad tests
|
||||
{"rpad_default", "rpad", []interface{}{"hello", int64(10)}, "hello ", false},
|
||||
{"rpad_custom", "rpad", []interface{}{"hello", int64(8), "*"}, "hello***", false},
|
||||
{"rpad_no_padding", "rpad", []interface{}{"hello", int64(3)}, "hello", false},
|
||||
|
||||
|
||||
// ltrim tests
|
||||
{"ltrim_spaces", "ltrim", []interface{}{" hello world "}, "hello world ", false},
|
||||
{"ltrim_tabs", "ltrim", []interface{}{"\t\nhello"}, "hello", false},
|
||||
{"ltrim_no_whitespace", "ltrim", []interface{}{"hello"}, "hello", false},
|
||||
|
||||
|
||||
// rtrim tests
|
||||
{"rtrim_spaces", "rtrim", []interface{}{" hello world "}, " hello world", false},
|
||||
{"rtrim_tabs", "rtrim", []interface{}{"hello\t\n"}, "hello", false},
|
||||
{"rtrim_no_whitespace", "rtrim", []interface{}{"hello"}, "hello", false},
|
||||
|
||||
|
||||
// regexp_matches tests
|
||||
{"regexp_matches_true", "regexp_matches", []interface{}{"hello123", "[0-9]+"}, true, false},
|
||||
{"regexp_matches_false", "regexp_matches", []interface{}{"hello", "[0-9]+"}, false, false},
|
||||
{"regexp_matches_email", "regexp_matches", []interface{}{"test@example.com", "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"}, true, false},
|
||||
|
||||
|
||||
// regexp_replace tests
|
||||
{"regexp_replace_digits", "regexp_replace", []interface{}{"hello123world456", "[0-9]+", "X"}, "helloXworldX", false},
|
||||
{"regexp_replace_no_match", "regexp_replace", []interface{}{"hello", "[0-9]+", "X"}, "hello", false},
|
||||
|
||||
|
||||
// regexp_substring tests
|
||||
{"regexp_substring_found", "regexp_substring", []interface{}{"hello123world", "[0-9]+"}, "123", false},
|
||||
{"regexp_substring_not_found", "regexp_substring", []interface{}{"hello", "[0-9]+"}, "", false},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
|
||||
ctx := &FunctionContext{}
|
||||
result, err := fn.Execute(ctx, tt.args)
|
||||
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for %s, got nil", tt.name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for %s: %v", tt.name, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 特殊处理 split 函数的结果比较
|
||||
if tt.funcName == "split" {
|
||||
expectedSlice, ok := tt.expected.([]string)
|
||||
@@ -127,4 +127,4 @@ func TestNewStringFunctions(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (f *IsNumericFunction) Execute(ctx *FunctionContext, args []interface{}) (i
|
||||
if args[0] == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
v := reflect.ValueOf(args[0])
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
@@ -92,7 +92,7 @@ func (f *IsStringFunction) Execute(ctx *FunctionContext, args []interface{}) (in
|
||||
if args[0] == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
_, ok := args[0].(string)
|
||||
return ok, nil
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (f *IsBoolFunction) Execute(ctx *FunctionContext, args []interface{}) (inte
|
||||
if args[0] == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
_, ok := args[0].(bool)
|
||||
return ok, nil
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func (f *IsArrayFunction) Execute(ctx *FunctionContext, args []interface{}) (int
|
||||
if args[0] == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
v := reflect.ValueOf(args[0])
|
||||
return v.Kind() == reflect.Slice || v.Kind() == reflect.Array, nil
|
||||
}
|
||||
@@ -164,7 +164,7 @@ func (f *IsObjectFunction) Execute(ctx *FunctionContext, args []interface{}) (in
|
||||
if args[0] == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
v := reflect.ValueOf(args[0])
|
||||
return v.Kind() == reflect.Map || v.Kind() == reflect.Struct, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// 测试类型检查函数
|
||||
func TestTypeFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
funcName string
|
||||
args []interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "is_null true",
|
||||
funcName: "is_null",
|
||||
args: []interface{}{nil},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is_null false",
|
||||
funcName: "is_null",
|
||||
args: []interface{}{"test"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "is_not_null true",
|
||||
funcName: "is_not_null",
|
||||
args: []interface{}{"test"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is_not_null false",
|
||||
funcName: "is_not_null",
|
||||
args: []interface{}{nil},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "is_numeric true",
|
||||
funcName: "is_numeric",
|
||||
args: []interface{}{123},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is_numeric false",
|
||||
funcName: "is_numeric",
|
||||
args: []interface{}{"test"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "is_string true",
|
||||
funcName: "is_string",
|
||||
args: []interface{}{"test"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is_string false",
|
||||
funcName: "is_string",
|
||||
args: []interface{}{123},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "is_bool true",
|
||||
funcName: "is_bool",
|
||||
args: []interface{}{true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "is_bool false",
|
||||
funcName: "is_bool",
|
||||
args: []interface{}{"test"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fn, exists := Get(tt.funcName)
|
||||
if !exists {
|
||||
t.Fatalf("Function %s not found", tt.funcName)
|
||||
}
|
||||
|
||||
result, err := fn.Execute(&FunctionContext{}, tt.args)
|
||||
if err != nil {
|
||||
t.Errorf("Execute() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Execute() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -303,34 +303,69 @@ func (f *FirstValueFunction) Clone() AggregatorFunction {
|
||||
// LeadFunction 返回当前行之后第N行的值
|
||||
type LeadFunction struct {
|
||||
*BaseFunction
|
||||
values []interface{}
|
||||
values []interface{}
|
||||
offset int
|
||||
defaultValue interface{}
|
||||
hasDefault bool
|
||||
}
|
||||
|
||||
func NewLeadFunction() *LeadFunction {
|
||||
return &LeadFunction{
|
||||
BaseFunction: NewBaseFunction("lead", TypeWindow, "窗口函数", "返回当前行之后第N行的值", 1, 3),
|
||||
values: make([]interface{}, 0),
|
||||
offset: 1, // 默认偏移量为1
|
||||
}
|
||||
}
|
||||
|
||||
func (f *LeadFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
if err := f.ValidateArgCount(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证第二个参数(offset)是否为整数
|
||||
if len(args) >= 2 {
|
||||
if offset, ok := args[1].(int); ok {
|
||||
f.offset = offset
|
||||
} else {
|
||||
return fmt.Errorf("offset must be an integer")
|
||||
}
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if len(args) >= 3 {
|
||||
f.defaultValue = args[2]
|
||||
f.hasDefault = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *LeadFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取默认值
|
||||
var defaultValue interface{}
|
||||
if len(args) >= 3 {
|
||||
defaultValue = args[2]
|
||||
|
||||
// 获取偏移量
|
||||
if len(args) >= 2 {
|
||||
if offset, ok := args[1].(int); ok {
|
||||
f.offset = offset
|
||||
} else {
|
||||
return nil, fmt.Errorf("offset must be an integer")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 获取默认值
|
||||
if len(args) >= 3 {
|
||||
f.defaultValue = args[2]
|
||||
f.hasDefault = true
|
||||
}
|
||||
|
||||
// Lead函数需要在窗口处理完成后才能确定值
|
||||
// 这里返回默认值,实际实现需要在窗口引擎中处理
|
||||
return defaultValue, nil
|
||||
if f.hasDefault {
|
||||
return f.defaultValue, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 实现AggregatorFunction接口
|
||||
@@ -338,6 +373,9 @@ func (f *LeadFunction) New() AggregatorFunction {
|
||||
return &LeadFunction{
|
||||
BaseFunction: f.BaseFunction,
|
||||
values: make([]interface{}, 0),
|
||||
offset: f.offset,
|
||||
defaultValue: f.defaultValue,
|
||||
hasDefault: f.hasDefault,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,18 +385,28 @@ func (f *LeadFunction) Add(value interface{}) {
|
||||
|
||||
func (f *LeadFunction) Result() interface{} {
|
||||
// Lead函数的结果需要在所有数据添加完成后计算
|
||||
// 如果没有足够的数据,返回默认值
|
||||
if len(f.values) == 0 && f.hasDefault {
|
||||
return f.defaultValue
|
||||
}
|
||||
// 这里简化实现,返回nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *LeadFunction) Reset() {
|
||||
f.values = make([]interface{}, 0)
|
||||
f.offset = 1
|
||||
f.defaultValue = nil
|
||||
f.hasDefault = false
|
||||
}
|
||||
|
||||
func (f *LeadFunction) Clone() AggregatorFunction {
|
||||
clone := &LeadFunction{
|
||||
BaseFunction: f.BaseFunction,
|
||||
values: make([]interface{}, len(f.values)),
|
||||
offset: f.offset,
|
||||
defaultValue: f.defaultValue,
|
||||
hasDefault: f.hasDefault,
|
||||
}
|
||||
copy(clone.values, f.values)
|
||||
return clone
|
||||
@@ -380,14 +428,35 @@ func NewNthValueFunction() *NthValueFunction {
|
||||
}
|
||||
|
||||
func (f *NthValueFunction) Validate(args []interface{}) error {
|
||||
return f.ValidateArgCount(args)
|
||||
if err := f.ValidateArgCount(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证N值
|
||||
n := 1
|
||||
if nVal, ok := args[1].(int); ok {
|
||||
n = nVal
|
||||
} else if nVal, ok := args[1].(int64); ok {
|
||||
n = int(nVal)
|
||||
} else {
|
||||
return fmt.Errorf("nth_value n must be an integer")
|
||||
}
|
||||
|
||||
if n <= 0 {
|
||||
return fmt.Errorf("nth_value n must be positive, got %d", n)
|
||||
}
|
||||
|
||||
// 设置n值
|
||||
f.n = n
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
if err := f.Validate(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 获取N值
|
||||
n := 1
|
||||
if nVal, ok := args[1].(int); ok {
|
||||
@@ -397,16 +466,16 @@ func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (in
|
||||
} else {
|
||||
return nil, fmt.Errorf("nth_value n must be an integer")
|
||||
}
|
||||
|
||||
|
||||
if n <= 0 {
|
||||
return nil, fmt.Errorf("nth_value n must be positive, got %d", n)
|
||||
}
|
||||
|
||||
|
||||
// 返回第N个值(1-based索引)
|
||||
if len(f.values) >= n {
|
||||
return f.values[n-1], nil
|
||||
}
|
||||
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,20 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// isWindowFunction 判断是否为窗口函数
|
||||
func isWindowFunction(funcName string) bool {
|
||||
windowFunctions := map[string]bool{
|
||||
"row_number": true,
|
||||
"window_start": true,
|
||||
"window_end": true,
|
||||
"lead": true,
|
||||
"lag": true,
|
||||
"first_value": true,
|
||||
"last_value": true,
|
||||
}
|
||||
return windowFunctions[funcName]
|
||||
}
|
||||
|
||||
func TestNewWindowFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -222,8 +236,19 @@ func TestNewWindowFunctions(t *testing.T) {
|
||||
tt.setup(aggInstance)
|
||||
}
|
||||
|
||||
// 执行函数
|
||||
_, err = fn.Execute(nil, tt.args)
|
||||
// 对于窗口函数测试,不需要调用Execute方法
|
||||
// Execute方法主要用于流式处理,这里我们直接测试聚合器的Result方法
|
||||
// 如果需要测试Execute方法,应该在原始函数实例上调用
|
||||
if !isWindowFunction(tt.funcName) {
|
||||
// 对于非窗口函数,在聚合器实例上执行
|
||||
if aggFunc, ok := aggInstance.(Function); ok {
|
||||
_, err = aggFunc.Execute(nil, tt.args)
|
||||
} else {
|
||||
// 执行函数
|
||||
_, err = fn.Execute(nil, tt.args)
|
||||
}
|
||||
}
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -249,6 +274,11 @@ func TestWindowFunctionBasics(t *testing.T) {
|
||||
t.Fatal("row_number function not found")
|
||||
}
|
||||
|
||||
// 重置函数状态
|
||||
if rowNum, ok := rowNumFunc.(*RowNumberFunction); ok {
|
||||
rowNum.Reset()
|
||||
}
|
||||
|
||||
// 测试行号递增
|
||||
result1, err := rowNumFunc.Execute(nil, []interface{}{})
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user