Files
streamsql/functions/functions_conditional.go
2025-06-11 18:45:39 +08:00

255 lines
5.7 KiB
Go

package functions
import (
"fmt"
"reflect"
"github.com/rulego/streamsql/utils/cast"
)
// IfNullFunction 如果第一个参数为NULL则返回第二个参数
type IfNullFunction struct {
*BaseFunction
}
func NewIfNullFunction() *IfNullFunction {
return &IfNullFunction{
BaseFunction: NewBaseFunction("if_null", TypeString, "条件函数", "如果第一个参数为NULL则返回第二个参数", 2, 2),
}
}
func (f *IfNullFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *IfNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if args[0] == nil {
return args[1], nil
}
return args[0], nil
}
// CoalesceFunction 返回第一个非NULL值
type CoalesceFunction struct {
*BaseFunction
}
func NewCoalesceFunction() *CoalesceFunction {
return &CoalesceFunction{
BaseFunction: NewBaseFunction("coalesce", TypeString, "条件函数", "返回第一个非NULL值", 1, -1),
}
}
func (f *CoalesceFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *CoalesceFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
for _, arg := range args {
if arg != nil {
return arg, nil
}
}
return nil, nil
}
// NullIfFunction 如果两个值相等则返回NULL
type NullIfFunction struct {
*BaseFunction
}
func NewNullIfFunction() *NullIfFunction {
return &NullIfFunction{
BaseFunction: NewBaseFunction("null_if", TypeString, "条件函数", "如果两个值相等则返回NULL", 2, 2),
}
}
func (f *NullIfFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *NullIfFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if reflect.DeepEqual(args[0], args[1]) {
return nil, nil
}
return args[0], nil
}
// GreatestFunction 返回最大值
type GreatestFunction struct {
*BaseFunction
}
func NewGreatestFunction() *GreatestFunction {
return &GreatestFunction{
BaseFunction: NewBaseFunction("greatest", TypeMath, "条件函数", "返回最大值", 1, -1),
}
}
func (f *GreatestFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *GreatestFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if len(args) == 0 {
return nil, nil
}
max := args[0]
if max == nil {
return nil, nil
}
for i := 1; i < len(args); i++ {
if args[i] == nil {
return nil, nil
}
// 尝试转换为数字进行比较
maxVal, err1 := cast.ToFloat64E(max)
currVal, err2 := cast.ToFloat64E(args[i])
if err1 == nil && err2 == nil {
if currVal > maxVal {
max = args[i]
}
} else {
// 如果不能转换为数字,则按字符串比较
maxStr := fmt.Sprintf("%v", max)
currStr := fmt.Sprintf("%v", args[i])
if currStr > maxStr {
max = args[i]
}
}
}
return max, nil
}
// LeastFunction 返回最小值
type LeastFunction struct {
*BaseFunction
}
func NewLeastFunction() *LeastFunction {
return &LeastFunction{
BaseFunction: NewBaseFunction("least", TypeMath, "条件函数", "返回最小值", 1, -1),
}
}
func (f *LeastFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *LeastFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if len(args) == 0 {
return nil, nil
}
min := args[0]
if min == nil {
return nil, nil
}
for i := 1; i < len(args); i++ {
if args[i] == nil {
return nil, nil
}
// 尝试转换为数字进行比较
minVal, err1 := cast.ToFloat64E(min)
currVal, err2 := cast.ToFloat64E(args[i])
if err1 == nil && err2 == nil {
if currVal < minVal {
min = args[i]
}
} else {
// 如果不能转换为数字,则按字符串比较
minStr := fmt.Sprintf("%v", min)
currStr := fmt.Sprintf("%v", args[i])
if currStr < minStr {
min = args[i]
}
}
}
return min, nil
}
// CaseWhenFunction CASE WHEN表达式
type CaseWhenFunction struct {
*BaseFunction
}
func NewCaseWhenFunction() *CaseWhenFunction {
return &CaseWhenFunction{
BaseFunction: NewBaseFunction("case_when", TypeString, "条件函数", "CASE WHEN表达式", 2, -1),
}
}
func (f *CaseWhenFunction) Validate(args []interface{}) error {
if len(args) < 2 {
return fmt.Errorf("case_when requires at least 2 arguments")
}
// 参数必须是偶数个(条件-值对)或奇数个(最后一个是默认值)
if len(args)%2 == 0 {
// 偶数个参数,必须都是条件-值对
for i := 0; i < len(args); i += 2 {
// 条件应该是布尔值或可以转换为布尔值的表达式
}
} else {
// 奇数个参数,最后一个是默认值
for i := 0; i < len(args)-1; i += 2 {
// 条件应该是布尔值或可以转换为布尔值的表达式
}
}
return nil
}
func (f *CaseWhenFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
// 处理条件-值对
for i := 0; i < len(args)-1; i += 2 {
condition := args[i]
value := args[i+1]
// 将条件转换为布尔值
condBool, err := cast.ToBoolE(condition)
if err != nil {
// 如果无法转换为布尔值,检查是否为非零/非空值
if condition == nil {
condBool = false
} else {
switch v := condition.(type) {
case string:
condBool = v != ""
case int, int32, int64:
num, _ := cast.ToInt64E(v)
condBool = num != 0
case float32, float64:
num, _ := cast.ToFloat64E(v)
condBool = num != 0.0
default:
condBool = true
}
}
}
if condBool {
return value, nil
}
}
// 如果没有条件匹配,返回默认值(如果有)
if len(args)%2 == 1 {
return args[len(args)-1], nil
}
// 没有默认值,返回 nil
return nil, nil
}