package expr import ( "fmt" "math" "strconv" "strings" "github.com/rulego/streamsql/functions" ) // 表达式类型 const ( TypeNumber = "number" // 数字常量 TypeField = "field" // 字段引用 TypeOperator = "operator" // 运算符 TypeFunction = "function" // 函数调用 TypeParenthesis = "parenthesis" // 括号 TypeCase = "case" // CASE表达式 TypeString = "string" // 字符串常量 ) // 操作符优先级 var operatorPrecedence = map[string]int{ "OR": 1, "AND": 2, "==": 3, "=": 3, "!=": 3, "<>": 3, ">": 4, "<": 4, ">=": 4, "<=": 4, "+": 5, "-": 5, "*": 6, "/": 6, "%": 6, "^": 7, // 幂运算 } // CASE表达式的WHEN子句 type WhenClause struct { Condition *ExprNode // WHEN条件 Result *ExprNode // THEN结果 } // 表达式节点 type ExprNode struct { Type string Value string Left *ExprNode Right *ExprNode Args []*ExprNode // 用于函数调用的参数 // CASE表达式专用字段 CaseExpr *ExprNode // CASE后面的表达式(简单CASE) WhenClauses []WhenClause // WHEN子句列表 ElseExpr *ExprNode // ELSE表达式 } // Expression 表示一个可计算的表达式 type Expression struct { Root *ExprNode useExprLang bool // 是否使用expr-lang/expr exprLangExpression string // expr-lang表达式字符串 } // NewExpression 创建一个新的表达式 func NewExpression(exprStr string) (*Expression, error) { // 首先尝试使用自定义解析器 tokens, err := tokenize(exprStr) if err != nil { // 如果自定义解析失败,标记为使用expr-lang return &Expression{ Root: nil, useExprLang: true, exprLangExpression: exprStr, }, nil } root, err := parseExpression(tokens) if err != nil { // 如果自定义解析失败,标记为使用expr-lang return &Expression{ Root: nil, useExprLang: true, exprLangExpression: exprStr, }, nil } return &Expression{ Root: root, useExprLang: false, }, nil } // Evaluate 计算表达式的值 func (e *Expression) Evaluate(data map[string]interface{}) (float64, error) { if e.useExprLang { return e.evaluateWithExprLang(data) } return evaluateNode(e.Root, data) } // evaluateWithExprLang 使用expr-lang/expr评估表达式 func (e *Expression) evaluateWithExprLang(data map[string]interface{}) (float64, error) { // 使用桥接器评估表达式 bridge := functions.GetExprBridge() result, err := bridge.EvaluateExpression(e.exprLangExpression, data) if err != nil { return 0, err } // 尝试转换结果为float64 switch v := result.(type) { case float64: return v, nil case float32: return float64(v), nil case int: return float64(v), nil case int32: return float64(v), nil case int64: return float64(v), nil case string: if f, err := strconv.ParseFloat(v, 64); err == nil { return f, nil } return 0, fmt.Errorf("cannot convert string result '%s' to float64", v) default: return 0, fmt.Errorf("expression result type %T is not convertible to float64", result) } } // GetFields 获取表达式中引用的所有字段 func (e *Expression) GetFields() []string { if e.useExprLang { // 对于expr-lang表达式,需要解析字段引用 // 这里简化处理,实际应该使用AST分析 return extractFieldsFromExprLang(e.exprLangExpression) } fields := make(map[string]bool) collectFields(e.Root, fields) result := make([]string, 0, len(fields)) for field := range fields { result = append(result, field) } return result } // extractFieldsFromExprLang 从expr-lang表达式中提取字段引用(简化版本) func extractFieldsFromExprLang(expression string) []string { // 这是一个简化的实现,实际应该使用AST解析 // 暂时使用正则表达式或简单的字符串解析 fields := make(map[string]bool) // 简单的字段提取:查找标识符模式 tokens := strings.FieldsFunc(expression, func(c rune) bool { return !(c >= 'a' && c <= 'z') && !(c >= 'A' && c <= 'Z') && !(c >= '0' && c <= '9') && c != '_' }) for _, token := range tokens { if isIdentifier(token) && !isNumber(token) && !isFunctionOrKeyword(token) { fields[token] = true } } result := make([]string, 0, len(fields)) for field := range fields { result = append(result, field) } return result } // isFunctionOrKeyword 检查是否是函数名或关键字 func isFunctionOrKeyword(token string) bool { // 检查是否是已知函数或关键字 keywords := []string{ "and", "or", "not", "true", "false", "nil", "null", "if", "else", "then", "in", "contains", "matches", // CASE表达式关键字 "case", "when", "then", "else", "end", } for _, keyword := range keywords { if strings.ToLower(token) == keyword { return true } } // 检查是否是注册的函数 bridge := functions.GetExprBridge() _, exists, _ := bridge.ResolveFunction(token) return exists } // collectFields 收集表达式中所有字段 func collectFields(node *ExprNode, fields map[string]bool) { if node == nil { return } if node.Type == TypeField { fields[node.Value] = true } // 处理CASE表达式的字段收集 if node.Type == TypeCase { // 收集CASE表达式本身的字段 if node.CaseExpr != nil { collectFields(node.CaseExpr, fields) } // 收集所有WHEN子句中的字段 for _, whenClause := range node.WhenClauses { collectFields(whenClause.Condition, fields) collectFields(whenClause.Result, fields) } // 收集ELSE表达式中的字段 if node.ElseExpr != nil { collectFields(node.ElseExpr, fields) } return } collectFields(node.Left, fields) collectFields(node.Right, fields) for _, arg := range node.Args { collectFields(arg, fields) } } // evaluateNode 计算节点的值 func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) { if node == nil { return 0, fmt.Errorf("null expression node") } switch node.Type { case TypeNumber: return strconv.ParseFloat(node.Value, 64) case TypeString: // 处理字符串类型,去掉引号并尝试转换为数字 // 如果无法转换,返回错误(因为这个函数返回float64) value := node.Value if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { value = value[1 : len(value)-1] // 去掉引号 } // 尝试转换为数字 if f, err := strconv.ParseFloat(value, 64); err == nil { return f, nil } // 对于字符串比较,我们需要返回一个哈希值或者错误 // 这里简化处理,将字符串转换为其长度(作为临时解决方案) return float64(len(value)), nil case TypeField: // 从数据中获取字段值 val, ok := data[node.Value] if !ok { return 0, fmt.Errorf("field %s not found in data", node.Value) } // 尝试转换为 float64 switch v := val.(type) { case float64: return v, nil case float32: return float64(v), nil case int: return float64(v), nil case int32: return float64(v), nil case int64: return float64(v), nil default: // 尝试字符串转换 if strVal, ok := val.(string); ok { if f, err := strconv.ParseFloat(strVal, 64); err == nil { return f, nil } } return 0, fmt.Errorf("cannot convert field %s value to number", node.Value) } case TypeOperator: // 计算左右子表达式的值 left, err := evaluateNode(node.Left, data) if err != nil { return 0, err } right, err := evaluateNode(node.Right, data) if err != nil { return 0, err } // 执行运算 switch node.Value { case "+": return left + right, nil case "-": return left - right, nil case "*": return left * right, nil case "/": if right == 0 { return 0, fmt.Errorf("division by zero") } return left / right, nil case "%": if right == 0 { return 0, fmt.Errorf("modulo by zero") } return math.Mod(left, right), nil case "^": return math.Pow(left, right), nil default: return 0, fmt.Errorf("unknown operator: %s", node.Value) } case TypeFunction: // 首先检查是否是新的函数注册系统中的函数 fn, exists := functions.Get(node.Value) if exists { // 计算所有参数 args := make([]interface{}, len(node.Args)) for i, arg := range node.Args { val, err := evaluateNode(arg, data) if err != nil { return 0, err } args[i] = val } // 创建函数执行上下文 ctx := &functions.FunctionContext{ Data: data, } // 执行函数 result, err := fn.Execute(ctx, args) if err != nil { return 0, err } // 转换结果为 float64 switch r := result.(type) { case float64: return r, nil case float32: return float64(r), nil case int: return float64(r), nil case int32: return float64(r), nil case int64: return float64(r), nil default: return 0, fmt.Errorf("function %s returned non-numeric value", node.Value) } } // 回退到内置函数处理(保持向后兼容) return evaluateBuiltinFunction(node, data) case TypeCase: // 处理CASE表达式 return evaluateCaseExpression(node, data) } return 0, fmt.Errorf("unknown node type: %s", node.Type) } // evaluateBuiltinFunction 处理内置函数(向后兼容) func evaluateBuiltinFunction(node *ExprNode, data map[string]interface{}) (float64, error) { switch strings.ToLower(node.Value) { case "abs": if len(node.Args) != 1 { return 0, fmt.Errorf("abs function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } return math.Abs(arg), nil case "sqrt": if len(node.Args) != 1 { return 0, fmt.Errorf("sqrt function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } if arg < 0 { return 0, fmt.Errorf("sqrt of negative number") } return math.Sqrt(arg), nil case "sin": if len(node.Args) != 1 { return 0, fmt.Errorf("sin function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } return math.Sin(arg), nil case "cos": if len(node.Args) != 1 { return 0, fmt.Errorf("cos function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } return math.Cos(arg), nil case "tan": if len(node.Args) != 1 { return 0, fmt.Errorf("tan function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } return math.Tan(arg), nil case "floor": if len(node.Args) != 1 { return 0, fmt.Errorf("floor function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } return math.Floor(arg), nil case "ceil": if len(node.Args) != 1 { return 0, fmt.Errorf("ceil function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } return math.Ceil(arg), nil case "round": if len(node.Args) != 1 { return 0, fmt.Errorf("round function requires exactly 1 argument") } arg, err := evaluateNode(node.Args[0], data) if err != nil { return 0, err } return math.Round(arg), nil default: return 0, fmt.Errorf("unknown function: %s", node.Value) } } // evaluateCaseExpression 计算CASE表达式 func evaluateCaseExpression(node *ExprNode, data map[string]interface{}) (float64, error) { if node.Type != TypeCase { return 0, fmt.Errorf("node is not a CASE expression") } // 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...) if node.CaseExpr != nil { // 计算CASE后面的表达式值 caseValue, err := evaluateNodeValue(node.CaseExpr, data) if err != nil { return 0, err } // 遍历WHEN子句,查找匹配的值 for _, whenClause := range node.WhenClauses { conditionValue, err := evaluateNodeValue(whenClause.Condition, data) if err != nil { return 0, err } // 比较值是否相等 isEqual, err := compareValues(caseValue, conditionValue, "==") if err != nil { return 0, err } if isEqual { return evaluateNode(whenClause.Result, data) } } } else { // 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...) for _, whenClause := range node.WhenClauses { // 评估WHEN条件,这里需要特殊处理布尔表达式 conditionResult, err := evaluateBooleanCondition(whenClause.Condition, data) if err != nil { return 0, err } // 如果条件为真,返回对应的结果 if conditionResult { return evaluateNode(whenClause.Result, data) } } } // 如果没有匹配的WHEN子句,执行ELSE子句 if node.ElseExpr != nil { return evaluateNode(node.ElseExpr, data) } // 如果没有ELSE子句,SQL标准是返回NULL,这里返回0 return 0, nil } // evaluateBooleanCondition 计算布尔条件表达式 func evaluateBooleanCondition(node *ExprNode, data map[string]interface{}) (bool, error) { if node == nil { return false, fmt.Errorf("null condition expression") } // 处理逻辑运算符 if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") { leftBool, err := evaluateBooleanCondition(node.Left, data) if err != nil { return false, err } rightBool, err := evaluateBooleanCondition(node.Right, data) if err != nil { return false, err } switch node.Value { case "AND": return leftBool && rightBool, nil case "OR": return leftBool || rightBool, nil } } // 处理比较运算符 if node.Type == TypeOperator { leftValue, err := evaluateNodeValue(node.Left, data) if err != nil { return false, err } rightValue, err := evaluateNodeValue(node.Right, data) if err != nil { return false, err } return compareValues(leftValue, rightValue, node.Value) } // 对于其他表达式,计算其数值并转换为布尔值 result, err := evaluateNode(node, data) if err != nil { return false, err } // 非零值为真,零值为假 return result != 0, nil } // evaluateNodeValue 计算节点值,返回interface{}以支持不同类型 func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{}, error) { if node == nil { return nil, fmt.Errorf("null expression node") } switch node.Type { case TypeNumber: return strconv.ParseFloat(node.Value, 64) case TypeString: // 去掉引号 value := node.Value if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { value = value[1 : len(value)-1] } return value, nil case TypeField: val, ok := data[node.Value] if !ok { return nil, fmt.Errorf("field %s not found in data", node.Value) } return val, nil default: // 对于其他类型,回退到数值计算 return evaluateNode(node, data) } } // compareValues 比较两个值 func compareValues(left, right interface{}, operator string) (bool, error) { // 尝试字符串比较 leftStr, leftIsStr := left.(string) rightStr, rightIsStr := right.(string) if leftIsStr && rightIsStr { switch operator { case "==", "=": return leftStr == rightStr, nil case "!=", "<>": return leftStr != rightStr, nil case ">": return leftStr > rightStr, nil case ">=": return leftStr >= rightStr, nil case "<": return leftStr < rightStr, nil case "<=": return leftStr <= rightStr, nil default: return false, fmt.Errorf("unsupported string comparison operator: %s", operator) } } // 转换为数值进行比较 leftNum, err1 := convertToFloat(left) rightNum, err2 := convertToFloat(right) if err1 != nil || err2 != nil { return false, fmt.Errorf("cannot compare values: %v and %v", left, right) } switch operator { case ">": return leftNum > rightNum, nil case ">=": return leftNum >= rightNum, nil case "<": return leftNum < rightNum, nil case "<=": return leftNum <= rightNum, nil case "==", "=": return math.Abs(leftNum-rightNum) < 1e-9, nil case "!=", "<>": return math.Abs(leftNum-rightNum) >= 1e-9, nil default: return false, fmt.Errorf("unsupported comparison operator: %s", operator) } } // convertToFloat 将值转换为float64 func convertToFloat(val interface{}) (float64, error) { switch v := val.(type) { case float64: return v, nil case float32: return float64(v), nil case int: return float64(v), nil case int32: return float64(v), nil case int64: return float64(v), nil case string: return strconv.ParseFloat(v, 64) default: return 0, fmt.Errorf("cannot convert %T to float64", val) } } // tokenize 将表达式字符串转换为token列表 func tokenize(expr string) ([]string, error) { expr = strings.TrimSpace(expr) if expr == "" { return nil, fmt.Errorf("empty expression") } tokens := make([]string, 0) i := 0 for i < len(expr) { ch := expr[i] // 跳过空白字符 if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' { i++ continue } // 处理数字 if isDigit(ch) || (ch == '.' && i+1 < len(expr) && isDigit(expr[i+1])) { start := i hasDot := ch == '.' i++ for i < len(expr) && (isDigit(expr[i]) || (expr[i] == '.' && !hasDot)) { if expr[i] == '.' { hasDot = true } i++ } tokens = append(tokens, expr[start:i]) continue } // 处理运算符和括号 if ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '%' || ch == '^' || ch == '(' || ch == ')' || ch == ',' { // 特殊处理负号:如果是负号且前面是运算符、括号或开始位置,则可能是负数 if ch == '-' { // 检查是否可能是负数的开始 prevTokenIndex := len(tokens) - 1 canBeNegativeNumber := i == 0 || // 表达式开始 tokens[prevTokenIndex] == "(" || // 左括号后 tokens[prevTokenIndex] == "," || // 逗号后(函数参数) isOperator(tokens[prevTokenIndex]) || // 运算符后 strings.ToUpper(tokens[prevTokenIndex]) == "THEN" || // THEN后 strings.ToUpper(tokens[prevTokenIndex]) == "ELSE" // ELSE后 if canBeNegativeNumber && i+1 < len(expr) && isDigit(expr[i+1]) { // 这是一个负数,解析整个数字 start := i i++ // 跳过负号 // 解析数字部分 for i < len(expr) && (isDigit(expr[i]) || expr[i] == '.') { i++ } tokens = append(tokens, expr[start:i]) continue } } tokens = append(tokens, string(ch)) i++ continue } // 处理比较运算符 if ch == '>' || ch == '<' || ch == '=' || ch == '!' { start := i i++ // 处理双字符运算符 if i < len(expr) { switch ch { case '>': if expr[i] == '=' { i++ tokens = append(tokens, ">=") continue } case '<': if expr[i] == '=' { i++ tokens = append(tokens, "<=") continue } else if expr[i] == '>' { i++ tokens = append(tokens, "<>") continue } case '=': if expr[i] == '=' { i++ tokens = append(tokens, "==") continue } case '!': if expr[i] == '=' { i++ tokens = append(tokens, "!=") continue } } } // 单字符运算符 tokens = append(tokens, expr[start:i]) continue } // 处理字符串字面量(单引号和双引号) if ch == '\'' || ch == '"' { quote := ch start := i i++ // 跳过开始引号 // 寻找结束引号 for i < len(expr) && expr[i] != quote { if expr[i] == '\\' && i+1 < len(expr) { i += 2 // 跳过转义字符 } else { i++ } } if i >= len(expr) { return nil, fmt.Errorf("unterminated string literal starting at position %d", start) } i++ // 跳过结束引号 tokens = append(tokens, expr[start:i]) continue } // 处理标识符(字段名或函数名) if isLetter(ch) { start := i i++ for i < len(expr) && (isLetter(expr[i]) || isDigit(expr[i]) || expr[i] == '_') { i++ } tokens = append(tokens, expr[start:i]) continue } // 未知字符 return nil, fmt.Errorf("unexpected character: %c at position %d", ch, i) } return tokens, nil } // parseExpression 解析表达式 func parseExpression(tokens []string) (*ExprNode, error) { if len(tokens) == 0 { return nil, fmt.Errorf("empty token list") } // 使用Shunting-yard算法处理运算符优先级 output := make([]*ExprNode, 0) operators := make([]string, 0) i := 0 for i < len(tokens) { token := tokens[i] // 处理数字 if isNumber(token) { output = append(output, &ExprNode{ Type: TypeNumber, Value: token, }) i++ continue } // 处理字符串字面量 if isStringLiteral(token) { output = append(output, &ExprNode{ Type: TypeString, Value: token, }) i++ continue } // 处理字段名或函数调用 if isIdentifier(token) { // 检查是否是逻辑运算符关键字 upperToken := strings.ToUpper(token) if upperToken == "AND" || upperToken == "OR" || upperToken == "NOT" { // 处理逻辑运算符 for len(operators) > 0 && operators[len(operators)-1] != "(" && operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[upperToken] { op := operators[len(operators)-1] operators = operators[:len(operators)-1] if len(output) < 2 { return nil, fmt.Errorf("not enough operands for operator: %s", op) } right := output[len(output)-1] left := output[len(output)-2] output = output[:len(output)-2] output = append(output, &ExprNode{ Type: TypeOperator, Value: op, Left: left, Right: right, }) } operators = append(operators, upperToken) i++ continue } // 检查是否是CASE表达式 if strings.ToUpper(token) == "CASE" { caseNode, newIndex, err := parseCaseExpression(tokens, i) if err != nil { return nil, err } output = append(output, caseNode) i = newIndex continue } // 检查下一个token是否是左括号,如果是则为函数调用 if i+1 < len(tokens) && tokens[i+1] == "(" { funcName := token i += 2 // 跳过函数名和左括号 // 解析函数参数 args, newIndex, err := parseFunctionArgs(tokens, i) if err != nil { return nil, err } output = append(output, &ExprNode{ Type: TypeFunction, Value: funcName, Args: args, }) i = newIndex continue } // 普通字段 output = append(output, &ExprNode{ Type: TypeField, Value: token, }) i++ continue } // 处理左括号 if token == "(" { operators = append(operators, token) i++ continue } // 处理右括号 if token == ")" { for len(operators) > 0 && operators[len(operators)-1] != "(" { op := operators[len(operators)-1] operators = operators[:len(operators)-1] if len(output) < 2 { return nil, fmt.Errorf("not enough operands for operator: %s", op) } right := output[len(output)-1] left := output[len(output)-2] output = output[:len(output)-2] output = append(output, &ExprNode{ Type: TypeOperator, Value: op, Left: left, Right: right, }) } if len(operators) == 0 || operators[len(operators)-1] != "(" { return nil, fmt.Errorf("mismatched parentheses") } operators = operators[:len(operators)-1] // 弹出左括号 i++ continue } // 处理运算符 if isOperator(token) { for len(operators) > 0 && operators[len(operators)-1] != "(" && operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[token] { op := operators[len(operators)-1] operators = operators[:len(operators)-1] if len(output) < 2 { return nil, fmt.Errorf("not enough operands for operator: %s", op) } right := output[len(output)-1] left := output[len(output)-2] output = output[:len(output)-2] output = append(output, &ExprNode{ Type: TypeOperator, Value: op, Left: left, Right: right, }) } operators = append(operators, token) i++ continue } // 处理逗号(在函数参数列表中处理) if token == "," { i++ continue } return nil, fmt.Errorf("unexpected token: %s", token) } // 处理剩余的运算符 for len(operators) > 0 { op := operators[len(operators)-1] operators = operators[:len(operators)-1] if op == "(" { return nil, fmt.Errorf("mismatched parentheses") } if len(output) < 2 { return nil, fmt.Errorf("not enough operands for operator: %s", op) } right := output[len(output)-1] left := output[len(output)-2] output = output[:len(output)-2] output = append(output, &ExprNode{ Type: TypeOperator, Value: op, Left: left, Right: right, }) } if len(output) != 1 { return nil, fmt.Errorf("invalid expression") } return output[0], nil } // parseFunctionArgs 解析函数参数 func parseFunctionArgs(tokens []string, startIndex int) ([]*ExprNode, int, error) { args := make([]*ExprNode, 0) i := startIndex // 处理空参数列表 if i < len(tokens) && tokens[i] == ")" { return args, i + 1, nil } for i < len(tokens) { // 解析参数表达式 argTokens := make([]string, 0) parenthesesCount := 0 for i < len(tokens) { token := tokens[i] if token == "(" { parenthesesCount++ } else if token == ")" { parenthesesCount-- if parenthesesCount < 0 { break } } else if token == "," && parenthesesCount == 0 { break } argTokens = append(argTokens, token) i++ } if len(argTokens) > 0 { arg, err := parseExpression(argTokens) if err != nil { return nil, 0, err } args = append(args, arg) } if i >= len(tokens) { return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments") } if tokens[i] == ")" { return args, i + 1, nil } if tokens[i] == "," { i++ continue } return nil, 0, fmt.Errorf("unexpected token in function arguments: %s", tokens[i]) } return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments") } // parseCaseExpression 解析CASE表达式 func parseCaseExpression(tokens []string, startIndex int) (*ExprNode, int, error) { if startIndex >= len(tokens) || strings.ToUpper(tokens[startIndex]) != "CASE" { return nil, startIndex, fmt.Errorf("expected CASE keyword") } caseNode := &ExprNode{ Type: TypeCase, WhenClauses: make([]WhenClause, 0), } i := startIndex + 1 // 跳过CASE关键字 // 检查是否是简单CASE表达式(CASE expr WHEN value1 THEN result1 ...) // 或搜索CASE表达式(CASE WHEN condition1 THEN result1 ...) if i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" { // 这是简单CASE表达式,需要解析CASE后面的表达式 caseExprTokens := make([]string, 0) // 收集CASE表达式直到遇到WHEN for i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" { caseExprTokens = append(caseExprTokens, tokens[i]) i++ } if len(caseExprTokens) == 0 { return nil, i, fmt.Errorf("expected expression after CASE") } // 对于简单的情况,直接处理单个token if len(caseExprTokens) == 1 { token := caseExprTokens[0] if isNumber(token) { caseNode.CaseExpr = &ExprNode{Type: TypeNumber, Value: token} } else if isStringLiteral(token) { caseNode.CaseExpr = &ExprNode{Type: TypeString, Value: token} } else if isIdentifier(token) { caseNode.CaseExpr = &ExprNode{Type: TypeField, Value: token} } else { return nil, i, fmt.Errorf("invalid CASE expression token: %s", token) } } else { // 对于复杂表达式,调用parseExpression caseExpr, err := parseExpression(caseExprTokens) if err != nil { return nil, i, fmt.Errorf("failed to parse CASE expression: %w", err) } caseNode.CaseExpr = caseExpr } } // 解析WHEN子句 for i < len(tokens) && strings.ToUpper(tokens[i]) == "WHEN" { i++ // 跳过WHEN关键字 // 收集WHEN条件直到遇到THEN conditionTokens := make([]string, 0) for i < len(tokens) && strings.ToUpper(tokens[i]) != "THEN" { conditionTokens = append(conditionTokens, tokens[i]) i++ } if len(conditionTokens) == 0 { return nil, i, fmt.Errorf("expected condition after WHEN") } if i >= len(tokens) || strings.ToUpper(tokens[i]) != "THEN" { return nil, i, fmt.Errorf("expected THEN after WHEN condition") } i++ // 跳过THEN关键字 // 收集THEN结果直到遇到WHEN、ELSE或END resultTokens := make([]string, 0) for i < len(tokens) { upper := strings.ToUpper(tokens[i]) if upper == "WHEN" || upper == "ELSE" || upper == "END" { break } resultTokens = append(resultTokens, tokens[i]) i++ } if len(resultTokens) == 0 { return nil, i, fmt.Errorf("expected result after THEN") } // 解析条件和结果表达式 conditionExpr, err := parseExpression(conditionTokens) if err != nil { return nil, i, fmt.Errorf("failed to parse WHEN condition: %w", err) } resultExpr, err := parseExpression(resultTokens) if err != nil { return nil, i, fmt.Errorf("failed to parse THEN result: %w", err) } // 添加WHEN子句 caseNode.WhenClauses = append(caseNode.WhenClauses, WhenClause{ Condition: conditionExpr, Result: resultExpr, }) } // 检查是否有ELSE子句 if i < len(tokens) && strings.ToUpper(tokens[i]) == "ELSE" { i++ // 跳过ELSE关键字 // 收集ELSE结果直到遇到END elseTokens := make([]string, 0) for i < len(tokens) && strings.ToUpper(tokens[i]) != "END" { elseTokens = append(elseTokens, tokens[i]) i++ } if len(elseTokens) == 0 { return nil, i, fmt.Errorf("expected result after ELSE") } // 解析ELSE表达式 elseExpr, err := parseExpression(elseTokens) if err != nil { return nil, i, fmt.Errorf("failed to parse ELSE result: %w", err) } caseNode.ElseExpr = elseExpr } // 检查END关键字 if i >= len(tokens) || strings.ToUpper(tokens[i]) != "END" { return nil, i, fmt.Errorf("expected END to close CASE expression") } i++ // 跳过END关键字 // 验证至少有一个WHEN子句 if len(caseNode.WhenClauses) == 0 { return nil, i, fmt.Errorf("CASE expression must have at least one WHEN clause") } return caseNode, i, nil } // 辅助函数 func isDigit(ch byte) bool { return ch >= '0' && ch <= '9' } func isLetter(ch byte) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') } func isNumber(s string) bool { _, err := strconv.ParseFloat(s, 64) return err == nil } func isIdentifier(s string) bool { if len(s) == 0 { return false } if !isLetter(s[0]) && s[0] != '_' { return false } for i := 1; i < len(s); i++ { if !isLetter(s[i]) && !isDigit(s[i]) && s[i] != '_' { return false } } return true } func isOperator(s string) bool { switch s { case "+", "-", "*", "/", "%", "^": return true case ">", "<", ">=", "<=", "==", "=", "!=", "<>": return true case "AND", "OR", "NOT": return true default: return false } } func isStringLiteral(expr string) bool { return len(expr) > 1 && (expr[0] == '\'' || expr[0] == '"') && expr[len(expr)-1] == expr[0] }