Files
streamsql/expr/expression.go
T

2190 lines
55 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package expr
import (
"fmt"
"math"
"strconv"
"strings"
"github.com/rulego/streamsql/functions"
"github.com/rulego/streamsql/utils/fieldpath"
)
// 表达式类型
const (
TypeNumber = "number" // 数字常量
TypeField = "field" // 字段引用
TypeOperator = "operator" // 运算符
TypeFunction = "function" // 函数调用
TypeParenthesis = "parenthesis" // 括号
TypeCase = "case" // CASE表达式
TypeString = "string" // 字符串常量
)
// 操作符优先级
var operatorPrecedence = map[string]int{
"OR": 1,
"AND": 2,
"==": 3, "=": 3, "!=": 3, "<>": 3,
">": 4, "<": 4, ">=": 4, "<=": 4, "LIKE": 4, "IS": 4,
"+": 5, "-": 5,
"*": 6, "/": 6, "%": 6,
"^": 7, // 幂运算
}
// CASE表达式的WHEN子句
type WhenClause struct {
Condition *ExprNode // WHEN条件
Result *ExprNode // THEN结果
}
// 表达式节点
type ExprNode struct {
Type string
Value string
Left *ExprNode
Right *ExprNode
Args []*ExprNode // 用于函数调用的参数
// CASE表达式专用字段
CaseExpr *ExprNode // CASE后面的表达式简单CASE
WhenClauses []WhenClause // WHEN子句列表
ElseExpr *ExprNode // ELSE表达式
}
// Expression 表示一个可计算的表达式
type Expression struct {
Root *ExprNode
useExprLang bool // 是否使用expr-lang/expr
exprLangExpression string // expr-lang表达式字符串
}
// NewExpression 创建一个新的表达式
func NewExpression(exprStr string) (*Expression, error) {
// 进行基本的语法验证
if err := validateBasicSyntax(exprStr); err != nil {
return nil, err
}
// 首先尝试使用自定义解析器
tokens, err := tokenize(exprStr)
if err != nil {
// 如果自定义解析失败标记为使用expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
root, err := parseExpression(tokens)
if err != nil {
// 如果自定义解析失败标记为使用expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
return &Expression{
Root: root,
useExprLang: false,
}, nil
}
// validateBasicSyntax 进行基本的语法验证
func validateBasicSyntax(exprStr string) error {
// 检查空表达式
trimmed := strings.TrimSpace(exprStr)
if trimmed == "" {
return fmt.Errorf("empty expression")
}
// 检查不匹配的括号
parenthesesCount := 0
for _, ch := range trimmed {
if ch == '(' {
parenthesesCount++
} else if ch == ')' {
parenthesesCount--
if parenthesesCount < 0 {
return fmt.Errorf("mismatched parentheses")
}
}
}
if parenthesesCount != 0 {
return fmt.Errorf("mismatched parentheses")
}
// 检查无效字符
for i, ch := range trimmed {
// 允许的字符:字母、数字、运算符、括号、点、下划线、空格、引号
if !isValidChar(ch) {
return fmt.Errorf("invalid character '%c' at position %d", ch, i)
}
}
// 检查连续运算符
if err := checkConsecutiveOperators(trimmed); err != nil {
return err
}
return nil
}
// checkConsecutiveOperators 检查连续运算符
func checkConsecutiveOperators(expr string) error {
// 简化的连续运算符检查:查找明显的双运算符模式
// 但要允许比较运算符后跟负数的情况
operators := []string{"+", "-", "*", "/", "%", "^", "==", "!=", ">=", "<=", ">", "<"}
comparisonOps := []string{"==", "!=", ">=", "<=", ">", "<"}
for i := 0; i < len(expr)-1; i++ {
// 跳过空白字符
if expr[i] == ' ' || expr[i] == '\t' {
continue
}
// 检查当前位置是否是运算符
isCurrentOp := false
currentOpLen := 0
currentOp := ""
for _, op := range operators {
if i+len(op) <= len(expr) && expr[i:i+len(op)] == op {
isCurrentOp = true
currentOpLen = len(op)
currentOp = op
break
}
}
if isCurrentOp {
// 查找下一个非空白字符
nextPos := i + currentOpLen
for nextPos < len(expr) && (expr[nextPos] == ' ' || expr[nextPos] == '\t') {
nextPos++
}
// 检查下一个字符是否也是运算符
if nextPos < len(expr) {
// 特殊处理:如果当前是比较运算符,下一个是负号,且负号后跟数字,则允许
isCurrentComparison := false
for _, compOp := range comparisonOps {
if currentOp == compOp {
isCurrentComparison = true
break
}
}
// 检查是否是负数的情况
if isCurrentComparison && nextPos < len(expr) && expr[nextPos] == '-' {
// 检查负号后是否跟数字
digitPos := nextPos + 1
for digitPos < len(expr) && (expr[digitPos] == ' ' || expr[digitPos] == '\t') {
digitPos++
}
if digitPos < len(expr) && expr[digitPos] >= '0' && expr[digitPos] <= '9' {
// 这是比较运算符后跟负数,允许通过
i = nextPos // 跳过到负号位置
continue
}
}
// 检查其他连续运算符
for _, op := range operators {
if nextPos+len(op) <= len(expr) && expr[nextPos:nextPos+len(op)] == op {
// 如果不是允许的负数情况,则报错
return fmt.Errorf("consecutive operators found: '%s' followed by '%s'",
currentOp, op)
}
}
}
// 跳过当前运算符
i += currentOpLen - 1
}
}
return nil
}
// isValidChar 检查字符是否有效
func isValidChar(ch rune) bool {
// 字母和数字
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') {
return true
}
// 特殊字符
switch ch {
case ' ', '\t', '\n', '\r': // 空白字符
return true
case '+', '-', '*', '/', '%', '^': // 算术运算符
return true
case '(', ')', ',': // 括号和逗号
return true
case '>', '<', '=', '!': // 比较运算符
return true
case '\'', '"': // 引号
return true
case '.', '_': // 点和下划线
return true
case '$': // 美元符号用于JSON路径等
return true
default:
return false
}
}
// Evaluate 计算表达式的值
func (e *Expression) Evaluate(data map[string]interface{}) (float64, error) {
if e.useExprLang {
return e.evaluateWithExprLang(data)
}
return evaluateNode(e.Root, data)
}
// evaluateWithExprLang 使用expr-lang/expr评估表达式
func (e *Expression) evaluateWithExprLang(data map[string]interface{}) (float64, error) {
// 使用桥接器评估表达式
bridge := functions.GetExprBridge()
result, err := bridge.EvaluateExpression(e.exprLangExpression, data)
if err != nil {
return 0, err
}
// 尝试转换结果为float64
switch v := result.(type) {
case float64:
return v, nil
case float32:
return float64(v), nil
case int:
return float64(v), nil
case int32:
return float64(v), nil
case int64:
return float64(v), nil
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f, nil
}
return 0, fmt.Errorf("cannot convert string result '%s' to float64", v)
default:
return 0, fmt.Errorf("expression result type %T is not convertible to float64", result)
}
}
// GetFields 获取表达式中引用的所有字段
func (e *Expression) GetFields() []string {
if e.useExprLang {
// 对于expr-lang表达式需要解析字段引用
// 这里简化处理实际应该使用AST分析
return extractFieldsFromExprLang(e.exprLangExpression)
}
fields := make(map[string]bool)
collectFields(e.Root, fields)
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
return result
}
// extractFieldsFromExprLang 从expr-lang表达式中提取字段引用简化版本
func extractFieldsFromExprLang(expression string) []string {
// 这是一个简化的实现实际应该使用AST解析
// 暂时使用正则表达式或简单的字符串解析
fields := make(map[string]bool)
// 简单的字段提取:查找标识符模式,支持点号分隔的嵌套字段
tokens := strings.FieldsFunc(expression, func(c rune) bool {
return !(c >= 'a' && c <= 'z') && !(c >= 'A' && c <= 'Z') && !(c >= '0' && c <= '9') && c != '_' && c != '.'
})
for _, token := range tokens {
if isValidFieldIdentifier(token) && !isNumber(token) && !isFunctionOrKeyword(token) {
fields[token] = true
}
}
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
return result
}
// isValidFieldIdentifier 检查是否是有效的字段标识符(支持点号分隔的嵌套字段)
func isValidFieldIdentifier(s string) bool {
if len(s) == 0 {
return false
}
// 分割点号分隔的字段
parts := strings.Split(s, ".")
for _, part := range parts {
if !isIdentifier(part) {
return false
}
}
return true
}
// isFunctionOrKeyword 检查是否是函数名或关键字
func isFunctionOrKeyword(token string) bool {
// 检查是否是已知函数或关键字
keywords := []string{
"and", "or", "not", "true", "false", "nil", "null", "is",
"if", "else", "then", "in", "contains", "matches",
// CASE表达式关键字
"case", "when", "then", "else", "end",
}
for _, keyword := range keywords {
if strings.ToLower(token) == keyword {
return true
}
}
// 检查是否是注册的函数
bridge := functions.GetExprBridge()
_, exists, _ := bridge.ResolveFunction(token)
return exists
}
// collectFields 收集表达式中所有字段
func collectFields(node *ExprNode, fields map[string]bool) {
if node == nil {
return
}
if node.Type == TypeField {
fields[node.Value] = true
}
// 处理CASE表达式的字段收集
if node.Type == TypeCase {
// 收集CASE表达式本身的字段
if node.CaseExpr != nil {
collectFields(node.CaseExpr, fields)
}
// 收集所有WHEN子句中的字段
for _, whenClause := range node.WhenClauses {
collectFields(whenClause.Condition, fields)
collectFields(whenClause.Result, fields)
}
// 收集ELSE表达式中的字段
if node.ElseExpr != nil {
collectFields(node.ElseExpr, fields)
}
return
}
collectFields(node.Left, fields)
collectFields(node.Right, fields)
for _, arg := range node.Args {
collectFields(arg, fields)
}
}
// evaluateNode 计算节点的值
func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) {
if node == nil {
return 0, fmt.Errorf("null expression node")
}
switch node.Type {
case TypeNumber:
return strconv.ParseFloat(node.Value, 64)
case TypeString:
// 处理字符串类型,去掉引号并尝试转换为数字
// 如果无法转换返回错误因为这个函数返回float64
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1] // 去掉引号
}
// 尝试转换为数字
if f, err := strconv.ParseFloat(value, 64); err == nil {
return f, nil
}
// 对于字符串比较,我们需要返回一个哈希值或者错误
// 这里简化处理,将字符串转换为其长度(作为临时解决方案)
return float64(len(value)), nil
case TypeField:
// 处理反引号标识符,去除反引号
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号
}
// 支持嵌套字段访问
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
// 尝试转换为float64
if floatVal, err := convertToFloat(val); err == nil {
return floatVal, nil
}
// 如果不能转换为数字,返回错误
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
}
} else {
// 原有的简单字段访问
if val, found := data[fieldName]; found {
// 尝试转换为float64
if floatVal, err := convertToFloat(val); err == nil {
return floatVal, nil
}
// 如果不能转换为数字,返回错误
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
}
}
return 0, fmt.Errorf("field '%s' not found", fieldName)
case TypeOperator:
// 计算左右子表达式的值
left, err := evaluateNode(node.Left, data)
if err != nil {
return 0, err
}
right, err := evaluateNode(node.Right, data)
if err != nil {
return 0, err
}
// 执行运算
switch node.Value {
case "+":
return left + right, nil
case "-":
return left - right, nil
case "*":
return left * right, nil
case "/":
if right == 0 {
return 0, fmt.Errorf("division by zero")
}
return left / right, nil
case "%":
if right == 0 {
return 0, fmt.Errorf("modulo by zero")
}
return math.Mod(left, right), nil
case "^":
return math.Pow(left, right), nil
default:
return 0, fmt.Errorf("unknown operator: %s", node.Value)
}
case TypeFunction:
// 首先检查是否是新的函数注册系统中的函数
fn, exists := functions.Get(node.Value)
if exists {
// 计算所有参数,但保持原始类型
args := make([]interface{}, len(node.Args))
for i, arg := range node.Args {
// 使用evaluateNodeValue获取原始类型的值
val, err := evaluateNodeValue(arg, data)
if err != nil {
return 0, err
}
args[i] = val
}
// 创建函数执行上下文
ctx := &functions.FunctionContext{
Data: data,
}
// 执行函数
result, err := fn.Execute(ctx, args)
if err != nil {
return 0, err
}
// 转换结果为 float64
switch r := result.(type) {
case float64:
return r, nil
case float32:
return float64(r), nil
case int:
return float64(r), nil
case int32:
return float64(r), nil
case int64:
return float64(r), nil
case string:
// 对于字符串结果,尝试转换为数字,如果失败则返回字符串长度
if f, err := strconv.ParseFloat(r, 64); err == nil {
return f, nil
}
return float64(len(r)), nil
case bool:
// 布尔值转换true=1, false=0
if r {
return 1.0, nil
}
return 0.0, nil
default:
return 0, fmt.Errorf("function %s returned unsupported type for numeric conversion: %T", node.Value, result)
}
}
// 回退到内置函数处理(保持向后兼容)
return evaluateBuiltinFunction(node, data)
case TypeCase:
// 处理CASE表达式
return evaluateCaseExpression(node, data)
}
return 0, fmt.Errorf("unknown node type: %s", node.Type)
}
// evaluateBuiltinFunction 处理内置函数(向后兼容)
func evaluateBuiltinFunction(node *ExprNode, data map[string]interface{}) (float64, error) {
switch strings.ToLower(node.Value) {
case "abs":
if len(node.Args) != 1 {
return 0, fmt.Errorf("abs function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Abs(arg), nil
case "sqrt":
if len(node.Args) != 1 {
return 0, fmt.Errorf("sqrt function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
if arg < 0 {
return 0, fmt.Errorf("sqrt of negative number")
}
return math.Sqrt(arg), nil
case "sin":
if len(node.Args) != 1 {
return 0, fmt.Errorf("sin function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Sin(arg), nil
case "cos":
if len(node.Args) != 1 {
return 0, fmt.Errorf("cos function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Cos(arg), nil
case "tan":
if len(node.Args) != 1 {
return 0, fmt.Errorf("tan function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Tan(arg), nil
case "floor":
if len(node.Args) != 1 {
return 0, fmt.Errorf("floor function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Floor(arg), nil
case "ceil":
if len(node.Args) != 1 {
return 0, fmt.Errorf("ceil function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Ceil(arg), nil
case "round":
if len(node.Args) != 1 {
return 0, fmt.Errorf("round function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Round(arg), nil
default:
return 0, fmt.Errorf("unknown function: %s", node.Value)
}
}
// evaluateCaseExpression 计算CASE表达式
func evaluateCaseExpression(node *ExprNode, data map[string]interface{}) (float64, error) {
if node.Type != TypeCase {
return 0, fmt.Errorf("node is not a CASE expression")
}
// 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...)
if node.CaseExpr != nil {
// 计算CASE后面的表达式值
caseValue, err := evaluateNodeValue(node.CaseExpr, data)
if err != nil {
return 0, err
}
// 遍历WHEN子句查找匹配的值
for _, whenClause := range node.WhenClauses {
conditionValue, err := evaluateNodeValue(whenClause.Condition, data)
if err != nil {
return 0, err
}
// 比较值是否相等
isEqual, err := compareValues(caseValue, conditionValue, "==")
if err != nil {
return 0, err
}
if isEqual {
return evaluateNode(whenClause.Result, data)
}
}
} else {
// 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...)
for _, whenClause := range node.WhenClauses {
// 评估WHEN条件这里需要特殊处理布尔表达式
conditionResult, err := evaluateBooleanCondition(whenClause.Condition, data)
if err != nil {
return 0, err
}
// 如果条件为真,返回对应的结果
if conditionResult {
return evaluateNode(whenClause.Result, data)
}
}
}
// 如果没有匹配的WHEN子句执行ELSE子句
if node.ElseExpr != nil {
return evaluateNode(node.ElseExpr, data)
}
// 如果没有ELSE子句SQL标准是返回NULL这里返回0
return 0, nil
}
// evaluateBooleanCondition 计算布尔条件表达式
func evaluateBooleanCondition(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil {
return false, fmt.Errorf("null condition expression")
}
// 处理逻辑运算符(实现短路求值)
if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") {
leftBool, err := evaluateBooleanCondition(node.Left, data)
if err != nil {
return false, err
}
// 短路求值对于AND如果左边为false立即返回false
if node.Value == "AND" && !leftBool {
return false, nil
}
// 短路求值对于OR如果左边为true立即返回true
if node.Value == "OR" && leftBool {
return true, nil
}
// 只有在需要时才评估右边的表达式
rightBool, err := evaluateBooleanCondition(node.Right, data)
if err != nil {
return false, err
}
switch node.Value {
case "AND":
return leftBool && rightBool, nil
case "OR":
return leftBool || rightBool, nil
}
}
// 处理IS NULL和IS NOT NULL特殊情况
if node.Type == TypeOperator && node.Value == "IS" {
return evaluateIsCondition(node, data)
}
// 处理比较运算符
if node.Type == TypeOperator {
leftValue, err := evaluateNodeValue(node.Left, data)
if err != nil {
return false, err
}
rightValue, err := evaluateNodeValue(node.Right, data)
if err != nil {
return false, err
}
return compareValues(leftValue, rightValue, node.Value)
}
// 对于其他表达式,计算其数值并转换为布尔值
result, err := evaluateNode(node, data)
if err != nil {
return false, err
}
// 非零值为真,零值为假
return result != 0, nil
}
// evaluateIsCondition 处理IS NULL和IS NOT NULL条件
func evaluateIsCondition(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil || node.Left == nil || node.Right == nil {
return false, fmt.Errorf("invalid IS condition")
}
// 获取左侧值
leftValue, err := evaluateNodeValue(node.Left, data)
if err != nil {
// 如果字段不存在认为是null
leftValue = nil
}
// 检查右侧是否是NULL或NOT NULL
if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" {
// IS NULL
return leftValue == nil, nil
}
// 检查是否是IS NOT NULL
if node.Right.Type == TypeOperator && node.Right.Value == "NOT" &&
node.Right.Right != nil && node.Right.Right.Type == TypeField &&
strings.ToUpper(node.Right.Right.Value) == "NULL" {
// IS NOT NULL
return leftValue != nil, nil
}
// 其他IS比较如IS TRUE, IS FALSE等暂不支持
rightValue, err := evaluateNodeValue(node.Right, data)
if err != nil {
return false, err
}
return compareValues(leftValue, rightValue, "==")
}
// evaluateNodeValue 计算节点值返回interface{}以支持不同类型
func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{}, error) {
if node == nil {
return nil, fmt.Errorf("null expression node")
}
switch node.Type {
case TypeNumber:
return strconv.ParseFloat(node.Value, 64)
case TypeString:
// 去掉引号
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1]
}
return value, nil
case TypeField:
// 处理反引号标识符,去除反引号
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号
}
// 支持嵌套字段访问
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
return val, nil
}
} else {
// 原有的简单字段访问
if val, found := data[fieldName]; found {
return val, nil
}
}
return nil, fmt.Errorf("field '%s' not found", fieldName)
default:
// 对于其他类型,回退到数值计算
return evaluateNode(node, data)
}
}
// compareValues 比较两个值
func compareValues(left, right interface{}, operator string) (bool, error) {
// 尝试字符串比较
leftStr, leftIsStr := left.(string)
rightStr, rightIsStr := right.(string)
if leftIsStr && rightIsStr {
switch operator {
case "==", "=":
return leftStr == rightStr, nil
case "!=", "<>":
return leftStr != rightStr, nil
case ">":
return leftStr > rightStr, nil
case ">=":
return leftStr >= rightStr, nil
case "<":
return leftStr < rightStr, nil
case "<=":
return leftStr <= rightStr, nil
case "LIKE":
return matchesLikePattern(leftStr, rightStr), nil
default:
return false, fmt.Errorf("unsupported string comparison operator: %s", operator)
}
}
// 转换为数值进行比较
leftNum, err1 := convertToFloat(left)
rightNum, err2 := convertToFloat(right)
if err1 != nil || err2 != nil {
return false, fmt.Errorf("cannot compare values: %v and %v", left, right)
}
switch operator {
case ">":
return leftNum > rightNum, nil
case ">=":
return leftNum >= rightNum, nil
case "<":
return leftNum < rightNum, nil
case "<=":
return leftNum <= rightNum, nil
case "==", "=":
return math.Abs(leftNum-rightNum) < 1e-9, nil
case "!=", "<>":
return math.Abs(leftNum-rightNum) >= 1e-9, nil
default:
return false, fmt.Errorf("unsupported comparison operator: %s", operator)
}
}
// matchesLikePattern 实现LIKE模式匹配
// 支持%匹配任意字符序列和_匹配单个字符
func matchesLikePattern(text, pattern string) bool {
return likeMatch(text, pattern, 0, 0)
}
// likeMatch 递归实现LIKE匹配算法
func likeMatch(text, pattern string, textIndex, patternIndex int) bool {
// 如果模式已经匹配完成
if patternIndex >= len(pattern) {
return textIndex >= len(text) // 文本也应该匹配完成
}
// 如果文本已经结束,但模式还有非%字符,则不匹配
if textIndex >= len(text) {
// 检查剩余的模式是否都是%
for i := patternIndex; i < len(pattern); i++ {
if pattern[i] != '%' {
return false
}
}
return true
}
switch pattern[patternIndex] {
case '%':
// %可以匹配0个或多个字符
// 尝试匹配0个字符跳过%
if likeMatch(text, pattern, textIndex, patternIndex+1) {
return true
}
// 尝试匹配1个或多个字符
for i := textIndex; i < len(text); i++ {
if likeMatch(text, pattern, i+1, patternIndex+1) {
return true
}
}
return false
case '_':
// _匹配任意单个字符
return likeMatch(text, pattern, textIndex+1, patternIndex+1)
default:
// 普通字符必须精确匹配
if text[textIndex] == pattern[patternIndex] {
return likeMatch(text, pattern, textIndex+1, patternIndex+1)
}
return false
}
}
// convertToFloat 将值转换为float64
func convertToFloat(val interface{}) (float64, error) {
switch v := val.(type) {
case float64:
return v, nil
case float32:
return float64(v), nil
case int:
return float64(v), nil
case int32:
return float64(v), nil
case int64:
return float64(v), nil
case string:
return strconv.ParseFloat(v, 64)
default:
return 0, fmt.Errorf("cannot convert %T to float64", val)
}
}
// tokenize 将表达式字符串转换为token列表
func tokenize(expr string) ([]string, error) {
expr = strings.TrimSpace(expr)
if expr == "" {
return nil, fmt.Errorf("empty expression")
}
tokens := make([]string, 0)
i := 0
for i < len(expr) {
ch := expr[i]
// 跳过空白字符
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
i++
continue
}
// 处理数字
if isDigit(ch) || (ch == '.' && i+1 < len(expr) && isDigit(expr[i+1])) {
start := i
hasDot := ch == '.'
i++
for i < len(expr) && (isDigit(expr[i]) || (expr[i] == '.' && !hasDot)) {
if expr[i] == '.' {
hasDot = true
}
i++
}
tokens = append(tokens, expr[start:i])
continue
}
// 处理运算符和括号
if ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '%' || ch == '^' ||
ch == '(' || ch == ')' || ch == ',' {
// 特殊处理负号:如果是负号且前面是运算符、括号或开始位置,则可能是负数
if ch == '-' {
// 检查是否可能是负数的开始
canBeNegativeNumber := i == 0 || // 表达式开始
len(tokens) == 0 // tokens为空时也可能是负数开始
// 只有当tokens不为空时才检查前一个token
if len(tokens) > 0 {
prevToken := tokens[len(tokens)-1]
canBeNegativeNumber = canBeNegativeNumber ||
prevToken == "(" || // 左括号后
prevToken == "," || // 逗号后(函数参数)
isOperator(prevToken) || // 运算符后
isComparisonOperator(prevToken) || // 比较运算符后
strings.ToUpper(prevToken) == "THEN" || // THEN后
strings.ToUpper(prevToken) == "ELSE" || // ELSE后
strings.ToUpper(prevToken) == "WHEN" || // WHEN后
strings.ToUpper(prevToken) == "AND" || // AND后
strings.ToUpper(prevToken) == "OR" // OR后
}
if canBeNegativeNumber && i+1 < len(expr) && isDigit(expr[i+1]) {
// 这是一个负数,解析整个数字
start := i
i++ // 跳过负号
// 解析数字部分
for i < len(expr) && (isDigit(expr[i]) || expr[i] == '.') {
i++
}
tokens = append(tokens, expr[start:i])
continue
}
}
tokens = append(tokens, string(ch))
i++
continue
}
// 处理比较运算符
if ch == '>' || ch == '<' || ch == '=' || ch == '!' {
start := i
i++
// 处理双字符运算符
if i < len(expr) {
switch ch {
case '>':
if expr[i] == '=' {
i++
tokens = append(tokens, ">=")
continue
}
case '<':
if expr[i] == '=' {
i++
tokens = append(tokens, "<=")
continue
} else if expr[i] == '>' {
i++
tokens = append(tokens, "<>")
continue
}
case '=':
if expr[i] == '=' {
i++
tokens = append(tokens, "==")
continue
}
case '!':
if expr[i] == '=' {
i++
tokens = append(tokens, "!=")
continue
}
}
}
// 单字符运算符
tokens = append(tokens, expr[start:i])
continue
}
// 处理字符串字面量(单引号和双引号)
if ch == '\'' || ch == '"' {
quote := ch
start := i
i++ // 跳过开始引号
// 寻找结束引号
for i < len(expr) && expr[i] != quote {
if expr[i] == '\\' && i+1 < len(expr) {
i += 2 // 跳过转义字符
} else {
i++
}
}
if i >= len(expr) {
return nil, fmt.Errorf("unterminated string literal starting at position %d", start)
}
i++ // 跳过结束引号
tokens = append(tokens, expr[start:i])
continue
}
// 处理反引号标识符
if ch == '`' {
start := i
i++ // 跳过开始反引号
// 寻找结束反引号
for i < len(expr) && expr[i] != '`' {
i++
}
if i >= len(expr) {
return nil, fmt.Errorf("unterminated quoted identifier starting at position %d", start)
}
i++ // 跳过结束反引号
tokens = append(tokens, expr[start:i])
continue
}
// 处理标识符(字段名或函数名)
if isLetter(ch) {
start := i
i++
for i < len(expr) && (isLetter(expr[i]) || isDigit(expr[i]) || expr[i] == '_') {
i++
}
tokens = append(tokens, expr[start:i])
continue
}
// 未知字符
return nil, fmt.Errorf("unexpected character: %c at position %d", ch, i)
}
return tokens, nil
}
// parseExpression 解析表达式
func parseExpression(tokens []string) (*ExprNode, error) {
if len(tokens) == 0 {
return nil, fmt.Errorf("empty token list")
}
// 使用Shunting-yard算法处理运算符优先级
output := make([]*ExprNode, 0)
operators := make([]string, 0)
i := 0
for i < len(tokens) {
token := tokens[i]
// 处理数字
if isNumber(token) {
output = append(output, &ExprNode{
Type: TypeNumber,
Value: token,
})
i++
continue
}
// 处理字符串字面量
if isStringLiteral(token) {
output = append(output, &ExprNode{
Type: TypeString,
Value: token,
})
i++
continue
}
// 处理字段名或函数调用
if isIdentifier(token) {
// 检查是否是逻辑运算符关键字
upperToken := strings.ToUpper(token)
if upperToken == "AND" || upperToken == "OR" || upperToken == "NOT" || upperToken == "LIKE" {
// 处理逻辑运算符
for len(operators) > 0 && operators[len(operators)-1] != "(" &&
operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[upperToken] {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
operators = append(operators, upperToken)
i++
continue
}
// 特殊处理IS运算符需要检查后续的NOT NULL组合
if upperToken == "IS" {
// 处理待处理的运算符
for len(operators) > 0 && operators[len(operators)-1] != "(" &&
operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence["IS"] {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
// 检查是否是IS NOT NULL模式
if i+2 < len(tokens) &&
strings.ToUpper(tokens[i+1]) == "NOT" &&
strings.ToUpper(tokens[i+2]) == "NULL" {
// 这是IS NOT NULL创建特殊的右侧节点结构
notNullNode := &ExprNode{
Type: TypeOperator,
Value: "NOT",
Right: &ExprNode{
Type: TypeField,
Value: "NULL",
},
}
operators = append(operators, "IS")
output = append(output, notNullNode)
i += 3 // 跳过IS NOT NULL三个token
continue
} else if i+1 < len(tokens) && strings.ToUpper(tokens[i+1]) == "NULL" {
// 这是IS NULL创建NULL节点
nullNode := &ExprNode{
Type: TypeField,
Value: "NULL",
}
operators = append(operators, "IS")
output = append(output, nullNode)
i += 2 // 跳过IS NULL两个token
continue
} else {
// 普通的IS运算符
operators = append(operators, "IS")
i++
continue
}
}
// 检查是否是CASE表达式
if strings.ToUpper(token) == "CASE" {
caseNode, newIndex, err := parseCaseExpression(tokens, i)
if err != nil {
return nil, err
}
output = append(output, caseNode)
i = newIndex
continue
}
// 检查下一个token是否是左括号如果是则为函数调用
if i+1 < len(tokens) && tokens[i+1] == "(" {
funcName := token
i += 2 // 跳过函数名和左括号
// 解析函数参数
args, newIndex, err := parseFunctionArgs(tokens, i)
if err != nil {
return nil, err
}
output = append(output, &ExprNode{
Type: TypeFunction,
Value: funcName,
Args: args,
})
i = newIndex
continue
}
// 普通字段
output = append(output, &ExprNode{
Type: TypeField,
Value: token,
})
i++
continue
}
// 处理左括号
if token == "(" {
operators = append(operators, token)
i++
continue
}
// 处理右括号
if token == ")" {
for len(operators) > 0 && operators[len(operators)-1] != "(" {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
if len(operators) == 0 || operators[len(operators)-1] != "(" {
return nil, fmt.Errorf("mismatched parentheses")
}
operators = operators[:len(operators)-1] // 弹出左括号
i++
continue
}
// 处理运算符
if isOperator(token) {
for len(operators) > 0 && operators[len(operators)-1] != "(" &&
operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[token] {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
operators = append(operators, token)
i++
continue
}
// 处理逗号(在函数参数列表中处理)
if token == "," {
i++
continue
}
return nil, fmt.Errorf("unexpected token: %s", token)
}
// 处理剩余的运算符
for len(operators) > 0 {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if op == "(" {
return nil, fmt.Errorf("mismatched parentheses")
}
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
if len(output) != 1 {
return nil, fmt.Errorf("invalid expression")
}
return output[0], nil
}
// parseFunctionArgs 解析函数参数
func parseFunctionArgs(tokens []string, startIndex int) ([]*ExprNode, int, error) {
args := make([]*ExprNode, 0)
i := startIndex
// 处理空参数列表
if i < len(tokens) && tokens[i] == ")" {
return args, i + 1, nil
}
for i < len(tokens) {
// 解析参数表达式
argTokens := make([]string, 0)
parenthesesCount := 0
for i < len(tokens) {
token := tokens[i]
if token == "(" {
parenthesesCount++
} else if token == ")" {
parenthesesCount--
if parenthesesCount < 0 {
break
}
} else if token == "," && parenthesesCount == 0 {
break
}
argTokens = append(argTokens, token)
i++
}
if len(argTokens) > 0 {
arg, err := parseExpression(argTokens)
if err != nil {
return nil, 0, err
}
args = append(args, arg)
}
if i >= len(tokens) {
return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments")
}
if tokens[i] == ")" {
return args, i + 1, nil
}
if tokens[i] == "," {
i++
continue
}
return nil, 0, fmt.Errorf("unexpected token in function arguments: %s", tokens[i])
}
return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments")
}
// parseCaseExpression 解析CASE表达式
func parseCaseExpression(tokens []string, startIndex int) (*ExprNode, int, error) {
if startIndex >= len(tokens) || strings.ToUpper(tokens[startIndex]) != "CASE" {
return nil, startIndex, fmt.Errorf("expected CASE keyword")
}
caseNode := &ExprNode{
Type: TypeCase,
WhenClauses: make([]WhenClause, 0),
}
i := startIndex + 1 // 跳过CASE关键字
// 检查是否是简单CASE表达式CASE expr WHEN value1 THEN result1 ...
// 或搜索CASE表达式CASE WHEN condition1 THEN result1 ...
if i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" {
// 这是简单CASE表达式需要解析CASE后面的表达式
caseExprTokens := make([]string, 0)
// 收集CASE表达式直到遇到WHEN
for i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" {
caseExprTokens = append(caseExprTokens, tokens[i])
i++
}
if len(caseExprTokens) == 0 {
return nil, i, fmt.Errorf("expected expression after CASE")
}
// 对于简单的情况直接处理单个token
if len(caseExprTokens) == 1 {
token := caseExprTokens[0]
if isNumber(token) {
caseNode.CaseExpr = &ExprNode{Type: TypeNumber, Value: token}
} else if isStringLiteral(token) {
caseNode.CaseExpr = &ExprNode{Type: TypeString, Value: token}
} else if isIdentifier(token) {
caseNode.CaseExpr = &ExprNode{Type: TypeField, Value: token}
} else {
return nil, i, fmt.Errorf("invalid CASE expression token: %s", token)
}
} else {
// 对于复杂表达式调用parseExpression
caseExpr, err := parseExpression(caseExprTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse CASE expression: %w", err)
}
caseNode.CaseExpr = caseExpr
}
}
// 解析WHEN子句
for i < len(tokens) && strings.ToUpper(tokens[i]) == "WHEN" {
i++ // 跳过WHEN关键字
// 收集WHEN条件直到遇到THEN
conditionTokens := make([]string, 0)
for i < len(tokens) && strings.ToUpper(tokens[i]) != "THEN" {
conditionTokens = append(conditionTokens, tokens[i])
i++
}
if len(conditionTokens) == 0 {
return nil, i, fmt.Errorf("expected condition after WHEN")
}
if i >= len(tokens) || strings.ToUpper(tokens[i]) != "THEN" {
return nil, i, fmt.Errorf("expected THEN after WHEN condition")
}
i++ // 跳过THEN关键字
// 收集THEN结果直到遇到WHEN、ELSE或END
resultTokens := make([]string, 0)
for i < len(tokens) {
upper := strings.ToUpper(tokens[i])
if upper == "WHEN" || upper == "ELSE" || upper == "END" {
break
}
resultTokens = append(resultTokens, tokens[i])
i++
}
if len(resultTokens) == 0 {
return nil, i, fmt.Errorf("expected result after THEN")
}
// 解析条件和结果表达式
conditionExpr, err := parseExpression(conditionTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse WHEN condition: %w", err)
}
resultExpr, err := parseExpression(resultTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse THEN result: %w", err)
}
// 添加WHEN子句
caseNode.WhenClauses = append(caseNode.WhenClauses, WhenClause{
Condition: conditionExpr,
Result: resultExpr,
})
}
// 检查是否有ELSE子句
if i < len(tokens) && strings.ToUpper(tokens[i]) == "ELSE" {
i++ // 跳过ELSE关键字
// 收集ELSE结果直到遇到END
elseTokens := make([]string, 0)
for i < len(tokens) && strings.ToUpper(tokens[i]) != "END" {
elseTokens = append(elseTokens, tokens[i])
i++
}
if len(elseTokens) == 0 {
return nil, i, fmt.Errorf("expected result after ELSE")
}
// 解析ELSE表达式
elseExpr, err := parseExpression(elseTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse ELSE result: %w", err)
}
caseNode.ElseExpr = elseExpr
}
// 检查END关键字
if i >= len(tokens) || strings.ToUpper(tokens[i]) != "END" {
return nil, i, fmt.Errorf("expected END to close CASE expression")
}
i++ // 跳过END关键字
// 验证至少有一个WHEN子句
if len(caseNode.WhenClauses) == 0 {
return nil, i, fmt.Errorf("CASE expression must have at least one WHEN clause")
}
return caseNode, i, nil
}
// 辅助函数
func isDigit(ch byte) bool {
return ch >= '0' && ch <= '9'
}
func isLetter(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
func isNumber(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}
func isIdentifier(s string) bool {
if len(s) == 0 {
return false
}
if !isLetter(s[0]) && s[0] != '_' {
return false
}
for i := 1; i < len(s); i++ {
if !isLetter(s[i]) && !isDigit(s[i]) && s[i] != '_' {
return false
}
}
return true
}
func isOperator(s string) bool {
switch s {
case "+", "-", "*", "/", "%", "^":
return true
case ">", "<", ">=", "<=", "==", "=", "!=", "<>":
return true
case "AND", "OR", "NOT":
return true
case "LIKE", "IS":
return true
default:
return false
}
}
// isComparisonOperator 检查是否是比较运算符
func isComparisonOperator(s string) bool {
switch s {
case ">", "<", ">=", "<=", "==", "=", "!=", "<>":
return true
default:
return false
}
}
func isStringLiteral(expr string) bool {
return len(expr) > 1 && (expr[0] == '\'' || expr[0] == '"') && expr[len(expr)-1] == expr[0]
}
// evaluateNodeWithNull 计算节点值支持NULL值返回
// 返回 (result, isNull, error)
func evaluateNodeWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) {
if node == nil {
return 0, true, nil // NULL
}
switch node.Type {
case TypeNumber:
val, err := strconv.ParseFloat(node.Value, 64)
return val, false, err
case TypeString:
// 字符串长度作为数值特殊处理NULL字符串
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1]
}
// 检查是否是NULL字符串
if strings.ToUpper(value) == "NULL" {
return 0, true, nil
}
return float64(len(value)), false, nil
case TypeField:
// 支持嵌套字段访问
var fieldVal interface{}
var found bool
if fieldpath.IsNestedField(node.Value) {
fieldVal, found = fieldpath.GetNestedField(data, node.Value)
} else {
fieldVal, found = data[node.Value]
}
if !found || fieldVal == nil {
return 0, true, nil // NULL
}
// 尝试转换为数值
if val, err := convertToFloat(fieldVal); err == nil {
return val, false, nil
}
return 0, true, fmt.Errorf("cannot convert field '%s' to number", node.Value)
case TypeOperator:
return evaluateOperatorWithNull(node, data)
case TypeFunction:
// 函数调用保持原有逻辑但处理NULL结果
result, err := evaluateBuiltinFunction(node, data)
return result, false, err
case TypeCase:
return evaluateCaseExpressionWithNull(node, data)
default:
return 0, true, fmt.Errorf("unsupported node type: %s", node.Type)
}
}
// evaluateOperatorWithNull 计算运算符表达式支持NULL值
func evaluateOperatorWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) {
leftVal, leftNull, err := evaluateNodeWithNull(node.Left, data)
if err != nil {
return 0, false, err
}
rightVal, rightNull, err := evaluateNodeWithNull(node.Right, data)
if err != nil {
return 0, false, err
}
// 算术运算如果任一操作数为NULL结果为NULL
if leftNull || rightNull {
switch node.Value {
case "+", "-", "*", "/", "%", "^":
return 0, true, nil
}
}
// 比较运算NULL值的比较有特殊规则
switch node.Value {
case "==", "=":
if leftNull && rightNull {
return 1, false, nil // NULL = NULL 为 true
}
if leftNull || rightNull {
return 0, false, nil // NULL = value 为 false
}
if leftVal == rightVal {
return 1, false, nil
}
return 0, false, nil
case "!=", "<>":
if leftNull && rightNull {
return 0, false, nil // NULL != NULL 为 false
}
if leftNull || rightNull {
return 0, false, nil // NULL != value 为 false
}
if leftVal != rightVal {
return 1, false, nil
}
return 0, false, nil
case ">", "<", ">=", "<=":
if leftNull || rightNull {
return 0, false, nil // NULL与任何值的比较都为false
}
}
// 对于非NULL值执行正常的算术和比较运算
switch node.Value {
case "+":
return leftVal + rightVal, false, nil
case "-":
return leftVal - rightVal, false, nil
case "*":
return leftVal * rightVal, false, nil
case "/":
if rightVal == 0 {
return 0, true, nil // 除零返回NULL
}
return leftVal / rightVal, false, nil
case "%":
if rightVal == 0 {
return 0, true, nil
}
return math.Mod(leftVal, rightVal), false, nil
case "^":
return math.Pow(leftVal, rightVal), false, nil
case ">":
if leftVal > rightVal {
return 1, false, nil
}
return 0, false, nil
case "<":
if leftVal < rightVal {
return 1, false, nil
}
return 0, false, nil
case ">=":
if leftVal >= rightVal {
return 1, false, nil
}
return 0, false, nil
case "<=":
if leftVal <= rightVal {
return 1, false, nil
}
return 0, false, nil
default:
return 0, false, fmt.Errorf("unsupported operator: %s", node.Value)
}
}
// evaluateCaseExpressionWithNull 计算CASE表达式支持NULL值
func evaluateCaseExpressionWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) {
if node.Type != TypeCase {
return 0, false, fmt.Errorf("node is not a CASE expression")
}
// 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...)
if node.CaseExpr != nil {
// 计算CASE后面的表达式值
caseValue, caseNull, err := evaluateNodeValueWithNull(node.CaseExpr, data)
if err != nil {
return 0, false, err
}
// 遍历WHEN子句查找匹配的值
for _, whenClause := range node.WhenClauses {
conditionValue, condNull, err := evaluateNodeValueWithNull(whenClause.Condition, data)
if err != nil {
return 0, false, err
}
// 比较值是否相等考虑NULL值
var isEqual bool
if caseNull && condNull {
isEqual = true // NULL = NULL
} else if caseNull || condNull {
isEqual = false // NULL != value
} else {
isEqual, err = compareValuesForEquality(caseValue, conditionValue)
if err != nil {
return 0, false, err
}
}
if isEqual {
return evaluateNodeWithNull(whenClause.Result, data)
}
}
} else {
// 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...)
for _, whenClause := range node.WhenClauses {
// 评估WHEN条件
conditionResult, err := evaluateBooleanConditionWithNull(whenClause.Condition, data)
if err != nil {
return 0, false, err
}
// 如果条件为真,返回对应的结果
if conditionResult {
return evaluateNodeWithNull(whenClause.Result, data)
}
}
}
// 如果没有匹配的WHEN子句执行ELSE子句
if node.ElseExpr != nil {
return evaluateNodeWithNull(node.ElseExpr, data)
}
// 如果没有ELSE子句SQL标准是返回NULL
return 0, true, nil
}
// evaluateNodeValueWithNull 计算节点值返回interface{}以支持不同类型包含NULL检查
func evaluateNodeValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) {
if node == nil {
return nil, true, nil
}
switch node.Type {
case TypeNumber:
val, err := strconv.ParseFloat(node.Value, 64)
return val, false, err
case TypeString:
// 去掉引号
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1]
}
// 检查是否是NULL字符串
if strings.ToUpper(value) == "NULL" {
return nil, true, nil
}
return value, false, nil
case TypeField:
// 处理反引号标识符,去除反引号
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号
}
// 支持嵌套字段访问
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
return val, val == nil, nil
}
} else {
// 原有的简单字段访问
if val, found := data[fieldName]; found {
return val, val == nil, nil
}
}
return nil, true, nil // 字段不存在视为NULL
default:
// 对于其他类型,回退到数值计算
result, isNull, err := evaluateNodeWithNull(node, data)
return result, isNull, err
}
}
// evaluateBooleanConditionWithNull 计算布尔条件表达式支持NULL值
func evaluateBooleanConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil {
return false, fmt.Errorf("null condition expression")
}
// 处理逻辑运算符(实现短路求值)
if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") {
leftBool, err := evaluateBooleanConditionWithNull(node.Left, data)
if err != nil {
return false, err
}
// 短路求值对于AND如果左边为false立即返回false
if node.Value == "AND" && !leftBool {
return false, nil
}
// 短路求值对于OR如果左边为true立即返回true
if node.Value == "OR" && leftBool {
return true, nil
}
// 只有在需要时才评估右边的表达式
rightBool, err := evaluateBooleanConditionWithNull(node.Right, data)
if err != nil {
return false, err
}
switch node.Value {
case "AND":
return leftBool && rightBool, nil
case "OR":
return leftBool || rightBool, nil
}
}
// 处理IS NULL和IS NOT NULL特殊情况
if node.Type == TypeOperator && node.Value == "IS" {
return evaluateIsConditionWithNull(node, data)
}
// 处理比较运算符
if node.Type == TypeOperator {
leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data)
if err != nil {
return false, err
}
rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data)
if err != nil {
return false, err
}
return compareValuesWithNull(leftValue, leftNull, rightValue, rightNull, node.Value)
}
// 对于其他表达式,计算其数值并转换为布尔值
result, isNull, err := evaluateNodeWithNull(node, data)
if err != nil {
return false, err
}
// NULL值在布尔上下文中为false非零值为真零值为假
return !isNull && result != 0, nil
}
// evaluateIsConditionWithNull 处理IS NULL和IS NOT NULL条件支持NULL值
func evaluateIsConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil || node.Left == nil || node.Right == nil {
return false, fmt.Errorf("invalid IS condition")
}
// 获取左侧值
leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data)
if err != nil {
// 如果字段不存在认为是null
leftValue = nil
leftNull = true
}
// 检查右侧是否是NULL或NOT NULL
if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" {
// IS NULL
return leftNull || leftValue == nil, nil
}
// 检查是否是IS NOT NULL
if node.Right.Type == TypeOperator && node.Right.Value == "NOT" &&
node.Right.Right != nil && node.Right.Right.Type == TypeField &&
strings.ToUpper(node.Right.Right.Value) == "NULL" {
// IS NOT NULL
return !leftNull && leftValue != nil, nil
}
// 其他IS比较
rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data)
if err != nil {
return false, err
}
return compareValuesWithNullForEquality(leftValue, leftNull, rightValue, rightNull)
}
// compareValuesForEquality 比较两个值是否相等
func compareValuesForEquality(left, right interface{}) (bool, error) {
// 尝试字符串比较
leftStr, leftIsStr := left.(string)
rightStr, rightIsStr := right.(string)
if leftIsStr && rightIsStr {
return leftStr == rightStr, nil
}
// 尝试数值比较
leftFloat, leftErr := convertToFloat(left)
rightFloat, rightErr := convertToFloat(right)
if leftErr == nil && rightErr == nil {
return leftFloat == rightFloat, nil
}
// 如果都不能转换,直接比较
return left == right, nil
}
// compareValuesWithNull 比较两个值支持NULL
func compareValuesWithNull(left interface{}, leftNull bool, right interface{}, rightNull bool, operator string) (bool, error) {
// NULL值的比较有特殊规则
switch operator {
case "==", "=":
if leftNull && rightNull {
return true, nil // NULL = NULL 为 true
}
if leftNull || rightNull {
return false, nil // NULL = value 为 false
}
case "!=", "<>":
if leftNull && rightNull {
return false, nil // NULL != NULL 为 false
}
if leftNull || rightNull {
return false, nil // NULL != value 为 false
}
case ">", "<", ">=", "<=":
if leftNull || rightNull {
return false, nil // NULL与任何值的比较都为false
}
}
// 对于非NULL值执行正确的比较逻辑
switch operator {
case "==", "=":
return compareValuesForEquality(left, right)
case "!=", "<>":
equal, err := compareValuesForEquality(left, right)
return !equal, err
case ">", "<", ">=", "<=":
// 进行数值比较
leftFloat, leftErr := convertToFloat(left)
rightFloat, rightErr := convertToFloat(right)
if leftErr != nil || rightErr != nil {
// 如果不能转换为数值,尝试字符串比较
leftStr := fmt.Sprintf("%v", left)
rightStr := fmt.Sprintf("%v", right)
switch operator {
case ">":
return leftStr > rightStr, nil
case "<":
return leftStr < rightStr, nil
case ">=":
return leftStr >= rightStr, nil
case "<=":
return leftStr <= rightStr, nil
}
}
// 数值比较
switch operator {
case ">":
return leftFloat > rightFloat, nil
case "<":
return leftFloat < rightFloat, nil
case ">=":
return leftFloat >= rightFloat, nil
case "<=":
return leftFloat <= rightFloat, nil
}
}
return false, fmt.Errorf("unsupported operator: %s", operator)
}
// compareValuesWithNullForEquality 比较两个值是否相等支持NULL
func compareValuesWithNullForEquality(left interface{}, leftNull bool, right interface{}, rightNull bool) (bool, error) {
if leftNull && rightNull {
return true, nil // NULL = NULL 为 true
}
if leftNull || rightNull {
return false, nil // NULL = value 为 false
}
return compareValuesForEquality(left, right)
}
// EvaluateWithNull 提供公开接口,用于聚合函数调用
func (e *Expression) EvaluateWithNull(data map[string]interface{}) (float64, bool, error) {
if e.useExprLang {
// expr-lang不支持NULL回退到原有逻辑
result, err := e.evaluateWithExprLang(data)
return result, false, err
}
return evaluateNodeWithNull(e.Root, data)
}