Files
streamsql/expr/expression.go
T
2025-08-05 11:25:49 +08:00

2450 lines
62 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package expr
import (
"fmt"
"math"
"strconv"
"strings"
"github.com/rulego/streamsql/functions"
"github.com/rulego/streamsql/utils/fieldpath"
)
// Expression types
const (
TypeNumber = "number" // Number constant
TypeField = "field" // Field reference
TypeOperator = "operator" // Operator
TypeFunction = "function" // Function call
TypeParenthesis = "parenthesis" // Parenthesis
TypeCase = "case" // CASE expression
TypeString = "string" // String constant
)
// Operator precedence
var operatorPrecedence = map[string]int{
"OR": 1,
"AND": 2,
"==": 3, "=": 3, "!=": 3, "<>": 3,
">": 4, "<": 4, ">=": 4, "<=": 4, "LIKE": 4, "IS": 4,
"+": 5, "-": 5,
"*": 6, "/": 6, "%": 6,
"^": 7, // Power operation
}
// WhenClause represents a WHEN clause in CASE expression
type WhenClause struct {
Condition *ExprNode // WHEN condition
Result *ExprNode // THEN result
}
// ExprNode represents an expression node
type ExprNode struct {
Type string
Value string
Left *ExprNode
Right *ExprNode
Args []*ExprNode // Arguments for function calls
// Fields specific to CASE expressions
CaseExpr *ExprNode // Expression after CASE (simple CASE)
WhenClauses []WhenClause // List of WHEN clauses
ElseExpr *ExprNode // ELSE expression
}
// Expression represents a computable expression
type Expression struct {
Root *ExprNode
useExprLang bool // Whether to use expr-lang/expr
exprLangExpression string // expr-lang expression string
}
// NewExpression creates a new expression
func NewExpression(exprStr string) (*Expression, error) {
// Perform basic syntax validation
if err := validateBasicSyntax(exprStr); err != nil {
return nil, err
}
// First try using custom parser
tokens, err := tokenize(exprStr)
if err != nil {
// If custom parsing fails, mark to use expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
root, err := parseExpression(tokens)
if err != nil {
// If custom parsing fails, mark to use expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
return &Expression{
Root: root,
useExprLang: false,
}, nil
}
// validateBasicSyntax performs basic syntax validation
func validateBasicSyntax(exprStr string) error {
// Check empty expression
trimmed := strings.TrimSpace(exprStr)
if trimmed == "" {
return fmt.Errorf("empty expression")
}
// 检查不匹配的括号
parenthesesCount := 0
for _, ch := range trimmed {
if ch == '(' {
parenthesesCount++
} else if ch == ')' {
parenthesesCount--
if parenthesesCount < 0 {
return fmt.Errorf("mismatched parentheses")
}
}
}
if parenthesesCount != 0 {
return fmt.Errorf("mismatched parentheses")
}
// 检查无效字符
for i, ch := range trimmed {
// 允许的字符:字母、数字、运算符、括号、点、下划线、空格、引号
if !isValidChar(ch) {
return fmt.Errorf("invalid character '%c' at position %d", ch, i)
}
}
// 检查表达式开头和结尾的运算符
if err := checkExpressionStartEnd(trimmed); err != nil {
return err
}
// 检查连续运算符
if err := checkConsecutiveOperators(trimmed); err != nil {
return err
}
return nil
}
// checkExpressionStartEnd checks if expression starts or ends with an operator
func checkExpressionStartEnd(expr string) error {
operators := []string{"+", "*", "/", "%", "^", "==", "!=", ">=", "<=", ">", "<"}
// 检查表达式开头(允许负号,因为它是合法的负数表示)
for _, op := range operators {
if strings.HasPrefix(expr, op) {
return fmt.Errorf("expression cannot start with operator")
}
}
// 检查表达式结尾
for _, op := range operators {
if strings.HasSuffix(expr, op) {
return fmt.Errorf("expression cannot end with operator")
}
}
return nil
}
// checkConsecutiveOperators checks for consecutive operators
func checkConsecutiveOperators(expr string) error {
// Simplified consecutive operator check: look for obvious double operator patterns
// But allow comparison operators followed by negative numbers
operators := []string{"+", "-", "*", "/", "%", "^", "==", "!=", ">=", "<=", ">", "<"}
comparisonOps := []string{"==", "!=", ">=", "<=", ">", "<"}
for i := 0; i < len(expr)-1; i++ {
// 跳过空白字符
if expr[i] == ' ' || expr[i] == '\t' {
continue
}
// 检查当前位置是否是运算符
isCurrentOp := false
currentOpLen := 0
currentOp := ""
for _, op := range operators {
if i+len(op) <= len(expr) && expr[i:i+len(op)] == op {
isCurrentOp = true
currentOpLen = len(op)
currentOp = op
break
}
}
if isCurrentOp {
// 查找下一个非空白字符
nextPos := i + currentOpLen
for nextPos < len(expr) && (expr[nextPos] == ' ' || expr[nextPos] == '\t') {
nextPos++
}
// 检查下一个字符是否也是运算符
if nextPos < len(expr) {
// 特殊处理:如果当前是比较运算符,下一个是负号,且负号后跟数字,则允许
isCurrentComparison := false
for _, compOp := range comparisonOps {
if currentOp == compOp {
isCurrentComparison = true
break
}
}
// 检查是否是负数的情况
if isCurrentComparison && nextPos < len(expr) && expr[nextPos] == '-' {
// 检查负号后是否跟数字
digitPos := nextPos + 1
for digitPos < len(expr) && (expr[digitPos] == ' ' || expr[digitPos] == '\t') {
digitPos++
}
if digitPos < len(expr) && expr[digitPos] >= '0' && expr[digitPos] <= '9' {
// 这是比较运算符后跟负数,允许通过
i = nextPos // 跳过到负号位置
continue
}
}
// 特殊处理:如果当前是幂运算符(^),下一个是负号,且负号后跟数字,则允许
if currentOp == "^" && nextPos < len(expr) && expr[nextPos] == '-' {
// 检查负号后是否跟数字
digitPos := nextPos + 1
for digitPos < len(expr) && (expr[digitPos] == ' ' || expr[digitPos] == '\t') {
digitPos++
}
if digitPos < len(expr) && expr[digitPos] >= '0' && expr[digitPos] <= '9' {
// 这是幂运算符后跟负数,允许通过
i = nextPos // 跳过到负号位置
continue
}
}
// 检查其他连续运算符
for _, op := range operators {
if nextPos+len(op) <= len(expr) && expr[nextPos:nextPos+len(op)] == op {
// 如果不是允许的负数情况,则报错
return fmt.Errorf("consecutive operators found: '%s' followed by '%s'",
currentOp, op)
}
}
}
// 跳过当前运算符
i += currentOpLen - 1
}
}
return nil
}
// isValidChar checks if a character is valid
func isValidChar(ch rune) bool {
// Letters and digits
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') {
return true
}
// Special characters
switch ch {
case ' ', '\t', '\n', '\r': // Whitespace characters
return true
case '+', '-', '*', '/', '%', '^': // Arithmetic operators
return true
case '(', ')', ',': // Parentheses and comma
return true
case '>', '<', '=', '!': // Comparison operators
return true
case '\'', '"': // Quotes
return true
case '.', '_': // Dot and underscore
return true
case '$': // Dollar sign (for JSON paths etc.)
return true
default:
return false
}
}
// Evaluate calculates the value of the expression
func (e *Expression) Evaluate(data map[string]interface{}) (float64, error) {
if e.useExprLang {
return e.evaluateWithExprLang(data)
}
return evaluateNode(e.Root, data)
}
// evaluateWithExprLang evaluates expression using expr-lang/expr
func (e *Expression) evaluateWithExprLang(data map[string]interface{}) (float64, error) {
// Use bridge to evaluate expression
bridge := functions.GetExprBridge()
result, err := bridge.EvaluateExpression(e.exprLangExpression, data)
if err != nil {
return 0, err
}
// Try to convert result to float64
switch v := result.(type) {
case float64:
return v, nil
case float32:
return float64(v), nil
case int:
return float64(v), nil
case int32:
return float64(v), nil
case int64:
return float64(v), nil
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f, nil
}
return 0, fmt.Errorf("cannot convert string result '%s' to float64", v)
default:
return 0, fmt.Errorf("expression result type %T is not convertible to float64", result)
}
}
// GetFields gets all fields referenced in the expression
func (e *Expression) GetFields() []string {
if e.useExprLang {
// For expr-lang expressions, need to parse field references
// Simplified handling here, should use AST analysis in practice
return extractFieldsFromExprLang(e.exprLangExpression)
}
fields := make(map[string]bool)
collectFields(e.Root, fields)
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
return result
}
// extractFieldsFromExprLang extracts field references from expr-lang expression (simplified version)
func extractFieldsFromExprLang(expression string) []string {
// This is a simplified implementation, should use AST parsing in practice
// Temporarily use regex or simple string parsing
fields := make(map[string]bool)
// Simple field extraction: find identifier patterns, support dot-separated nested fields
tokens := strings.FieldsFunc(expression, func(c rune) bool {
return !(c >= 'a' && c <= 'z') && !(c >= 'A' && c <= 'Z') && !(c >= '0' && c <= '9') && c != '_' && c != '.'
})
for _, token := range tokens {
if isValidFieldIdentifier(token) && !isNumber(token) && !isFunctionOrKeyword(token) {
fields[token] = true
}
}
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
return result
}
// isValidFieldIdentifier checks if it's a valid field identifier (supports dot-separated nested fields)
func isValidFieldIdentifier(s string) bool {
if len(s) == 0 {
return false
}
// Split dot-separated fields
parts := strings.Split(s, ".")
for _, part := range parts {
if !isIdentifier(part) {
return false
}
}
return true
}
// isFunctionOrKeyword checks if it's a function name or keyword
func isFunctionOrKeyword(token string) bool {
// Check if it's a known function or keyword
keywords := []string{
"and", "or", "not", "true", "false", "nil", "null", "is",
"if", "else", "then", "in", "contains", "matches",
// CASE expression keywords
"case", "when", "then", "else", "end",
}
for _, keyword := range keywords {
if strings.ToLower(token) == keyword {
return true
}
}
// Check if it's a registered function
bridge := functions.GetExprBridge()
_, exists, _ := bridge.ResolveFunction(token)
return exists
}
// collectFields collects all fields in the expression
func collectFields(node *ExprNode, fields map[string]bool) {
if node == nil {
return
}
if node.Type == TypeField {
fields[node.Value] = true
}
// Handle field collection for CASE expressions
if node.Type == TypeCase {
// Collect fields from CASE expression itself
if node.CaseExpr != nil {
collectFields(node.CaseExpr, fields)
}
// Collect fields from all WHEN clauses
for _, whenClause := range node.WhenClauses {
collectFields(whenClause.Condition, fields)
collectFields(whenClause.Result, fields)
}
// Collect fields from ELSE expression
if node.ElseExpr != nil {
collectFields(node.ElseExpr, fields)
}
return
}
collectFields(node.Left, fields)
collectFields(node.Right, fields)
for _, arg := range node.Args {
collectFields(arg, fields)
}
}
// evaluateNode calculates the value of a node
func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) {
if node == nil {
return 0, fmt.Errorf("null expression node")
}
switch node.Type {
case TypeNumber:
return strconv.ParseFloat(node.Value, 64)
case TypeString:
// Handle string type, remove quotes and try to convert to number
// If conversion fails, return error (since this function returns float64)
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1] // Remove quotes
}
// Try to convert to number
if f, err := strconv.ParseFloat(value, 64); err == nil {
return f, nil
}
// For string comparison, we need to return a hash value or error
// Simplified handling here, convert string to its length (as temporary solution)
return float64(len(value)), nil
case TypeField:
// Handle backtick identifiers, remove backticks
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks
}
// Support nested field access
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
// Try to convert to float64
if floatVal, err := convertToFloat(val); err == nil {
return floatVal, nil
}
// If cannot convert to number, return error
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
}
} else {
// Original simple field access
if val, found := data[fieldName]; found {
// Try to convert to float64
if floatVal, err := convertToFloat(val); err == nil {
return floatVal, nil
}
// If cannot convert to number, return error
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
}
}
return 0, fmt.Errorf("field '%s' not found", fieldName)
case TypeOperator:
// Check if this is a comparison operator
if isComparisonOperator(node.Value) {
// For comparison operators, use evaluateNodeValue to get original types
leftValue, err := evaluateNodeValue(node.Left, data)
if err != nil {
return 0, err
}
rightValue, err := evaluateNodeValue(node.Right, data)
if err != nil {
return 0, err
}
// Perform comparison and convert boolean to number
result, err := compareValues(leftValue, rightValue, node.Value)
if err != nil {
return 0, err
}
if result {
return 1.0, nil
}
return 0.0, nil
}
// For arithmetic operators, calculate numeric values
left, err := evaluateNode(node.Left, data)
if err != nil {
return 0, err
}
right, err := evaluateNode(node.Right, data)
if err != nil {
return 0, err
}
// Perform operation
switch node.Value {
case "+":
return left + right, nil
case "-":
return left - right, nil
case "*":
return left * right, nil
case "/":
if right == 0 {
return 0, fmt.Errorf("division by zero")
}
return left / right, nil
case "%":
if right == 0 {
return 0, fmt.Errorf("modulo by zero")
}
return math.Mod(left, right), nil
case "^":
return math.Pow(left, right), nil
default:
return 0, fmt.Errorf("unknown operator: %s", node.Value)
}
case TypeFunction:
// First check if it's a function in the new function registration system
fn, exists := functions.Get(node.Value)
if exists {
// Calculate all arguments but keep original types
args := make([]interface{}, len(node.Args))
for i, arg := range node.Args {
// Use evaluateNodeValue to get original type values
val, err := evaluateNodeValue(arg, data)
if err != nil {
return 0, err
}
args[i] = val
}
// Create function execution context
ctx := &functions.FunctionContext{
Data: data,
}
// Execute function
result, err := fn.Execute(ctx, args)
if err != nil {
return 0, err
}
// Convert result to float64
switch r := result.(type) {
case float64:
return r, nil
case float32:
return float64(r), nil
case int:
return float64(r), nil
case int32:
return float64(r), nil
case int64:
return float64(r), nil
case string:
// For string results, try to convert to number, if failed return string length
if f, err := strconv.ParseFloat(r, 64); err == nil {
return f, nil
}
return float64(len(r)), nil
case bool:
// Boolean conversion: true=1, false=0
if r {
return 1.0, nil
}
return 0.0, nil
default:
return 0, fmt.Errorf("function %s returned unsupported type for numeric conversion: %T", node.Value, result)
}
}
// Fall back to built-in function handling (maintain backward compatibility)
return evaluateBuiltinFunction(node, data)
case TypeCase:
// Handle CASE expression
return evaluateCaseExpression(node, data)
}
return 0, fmt.Errorf("unknown node type: %s", node.Type)
}
// evaluateBuiltinFunction handles built-in functions (backward compatibility)
func evaluateBuiltinFunction(node *ExprNode, data map[string]interface{}) (float64, error) {
switch strings.ToLower(node.Value) {
case "abs":
if len(node.Args) != 1 {
return 0, fmt.Errorf("abs function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Abs(arg), nil
case "sqrt":
if len(node.Args) != 1 {
return 0, fmt.Errorf("sqrt function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
if arg < 0 {
return 0, fmt.Errorf("sqrt of negative number")
}
return math.Sqrt(arg), nil
case "sin":
if len(node.Args) != 1 {
return 0, fmt.Errorf("sin function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Sin(arg), nil
case "cos":
if len(node.Args) != 1 {
return 0, fmt.Errorf("cos function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Cos(arg), nil
case "tan":
if len(node.Args) != 1 {
return 0, fmt.Errorf("tan function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Tan(arg), nil
case "floor":
if len(node.Args) != 1 {
return 0, fmt.Errorf("floor function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Floor(arg), nil
case "ceil":
if len(node.Args) != 1 {
return 0, fmt.Errorf("ceil function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Ceil(arg), nil
case "round":
if len(node.Args) != 1 {
return 0, fmt.Errorf("round function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Round(arg), nil
case "pow":
if len(node.Args) != 2 {
return 0, fmt.Errorf("pow function requires exactly 2 arguments")
}
base, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
exponent, err := evaluateNode(node.Args[1], data)
if err != nil {
return 0, err
}
return math.Pow(base, exponent), nil
case "max":
if len(node.Args) < 1 {
return 0, fmt.Errorf("max function requires at least 1 argument")
}
maxVal, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
for i := 1; i < len(node.Args); i++ {
arg, err := evaluateNode(node.Args[i], data)
if err != nil {
return 0, err
}
if arg > maxVal {
maxVal = arg
}
}
return maxVal, nil
case "min":
if len(node.Args) < 1 {
return 0, fmt.Errorf("min function requires at least 1 argument")
}
minVal, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
for i := 1; i < len(node.Args); i++ {
arg, err := evaluateNode(node.Args[i], data)
if err != nil {
return 0, err
}
if arg < minVal {
minVal = arg
}
}
return minVal, nil
case "log":
if len(node.Args) != 1 {
return 0, fmt.Errorf("log function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
if arg <= 0 {
return 0, fmt.Errorf("log of non-positive number")
}
return math.Log(arg), nil
case "log10":
if len(node.Args) != 1 {
return 0, fmt.Errorf("log10 function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
if arg <= 0 {
return 0, fmt.Errorf("log10 of non-positive number")
}
return math.Log10(arg), nil
case "exp":
if len(node.Args) != 1 {
return 0, fmt.Errorf("exp function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Exp(arg), nil
case "len":
if len(node.Args) != 1 {
return 0, fmt.Errorf("len function requires exactly 1 argument")
}
// Use evaluateNodeValue to get the original value
arg, err := evaluateNodeValue(node.Args[0], data)
if err != nil {
return 0, err
}
// Convert to string and get length
strVal := fmt.Sprintf("%v", arg)
return float64(len(strVal)), nil
default:
return 0, fmt.Errorf("unknown function: %s", node.Value)
}
}
// evaluateCaseExpression evaluates CASE expression
func evaluateCaseExpression(node *ExprNode, data map[string]interface{}) (float64, error) {
if node.Type != TypeCase {
return 0, fmt.Errorf("node is not a CASE expression")
}
// Handle simple CASE expression (CASE expr WHEN value1 THEN result1 ...)
if node.CaseExpr != nil {
// Calculate the value of expression after CASE
caseValue, err := evaluateNodeValue(node.CaseExpr, data)
if err != nil {
return 0, err
}
// Iterate through WHEN clauses to find matching values
for _, whenClause := range node.WhenClauses {
conditionValue, err := evaluateNodeValue(whenClause.Condition, data)
if err != nil {
return 0, err
}
// Compare if values are equal
isEqual, err := compareValues(caseValue, conditionValue, "==")
if err != nil {
return 0, err
}
if isEqual {
return evaluateNode(whenClause.Result, data)
}
}
} else {
// Handle search CASE expression (CASE WHEN condition1 THEN result1 ...)
for _, whenClause := range node.WhenClauses {
// Evaluate WHEN condition, need special handling for boolean expressions
conditionResult, err := evaluateBooleanCondition(whenClause.Condition, data)
if err != nil {
return 0, err
}
// If condition is true, return corresponding result
if conditionResult {
return evaluateNode(whenClause.Result, data)
}
}
}
// If no WHEN clause matches, execute ELSE clause
if node.ElseExpr != nil {
return evaluateNode(node.ElseExpr, data)
}
// If no ELSE clause, SQL standard returns NULL, here return 0
return 0, nil
}
// evaluateBooleanCondition evaluates boolean condition expression
func evaluateBooleanCondition(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil {
return false, fmt.Errorf("null condition expression")
}
// Handle logical operators (implement short-circuit evaluation)
if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") {
leftBool, err := evaluateBooleanCondition(node.Left, data)
if err != nil {
return false, err
}
// Short-circuit evaluation: for AND, if left is false, return false immediately
if node.Value == "AND" && !leftBool {
return false, nil
}
// Short-circuit evaluation: for OR, if left is true, return true immediately
if node.Value == "OR" && leftBool {
return true, nil
}
// Only evaluate right expression when needed
rightBool, err := evaluateBooleanCondition(node.Right, data)
if err != nil {
return false, err
}
switch node.Value {
case "AND":
return leftBool && rightBool, nil
case "OR":
return leftBool || rightBool, nil
}
}
// Handle IS NULL and IS NOT NULL special cases
if node.Type == TypeOperator && node.Value == "IS" {
return evaluateIsCondition(node, data)
}
// Handle comparison operators
if node.Type == TypeOperator {
leftValue, err := evaluateNodeValue(node.Left, data)
if err != nil {
return false, err
}
rightValue, err := evaluateNodeValue(node.Right, data)
if err != nil {
return false, err
}
return compareValues(leftValue, rightValue, node.Value)
}
// For other expressions, calculate numeric value and convert to boolean
result, err := evaluateNode(node, data)
if err != nil {
return false, err
}
// Non-zero values are true, zero values are false
return result != 0, nil
}
// evaluateIsCondition handles IS NULL and IS NOT NULL conditions
func evaluateIsCondition(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil || node.Left == nil || node.Right == nil {
return false, fmt.Errorf("invalid IS condition")
}
// Get left side value
leftValue, err := evaluateNodeValue(node.Left, data)
if err != nil {
// If field doesn't exist, consider it as null
leftValue = nil
}
// Check if right side is NULL or NOT NULL
if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" {
// IS NULL
return leftValue == nil, nil
}
// Check if it's IS NOT NULL
if node.Right.Type == TypeOperator && node.Right.Value == "NOT" &&
node.Right.Right != nil && node.Right.Right.Type == TypeField &&
strings.ToUpper(node.Right.Right.Value) == "NULL" {
// IS NOT NULL
return leftValue != nil, nil
}
// Other IS comparisons (like IS TRUE, IS FALSE etc., not supported yet)
rightValue, err := evaluateNodeValue(node.Right, data)
if err != nil {
return false, err
}
return compareValues(leftValue, rightValue, "==")
}
// evaluateNodeValue calculates node value, returns interface{} to support different types
func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{}, error) {
if node == nil {
return nil, fmt.Errorf("null expression node")
}
switch node.Type {
case TypeNumber:
return strconv.ParseFloat(node.Value, 64)
case TypeString:
// Remove quotes
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1]
}
return value, nil
case TypeField:
// Handle backtick identifiers, remove backticks
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks
}
// Support nested field access
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
return val, nil
}
} else {
// Original simple field access
if val, found := data[fieldName]; found {
return val, nil
}
}
return nil, fmt.Errorf("field '%s' not found", fieldName)
default:
// For other types, fall back to numeric calculation
return evaluateNode(node, data)
}
}
// compareValues compares two values
func compareValues(left, right interface{}, operator string) (bool, error) {
// Try string comparison
leftStr, leftIsStr := left.(string)
rightStr, rightIsStr := right.(string)
if leftIsStr && rightIsStr {
switch operator {
case "==", "=":
return leftStr == rightStr, nil
case "!=", "<>":
return leftStr != rightStr, nil
case ">":
return leftStr > rightStr, nil
case ">=":
return leftStr >= rightStr, nil
case "<":
return leftStr < rightStr, nil
case "<=":
return leftStr <= rightStr, nil
case "LIKE":
return matchesLikePattern(leftStr, rightStr), nil
default:
return false, fmt.Errorf("unsupported string comparison operator: %s", operator)
}
}
// Convert to numeric values for comparison
leftNum, err1 := convertToFloat(left)
rightNum, err2 := convertToFloat(right)
if err1 != nil || err2 != nil {
return false, fmt.Errorf("cannot compare values: %v and %v", left, right)
}
switch operator {
case ">":
return leftNum > rightNum, nil
case ">=":
return leftNum >= rightNum, nil
case "<":
return leftNum < rightNum, nil
case "<=":
return leftNum <= rightNum, nil
case "==", "=":
return math.Abs(leftNum-rightNum) < 1e-9, nil
case "!=", "<>":
return math.Abs(leftNum-rightNum) >= 1e-9, nil
default:
return false, fmt.Errorf("unsupported comparison operator: %s", operator)
}
}
// matchesLikePattern implements LIKE pattern matching
// Supports % (matches any character sequence) and _ (matches single character)
func matchesLikePattern(text, pattern string) bool {
return likeMatch(text, pattern, 0, 0)
}
// likeMatch recursively implements LIKE matching algorithm
func likeMatch(text, pattern string, textIndex, patternIndex int) bool {
// If pattern matching is complete
if patternIndex >= len(pattern) {
return textIndex >= len(text) // Text should also be completely matched
}
// If text has ended but pattern still has non-% characters, no match
if textIndex >= len(text) {
// Check if remaining pattern consists only of %
for i := patternIndex; i < len(pattern); i++ {
if pattern[i] != '%' {
return false
}
}
return true
}
switch pattern[patternIndex] {
case '%':
// % can match 0 or more characters
// Try matching 0 characters (skip %)
if likeMatch(text, pattern, textIndex, patternIndex+1) {
return true
}
// Try matching 1 or more characters
for i := textIndex; i < len(text); i++ {
if likeMatch(text, pattern, i+1, patternIndex+1) {
return true
}
}
return false
case '_':
// _ matches any single character
return likeMatch(text, pattern, textIndex+1, patternIndex+1)
default:
// Regular characters must match exactly
if text[textIndex] == pattern[patternIndex] {
return likeMatch(text, pattern, textIndex+1, patternIndex+1)
}
return false
}
}
// convertToFloat converts value to float64
func convertToFloat(val interface{}) (float64, error) {
switch v := val.(type) {
case float64:
if math.IsNaN(v) {
return 0, fmt.Errorf("NaN value detected")
}
return v, nil
case float32:
if math.IsNaN(float64(v)) {
return 0, fmt.Errorf("NaN value detected")
}
return float64(v), nil
case int:
return float64(v), nil
case int32:
return float64(v), nil
case int64:
return float64(v), nil
case bool:
if v {
return 1.0, nil
}
return 0.0, nil
case string:
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, err
}
if math.IsNaN(f) {
return 0, fmt.Errorf("NaN value detected")
}
return f, nil
default:
return 0, fmt.Errorf("cannot convert %T to float64", val)
}
}
// tokenize converts expression string to token list
func tokenize(expr string) ([]string, error) {
expr = strings.TrimSpace(expr)
if expr == "" {
return nil, fmt.Errorf("empty expression")
}
tokens := make([]string, 0)
i := 0
for i < len(expr) {
ch := expr[i]
// Skip whitespace characters
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
i++
continue
}
// Handle numbers
if isDigit(ch) || (ch == '.' && i+1 < len(expr) && isDigit(expr[i+1])) {
start := i
hasDot := ch == '.'
i++
for i < len(expr) && (isDigit(expr[i]) || (expr[i] == '.' && !hasDot)) {
if expr[i] == '.' {
hasDot = true
}
i++
}
tokens = append(tokens, expr[start:i])
continue
}
// Handle operators and parentheses
if ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '%' || ch == '^' ||
ch == '(' || ch == ')' || ch == ',' {
// Special handling for minus sign: if it's minus and preceded by operator, parenthesis or start position, it might be negative number
if ch == '-' {
// Check if it could be the start of a negative number
canBeNegativeNumber := i == 0 || // Expression start
len(tokens) == 0 // When tokens is empty, it could also be negative number start
// Only check previous token when tokens is not empty
if len(tokens) > 0 {
prevToken := tokens[len(tokens)-1]
canBeNegativeNumber = canBeNegativeNumber ||
prevToken == "(" || // After left parenthesis
prevToken == "," || // After comma (function parameter)
isOperator(prevToken) || // After operator
isComparisonOperator(prevToken) || // After comparison operator
strings.ToUpper(prevToken) == "THEN" || // After THEN
strings.ToUpper(prevToken) == "ELSE" || // After ELSE
strings.ToUpper(prevToken) == "WHEN" || // After WHEN
strings.ToUpper(prevToken) == "AND" || // After AND
strings.ToUpper(prevToken) == "OR" // After OR
}
if canBeNegativeNumber && i+1 < len(expr) && isDigit(expr[i+1]) {
// This is a negative number, parse the entire number
start := i
i++ // Skip minus sign
// Parse numeric part
for i < len(expr) && (isDigit(expr[i]) || expr[i] == '.') {
i++
}
tokens = append(tokens, expr[start:i])
continue
}
}
tokens = append(tokens, string(ch))
i++
continue
}
// Handle comparison operators
if ch == '>' || ch == '<' || ch == '=' || ch == '!' {
start := i
i++
// Handle two-character operators
if i < len(expr) {
switch ch {
case '>':
if expr[i] == '=' {
i++
tokens = append(tokens, ">=")
continue
}
case '<':
if expr[i] == '=' {
i++
tokens = append(tokens, "<=")
continue
} else if expr[i] == '>' {
i++
tokens = append(tokens, "<>")
continue
}
case '=':
if expr[i] == '=' {
i++
tokens = append(tokens, "==")
continue
}
case '!':
if expr[i] == '=' {
i++
tokens = append(tokens, "!=")
continue
}
}
}
// Single character operator
tokens = append(tokens, expr[start:i])
continue
}
// Handle string literals (single and double quotes)
if ch == '\'' || ch == '"' {
quote := ch
start := i
i++ // Skip opening quote
// Find closing quote
for i < len(expr) && expr[i] != quote {
if expr[i] == '\\' && i+1 < len(expr) {
i += 2 // Skip escape character
} else {
i++
}
}
if i >= len(expr) {
return nil, fmt.Errorf("unterminated string literal starting at position %d", start)
}
i++ // Skip closing quote
tokens = append(tokens, expr[start:i])
continue
}
// Handle backtick identifiers
if ch == '`' {
start := i
i++ // Skip opening backtick
// Find closing backtick
for i < len(expr) && expr[i] != '`' {
i++
}
if i >= len(expr) {
return nil, fmt.Errorf("unterminated quoted identifier starting at position %d", start)
}
i++ // Skip closing backtick
tokens = append(tokens, expr[start:i])
continue
}
// Handle identifiers (field names or function names)
if isLetter(ch) {
start := i
i++
for i < len(expr) && (isLetter(expr[i]) || isDigit(expr[i]) || expr[i] == '_') {
i++
}
tokens = append(tokens, expr[start:i])
continue
}
// Unknown character
return nil, fmt.Errorf("unexpected character: %c at position %d", ch, i)
}
return tokens, nil
}
// parseExpression parses expression
func parseExpression(tokens []string) (*ExprNode, error) {
if len(tokens) == 0 {
return nil, fmt.Errorf("empty token list")
}
// Use Shunting-yard algorithm to handle operator precedence
output := make([]*ExprNode, 0)
operators := make([]string, 0)
i := 0
for i < len(tokens) {
token := tokens[i]
// Handle numbers
if isNumber(token) {
output = append(output, &ExprNode{
Type: TypeNumber,
Value: token,
})
i++
continue
}
// Handle string literals
if isStringLiteral(token) {
output = append(output, &ExprNode{
Type: TypeString,
Value: token,
})
i++
continue
}
// Handle field names or function calls
if isIdentifier(token) {
// Check if it's a logical operator keyword
upperToken := strings.ToUpper(token)
if upperToken == "AND" || upperToken == "OR" || upperToken == "NOT" || upperToken == "LIKE" {
// Handle logical operators
for len(operators) > 0 && operators[len(operators)-1] != "(" &&
operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[upperToken] {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
operators = append(operators, upperToken)
i++
continue
}
// Special handling for IS operator, need to check subsequent NOT NULL combination
if upperToken == "IS" {
// Handle pending operators
for len(operators) > 0 && operators[len(operators)-1] != "(" &&
operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence["IS"] {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
// Check if it's IS NOT NULL pattern
if i+2 < len(tokens) &&
strings.ToUpper(tokens[i+1]) == "NOT" &&
strings.ToUpper(tokens[i+2]) == "NULL" {
// This is IS NOT NULL, create special right-side node structure
notNullNode := &ExprNode{
Type: TypeOperator,
Value: "NOT",
Right: &ExprNode{
Type: TypeField,
Value: "NULL",
},
}
operators = append(operators, "IS")
output = append(output, notNullNode)
i += 3 // Skip three tokens: IS NOT NULL
continue
} else if i+1 < len(tokens) && strings.ToUpper(tokens[i+1]) == "NULL" {
// This is IS NULL, create NULL node
nullNode := &ExprNode{
Type: TypeField,
Value: "NULL",
}
operators = append(operators, "IS")
output = append(output, nullNode)
i += 2 // Skip two tokens: IS NULL
continue
} else {
// Regular IS operator
operators = append(operators, "IS")
i++
continue
}
}
// Check if it's CASE expression
if strings.ToUpper(token) == "CASE" {
caseNode, newIndex, err := parseCaseExpression(tokens, i)
if err != nil {
return nil, err
}
output = append(output, caseNode)
i = newIndex
continue
}
// Check if next token is left parenthesis, if so it's a function call
if i+1 < len(tokens) && tokens[i+1] == "(" {
funcName := token
i += 2 // Skip function name and left parenthesis
// Parse function arguments
args, newIndex, err := parseFunctionArgs(tokens, i)
if err != nil {
return nil, err
}
output = append(output, &ExprNode{
Type: TypeFunction,
Value: funcName,
Args: args,
})
i = newIndex
continue
}
// Regular field
output = append(output, &ExprNode{
Type: TypeField,
Value: token,
})
i++
continue
}
// Handle left parenthesis
if token == "(" {
operators = append(operators, token)
i++
continue
}
// Handle right parenthesis
if token == ")" {
for len(operators) > 0 && operators[len(operators)-1] != "(" {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
if len(operators) == 0 || operators[len(operators)-1] != "(" {
return nil, fmt.Errorf("mismatched parentheses")
}
operators = operators[:len(operators)-1] // Pop left parenthesis
i++
continue
}
// Handle operators
if isOperator(token) {
for len(operators) > 0 && operators[len(operators)-1] != "(" &&
operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[token] {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
operators = append(operators, token)
i++
continue
}
// Handle comma (processed in function argument list)
if token == "," {
i++
continue
}
return nil, fmt.Errorf("unexpected token: %s", token)
}
// Handle remaining operators
for len(operators) > 0 {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if op == "(" {
return nil, fmt.Errorf("mismatched parentheses")
}
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
if len(output) != 1 {
return nil, fmt.Errorf("invalid expression")
}
return output[0], nil
}
// parseFunctionArgs parses function arguments
func parseFunctionArgs(tokens []string, startIndex int) ([]*ExprNode, int, error) {
args := make([]*ExprNode, 0)
i := startIndex
// Handle empty argument list
if i < len(tokens) && tokens[i] == ")" {
return args, i + 1, nil
}
for i < len(tokens) {
// Parse argument expression
argTokens := make([]string, 0)
parenthesesCount := 0
for i < len(tokens) {
token := tokens[i]
if token == "(" {
parenthesesCount++
} else if token == ")" {
parenthesesCount--
if parenthesesCount < 0 {
break
}
} else if token == "," && parenthesesCount == 0 {
break
}
argTokens = append(argTokens, token)
i++
}
if len(argTokens) > 0 {
arg, err := parseExpression(argTokens)
if err != nil {
return nil, 0, err
}
args = append(args, arg)
}
if i >= len(tokens) {
return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments")
}
if tokens[i] == ")" {
return args, i + 1, nil
}
if tokens[i] == "," {
i++
continue
}
return nil, 0, fmt.Errorf("unexpected token in function arguments: %s", tokens[i])
}
return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments")
}
// parseCaseExpression parses CASE expression
func parseCaseExpression(tokens []string, startIndex int) (*ExprNode, int, error) {
if startIndex >= len(tokens) || strings.ToUpper(tokens[startIndex]) != "CASE" {
return nil, startIndex, fmt.Errorf("expected CASE keyword")
}
caseNode := &ExprNode{
Type: TypeCase,
WhenClauses: make([]WhenClause, 0),
}
i := startIndex + 1 // 跳过CASE关键字
// 检查是否是简单CASE表达式(CASE expr WHEN value1 THEN result1 ...
// 或搜索CASE表达式(CASE WHEN condition1 THEN result1 ...
if i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" {
// 这是简单CASE表达式,需要解析CASE后面的表达式
caseExprTokens := make([]string, 0)
// 收集CASE表达式直到遇到WHEN
for i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" {
caseExprTokens = append(caseExprTokens, tokens[i])
i++
}
if len(caseExprTokens) == 0 {
return nil, i, fmt.Errorf("expected expression after CASE")
}
// 对于简单的情况,直接处理单个token
if len(caseExprTokens) == 1 {
token := caseExprTokens[0]
if isNumber(token) {
caseNode.CaseExpr = &ExprNode{Type: TypeNumber, Value: token}
} else if isStringLiteral(token) {
caseNode.CaseExpr = &ExprNode{Type: TypeString, Value: token}
} else if isIdentifier(token) {
caseNode.CaseExpr = &ExprNode{Type: TypeField, Value: token}
} else {
return nil, i, fmt.Errorf("invalid CASE expression token: %s", token)
}
} else {
// 对于复杂表达式,调用parseExpression
caseExpr, err := parseExpression(caseExprTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse CASE expression: %w", err)
}
caseNode.CaseExpr = caseExpr
}
}
// 解析WHEN子句
for i < len(tokens) && strings.ToUpper(tokens[i]) == "WHEN" {
i++ // 跳过WHEN关键字
// 收集WHEN条件直到遇到THEN
conditionTokens := make([]string, 0)
for i < len(tokens) && strings.ToUpper(tokens[i]) != "THEN" {
conditionTokens = append(conditionTokens, tokens[i])
i++
}
if len(conditionTokens) == 0 {
return nil, i, fmt.Errorf("expected condition after WHEN")
}
if i >= len(tokens) || strings.ToUpper(tokens[i]) != "THEN" {
return nil, i, fmt.Errorf("expected THEN after WHEN condition")
}
i++ // 跳过THEN关键字
// 收集THEN结果直到遇到WHEN、ELSE或END
resultTokens := make([]string, 0)
for i < len(tokens) {
upper := strings.ToUpper(tokens[i])
if upper == "WHEN" || upper == "ELSE" || upper == "END" {
break
}
resultTokens = append(resultTokens, tokens[i])
i++
}
if len(resultTokens) == 0 {
return nil, i, fmt.Errorf("expected result after THEN")
}
// 解析条件和结果表达式
conditionExpr, err := parseExpression(conditionTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse WHEN condition: %w", err)
}
resultExpr, err := parseExpression(resultTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse THEN result: %w", err)
}
// 添加WHEN子句
caseNode.WhenClauses = append(caseNode.WhenClauses, WhenClause{
Condition: conditionExpr,
Result: resultExpr,
})
}
// 检查是否有ELSE子句
if i < len(tokens) && strings.ToUpper(tokens[i]) == "ELSE" {
i++ // 跳过ELSE关键字
// 收集ELSE结果直到遇到END
elseTokens := make([]string, 0)
for i < len(tokens) && strings.ToUpper(tokens[i]) != "END" {
elseTokens = append(elseTokens, tokens[i])
i++
}
if len(elseTokens) == 0 {
return nil, i, fmt.Errorf("expected result after ELSE")
}
// 解析ELSE表达式
elseExpr, err := parseExpression(elseTokens)
if err != nil {
return nil, i, fmt.Errorf("failed to parse ELSE result: %w", err)
}
caseNode.ElseExpr = elseExpr
}
// 检查END关键字
if i >= len(tokens) || strings.ToUpper(tokens[i]) != "END" {
return nil, i, fmt.Errorf("expected END to close CASE expression")
}
i++ // 跳过END关键字
// 验证至少有一个WHEN子句
if len(caseNode.WhenClauses) == 0 {
return nil, i, fmt.Errorf("CASE expression must have at least one WHEN clause")
}
return caseNode, i, nil
}
// 辅助函数
func isDigit(ch byte) bool {
return ch >= '0' && ch <= '9'
}
func isLetter(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
func isNumber(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}
func isIdentifier(s string) bool {
if len(s) == 0 {
return false
}
if !isLetter(s[0]) && s[0] != '_' {
return false
}
for i := 1; i < len(s); i++ {
if !isLetter(s[i]) && !isDigit(s[i]) && s[i] != '_' {
return false
}
}
return true
}
func isOperator(s string) bool {
switch s {
case "+", "-", "*", "/", "%", "^":
return true
case ">", "<", ">=", "<=", "==", "=", "!=", "<>":
return true
case "AND", "OR", "NOT":
return true
case "LIKE", "IS":
return true
default:
return false
}
}
// isComparisonOperator 检查是否是比较运算符
func isComparisonOperator(s string) bool {
switch s {
case ">", "<", ">=", "<=", "==", "=", "!=", "<>":
return true
default:
return false
}
}
func isStringLiteral(expr string) bool {
return len(expr) > 1 && (expr[0] == '\'' || expr[0] == '"') && expr[len(expr)-1] == expr[0]
}
// evaluateNodeWithNull 计算节点值,支持NULL值返回
// 返回 (result, isNull, error)
func evaluateNodeWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) {
if node == nil {
return 0, true, nil // NULL
}
switch node.Type {
case TypeNumber:
val, err := strconv.ParseFloat(node.Value, 64)
return val, false, err
case TypeString:
// 字符串长度作为数值,特殊处理NULL字符串
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1]
}
// 检查是否是NULL字符串
if strings.ToUpper(value) == "NULL" {
return 0, true, nil
}
return float64(len(value)), false, nil
case TypeField:
// 支持嵌套字段访问
var fieldVal interface{}
var found bool
if fieldpath.IsNestedField(node.Value) {
fieldVal, found = fieldpath.GetNestedField(data, node.Value)
} else {
fieldVal, found = data[node.Value]
}
if !found || fieldVal == nil {
return 0, true, nil // NULL
}
// 尝试转换为数值
if val, err := convertToFloat(fieldVal); err == nil {
return val, false, nil
}
return 0, true, fmt.Errorf("cannot convert field '%s' to number", node.Value)
case TypeOperator:
return evaluateOperatorWithNull(node, data)
case TypeFunction:
// 函数调用保持原有逻辑,但处理NULL结果
result, err := evaluateBuiltinFunction(node, data)
return result, false, err
case TypeCase:
return evaluateCaseExpressionWithNull(node, data)
default:
return 0, true, fmt.Errorf("unsupported node type: %s", node.Type)
}
}
// evaluateOperatorWithNull 计算运算符表达式,支持NULL值
func evaluateOperatorWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) {
leftVal, leftNull, err := evaluateNodeWithNull(node.Left, data)
if err != nil {
return 0, false, err
}
rightVal, rightNull, err := evaluateNodeWithNull(node.Right, data)
if err != nil {
return 0, false, err
}
// 算术运算:如果任一操作数为NULL,结果为NULL
if leftNull || rightNull {
switch node.Value {
case "+", "-", "*", "/", "%", "^":
return 0, true, nil
}
}
// 比较运算:NULL值的比较有特殊规则
switch node.Value {
case "==", "=":
if leftNull && rightNull {
return 1, false, nil // NULL = NULL 为 true
}
if leftNull || rightNull {
return 0, false, nil // NULL = value 为 false
}
if leftVal == rightVal {
return 1, false, nil
}
return 0, false, nil
case "!=", "<>":
if leftNull && rightNull {
return 0, false, nil // NULL != NULL 为 false
}
if leftNull || rightNull {
return 0, false, nil // NULL != value 为 false
}
if leftVal != rightVal {
return 1, false, nil
}
return 0, false, nil
case ">", "<", ">=", "<=":
if leftNull || rightNull {
return 0, false, nil // NULL与任何值的比较都为false
}
}
// 对于非NULL值,执行正常的算术和比较运算
switch node.Value {
case "+":
return leftVal + rightVal, false, nil
case "-":
return leftVal - rightVal, false, nil
case "*":
return leftVal * rightVal, false, nil
case "/":
if rightVal == 0 {
return 0, true, nil // 除零返回NULL
}
return leftVal / rightVal, false, nil
case "%":
if rightVal == 0 {
return 0, true, nil
}
return math.Mod(leftVal, rightVal), false, nil
case "^":
return math.Pow(leftVal, rightVal), false, nil
case ">":
if leftVal > rightVal {
return 1, false, nil
}
return 0, false, nil
case "<":
if leftVal < rightVal {
return 1, false, nil
}
return 0, false, nil
case ">=":
if leftVal >= rightVal {
return 1, false, nil
}
return 0, false, nil
case "<=":
if leftVal <= rightVal {
return 1, false, nil
}
return 0, false, nil
default:
return 0, false, fmt.Errorf("unsupported operator: %s", node.Value)
}
}
// evaluateCaseExpressionWithNull 计算CASE表达式,支持NULL值
func evaluateCaseExpressionWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) {
if node.Type != TypeCase {
return 0, false, fmt.Errorf("node is not a CASE expression")
}
// 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...)
if node.CaseExpr != nil {
// 计算CASE后面的表达式值
caseValue, caseNull, err := evaluateNodeValueWithNull(node.CaseExpr, data)
if err != nil {
return 0, false, err
}
// 遍历WHEN子句,查找匹配的值
for _, whenClause := range node.WhenClauses {
conditionValue, condNull, err := evaluateNodeValueWithNull(whenClause.Condition, data)
if err != nil {
return 0, false, err
}
// 比较值是否相等(考虑NULL值)
var isEqual bool
if caseNull && condNull {
isEqual = true // NULL = NULL
} else if caseNull || condNull {
isEqual = false // NULL != value
} else {
isEqual, err = compareValuesForEquality(caseValue, conditionValue)
if err != nil {
return 0, false, err
}
}
if isEqual {
return evaluateNodeWithNull(whenClause.Result, data)
}
}
} else {
// 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...)
for _, whenClause := range node.WhenClauses {
// 评估WHEN条件
conditionResult, err := evaluateBooleanConditionWithNull(whenClause.Condition, data)
if err != nil {
return 0, false, err
}
// 如果条件为真,返回对应的结果
if conditionResult {
return evaluateNodeWithNull(whenClause.Result, data)
}
}
}
// 如果没有匹配的WHEN子句,执行ELSE子句
if node.ElseExpr != nil {
return evaluateNodeWithNull(node.ElseExpr, data)
}
// 如果没有ELSE子句,SQL标准是返回NULL
return 0, true, nil
}
// evaluateCaseExpressionValueWithNull 计算CASE表达式并返回实际值(支持字符串),支持NULL值
func evaluateCaseExpressionValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) {
if node.Type != TypeCase {
return nil, false, fmt.Errorf("node is not a CASE expression")
}
// 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...)
if node.CaseExpr != nil {
// 计算CASE后面的表达式值
caseValue, caseNull, err := evaluateNodeValueWithNull(node.CaseExpr, data)
if err != nil {
return nil, false, err
}
// 遍历WHEN子句,查找匹配的值
for _, whenClause := range node.WhenClauses {
conditionValue, condNull, err := evaluateNodeValueWithNull(whenClause.Condition, data)
if err != nil {
return nil, false, err
}
// 比较值是否相等(考虑NULL值)
var isEqual bool
if caseNull && condNull {
isEqual = true // NULL = NULL
} else if caseNull || condNull {
isEqual = false // NULL != value
} else {
isEqual, err = compareValuesForEquality(caseValue, conditionValue)
if err != nil {
return nil, false, err
}
}
if isEqual {
return evaluateNodeValueWithNull(whenClause.Result, data)
}
}
} else {
// 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...)
for _, whenClause := range node.WhenClauses {
// 评估WHEN条件
conditionResult, err := evaluateBooleanConditionWithNull(whenClause.Condition, data)
if err != nil {
return nil, false, err
}
// 如果条件为真,返回对应的结果
if conditionResult {
return evaluateNodeValueWithNull(whenClause.Result, data)
}
}
}
// 如果没有匹配的WHEN子句,执行ELSE子句
if node.ElseExpr != nil {
return evaluateNodeValueWithNull(node.ElseExpr, data)
}
// 如果没有ELSE子句,SQL标准是返回NULL
return nil, true, nil
}
// evaluateNodeValueWithNull 计算节点值,返回interface{}以支持不同类型,包含NULL检查
func evaluateNodeValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) {
if node == nil {
return nil, true, nil
}
switch node.Type {
case TypeNumber:
val, err := strconv.ParseFloat(node.Value, 64)
return val, false, err
case TypeString:
// 去掉引号
value := node.Value
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
value = value[1 : len(value)-1]
}
// 检查是否是NULL字符串
if strings.ToUpper(value) == "NULL" {
return nil, true, nil
}
return value, false, nil
case TypeField:
// 处理反引号标识符,去除反引号
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号
}
// 支持嵌套字段访问
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
return val, val == nil, nil
}
} else {
// 原有的简单字段访问
if val, found := data[fieldName]; found {
return val, val == nil, nil
}
}
return nil, true, nil // 字段不存在视为NULL
case TypeCase:
// 处理CASE表达式,返回实际值
return evaluateCaseExpressionValueWithNull(node, data)
default:
// 对于其他类型,回退到数值计算
result, isNull, err := evaluateNodeWithNull(node, data)
return result, isNull, err
}
}
// evaluateBooleanConditionWithNull 计算布尔条件表达式,支持NULL值
func evaluateBooleanConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil {
return false, fmt.Errorf("null condition expression")
}
// 处理逻辑运算符(实现短路求值)
if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") {
leftBool, err := evaluateBooleanConditionWithNull(node.Left, data)
if err != nil {
return false, err
}
// 短路求值:对于AND,如果左边为false,立即返回false
if node.Value == "AND" && !leftBool {
return false, nil
}
// 短路求值:对于OR,如果左边为true,立即返回true
if node.Value == "OR" && leftBool {
return true, nil
}
// 只有在需要时才评估右边的表达式
rightBool, err := evaluateBooleanConditionWithNull(node.Right, data)
if err != nil {
return false, err
}
switch node.Value {
case "AND":
return leftBool && rightBool, nil
case "OR":
return leftBool || rightBool, nil
}
}
// 处理IS NULL和IS NOT NULL特殊情况
if node.Type == TypeOperator && node.Value == "IS" {
return evaluateIsConditionWithNull(node, data)
}
// 处理比较运算符
if node.Type == TypeOperator {
leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data)
if err != nil {
return false, err
}
rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data)
if err != nil {
return false, err
}
return compareValuesWithNull(leftValue, leftNull, rightValue, rightNull, node.Value)
}
// 对于其他表达式,计算其数值并转换为布尔值
result, isNull, err := evaluateNodeWithNull(node, data)
if err != nil {
return false, err
}
// NULL值在布尔上下文中为false,非零值为真,零值为假
return !isNull && result != 0, nil
}
// evaluateIsConditionWithNull 处理IS NULL和IS NOT NULL条件,支持NULL值
func evaluateIsConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) {
if node == nil || node.Left == nil || node.Right == nil {
return false, fmt.Errorf("invalid IS condition")
}
// 获取左侧值
leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data)
if err != nil {
// 如果字段不存在,认为是null
leftValue = nil
leftNull = true
}
// 检查右侧是否是NULL或NOT NULL
if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" {
// IS NULL
return leftNull || leftValue == nil, nil
}
// 检查是否是IS NOT NULL
if node.Right.Type == TypeOperator && node.Right.Value == "NOT" &&
node.Right.Right != nil && node.Right.Right.Type == TypeField &&
strings.ToUpper(node.Right.Right.Value) == "NULL" {
// IS NOT NULL
return !leftNull && leftValue != nil, nil
}
// 其他IS比较
rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data)
if err != nil {
return false, err
}
return compareValuesWithNullForEquality(leftValue, leftNull, rightValue, rightNull)
}
// compareValuesForEquality 比较两个值是否相等
func compareValuesForEquality(left, right interface{}) (bool, error) {
// 尝试字符串比较
leftStr, leftIsStr := left.(string)
rightStr, rightIsStr := right.(string)
if leftIsStr && rightIsStr {
return leftStr == rightStr, nil
}
// 尝试数值比较
leftFloat, leftErr := convertToFloat(left)
rightFloat, rightErr := convertToFloat(right)
if leftErr == nil && rightErr == nil {
return leftFloat == rightFloat, nil
}
// 如果都不能转换,直接比较
return left == right, nil
}
// compareValuesWithNull 比较两个值(支持NULL
func compareValuesWithNull(left interface{}, leftNull bool, right interface{}, rightNull bool, operator string) (bool, error) {
// NULL值的比较有特殊规则
switch operator {
case "==", "=":
if leftNull && rightNull {
return true, nil // NULL = NULL 为 true
}
if leftNull || rightNull {
return false, nil // NULL = value 为 false
}
case "!=", "<>":
if leftNull && rightNull {
return false, nil // NULL != NULL 为 false
}
if leftNull || rightNull {
return false, nil // NULL != value 为 false
}
case ">", "<", ">=", "<=":
if leftNull || rightNull {
return false, nil // NULL与任何值的比较都为false
}
}
// 对于非NULL值,执行正确的比较逻辑
switch operator {
case "==", "=":
return compareValuesForEquality(left, right)
case "!=", "<>":
equal, err := compareValuesForEquality(left, right)
return !equal, err
case ">", "<", ">=", "<=":
// 进行数值比较
leftFloat, leftErr := convertToFloat(left)
rightFloat, rightErr := convertToFloat(right)
if leftErr != nil || rightErr != nil {
// 如果不能转换为数值,尝试字符串比较
leftStr := fmt.Sprintf("%v", left)
rightStr := fmt.Sprintf("%v", right)
switch operator {
case ">":
return leftStr > rightStr, nil
case "<":
return leftStr < rightStr, nil
case ">=":
return leftStr >= rightStr, nil
case "<=":
return leftStr <= rightStr, nil
}
}
// 数值比较
switch operator {
case ">":
return leftFloat > rightFloat, nil
case "<":
return leftFloat < rightFloat, nil
case ">=":
return leftFloat >= rightFloat, nil
case "<=":
return leftFloat <= rightFloat, nil
}
}
return false, fmt.Errorf("unsupported operator: %s", operator)
}
// compareValuesWithNullForEquality 比较两个值是否相等(支持NULL)
func compareValuesWithNullForEquality(left interface{}, leftNull bool, right interface{}, rightNull bool) (bool, error) {
if leftNull && rightNull {
return true, nil // NULL = NULL 为 true
}
if leftNull || rightNull {
return false, nil // NULL = value 为 false
}
return compareValuesForEquality(left, right)
}
// EvaluateWithNull 提供公开接口,用于聚合函数调用
func (e *Expression) EvaluateWithNull(data map[string]interface{}) (float64, bool, error) {
if e.useExprLang {
// expr-lang不支持NULL,回退到原有逻辑
result, err := e.evaluateWithExprLang(data)
return result, false, err
}
return evaluateNodeWithNull(e.Root, data)
}
// EvaluateValueWithNull 评估表达式并返回任意类型的值,支持NULL
func (e *Expression) EvaluateValueWithNull(data map[string]interface{}) (interface{}, bool, error) {
if e.useExprLang {
// expr-lang不支持NULL,回退到原有逻辑
result, err := e.evaluateWithExprLang(data)
return result, false, err
}
return evaluateNodeValueWithNull(e.Root, data)
}