mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-03-13 13:57:29 +00:00
963 lines
25 KiB
Go
963 lines
25 KiB
Go
package expr
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/rulego/streamsql/functions"
|
|
"github.com/rulego/streamsql/utils/fieldpath"
|
|
)
|
|
|
|
// evaluateNode evaluates the value of a node
|
|
func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) {
|
|
if node == nil {
|
|
return 0, fmt.Errorf("null expression node")
|
|
}
|
|
|
|
switch node.Type {
|
|
case TypeNumber:
|
|
return strconv.ParseFloat(node.Value, 64)
|
|
|
|
case TypeString:
|
|
// Handle string type, remove quotes and try to convert to number
|
|
value := node.Value
|
|
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
|
|
value = value[1 : len(value)-1] // Remove quotes
|
|
}
|
|
|
|
// Try to convert to number
|
|
if f, err := strconv.ParseFloat(value, 64); err == nil {
|
|
return f, nil
|
|
}
|
|
|
|
// For string comparison, return string length (temporary solution)
|
|
return float64(len(value)), nil
|
|
|
|
case TypeField:
|
|
return evaluateFieldNode(node, data)
|
|
|
|
case TypeOperator:
|
|
return evaluateOperatorNode(node, data)
|
|
|
|
case TypeFunction:
|
|
return evaluateFunctionNode(node, data)
|
|
|
|
case TypeCase:
|
|
// Handle CASE expression
|
|
return evaluateCaseExpression(node, data)
|
|
|
|
case TypeParenthesis:
|
|
// Handle parenthesis expression, directly evaluate inner expression
|
|
return evaluateNode(node.Left, data)
|
|
}
|
|
|
|
return 0, fmt.Errorf("unknown node type: %s", node.Type)
|
|
}
|
|
|
|
// evaluateFieldNode evaluates the value of a field node
|
|
func evaluateFieldNode(node *ExprNode, data map[string]interface{}) (float64, error) {
|
|
// Handle backtick identifiers, remove backticks
|
|
fieldName := node.Value
|
|
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
|
|
fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks
|
|
}
|
|
|
|
// Support nested field access
|
|
if fieldpath.IsNestedField(fieldName) {
|
|
if val, found := fieldpath.GetNestedField(data, fieldName); found {
|
|
// Try to convert to float64
|
|
if floatVal, err := convertToFloat(val); err == nil {
|
|
// Check if it's NaN
|
|
if math.IsNaN(floatVal) {
|
|
return 0, fmt.Errorf("field '%s' contains NaN value", fieldName)
|
|
}
|
|
return floatVal, nil
|
|
}
|
|
// If cannot convert to number, return error
|
|
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
|
|
}
|
|
} else {
|
|
// Original simple field access
|
|
if val, found := data[fieldName]; found {
|
|
// Try to convert to float64
|
|
if floatVal, err := convertToFloat(val); err == nil {
|
|
// Check if it's NaN
|
|
if math.IsNaN(floatVal) {
|
|
return 0, fmt.Errorf("field '%s' contains NaN value", fieldName)
|
|
}
|
|
return floatVal, nil
|
|
}
|
|
// If cannot convert to number, return error
|
|
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
|
|
}
|
|
}
|
|
return 0, fmt.Errorf("field '%s' not found", fieldName)
|
|
}
|
|
|
|
// evaluateOperatorNode evaluates the value of an operator node
|
|
func evaluateOperatorNode(node *ExprNode, data map[string]interface{}) (float64, error) {
|
|
// Check if it's a comparison operator
|
|
if isComparisonOperator(node.Value) {
|
|
// For comparison operators, use evaluateNodeValue to get original type
|
|
leftValue, err := evaluateNodeValue(node.Left, data)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
rightValue, err := evaluateNodeValue(node.Right, data)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Execute comparison and convert boolean to number
|
|
result, err := compareValues(leftValue, rightValue, node.Value)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if result {
|
|
return 1.0, nil
|
|
}
|
|
return 0.0, nil
|
|
}
|
|
|
|
// For arithmetic operators, calculate numeric values
|
|
left, err := evaluateNode(node.Left, data)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
right, err := evaluateNode(node.Right, data)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Check if operands are NaN
|
|
if math.IsNaN(left) {
|
|
return 0, fmt.Errorf("left operand is NaN")
|
|
}
|
|
if math.IsNaN(right) {
|
|
return 0, fmt.Errorf("right operand is NaN")
|
|
}
|
|
|
|
// Execute operation
|
|
var result float64
|
|
switch node.Value {
|
|
case "+":
|
|
result = left + right
|
|
case "-":
|
|
result = left - right
|
|
case "*":
|
|
result = left * right
|
|
case "/":
|
|
if right == 0 {
|
|
return 0, fmt.Errorf("division by zero")
|
|
}
|
|
result = left / right
|
|
case "%":
|
|
if right == 0 {
|
|
return 0, fmt.Errorf("modulo by zero")
|
|
}
|
|
result = math.Mod(left, right)
|
|
case "^":
|
|
result = math.Pow(left, right)
|
|
default:
|
|
return 0, fmt.Errorf("unknown operator: %s", node.Value)
|
|
}
|
|
|
|
// Check if result is NaN
|
|
if math.IsNaN(result) {
|
|
return 0, fmt.Errorf("operation result is NaN")
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// evaluateFunctionNode evaluates the value of a function node
|
|
// Uses unified function registration system to handle all function calls
|
|
func evaluateFunctionNode(node *ExprNode, data map[string]interface{}) (float64, error) {
|
|
// Check if function exists in the new function registration system
|
|
fn, exists := functions.Get(node.Value)
|
|
if !exists {
|
|
return 0, fmt.Errorf("unknown function: %s", node.Value)
|
|
}
|
|
|
|
// Calculate all arguments but keep original types
|
|
args := make([]interface{}, len(node.Args))
|
|
for i, arg := range node.Args {
|
|
// Use evaluateNodeValue to get original type values
|
|
val, err := evaluateNodeValue(arg, data)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
args[i] = val
|
|
}
|
|
|
|
// Validate arguments
|
|
if err := fn.Validate(args); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Create function execution context
|
|
ctx := &functions.FunctionContext{
|
|
Data: data,
|
|
}
|
|
|
|
// Execute function
|
|
result, err := fn.Execute(ctx, args)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Convert result to float64
|
|
switch r := result.(type) {
|
|
case float64:
|
|
return r, nil
|
|
case float32:
|
|
return float64(r), nil
|
|
case int:
|
|
return float64(r), nil
|
|
case int32:
|
|
return float64(r), nil
|
|
case int64:
|
|
return float64(r), nil
|
|
case string:
|
|
// For string results, try to convert to number, return string length if failed
|
|
if f, err := strconv.ParseFloat(r, 64); err == nil {
|
|
return f, nil
|
|
}
|
|
return float64(len(r)), nil
|
|
case bool:
|
|
// Boolean conversion: true=1, false=0
|
|
if r {
|
|
return 1.0, nil
|
|
}
|
|
return 0.0, nil
|
|
default:
|
|
return 0, fmt.Errorf("function %s returned unsupported type for numeric conversion: %T", node.Value, result)
|
|
}
|
|
}
|
|
|
|
// evaluateNodeValue evaluates the original value of a node (preserving type)
|
|
func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{}, error) {
|
|
if node == nil {
|
|
return nil, fmt.Errorf("null expression node")
|
|
}
|
|
|
|
switch node.Type {
|
|
case TypeNumber:
|
|
return strconv.ParseFloat(node.Value, 64)
|
|
|
|
case TypeString:
|
|
// Handle string type, remove quotes
|
|
value := node.Value
|
|
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
|
|
value = value[1 : len(value)-1] // Remove quotes
|
|
}
|
|
return value, nil
|
|
|
|
case TypeField:
|
|
return evaluateFieldValue(node, data)
|
|
|
|
case TypeOperator:
|
|
return evaluateOperatorValue(node, data)
|
|
|
|
case TypeFunction:
|
|
return evaluateFunctionValue(node, data)
|
|
|
|
case TypeCase:
|
|
// Handle CASE expression
|
|
return evaluateCaseExpression(node, data)
|
|
|
|
case TypeParenthesis:
|
|
// Handle parenthesis expression, directly evaluate inner expression
|
|
return evaluateNodeValue(node.Left, data)
|
|
}
|
|
|
|
return nil, fmt.Errorf("unknown node type: %s", node.Type)
|
|
}
|
|
|
|
// evaluateFieldValue evaluates the original value of a field
|
|
func evaluateFieldValue(node *ExprNode, data map[string]interface{}) (interface{}, error) {
|
|
// Handle backtick identifiers, remove backticks
|
|
fieldName := node.Value
|
|
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
|
|
fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks
|
|
}
|
|
|
|
// Support nested field access
|
|
if fieldpath.IsNestedField(fieldName) {
|
|
if val, found := fieldpath.GetNestedField(data, fieldName); found {
|
|
return val, nil
|
|
}
|
|
} else {
|
|
// Original simple field access
|
|
if val, found := data[fieldName]; found {
|
|
return val, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("field '%s' not found", fieldName)
|
|
}
|
|
|
|
// evaluateOperatorValue evaluates the original value of an operator
|
|
func evaluateOperatorValue(node *ExprNode, data map[string]interface{}) (interface{}, error) {
|
|
// Special handling for IS and IS NOT operators
|
|
operator := strings.ToUpper(node.Value)
|
|
if operator == "IS" || operator == "IS NOT" {
|
|
return evaluateIsOperator(node, data)
|
|
}
|
|
|
|
// Check if it's a logical operator
|
|
if isLogicalOperator(node.Value) {
|
|
// For logical operators, use boolean evaluation
|
|
result, err := evaluateBoolOperator(node, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// Check if it's a comparison operator
|
|
if isComparisonOperator(node.Value) {
|
|
leftValue, err := evaluateNodeValue(node.Left, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rightValue, err := evaluateNodeValue(node.Right, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Execute comparison
|
|
return compareValues(leftValue, rightValue, node.Value)
|
|
}
|
|
|
|
// For arithmetic operators, use NULL-supporting evaluation
|
|
left, leftIsNull, err := evaluateNodeValueWithNull(node.Left, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
right, rightIsNull, err := evaluateNodeValueWithNull(node.Right, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// If any operand is NULL, result is NULL
|
|
if leftIsNull || rightIsNull {
|
|
return nil, nil
|
|
}
|
|
|
|
// Try to convert operands to numbers
|
|
leftFloat, leftOk := convertToFloatSafe(left)
|
|
rightFloat, rightOk := convertToFloatSafe(right)
|
|
|
|
if !leftOk {
|
|
return nil, fmt.Errorf("left operand cannot be converted to number: %v", left)
|
|
}
|
|
if !rightOk {
|
|
return nil, fmt.Errorf("right operand cannot be converted to number: %v", right)
|
|
}
|
|
|
|
// Execute arithmetic operation
|
|
var result float64
|
|
switch node.Value {
|
|
case "+":
|
|
result = leftFloat + rightFloat
|
|
case "-":
|
|
result = leftFloat - rightFloat
|
|
case "*":
|
|
result = leftFloat * rightFloat
|
|
case "/":
|
|
if rightFloat == 0 {
|
|
return nil, fmt.Errorf("division by zero")
|
|
}
|
|
result = leftFloat / rightFloat
|
|
case "%":
|
|
if rightFloat == 0 {
|
|
return nil, fmt.Errorf("modulo by zero")
|
|
}
|
|
result = math.Mod(leftFloat, rightFloat)
|
|
case "^":
|
|
result = math.Pow(leftFloat, rightFloat)
|
|
default:
|
|
return nil, fmt.Errorf("unknown arithmetic operator: %s", node.Value)
|
|
}
|
|
|
|
// Check if result is NaN
|
|
if math.IsNaN(result) {
|
|
return nil, fmt.Errorf("operation result is NaN")
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// evaluateFunctionValue evaluates the original value of a function
|
|
// Uses unified function registration system to handle all function calls
|
|
func evaluateFunctionValue(node *ExprNode, data map[string]interface{}) (interface{}, error) {
|
|
// Check if function exists in the new function registration system
|
|
fn, exists := functions.Get(node.Value)
|
|
if !exists {
|
|
return nil, fmt.Errorf("unknown function: %s", node.Value)
|
|
}
|
|
|
|
// Calculate all arguments but keep original types
|
|
args := make([]interface{}, len(node.Args))
|
|
for i, arg := range node.Args {
|
|
val, err := evaluateNodeValue(arg, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
args[i] = val
|
|
}
|
|
|
|
// Validate arguments
|
|
if err := fn.Validate(args); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create function execution context
|
|
ctx := &functions.FunctionContext{
|
|
Data: data,
|
|
}
|
|
|
|
// Execute function
|
|
return fn.Execute(ctx, args)
|
|
}
|
|
|
|
// compareValues compares two values
|
|
func compareValues(left, right interface{}, operator string) (bool, error) {
|
|
// Handle NULL values
|
|
if left == nil || right == nil {
|
|
switch strings.ToUpper(operator) {
|
|
case "IS":
|
|
return left == right, nil
|
|
case "IS NOT":
|
|
return left != right, nil
|
|
default:
|
|
return false, nil // NULL compared with any value returns false
|
|
}
|
|
}
|
|
|
|
// Try numeric comparison
|
|
leftFloat, leftIsFloat := convertToFloatSafe(left)
|
|
rightFloat, rightIsFloat := convertToFloatSafe(right)
|
|
|
|
if leftIsFloat && rightIsFloat {
|
|
return compareFloats(leftFloat, rightFloat, operator)
|
|
}
|
|
|
|
// Check for incompatible type comparison (one is number, one is not)
|
|
if (leftIsFloat && !rightIsFloat) || (!leftIsFloat && rightIsFloat) {
|
|
// For equality comparison, allow type conversion
|
|
operatorUpper := strings.ToUpper(operator)
|
|
if operatorUpper == "==" || operatorUpper == "=" || operatorUpper == "!=" || operatorUpper == "<>" {
|
|
// String comparison
|
|
leftStr := fmt.Sprintf("%v", left)
|
|
rightStr := fmt.Sprintf("%v", right)
|
|
return compareStrings(leftStr, rightStr, operator)
|
|
}
|
|
// For size comparison, return error
|
|
return false, fmt.Errorf("cannot compare incompatible types: %T and %T", left, right)
|
|
}
|
|
|
|
// String comparison
|
|
leftStr := fmt.Sprintf("%v", left)
|
|
rightStr := fmt.Sprintf("%v", right)
|
|
|
|
return compareStrings(leftStr, rightStr, operator)
|
|
}
|
|
|
|
// compareFloats compares two floating point numbers
|
|
func compareFloats(left, right float64, operator string) (bool, error) {
|
|
switch strings.ToUpper(operator) {
|
|
case "==", "=":
|
|
return left == right, nil
|
|
case "!=", "<>":
|
|
return left != right, nil
|
|
case ">":
|
|
return left > right, nil
|
|
case "<":
|
|
return left < right, nil
|
|
case ">=":
|
|
return left >= right, nil
|
|
case "<=":
|
|
return left <= right, nil
|
|
default:
|
|
return false, fmt.Errorf("unsupported numeric comparison operator: %s", operator)
|
|
}
|
|
}
|
|
|
|
// compareStrings compares two strings
|
|
func compareStrings(left, right, operator string) (bool, error) {
|
|
switch strings.ToUpper(operator) {
|
|
case "==", "=":
|
|
return left == right, nil
|
|
case "!=", "<>":
|
|
return left != right, nil
|
|
case ">":
|
|
return left > right, nil
|
|
case "<":
|
|
return left < right, nil
|
|
case ">=":
|
|
return left >= right, nil
|
|
case "<=":
|
|
return left <= right, nil
|
|
case "LIKE":
|
|
return matchLikePattern(left, right), nil
|
|
default:
|
|
return false, fmt.Errorf("unsupported string comparison operator: %s", operator)
|
|
}
|
|
}
|
|
|
|
// matchLikePattern implements LIKE pattern matching
|
|
func matchLikePattern(text, pattern string) bool {
|
|
// Simplified LIKE implementation, supports % and _ wildcards
|
|
// % matches any character sequence, _ matches single character
|
|
return matchPattern(text, pattern, 0, 0)
|
|
}
|
|
|
|
// matchPattern recursively matches pattern
|
|
func matchPattern(text, pattern string, textIdx, patternIdx int) bool {
|
|
if patternIdx == len(pattern) {
|
|
return textIdx == len(text)
|
|
}
|
|
|
|
if pattern[patternIdx] == '%' {
|
|
// % matches any character sequence
|
|
for i := textIdx; i <= len(text); i++ {
|
|
if matchPattern(text, pattern, i, patternIdx+1) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
if textIdx == len(text) {
|
|
return false
|
|
}
|
|
|
|
if pattern[patternIdx] == '_' || pattern[patternIdx] == text[textIdx] {
|
|
// _ matches single character or exact character match
|
|
return matchPattern(text, pattern, textIdx+1, patternIdx+1)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// compareValuesForEquality compares two values for equality (for simple CASE expressions)
|
|
func compareValuesForEquality(left, right interface{}) bool {
|
|
if left == nil && right == nil {
|
|
return true
|
|
}
|
|
if left == nil || right == nil {
|
|
return false
|
|
}
|
|
|
|
// Try numeric comparison
|
|
leftFloat, leftIsFloat := convertToFloatSafe(left)
|
|
rightFloat, rightIsFloat := convertToFloatSafe(right)
|
|
|
|
if leftIsFloat && rightIsFloat {
|
|
return leftFloat == rightFloat
|
|
}
|
|
|
|
// String comparison
|
|
leftStr := fmt.Sprintf("%v", left)
|
|
rightStr := fmt.Sprintf("%v", right)
|
|
return leftStr == rightStr
|
|
}
|
|
|
|
// evaluateNodeWithNull evaluates node value with NULL value handling
|
|
func evaluateNodeWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) {
|
|
if node == nil {
|
|
return 0, true, nil // NULL node
|
|
}
|
|
|
|
switch node.Type {
|
|
case TypeNumber:
|
|
val, err := strconv.ParseFloat(node.Value, 64)
|
|
return val, false, err
|
|
|
|
case TypeField:
|
|
// Handle backtick identifiers
|
|
fieldName := node.Value
|
|
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
|
|
fieldName = fieldName[1 : len(fieldName)-1]
|
|
}
|
|
|
|
// Support nested field access
|
|
var val interface{}
|
|
var found bool
|
|
if fieldpath.IsNestedField(fieldName) {
|
|
val, found = fieldpath.GetNestedField(data, fieldName)
|
|
} else {
|
|
val, found = data[fieldName]
|
|
}
|
|
|
|
if !found {
|
|
return 0, true, nil // Field not found is treated as NULL
|
|
}
|
|
|
|
if val == nil {
|
|
return 0, true, nil // NULL value
|
|
}
|
|
|
|
// Try to convert to numeric value
|
|
if floatVal, ok := convertToFloatSafe(val); ok {
|
|
return floatVal, false, nil
|
|
}
|
|
|
|
return 0, false, fmt.Errorf("field '%s' is not a number", fieldName)
|
|
|
|
case TypeOperator:
|
|
// For comparison operators, return boolean converted to numeric
|
|
if isComparisonOperator(node.Value) {
|
|
leftValue, leftIsNull, err := evaluateNodeValueWithNull(node.Left, data)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
|
|
rightValue, rightIsNull, err := evaluateNodeValueWithNull(node.Right, data)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
|
|
// Handle NULL comparison
|
|
if leftIsNull || rightIsNull {
|
|
switch strings.ToUpper(node.Value) {
|
|
case "IS":
|
|
if leftIsNull && rightIsNull {
|
|
return 1, false, nil
|
|
}
|
|
return 0, false, nil
|
|
case "IS NOT":
|
|
if leftIsNull && rightIsNull {
|
|
return 0, false, nil
|
|
}
|
|
return 1, false, nil
|
|
default:
|
|
return 0, true, nil // NULL compared with any value returns NULL
|
|
}
|
|
}
|
|
|
|
// Execute comparison
|
|
result, err := compareValues(leftValue, rightValue, node.Value)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
if result {
|
|
return 1, false, nil
|
|
}
|
|
return 0, false, nil
|
|
}
|
|
|
|
// Arithmetic operators
|
|
leftVal, leftIsNull, err := evaluateNodeWithNull(node.Left, data)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
if leftIsNull {
|
|
return 0, true, nil
|
|
}
|
|
|
|
rightVal, rightIsNull, err := evaluateNodeWithNull(node.Right, data)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
if rightIsNull {
|
|
return 0, true, nil
|
|
}
|
|
|
|
switch node.Value {
|
|
case "+":
|
|
return leftVal + rightVal, false, nil
|
|
case "-":
|
|
return leftVal - rightVal, false, nil
|
|
case "*":
|
|
return leftVal * rightVal, false, nil
|
|
case "/":
|
|
if rightVal == 0 {
|
|
return 0, false, fmt.Errorf("division by zero")
|
|
}
|
|
return leftVal / rightVal, false, nil
|
|
case "%":
|
|
if rightVal == 0 {
|
|
return 0, false, fmt.Errorf("modulo by zero")
|
|
}
|
|
return math.Mod(leftVal, rightVal), false, nil
|
|
default:
|
|
return 0, false, fmt.Errorf("unknown operator: %s", node.Value)
|
|
}
|
|
|
|
case TypeFunction:
|
|
// Function call, if any argument is NULL, result is usually NULL
|
|
val, err := evaluateNode(node, data)
|
|
return val, false, err
|
|
|
|
case TypeCase:
|
|
// CASE expression
|
|
result, isNull, err := evaluateCaseExpressionWithNull(node, data)
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
if isNull {
|
|
return 0, true, nil
|
|
}
|
|
if floatVal, ok := convertToFloatSafe(result); ok {
|
|
return floatVal, false, nil
|
|
}
|
|
return 0, false, fmt.Errorf("CASE expression result is not a number")
|
|
|
|
case TypeParenthesis:
|
|
// Handle parenthesis expression, directly evaluate inner expression
|
|
return evaluateNodeWithNull(node.Left, data)
|
|
|
|
default:
|
|
return 0, false, fmt.Errorf("unknown node type: %s", node.Type)
|
|
}
|
|
}
|
|
|
|
// evaluateNodeValueWithNull evaluates the original value of a node with NULL value handling
|
|
func evaluateNodeValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) {
|
|
if node == nil {
|
|
return nil, true, nil
|
|
}
|
|
|
|
switch node.Type {
|
|
case TypeNumber:
|
|
val, err := strconv.ParseFloat(node.Value, 64)
|
|
return val, false, err
|
|
|
|
case TypeString:
|
|
// Handle string type, remove quotes
|
|
value := node.Value
|
|
if len(value) >= 2 && ((value[0] == '"' && value[len(value)-1] == '"') || (value[0] == '\'' && value[len(value)-1] == '\'')) {
|
|
value = value[1 : len(value)-1] // Remove quotes
|
|
}
|
|
return value, false, nil
|
|
|
|
case TypeField:
|
|
// Handle backtick identifiers
|
|
fieldName := node.Value
|
|
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
|
|
fieldName = fieldName[1 : len(fieldName)-1]
|
|
}
|
|
|
|
// Support nested field access
|
|
var val interface{}
|
|
var found bool
|
|
if fieldpath.IsNestedField(fieldName) {
|
|
val, found = fieldpath.GetNestedField(data, fieldName)
|
|
} else {
|
|
val, found = data[fieldName]
|
|
}
|
|
|
|
if !found {
|
|
return nil, true, nil // Field not found is treated as NULL
|
|
}
|
|
|
|
return val, val == nil, nil
|
|
|
|
case TypeOperator:
|
|
val, err := evaluateOperatorValue(node, data)
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
return val, false, nil
|
|
|
|
case TypeFunction:
|
|
val, err := evaluateFunctionValue(node, data)
|
|
return val, val == nil, err
|
|
|
|
case TypeCase:
|
|
return evaluateCaseExpressionWithNull(node, data)
|
|
|
|
case TypeParenthesis:
|
|
// Handle parenthesis expression, directly evaluate inner expression
|
|
return evaluateNodeValueWithNull(node.Left, data)
|
|
|
|
default:
|
|
return nil, false, fmt.Errorf("unknown node type: %s", node.Type)
|
|
}
|
|
}
|
|
|
|
// compareValuesWithNullForEquality compares two values for equality (supports NULL comparison)
|
|
func compareValuesWithNullForEquality(left interface{}, leftIsNull bool, right interface{}, rightIsNull bool) bool {
|
|
if leftIsNull && rightIsNull {
|
|
return true
|
|
}
|
|
if leftIsNull || rightIsNull {
|
|
return false
|
|
}
|
|
return compareValuesForEquality(left, right)
|
|
}
|
|
|
|
// evaluateBoolNode evaluates the boolean value of a node
|
|
func evaluateBoolNode(node *ExprNode, data map[string]interface{}) (bool, error) {
|
|
if node == nil {
|
|
return false, fmt.Errorf("null expression node")
|
|
}
|
|
|
|
switch node.Type {
|
|
case TypeOperator:
|
|
return evaluateBoolOperator(node, data)
|
|
case TypeFunction:
|
|
return evaluateBoolFunction(node, data)
|
|
case TypeParenthesis:
|
|
// Parenthesis node, recursively evaluate inner expression
|
|
if node.Left != nil {
|
|
return evaluateBoolNode(node.Left, data)
|
|
}
|
|
return false, fmt.Errorf("empty parenthesis expression")
|
|
case TypeField:
|
|
// Convert field value to boolean
|
|
value, err := evaluateFieldValue(node, data)
|
|
if err != nil {
|
|
// If field doesn't exist, treat as NULL, convert to false
|
|
return false, nil
|
|
}
|
|
return convertToBool(value), nil
|
|
case TypeNumber:
|
|
// Convert number to boolean (non-zero is true)
|
|
value, err := strconv.ParseFloat(node.Value, 64)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return value != 0, nil
|
|
case TypeString:
|
|
// Convert string to boolean (non-empty is true)
|
|
value := node.Value
|
|
if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') {
|
|
value = value[1 : len(value)-1] // Remove quotes
|
|
}
|
|
return value != "", nil
|
|
default:
|
|
return false, fmt.Errorf("unsupported node type for boolean evaluation: %s", node.Type)
|
|
}
|
|
}
|
|
|
|
// evaluateBoolOperator evaluates boolean operators
|
|
func evaluateBoolOperator(node *ExprNode, data map[string]interface{}) (bool, error) {
|
|
operator := strings.ToUpper(node.Value)
|
|
|
|
switch operator {
|
|
case "AND", "&&":
|
|
left, err := evaluateBoolNode(node.Left, data)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if !left {
|
|
return false, nil // Short-circuit evaluation
|
|
}
|
|
return evaluateBoolNode(node.Right, data)
|
|
|
|
case "OR", "||":
|
|
left, err := evaluateBoolNode(node.Left, data)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if left {
|
|
return true, nil // Short-circuit evaluation
|
|
}
|
|
return evaluateBoolNode(node.Right, data)
|
|
|
|
case "NOT", "!":
|
|
// NOT operator may use Left or Right node
|
|
var operand *ExprNode
|
|
if node.Left != nil {
|
|
operand = node.Left
|
|
} else if node.Right != nil {
|
|
operand = node.Right
|
|
} else {
|
|
return false, fmt.Errorf("NOT operator requires an operand")
|
|
}
|
|
result, err := evaluateBoolNode(operand, data)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return !result, nil
|
|
|
|
case "IS", "IS NOT":
|
|
// IS and IS NOT operators (including IS NULL and IS NOT NULL)
|
|
result, err := evaluateIsOperator(node, data)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return convertToBool(result), nil
|
|
|
|
case "==", "=", "!=", "<>", ">", "<", ">=", "<=", "LIKE":
|
|
// Comparison operators
|
|
leftValue, err := evaluateNodeValue(node.Left, data)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
rightValue, err := evaluateNodeValue(node.Right, data)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return compareValues(leftValue, rightValue, operator)
|
|
|
|
default:
|
|
return false, fmt.Errorf("unsupported boolean operator: %s", operator)
|
|
}
|
|
}
|
|
|
|
// evaluateBoolFunction evaluates boolean functions
|
|
func evaluateBoolFunction(node *ExprNode, data map[string]interface{}) (bool, error) {
|
|
// Call function and convert result to boolean
|
|
result, err := evaluateFunctionValue(node, data)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return convertToBool(result), nil
|
|
}
|
|
|
|
// evaluateIsOperator handles IS and IS NOT operators (mainly IS NULL and IS NOT NULL)
|
|
func evaluateIsOperator(node *ExprNode, data map[string]interface{}) (interface{}, error) {
|
|
if node.Right == nil {
|
|
return nil, fmt.Errorf("IS operator requires a right operand")
|
|
}
|
|
|
|
operator := strings.ToUpper(node.Value)
|
|
|
|
// Check if right side is NULL
|
|
if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" {
|
|
// Get left value using NULL-supporting method
|
|
_, leftIsNull, err := evaluateNodeValueWithNull(node.Left, data)
|
|
if err != nil {
|
|
// If field doesn't exist, consider it NULL
|
|
leftIsNull = true
|
|
}
|
|
|
|
if operator == "IS" {
|
|
// IS NULL comparison
|
|
return leftIsNull, nil
|
|
} else if operator == "IS NOT" {
|
|
// IS NOT NULL comparison
|
|
return !leftIsNull, nil
|
|
}
|
|
}
|
|
|
|
// Other IS comparisons
|
|
leftValue, err := evaluateNodeValue(node.Left, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rightValue, err := evaluateNodeValue(node.Right, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if operator == "IS" {
|
|
return compareValuesForEquality(leftValue, rightValue), nil
|
|
} else if operator == "IS NOT" {
|
|
return !compareValuesForEquality(leftValue, rightValue), nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("unsupported IS operator: %s", operator)
|
|
}
|