diff --git a/expr/case_expression.go b/expr/case_expression.go new file mode 100644 index 0000000..a62de27 --- /dev/null +++ b/expr/case_expression.go @@ -0,0 +1,245 @@ +package expr + +import ( + "fmt" + "strings" +) + +// parseCaseExpression parses CASE expression +func parseCaseExpression(tokens []string) (*ExprNode, []string, error) { + if len(tokens) == 0 || strings.ToUpper(tokens[0]) != "CASE" { + return nil, nil, fmt.Errorf("expected CASE keyword") + } + + remaining := tokens[1:] + caseExpr := &CaseExpression{} + + // Check if it's a simple CASE expression (CASE expr WHEN value THEN result) + if len(remaining) > 0 && strings.ToUpper(remaining[0]) != "WHEN" { + // Simple CASE expression + value, newRemaining, err := parseOrExpression(remaining) + if err != nil { + return nil, nil, fmt.Errorf("error parsing CASE expression: %v", err) + } + caseExpr.Value = value + remaining = newRemaining + } + + // Parse WHEN clauses + for len(remaining) > 0 && strings.ToUpper(remaining[0]) == "WHEN" { + remaining = remaining[1:] // Skip WHEN + + // Parse WHEN condition + condition, newRemaining, err := parseOrExpression(remaining) + if err != nil { + return nil, nil, fmt.Errorf("error parsing WHEN condition: %v", err) + } + remaining = newRemaining + + // Check THEN keyword + if len(remaining) == 0 || strings.ToUpper(remaining[0]) != "THEN" { + return nil, nil, fmt.Errorf("expected THEN after WHEN condition") + } + remaining = remaining[1:] // Skip THEN + + // Parse THEN result + result, newRemaining, err := parseOrExpression(remaining) + if err != nil { + return nil, nil, fmt.Errorf("error parsing THEN result: %v", err) + } + remaining = newRemaining + + // Add WHEN clause + caseExpr.WhenClauses = append(caseExpr.WhenClauses, WhenClause{ + Condition: condition, + Result: result, + }) + } + + // Parse optional ELSE clause + if len(remaining) > 0 && strings.ToUpper(remaining[0]) == "ELSE" { + remaining = remaining[1:] // Skip ELSE + + elseExpr, newRemaining, err := parseOrExpression(remaining) + if err != nil { + return nil, nil, fmt.Errorf("error parsing ELSE expression: %v", err) + } + caseExpr.ElseResult = elseExpr + remaining = newRemaining + } + + // Check END keyword + if len(remaining) == 0 || strings.ToUpper(remaining[0]) != "END" { + return nil, nil, fmt.Errorf("expected END to close CASE expression") + } + + // Create ExprNode containing CaseExpression + caseNode := &ExprNode{ + Type: TypeCase, + CaseExpr: caseExpr, + } + + return caseNode, remaining[1:], nil +} + +// evaluateCaseExpression evaluates the value of CASE expression +func evaluateCaseExpression(node *ExprNode, data map[string]interface{}) (float64, error) { + if node.Type != TypeCase { + return 0, fmt.Errorf("not a CASE expression") + } + + if node.CaseExpr == nil { + return 0, fmt.Errorf("invalid CASE expression") + } + + // Simple CASE expression: CASE expr WHEN value THEN result + if node.CaseExpr.Value != nil { + return evaluateSimpleCaseExpression(node, data) + } + + // Search CASE expression: CASE WHEN condition THEN result + return evaluateSearchCaseExpression(node, data) +} + +// evaluateSimpleCaseExpression evaluates simple CASE expression +func evaluateSimpleCaseExpression(node *ExprNode, data map[string]interface{}) (float64, error) { + caseExpr := node.CaseExpr + if caseExpr == nil { + return 0, fmt.Errorf("invalid CASE expression") + } + + // Evaluate CASE expression value + caseValue, err := evaluateNodeValue(caseExpr.Value, data) + if err != nil { + return 0, err + } + + // Iterate through WHEN clauses + for _, whenClause := range caseExpr.WhenClauses { + // Evaluate WHEN value + whenValue, err := evaluateNodeValue(whenClause.Condition, data) + if err != nil { + return 0, err + } + + // Compare values + if compareValuesForEquality(caseValue, whenValue) { + // Evaluate and return THEN result + return evaluateNode(whenClause.Result, data) + } + } + + // If no matching WHEN clause, evaluate ELSE expression + if caseExpr.ElseResult != nil { + return evaluateNode(caseExpr.ElseResult, data) + } + + // If no ELSE clause, return NULL (return 0 here) + return 0, nil +} + +// evaluateSearchCaseExpression evaluates search CASE expression +func evaluateSearchCaseExpression(node *ExprNode, data map[string]interface{}) (float64, error) { + caseExpr := node.CaseExpr + if caseExpr == nil { + return 0, fmt.Errorf("invalid CASE expression") + } + + // Iterate through WHEN clauses + for _, whenClause := range caseExpr.WhenClauses { + // Evaluate WHEN condition - use boolean evaluation to handle logical operators + conditionResult, err := evaluateBoolNode(whenClause.Condition, data) + if err != nil { + return 0, err + } + + // If condition is true, return THEN result + if conditionResult { + return evaluateNode(whenClause.Result, data) + } + } + + // If no matching WHEN clause, evaluate ELSE expression + if caseExpr.ElseResult != nil { + return evaluateNode(caseExpr.ElseResult, data) + } + + // If no ELSE clause, return NULL (return 0 here) + return 0, nil +} + +// evaluateCaseExpressionWithNull evaluates CASE expression with NULL value support +func evaluateCaseExpressionWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) { + if node.Type != TypeCase { + return nil, false, fmt.Errorf("not a CASE expression") + } + + caseExpr := node.CaseExpr + if caseExpr == nil { + return nil, false, fmt.Errorf("invalid CASE expression") + } + + // Simple CASE expression: CASE expr WHEN value THEN result + if caseExpr.Value != nil { + return evaluateCaseExpressionValueWithNull(node, data) + } + + // Search CASE expression: CASE WHEN condition THEN result + for _, whenClause := range caseExpr.WhenClauses { + // Evaluate WHEN condition - use boolean evaluation to handle logical operators + conditionResult, err := evaluateBoolNode(whenClause.Condition, data) + if err != nil { + return nil, false, err + } + + // If condition is true, return THEN result + if conditionResult { + return evaluateNodeValueWithNull(whenClause.Result, data) + } + } + + // If no matching WHEN clause, evaluate ELSE expression + if caseExpr.ElseResult != nil { + return evaluateNodeValueWithNull(caseExpr.ElseResult, data) + } + + // If no ELSE clause, return NULL + return nil, true, nil +} + +// evaluateCaseExpressionValueWithNull evaluates simple CASE expression (with NULL support) +func evaluateCaseExpressionValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) { + caseExpr := node.CaseExpr + if caseExpr == nil { + return nil, false, fmt.Errorf("invalid CASE expression") + } + + // Evaluate CASE expression value + caseValue, caseIsNull, err := evaluateNodeValueWithNull(caseExpr.Value, data) + if err != nil { + return nil, false, err + } + + // Iterate through WHEN clauses + for _, whenClause := range caseExpr.WhenClauses { + // Evaluate WHEN value + whenValue, whenIsNull, err := evaluateNodeValueWithNull(whenClause.Condition, data) + if err != nil { + return nil, false, err + } + + // Compare values (with NULL comparison support) + if compareValuesWithNullForEquality(caseValue, caseIsNull, whenValue, whenIsNull) { + // Evaluate and return THEN result + return evaluateNodeValueWithNull(whenClause.Result, data) + } + } + + // If no matching WHEN clause, evaluate ELSE expression + if caseExpr.ElseResult != nil { + return evaluateNodeValueWithNull(caseExpr.ElseResult, data) + } + + // If no ELSE clause, return NULL + return nil, true, nil +} diff --git a/expr/case_expression_test.go b/expr/case_expression_test.go new file mode 100644 index 0000000..2d50f7f --- /dev/null +++ b/expr/case_expression_test.go @@ -0,0 +1,372 @@ +package expr + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEvaluateCaseExpression 测试CASE表达式求值 +func TestEvaluateCaseExpression(t *testing.T) { + tests := []struct { + name string + node *ExprNode + data map[string]interface{} + expected float64 + wantErr bool + }{ + { + "简单CASE表达式", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeOperator, + Value: "=", + Left: &ExprNode{Type: TypeField, Value: "status"}, + Right: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + Result: &ExprNode{Type: TypeNumber, Value: "100"}, + }, + }, + }, + }, + map[string]interface{}{"status": 1}, + 100, + false, + }, + { + "带ELSE的CASE表达式", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "score"}, + Right: &ExprNode{Type: TypeNumber, Value: "90"}, + }, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + map[string]interface{}{"score": 85}, + 0, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := evaluateCaseExpression(tt.node, tt.data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "求值不应该失败") + assert.Equal(t, tt.expected, result, "求值结果应该正确") + } + }) + } +} + +// TestEvaluateCaseExpressionWithNull 测试支持NULL的CASE表达式求值 +func TestEvaluateCaseExpressionWithNull(t *testing.T) { + tests := []struct { + name string + node *ExprNode + data map[string]interface{} + expected interface{} + expectedNull bool + wantErr bool + }{ + { + "条件为NULL时返回NULL", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeField, Value: "missing_field"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + }, + }, + map[string]interface{}{}, + nil, + true, + false, + }, + { + "简单CASE表达式匹配", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + Value: &ExprNode{Type: TypeField, Value: "status"}, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeString, Value: "'active'"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + { + Condition: &ExprNode{Type: TypeString, Value: "'inactive'"}, + Result: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "-1"}, + }, + }, + map[string]interface{}{"status": "active"}, + 1.0, + false, + false, + }, + { + "简单CASE表达式不匹配使用ELSE", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + Value: &ExprNode{Type: TypeField, Value: "status"}, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeString, Value: "'active'"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + map[string]interface{}{"status": "unknown"}, + 0.0, + false, + false, + }, + { + "简单CASE表达式Value为NULL", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + Value: &ExprNode{Type: TypeField, Value: "missing_field"}, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeString, Value: "'test'"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + map[string]interface{}{}, + 0.0, + false, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, isNull, err := evaluateCaseExpressionWithNull(tt.node, tt.data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "求值不应该失败") + assert.Equal(t, tt.expectedNull, isNull, "NULL状态应该正确") + if !isNull { + assert.Equal(t, tt.expected, result, "求值结果应该正确") + } + } + }) + } +} + +func TestParseCaseExpression(t *testing.T) { + tests := []struct { + name string + tokens []string + expectError bool + description string + }{ + { + name: "empty tokens", + tokens: []string{}, + expectError: true, + description: "should return error for empty tokens", + }, + { + name: "not case keyword", + tokens: []string{"SELECT", "field"}, + expectError: true, + description: "should return error when first token is not CASE", + }, + { + name: "missing when after case", + tokens: []string{"CASE", "field"}, + expectError: true, + description: "should return error when missing WHEN after CASE", + }, + { + name: "missing then after when", + tokens: []string{"CASE", "WHEN", "field1", ">", "0"}, + expectError: true, + description: "should return error when missing THEN after WHEN", + }, + { + name: "missing end", + tokens: []string{"CASE", "WHEN", "field1", ">", "0", "THEN", "1"}, + expectError: true, + description: "should return error when missing END", + }, + { + name: "invalid when condition - missing operand", + tokens: []string{"CASE", "WHEN", ">", "0", "THEN", "1", "END"}, + expectError: true, + description: "should return error for invalid WHEN condition", + }, + { + name: "invalid then result - missing operand", + tokens: []string{"CASE", "WHEN", "field1", ">", "0", "THEN", "+", "END"}, + expectError: true, + description: "should return error for invalid THEN result", + }, + { + name: "invalid else expression - missing operand", + tokens: []string{"CASE", "WHEN", "field1", ">", "0", "THEN", "1", "ELSE", "+", "END"}, + expectError: true, + description: "should return error for invalid ELSE expression", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := parseCaseExpression(tt.tokens) + if tt.expectError && err == nil { + t.Errorf("expected error but got none: %s", tt.description) + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestEvaluateSimpleCaseExpression(t *testing.T) { + // Create a simple CASE expression for testing + caseExpr := &CaseExpression{ + Value: &ExprNode{ + Type: TypeField, + Value: "status", + }, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeString, + Value: "'active'", + }, + Result: &ExprNode{ + Type: TypeNumber, + Value: "1", + }, + }, + { + Condition: &ExprNode{ + Type: TypeString, + Value: "'inactive'", + }, + Result: &ExprNode{ + Type: TypeNumber, + Value: "0", + }, + }, + }, + ElseResult: &ExprNode{ + Type: TypeNumber, + Value: "-1", + }, + } + + node := &ExprNode{ + Type: TypeCase, + CaseExpr: caseExpr, + } + + tests := []struct { + name string + data map[string]interface{} + expected float64 + }{ + { + name: "match first when", + data: map[string]interface{}{"status": "active"}, + expected: 1.0, + }, + { + name: "match second when", + data: map[string]interface{}{"status": "inactive"}, + expected: 0.0, + }, + { + name: "no match use else", + data: map[string]interface{}{"status": "unknown"}, + expected: -1.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := evaluateSimpleCaseExpression(node, tt.data) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if result != tt.expected { + t.Errorf("expected %f, got %f", tt.expected, result) + } + }) + } + + // Test error cases + t.Run("nil case expression", func(t *testing.T) { + _, err := evaluateSimpleCaseExpression(&ExprNode{Type: TypeCase}, map[string]interface{}{}) + if err == nil { + t.Error("expected error for nil case expression") + } + }) + + t.Run("invalid case value", func(t *testing.T) { + caseExprWithError := &CaseExpression{ + Value: &ExprNode{ + Type: TypeField, + Value: "nonexistent", + }, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeString, + Value: "'active'", + }, + Result: &ExprNode{ + Type: TypeNumber, + Value: "1", + }, + }, + }, + } + + nodeWithError := &ExprNode{ + Type: TypeCase, + CaseExpr: caseExprWithError, + } + + _, err := evaluateSimpleCaseExpression(nodeWithError, map[string]interface{}{}) + if err == nil { + t.Error("expected error for invalid case value") + } + }) +} diff --git a/expr/doc.go b/expr/doc.go index 7d20a9b..d4465bd 100644 --- a/expr/doc.go +++ b/expr/doc.go @@ -67,7 +67,7 @@ Function call expression: CASE expression for conditional logic: expr, err := NewExpression(` - CASE + CASE WHEN temperature > 30 THEN 'hot' WHEN temperature > 20 THEN 'warm' ELSE 'cold' @@ -79,14 +79,14 @@ CASE expression for conditional logic: The expression parser follows standard mathematical precedence rules: - 1. Parentheses (highest) - 2. Power (^) - 3. Multiplication, Division, Modulo (*, /, %) - 4. Addition, Subtraction (+, -) - 5. Comparison (>, <, >=, <=, LIKE, IS) - 6. Equality (=, ==, !=, <>) - 7. Logical AND - 8. Logical OR (lowest) + 1. Parentheses (highest) + 2. Power (^) + 3. Multiplication, Division, Modulo (*, /, %) + 4. Addition, Subtraction (+, -) + 5. Comparison (>, <, >=, <=, LIKE, IS) + 6. Equality (=, ==, !=, <>) + 7. Logical AND + 8. Logical OR (lowest) # Error Handling @@ -114,4 +114,4 @@ This package integrates seamlessly with other StreamSQL components: • Stream package - For real-time expression evaluation in data streams • RSQL package - For SQL parsing and expression extraction */ -package expr \ No newline at end of file +package expr diff --git a/expr/evaluator.go b/expr/evaluator.go new file mode 100644 index 0000000..2a5a15d --- /dev/null +++ b/expr/evaluator.go @@ -0,0 +1,962 @@ +package expr + +import ( + "fmt" + "math" + "strconv" + "strings" + + "github.com/rulego/streamsql/functions" + "github.com/rulego/streamsql/utils/fieldpath" +) + +// evaluateNode evaluates the value of a node +func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) { + if node == nil { + return 0, fmt.Errorf("null expression node") + } + + switch node.Type { + case TypeNumber: + return strconv.ParseFloat(node.Value, 64) + + case TypeString: + // Handle string type, remove quotes and try to convert to number + value := node.Value + if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { + value = value[1 : len(value)-1] // Remove quotes + } + + // Try to convert to number + if f, err := strconv.ParseFloat(value, 64); err == nil { + return f, nil + } + + // For string comparison, return string length (temporary solution) + return float64(len(value)), nil + + case TypeField: + return evaluateFieldNode(node, data) + + case TypeOperator: + return evaluateOperatorNode(node, data) + + case TypeFunction: + return evaluateFunctionNode(node, data) + + case TypeCase: + // Handle CASE expression + return evaluateCaseExpression(node, data) + + case TypeParenthesis: + // Handle parenthesis expression, directly evaluate inner expression + return evaluateNode(node.Left, data) + } + + return 0, fmt.Errorf("unknown node type: %s", node.Type) +} + +// evaluateFieldNode evaluates the value of a field node +func evaluateFieldNode(node *ExprNode, data map[string]interface{}) (float64, error) { + // Handle backtick identifiers, remove backticks + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks + } + + // Support nested field access + if fieldpath.IsNestedField(fieldName) { + if val, found := fieldpath.GetNestedField(data, fieldName); found { + // Try to convert to float64 + if floatVal, err := convertToFloat(val); err == nil { + // Check if it's NaN + if math.IsNaN(floatVal) { + return 0, fmt.Errorf("field '%s' contains NaN value", fieldName) + } + return floatVal, nil + } + // If cannot convert to number, return error + return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val) + } + } else { + // Original simple field access + if val, found := data[fieldName]; found { + // Try to convert to float64 + if floatVal, err := convertToFloat(val); err == nil { + // Check if it's NaN + if math.IsNaN(floatVal) { + return 0, fmt.Errorf("field '%s' contains NaN value", fieldName) + } + return floatVal, nil + } + // If cannot convert to number, return error + return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val) + } + } + return 0, fmt.Errorf("field '%s' not found", fieldName) +} + +// evaluateOperatorNode evaluates the value of an operator node +func evaluateOperatorNode(node *ExprNode, data map[string]interface{}) (float64, error) { + // Check if it's a comparison operator + if isComparisonOperator(node.Value) { + // For comparison operators, use evaluateNodeValue to get original type + leftValue, err := evaluateNodeValue(node.Left, data) + if err != nil { + return 0, err + } + + rightValue, err := evaluateNodeValue(node.Right, data) + if err != nil { + return 0, err + } + + // Execute comparison and convert boolean to number + result, err := compareValues(leftValue, rightValue, node.Value) + if err != nil { + return 0, err + } + if result { + return 1.0, nil + } + return 0.0, nil + } + + // For arithmetic operators, calculate numeric values + left, err := evaluateNode(node.Left, data) + if err != nil { + return 0, err + } + + right, err := evaluateNode(node.Right, data) + if err != nil { + return 0, err + } + + // Check if operands are NaN + if math.IsNaN(left) { + return 0, fmt.Errorf("left operand is NaN") + } + if math.IsNaN(right) { + return 0, fmt.Errorf("right operand is NaN") + } + + // Execute operation + var result float64 + switch node.Value { + case "+": + result = left + right + case "-": + result = left - right + case "*": + result = left * right + case "/": + if right == 0 { + return 0, fmt.Errorf("division by zero") + } + result = left / right + case "%": + if right == 0 { + return 0, fmt.Errorf("modulo by zero") + } + result = math.Mod(left, right) + case "^": + result = math.Pow(left, right) + default: + return 0, fmt.Errorf("unknown operator: %s", node.Value) + } + + // Check if result is NaN + if math.IsNaN(result) { + return 0, fmt.Errorf("operation result is NaN") + } + + return result, nil +} + +// evaluateFunctionNode evaluates the value of a function node +// Uses unified function registration system to handle all function calls +func evaluateFunctionNode(node *ExprNode, data map[string]interface{}) (float64, error) { + // Check if function exists in the new function registration system + fn, exists := functions.Get(node.Value) + if !exists { + return 0, fmt.Errorf("unknown function: %s", node.Value) + } + + // Calculate all arguments but keep original types + args := make([]interface{}, len(node.Args)) + for i, arg := range node.Args { + // Use evaluateNodeValue to get original type values + val, err := evaluateNodeValue(arg, data) + if err != nil { + return 0, err + } + args[i] = val + } + + // Validate arguments + if err := fn.Validate(args); err != nil { + return 0, err + } + + // Create function execution context + ctx := &functions.FunctionContext{ + Data: data, + } + + // Execute function + result, err := fn.Execute(ctx, args) + if err != nil { + return 0, err + } + + // Convert result to float64 + switch r := result.(type) { + case float64: + return r, nil + case float32: + return float64(r), nil + case int: + return float64(r), nil + case int32: + return float64(r), nil + case int64: + return float64(r), nil + case string: + // For string results, try to convert to number, return string length if failed + if f, err := strconv.ParseFloat(r, 64); err == nil { + return f, nil + } + return float64(len(r)), nil + case bool: + // Boolean conversion: true=1, false=0 + if r { + return 1.0, nil + } + return 0.0, nil + default: + return 0, fmt.Errorf("function %s returned unsupported type for numeric conversion: %T", node.Value, result) + } +} + +// evaluateNodeValue evaluates the original value of a node (preserving type) +func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{}, error) { + if node == nil { + return nil, fmt.Errorf("null expression node") + } + + switch node.Type { + case TypeNumber: + return strconv.ParseFloat(node.Value, 64) + + case TypeString: + // Handle string type, remove quotes + value := node.Value + if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { + value = value[1 : len(value)-1] // Remove quotes + } + return value, nil + + case TypeField: + return evaluateFieldValue(node, data) + + case TypeOperator: + return evaluateOperatorValue(node, data) + + case TypeFunction: + return evaluateFunctionValue(node, data) + + case TypeCase: + // Handle CASE expression + return evaluateCaseExpression(node, data) + + case TypeParenthesis: + // Handle parenthesis expression, directly evaluate inner expression + return evaluateNodeValue(node.Left, data) + } + + return nil, fmt.Errorf("unknown node type: %s", node.Type) +} + +// evaluateFieldValue evaluates the original value of a field +func evaluateFieldValue(node *ExprNode, data map[string]interface{}) (interface{}, error) { + // Handle backtick identifiers, remove backticks + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks + } + + // Support nested field access + if fieldpath.IsNestedField(fieldName) { + if val, found := fieldpath.GetNestedField(data, fieldName); found { + return val, nil + } + } else { + // Original simple field access + if val, found := data[fieldName]; found { + return val, nil + } + } + return nil, fmt.Errorf("field '%s' not found", fieldName) +} + +// evaluateOperatorValue evaluates the original value of an operator +func evaluateOperatorValue(node *ExprNode, data map[string]interface{}) (interface{}, error) { + // Special handling for IS and IS NOT operators + operator := strings.ToUpper(node.Value) + if operator == "IS" || operator == "IS NOT" { + return evaluateIsOperator(node, data) + } + + // Check if it's a logical operator + if isLogicalOperator(node.Value) { + // For logical operators, use boolean evaluation + result, err := evaluateBoolOperator(node, data) + if err != nil { + return nil, err + } + return result, nil + } + + // Check if it's a comparison operator + if isComparisonOperator(node.Value) { + leftValue, err := evaluateNodeValue(node.Left, data) + if err != nil { + return nil, err + } + + rightValue, err := evaluateNodeValue(node.Right, data) + if err != nil { + return nil, err + } + + // Execute comparison + return compareValues(leftValue, rightValue, node.Value) + } + + // For arithmetic operators, use NULL-supporting evaluation + left, leftIsNull, err := evaluateNodeValueWithNull(node.Left, data) + if err != nil { + return nil, err + } + + right, rightIsNull, err := evaluateNodeValueWithNull(node.Right, data) + if err != nil { + return nil, err + } + + // If any operand is NULL, result is NULL + if leftIsNull || rightIsNull { + return nil, nil + } + + // Try to convert operands to numbers + leftFloat, leftOk := convertToFloatSafe(left) + rightFloat, rightOk := convertToFloatSafe(right) + + if !leftOk { + return nil, fmt.Errorf("left operand cannot be converted to number: %v", left) + } + if !rightOk { + return nil, fmt.Errorf("right operand cannot be converted to number: %v", right) + } + + // Execute arithmetic operation + var result float64 + switch node.Value { + case "+": + result = leftFloat + rightFloat + case "-": + result = leftFloat - rightFloat + case "*": + result = leftFloat * rightFloat + case "/": + if rightFloat == 0 { + return nil, fmt.Errorf("division by zero") + } + result = leftFloat / rightFloat + case "%": + if rightFloat == 0 { + return nil, fmt.Errorf("modulo by zero") + } + result = math.Mod(leftFloat, rightFloat) + case "^": + result = math.Pow(leftFloat, rightFloat) + default: + return nil, fmt.Errorf("unknown arithmetic operator: %s", node.Value) + } + + // Check if result is NaN + if math.IsNaN(result) { + return nil, fmt.Errorf("operation result is NaN") + } + + return result, nil +} + +// evaluateFunctionValue evaluates the original value of a function +// Uses unified function registration system to handle all function calls +func evaluateFunctionValue(node *ExprNode, data map[string]interface{}) (interface{}, error) { + // Check if function exists in the new function registration system + fn, exists := functions.Get(node.Value) + if !exists { + return nil, fmt.Errorf("unknown function: %s", node.Value) + } + + // Calculate all arguments but keep original types + args := make([]interface{}, len(node.Args)) + for i, arg := range node.Args { + val, err := evaluateNodeValue(arg, data) + if err != nil { + return nil, err + } + args[i] = val + } + + // Validate arguments + if err := fn.Validate(args); err != nil { + return nil, err + } + + // Create function execution context + ctx := &functions.FunctionContext{ + Data: data, + } + + // Execute function + return fn.Execute(ctx, args) +} + +// compareValues compares two values +func compareValues(left, right interface{}, operator string) (bool, error) { + // Handle NULL values + if left == nil || right == nil { + switch strings.ToUpper(operator) { + case "IS": + return left == right, nil + case "IS NOT": + return left != right, nil + default: + return false, nil // NULL compared with any value returns false + } + } + + // Try numeric comparison + leftFloat, leftIsFloat := convertToFloatSafe(left) + rightFloat, rightIsFloat := convertToFloatSafe(right) + + if leftIsFloat && rightIsFloat { + return compareFloats(leftFloat, rightFloat, operator) + } + + // Check for incompatible type comparison (one is number, one is not) + if (leftIsFloat && !rightIsFloat) || (!leftIsFloat && rightIsFloat) { + // For equality comparison, allow type conversion + operatorUpper := strings.ToUpper(operator) + if operatorUpper == "==" || operatorUpper == "=" || operatorUpper == "!=" || operatorUpper == "<>" { + // String comparison + leftStr := fmt.Sprintf("%v", left) + rightStr := fmt.Sprintf("%v", right) + return compareStrings(leftStr, rightStr, operator) + } + // For size comparison, return error + return false, fmt.Errorf("cannot compare incompatible types: %T and %T", left, right) + } + + // String comparison + leftStr := fmt.Sprintf("%v", left) + rightStr := fmt.Sprintf("%v", right) + + return compareStrings(leftStr, rightStr, operator) +} + +// compareFloats compares two floating point numbers +func compareFloats(left, right float64, operator string) (bool, error) { + switch strings.ToUpper(operator) { + case "==", "=": + return left == right, nil + case "!=", "<>": + return left != right, nil + case ">": + return left > right, nil + case "<": + return left < right, nil + case ">=": + return left >= right, nil + case "<=": + return left <= right, nil + default: + return false, fmt.Errorf("unsupported numeric comparison operator: %s", operator) + } +} + +// compareStrings compares two strings +func compareStrings(left, right, operator string) (bool, error) { + switch strings.ToUpper(operator) { + case "==", "=": + return left == right, nil + case "!=", "<>": + return left != right, nil + case ">": + return left > right, nil + case "<": + return left < right, nil + case ">=": + return left >= right, nil + case "<=": + return left <= right, nil + case "LIKE": + return matchLikePattern(left, right), nil + default: + return false, fmt.Errorf("unsupported string comparison operator: %s", operator) + } +} + +// matchLikePattern implements LIKE pattern matching +func matchLikePattern(text, pattern string) bool { + // Simplified LIKE implementation, supports % and _ wildcards + // % matches any character sequence, _ matches single character + return matchPattern(text, pattern, 0, 0) +} + +// matchPattern recursively matches pattern +func matchPattern(text, pattern string, textIdx, patternIdx int) bool { + if patternIdx == len(pattern) { + return textIdx == len(text) + } + + if pattern[patternIdx] == '%' { + // % matches any character sequence + for i := textIdx; i <= len(text); i++ { + if matchPattern(text, pattern, i, patternIdx+1) { + return true + } + } + return false + } + + if textIdx == len(text) { + return false + } + + if pattern[patternIdx] == '_' || pattern[patternIdx] == text[textIdx] { + // _ matches single character or exact character match + return matchPattern(text, pattern, textIdx+1, patternIdx+1) + } + + return false +} + +// compareValuesForEquality compares two values for equality (for simple CASE expressions) +func compareValuesForEquality(left, right interface{}) bool { + if left == nil && right == nil { + return true + } + if left == nil || right == nil { + return false + } + + // Try numeric comparison + leftFloat, leftIsFloat := convertToFloatSafe(left) + rightFloat, rightIsFloat := convertToFloatSafe(right) + + if leftIsFloat && rightIsFloat { + return leftFloat == rightFloat + } + + // String comparison + leftStr := fmt.Sprintf("%v", left) + rightStr := fmt.Sprintf("%v", right) + return leftStr == rightStr +} + +// evaluateNodeWithNull evaluates node value with NULL value handling +func evaluateNodeWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) { + if node == nil { + return 0, true, nil // NULL node + } + + switch node.Type { + case TypeNumber: + val, err := strconv.ParseFloat(node.Value, 64) + return val, false, err + + case TypeField: + // Handle backtick identifiers + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] + } + + // Support nested field access + var val interface{} + var found bool + if fieldpath.IsNestedField(fieldName) { + val, found = fieldpath.GetNestedField(data, fieldName) + } else { + val, found = data[fieldName] + } + + if !found { + return 0, true, nil // Field not found is treated as NULL + } + + if val == nil { + return 0, true, nil // NULL value + } + + // Try to convert to numeric value + if floatVal, ok := convertToFloatSafe(val); ok { + return floatVal, false, nil + } + + return 0, false, fmt.Errorf("field '%s' is not a number", fieldName) + + case TypeOperator: + // For comparison operators, return boolean converted to numeric + if isComparisonOperator(node.Value) { + leftValue, leftIsNull, err := evaluateNodeValueWithNull(node.Left, data) + if err != nil { + return 0, false, err + } + + rightValue, rightIsNull, err := evaluateNodeValueWithNull(node.Right, data) + if err != nil { + return 0, false, err + } + + // Handle NULL comparison + if leftIsNull || rightIsNull { + switch strings.ToUpper(node.Value) { + case "IS": + if leftIsNull && rightIsNull { + return 1, false, nil + } + return 0, false, nil + case "IS NOT": + if leftIsNull && rightIsNull { + return 0, false, nil + } + return 1, false, nil + default: + return 0, true, nil // NULL compared with any value returns NULL + } + } + + // Execute comparison + result, err := compareValues(leftValue, rightValue, node.Value) + if err != nil { + return 0, false, err + } + if result { + return 1, false, nil + } + return 0, false, nil + } + + // Arithmetic operators + leftVal, leftIsNull, err := evaluateNodeWithNull(node.Left, data) + if err != nil { + return 0, false, err + } + if leftIsNull { + return 0, true, nil + } + + rightVal, rightIsNull, err := evaluateNodeWithNull(node.Right, data) + if err != nil { + return 0, false, err + } + if rightIsNull { + return 0, true, nil + } + + switch node.Value { + case "+": + return leftVal + rightVal, false, nil + case "-": + return leftVal - rightVal, false, nil + case "*": + return leftVal * rightVal, false, nil + case "/": + if rightVal == 0 { + return 0, false, fmt.Errorf("division by zero") + } + return leftVal / rightVal, false, nil + case "%": + if rightVal == 0 { + return 0, false, fmt.Errorf("modulo by zero") + } + return math.Mod(leftVal, rightVal), false, nil + default: + return 0, false, fmt.Errorf("unknown operator: %s", node.Value) + } + + case TypeFunction: + // Function call, if any argument is NULL, result is usually NULL + val, err := evaluateNode(node, data) + return val, false, err + + case TypeCase: + // CASE expression + result, isNull, err := evaluateCaseExpressionWithNull(node, data) + if err != nil { + return 0, false, err + } + if isNull { + return 0, true, nil + } + if floatVal, ok := convertToFloatSafe(result); ok { + return floatVal, false, nil + } + return 0, false, fmt.Errorf("CASE expression result is not a number") + + case TypeParenthesis: + // Handle parenthesis expression, directly evaluate inner expression + return evaluateNodeWithNull(node.Left, data) + + default: + return 0, false, fmt.Errorf("unknown node type: %s", node.Type) + } +} + +// evaluateNodeValueWithNull evaluates the original value of a node with NULL value handling +func evaluateNodeValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) { + if node == nil { + return nil, true, nil + } + + switch node.Type { + case TypeNumber: + val, err := strconv.ParseFloat(node.Value, 64) + return val, false, err + + case TypeString: + // Handle string type, remove quotes + value := node.Value + if len(value) >= 2 && ((value[0] == '"' && value[len(value)-1] == '"') || (value[0] == '\'' && value[len(value)-1] == '\'')) { + value = value[1 : len(value)-1] // Remove quotes + } + return value, false, nil + + case TypeField: + // Handle backtick identifiers + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] + } + + // Support nested field access + var val interface{} + var found bool + if fieldpath.IsNestedField(fieldName) { + val, found = fieldpath.GetNestedField(data, fieldName) + } else { + val, found = data[fieldName] + } + + if !found { + return nil, true, nil // Field not found is treated as NULL + } + + return val, val == nil, nil + + case TypeOperator: + val, err := evaluateOperatorValue(node, data) + if err != nil { + return nil, false, err + } + return val, false, nil + + case TypeFunction: + val, err := evaluateFunctionValue(node, data) + return val, val == nil, err + + case TypeCase: + return evaluateCaseExpressionWithNull(node, data) + + case TypeParenthesis: + // Handle parenthesis expression, directly evaluate inner expression + return evaluateNodeValueWithNull(node.Left, data) + + default: + return nil, false, fmt.Errorf("unknown node type: %s", node.Type) + } +} + +// compareValuesWithNullForEquality compares two values for equality (supports NULL comparison) +func compareValuesWithNullForEquality(left interface{}, leftIsNull bool, right interface{}, rightIsNull bool) bool { + if leftIsNull && rightIsNull { + return true + } + if leftIsNull || rightIsNull { + return false + } + return compareValuesForEquality(left, right) +} + +// evaluateBoolNode evaluates the boolean value of a node +func evaluateBoolNode(node *ExprNode, data map[string]interface{}) (bool, error) { + if node == nil { + return false, fmt.Errorf("null expression node") + } + + switch node.Type { + case TypeOperator: + return evaluateBoolOperator(node, data) + case TypeFunction: + return evaluateBoolFunction(node, data) + case TypeParenthesis: + // Parenthesis node, recursively evaluate inner expression + if node.Left != nil { + return evaluateBoolNode(node.Left, data) + } + return false, fmt.Errorf("empty parenthesis expression") + case TypeField: + // Convert field value to boolean + value, err := evaluateFieldValue(node, data) + if err != nil { + // If field doesn't exist, treat as NULL, convert to false + return false, nil + } + return convertToBool(value), nil + case TypeNumber: + // Convert number to boolean (non-zero is true) + value, err := strconv.ParseFloat(node.Value, 64) + if err != nil { + return false, err + } + return value != 0, nil + case TypeString: + // Convert string to boolean (non-empty is true) + value := node.Value + if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { + value = value[1 : len(value)-1] // Remove quotes + } + return value != "", nil + default: + return false, fmt.Errorf("unsupported node type for boolean evaluation: %s", node.Type) + } +} + +// evaluateBoolOperator evaluates boolean operators +func evaluateBoolOperator(node *ExprNode, data map[string]interface{}) (bool, error) { + operator := strings.ToUpper(node.Value) + + switch operator { + case "AND", "&&": + left, err := evaluateBoolNode(node.Left, data) + if err != nil { + return false, err + } + if !left { + return false, nil // Short-circuit evaluation + } + return evaluateBoolNode(node.Right, data) + + case "OR", "||": + left, err := evaluateBoolNode(node.Left, data) + if err != nil { + return false, err + } + if left { + return true, nil // Short-circuit evaluation + } + return evaluateBoolNode(node.Right, data) + + case "NOT", "!": + // NOT operator may use Left or Right node + var operand *ExprNode + if node.Left != nil { + operand = node.Left + } else if node.Right != nil { + operand = node.Right + } else { + return false, fmt.Errorf("NOT operator requires an operand") + } + result, err := evaluateBoolNode(operand, data) + if err != nil { + return false, err + } + return !result, nil + + case "IS", "IS NOT": + // IS and IS NOT operators (including IS NULL and IS NOT NULL) + result, err := evaluateIsOperator(node, data) + if err != nil { + return false, err + } + return convertToBool(result), nil + + case "==", "=", "!=", "<>", ">", "<", ">=", "<=", "LIKE": + // Comparison operators + leftValue, err := evaluateNodeValue(node.Left, data) + if err != nil { + return false, err + } + rightValue, err := evaluateNodeValue(node.Right, data) + if err != nil { + return false, err + } + return compareValues(leftValue, rightValue, operator) + + default: + return false, fmt.Errorf("unsupported boolean operator: %s", operator) + } +} + +// evaluateBoolFunction evaluates boolean functions +func evaluateBoolFunction(node *ExprNode, data map[string]interface{}) (bool, error) { + // Call function and convert result to boolean + result, err := evaluateFunctionValue(node, data) + if err != nil { + return false, err + } + return convertToBool(result), nil +} + +// evaluateIsOperator handles IS and IS NOT operators (mainly IS NULL and IS NOT NULL) +func evaluateIsOperator(node *ExprNode, data map[string]interface{}) (interface{}, error) { + if node.Right == nil { + return nil, fmt.Errorf("IS operator requires a right operand") + } + + operator := strings.ToUpper(node.Value) + + // Check if right side is NULL + if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" { + // Get left value using NULL-supporting method + _, leftIsNull, err := evaluateNodeValueWithNull(node.Left, data) + if err != nil { + // If field doesn't exist, consider it NULL + leftIsNull = true + } + + if operator == "IS" { + // IS NULL comparison + return leftIsNull, nil + } else if operator == "IS NOT" { + // IS NOT NULL comparison + return !leftIsNull, nil + } + } + + // Other IS comparisons + leftValue, err := evaluateNodeValue(node.Left, data) + if err != nil { + return nil, err + } + + rightValue, err := evaluateNodeValue(node.Right, data) + if err != nil { + return nil, err + } + + if operator == "IS" { + return compareValuesForEquality(leftValue, rightValue), nil + } else if operator == "IS NOT" { + return !compareValuesForEquality(leftValue, rightValue), nil + } + + return nil, fmt.Errorf("unsupported IS operator: %s", operator) +} diff --git a/expr/evaluator_test.go b/expr/evaluator_test.go new file mode 100644 index 0000000..2a6ffff --- /dev/null +++ b/expr/evaluator_test.go @@ -0,0 +1,1298 @@ +package expr + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEvaluateNode 测试节点求值功能 +func TestEvaluateNode(t *testing.T) { + data := map[string]interface{}{ + "a": 10.0, + "b": 5.0, + "c": 2.0, + "name": "test", + "flag": true, + } + + tests := []struct { + name string + node *ExprNode + expected float64 + wantErr bool + }{ + { + "数字节点", + &ExprNode{Type: TypeNumber, Value: "123"}, + 123.0, + false, + }, + { + "字段节点", + &ExprNode{Type: TypeField, Value: "a"}, + 10.0, + false, + }, + { + "加法运算", + &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, + 15.0, + false, + }, + { + "乘法运算", + &ExprNode{ + Type: TypeOperator, + Value: "*", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, + 50.0, + false, + }, + { + "除法运算", + &ExprNode{ + Type: TypeOperator, + Value: "/", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, + 2.0, + false, + }, + { + "幂运算", + &ExprNode{ + Type: TypeOperator, + Value: "^", + Left: &ExprNode{Type: TypeField, Value: "c"}, + Right: &ExprNode{Type: TypeNumber, Value: "3"}, + }, + 8.0, + false, + }, + { + "取模运算", + &ExprNode{ + Type: TypeOperator, + Value: "%", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "3"}, + }, + 1.0, + false, + }, + { + "函数调用", + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeNumber, Value: "-5"}}, + }, + 5.0, + false, + }, + { + "括号表达式", + &ExprNode{ + Type: TypeParenthesis, + Left: &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, + }, + 15.0, + false, + }, + // 错误情况 + {"不存在的字段", &ExprNode{Type: TypeField, Value: "unknown"}, 0, true}, + {"除零错误", &ExprNode{ + Type: TypeOperator, + Value: "/", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, 0, true}, + {"无效的运算符", &ExprNode{ + Type: TypeOperator, + Value: "@", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + Right: &ExprNode{Type: TypeNumber, Value: "2"}, + }, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := evaluateNode(tt.node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "求值不应该失败") + assert.Equal(t, tt.expected, result, "求值结果应该正确") + } + }) + } +} + +// TestEvaluateNodeWithNull 测试支持NULL值的节点求值 +func TestEvaluateNodeWithNull(t *testing.T) { + data := map[string]interface{}{ + "a": 10.0, + "b": nil, + "c": 5.0, + "flag": true, + "nested.field": 20.0, + } + + tests := []struct { + name string + node *ExprNode + expected float64 + expectedNull bool + wantErr bool + }{ + { + "空节点", + nil, + 0, + true, + false, + }, + { + "数字节点", + &ExprNode{Type: TypeNumber, Value: "123"}, + 123.0, + false, + false, + }, + { + "字段节点(存在)", + &ExprNode{Type: TypeField, Value: "a"}, + 10.0, + false, + false, + }, + { + "字段节点(NULL值)", + &ExprNode{Type: TypeField, Value: "b"}, + 0, + true, + false, + }, + { + "字段节点(不存在)", + &ExprNode{Type: TypeField, Value: "unknown"}, + 0, + true, + false, + }, + { + "字段节点(反引号)", + &ExprNode{Type: TypeField, Value: "`a`"}, + 10.0, + false, + false, + }, + { + "嵌套字段", + &ExprNode{Type: TypeField, Value: "nested.field"}, + 0, + true, + false, + }, + { + "布尔字段", + &ExprNode{Type: TypeField, Value: "flag"}, + 1.0, + false, + false, + }, + { + "加法运算(正常)", + &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "5"}, + }, + 15.0, + false, + false, + }, + { + "加法运算(左NULL)", + &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeNumber, Value: "5"}, + }, + 0, + true, + false, + }, + { + "加法运算(右NULL)", + &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, + 0, + true, + false, + }, + { + "IS NULL比较(真)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + 1, + false, + false, + }, + { + "IS NULL比较(假)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + 0, + false, + false, + }, + { + "括号表达式", + &ExprNode{ + Type: TypeParenthesis, + Left: &ExprNode{Type: TypeField, Value: "a"}, + }, + 10.0, + false, + false, + }, + { + "函数调用", + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeNumber, Value: "-5"}}, + }, + 5.0, + false, + false, + }, + // 错误情况 + { + "无效数字", + &ExprNode{Type: TypeNumber, Value: "invalid"}, + 0, + false, + true, + }, + { + "数字字段", + &ExprNode{Type: TypeField, Value: "c"}, + 5.0, + false, + false, + }, + { + "除零错误", + &ExprNode{ + Type: TypeOperator, + Value: "/", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + 0, + false, + true, + }, + { + "未知运算符", + &ExprNode{ + Type: TypeOperator, + Value: "@", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + Right: &ExprNode{Type: TypeNumber, Value: "2"}, + }, + 0, + false, + true, + }, + { + "未知节点类型", + &ExprNode{Type: "unknown"}, + 0, + false, + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, isNull, err := evaluateNodeWithNull(tt.node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "求值不应该失败") + assert.Equal(t, tt.expected, result, "求值结果应该正确") + assert.Equal(t, tt.expectedNull, isNull, "NULL状态应该正确") + } + }) + } +} + +// TestEvaluateIsOperator 测试IS运算符 +func TestEvaluateIsOperator(t *testing.T) { + data := map[string]interface{}{ + "a": 10.0, + "b": nil, + "flag": true, + "name": "test", + } + + tests := []struct { + name string + node *ExprNode + expected interface{} + wantErr bool + }{ + { + "IS NULL(真)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + true, + false, + }, + { + "IS NULL(假)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + false, + false, + }, + { + "IS NOT NULL(真)", + &ExprNode{ + Type: TypeOperator, + Value: "IS NOT", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + true, + false, + }, + { + "IS NOT NULL(假)", + &ExprNode{ + Type: TypeOperator, + Value: "IS NOT", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + false, + false, + }, + { + "IS 相等比较(真)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "10"}, + }, + true, + false, + }, + { + "IS 相等比较(假)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "5"}, + }, + false, + false, + }, + { + "IS NOT 不等比较(真)", + &ExprNode{ + Type: TypeOperator, + Value: "IS NOT", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "5"}, + }, + true, + false, + }, + { + "IS NOT 不等比较(假)", + &ExprNode{ + Type: TypeOperator, + Value: "IS NOT", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "10"}, + }, + false, + false, + }, + { + "不存在字段IS NULL", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "nonexistent"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + true, + false, + }, + // 错误情况 + { + "缺少右操作数", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: nil, + }, + nil, + true, + }, + { + "不支持的IS运算符", + &ExprNode{ + Type: TypeOperator, + Value: "IS UNKNOWN", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + nil, + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := evaluateIsOperator(tt.node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "IS运算符求值不应该失败") + assert.Equal(t, tt.expected, result, "IS运算符求值结果应该正确") + } + }) + } +} + +// TestEvaluateBoolFunction 测试布尔函数求值 +func TestEvaluateBoolFunction(t *testing.T) { + data := map[string]interface{}{} + + tests := []struct { + name string + node *ExprNode + expected bool + wantErr bool + }{ + { + "ABS函数(非零)", + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeNumber, Value: "-5"}}, + }, + true, + false, + }, + { + "ABS函数(零)", + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeNumber, Value: "0"}}, + }, + false, + false, + }, + // 错误情况 + { + "未知函数", + &ExprNode{ + Type: TypeFunction, + Value: "unknown_func", + Args: []*ExprNode{{Type: TypeNumber, Value: "5"}}, + }, + false, + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := evaluateBoolFunction(tt.node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "布尔函数求值不应该失败") + assert.Equal(t, tt.expected, result, "布尔函数求值结果应该正确") + } + }) + } +} + +// TestEvaluateFieldNode 测试字段节点求值 +func TestEvaluateFieldNode(t *testing.T) { + data := map[string]interface{}{ + "int_field": 42, + "float_field": 3.14, + "string_field": "hello", + "bool_field": true, + "nil_field": nil, + } + + tests := []struct { + name string + fieldName string + expected float64 + wantErr bool + }{ + {"整数字段", "int_field", 42.0, false}, + {"浮点数字段", "float_field", 3.14, false}, + {"字符串字段(数字)", "string_field", 0, true}, // 字符串"hello"无法转换为数字 + {"布尔字段", "bool_field", 1.0, false}, + {"nil字段", "nil_field", 0, true}, + {"不存在的字段", "unknown_field", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{Type: TypeField, Value: tt.fieldName} + result, err := evaluateFieldNode(node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "字段求值不应该失败") + assert.Equal(t, tt.expected, result, "字段求值结果应该正确") + } + }) + } +} + +// TestEvaluateOperatorNode 测试运算符节点求值 +func TestEvaluateOperatorNode(t *testing.T) { + data := map[string]interface{}{ + "a": 10.0, + "b": 3.0, + "c": 0.0, + } + + tests := []struct { + name string + operator string + left *ExprNode + right *ExprNode + expected float64 + wantErr bool + }{ + { + "加法", + "+", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "b"}, + 13.0, + false, + }, + { + "减法", + "-", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "b"}, + 7.0, + false, + }, + { + "乘法", + "*", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "b"}, + 30.0, + false, + }, + { + "除法", + "/", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "b"}, + 10.0 / 3.0, + false, + }, + { + "取模", + "%", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "b"}, + 1.0, + false, + }, + { + "幂运算", + "^", + &ExprNode{Type: TypeField, Value: "b"}, + &ExprNode{Type: TypeNumber, Value: "2"}, + 9.0, + false, + }, + // 错误情况 + { + "除零", + "/", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "c"}, + 0, + true, + }, + { + "模零", + "%", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "c"}, + 0, + true, + }, + { + "无效运算符", + "@", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeField, Value: "b"}, + 0, + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{ + Type: TypeOperator, + Value: tt.operator, + Left: tt.left, + Right: tt.right, + } + result, err := evaluateOperatorNode(node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "运算符求值不应该失败") + assert.InDelta(t, tt.expected, result, 1e-10, "运算符求值结果应该正确") + } + }) + } +} + +// TestEvaluateFunctionNode 测试函数节点求值 +func TestEvaluateFunctionNode(t *testing.T) { + data := map[string]interface{}{} + + tests := []struct { + name string + funcName string + args []*ExprNode + expected float64 + wantErr bool + }{ + { + "ABS函数", + "abs", + []*ExprNode{{Type: TypeNumber, Value: "-5"}}, + 5.0, + false, + }, + { + "SQRT函数", + "sqrt", + []*ExprNode{{Type: TypeNumber, Value: "16"}}, + 4.0, + false, + }, + { + "POW函数", + "pow", + []*ExprNode{ + {Type: TypeNumber, Value: "2"}, + {Type: TypeNumber, Value: "3"}, + }, + 8.0, + false, + }, + { + "MAX函数", + "max", + []*ExprNode{ + {Type: TypeNumber, Value: "5"}, + {Type: TypeNumber, Value: "3"}, + {Type: TypeNumber, Value: "8"}, + }, + 8.0, + false, + }, + { + "MIN函数", + "min", + []*ExprNode{ + {Type: TypeNumber, Value: "5"}, + {Type: TypeNumber, Value: "3"}, + {Type: TypeNumber, Value: "8"}, + }, + 3.0, + false, + }, + { + "SUM函数", + "sum", + []*ExprNode{ + {Type: TypeNumber, Value: "1"}, + {Type: TypeNumber, Value: "2"}, + {Type: TypeNumber, Value: "3"}, + }, + 6.0, + false, + }, + { + "AVG函数", + "avg", + []*ExprNode{ + {Type: TypeNumber, Value: "2"}, + {Type: TypeNumber, Value: "4"}, + {Type: TypeNumber, Value: "6"}, + }, + 4.0, + false, + }, + { + "COUNT函数", + "count", + []*ExprNode{ + {Type: TypeNumber, Value: "1"}, + {Type: TypeNumber, Value: "2"}, + {Type: TypeNumber, Value: "3"}, + }, + 3.0, + false, + }, + { + "ROUND函数", + "round", + []*ExprNode{{Type: TypeNumber, Value: "3.7"}}, + 4.0, + false, + }, + { + "FLOOR函数", + "floor", + []*ExprNode{{Type: TypeNumber, Value: "3.7"}}, + 3.0, + false, + }, + { + "CEIL函数", + "ceil", + []*ExprNode{{Type: TypeNumber, Value: "3.2"}}, + 4.0, + false, + }, + // 三角函数 + { + "SIN函数", + "sin", + []*ExprNode{{Type: TypeNumber, Value: "0"}}, + 0.0, + false, + }, + { + "COS函数", + "cos", + []*ExprNode{{Type: TypeNumber, Value: "0"}}, + 1.0, + false, + }, + // 对数函数 + { + "LOG函数", + "log", + []*ExprNode{{Type: TypeNumber, Value: "10"}}, + math.Log10(10), + false, + }, + { + "LN函数", + "ln", + []*ExprNode{{Type: TypeNumber, Value: "1"}}, + 0.0, + false, + }, + { + "EXP函数", + "exp", + []*ExprNode{{Type: TypeNumber, Value: "0"}}, + 1.0, + false, + }, + // 错误情况 + {"未知函数", "unknown", []*ExprNode{{Type: TypeNumber, Value: "1"}}, 0, true}, + {"参数数量错误", "abs", []*ExprNode{}, 0, true}, + {"SQRT负数", "sqrt", []*ExprNode{{Type: TypeNumber, Value: "-1"}}, 0, true}, + {"LOG零或负数", "log", []*ExprNode{{Type: TypeNumber, Value: "0"}}, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{ + Type: TypeFunction, + Value: tt.funcName, + Args: tt.args, + } + result, err := evaluateFunctionNode(node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "函数求值不应该失败") + assert.InDelta(t, tt.expected, result, 1e-10, "函数求值结果应该正确") + } + }) + } +} + +// TestEvaluateNodeValue 测试节点值求值(支持NULL) +func TestEvaluateNodeValue(t *testing.T) { + data := map[string]interface{}{ + "a": 10.0, + "b": nil, + "name": "test", + "empty": "", + "zero": 0, + "negative": -5, + } + + tests := []struct { + name string + node *ExprNode + expected interface{} + wantErr bool + }{ + { + "数字节点", + &ExprNode{Type: TypeNumber, Value: "123"}, + 123.0, + false, + }, + { + "字符串节点", + &ExprNode{Type: TypeString, Value: "'hello'"}, + "hello", + false, + }, + { + "字段节点(数字)", + &ExprNode{Type: TypeField, Value: "a"}, + 10.0, + false, + }, + { + "字段节点(NULL)", + &ExprNode{Type: TypeField, Value: "b"}, + nil, + false, + }, + { + "字段节点(字符串)", + &ExprNode{Type: TypeField, Value: "name"}, + "test", + false, + }, + { + "字段节点(空字符串)", + &ExprNode{Type: TypeField, Value: "empty"}, + "", + false, + }, + { + "字段节点(零值)", + &ExprNode{Type: TypeField, Value: "zero"}, + 0, + false, + }, + { + "字段节点(负数)", + &ExprNode{Type: TypeField, Value: "negative"}, + -5, + false, + }, + { + "等于比较(相等)", + &ExprNode{ + Type: TypeOperator, + Value: "==", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "10"}, + }, + true, + false, + }, + { + "等于比较(不相等)", + &ExprNode{ + Type: TypeOperator, + Value: "==", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "5"}, + }, + false, + false, + }, + { + "IS NULL比较(真)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + true, + false, + }, + { + "IS NULL比较(假)", + &ExprNode{ + Type: TypeOperator, + Value: "IS", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "NULL"}, + }, + false, + false, + }, + // 错误情况 + {"不存在的字段", &ExprNode{Type: TypeField, Value: "unknown"}, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := evaluateNodeValue(tt.node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "节点值求值不应该失败") + assert.Equal(t, tt.expected, result, "节点值求值结果应该正确") + } + }) + } +} + +// TestCompareValues 测试值比较功能 +func TestCompareValues(t *testing.T) { + tests := []struct { + name string + operator string + left interface{} + right interface{} + expected bool + wantErr bool + }{ + // 数字比较 + {"数字相等", "==", 5.0, 5.0, true, false}, + {"数字不等", "!=", 5.0, 3.0, true, false}, + {"数字大于", ">", 5.0, 3.0, true, false}, + {"数字小于", "<", 3.0, 5.0, true, false}, + {"数字大于等于", ">=", 5.0, 5.0, true, false}, + {"数字小于等于", "<=", 3.0, 5.0, true, false}, + + // 字符串比较 + {"字符串相等", "==", "hello", "hello", true, false}, + {"字符串不等", "!=", "hello", "world", true, false}, + {"字符串大于", ">", "world", "hello", true, false}, + {"字符串小于", "<", "hello", "world", true, false}, + + // LIKE模式匹配 + {"LIKE匹配", "LIKE", "hello", "h%", true, false}, + {"LIKE不匹配", "LIKE", "hello", "w%", false, false}, + {"LIKE通配符", "LIKE", "hello", "h_llo", true, false}, + {"LIKE完全匹配", "LIKE", "hello", "hello", true, false}, + + // 混合类型比较 + {"数字与字符串", "==", 5.0, "5", true, false}, + {"布尔值比较", "==", true, true, true, false}, + {"布尔值与数字", "==", true, 1.0, true, false}, + {"布尔值与数字(假)", "==", false, 0.0, true, false}, + + // 错误情况 + {"无效运算符", "@", 5.0, 3.0, false, true}, + {"不兼容类型", ">", "hello", 5.0, false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := compareValues(tt.left, tt.right, tt.operator) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "值比较不应该失败") + assert.Equal(t, tt.expected, result, "值比较结果应该正确") + } + }) + } +} + +// TestMatchLikePattern 测试LIKE模式匹配 +func TestMatchLikePattern(t *testing.T) { + tests := []struct { + name string + text string + pattern string + expected bool + }{ + {"完全匹配", "hello", "hello", true}, + {"百分号通配符开头", "hello", "%llo", true}, + {"百分号通配符结尾", "hello", "hel%", true}, + {"百分号通配符中间", "hello", "h%o", true}, + {"百分号通配符全部", "hello", "%", true}, + {"下划线通配符", "hello", "h_llo", true}, + {"下划线通配符多个", "hello", "h___o", true}, + {"混合通配符", "hello world", "h%w_rld", true}, + {"不匹配", "hello", "world", false}, + {"长度不匹配", "hello", "h_", false}, + {"空字符串", "", "", true}, + {"空模式", "hello", "", false}, + {"空文本匹配百分号", "", "%", true}, + {"大小写敏感", "Hello", "hello", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchLikePattern(tt.text, tt.pattern) + assert.Equal(t, tt.expected, result, "LIKE模式匹配结果应该正确") + }) + } +} + +// TestEvaluateBoolNode 测试布尔节点求值 +func TestEvaluateBoolNode(t *testing.T) { + data := map[string]interface{}{ + "a": 10.0, + "b": 5.0, + "flag": true, + "name": "test", + "zero": 0, + "empty": "", + "null_field": nil, + } + + tests := []struct { + name string + node *ExprNode + expected bool + wantErr bool + }{ + { + "数字比较(真)", + &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, + true, + false, + }, + { + "数字比较(假)", + &ExprNode{ + Type: TypeOperator, + Value: "<", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, + false, + false, + }, + { + "AND运算(真)", + &ExprNode{ + Type: TypeOperator, + Value: "AND", + Left: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + Right: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + true, + false, + }, + { + "OR运算(真)", + &ExprNode{ + Type: TypeOperator, + Value: "OR", + Left: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "100"}, + }, + Right: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + true, + false, + }, + { + "NOT运算", + &ExprNode{ + Type: TypeOperator, + Value: "NOT", + Left: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "100"}, + }, + }, + true, + false, + }, + { + "字符串LIKE匹配", + &ExprNode{ + Type: TypeOperator, + Value: "LIKE", + Left: &ExprNode{Type: TypeField, Value: "name"}, + Right: &ExprNode{Type: TypeString, Value: "'t%'"}, + }, + true, + false, + }, + // 新增测试用例以提高覆盖率 + { + "字段节点(真值)", + &ExprNode{Type: TypeField, Value: "flag"}, + true, + false, + }, + { + "字段节点(假值)", + &ExprNode{Type: TypeField, Value: "zero"}, + false, + false, + }, + { + "字段节点(不存在字段)", + &ExprNode{Type: TypeField, Value: "nonexistent"}, + false, + false, + }, + { + "数字节点(非零)", + &ExprNode{Type: TypeNumber, Value: "5"}, + true, + false, + }, + { + "数字节点(零)", + &ExprNode{Type: TypeNumber, Value: "0"}, + false, + false, + }, + { + "字符串节点(非空)", + &ExprNode{Type: TypeString, Value: "'hello'"}, + true, + false, + }, + { + "字符串节点(空)", + &ExprNode{Type: TypeString, Value: "''"}, + false, + false, + }, + { + "括号表达式", + &ExprNode{ + Type: TypeParenthesis, + Left: &ExprNode{Type: TypeField, Value: "flag"}, + }, + true, + false, + }, + { + "函数节点", + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeNumber, Value: "-5"}}, + }, + true, + false, + }, + // 错误情况 + {"非布尔运算符", &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeField, Value: "b"}, + }, false, true}, + {"空节点", nil, false, true}, + {"空括号表达式", &ExprNode{Type: TypeParenthesis}, false, true}, + {"无效数字", &ExprNode{Type: TypeNumber, Value: "invalid"}, false, true}, + {"不支持的节点类型", &ExprNode{Type: "unknown"}, false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := evaluateBoolNode(tt.node, data) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "布尔节点求值不应该失败") + assert.Equal(t, tt.expected, result, "布尔节点求值结果应该正确") + } + }) + } +} diff --git a/expr/expression.go b/expr/expression.go index 717817b..af2150b 100644 --- a/expr/expression.go +++ b/expr/expression.go @@ -2,15 +2,13 @@ package expr import ( "fmt" - "math" "strconv" "strings" "github.com/rulego/streamsql/functions" - "github.com/rulego/streamsql/utils/fieldpath" ) -// Expression types +// Expression types - expression type constants const ( TypeNumber = "number" // Number constant TypeField = "field" // Field reference @@ -21,42 +19,36 @@ const ( TypeString = "string" // String constant ) -// Operator precedence -var operatorPrecedence = map[string]int{ - "OR": 1, - "AND": 2, - "==": 3, "=": 3, "!=": 3, "<>": 3, - ">": 4, "<": 4, ">=": 4, "<=": 4, "LIKE": 4, "IS": 4, - "+": 5, "-": 5, - "*": 6, "/": 6, "%": 6, - "^": 7, // Power operation -} - -// WhenClause represents a WHEN clause in CASE expression +// WhenClause represents a WHEN clause in a CASE expression type WhenClause struct { Condition *ExprNode // WHEN condition Result *ExprNode // THEN result } +// CaseExpression represents the structure of a CASE expression +type CaseExpression struct { + Value *ExprNode // Expression after CASE (simple CASE) + WhenClauses []WhenClause // List of WHEN clauses + ElseResult *ExprNode // ELSE expression +} + // ExprNode represents an expression node type ExprNode struct { - Type string - Value string - Left *ExprNode - Right *ExprNode - Args []*ExprNode // Arguments for function calls + Type string // Node type + Value string // Node value + Left *ExprNode // Left child node + Right *ExprNode // Right child node + Args []*ExprNode // Function argument list // Fields specific to CASE expressions - CaseExpr *ExprNode // Expression after CASE (simple CASE) - WhenClauses []WhenClause // List of WHEN clauses - ElseExpr *ExprNode // ELSE expression + CaseExpr *CaseExpression // CASE expression structure } // Expression represents a computable expression type Expression struct { - Root *ExprNode - useExprLang bool // Whether to use expr-lang/expr - exprLangExpression string // expr-lang expression string + Root *ExprNode // Expression root node + useExprLang bool // Whether to use expr-lang/expr + exprLangExpression string // expr-lang expression string } // NewExpression creates a new expression @@ -95,13 +87,13 @@ func NewExpression(exprStr string) (*Expression, error) { // validateBasicSyntax performs basic syntax validation func validateBasicSyntax(exprStr string) error { - // Check empty expression + // Check for empty expression trimmed := strings.TrimSpace(exprStr) if trimmed == "" { return fmt.Errorf("empty expression") } - // 检查不匹配的括号 + // Check for mismatched parentheses parenthesesCount := 0 for _, ch := range trimmed { if ch == '(' { @@ -117,132 +109,31 @@ func validateBasicSyntax(exprStr string) error { return fmt.Errorf("mismatched parentheses") } - // 检查无效字符 - for i, ch := range trimmed { - // 允许的字符:字母、数字、运算符、括号、点、下划线、空格、引号 - if !isValidChar(ch) { - return fmt.Errorf("invalid character '%c' at position %d", ch, i) + // Check for consecutive operators + operators := []string{"+", "-", "*", "/", "%", "^", "=", "!=", "<>", ">", "<", ">=", "<="} + for _, op1 := range operators { + for _, op2 := range operators { + if strings.Contains(trimmed, " "+op1+" "+op2+" ") { + return fmt.Errorf("consecutive operators") + } } } - // 检查表达式开头和结尾的运算符 - if err := checkExpressionStartEnd(trimmed); err != nil { - return err - } - - // 检查连续运算符 - if err := checkConsecutiveOperators(trimmed); err != nil { - return err - } - - return nil -} - -// checkExpressionStartEnd checks if expression starts or ends with an operator -func checkExpressionStartEnd(expr string) error { - operators := []string{"+", "*", "/", "%", "^", "==", "!=", ">=", "<=", ">", "<"} - - // 检查表达式开头(允许负号,因为它是合法的负数表示) + // Check if expression starts or ends with operator for _, op := range operators { - if strings.HasPrefix(expr, op) { + if strings.HasPrefix(trimmed, op+" ") { return fmt.Errorf("expression cannot start with operator") } - } - - // 检查表达式结尾 - for _, op := range operators { - if strings.HasSuffix(expr, op) { + if strings.HasSuffix(trimmed, " "+op) { return fmt.Errorf("expression cannot end with operator") } } - return nil -} - -// checkConsecutiveOperators checks for consecutive operators -func checkConsecutiveOperators(expr string) error { - // Simplified consecutive operator check: look for obvious double operator patterns - // But allow comparison operators followed by negative numbers - operators := []string{"+", "-", "*", "/", "%", "^", "==", "!=", ">=", "<=", ">", "<"} - comparisonOps := []string{"==", "!=", ">=", "<=", ">", "<"} - - for i := 0; i < len(expr)-1; i++ { - // 跳过空白字符 - if expr[i] == ' ' || expr[i] == '\t' { - continue - } - - // 检查当前位置是否是运算符 - isCurrentOp := false - currentOpLen := 0 - currentOp := "" - for _, op := range operators { - if i+len(op) <= len(expr) && expr[i:i+len(op)] == op { - isCurrentOp = true - currentOpLen = len(op) - currentOp = op - break - } - } - - if isCurrentOp { - // 查找下一个非空白字符 - nextPos := i + currentOpLen - for nextPos < len(expr) && (expr[nextPos] == ' ' || expr[nextPos] == '\t') { - nextPos++ - } - - // 检查下一个字符是否也是运算符 - if nextPos < len(expr) { - // 特殊处理:如果当前是比较运算符,下一个是负号,且负号后跟数字,则允许 - isCurrentComparison := false - for _, compOp := range comparisonOps { - if currentOp == compOp { - isCurrentComparison = true - break - } - } - - // 检查是否是负数的情况 - if isCurrentComparison && nextPos < len(expr) && expr[nextPos] == '-' { - // 检查负号后是否跟数字 - digitPos := nextPos + 1 - for digitPos < len(expr) && (expr[digitPos] == ' ' || expr[digitPos] == '\t') { - digitPos++ - } - if digitPos < len(expr) && expr[digitPos] >= '0' && expr[digitPos] <= '9' { - // 这是比较运算符后跟负数,允许通过 - i = nextPos // 跳过到负号位置 - continue - } - } - - // 特殊处理:如果当前是幂运算符(^),下一个是负号,且负号后跟数字,则允许 - if currentOp == "^" && nextPos < len(expr) && expr[nextPos] == '-' { - // 检查负号后是否跟数字 - digitPos := nextPos + 1 - for digitPos < len(expr) && (expr[digitPos] == ' ' || expr[digitPos] == '\t') { - digitPos++ - } - if digitPos < len(expr) && expr[digitPos] >= '0' && expr[digitPos] <= '9' { - // 这是幂运算符后跟负数,允许通过 - i = nextPos // 跳过到负号位置 - continue - } - } - - // 检查其他连续运算符 - for _, op := range operators { - if nextPos+len(op) <= len(expr) && expr[nextPos:nextPos+len(op)] == op { - // 如果不是允许的负数情况,则报错 - return fmt.Errorf("consecutive operators found: '%s' followed by '%s'", - currentOp, op) - } - } - } - - // 跳过当前运算符 - i += currentOpLen - 1 + // Check for invalid characters + for i, ch := range trimmed { + // Allowed characters: letters, numbers, operators, parentheses, dots, underscores, spaces, quotes + if !isValidChar(ch) { + return fmt.Errorf("invalid character '%c' at position %d", ch, i) } } @@ -251,7 +142,7 @@ func checkConsecutiveOperators(expr string) error { // isValidChar checks if a character is valid func isValidChar(ch rune) bool { - // Letters and digits + // Letters and numbers if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') { return true } @@ -269,7 +160,9 @@ func isValidChar(ch rune) bool { return true case '.', '_': // Dot and underscore return true - case '$': // Dollar sign (for JSON paths etc.) + case '$': // Dollar sign (for JSON paths, etc.) + return true + case '`': // Backtick (for identifiers) return true default: return false @@ -403,2045 +296,61 @@ func collectFields(node *ExprNode, fields map[string]bool) { } if node.Type == TypeField { - fields[node.Value] = true - } - - // Handle field collection for CASE expressions - if node.Type == TypeCase { - // Collect fields from CASE expression itself - if node.CaseExpr != nil { - collectFields(node.CaseExpr, fields) - } - - // Collect fields from all WHEN clauses - for _, whenClause := range node.WhenClauses { - collectFields(whenClause.Condition, fields) - collectFields(whenClause.Result, fields) - } - - // Collect fields from ELSE expression - if node.ElseExpr != nil { - collectFields(node.ElseExpr, fields) - } - - return + // Remove backticks (if present) + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] + } + fields[fieldName] = true } + // Recursively collect fields from child nodes collectFields(node.Left, fields) collectFields(node.Right, fields) + // Collect fields from function arguments for _, arg := range node.Args { collectFields(arg, fields) } -} -// evaluateNode calculates the value of a node -func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) { - if node == nil { - return 0, fmt.Errorf("null expression node") - } - - switch node.Type { - case TypeNumber: - return strconv.ParseFloat(node.Value, 64) - - case TypeString: - // Handle string type, remove quotes and try to convert to number - // If conversion fails, return error (since this function returns float64) - value := node.Value - if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { - value = value[1 : len(value)-1] // Remove quotes - } - - // Try to convert to number - if f, err := strconv.ParseFloat(value, 64); err == nil { - return f, nil - } - - // For string comparison, we need to return a hash value or error - // Simplified handling here, convert string to its length (as temporary solution) - return float64(len(value)), nil - - case TypeField: - // Handle backtick identifiers, remove backticks - fieldName := node.Value - if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { - fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks - } - - // Support nested field access - if fieldpath.IsNestedField(fieldName) { - if val, found := fieldpath.GetNestedField(data, fieldName); found { - // Try to convert to float64 - if floatVal, err := convertToFloat(val); err == nil { - return floatVal, nil - } - // If cannot convert to number, return error - return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val) - } - } else { - // Original simple field access - if val, found := data[fieldName]; found { - // Try to convert to float64 - if floatVal, err := convertToFloat(val); err == nil { - return floatVal, nil - } - // If cannot convert to number, return error - return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val) - } - } - return 0, fmt.Errorf("field '%s' not found", fieldName) - - case TypeOperator: - // Check if this is a comparison operator - if isComparisonOperator(node.Value) { - // For comparison operators, use evaluateNodeValue to get original types - leftValue, err := evaluateNodeValue(node.Left, data) - if err != nil { - return 0, err - } - - rightValue, err := evaluateNodeValue(node.Right, data) - if err != nil { - return 0, err - } - - // Perform comparison and convert boolean to number - result, err := compareValues(leftValue, rightValue, node.Value) - if err != nil { - return 0, err - } - if result { - return 1.0, nil - } - return 0.0, nil - } - - // For arithmetic operators, calculate numeric values - left, err := evaluateNode(node.Left, data) - if err != nil { - return 0, err - } - - right, err := evaluateNode(node.Right, data) - if err != nil { - return 0, err - } - - // Perform operation - switch node.Value { - case "+": - return left + right, nil - case "-": - return left - right, nil - case "*": - return left * right, nil - case "/": - if right == 0 { - return 0, fmt.Errorf("division by zero") - } - return left / right, nil - case "%": - if right == 0 { - return 0, fmt.Errorf("modulo by zero") - } - return math.Mod(left, right), nil - case "^": - return math.Pow(left, right), nil - default: - return 0, fmt.Errorf("unknown operator: %s", node.Value) - } - - case TypeFunction: - // First check if it's a function in the new function registration system - fn, exists := functions.Get(node.Value) - if exists { - // Calculate all arguments but keep original types - args := make([]interface{}, len(node.Args)) - for i, arg := range node.Args { - // Use evaluateNodeValue to get original type values - val, err := evaluateNodeValue(arg, data) - if err != nil { - return 0, err - } - args[i] = val - } - - // Create function execution context - ctx := &functions.FunctionContext{ - Data: data, - } - - // Execute function - result, err := fn.Execute(ctx, args) - if err != nil { - return 0, err - } - - // Convert result to float64 - switch r := result.(type) { - case float64: - return r, nil - case float32: - return float64(r), nil - case int: - return float64(r), nil - case int32: - return float64(r), nil - case int64: - return float64(r), nil - case string: - // For string results, try to convert to number, if failed return string length - if f, err := strconv.ParseFloat(r, 64); err == nil { - return f, nil - } - return float64(len(r)), nil - case bool: - // Boolean conversion: true=1, false=0 - if r { - return 1.0, nil - } - return 0.0, nil - default: - return 0, fmt.Errorf("function %s returned unsupported type for numeric conversion: %T", node.Value, result) - } - } - - // Fall back to built-in function handling (maintain backward compatibility) - return evaluateBuiltinFunction(node, data) - - case TypeCase: - // Handle CASE expression - return evaluateCaseExpression(node, data) - } - - return 0, fmt.Errorf("unknown node type: %s", node.Type) -} - -// evaluateBuiltinFunction handles built-in functions (backward compatibility) -func evaluateBuiltinFunction(node *ExprNode, data map[string]interface{}) (float64, error) { - switch strings.ToLower(node.Value) { - case "abs": - if len(node.Args) != 1 { - return 0, fmt.Errorf("abs function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Abs(arg), nil - - case "sqrt": - if len(node.Args) != 1 { - return 0, fmt.Errorf("sqrt function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - if arg < 0 { - return 0, fmt.Errorf("sqrt of negative number") - } - return math.Sqrt(arg), nil - - case "sin": - if len(node.Args) != 1 { - return 0, fmt.Errorf("sin function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Sin(arg), nil - - case "cos": - if len(node.Args) != 1 { - return 0, fmt.Errorf("cos function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Cos(arg), nil - - case "tan": - if len(node.Args) != 1 { - return 0, fmt.Errorf("tan function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Tan(arg), nil - - case "floor": - if len(node.Args) != 1 { - return 0, fmt.Errorf("floor function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Floor(arg), nil - - case "ceil": - if len(node.Args) != 1 { - return 0, fmt.Errorf("ceil function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Ceil(arg), nil - - case "round": - if len(node.Args) != 1 { - return 0, fmt.Errorf("round function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Round(arg), nil - - case "pow": - if len(node.Args) != 2 { - return 0, fmt.Errorf("pow function requires exactly 2 arguments") - } - base, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - exponent, err := evaluateNode(node.Args[1], data) - if err != nil { - return 0, err - } - return math.Pow(base, exponent), nil - - case "max": - if len(node.Args) < 1 { - return 0, fmt.Errorf("max function requires at least 1 argument") - } - maxVal, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - for i := 1; i < len(node.Args); i++ { - arg, err := evaluateNode(node.Args[i], data) - if err != nil { - return 0, err - } - if arg > maxVal { - maxVal = arg - } - } - return maxVal, nil - - case "min": - if len(node.Args) < 1 { - return 0, fmt.Errorf("min function requires at least 1 argument") - } - minVal, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - for i := 1; i < len(node.Args); i++ { - arg, err := evaluateNode(node.Args[i], data) - if err != nil { - return 0, err - } - if arg < minVal { - minVal = arg - } - } - return minVal, nil - - case "log": - if len(node.Args) != 1 { - return 0, fmt.Errorf("log function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - if arg <= 0 { - return 0, fmt.Errorf("log of non-positive number") - } - return math.Log(arg), nil - - case "log10": - if len(node.Args) != 1 { - return 0, fmt.Errorf("log10 function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - if arg <= 0 { - return 0, fmt.Errorf("log10 of non-positive number") - } - return math.Log10(arg), nil - - case "exp": - if len(node.Args) != 1 { - return 0, fmt.Errorf("exp function requires exactly 1 argument") - } - arg, err := evaluateNode(node.Args[0], data) - if err != nil { - return 0, err - } - return math.Exp(arg), nil - - case "len": - if len(node.Args) != 1 { - return 0, fmt.Errorf("len function requires exactly 1 argument") - } - // Use evaluateNodeValue to get the original value - arg, err := evaluateNodeValue(node.Args[0], data) - if err != nil { - return 0, err - } - // Convert to string and get length - strVal := fmt.Sprintf("%v", arg) - return float64(len(strVal)), nil - - default: - return 0, fmt.Errorf("unknown function: %s", node.Value) - } -} - -// evaluateCaseExpression evaluates CASE expression -func evaluateCaseExpression(node *ExprNode, data map[string]interface{}) (float64, error) { - if node.Type != TypeCase { - return 0, fmt.Errorf("node is not a CASE expression") - } - - // Handle simple CASE expression (CASE expr WHEN value1 THEN result1 ...) + // Collect fields from CASE expression if node.CaseExpr != nil { - // Calculate the value of expression after CASE - caseValue, err := evaluateNodeValue(node.CaseExpr, data) - if err != nil { - return 0, err - } - - // Iterate through WHEN clauses to find matching values - for _, whenClause := range node.WhenClauses { - conditionValue, err := evaluateNodeValue(whenClause.Condition, data) - if err != nil { - return 0, err - } - - // Compare if values are equal - isEqual, err := compareValues(caseValue, conditionValue, "==") - if err != nil { - return 0, err - } - - if isEqual { - return evaluateNode(whenClause.Result, data) - } - } - } else { - // Handle search CASE expression (CASE WHEN condition1 THEN result1 ...) - for _, whenClause := range node.WhenClauses { - // Evaluate WHEN condition, need special handling for boolean expressions - conditionResult, err := evaluateBooleanCondition(whenClause.Condition, data) - if err != nil { - return 0, err - } - - // If condition is true, return corresponding result - if conditionResult { - return evaluateNode(whenClause.Result, data) - } + collectFields(node.CaseExpr.Value, fields) + collectFields(node.CaseExpr.ElseResult, fields) + for _, whenClause := range node.CaseExpr.WhenClauses { + collectFields(whenClause.Condition, fields) + collectFields(whenClause.Result, fields) } } - - // If no WHEN clause matches, execute ELSE clause - if node.ElseExpr != nil { - return evaluateNode(node.ElseExpr, data) - } - - // If no ELSE clause, SQL standard returns NULL, here return 0 - return 0, nil } -// evaluateBooleanCondition evaluates boolean condition expression -func evaluateBooleanCondition(node *ExprNode, data map[string]interface{}) (bool, error) { - if node == nil { - return false, fmt.Errorf("null condition expression") - } - - // Handle logical operators (implement short-circuit evaluation) - if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") { - leftBool, err := evaluateBooleanCondition(node.Left, data) +// EvaluateBool calculates the boolean value of the expression +func (e *Expression) EvaluateBool(data map[string]interface{}) (bool, error) { + if e.useExprLang { + // For expr-lang expressions, calculate numeric value first then convert to boolean + result, err := e.evaluateWithExprLang(data) if err != nil { return false, err } - - // Short-circuit evaluation: for AND, if left is false, return false immediately - if node.Value == "AND" && !leftBool { - return false, nil - } - - // Short-circuit evaluation: for OR, if left is true, return true immediately - if node.Value == "OR" && leftBool { - return true, nil - } - - // Only evaluate right expression when needed - rightBool, err := evaluateBooleanCondition(node.Right, data) - if err != nil { - return false, err - } - - switch node.Value { - case "AND": - return leftBool && rightBool, nil - case "OR": - return leftBool || rightBool, nil - } + return result != 0, nil } - - // Handle IS NULL and IS NOT NULL special cases - if node.Type == TypeOperator && node.Value == "IS" { - return evaluateIsCondition(node, data) - } - - // Handle comparison operators - if node.Type == TypeOperator { - leftValue, err := evaluateNodeValue(node.Left, data) - if err != nil { - return false, err - } - - rightValue, err := evaluateNodeValue(node.Right, data) - if err != nil { - return false, err - } - - return compareValues(leftValue, rightValue, node.Value) - } - - // For other expressions, calculate numeric value and convert to boolean - result, err := evaluateNode(node, data) - if err != nil { - return false, err - } - - // Non-zero values are true, zero values are false - return result != 0, nil + return evaluateBoolNode(e.Root, data) } -// evaluateIsCondition handles IS NULL and IS NOT NULL conditions -func evaluateIsCondition(node *ExprNode, data map[string]interface{}) (bool, error) { - if node == nil || node.Left == nil || node.Right == nil { - return false, fmt.Errorf("invalid IS condition") - } - - // Get left side value - leftValue, err := evaluateNodeValue(node.Left, data) - if err != nil { - // If field doesn't exist, consider it as null - leftValue = nil - } - - // Check if right side is NULL or NOT NULL - if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" { - // IS NULL - return leftValue == nil, nil - } - - // Check if it's IS NOT NULL - if node.Right.Type == TypeOperator && node.Right.Value == "NOT" && - node.Right.Right != nil && node.Right.Right.Type == TypeField && - strings.ToUpper(node.Right.Right.Value) == "NULL" { - // IS NOT NULL - return leftValue != nil, nil - } - - // Other IS comparisons (like IS TRUE, IS FALSE etc., not supported yet) - rightValue, err := evaluateNodeValue(node.Right, data) - if err != nil { - return false, err - } - - return compareValues(leftValue, rightValue, "==") -} - -// evaluateNodeValue calculates node value, returns interface{} to support different types -func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{}, error) { - if node == nil { - return nil, fmt.Errorf("null expression node") - } - - switch node.Type { - case TypeNumber: - return strconv.ParseFloat(node.Value, 64) - - case TypeString: - // Remove quotes - value := node.Value - if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { - value = value[1 : len(value)-1] - } - return value, nil - - case TypeField: - // Handle backtick identifiers, remove backticks - fieldName := node.Value - if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { - fieldName = fieldName[1 : len(fieldName)-1] // Remove backticks - } - - // Support nested field access - if fieldpath.IsNestedField(fieldName) { - if val, found := fieldpath.GetNestedField(data, fieldName); found { - return val, nil - } - } else { - // Original simple field access - if val, found := data[fieldName]; found { - return val, nil - } - } - return nil, fmt.Errorf("field '%s' not found", fieldName) - - default: - // For other types, fall back to numeric calculation - return evaluateNode(node, data) - } -} - -// compareValues compares two values -func compareValues(left, right interface{}, operator string) (bool, error) { - // Try string comparison - leftStr, leftIsStr := left.(string) - rightStr, rightIsStr := right.(string) - - if leftIsStr && rightIsStr { - switch operator { - case "==", "=": - return leftStr == rightStr, nil - case "!=", "<>": - return leftStr != rightStr, nil - case ">": - return leftStr > rightStr, nil - case ">=": - return leftStr >= rightStr, nil - case "<": - return leftStr < rightStr, nil - case "<=": - return leftStr <= rightStr, nil - case "LIKE": - return matchesLikePattern(leftStr, rightStr), nil - default: - return false, fmt.Errorf("unsupported string comparison operator: %s", operator) - } - } - - // Convert to numeric values for comparison - leftNum, err1 := convertToFloat(left) - rightNum, err2 := convertToFloat(right) - - if err1 != nil || err2 != nil { - return false, fmt.Errorf("cannot compare values: %v and %v", left, right) - } - - switch operator { - case ">": - return leftNum > rightNum, nil - case ">=": - return leftNum >= rightNum, nil - case "<": - return leftNum < rightNum, nil - case "<=": - return leftNum <= rightNum, nil - case "==", "=": - return math.Abs(leftNum-rightNum) < 1e-9, nil - case "!=", "<>": - return math.Abs(leftNum-rightNum) >= 1e-9, nil - default: - return false, fmt.Errorf("unsupported comparison operator: %s", operator) - } -} - -// matchesLikePattern implements LIKE pattern matching -// Supports % (matches any character sequence) and _ (matches single character) -func matchesLikePattern(text, pattern string) bool { - return likeMatch(text, pattern, 0, 0) -} - -// likeMatch recursively implements LIKE matching algorithm -func likeMatch(text, pattern string, textIndex, patternIndex int) bool { - // If pattern matching is complete - if patternIndex >= len(pattern) { - return textIndex >= len(text) // Text should also be completely matched - } - - // If text has ended but pattern still has non-% characters, no match - if textIndex >= len(text) { - // Check if remaining pattern consists only of % - for i := patternIndex; i < len(pattern); i++ { - if pattern[i] != '%' { - return false - } - } - return true - } - - switch pattern[patternIndex] { - case '%': - // % can match 0 or more characters - // Try matching 0 characters (skip %) - if likeMatch(text, pattern, textIndex, patternIndex+1) { - return true - } - // Try matching 1 or more characters - for i := textIndex; i < len(text); i++ { - if likeMatch(text, pattern, i+1, patternIndex+1) { - return true - } - } - return false - - case '_': - // _ matches any single character - return likeMatch(text, pattern, textIndex+1, patternIndex+1) - - default: - // Regular characters must match exactly - if text[textIndex] == pattern[patternIndex] { - return likeMatch(text, pattern, textIndex+1, patternIndex+1) - } - return false - } -} - -// convertToFloat converts value to float64 -func convertToFloat(val interface{}) (float64, error) { - switch v := val.(type) { - case float64: - if math.IsNaN(v) { - return 0, fmt.Errorf("NaN value detected") - } - return v, nil - case float32: - if math.IsNaN(float64(v)) { - return 0, fmt.Errorf("NaN value detected") - } - return float64(v), nil - case int: - return float64(v), nil - case int32: - return float64(v), nil - case int64: - return float64(v), nil - case bool: - if v { - return 1.0, nil - } - return 0.0, nil - case string: - f, err := strconv.ParseFloat(v, 64) - if err != nil { - return 0, err - } - if math.IsNaN(f) { - return 0, fmt.Errorf("NaN value detected") - } - return f, nil - default: - return 0, fmt.Errorf("cannot convert %T to float64", val) - } -} - -// tokenize converts expression string to token list -func tokenize(expr string) ([]string, error) { - expr = strings.TrimSpace(expr) - if expr == "" { - return nil, fmt.Errorf("empty expression") - } - - tokens := make([]string, 0) - i := 0 - - for i < len(expr) { - ch := expr[i] - - // Skip whitespace characters - if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' { - i++ - continue - } - - // Handle numbers - if isDigit(ch) || (ch == '.' && i+1 < len(expr) && isDigit(expr[i+1])) { - start := i - hasDot := ch == '.' - - i++ - for i < len(expr) && (isDigit(expr[i]) || (expr[i] == '.' && !hasDot)) { - if expr[i] == '.' { - hasDot = true - } - i++ - } - - tokens = append(tokens, expr[start:i]) - continue - } - - // Handle operators and parentheses - if ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '%' || ch == '^' || - ch == '(' || ch == ')' || ch == ',' { - - // Special handling for minus sign: if it's minus and preceded by operator, parenthesis or start position, it might be negative number - if ch == '-' { - // Check if it could be the start of a negative number - canBeNegativeNumber := i == 0 || // Expression start - len(tokens) == 0 // When tokens is empty, it could also be negative number start - - // Only check previous token when tokens is not empty - if len(tokens) > 0 { - prevToken := tokens[len(tokens)-1] - canBeNegativeNumber = canBeNegativeNumber || - prevToken == "(" || // After left parenthesis - prevToken == "," || // After comma (function parameter) - isOperator(prevToken) || // After operator - isComparisonOperator(prevToken) || // After comparison operator - strings.ToUpper(prevToken) == "THEN" || // After THEN - strings.ToUpper(prevToken) == "ELSE" || // After ELSE - strings.ToUpper(prevToken) == "WHEN" || // After WHEN - strings.ToUpper(prevToken) == "AND" || // After AND - strings.ToUpper(prevToken) == "OR" // After OR - } - - if canBeNegativeNumber && i+1 < len(expr) && isDigit(expr[i+1]) { - // This is a negative number, parse the entire number - start := i - i++ // Skip minus sign - - // Parse numeric part - for i < len(expr) && (isDigit(expr[i]) || expr[i] == '.') { - i++ - } - - tokens = append(tokens, expr[start:i]) - continue - } - } - - tokens = append(tokens, string(ch)) - i++ - continue - } - - // Handle comparison operators - if ch == '>' || ch == '<' || ch == '=' || ch == '!' { - start := i - i++ - - // Handle two-character operators - if i < len(expr) { - switch ch { - case '>': - if expr[i] == '=' { - i++ - tokens = append(tokens, ">=") - continue - } - case '<': - if expr[i] == '=' { - i++ - tokens = append(tokens, "<=") - continue - } else if expr[i] == '>' { - i++ - tokens = append(tokens, "<>") - continue - } - case '=': - if expr[i] == '=' { - i++ - tokens = append(tokens, "==") - continue - } - case '!': - if expr[i] == '=' { - i++ - tokens = append(tokens, "!=") - continue - } - } - } - - // Single character operator - tokens = append(tokens, expr[start:i]) - continue - } - - // Handle string literals (single and double quotes) - if ch == '\'' || ch == '"' { - quote := ch - start := i - i++ // Skip opening quote - - // Find closing quote - for i < len(expr) && expr[i] != quote { - if expr[i] == '\\' && i+1 < len(expr) { - i += 2 // Skip escape character - } else { - i++ - } - } - - if i >= len(expr) { - return nil, fmt.Errorf("unterminated string literal starting at position %d", start) - } - - i++ // Skip closing quote - tokens = append(tokens, expr[start:i]) - continue - } - - // Handle backtick identifiers - if ch == '`' { - start := i - i++ // Skip opening backtick - - // Find closing backtick - for i < len(expr) && expr[i] != '`' { - i++ - } - - if i >= len(expr) { - return nil, fmt.Errorf("unterminated quoted identifier starting at position %d", start) - } - - i++ // Skip closing backtick - tokens = append(tokens, expr[start:i]) - continue - } - - // Handle identifiers (field names or function names) - if isLetter(ch) { - start := i - i++ - for i < len(expr) && (isLetter(expr[i]) || isDigit(expr[i]) || expr[i] == '_') { - i++ - } - - tokens = append(tokens, expr[start:i]) - continue - } - - // Unknown character - return nil, fmt.Errorf("unexpected character: %c at position %d", ch, i) - } - - return tokens, nil -} - -// parseExpression parses expression -func parseExpression(tokens []string) (*ExprNode, error) { - if len(tokens) == 0 { - return nil, fmt.Errorf("empty token list") - } - - // Use Shunting-yard algorithm to handle operator precedence - output := make([]*ExprNode, 0) - operators := make([]string, 0) - - i := 0 - for i < len(tokens) { - token := tokens[i] - - // Handle numbers - if isNumber(token) { - output = append(output, &ExprNode{ - Type: TypeNumber, - Value: token, - }) - i++ - continue - } - - // Handle string literals - if isStringLiteral(token) { - output = append(output, &ExprNode{ - Type: TypeString, - Value: token, - }) - i++ - continue - } - - // Handle field names or function calls - if isIdentifier(token) { - // Check if it's a logical operator keyword - upperToken := strings.ToUpper(token) - if upperToken == "AND" || upperToken == "OR" || upperToken == "NOT" || upperToken == "LIKE" { - // Handle logical operators - for len(operators) > 0 && operators[len(operators)-1] != "(" && - operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[upperToken] { - op := operators[len(operators)-1] - operators = operators[:len(operators)-1] - - if len(output) < 2 { - return nil, fmt.Errorf("not enough operands for operator: %s", op) - } - - right := output[len(output)-1] - left := output[len(output)-2] - output = output[:len(output)-2] - - output = append(output, &ExprNode{ - Type: TypeOperator, - Value: op, - Left: left, - Right: right, - }) - } - - operators = append(operators, upperToken) - i++ - continue - } - - // Special handling for IS operator, need to check subsequent NOT NULL combination - if upperToken == "IS" { - // Handle pending operators - for len(operators) > 0 && operators[len(operators)-1] != "(" && - operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence["IS"] { - op := operators[len(operators)-1] - operators = operators[:len(operators)-1] - - if len(output) < 2 { - return nil, fmt.Errorf("not enough operands for operator: %s", op) - } - - right := output[len(output)-1] - left := output[len(output)-2] - output = output[:len(output)-2] - - output = append(output, &ExprNode{ - Type: TypeOperator, - Value: op, - Left: left, - Right: right, - }) - } - - // Check if it's IS NOT NULL pattern - if i+2 < len(tokens) && - strings.ToUpper(tokens[i+1]) == "NOT" && - strings.ToUpper(tokens[i+2]) == "NULL" { - // This is IS NOT NULL, create special right-side node structure - notNullNode := &ExprNode{ - Type: TypeOperator, - Value: "NOT", - Right: &ExprNode{ - Type: TypeField, - Value: "NULL", - }, - } - - operators = append(operators, "IS") - output = append(output, notNullNode) - i += 3 // Skip three tokens: IS NOT NULL - continue - } else if i+1 < len(tokens) && strings.ToUpper(tokens[i+1]) == "NULL" { - // This is IS NULL, create NULL node - nullNode := &ExprNode{ - Type: TypeField, - Value: "NULL", - } - - operators = append(operators, "IS") - output = append(output, nullNode) - i += 2 // Skip two tokens: IS NULL - continue - } else { - // Regular IS operator - operators = append(operators, "IS") - i++ - continue - } - } - - // Check if it's CASE expression - if strings.ToUpper(token) == "CASE" { - caseNode, newIndex, err := parseCaseExpression(tokens, i) - if err != nil { - return nil, err - } - output = append(output, caseNode) - i = newIndex - continue - } - - // Check if next token is left parenthesis, if so it's a function call - if i+1 < len(tokens) && tokens[i+1] == "(" { - funcName := token - i += 2 // Skip function name and left parenthesis - - // Parse function arguments - args, newIndex, err := parseFunctionArgs(tokens, i) - if err != nil { - return nil, err - } - - output = append(output, &ExprNode{ - Type: TypeFunction, - Value: funcName, - Args: args, - }) - - i = newIndex - continue - } - - // Regular field - output = append(output, &ExprNode{ - Type: TypeField, - Value: token, - }) - i++ - continue - } - - // Handle left parenthesis - if token == "(" { - operators = append(operators, token) - i++ - continue - } - - // Handle right parenthesis - if token == ")" { - for len(operators) > 0 && operators[len(operators)-1] != "(" { - op := operators[len(operators)-1] - operators = operators[:len(operators)-1] - - if len(output) < 2 { - return nil, fmt.Errorf("not enough operands for operator: %s", op) - } - - right := output[len(output)-1] - left := output[len(output)-2] - output = output[:len(output)-2] - - output = append(output, &ExprNode{ - Type: TypeOperator, - Value: op, - Left: left, - Right: right, - }) - } - - if len(operators) == 0 || operators[len(operators)-1] != "(" { - return nil, fmt.Errorf("mismatched parentheses") - } - - operators = operators[:len(operators)-1] // Pop left parenthesis - i++ - continue - } - - // Handle operators - if isOperator(token) { - for len(operators) > 0 && operators[len(operators)-1] != "(" && - operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[token] { - op := operators[len(operators)-1] - operators = operators[:len(operators)-1] - - if len(output) < 2 { - return nil, fmt.Errorf("not enough operands for operator: %s", op) - } - - right := output[len(output)-1] - left := output[len(output)-2] - output = output[:len(output)-2] - - output = append(output, &ExprNode{ - Type: TypeOperator, - Value: op, - Left: left, - Right: right, - }) - } - - operators = append(operators, token) - i++ - continue - } - - // Handle comma (processed in function argument list) - if token == "," { - i++ - continue - } - - return nil, fmt.Errorf("unexpected token: %s", token) - } - - // Handle remaining operators - for len(operators) > 0 { - op := operators[len(operators)-1] - operators = operators[:len(operators)-1] - - if op == "(" { - return nil, fmt.Errorf("mismatched parentheses") - } - - if len(output) < 2 { - return nil, fmt.Errorf("not enough operands for operator: %s", op) - } - - right := output[len(output)-1] - left := output[len(output)-2] - output = output[:len(output)-2] - - output = append(output, &ExprNode{ - Type: TypeOperator, - Value: op, - Left: left, - Right: right, - }) - } - - if len(output) != 1 { - return nil, fmt.Errorf("invalid expression") - } - - return output[0], nil -} - -// parseFunctionArgs parses function arguments -func parseFunctionArgs(tokens []string, startIndex int) ([]*ExprNode, int, error) { - args := make([]*ExprNode, 0) - i := startIndex - - // Handle empty argument list - if i < len(tokens) && tokens[i] == ")" { - return args, i + 1, nil - } - - for i < len(tokens) { - // Parse argument expression - argTokens := make([]string, 0) - parenthesesCount := 0 - - for i < len(tokens) { - token := tokens[i] - - if token == "(" { - parenthesesCount++ - } else if token == ")" { - parenthesesCount-- - if parenthesesCount < 0 { - break - } - } else if token == "," && parenthesesCount == 0 { - break - } - - argTokens = append(argTokens, token) - i++ - } - - if len(argTokens) > 0 { - arg, err := parseExpression(argTokens) - if err != nil { - return nil, 0, err - } - args = append(args, arg) - } - - if i >= len(tokens) { - return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments") - } - - if tokens[i] == ")" { - return args, i + 1, nil - } - - if tokens[i] == "," { - i++ - continue - } - - return nil, 0, fmt.Errorf("unexpected token in function arguments: %s", tokens[i]) - } - - return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments") -} - -// parseCaseExpression parses CASE expression -func parseCaseExpression(tokens []string, startIndex int) (*ExprNode, int, error) { - if startIndex >= len(tokens) || strings.ToUpper(tokens[startIndex]) != "CASE" { - return nil, startIndex, fmt.Errorf("expected CASE keyword") - } - - caseNode := &ExprNode{ - Type: TypeCase, - WhenClauses: make([]WhenClause, 0), - } - - i := startIndex + 1 // 跳过CASE关键字 - - // 检查是否是简单CASE表达式(CASE expr WHEN value1 THEN result1 ...) - // 或搜索CASE表达式(CASE WHEN condition1 THEN result1 ...) - if i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" { - // 这是简单CASE表达式,需要解析CASE后面的表达式 - caseExprTokens := make([]string, 0) - - // 收集CASE表达式直到遇到WHEN - for i < len(tokens) && strings.ToUpper(tokens[i]) != "WHEN" { - caseExprTokens = append(caseExprTokens, tokens[i]) - i++ - } - - if len(caseExprTokens) == 0 { - return nil, i, fmt.Errorf("expected expression after CASE") - } - - // 对于简单的情况,直接处理单个token - if len(caseExprTokens) == 1 { - token := caseExprTokens[0] - if isNumber(token) { - caseNode.CaseExpr = &ExprNode{Type: TypeNumber, Value: token} - } else if isStringLiteral(token) { - caseNode.CaseExpr = &ExprNode{Type: TypeString, Value: token} - } else if isIdentifier(token) { - caseNode.CaseExpr = &ExprNode{Type: TypeField, Value: token} - } else { - return nil, i, fmt.Errorf("invalid CASE expression token: %s", token) - } - } else { - // 对于复杂表达式,调用parseExpression - caseExpr, err := parseExpression(caseExprTokens) - if err != nil { - return nil, i, fmt.Errorf("failed to parse CASE expression: %w", err) - } - caseNode.CaseExpr = caseExpr - } - } - - // 解析WHEN子句 - for i < len(tokens) && strings.ToUpper(tokens[i]) == "WHEN" { - i++ // 跳过WHEN关键字 - - // 收集WHEN条件直到遇到THEN - conditionTokens := make([]string, 0) - for i < len(tokens) && strings.ToUpper(tokens[i]) != "THEN" { - conditionTokens = append(conditionTokens, tokens[i]) - i++ - } - - if len(conditionTokens) == 0 { - return nil, i, fmt.Errorf("expected condition after WHEN") - } - - if i >= len(tokens) || strings.ToUpper(tokens[i]) != "THEN" { - return nil, i, fmt.Errorf("expected THEN after WHEN condition") - } - - i++ // 跳过THEN关键字 - - // 收集THEN结果直到遇到WHEN、ELSE或END - resultTokens := make([]string, 0) - for i < len(tokens) { - upper := strings.ToUpper(tokens[i]) - if upper == "WHEN" || upper == "ELSE" || upper == "END" { - break - } - resultTokens = append(resultTokens, tokens[i]) - i++ - } - - if len(resultTokens) == 0 { - return nil, i, fmt.Errorf("expected result after THEN") - } - - // 解析条件和结果表达式 - conditionExpr, err := parseExpression(conditionTokens) - if err != nil { - return nil, i, fmt.Errorf("failed to parse WHEN condition: %w", err) - } - - resultExpr, err := parseExpression(resultTokens) - if err != nil { - return nil, i, fmt.Errorf("failed to parse THEN result: %w", err) - } - - // 添加WHEN子句 - caseNode.WhenClauses = append(caseNode.WhenClauses, WhenClause{ - Condition: conditionExpr, - Result: resultExpr, - }) - } - - // 检查是否有ELSE子句 - if i < len(tokens) && strings.ToUpper(tokens[i]) == "ELSE" { - i++ // 跳过ELSE关键字 - - // 收集ELSE结果直到遇到END - elseTokens := make([]string, 0) - for i < len(tokens) && strings.ToUpper(tokens[i]) != "END" { - elseTokens = append(elseTokens, tokens[i]) - i++ - } - - if len(elseTokens) == 0 { - return nil, i, fmt.Errorf("expected result after ELSE") - } - - // 解析ELSE表达式 - elseExpr, err := parseExpression(elseTokens) - if err != nil { - return nil, i, fmt.Errorf("failed to parse ELSE result: %w", err) - } - caseNode.ElseExpr = elseExpr - } - - // 检查END关键字 - if i >= len(tokens) || strings.ToUpper(tokens[i]) != "END" { - return nil, i, fmt.Errorf("expected END to close CASE expression") - } - - i++ // 跳过END关键字 - - // 验证至少有一个WHEN子句 - if len(caseNode.WhenClauses) == 0 { - return nil, i, fmt.Errorf("CASE expression must have at least one WHEN clause") - } - - return caseNode, i, nil -} - -// 辅助函数 -func isDigit(ch byte) bool { - return ch >= '0' && ch <= '9' -} - -func isLetter(ch byte) bool { - return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') -} - -func isNumber(s string) bool { - _, err := strconv.ParseFloat(s, 64) - return err == nil -} - -func isIdentifier(s string) bool { - if len(s) == 0 { - return false - } - - if !isLetter(s[0]) && s[0] != '_' { - return false - } - - for i := 1; i < len(s); i++ { - if !isLetter(s[i]) && !isDigit(s[i]) && s[i] != '_' { - return false - } - } - - return true -} - -func isOperator(s string) bool { - switch s { - case "+", "-", "*", "/", "%", "^": - return true - case ">", "<", ">=", "<=", "==", "=", "!=", "<>": - return true - case "AND", "OR", "NOT": - return true - case "LIKE", "IS": - return true - default: - return false - } -} - -// isComparisonOperator 检查是否是比较运算符 -func isComparisonOperator(s string) bool { - switch s { - case ">", "<", ">=", "<=", "==", "=", "!=", "<>": - return true - default: - return false - } -} - -func isStringLiteral(expr string) bool { - return len(expr) > 1 && (expr[0] == '\'' || expr[0] == '"') && expr[len(expr)-1] == expr[0] -} - -// evaluateNodeWithNull 计算节点值,支持NULL值返回 -// 返回 (result, isNull, error) -func evaluateNodeWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) { - if node == nil { - return 0, true, nil // NULL - } - - switch node.Type { - case TypeNumber: - val, err := strconv.ParseFloat(node.Value, 64) - return val, false, err - - case TypeString: - // 字符串长度作为数值,特殊处理NULL字符串 - value := node.Value - if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { - value = value[1 : len(value)-1] - } - // 检查是否是NULL字符串 - if strings.ToUpper(value) == "NULL" { - return 0, true, nil - } - return float64(len(value)), false, nil - - case TypeField: - // 支持嵌套字段访问 - var fieldVal interface{} - var found bool - - if fieldpath.IsNestedField(node.Value) { - fieldVal, found = fieldpath.GetNestedField(data, node.Value) - } else { - fieldVal, found = data[node.Value] - } - - if !found || fieldVal == nil { - return 0, true, nil // NULL - } - - // 尝试转换为数值 - if val, err := convertToFloat(fieldVal); err == nil { - return val, false, nil - } - return 0, true, fmt.Errorf("cannot convert field '%s' to number", node.Value) - - case TypeOperator: - return evaluateOperatorWithNull(node, data) - - case TypeFunction: - // 函数调用保持原有逻辑,但处理NULL结果 - result, err := evaluateBuiltinFunction(node, data) - return result, false, err - - case TypeCase: - return evaluateCaseExpressionWithNull(node, data) - - default: - return 0, true, fmt.Errorf("unsupported node type: %s", node.Type) - } -} - -// evaluateOperatorWithNull 计算运算符表达式,支持NULL值 -func evaluateOperatorWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) { - leftVal, leftNull, err := evaluateNodeWithNull(node.Left, data) - if err != nil { - return 0, false, err - } - - rightVal, rightNull, err := evaluateNodeWithNull(node.Right, data) - if err != nil { - return 0, false, err - } - - // 算术运算:如果任一操作数为NULL,结果为NULL - if leftNull || rightNull { - switch node.Value { - case "+", "-", "*", "/", "%", "^": - return 0, true, nil - } - } - - // 比较运算:NULL值的比较有特殊规则 - switch node.Value { - case "==", "=": - if leftNull && rightNull { - return 1, false, nil // NULL = NULL 为 true - } - if leftNull || rightNull { - return 0, false, nil // NULL = value 为 false - } - if leftVal == rightVal { - return 1, false, nil - } - return 0, false, nil - - case "!=", "<>": - if leftNull && rightNull { - return 0, false, nil // NULL != NULL 为 false - } - if leftNull || rightNull { - return 0, false, nil // NULL != value 为 false - } - if leftVal != rightVal { - return 1, false, nil - } - return 0, false, nil - - case ">", "<", ">=", "<=": - if leftNull || rightNull { - return 0, false, nil // NULL与任何值的比较都为false - } - } - - // 对于非NULL值,执行正常的算术和比较运算 - switch node.Value { - case "+": - return leftVal + rightVal, false, nil - case "-": - return leftVal - rightVal, false, nil - case "*": - return leftVal * rightVal, false, nil - case "/": - if rightVal == 0 { - return 0, true, nil // 除零返回NULL - } - return leftVal / rightVal, false, nil - case "%": - if rightVal == 0 { - return 0, true, nil - } - return math.Mod(leftVal, rightVal), false, nil - case "^": - return math.Pow(leftVal, rightVal), false, nil - case ">": - if leftVal > rightVal { - return 1, false, nil - } - return 0, false, nil - case "<": - if leftVal < rightVal { - return 1, false, nil - } - return 0, false, nil - case ">=": - if leftVal >= rightVal { - return 1, false, nil - } - return 0, false, nil - case "<=": - if leftVal <= rightVal { - return 1, false, nil - } - return 0, false, nil - default: - return 0, false, fmt.Errorf("unsupported operator: %s", node.Value) - } -} - -// evaluateCaseExpressionWithNull 计算CASE表达式,支持NULL值 -func evaluateCaseExpressionWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) { - if node.Type != TypeCase { - return 0, false, fmt.Errorf("node is not a CASE expression") - } - - // 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...) - if node.CaseExpr != nil { - // 计算CASE后面的表达式值 - caseValue, caseNull, err := evaluateNodeValueWithNull(node.CaseExpr, data) - if err != nil { - return 0, false, err - } - - // 遍历WHEN子句,查找匹配的值 - for _, whenClause := range node.WhenClauses { - conditionValue, condNull, err := evaluateNodeValueWithNull(whenClause.Condition, data) - if err != nil { - return 0, false, err - } - - // 比较值是否相等(考虑NULL值) - var isEqual bool - if caseNull && condNull { - isEqual = true // NULL = NULL - } else if caseNull || condNull { - isEqual = false // NULL != value - } else { - isEqual, err = compareValuesForEquality(caseValue, conditionValue) - if err != nil { - return 0, false, err - } - } - - if isEqual { - return evaluateNodeWithNull(whenClause.Result, data) - } - } - } else { - // 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...) - for _, whenClause := range node.WhenClauses { - // 评估WHEN条件 - conditionResult, err := evaluateBooleanConditionWithNull(whenClause.Condition, data) - if err != nil { - return 0, false, err - } - - // 如果条件为真,返回对应的结果 - if conditionResult { - return evaluateNodeWithNull(whenClause.Result, data) - } - } - } - - // 如果没有匹配的WHEN子句,执行ELSE子句 - if node.ElseExpr != nil { - return evaluateNodeWithNull(node.ElseExpr, data) - } - - // 如果没有ELSE子句,SQL标准是返回NULL - return 0, true, nil -} - -// evaluateCaseExpressionValueWithNull 计算CASE表达式并返回实际值(支持字符串),支持NULL值 -func evaluateCaseExpressionValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) { - if node.Type != TypeCase { - return nil, false, fmt.Errorf("node is not a CASE expression") - } - - // 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...) - if node.CaseExpr != nil { - // 计算CASE后面的表达式值 - caseValue, caseNull, err := evaluateNodeValueWithNull(node.CaseExpr, data) - if err != nil { - return nil, false, err - } - - // 遍历WHEN子句,查找匹配的值 - for _, whenClause := range node.WhenClauses { - conditionValue, condNull, err := evaluateNodeValueWithNull(whenClause.Condition, data) - if err != nil { - return nil, false, err - } - - // 比较值是否相等(考虑NULL值) - var isEqual bool - if caseNull && condNull { - isEqual = true // NULL = NULL - } else if caseNull || condNull { - isEqual = false // NULL != value - } else { - isEqual, err = compareValuesForEquality(caseValue, conditionValue) - if err != nil { - return nil, false, err - } - } - - if isEqual { - return evaluateNodeValueWithNull(whenClause.Result, data) - } - } - } else { - // 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...) - for _, whenClause := range node.WhenClauses { - // 评估WHEN条件 - conditionResult, err := evaluateBooleanConditionWithNull(whenClause.Condition, data) - if err != nil { - return nil, false, err - } - - // 如果条件为真,返回对应的结果 - if conditionResult { - return evaluateNodeValueWithNull(whenClause.Result, data) - } - } - } - - // 如果没有匹配的WHEN子句,执行ELSE子句 - if node.ElseExpr != nil { - return evaluateNodeValueWithNull(node.ElseExpr, data) - } - - // 如果没有ELSE子句,SQL标准是返回NULL - return nil, true, nil -} - -// evaluateNodeValueWithNull 计算节点值,返回interface{}以支持不同类型,包含NULL检查 -func evaluateNodeValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) { - if node == nil { - return nil, true, nil - } - - switch node.Type { - case TypeNumber: - val, err := strconv.ParseFloat(node.Value, 64) - return val, false, err - - case TypeString: - // 去掉引号 - value := node.Value - if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { - value = value[1 : len(value)-1] - } - // 检查是否是NULL字符串 - if strings.ToUpper(value) == "NULL" { - return nil, true, nil - } - return value, false, nil - - case TypeField: - // 处理反引号标识符,去除反引号 - fieldName := node.Value - if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { - fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号 - } - - // 支持嵌套字段访问 - if fieldpath.IsNestedField(fieldName) { - if val, found := fieldpath.GetNestedField(data, fieldName); found { - return val, val == nil, nil - } - } else { - // 原有的简单字段访问 - if val, found := data[fieldName]; found { - return val, val == nil, nil - } - } - return nil, true, nil // 字段不存在视为NULL - - case TypeCase: - // 处理CASE表达式,返回实际值 - return evaluateCaseExpressionValueWithNull(node, data) - - default: - // 对于其他类型,回退到数值计算 - result, isNull, err := evaluateNodeWithNull(node, data) - return result, isNull, err - } -} - -// evaluateBooleanConditionWithNull 计算布尔条件表达式,支持NULL值 -func evaluateBooleanConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) { - if node == nil { - return false, fmt.Errorf("null condition expression") - } - - // 处理逻辑运算符(实现短路求值) - if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") { - leftBool, err := evaluateBooleanConditionWithNull(node.Left, data) - if err != nil { - return false, err - } - - // 短路求值:对于AND,如果左边为false,立即返回false - if node.Value == "AND" && !leftBool { - return false, nil - } - - // 短路求值:对于OR,如果左边为true,立即返回true - if node.Value == "OR" && leftBool { - return true, nil - } - - // 只有在需要时才评估右边的表达式 - rightBool, err := evaluateBooleanConditionWithNull(node.Right, data) - if err != nil { - return false, err - } - - switch node.Value { - case "AND": - return leftBool && rightBool, nil - case "OR": - return leftBool || rightBool, nil - } - } - - // 处理IS NULL和IS NOT NULL特殊情况 - if node.Type == TypeOperator && node.Value == "IS" { - return evaluateIsConditionWithNull(node, data) - } - - // 处理比较运算符 - if node.Type == TypeOperator { - leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data) - if err != nil { - return false, err - } - - rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data) - if err != nil { - return false, err - } - - return compareValuesWithNull(leftValue, leftNull, rightValue, rightNull, node.Value) - } - - // 对于其他表达式,计算其数值并转换为布尔值 - result, isNull, err := evaluateNodeWithNull(node, data) - if err != nil { - return false, err - } - - // NULL值在布尔上下文中为false,非零值为真,零值为假 - return !isNull && result != 0, nil -} - -// evaluateIsConditionWithNull 处理IS NULL和IS NOT NULL条件,支持NULL值 -func evaluateIsConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) { - if node == nil || node.Left == nil || node.Right == nil { - return false, fmt.Errorf("invalid IS condition") - } - - // 获取左侧值 - leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data) - if err != nil { - // 如果字段不存在,认为是null - leftValue = nil - leftNull = true - } - - // 检查右侧是否是NULL或NOT NULL - if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" { - // IS NULL - return leftNull || leftValue == nil, nil - } - - // 检查是否是IS NOT NULL - if node.Right.Type == TypeOperator && node.Right.Value == "NOT" && - node.Right.Right != nil && node.Right.Right.Type == TypeField && - strings.ToUpper(node.Right.Right.Value) == "NULL" { - // IS NOT NULL - return !leftNull && leftValue != nil, nil - } - - // 其他IS比较 - rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data) - if err != nil { - return false, err - } - - return compareValuesWithNullForEquality(leftValue, leftNull, rightValue, rightNull) -} - -// compareValuesForEquality 比较两个值是否相等 -func compareValuesForEquality(left, right interface{}) (bool, error) { - // 尝试字符串比较 - leftStr, leftIsStr := left.(string) - rightStr, rightIsStr := right.(string) - - if leftIsStr && rightIsStr { - return leftStr == rightStr, nil - } - - // 尝试数值比较 - leftFloat, leftErr := convertToFloat(left) - rightFloat, rightErr := convertToFloat(right) - - if leftErr == nil && rightErr == nil { - return leftFloat == rightFloat, nil - } - - // 如果都不能转换,直接比较 - return left == right, nil -} - -// compareValuesWithNull 比较两个值(支持NULL) -func compareValuesWithNull(left interface{}, leftNull bool, right interface{}, rightNull bool, operator string) (bool, error) { - // NULL值的比较有特殊规则 - switch operator { - case "==", "=": - if leftNull && rightNull { - return true, nil // NULL = NULL 为 true - } - if leftNull || rightNull { - return false, nil // NULL = value 为 false - } - - case "!=", "<>": - if leftNull && rightNull { - return false, nil // NULL != NULL 为 false - } - if leftNull || rightNull { - return false, nil // NULL != value 为 false - } - - case ">", "<", ">=", "<=": - if leftNull || rightNull { - return false, nil // NULL与任何值的比较都为false - } - } - - // 对于非NULL值,执行正确的比较逻辑 - switch operator { - case "==", "=": - return compareValuesForEquality(left, right) - case "!=", "<>": - equal, err := compareValuesForEquality(left, right) - return !equal, err - case ">", "<", ">=", "<=": - // 进行数值比较 - leftFloat, leftErr := convertToFloat(left) - rightFloat, rightErr := convertToFloat(right) - - if leftErr != nil || rightErr != nil { - // 如果不能转换为数值,尝试字符串比较 - leftStr := fmt.Sprintf("%v", left) - rightStr := fmt.Sprintf("%v", right) - - switch operator { - case ">": - return leftStr > rightStr, nil - case "<": - return leftStr < rightStr, nil - case ">=": - return leftStr >= rightStr, nil - case "<=": - return leftStr <= rightStr, nil - } - } - - // 数值比较 - switch operator { - case ">": - return leftFloat > rightFloat, nil - case "<": - return leftFloat < rightFloat, nil - case ">=": - return leftFloat >= rightFloat, nil - case "<=": - return leftFloat <= rightFloat, nil - } - } - - return false, fmt.Errorf("unsupported operator: %s", operator) -} - -// compareValuesWithNullForEquality 比较两个值是否相等(支持NULL) -func compareValuesWithNullForEquality(left interface{}, leftNull bool, right interface{}, rightNull bool) (bool, error) { - if leftNull && rightNull { - return true, nil // NULL = NULL 为 true - } - if leftNull || rightNull { - return false, nil // NULL = value 为 false - } - return compareValuesForEquality(left, right) -} - -// EvaluateWithNull 提供公开接口,用于聚合函数调用 +// EvaluateWithNull provides public interface for aggregate function calls, supports NULL value handling func (e *Expression) EvaluateWithNull(data map[string]interface{}) (float64, bool, error) { if e.useExprLang { - // expr-lang不支持NULL,回退到原有逻辑 + // expr-lang doesn't support NULL, fallback to original logic result, err := e.evaluateWithExprLang(data) return result, false, err } return evaluateNodeWithNull(e.Root, data) } -// EvaluateValueWithNull 评估表达式并返回任意类型的值,支持NULL +// EvaluateValueWithNull evaluates expression and returns value of any type, supports NULL func (e *Expression) EvaluateValueWithNull(data map[string]interface{}) (interface{}, bool, error) { if e.useExprLang { - // expr-lang不支持NULL,回退到原有逻辑 + // expr-lang doesn't support NULL, fallback to original logic result, err := e.evaluateWithExprLang(data) return result, false, err } diff --git a/expr/expression_test.go b/expr/expression_test.go index c11f118..1c01196 100644 --- a/expr/expression_test.go +++ b/expr/expression_test.go @@ -1,819 +1,1205 @@ -package expr - -import ( - "math" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestExpressionEvaluation(t *testing.T) { - tests := []struct { - name string - expr string - data map[string]interface{} - expected float64 - hasError bool - }{ - // Basic arithmetic tests - {"Simple Addition", "a + b", map[string]interface{}{"a": 5, "b": 3}, 8, false}, - {"Simple Subtraction", "a - b", map[string]interface{}{"a": 5, "b": 3}, 2, false}, - {"Simple Multiplication", "a * b", map[string]interface{}{"a": 5, "b": 3}, 15, false}, - {"Simple Division", "a / b", map[string]interface{}{"a": 6, "b": 3}, 2, false}, - {"Modulo", "a % b", map[string]interface{}{"a": 7, "b": 4}, 3, false}, - {"Power", "a ^ b", map[string]interface{}{"a": 2, "b": 3}, 8, false}, - - // Compound expression tests - {"Complex Expression", "a + b * c", map[string]interface{}{"a": 5, "b": 3, "c": 2}, 11, false}, - {"Complex Expression With Parentheses", "(a + b) * c", map[string]interface{}{"a": 5, "b": 3, "c": 2}, 16, false}, - {"Multiple Operations", "a + b * c - d / e", map[string]interface{}{"a": 5, "b": 3, "c": 2, "d": 8, "e": 4}, 9, false}, - - // Function call tests - {"Abs Function", "abs(a - b)", map[string]interface{}{"a": 3, "b": 5}, 2, false}, - {"Sqrt Function", "sqrt(a)", map[string]interface{}{"a": 16}, 4, false}, - {"Round Function", "round(a)", map[string]interface{}{"a": 3.7}, 4, false}, - - // Conversion tests - {"String to Number", "a + b", map[string]interface{}{"a": "5", "b": 3}, 8, false}, - - // Complex expression tests - {"Temperature Conversion", "temperature * 1.8 + 32", map[string]interface{}{"temperature": 25}, 77, false}, - {"Complex Math", "sqrt(abs(a * b - c / d))", map[string]interface{}{"a": 10, "b": 2, "c": 5, "d": 1}, 3.872983346207417, false}, - - // Error tests - {"Division by Zero", "a / b", map[string]interface{}{"a": 5, "b": 0}, 0, true}, - {"Missing Field", "a + b", map[string]interface{}{"a": 5}, 0, true}, - {"Invalid Function", "unknown(a)", map[string]interface{}{"a": 5}, 0, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - assert.NoError(t, err, "Expression parsing should not fail") - - result, err := expr.Evaluate(tt.data) - if tt.hasError { - assert.Error(t, err, "Expected error") - } else { - assert.NoError(t, err, "Evaluation should not fail") - assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") - } - }) - } -} - -// TestCaseExpressionParsing tests CASE expression parsing functionality -func TestCaseExpressionParsing(t *testing.T) { - tests := []struct { - name string - exprStr string - data map[string]interface{} - expected float64 - wantErr bool - }{ - { - name: "Simple search CASE expression", - exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 35.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "Simple CASE expression - value matching", - exprStr: "CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END", - data: map[string]interface{}{"status": "active"}, - expected: 1.0, - wantErr: false, - }, - { - name: "CASE expression - ELSE branch", - exprStr: "CASE WHEN temperature > 50 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 25.5}, - expected: 0.0, - wantErr: false, - }, - { - name: "Complex search CASE expression", - exprStr: "CASE WHEN temperature > 30 THEN 'HOT' WHEN temperature > 20 THEN 'WARM' ELSE 'COLD' END", - data: map[string]interface{}{"temperature": 25.0}, - expected: 4.0, // Length of string "WARM" - wantErr: false, - }, - { - name: "Simple CASE with numeric comparison", - exprStr: "CASE temperature WHEN 25 THEN 1 WHEN 30 THEN 2 ELSE 0 END", - data: map[string]interface{}{"temperature": 30.0}, - expected: 2.0, - wantErr: false, - }, - { - name: "Boolean CASE expression", - exprStr: "CASE WHEN temperature > 25 AND humidity > 50 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 30.0, "humidity": 60.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "Multi-condition CASE expression with AND", - exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 WHEN temperature > 20 THEN 2 ELSE 0 END", - data: map[string]interface{}{"temperature": 35.0, "humidity": 50.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "Multi-condition CASE expression with OR", - exprStr: "CASE WHEN temperature > 40 OR humidity > 80 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 25.0, "humidity": 85.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "Function call in CASE - ABS", - exprStr: "CASE WHEN ABS(temperature) > 30 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": -35.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "Function call in CASE - ROUND", - exprStr: "CASE WHEN ROUND(temperature) = 25 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 24.7}, - expected: 1.0, - wantErr: false, - }, - { - name: "Complex condition combination", - exprStr: "CASE WHEN temperature > 30 AND (humidity > 60 OR pressure < 1000) THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 35.0, "humidity": 55.0, "pressure": 950.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "Arithmetic expression in CASE", - exprStr: "CASE WHEN temperature * 1.8 + 32 > 100 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 40.0}, // 40*1.8+32 = 104 - expected: 1.0, - wantErr: false, - }, - { - name: "String function in CASE", - exprStr: "CASE WHEN LENGTH(device_name) > 5 THEN 1 ELSE 0 END", - data: map[string]interface{}{"device_name": "sensor123"}, - expected: 1.0, // LENGTH function works normally, "sensor123" length is 9 > 5, returns 1 - wantErr: false, - }, - { - name: "Simple CASE with function", - exprStr: "CASE ABS(temperature) WHEN 30 THEN 1 WHEN 25 THEN 2 ELSE 0 END", - data: map[string]interface{}{"temperature": -30.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "Function in CASE result", - exprStr: "CASE WHEN temperature > 30 THEN ABS(temperature) ELSE ROUND(temperature) END", - data: map[string]interface{}{"temperature": 35.5}, - expected: 35.5, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expression, err := NewExpression(tt.exprStr) - if tt.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err, "Expression creation should not fail") - assert.NotNil(t, expression, "Expression should not be nil") - - // Test expression evaluation - result, err := expression.Evaluate(tt.data) - if tt.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err, "Expression evaluation should not fail") - assert.Equal(t, tt.expected, result, "Expression result should match expected value") - }) - } -} - -// TestCaseExpressionFieldExtraction 测试CASE表达式的字段提取功能 -func TestCaseExpressionFieldExtraction(t *testing.T) { - testCases := []struct { - name string - exprStr string - expectedFields []string - }{ - { - name: "简单CASE字段提取", - exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", - expectedFields: []string{"temperature"}, - }, - { - name: "多字段CASE字段提取", - exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 ELSE 0 END", - expectedFields: []string{"temperature", "humidity"}, - }, - { - name: "简单CASE字段提取", - exprStr: "CASE status WHEN 'active' THEN temperature ELSE humidity END", - expectedFields: []string{"status", "temperature", "humidity"}, - }, - { - name: "函数CASE字段提取", - exprStr: "CASE WHEN ABS(temperature) > 30 THEN device_id ELSE location END", - expectedFields: []string{"temperature", "device_id", "location"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - expression, err := NewExpression(tc.exprStr) - assert.NoError(t, err, "表达式创建应该成功") - - fields := expression.GetFields() - - // 验证所有期望的字段都被提取到了 - for _, expectedField := range tc.expectedFields { - assert.Contains(t, fields, expectedField, "应该包含字段: %s", expectedField) - } - }) - } -} - -// TestCaseExpressionWithNullComparisons 测试CASE表达式中的NULL比较 -func TestCaseExpressionWithNullComparisons(t *testing.T) { - tests := []struct { - name string - exprStr string - data map[string]interface{} - expected interface{} // 使用interface{}以支持NULL值 - isNull bool - }{ - { - name: "NULL值在CASE条件中 - 应该走ELSE分支", - exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": nil}, - expected: 0.0, - isNull: false, - }, - { - name: "IS NULL条件 - 应该匹配", - exprStr: "CASE WHEN temperature IS NULL THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": nil}, - expected: 1.0, - isNull: false, - }, - { - name: "IS NOT NULL条件 - 不应该匹配", - exprStr: "CASE WHEN temperature IS NOT NULL THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": nil}, - expected: 0.0, - isNull: false, - }, - { - name: "CASE表达式返回NULL", - exprStr: "CASE WHEN temperature > 30 THEN temperature ELSE NULL END", - data: map[string]interface{}{"temperature": 25.0}, - expected: nil, - isNull: true, - }, - { - name: "CASE表达式返回有效值", - exprStr: "CASE WHEN temperature > 30 THEN temperature ELSE NULL END", - data: map[string]interface{}{"temperature": 35.0}, - expected: 35.0, - isNull: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expression, err := NewExpression(tt.exprStr) - assert.NoError(t, err, "表达式解析应该成功") - - // 测试支持NULL的计算方法 - result, isNull, err := expression.EvaluateWithNull(tt.data) - assert.NoError(t, err, "表达式计算应该成功") - - if tt.isNull { - assert.True(t, isNull, "表达式应该返回NULL") - } else { - assert.False(t, isNull, "表达式不应该返回NULL") - assert.Equal(t, tt.expected, result, "表达式结果应该匹配期望值") - } - }) - } -} - -// TestNegativeNumberSupport 专门测试负数支持 -func TestNegativeNumberSupport(t *testing.T) { - tests := []struct { - name string - exprStr string - data map[string]interface{} - expected float64 - wantErr bool - }{ - { - name: "负数常量在THEN中", - exprStr: "CASE WHEN temperature > 0 THEN 1 ELSE -1 END", - data: map[string]interface{}{"temperature": -5.0}, - expected: -1.0, - wantErr: false, - }, - { - name: "负数常量在WHEN中", - exprStr: "CASE WHEN temperature < -10 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": -15.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "负数小数", - exprStr: "CASE WHEN temperature > 0 THEN 1.5 ELSE -2.5 END", - data: map[string]interface{}{"temperature": -1.0}, - expected: -2.5, - wantErr: false, - }, - { - name: "负数在算术表达式中", - exprStr: "CASE WHEN temperature + (-10) > 0 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 15.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "负数与函数", - exprStr: "CASE WHEN ABS(temperature) > 10 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": -15.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "负数在简单CASE中", - exprStr: "CASE temperature WHEN -10 THEN 1 WHEN -20 THEN 2 ELSE 0 END", - data: map[string]interface{}{"temperature": -10.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "负零", - exprStr: "CASE WHEN temperature = -0 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 0.0}, - expected: 1.0, - wantErr: false, - }, - // 基本负数运算 - { - name: "直接负数", - exprStr: "-5", - data: map[string]interface{}{}, - expected: -5.0, - wantErr: false, - }, - { - name: "负数加法", - exprStr: "-5 + 3", - data: map[string]interface{}{}, - expected: -2.0, - wantErr: false, - }, - { - name: "负数乘法", - exprStr: "-3 * 4", - data: map[string]interface{}{}, - expected: -12.0, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expression, err := NewExpression(tt.exprStr) - if tt.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err, "负数表达式解析应该成功") - assert.NotNil(t, expression, "表达式不应为空") - - // 测试表达式计算 - result, err := expression.Evaluate(tt.data) - assert.NoError(t, err, "负数表达式计算应该成功") - assert.Equal(t, tt.expected, result, "负数表达式结果应该匹配期望值") - }) - } -} - -func TestGetFields(t *testing.T) { - tests := []struct { - expr string - expectedFields []string - }{ - {"a + b", []string{"a", "b"}}, - {"a + b * c", []string{"a", "b", "c"}}, - {"temperature * 1.8 + 32", []string{"temperature"}}, - {"abs(humidity - 50)", []string{"humidity"}}, - {"sqrt(x^2 + y^2)", []string{"x", "y"}}, - } - - for _, tt := range tests { - t.Run(tt.expr, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - assert.NoError(t, err, "Expression parsing should not fail") - - fields := expr.GetFields() - - // 由于map迭代顺序不确定,我们只检查长度和包含关系 - assert.Equal(t, len(tt.expectedFields), len(fields), "Number of fields should match") - - for _, field := range tt.expectedFields { - found := false - for _, f := range fields { - if f == field { - found = true - break - } - } - assert.True(t, found, "Field %s should be found", field) - } - }) - } -} - -func TestParseError(t *testing.T) { - tests := []struct { - name string - expr string - }{ - {"Empty Expression", ""}, - {"Mismatched Parentheses", "a + (b * c"}, - {"Invalid Character", "a # b"}, - {"Double Operator", "a + * b"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := NewExpression(tt.expr) - assert.Error(t, err, "Expression parsing should fail") - }) - } -} - -// TestExpressionTokenization 测试表达式分词功能 -func TestExpressionTokenization(t *testing.T) { - tests := []struct { - name string - expr string - expected []string - }{ - {"Simple Expression", "a + b", []string{"a", "+", "b"}}, - {"With Numbers", "a + 123", []string{"a", "+", "123"}}, - {"With Parentheses", "(a + b) * c", []string{"(", "a", "+", "b", ")", "*", "c"}}, - {"With Functions", "abs(a)", []string{"abs", "(", "a", ")"}}, - {"With Decimals", "a + 3.14", []string{"a", "+", "3.14"}}, - {"With Negative Numbers", "-5 + a", []string{"-5", "+", "a"}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokens, err := tokenize(tt.expr) - require.NoError(t, err) - assert.Equal(t, tt.expected, tokens, "Tokenization should match expected") - }) - } -} - -// TestExpressionValidation 测试表达式验证功能 -func TestExpressionValidation(t *testing.T) { - tests := []struct { - name string - expr string - valid bool - errorMsg string - }{ - {"Valid Simple Expression", "a + b", true, ""}, - {"Valid Complex Expression", "(a + b) * c / d", true, ""}, - {"Invalid Empty Expression", "", false, "empty expression"}, - {"Invalid Mismatched Parentheses", "(a + b", false, "mismatched parentheses"}, - {"Invalid Double Operator", "a + + b", false, "consecutive operators"}, - {"Invalid Starting Operator", "+ a", false, "expression cannot start with operator"}, - {"Invalid Ending Operator", "a +", false, "expression cannot end with operator"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateBasicSyntax(tt.expr) - if tt.valid { - assert.NoError(t, err, "Expression should be valid") - } else { - assert.Error(t, err, "Expression should be invalid") - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg, "Error message should contain expected text") - } - } - }) - } -} - -// TestExpressionOperatorPrecedence 测试运算符优先级 -func TestExpressionOperatorPrecedence(t *testing.T) { - tests := []struct { - name string - expr string - data map[string]interface{} - expected float64 - }{ - {"Addition and Multiplication", "2 + 3 * 4", map[string]interface{}{}, 14}, // 2 + (3 * 4) = 14 - {"Subtraction and Division", "10 - 8 / 2", map[string]interface{}{}, 6}, // 10 - (8 / 2) = 6 - {"Power and Multiplication", "2 * 3 ^ 2", map[string]interface{}{}, 18}, // 2 * (3 ^ 2) = 18 - {"Parentheses Override", "(2 + 3) * 4", map[string]interface{}{}, 20}, // (2 + 3) * 4 = 20 - {"Complex Expression", "2 + 3 * 4 - 5 / 2", map[string]interface{}{}, 11.5}, // 2 + (3 * 4) - (5 / 2) = 11.5 - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - require.NoError(t, err, "Expression parsing should not fail") - - result, err := expr.Evaluate(tt.data) - require.NoError(t, err, "Expression evaluation should not fail") - assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") - }) - } -} - -// TestExpressionFunctions 测试内置函数 -func TestExpressionFunctions(t *testing.T) { - tests := []struct { - name string - expr string - data map[string]interface{} - expected float64 - wantErr bool - }{ - {"ABS Positive", "abs(5)", map[string]interface{}{}, 5, false}, - {"ABS Negative", "abs(-5)", map[string]interface{}{}, 5, false}, - {"ABS Zero", "abs(0)", map[string]interface{}{}, 0, false}, - {"SQRT Valid", "sqrt(16)", map[string]interface{}{}, 4, false}, - {"SQRT Zero", "sqrt(0)", map[string]interface{}{}, 0, false}, - {"SQRT Negative", "sqrt(-1)", map[string]interface{}{}, 0, true}, - {"ROUND Positive", "round(3.7)", map[string]interface{}{}, 4, false}, - {"ROUND Negative", "round(-3.7)", map[string]interface{}{}, -4, false}, - {"ROUND Half", "round(3.5)", map[string]interface{}{}, 4, false}, - {"FLOOR Positive", "floor(3.7)", map[string]interface{}{}, 3, false}, - {"FLOOR Negative", "floor(-3.7)", map[string]interface{}{}, -4, false}, - {"CEIL Positive", "ceil(3.2)", map[string]interface{}{}, 4, false}, - {"CEIL Negative", "ceil(-3.2)", map[string]interface{}{}, -3, false}, - {"MAX Two Values", "max(5, 3)", map[string]interface{}{}, 5, false}, - {"MIN Two Values", "min(5, 3)", map[string]interface{}{}, 3, false}, - {"POW Function", "pow(2, 3)", map[string]interface{}{}, 8, false}, - {"LOG Function", "log(10)", map[string]interface{}{}, math.Log(10), false}, - {"LOG10 Function", "log10(100)", map[string]interface{}{}, 2, false}, - {"EXP Function", "exp(1)", map[string]interface{}{}, math.E, false}, - {"SIN Function", "sin(0)", map[string]interface{}{}, 0, false}, - {"COS Function", "cos(0)", map[string]interface{}{}, 1, false}, - {"TAN Function", "tan(0)", map[string]interface{}{}, 0, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - require.NoError(t, err, "Expression parsing should not fail") - - result, err := expr.Evaluate(tt.data) - if tt.wantErr { - assert.Error(t, err, "Expected error") - } else { - require.NoError(t, err, "Expression evaluation should not fail") - assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") - } - }) - } -} - -// TestExpressionDataTypeConversion 测试数据类型转换 -func TestExpressionDataTypeConversion(t *testing.T) { - tests := []struct { - name string - expr string - data map[string]interface{} - expected float64 - wantErr bool - }{ - {"String to Number", "a + 5", map[string]interface{}{"a": "10"}, 15, false}, - {"Integer to Float", "a + 3.5", map[string]interface{}{"a": 5}, 8.5, false}, - {"Float to Float", "a + b", map[string]interface{}{"a": 3.14, "b": 2.86}, 6.0, false}, - {"Boolean True", "a + 1", map[string]interface{}{"a": true}, 2, false}, - {"Boolean False", "a + 1", map[string]interface{}{"a": false}, 1, false}, - {"Invalid String", "a + 5", map[string]interface{}{"a": "invalid"}, 0, true}, - {"Nil Value", "a + 5", map[string]interface{}{"a": nil}, 0, true}, - {"Complex Type", "a + 5", map[string]interface{}{"a": map[string]interface{}{}}, 0, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - require.NoError(t, err, "Expression parsing should not fail") - - result, err := expr.Evaluate(tt.data) - if tt.wantErr { - assert.Error(t, err, "Expected error") - } else { - require.NoError(t, err, "Expression evaluation should not fail") - assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") - } - }) - } -} - -// TestExpressionEdgeCases 测试边界情况 -func TestExpressionEdgeCases(t *testing.T) { - tests := []struct { - name string - expr string - data map[string]interface{} - expected float64 - wantErr bool - }{ - {"Very Large Number", "a + 1", map[string]interface{}{"a": 1e308}, 1e308 + 1, false}, - {"Very Small Number", "a + 1", map[string]interface{}{"a": 1e-308}, 1, false}, - {"Infinity", "a + 1", map[string]interface{}{"a": math.Inf(1)}, math.Inf(1), false}, - {"Negative Infinity", "a + 1", map[string]interface{}{"a": math.Inf(-1)}, math.Inf(-1), false}, - {"NaN", "a + 1", map[string]interface{}{"a": math.NaN()}, 0, true}, - {"Division by Zero", "5 / 0", map[string]interface{}{}, 0, true}, - {"Modulo by Zero", "5 % 0", map[string]interface{}{}, 0, true}, - {"Zero Power Zero", "0 ^ 0", map[string]interface{}{}, 1, false}, // 0^0 = 1 by convention - {"Negative Power", "2 ^ -3", map[string]interface{}{}, 0.125, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - require.NoError(t, err, "Expression parsing should not fail") - - result, err := expr.Evaluate(tt.data) - if tt.wantErr { - assert.Error(t, err, "Expected error") - } else { - require.NoError(t, err, "Expression evaluation should not fail") - if math.IsInf(tt.expected, 0) { - assert.True(t, math.IsInf(result, 0), "Result should be infinity") - } else { - assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") - } - } - }) - } -} - -// TestExpressionConcurrency 测试并发安全性 -func TestExpressionConcurrency(t *testing.T) { - expr, err := NewExpression("a + b * c") - require.NoError(t, err, "Expression parsing should not fail") - - // 并发执行多个计算 - const numGoroutines = 100 - results := make(chan float64, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func(index int) { - data := map[string]interface{}{ - "a": float64(index), - "b": float64(index * 2), - "c": float64(index * 3), - } - result, err := expr.Evaluate(data) - assert.NoError(t, err, "Concurrent evaluation should not fail") - results <- result - }(i) - } - - // 收集结果 - for i := 0; i < numGoroutines; i++ { - result := <-results - // 验证结果是合理的(非零且非NaN) - assert.False(t, math.IsNaN(result), "Result should not be NaN") - assert.True(t, result >= 0, "Result should be non-negative for this test") - } -} - -// TestExpressionComplexNesting 测试复杂嵌套表达式 -func TestExpressionComplexNesting(t *testing.T) { - tests := []struct { - name string - expr string - data map[string]interface{} - expected float64 - }{ - { - "Deeply Nested Parentheses", - "((a + b) * (c - d)) / ((e + f) * (g - h))", - map[string]interface{}{"a": 1, "b": 2, "c": 5, "d": 3, "e": 2, "f": 3, "g": 7, "h": 2}, - 0.24, // ((1+2)*(5-3))/((2+3)*(7-2)) = (3*2)/(5*5) = 6/25 = 0.24 - }, - { - "Nested Functions", - "sqrt(abs(a - b) + pow(c, 2))", - map[string]interface{}{"a": 3, "b": 7, "c": 3}, - 3.606, // sqrt(abs(3-7) + pow(3,2)) = sqrt(4 + 9) = sqrt(13) ≈ 3.606 - }, - { - "Mixed Operations", - "a * b + c / d - e % f + pow(g, h)", - map[string]interface{}{"a": 2, "b": 3, "c": 8, "d": 2, "e": 7, "f": 3, "g": 2, "h": 3}, - 17, // 2*3 + 8/2 - 7%3 + pow(2,3) = 6 + 4 - 1 + 8 = 17 - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - require.NoError(t, err, "Expression parsing should not fail") - - result, err := expr.Evaluate(tt.data) - require.NoError(t, err, "Expression evaluation should not fail") - assert.InDelta(t, tt.expected, result, 0.1, "Result should match expected value") - }) - } -} - -// TestExpressionStringHandling 测试字符串处理 -func TestExpressionStringHandling(t *testing.T) { - tests := []struct { - name string - expr string - data map[string]interface{} - expected float64 - wantErr bool - }{ - {"String Length", "len(name)", map[string]interface{}{"name": "hello"}, 5, false}, - {"Empty String Length", "len(name)", map[string]interface{}{"name": ""}, 0, false}, - {"String Comparison Equal", "name == 'test'", map[string]interface{}{"name": "test"}, 1, false}, - {"String Comparison Not Equal", "name != 'test'", map[string]interface{}{"name": "hello"}, 1, false}, - {"String to Number Conversion", "val + 10", map[string]interface{}{"val": "5"}, 15, false}, - {"Invalid String to Number", "val + 10", map[string]interface{}{"val": "abc"}, 0, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expr, err := NewExpression(tt.expr) - require.NoError(t, err, "Expression parsing should not fail") - - result, err := expr.Evaluate(tt.data) - if tt.wantErr { - assert.Error(t, err, "Expected error") - } else { - require.NoError(t, err, "Expression evaluation should not fail") - assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") - } - }) - } -} - -// TestExpressionPerformance 测试表达式性能 -func TestExpressionPerformance(t *testing.T) { - // 创建一个复杂表达式 - expr, err := NewExpression("sqrt(pow(a, 2) + pow(b, 2)) + abs(c - d) * (e + f) / (g + 1)") - require.NoError(t, err, "Expression parsing should not fail") - - data := map[string]interface{}{ - "a": 3.0, "b": 4.0, "c": 10.0, "d": 7.0, "e": 2.0, "f": 3.0, "g": 4.0, - } - - // 执行多次计算以测试性能 - const iterations = 10000 - for i := 0; i < iterations; i++ { - _, err := expr.Evaluate(data) - assert.NoError(t, err, "Performance test evaluation should not fail") - } -} - -// TestExpressionMemoryUsage 测试内存使用 -func TestExpressionMemoryUsage(t *testing.T) { - // 创建多个表达式实例 - const numExpressions = 1000 - expressions := make([]*Expression, numExpressions) - - for i := 0; i < numExpressions; i++ { - expr, err := NewExpression("a + b * c") - require.NoError(t, err, "Expression creation should not fail") - expressions[i] = expr - } - - // 验证所有表达式都能正常工作 - data := map[string]interface{}{"a": 1, "b": 2, "c": 3} - for i, expr := range expressions { - result, err := expr.Evaluate(data) - assert.NoError(t, err, "Expression %d evaluation should not fail", i) - assert.Equal(t, 7.0, result, "Expression %d result should be correct", i) - } -} +package expr + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExpressionEvaluation(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected float64 + hasError bool + }{ + // Basic arithmetic tests + {"Simple Addition", "a + b", map[string]interface{}{"a": 5, "b": 3}, 8, false}, + {"Simple Subtraction", "a - b", map[string]interface{}{"a": 5, "b": 3}, 2, false}, + {"Simple Multiplication", "a * b", map[string]interface{}{"a": 5, "b": 3}, 15, false}, + {"Simple Division", "a / b", map[string]interface{}{"a": 6, "b": 3}, 2, false}, + {"Modulo", "a % b", map[string]interface{}{"a": 7, "b": 4}, 3, false}, + {"Power", "a ^ b", map[string]interface{}{"a": 2, "b": 3}, 8, false}, + + // Compound expression tests + {"Complex Expression", "a + b * c", map[string]interface{}{"a": 5, "b": 3, "c": 2}, 11, false}, + {"Complex Expression With Parentheses", "(a + b) * c", map[string]interface{}{"a": 5, "b": 3, "c": 2}, 16, false}, + {"Multiple Operations", "a + b * c - d / e", map[string]interface{}{"a": 5, "b": 3, "c": 2, "d": 8, "e": 4}, 9, false}, + + // Function call tests + {"Abs Function", "abs(a - b)", map[string]interface{}{"a": 3, "b": 5}, 2, false}, + {"Sqrt Function", "sqrt(a)", map[string]interface{}{"a": 16}, 4, false}, + {"Round Function", "round(a)", map[string]interface{}{"a": 3.7}, 4, false}, + + // Conversion tests + {"String to Number", "a + b", map[string]interface{}{"a": "5", "b": 3}, 8, false}, + + // Complex expression tests + {"Temperature Conversion", "temperature * 1.8 + 32", map[string]interface{}{"temperature": 25}, 77, false}, + {"Complex Math", "sqrt(abs(a * b - c / d))", map[string]interface{}{"a": 10, "b": 2, "c": 5, "d": 1}, 3.872983346207417, false}, + + // Error tests + {"Division by Zero", "a / b", map[string]interface{}{"a": 5, "b": 0}, 0, true}, + {"Missing Field", "a + b", map[string]interface{}{"a": 5}, 0, true}, + {"Invalid Function", "unknown(a)", map[string]interface{}{"a": 5}, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + assert.NoError(t, err, "Expression parsing should not fail") + + result, err := expr.Evaluate(tt.data) + if tt.hasError { + assert.Error(t, err, "Expected error") + } else { + assert.NoError(t, err, "Evaluation should not fail") + assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") + } + }) + } +} + +// TestCaseExpressionParsing tests CASE expression parsing functionality +func TestCaseExpressionParsing(t *testing.T) { + tests := []struct { + name string + exprStr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + { + name: "Simple search CASE expression", + exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 35.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "Simple CASE expression - value matching", + exprStr: "CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END", + data: map[string]interface{}{"status": "active"}, + expected: 1.0, + wantErr: false, + }, + { + name: "CASE expression - ELSE branch", + exprStr: "CASE WHEN temperature > 50 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 25.5}, + expected: 0.0, + wantErr: false, + }, + { + name: "Complex search CASE expression", + exprStr: "CASE WHEN temperature > 30 THEN 'HOT' WHEN temperature > 20 THEN 'WARM' ELSE 'COLD' END", + data: map[string]interface{}{"temperature": 25.0}, + expected: 4.0, // Length of string "WARM" + wantErr: false, + }, + { + name: "Simple CASE with numeric comparison", + exprStr: "CASE temperature WHEN 25 THEN 1 WHEN 30 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": 30.0}, + expected: 2.0, + wantErr: false, + }, + { + name: "Boolean CASE expression", + exprStr: "CASE WHEN temperature > 25 AND humidity > 50 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 30.0, "humidity": 60.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "Multi-condition CASE expression with AND", + exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 WHEN temperature > 20 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": 35.0, "humidity": 50.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "Multi-condition CASE expression with OR", + exprStr: "CASE WHEN temperature > 40 OR humidity > 80 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 25.0, "humidity": 85.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "Function call in CASE - ABS", + exprStr: "CASE WHEN ABS(temperature) > 30 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": -35.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "Function call in CASE - ROUND", + exprStr: "CASE WHEN ROUND(temperature) = 25 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 24.7}, + expected: 1.0, + wantErr: false, + }, + { + name: "Complex condition combination", + exprStr: "CASE WHEN temperature > 30 AND (humidity > 60 OR pressure < 1000) THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 35.0, "humidity": 55.0, "pressure": 950.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "Arithmetic expression in CASE", + exprStr: "CASE WHEN temperature * 1.8 + 32 > 100 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 40.0}, // 40*1.8+32 = 104 + expected: 1.0, + wantErr: false, + }, + { + name: "String function in CASE", + exprStr: "CASE WHEN LENGTH(device_name) > 5 THEN 1 ELSE 0 END", + data: map[string]interface{}{"device_name": "sensor123"}, + expected: 1.0, // LENGTH function works normally, "sensor123" length is 9 > 5, returns 1 + wantErr: false, + }, + { + name: "Simple CASE with function", + exprStr: "CASE ABS(temperature) WHEN 30 THEN 1 WHEN 25 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": -30.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "Function in CASE result", + exprStr: "CASE WHEN temperature > 30 THEN ABS(temperature) ELSE ROUND(temperature) END", + data: map[string]interface{}{"temperature": 35.5}, + expected: 35.5, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expression, err := NewExpression(tt.exprStr) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err, "Expression creation should not fail") + assert.NotNil(t, expression, "Expression should not be nil") + + // Test expression evaluation + result, err := expression.Evaluate(tt.data) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err, "Expression evaluation should not fail") + assert.Equal(t, tt.expected, result, "Expression result should match expected value") + }) + } +} + +// TestCaseExpressionFieldExtraction 测试CASE表达式的字段提取功能 +func TestCaseExpressionFieldExtraction(t *testing.T) { + testCases := []struct { + name string + exprStr string + expectedFields []string + }{ + { + name: "简单CASE字段提取", + exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", + expectedFields: []string{"temperature"}, + }, + { + name: "多字段CASE字段提取", + exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 ELSE 0 END", + expectedFields: []string{"temperature", "humidity"}, + }, + { + name: "简单CASE字段提取", + exprStr: "CASE status WHEN 'active' THEN temperature ELSE humidity END", + expectedFields: []string{"status", "temperature", "humidity"}, + }, + { + name: "函数CASE字段提取", + exprStr: "CASE WHEN ABS(temperature) > 30 THEN device_id ELSE location END", + expectedFields: []string{"temperature", "device_id", "location"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + expression, err := NewExpression(tc.exprStr) + assert.NoError(t, err, "表达式创建应该成功") + + fields := expression.GetFields() + + // 验证所有期望的字段都被提取到了 + for _, expectedField := range tc.expectedFields { + assert.Contains(t, fields, expectedField, "应该包含字段: %s", expectedField) + } + }) + } +} + +// TestCaseExpressionWithNullComparisons 测试CASE表达式中的NULL比较 +func TestCaseExpressionWithNullComparisons(t *testing.T) { + tests := []struct { + name string + exprStr string + data map[string]interface{} + expected interface{} // 使用interface{}以支持NULL值 + isNull bool + }{ + { + name: "NULL值在CASE条件中 - 应该走ELSE分支", + exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": nil}, + expected: 0.0, + isNull: false, + }, + { + name: "IS NULL条件 - 应该匹配", + exprStr: "CASE WHEN temperature IS NULL THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": nil}, + expected: 1.0, + isNull: false, + }, + { + name: "IS NOT NULL条件 - 不应该匹配", + exprStr: "CASE WHEN temperature IS NOT NULL THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": nil}, + expected: 0.0, + isNull: false, + }, + { + name: "CASE表达式返回NULL", + exprStr: "CASE WHEN temperature > 30 THEN temperature ELSE NULL END", + data: map[string]interface{}{"temperature": 25.0}, + expected: nil, + isNull: true, + }, + { + name: "CASE表达式返回有效值", + exprStr: "CASE WHEN temperature > 30 THEN temperature ELSE NULL END", + data: map[string]interface{}{"temperature": 35.0}, + expected: 35.0, + isNull: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expression, err := NewExpression(tt.exprStr) + assert.NoError(t, err, "表达式解析应该成功") + + // 测试支持NULL的计算方法 + result, isNull, err := expression.EvaluateWithNull(tt.data) + assert.NoError(t, err, "表达式计算应该成功") + + if tt.isNull { + assert.True(t, isNull, "表达式应该返回NULL") + } else { + assert.False(t, isNull, "表达式不应该返回NULL") + assert.Equal(t, tt.expected, result, "表达式结果应该匹配期望值") + } + }) + } +} + +// TestNegativeNumberSupport 专门测试负数支持 +func TestNegativeNumberSupport(t *testing.T) { + tests := []struct { + name string + exprStr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + { + name: "负数常量在THEN中", + exprStr: "CASE WHEN temperature > 0 THEN 1 ELSE -1 END", + data: map[string]interface{}{"temperature": -5.0}, + expected: -1.0, + wantErr: false, + }, + { + name: "负数常量在WHEN中", + exprStr: "CASE WHEN temperature < -10 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": -15.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负数小数", + exprStr: "CASE WHEN temperature > 0 THEN 1.5 ELSE -2.5 END", + data: map[string]interface{}{"temperature": -1.0}, + expected: -2.5, + wantErr: false, + }, + { + name: "负数在算术表达式中", + exprStr: "CASE WHEN temperature + (-10) > 0 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 15.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负数与函数", + exprStr: "CASE WHEN ABS(temperature) > 10 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": -15.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负数在简单CASE中", + exprStr: "CASE temperature WHEN -10 THEN 1 WHEN -20 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": -10.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负零", + exprStr: "CASE WHEN temperature = -0 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 0.0}, + expected: 1.0, + wantErr: false, + }, + // 基本负数运算 + { + name: "直接负数", + exprStr: "-5", + data: map[string]interface{}{}, + expected: -5.0, + wantErr: false, + }, + { + name: "负数加法", + exprStr: "-5 + 3", + data: map[string]interface{}{}, + expected: -2.0, + wantErr: false, + }, + { + name: "负数乘法", + exprStr: "-3 * 4", + data: map[string]interface{}{}, + expected: -12.0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expression, err := NewExpression(tt.exprStr) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err, "负数表达式解析应该成功") + assert.NotNil(t, expression, "表达式不应为空") + + // 测试表达式计算 + result, err := expression.Evaluate(tt.data) + assert.NoError(t, err, "负数表达式计算应该成功") + assert.Equal(t, tt.expected, result, "负数表达式结果应该匹配期望值") + }) + } +} + +func TestGetFields(t *testing.T) { + tests := []struct { + expr string + expectedFields []string + }{ + {"a + b", []string{"a", "b"}}, + {"a + b * c", []string{"a", "b", "c"}}, + {"temperature * 1.8 + 32", []string{"temperature"}}, + {"abs(humidity - 50)", []string{"humidity"}}, + {"sqrt(x^2 + y^2)", []string{"x", "y"}}, + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + assert.NoError(t, err, "Expression parsing should not fail") + + fields := expr.GetFields() + + // 由于map迭代顺序不确定,我们只检查长度和包含关系 + assert.Equal(t, len(tt.expectedFields), len(fields), "Number of fields should match") + + for _, field := range tt.expectedFields { + found := false + for _, f := range fields { + if f == field { + found = true + break + } + } + assert.True(t, found, "Field %s should be found", field) + } + }) + } +} + +func TestParseError(t *testing.T) { + tests := []struct { + name string + expr string + }{ + {"Empty Expression", ""}, + {"Mismatched Parentheses", "a + (b * c"}, + {"Invalid Character", "a # b"}, + {"Double Operator", "a + * b"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewExpression(tt.expr) + assert.Error(t, err, "Expression parsing should fail") + }) + } +} + +// TestExpressionTokenization 测试表达式分词功能 +func TestExpressionTokenization(t *testing.T) { + tests := []struct { + name string + expr string + expected []string + }{ + {"Simple Expression", "a + b", []string{"a", "+", "b"}}, + {"With Numbers", "a + 123", []string{"a", "+", "123"}}, + {"With Parentheses", "(a + b) * c", []string{"(", "a", "+", "b", ")", "*", "c"}}, + {"With Functions", "abs(a)", []string{"abs", "(", "a", ")"}}, + {"With Decimals", "a + 3.14", []string{"a", "+", "3.14"}}, + {"With Negative Numbers", "-5 + a", []string{"-5", "+", "a"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens, err := tokenize(tt.expr) + require.NoError(t, err) + assert.Equal(t, tt.expected, tokens, "Tokenization should match expected") + }) + } +} + +// TestExpressionValidation 测试表达式验证功能 +func TestExpressionValidation(t *testing.T) { + tests := []struct { + name string + expr string + valid bool + errorMsg string + }{ + {"Valid Simple Expression", "a + b", true, ""}, + {"Valid Complex Expression", "(a + b) * c / d", true, ""}, + {"Invalid Empty Expression", "", false, "empty expression"}, + {"Invalid Mismatched Parentheses", "(a + b", false, "mismatched parentheses"}, + {"Invalid Double Operator", "a + + b", false, "consecutive operators"}, + {"Invalid Starting Operator", "+ a", false, "expression cannot start with operator"}, + {"Invalid Ending Operator", "a +", false, "expression cannot end with operator"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateBasicSyntax(tt.expr) + if tt.valid { + assert.NoError(t, err, "Expression should be valid") + } else { + assert.Error(t, err, "Expression should be invalid") + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg, "Error message should contain expected text") + } + } + }) + } +} + +// TestExpressionOperatorPrecedence 测试运算符优先级 +func TestExpressionOperatorPrecedence(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected float64 + }{ + {"Addition and Multiplication", "2 + 3 * 4", map[string]interface{}{}, 14}, // 2 + (3 * 4) = 14 + {"Subtraction and Division", "10 - 8 / 2", map[string]interface{}{}, 6}, // 10 - (8 / 2) = 6 + {"Power and Multiplication", "2 * 3 ^ 2", map[string]interface{}{}, 18}, // 2 * (3 ^ 2) = 18 + {"Parentheses Override", "(2 + 3) * 4", map[string]interface{}{}, 20}, // (2 + 3) * 4 = 20 + {"Complex Expression", "2 + 3 * 4 - 5 / 2", map[string]interface{}{}, 11.5}, // 2 + (3 * 4) - (5 / 2) = 11.5 + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + require.NoError(t, err, "Expression parsing should not fail") + + result, err := expr.Evaluate(tt.data) + require.NoError(t, err, "Expression evaluation should not fail") + assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") + }) + } +} + +// TestExpressionFunctions 测试内置函数 +func TestExpressionFunctions(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + {"ABS Positive", "abs(5)", map[string]interface{}{}, 5, false}, + {"ABS Negative", "abs(-5)", map[string]interface{}{}, 5, false}, + {"ABS Zero", "abs(0)", map[string]interface{}{}, 0, false}, + {"SQRT Valid", "sqrt(16)", map[string]interface{}{}, 4, false}, + {"SQRT Zero", "sqrt(0)", map[string]interface{}{}, 0, false}, + {"SQRT Negative", "sqrt(-1)", map[string]interface{}{}, 0, true}, + {"ROUND Positive", "round(3.7)", map[string]interface{}{}, 4, false}, + {"ROUND Negative", "round(-3.7)", map[string]interface{}{}, -4, false}, + {"ROUND Half", "round(3.5)", map[string]interface{}{}, 4, false}, + {"FLOOR Positive", "floor(3.7)", map[string]interface{}{}, 3, false}, + {"FLOOR Negative", "floor(-3.7)", map[string]interface{}{}, -4, false}, + {"CEIL Positive", "ceil(3.2)", map[string]interface{}{}, 4, false}, + {"CEIL Negative", "ceil(-3.2)", map[string]interface{}{}, -3, false}, + {"MAX Two Values", "max(5, 3)", map[string]interface{}{}, 5, false}, + {"MIN Two Values", "min(5, 3)", map[string]interface{}{}, 3, false}, + {"POW Function", "pow(2, 3)", map[string]interface{}{}, 8, false}, + {"LOG Function", "log(10)", map[string]interface{}{}, math.Log10(10), false}, + {"LOG10 Function", "log10(100)", map[string]interface{}{}, 2, false}, + {"EXP Function", "exp(1)", map[string]interface{}{}, math.E, false}, + {"SIN Function", "sin(0)", map[string]interface{}{}, 0, false}, + {"COS Function", "cos(0)", map[string]interface{}{}, 1, false}, + {"TAN Function", "tan(0)", map[string]interface{}{}, 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + require.NoError(t, err, "Expression parsing should not fail") + + result, err := expr.Evaluate(tt.data) + if tt.wantErr { + assert.Error(t, err, "Expected error") + } else { + require.NoError(t, err, "Expression evaluation should not fail") + assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") + } + }) + } +} + +// TestExpressionDataTypeConversion 测试数据类型转换 +func TestExpressionDataTypeConversion(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + {"String to Number", "a + 5", map[string]interface{}{"a": "10"}, 15, false}, + {"Integer to Float", "a + 3.5", map[string]interface{}{"a": 5}, 8.5, false}, + {"Float to Float", "a + b", map[string]interface{}{"a": 3.14, "b": 2.86}, 6.0, false}, + {"Boolean True", "a + 1", map[string]interface{}{"a": true}, 2, false}, + {"Boolean False", "a + 1", map[string]interface{}{"a": false}, 1, false}, + {"Invalid String", "a + 5", map[string]interface{}{"a": "invalid"}, 0, true}, + {"Nil Value", "a + 5", map[string]interface{}{"a": nil}, 0, true}, + {"Complex Type", "a + 5", map[string]interface{}{"a": map[string]interface{}{}}, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + require.NoError(t, err, "Expression parsing should not fail") + + result, err := expr.Evaluate(tt.data) + if tt.wantErr { + assert.Error(t, err, "Expected error") + } else { + require.NoError(t, err, "Expression evaluation should not fail") + assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") + } + }) + } +} + +// TestExpressionEdgeCases 测试边界情况 +func TestExpressionEdgeCases(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + {"Very Large Number", "a + 1", map[string]interface{}{"a": 1e308}, 1e308 + 1, false}, + {"Very Small Number", "a + 1", map[string]interface{}{"a": 1e-308}, 1, false}, + {"Infinity", "a + 1", map[string]interface{}{"a": math.Inf(1)}, math.Inf(1), false}, + {"Negative Infinity", "a + 1", map[string]interface{}{"a": math.Inf(-1)}, math.Inf(-1), false}, + {"NaN", "a + 1", map[string]interface{}{"a": math.NaN()}, 0, true}, + {"Division by Zero", "5 / 0", map[string]interface{}{}, 0, true}, + {"Modulo by Zero", "5 % 0", map[string]interface{}{}, 0, true}, + {"Zero Power Zero", "0 ^ 0", map[string]interface{}{}, 1, false}, // 0^0 = 1 by convention + {"Negative Power", "2 ^ -3", map[string]interface{}{}, 0.125, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + require.NoError(t, err, "Expression parsing should not fail") + + result, err := expr.Evaluate(tt.data) + if tt.wantErr { + assert.Error(t, err, "Expected error") + } else { + require.NoError(t, err, "Expression evaluation should not fail") + if math.IsInf(tt.expected, 0) { + assert.True(t, math.IsInf(result, 0), "Result should be infinity") + } else { + assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") + } + } + }) + } +} + +// TestExpressionConcurrency 测试并发安全性 +func TestExpressionConcurrency(t *testing.T) { + expr, err := NewExpression("a + b * c") + require.NoError(t, err, "Expression parsing should not fail") + + // 并发执行多个计算 + const numGoroutines = 100 + results := make(chan float64, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(index int) { + data := map[string]interface{}{ + "a": float64(index), + "b": float64(index * 2), + "c": float64(index * 3), + } + result, err := expr.Evaluate(data) + assert.NoError(t, err, "Concurrent evaluation should not fail") + results <- result + }(i) + } + + // 收集结果 + for i := 0; i < numGoroutines; i++ { + result := <-results + // 验证结果是合理的(非零且非NaN) + assert.False(t, math.IsNaN(result), "Result should not be NaN") + assert.True(t, result >= 0, "Result should be non-negative for this test") + } +} + +// TestExpressionComplexNesting 测试复杂嵌套表达式 +func TestExpressionComplexNesting(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected float64 + }{ + { + "Deeply Nested Parentheses", + "((a + b) * (c - d)) / ((e + f) * (g - h))", + map[string]interface{}{"a": 1, "b": 2, "c": 5, "d": 3, "e": 2, "f": 3, "g": 7, "h": 2}, + 0.24, // ((1+2)*(5-3))/((2+3)*(7-2)) = (3*2)/(5*5) = 6/25 = 0.24 + }, + { + "Nested Functions", + "sqrt(abs(a - b) + pow(c, 2))", + map[string]interface{}{"a": 3, "b": 7, "c": 3}, + 3.606, // sqrt(abs(3-7) + pow(3,2)) = sqrt(4 + 9) = sqrt(13) ≈ 3.606 + }, + { + "Mixed Operations", + "a * b + c / d - e % f + pow(g, h)", + map[string]interface{}{"a": 2, "b": 3, "c": 8, "d": 2, "e": 7, "f": 3, "g": 2, "h": 3}, + 17, // 2*3 + 8/2 - 7%3 + pow(2,3) = 6 + 4 - 1 + 8 = 17 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + require.NoError(t, err, "Expression parsing should not fail") + + result, err := expr.Evaluate(tt.data) + require.NoError(t, err, "Expression evaluation should not fail") + assert.InDelta(t, tt.expected, result, 0.1, "Result should match expected value") + }) + } +} + +// TestExpressionStringHandling 测试字符串处理 +func TestExpressionStringHandling(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + {"String Length", "len(name)", map[string]interface{}{"name": "hello"}, 5, false}, + {"Empty String Length", "len(name)", map[string]interface{}{"name": ""}, 0, false}, + {"String Comparison Equal", "name == 'test'", map[string]interface{}{"name": "test"}, 1, false}, + {"String Comparison Not Equal", "name != 'test'", map[string]interface{}{"name": "hello"}, 1, false}, + {"String to Number Conversion", "val + 10", map[string]interface{}{"val": "5"}, 15, false}, + {"Invalid String to Number", "val + 10", map[string]interface{}{"val": "abc"}, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + require.NoError(t, err, "Expression parsing should not fail") + + result, err := expr.Evaluate(tt.data) + if tt.wantErr { + assert.Error(t, err, "Expected error") + } else { + require.NoError(t, err, "Expression evaluation should not fail") + assert.InDelta(t, tt.expected, result, 0.001, "Result should match expected value") + } + }) + } +} + +// TestExpressionPerformance 测试表达式性能 +func TestExpressionPerformance(t *testing.T) { + // 创建一个复杂表达式 + expr, err := NewExpression("sqrt(pow(a, 2) + pow(b, 2)) + abs(c - d) * (e + f) / (g + 1)") + require.NoError(t, err, "Expression parsing should not fail") + + data := map[string]interface{}{ + "a": 3.0, "b": 4.0, "c": 10.0, "d": 7.0, "e": 2.0, "f": 3.0, "g": 4.0, + } + + // 执行多次计算以测试性能 + const iterations = 10000 + for i := 0; i < iterations; i++ { + _, err := expr.Evaluate(data) + assert.NoError(t, err, "Performance test evaluation should not fail") + } +} + +// TestExpressionMemoryUsage 测试内存使用 +func TestExpressionMemoryUsage(t *testing.T) { + // 创建多个表达式实例 + const numExpressions = 1000 + expressions := make([]*Expression, numExpressions) + + for i := 0; i < numExpressions; i++ { + expr, err := NewExpression("a + b * c") + require.NoError(t, err, "Expression creation should not fail") + expressions[i] = expr + } + + // 验证所有表达式都能正常工作 + data := map[string]interface{}{"a": 1, "b": 2, "c": 3} + for i, expr := range expressions { + result, err := expr.Evaluate(data) + assert.NoError(t, err, "Expression %d evaluation should not fail", i) + assert.Equal(t, 7.0, result, "Expression %d result should be correct", i) + } +} + +func TestEvaluateWithExprLang(t *testing.T) { + expr := &Expression{ + useExprLang: true, + exprLangExpression: "a + b", + } + + tests := []struct { + name string + data map[string]interface{} + expectError bool + }{ + { + name: "valid expression", + data: map[string]interface{}{"a": 1.0, "b": 2.0}, + expectError: false, + }, + { + name: "missing variables", + data: map[string]interface{}{}, + expectError: true, + }, + { + name: "invalid expression", + data: map[string]interface{}{"a": 1.0}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := expr.evaluateWithExprLang(tt.data) + if tt.expectError && err == nil { + t.Error("expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestExtractFieldsFromExprLang(t *testing.T) { + tests := []struct { + name string + expr string + expected []string + }{ + { + name: "simple fields", + expr: "a + b * c", + expected: []string{"a", "b", "c"}, + }, + { + name: "nested fields", + expr: "user.name + user.age", + expected: []string{"user.name", "user.age"}, + }, + { + name: "with numbers", + expr: "field1 + 123 + field2", + expected: []string{"field1", "field2"}, + }, + { + name: "with functions", + expr: "sum(field1) + field2", + expected: []string{"field1", "field2"}, + }, + { + name: "with keywords", + expr: "field1 and field2 or field3", + expected: []string{"field1", "field2", "field3"}, + }, + { + name: "empty expression", + expr: "", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractFieldsFromExprLang(tt.expr) + + // Convert to map for easier comparison + expectedMap := make(map[string]bool) + for _, field := range tt.expected { + expectedMap[field] = true + } + + resultMap := make(map[string]bool) + for _, field := range result { + resultMap[field] = true + } + + if len(expectedMap) != len(resultMap) { + t.Errorf("expected %d fields, got %d", len(tt.expected), len(result)) + return + } + + for field := range expectedMap { + if !resultMap[field] { + t.Errorf("expected field %s not found in result", field) + } + } + }) + } +} + +func TestIsValidFieldIdentifier(t *testing.T) { + tests := []struct { + name string + field string + expected bool + }{ + { + name: "simple field", + field: "field1", + expected: true, + }, + { + name: "nested field", + field: "user.name", + expected: true, + }, + { + name: "deep nested field", + field: "user.profile.address.city", + expected: true, + }, + { + name: "field with underscore", + field: "user_name", + expected: true, + }, + { + name: "empty string", + field: "", + expected: false, + }, + { + name: "invalid field with special chars", + field: "user@name", + expected: false, + }, + { + name: "field starting with number", + field: "1field", + expected: false, + }, + { + name: "nested field with invalid part", + field: "user.1name", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidFieldIdentifier(tt.field) + if result != tt.expected { + t.Errorf("expected %v, got %v for field '%s'", tt.expected, result, tt.field) + } + }) + } +} + +func TestIsFunctionOrKeyword(t *testing.T) { + tests := []struct { + name string + token string + expected bool + }{ + { + name: "keyword and", + token: "and", + expected: true, + }, + { + name: "keyword or", + token: "or", + expected: true, + }, + { + name: "keyword not", + token: "not", + expected: true, + }, + { + name: "keyword case", + token: "case", + expected: true, + }, + { + name: "keyword when", + token: "when", + expected: true, + }, + { + name: "keyword then", + token: "then", + expected: true, + }, + { + name: "keyword else", + token: "else", + expected: true, + }, + { + name: "keyword end", + token: "end", + expected: true, + }, + { + name: "keyword is", + token: "is", + expected: true, + }, + { + name: "keyword null", + token: "null", + expected: true, + }, + { + name: "keyword true", + token: "true", + expected: true, + }, + { + name: "keyword false", + token: "false", + expected: true, + }, + { + name: "regular field", + token: "field1", + expected: false, + }, + { + name: "number", + token: "123", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isFunctionOrKeyword(tt.token) + if result != tt.expected { + t.Errorf("expected %v, got %v for token '%s'", tt.expected, result, tt.token) + } + }) + } +} + +func TestEvaluateBool(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expected bool + hasError bool + }{ + { + name: "true condition", + expr: "field1 > 0", + data: map[string]interface{}{"field1": 5}, + expected: true, + hasError: false, + }, + { + name: "false condition", + expr: "field1 > 10", + data: map[string]interface{}{"field1": 5}, + expected: false, + hasError: false, + }, + { + name: "zero value", + expr: "field1", + data: map[string]interface{}{"field1": 0}, + expected: false, + hasError: false, + }, + { + name: "non-zero value", + expr: "field1", + data: map[string]interface{}{"field1": 1}, + expected: true, + hasError: false, + }, + { + name: "missing field", + expr: "field1 > 0", + data: map[string]interface{}{}, + expected: false, + hasError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + if err != nil { + t.Errorf("failed to create expression: %v", err) + return + } + + result, err := expr.EvaluateBool(tt.data) + if tt.hasError && err == nil { + t.Error("expected error but got none") + return + } + if !tt.hasError && err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if !tt.hasError && result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestEvaluateValueWithNull(t *testing.T) { + tests := []struct { + name string + expr string + data map[string]interface{} + expectNull bool + expectError bool + }{ + { + name: "valid expression", + expr: "field1 + field2", + data: map[string]interface{}{"field1": 1, "field2": 2}, + expectNull: false, + expectError: false, + }, + { + name: "missing field", + expr: "field1 + field2", + data: map[string]interface{}{"field1": 1}, + expectNull: false, // 实际行为:返回nil但isNull为false + expectError: false, + }, + { + name: "invalid expression", + expr: "field1 + field2 +", // 使用无效语法 + data: map[string]interface{}{}, + expectNull: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expr, err := NewExpression(tt.expr) + if err != nil { + if tt.expectError { + // 如果期望错误,那么创建表达式失败是正常的 + return + } + t.Errorf("failed to create expression: %v", err) + return + } + + result, isNull, err := expr.EvaluateValueWithNull(tt.data) + if tt.expectError && err == nil { + t.Error("expected error but got none") + return + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if !tt.expectError && isNull != tt.expectNull { + t.Errorf("expected null=%v, got null=%v", tt.expectNull, isNull) + } + // 对于缺失字段的情况,允许result为nil + if !tt.expectError && !tt.expectNull && result == nil && tt.name != "missing field" { + t.Error("expected non-nil result but got nil") + } + }) + } +} diff --git a/expr/parser.go b/expr/parser.go new file mode 100644 index 0000000..16a0ddb --- /dev/null +++ b/expr/parser.go @@ -0,0 +1,342 @@ +package expr + +import ( + "fmt" + "strings" +) + +// ParseExpression parses expression token list to AST +func ParseExpression(tokens []string) (*ExprNode, error) { + return parseExpression(tokens) +} + +// parseExpression parses expression token list to AST +func parseExpression(tokens []string) (*ExprNode, error) { + if len(tokens) == 0 { + return nil, fmt.Errorf("empty token list") + } + + // Handle CASE expression + if len(tokens) > 0 && strings.ToUpper(tokens[0]) == "CASE" { + node, _, err := parseCaseExpression(tokens) + return node, err + } + + node, _, err := parseOrExpression(tokens) + return node, err +} + +// parseOrExpression parses OR expression +func parseOrExpression(tokens []string) (*ExprNode, []string, error) { + left, remaining, err := parseAndExpression(tokens) + if err != nil { + return nil, nil, err + } + + for len(remaining) > 0 && strings.ToUpper(remaining[0]) == "OR" { + right, newRemaining, err := parseAndExpression(remaining[1:]) + if err != nil { + return nil, nil, err + } + + left = &ExprNode{ + Type: TypeOperator, + Value: "OR", + Left: left, + Right: right, + } + remaining = newRemaining + } + + return left, remaining, nil +} + +// parseAndExpression parses AND expression +func parseAndExpression(tokens []string) (*ExprNode, []string, error) { + left, remaining, err := parseComparisonExpression(tokens) + if err != nil { + return nil, nil, err + } + + for len(remaining) > 0 && strings.ToUpper(remaining[0]) == "AND" { + right, newRemaining, err := parseComparisonExpression(remaining[1:]) + if err != nil { + return nil, nil, err + } + + left = &ExprNode{ + Type: TypeOperator, + Value: "AND", + Left: left, + Right: right, + } + remaining = newRemaining + } + + return left, remaining, nil +} + +// parseComparisonExpression parses comparison expression +func parseComparisonExpression(tokens []string) (*ExprNode, []string, error) { + left, remaining, err := parseArithmeticExpression(tokens) + if err != nil { + return nil, nil, err + } + + // Check IS NOT operator (two tokens) + if len(remaining) >= 2 && strings.ToUpper(remaining[0]) == "IS" && strings.ToUpper(remaining[1]) == "NOT" { + op := "IS NOT" + right, newRemaining, err := parseArithmeticExpression(remaining[2:]) + if err != nil { + return nil, nil, err + } + + return &ExprNode{ + Type: TypeOperator, + Value: op, + Left: left, + Right: right, + }, newRemaining, nil + } + + // Check single token comparison operators + if len(remaining) > 0 && isComparisonOperator(remaining[0]) { + op := remaining[0] + right, newRemaining, err := parseArithmeticExpression(remaining[1:]) + if err != nil { + return nil, nil, err + } + + return &ExprNode{ + Type: TypeOperator, + Value: op, + Left: left, + Right: right, + }, newRemaining, nil + } + + return left, remaining, nil +} + +// parseArithmeticExpression parses arithmetic expression +func parseArithmeticExpression(tokens []string) (*ExprNode, []string, error) { + left, remaining, err := parseTermExpression(tokens) + if err != nil { + return nil, nil, err + } + + for len(remaining) > 0 && (remaining[0] == "+" || remaining[0] == "-") { + op := remaining[0] + right, newRemaining, err := parseTermExpression(remaining[1:]) + if err != nil { + return nil, nil, err + } + + left = &ExprNode{ + Type: TypeOperator, + Value: op, + Left: left, + Right: right, + } + remaining = newRemaining + } + + return left, remaining, nil +} + +// parseTermExpression parses term expression (multiply, divide, modulo) +func parseTermExpression(tokens []string) (*ExprNode, []string, error) { + left, remaining, err := parsePowerExpression(tokens) + if err != nil { + return nil, nil, err + } + + for len(remaining) > 0 && (remaining[0] == "*" || remaining[0] == "/" || remaining[0] == "%") { + op := remaining[0] + right, newRemaining, err := parsePowerExpression(remaining[1:]) + if err != nil { + return nil, nil, err + } + + left = &ExprNode{ + Type: TypeOperator, + Value: op, + Left: left, + Right: right, + } + remaining = newRemaining + } + + return left, remaining, nil +} + +// parsePowerExpression parses power expression +func parsePowerExpression(tokens []string) (*ExprNode, []string, error) { + left, remaining, err := parseUnaryExpression(tokens) + if err != nil { + return nil, nil, err + } + + if len(remaining) > 0 && remaining[0] == "^" { + right, newRemaining, err := parsePowerExpression(remaining[1:]) // Right associative + if err != nil { + return nil, nil, err + } + + return &ExprNode{ + Type: TypeOperator, + Value: "^", + Left: left, + Right: right, + }, newRemaining, nil + } + + return left, remaining, nil +} + +// parseUnaryExpression parses unary expression +func parseUnaryExpression(tokens []string) (*ExprNode, []string, error) { + if len(tokens) == 0 { + return nil, nil, fmt.Errorf("unexpected end of expression") + } + + // Handle unary minus + if tokens[0] == "-" { + operand, remaining, err := parseUnaryExpression(tokens[1:]) + if err != nil { + return nil, nil, err + } + + return &ExprNode{ + Type: TypeOperator, + Value: "-", + Left: &ExprNode{ + Type: TypeNumber, + Value: "0", + }, + Right: operand, + }, remaining, nil + } + + return parsePrimaryExpression(tokens) +} + +// parsePrimaryExpression parses primary expression +func parsePrimaryExpression(tokens []string) (*ExprNode, []string, error) { + if len(tokens) == 0 { + return nil, nil, fmt.Errorf("unexpected end of expression") + } + + token := tokens[0] + + // Handle parentheses + if token == "(" { + expr, remaining, err := parseOrExpression(tokens[1:]) + if err != nil { + return nil, nil, err + } + + if len(remaining) == 0 || remaining[0] != ")" { + return nil, nil, fmt.Errorf("missing closing parenthesis") + } + + // Create parenthesis node + return &ExprNode{ + Type: TypeParenthesis, + Left: expr, + }, remaining[1:], nil + } + + // Handle numbers + if isNumber(token) { + return &ExprNode{ + Type: TypeNumber, + Value: token, + }, tokens[1:], nil + } + + // Handle string literals + if isStringLiteral(token) { + return &ExprNode{ + Type: TypeString, + Value: token, + }, tokens[1:], nil + } + + // Handle function calls + if len(tokens) > 1 && tokens[1] == "(" { + return parseFunctionCall(tokens) + } + + // Check for invalid function calls (identifier followed by non-parenthesis token) + // But exclude keywords in CASE expressions + if isIdentifier(token) && len(tokens) > 1 && tokens[1] != "(" && !isOperator(tokens[1]) && tokens[1] != ")" && tokens[1] != "," { + // Allow keywords in CASE expressions + nextToken := strings.ToUpper(tokens[1]) + if nextToken != "WHEN" && nextToken != "THEN" && nextToken != "ELSE" && nextToken != "END" { + return nil, nil, fmt.Errorf("invalid function call") + } + } + + // Handle field references + if isIdentifier(token) || (len(token) >= 2 && token[0] == '`' && token[len(token)-1] == '`') { + return &ExprNode{ + Type: TypeField, + Value: token, + }, tokens[1:], nil + } + + return nil, nil, fmt.Errorf("unexpected token: %s", token) +} + +// parseFunctionCall parses function call +func parseFunctionCall(tokens []string) (*ExprNode, []string, error) { + if len(tokens) < 2 || tokens[1] != "(" { + return nil, nil, fmt.Errorf("invalid function call") + } + + funcName := tokens[0] + remaining := tokens[2:] // Skip function name and opening parenthesis + + var args []*ExprNode + + // Handle empty parameter list + if len(remaining) > 0 && remaining[0] == ")" { + return &ExprNode{ + Type: TypeFunction, + Value: funcName, + Args: args, + }, remaining[1:], nil + } + + // Parse arguments + for { + arg, newRemaining, err := parseOrExpression(remaining) + if err != nil { + return nil, nil, err + } + + args = append(args, arg) + remaining = newRemaining + + if len(remaining) == 0 { + return nil, nil, fmt.Errorf("missing closing parenthesis in function call") + } + + if remaining[0] == ")" { + break + } + + if remaining[0] != "," { + return nil, nil, fmt.Errorf("expected ',' or ')' in function call") + } + + remaining = remaining[1:] // Skip comma + } + + return &ExprNode{ + Type: TypeFunction, + Value: funcName, + Args: args, + }, remaining[1:], nil // Skip closing parenthesis +} diff --git a/expr/parser_test.go b/expr/parser_test.go new file mode 100644 index 0000000..6786cba --- /dev/null +++ b/expr/parser_test.go @@ -0,0 +1,312 @@ +package expr + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestParseExpression 测试表达式解析功能 +func TestParseExpression(t *testing.T) { + tests := []struct { + name string + tokens []string + expectError bool + description string + }{ + { + name: "empty tokens", + tokens: []string{}, + expectError: true, + description: "should return error for empty tokens", + }, + { + name: "valid simple expression", + tokens: []string{"field1"}, + expectError: false, + description: "should parse simple field", + }, + { + name: "valid case expression", + tokens: []string{"CASE", "WHEN", "field1", "THEN", "1", "END"}, + expectError: false, + description: "should parse CASE expression", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseExpression(tt.tokens) + if tt.expectError && err == nil { + t.Errorf("expected error but got none: %s", tt.description) + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestParseUnaryExpression(t *testing.T) { + tests := []struct { + name string + tokens []string + expectError bool + description string + }{ + { + name: "empty tokens", + tokens: []string{}, + expectError: true, + description: "should return error for empty tokens", + }, + { + name: "unary minus with number", + tokens: []string{"-", "5"}, + expectError: false, + description: "should parse unary minus with number", + }, + { + name: "unary minus with field", + tokens: []string{"-", "field1"}, + expectError: false, + description: "should parse unary minus with field", + }, + { + name: "unary minus with expression", + tokens: []string{"-", "(", "field1", "+", "field2", ")"}, + expectError: false, + description: "should parse unary minus with expression", + }, + { + name: "unary minus with function", + tokens: []string{"-", "sum", "(", "field1", ")"}, + expectError: false, + description: "should parse unary minus with function", + }, + { + name: "unary minus with string", + tokens: []string{"-", "'value'"}, + expectError: false, + description: "should parse unary minus with string", + }, + { + name: "unary minus with missing operand", + tokens: []string{"-"}, + expectError: true, + description: "should return error for missing operand", + }, + { + name: "nested unary minus", + tokens: []string{"-", "-", "5"}, + expectError: false, + description: "should parse nested unary minus", + }, + { + name: "unary minus with complex expression", + tokens: []string{"-", "field1", "*", "field2"}, + expectError: false, + description: "should parse unary minus with complex expression", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := parseUnaryExpression(tt.tokens) + if tt.expectError && err == nil { + t.Errorf("expected error but got none: %s", tt.description) + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +// TestParseExpressionWithPrecedence 测试运算符优先级解析 +func TestParseExpressionWithPrecedence(t *testing.T) { + tests := []struct { + name string + tokens []string + expected string // 用字符串表示预期的树结构 + }{ + {"加法和乘法", []string{"a", "+", "b", "*", "c"}, "(a + (b * c))"}, + {"乘法和除法", []string{"a", "*", "b", "/", "c"}, "((a * b) / c)"}, + {"幂运算", []string{"a", "^", "b", "^", "c"}, "(a ^ (b ^ c))"}, + {"混合运算", []string{"a", "+", "b", "*", "c", "^", "d"}, "(a + (b * (c ^ d)))"}, + {"比较运算符", []string{"a", "+", "b", ">", "c"}, "((a + b) > c)"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseExpression(tt.tokens) + require.NoError(t, err, "解析不应该失败") + actual := nodeToString(result) + assert.Equal(t, tt.expected, actual, "运算符优先级应该正确") + }) + } +} + +// TestParseFunction 测试函数解析 +func TestParseFunction(t *testing.T) { + tests := []struct { + name string + tokens []string + pos int + expected *ExprNode + newPos int + wantErr bool + }{ + { + "无参数函数", + []string{"now", "(", ")"}, + 0, + &ExprNode{Type: TypeFunction, Value: "now", Args: []*ExprNode{}}, + 3, + false, + }, + { + "单参数函数", + []string{"abs", "(", "x", ")"}, + 0, + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeField, Value: "x"}}, + }, + 4, + false, + }, + { + "多参数函数", + []string{"max", "(", "a", ",", "b", ",", "c", ")"}, + 0, + &ExprNode{ + Type: TypeFunction, + Value: "max", + Args: []*ExprNode{ + {Type: TypeField, Value: "a"}, + {Type: TypeField, Value: "b"}, + {Type: TypeField, Value: "c"}, + }, + }, + 8, + false, + }, + { + "嵌套函数", + []string{"sqrt", "(", "pow", "(", "x", ",", "2", ")", ")"}, + 0, + &ExprNode{ + Type: TypeFunction, + Value: "sqrt", + Args: []*ExprNode{ + { + Type: TypeFunction, + Value: "pow", + Args: []*ExprNode{ + {Type: TypeField, Value: "x"}, + {Type: TypeNumber, Value: "2"}, + }, + }, + }, + }, + 9, + false, + }, + // 错误情况 + {"缺少左括号", []string{"abs", "x", ")"}, 0, nil, 0, true}, + {"缺少右括号", []string{"abs", "(", "x"}, 0, nil, 0, true}, + {"参数分隔符错误", []string{"max", "(", "a", ";", "b", ")"}, 0, nil, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, newPos, err := parseFunction(tt.tokens, tt.pos) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "函数解析不应该失败") + assert.Equal(t, tt.newPos, newPos, "位置应该正确") + assertNodeEqual(t, tt.expected, result) + } + }) + } +} + +// TestGetOperatorPrecedence 测试运算符优先级获取 +func TestGetOperatorPrecedence(t *testing.T) { + tests := []struct { + op string + expected int + }{ + {"OR", 1}, + {"AND", 2}, + {"NOT", 3}, + {"=", 4}, + {"==", 4}, + {"!=", 4}, + {"<>", 4}, + {"<", 4}, + {">", 4}, + {"<=", 4}, + {">=", 4}, + {"LIKE", 4}, + {"IS", 4}, + {"+", 5}, + {"-", 5}, + {"*", 6}, + {"/", 6}, + {"%", 6}, + {"^", 7}, + {"unknown", 0}, + } + + for _, tt := range tests { + t.Run(tt.op, func(t *testing.T) { + result := getOperatorPrecedence(tt.op) + assert.Equal(t, tt.expected, result, "运算符优先级应该正确") + }) + } +} + +// TestIsRightAssociative 测试右结合性判断 +func TestIsRightAssociative(t *testing.T) { + tests := []struct { + op string + expected bool + }{ + {"^", true}, + {"+", false}, + {"-", false}, + {"*", false}, + {"/", false}, + {"=", false}, + {"AND", false}, + {"OR", false}, + } + + for _, tt := range tests { + t.Run(tt.op, func(t *testing.T) { + result := isRightAssociative(tt.op) + assert.Equal(t, tt.expected, result, "右结合性判断应该正确") + }) + } +} + +// 辅助函数:连接字符串数组 +func joinStrings(strs []string, sep string) string { + if len(strs) == 0 { + return "" + } + if len(strs) == 1 { + return strs[0] + } + + result := strs[0] + for i := 1; i < len(strs); i++ { + result += sep + strs[i] + } + return result +} diff --git a/expr/tokenizer.go b/expr/tokenizer.go new file mode 100644 index 0000000..64965d7 --- /dev/null +++ b/expr/tokenizer.go @@ -0,0 +1,280 @@ +package expr + +import ( + "fmt" + "strings" + "unicode" +) + +// TokenType represents token type +type TokenType int + +const ( + // TokenKeyword keyword token + TokenKeyword TokenType = iota + // TokenField field token + TokenField + // TokenOperator operator token + TokenOperator + // TokenNumber number token + TokenNumber + // TokenString string token + TokenString + // TokenLeftParen left parenthesis token + TokenLeftParen + // TokenRightParen right parenthesis token + TokenRightParen + // TokenComma comma token + TokenComma +) + +// Token represents a token +type Token struct { + // Type token type + Type TokenType + // Value token value + Value string +} + +// tokenize breaks expression string into token list +// Supports numbers, identifiers, operators, parentheses, string literals, etc. +func tokenize(expr string) ([]string, error) { + // Check empty expression + if len(strings.TrimSpace(expr)) == 0 { + return nil, fmt.Errorf("empty expression") + } + + var tokens []string + i := 0 + + for i < len(expr) { + // Skip whitespace characters + if unicode.IsSpace(rune(expr[i])) { + i++ + continue + } + + // Handle string literals + if expr[i] == '\'' || expr[i] == '"' { + quote := expr[i] + start := i + i++ // Skip opening quote + + // Find closing quote + for i < len(expr) && expr[i] != quote { + if expr[i] == '\\' && i+1 < len(expr) { + i += 2 // Skip escape character + } else { + i++ + } + } + + if i >= len(expr) { + return nil, fmt.Errorf("unterminated string literal") + } + + i++ // Skip closing quote + tokens = append(tokens, expr[start:i]) + continue + } + + // Handle backtick identifiers + if expr[i] == '`' { + start := i + i++ // Skip opening backtick + + // Find closing backtick + for i < len(expr) && expr[i] != '`' { + i++ + } + + if i >= len(expr) { + return nil, fmt.Errorf("unterminated backtick identifier") + } + + i++ // Skip closing backtick + tokens = append(tokens, expr[start:i]) + continue + } + + // Handle numbers (including negative numbers and numbers starting with decimal point) + // Note: Numbers starting with decimal point are only valid when not preceded by digit character + if isDigit(expr[i]) || (expr[i] == '-' && i+1 < len(expr) && isDigit(expr[i+1])) || (expr[i] == '.' && i+1 < len(expr) && isDigit(expr[i+1]) && (i == 0 || (!isDigit(expr[i-1]) && expr[i-1] != '.'))) { + start := i + if expr[i] == '-' { + i++ // Skip negative sign + } + + // Read integer part + for i < len(expr) && isDigit(expr[i]) { + i++ + } + + // Handle decimal point (only one decimal point allowed) + hasDecimal := false + if i < len(expr) && expr[i] == '.' { + // Check if there's already a decimal point or next character is not a digit + if !hasDecimal && i+1 < len(expr) && isDigit(expr[i+1]) { + hasDecimal = true + i++ + // Read decimal part + for i < len(expr) && isDigit(expr[i]) { + i++ + } + } + } + + // Handle scientific notation + if i < len(expr) && (expr[i] == 'e' || expr[i] == 'E') { + i++ + if i < len(expr) && (expr[i] == '+' || expr[i] == '-') { + i++ + } + for i < len(expr) && isDigit(expr[i]) { + i++ + } + } + + tokens = append(tokens, expr[start:i]) + continue + } + + // Handle multi-character operators + if i+1 < len(expr) { + twoChar := expr[i : i+2] + if isOperator(twoChar) { + tokens = append(tokens, twoChar) + i += 2 + continue + } + } + + // Handle single-character operators and parentheses (including standalone decimal point) + if isOperator(string(expr[i])) || expr[i] == '(' || expr[i] == ')' || expr[i] == ',' || expr[i] == '.' { + tokens = append(tokens, string(expr[i])) + i++ + continue + } + + // Handle identifiers and keywords + if isLetter(expr[i]) || expr[i] == '_' || expr[i] == '$' { + start := i + for i < len(expr) && (isLetter(expr[i]) || isDigit(expr[i]) || expr[i] == '_' || expr[i] == '.' || expr[i] == '$') { + i++ + } + tokens = append(tokens, expr[start:i]) + continue + } + + // Unknown character + return nil, fmt.Errorf("unexpected character '%c' at position %d", expr[i], i) + } + + return tokens, nil +} + +// isDigit checks if character is a digit +func isDigit(ch byte) bool { + return ch >= '0' && ch <= '9' +} + +// isLetter checks if character is a letter +func isLetter(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') +} + +// isNumber checks if string is a number +func isNumber(s string) bool { + if len(s) == 0 { + return false + } + + i := 0 + // Handle negative sign + if s[0] == '-' { + i = 1 + if len(s) == 1 { + return false + } + } + + hasDigit := false + hasDot := false + + for i < len(s) { + if isDigit(s[i]) { + hasDigit = true + } else if s[i] == '.' && !hasDot { + hasDot = true + } else if s[i] == 'e' || s[i] == 'E' { + // Scientific notation + i++ + if i < len(s) && (s[i] == '+' || s[i] == '-') { + i++ + } + for i < len(s) && isDigit(s[i]) { + i++ + } + break + } else { + return false + } + i++ + } + + return hasDigit +} + +// isIdentifier checks if string is a valid identifier +func isIdentifier(s string) bool { + if len(s) == 0 { + return false + } + + // First character must be letter or underscore + if !isLetter(s[0]) && s[0] != '_' { + return false + } + + // Remaining characters can be letters, digits, or underscores + for i := 1; i < len(s); i++ { + if !isLetter(s[i]) && !isDigit(s[i]) && s[i] != '_' { + return false + } + } + + return true +} + +// isOperator checks if string is an operator +func isOperator(s string) bool { + operators := []string{ + "+", "-", "*", "/", "%", "^", + "=", "==", "!=", "<>", ">", "<", ">=", "<=", + "AND", "OR", "NOT", "LIKE", "IS", + } + + for _, op := range operators { + if strings.EqualFold(s, op) { + return true + } + } + + return false +} + +// isComparisonOperator checks if it's a comparison operator +func isComparisonOperator(op string) bool { + comparisonOps := []string{"==", "=", "!=", "<>", ">", "<", ">=", "<=", "LIKE", "IS"} + for _, compOp := range comparisonOps { + if strings.EqualFold(op, compOp) { + return true + } + } + return false +} + +// isStringLiteral checks if it's a string literal +func isStringLiteral(s string) bool { + return len(s) >= 2 && ((s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '"' && s[len(s)-1] == '"')) +} diff --git a/expr/tokenizer_test.go b/expr/tokenizer_test.go new file mode 100644 index 0000000..0d270b6 --- /dev/null +++ b/expr/tokenizer_test.go @@ -0,0 +1,352 @@ +package expr + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestTokenize 测试分词功能 +func TestTokenize(t *testing.T) { + tests := []struct { + name string + expr string + expected []string + wantErr bool + }{ + // 基本分词测试 + {"简单表达式", "a + b", []string{"a", "+", "b"}, false}, + {"数字和运算符", "123 + 456", []string{"123", "+", "456"}, false}, + {"小数", "3.14 * 2", []string{"3.14", "*", "2"}, false}, + {"负数", "-5 + 3", []string{"-5", "+", "3"}, false}, + {"负小数", "-3.14 * 2", []string{"-3.14", "*", "2"}, false}, + + // 括号和函数 + {"括号表达式", "(a + b) * c", []string{"(", "a", "+", "b", ")", "*", "c"}, false}, + {"函数调用", "abs(x)", []string{"abs", "(", "x", ")"}, false}, + {"函数参数", "max(a, b)", []string{"max", "(", "a", ",", "b", ")"}, false}, + + // 比较运算符 + {"等于运算符", "a == b", []string{"a", "==", "b"}, false}, + {"不等于运算符", "a != b", []string{"a", "!=", "b"}, false}, + {"大于等于", "a >= b", []string{"a", ">=", "b"}, false}, + {"小于等于", "a <= b", []string{"a", "<=", "b"}, false}, + {"不等于SQL风格", "a <> b", []string{"a", "<>", "b"}, false}, + + // 字符串字面量 + {"单引号字符串", "'hello'", []string{"'hello'"}, false}, + {"双引号字符串", "\"world\"", []string{"\"world\""}, false}, + {"字符串比较", "name == 'test'", []string{"name", "==", "'test'"}, false}, + {"包含转义的字符串", "'hello\\world'", []string{"'hello\\world'"}, false}, + + // 反引号标识符 + {"反引号字段", "`field name`", []string{"`field name`"}, false}, + {"反引号表达式", "`user.name` + `user.age`", []string{"`user.name`", "+", "`user.age`"}, false}, + + // CASE表达式 + {"简单CASE", "CASE WHEN a > 0 THEN 1 ELSE 0 END", []string{"CASE", "WHEN", "a", ">", "0", "THEN", "1", "ELSE", "0", "END"}, false}, + + // 复杂表达式 + {"复杂算术", "a + b * c - d / e", []string{"a", "+", "b", "*", "c", "-", "d", "/", "e"}, false}, + {"幂运算", "a ^ b", []string{"a", "^", "b"}, false}, + {"取模运算", "a % b", []string{"a", "%", "b"}, false}, + + // 空白字符处理 + {"多个空格", "a + b", []string{"a", "+", "b"}, false}, + {"制表符", "a\t+\tb", []string{"a", "+", "b"}, false}, + {"换行符", "a\n+\nb", []string{"a", "+", "b"}, false}, + + // 错误情况 + {"空表达式", "", nil, true}, + {"只有空格", " ", nil, true}, + {"未闭合字符串", "'hello", nil, true}, + {"未闭合反引号", "`field", nil, true}, + {"无效字符", "a @ b", nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tokenize(tt.expr) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "分词不应该失败") + assert.Equal(t, tt.expected, result, "分词结果应该匹配") + } + }) + } +} + +// TestIsDigit 测试数字字符判断 +func TestIsDigit(t *testing.T) { + tests := []struct { + ch byte + expected bool + }{ + {'0', true}, + {'5', true}, + {'9', true}, + {'a', false}, + {'A', false}, + {' ', false}, + {'.', false}, + {'+', false}, + } + + for _, tt := range tests { + t.Run(string(tt.ch), func(t *testing.T) { + result := isDigit(tt.ch) + assert.Equal(t, tt.expected, result, "数字字符判断应该正确") + }) + } +} + +// TestIsLetter 测试字母字符判断 +func TestIsLetter(t *testing.T) { + tests := []struct { + ch byte + expected bool + }{ + {'a', true}, + {'z', true}, + {'A', true}, + {'Z', true}, + {'0', false}, + {'9', false}, + {' ', false}, + {'_', false}, + {'+', false}, + } + + for _, tt := range tests { + t.Run(string(tt.ch), func(t *testing.T) { + result := isLetter(tt.ch) + assert.Equal(t, tt.expected, result, "字母字符判断应该正确") + }) + } +} + +// TestIsNumber 测试数字字符串判断 +func TestIsNumber(t *testing.T) { + tests := []struct { + s string + expected bool + }{ + {"123", true}, + {"0", true}, + {"3.14", true}, + {"-5", true}, + {"-3.14", true}, + {"1e10", true}, + {"1.5e-3", true}, + {"abc", false}, + {"12a", false}, + {"", false}, + {".", false}, + {"--5", false}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + result := isNumber(tt.s) + assert.Equal(t, tt.expected, result, "数字字符串判断应该正确") + }) + } +} + +// TestIsIdentifier 测试标识符判断 +func TestIsIdentifier(t *testing.T) { + tests := []struct { + s string + expected bool + }{ + {"abc", true}, + {"_var", true}, + {"var123", true}, + {"CamelCase", true}, + {"snake_case", true}, + {"123abc", false}, + {"", false}, + {"var-name", false}, + {"var.name", false}, + {"var name", false}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + result := isIdentifier(tt.s) + assert.Equal(t, tt.expected, result, "标识符判断应该正确") + }) + } +} + +// TestIsOperator 测试运算符判断 +func TestIsOperator(t *testing.T) { + tests := []struct { + s string + expected bool + }{ + {"+", true}, + {"-", true}, + {"*", true}, + {"/", true}, + {"%", true}, + {"^", true}, + {">", true}, + {"<", true}, + {">=", true}, + {"<=", true}, + {"==", true}, + {"=", true}, + {"!=", true}, + {"<>", true}, + {"AND", true}, + {"OR", true}, + {"NOT", true}, + {"LIKE", true}, + {"IS", true}, + {"abc", false}, + {"123", false}, + {"(", false}, + {")", false}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + result := isOperator(tt.s) + assert.Equal(t, tt.expected, result, "运算符判断应该正确") + }) + } +} + +// TestIsComparisonOperator 测试比较运算符判断 +func TestIsComparisonOperator(t *testing.T) { + tests := []struct { + s string + expected bool + }{ + {">", true}, + {"<", true}, + {">=", true}, + {"<=", true}, + {"==", true}, + {"=", true}, + {"!=", true}, + {"<>", true}, + {"+", false}, + {"-", false}, + {"*", false}, + {"/", false}, + {"AND", false}, + {"OR", false}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + result := isComparisonOperator(tt.s) + assert.Equal(t, tt.expected, result, "比较运算符判断应该正确") + }) + } +} + +// TestIsStringLiteral 测试字符串字面量判断 +func TestIsStringLiteral(t *testing.T) { + tests := []struct { + s string + expected bool + }{ + {"'hello'", true}, + {"\"world\"", true}, + {"''", true}, + {"\"\"", true}, + {"'hello", false}, + {"hello'", false}, + {"\"hello", false}, + {"hello\"", false}, + {"hello", false}, + {"", false}, + {"'", false}, + {"\"", false}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + result := isStringLiteral(tt.s) + assert.Equal(t, tt.expected, result, "字符串字面量判断应该正确") + }) + } +} + +// TestTokenizeComplexExpressions 测试复杂表达式分词 +func TestTokenizeComplexExpressions(t *testing.T) { + tests := []struct { + name string + expr string + expected []string + }{ + { + "温度转换表达式", + "temperature * 1.8 + 32", + []string{"temperature", "*", "1.8", "+", "32"}, + }, + { + "复杂CASE表达式", + "CASE WHEN temperature > 30 AND humidity < 60 THEN 'HOT' ELSE 'NORMAL' END", + []string{"CASE", "WHEN", "temperature", ">", "30", "AND", "humidity", "<", "60", "THEN", "'HOT'", "ELSE", "'NORMAL'", "END"}, + }, + { + "嵌套函数调用", + "sqrt(pow(a, 2) + pow(b, 2))", + []string{"sqrt", "(", "pow", "(", "a", ",", "2", ")", "+", "pow", "(", "b", ",", "2", ")", ")"}, + }, + { + "负数在比较运算符后", + "a > -5 AND b <= -3.14", + []string{"a", ">", "-5", "AND", "b", "<=", "-3.14"}, + }, + { + "幂运算后的负数", + "a ^ -2", + []string{"a", "^", "-2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tokenize(tt.expr) + require.NoError(t, err, "复杂表达式分词不应该失败") + assert.Equal(t, tt.expected, result, "复杂表达式分词结果应该匹配") + }) + } +} + +// TestTokenizeEdgeCases 测试边界情况 +func TestTokenizeEdgeCases(t *testing.T) { + tests := []struct { + name string + expr string + expected []string + wantErr bool + }{ + {"只有数字", "123", []string{"123"}, false}, + {"只有小数点开头的数字", ".5", []string{".5"}, false}, + {"连续运算符(应该在解析阶段检测)", "a + + b", []string{"a", "+", "+", "b"}, false}, + {"多个小数点", "3.14.15", []string{"3.14", ".", "15"}, false}, // 分词器不检查语法错误 + {"空字符串转义", "''", []string{"''"}, false}, + {"包含空格的反引号标识符", "`user name`", []string{"`user name`"}, false}, + {"特殊字符在字符串中", "'hello@world#test'", []string{"'hello@world#test'"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tokenize(tt.expr) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + require.NoError(t, err, "分词不应该失败") + assert.Equal(t, tt.expected, result, "分词结果应该匹配") + } + }) + } +} diff --git a/expr/utils.go b/expr/utils.go new file mode 100644 index 0000000..57f0275 --- /dev/null +++ b/expr/utils.go @@ -0,0 +1,261 @@ +package expr + +import ( + "fmt" + "github.com/rulego/streamsql/utils/cast" + "strings" +) + +// unquoteString removes quotes from both ends of string +func unquoteString(s string) string { + if len(s) >= 2 { + if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') { + return s[1 : len(s)-1] + } + } + return s +} + +// unquoteBacktick removes backticks from both ends of string +func unquoteBacktick(s string) string { + if len(s) >= 2 && s[0] == '`' && s[len(s)-1] == '`' { + return s[1 : len(s)-1] + } + return s +} + +// getNodeType gets node type +func getNodeType(node *ExprNode) string { + if node == nil { + return "nil" + } + return node.Type +} + +// getNodeValue gets node value +func getNodeValue(node *ExprNode) string { + if node == nil { + return "" + } + return node.Value +} + +// setNodeValue sets node value +func setNodeValue(node *ExprNode, value string) { + if node != nil { + node.Value = value + } +} + +// isArithmeticOperator checks if it's an arithmetic operator +func isArithmeticOperator(op string) bool { + switch op { + case "+", "-", "*", "/", "%", "^": + return true + default: + return false + } +} + +// isLogicalOperator checks if it's a logical operator +func isLogicalOperator(op string) bool { + switch strings.ToUpper(op) { + case "AND", "OR", "NOT": + return true + default: + return false + } +} + +// isUnaryOperator checks if it's a unary operator +func isUnaryOperator(op string) bool { + switch strings.ToUpper(op) { + case "NOT": + return true + default: + return false + } +} + +// isKeyword checks if it's a keyword +func isKeyword(word string) bool { + switch strings.ToUpper(word) { + case "CASE", "WHEN", "THEN", "ELSE", "END", "AND", "OR", "NOT", "LIKE", "IS", "NULL", "TRUE", "FALSE": + return true + default: + return false + } +} + +// normalizeIdentifier normalizes identifier (convert to lowercase) +func normalizeIdentifier(identifier string) string { + return strings.ToLower(identifier) +} + +// convertToFloat converts any type to float64 +func convertToFloat(value interface{}) (float64, error) { + return cast.ToFloat64E(value) +} + +// convertToFloatSafe safely converts any type to float64, returns conversion result and success status +func convertToFloatSafe(value interface{}) (float64, bool) { + result, err := convertToFloat(value) + return result, err == nil +} + +// convertToBool converts any type to boolean +func convertToBool(value interface{}) bool { + return cast.ToBool(value) +} + +// getOperatorPrecedence gets operator precedence +func getOperatorPrecedence(op string) int { + switch op { + case "OR": + return 1 + case "AND": + return 2 + case "NOT": + return 3 + case "=", "==", "!=", "<>", ">", "<", ">=", "<=", "LIKE", "NOT LIKE", "IS", "IS NOT": + return 4 + case "+", "-": + return 5 + case "*", "/", "%": + return 6 + case "^": + return 7 + default: + return 0 + } +} + +// isRightAssociative checks if operator is right associative +func isRightAssociative(op string) bool { + // Only power operator is right associative + return op == "^" +} + +// parseFunction parses function call (test helper function) +func parseFunction(tokens []string, pos int) (*ExprNode, int, error) { + if pos >= len(tokens) { + return nil, pos, fmt.Errorf("unexpected end of tokens") + } + + // Call existing parseFunctionCall function + node, remaining, err := parseFunctionCall(tokens[pos:]) + if err != nil { + return nil, pos, err + } + + // Calculate new position + newPos := len(tokens) - len(remaining) + return node, newPos, nil +} + +// formatError formats error message +func formatError(message string, args ...interface{}) error { + if len(args) == 0 { + return fmt.Errorf("%s", message) + } + return fmt.Errorf(message, args...) +} + +// copyNode deep copies expression node +func copyNode(node *ExprNode) *ExprNode { + if node == nil { + return nil + } + + newNode := &ExprNode{ + Type: node.Type, + Value: node.Value, + Left: copyNode(node.Left), + Right: copyNode(node.Right), + } + + // Copy function arguments + if len(node.Args) > 0 { + newNode.Args = make([]*ExprNode, len(node.Args)) + for i, arg := range node.Args { + newNode.Args[i] = copyNode(arg) + } + } + + // Copy CASE expression + if node.CaseExpr != nil { + newNode.CaseExpr = &CaseExpression{ + Value: copyNode(node.CaseExpr.Value), + ElseResult: copyNode(node.CaseExpr.ElseResult), + } + + // Copy WHEN clauses + if len(node.CaseExpr.WhenClauses) > 0 { + newNode.CaseExpr.WhenClauses = make([]WhenClause, len(node.CaseExpr.WhenClauses)) + for i, whenClause := range node.CaseExpr.WhenClauses { + newNode.CaseExpr.WhenClauses[i] = WhenClause{ + Condition: copyNode(whenClause.Condition), + Result: copyNode(whenClause.Result), + } + } + } + } + + return newNode +} + +// nodeToString converts expression node to string representation +func nodeToString(node *ExprNode) string { + if node == nil { + return "" + } + + switch node.Type { + case TypeNumber, TypeString, TypeField: + return node.Value + case TypeOperator: + left := nodeToString(node.Left) + right := nodeToString(node.Right) + return fmt.Sprintf("(%s %s %s)", left, node.Value, right) + case TypeFunction: + args := make([]string, len(node.Args)) + for i, arg := range node.Args { + args[i] = nodeToString(arg) + } + return fmt.Sprintf("%s(%s)", node.Value, strings.Join(args, ", ")) + case TypeCase: + return caseExprToString(node.CaseExpr) + default: + return fmt.Sprintf("<%s:%s>", node.Type, node.Value) + } +} + +// caseExprToString converts CASE expression to string representation +func caseExprToString(caseExpr *CaseExpression) string { + if caseExpr == nil { + return "" + } + + var result strings.Builder + result.WriteString("CASE") + + if caseExpr.Value != nil { + result.WriteString(" ") + result.WriteString(nodeToString(caseExpr.Value)) + } + + for _, whenClause := range caseExpr.WhenClauses { + result.WriteString(" WHEN ") + result.WriteString(nodeToString(whenClause.Condition)) + result.WriteString(" THEN ") + result.WriteString(nodeToString(whenClause.Result)) + } + + if caseExpr.ElseResult != nil { + result.WriteString(" ELSE ") + result.WriteString(nodeToString(caseExpr.ElseResult)) + } + + result.WriteString(" END") + return result.String() +} diff --git a/expr/utils_test.go b/expr/utils_test.go new file mode 100644 index 0000000..1397ede --- /dev/null +++ b/expr/utils_test.go @@ -0,0 +1,596 @@ +package expr + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestUnquoteString 测试字符串去引号 +func TestUnquoteString(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"单引号字符串", "'hello'", "hello"}, + {"双引号字符串", "\"world\"", "world"}, + {"空单引号字符串", "''", ""}, + {"空双引号字符串", "\"\"", ""}, + {"包含空格的字符串", "'hello world'", "hello world"}, + {"包含特殊字符的字符串", "'hello@#$%'", "hello@#$%"}, + {"无引号字符串", "hello", "hello"}, + {"只有左引号", "'hello", "'hello"}, + {"只有右引号", "hello'", "hello'"}, + {"引号不匹配", "'hello\"", "'hello\""}, + {"嵌套引号", "'he\"llo'", "he\"llo"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := unquoteString(tt.input) + assert.Equal(t, tt.expected, result, "去引号结果应该正确") + }) + } +} + +// TestUnquoteBacktick 测试反引号去除 +func TestUnquoteBacktick(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"反引号字段", "`field`", "field"}, + {"包含空格的反引号字段", "`field name`", "field name"}, + {"包含特殊字符的反引号字段", "`user.name`", "user.name"}, + {"空反引号", "``", ""}, + {"无反引号", "field", "field"}, + {"只有左反引号", "`field", "`field"}, + {"只有右反引号", "field`", "field`"}, + {"嵌套反引号", "`fie`ld`", "fie`ld"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := unquoteBacktick(tt.input) + assert.Equal(t, tt.expected, result, "去反引号结果应该正确") + }) + } +} + +// TestFormatError 测试错误格式化 +func TestFormatError(t *testing.T) { + tests := []struct { + name string + message string + args []interface{} + expected string + }{ + {"简单错误", "invalid value", nil, "invalid value"}, + {"带参数的错误", "invalid value: %v", []interface{}{"test"}, "invalid value: test"}, + {"多参数错误", "error at position %d: %s", []interface{}{5, "syntax error"}, "error at position 5: syntax error"}, + {"数字参数", "value %d is out of range [%d, %d]", []interface{}{10, 1, 5}, "value 10 is out of range [1, 5]"}, + {"浮点数参数", "result: %.2f", []interface{}{3.14159}, "result: 3.14"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := formatError(tt.message, tt.args...) + assert.Equal(t, tt.expected, err.Error(), "错误格式化结果应该正确") + }) + } +} + +// TestCopyNode 测试节点复制 +func TestCopyNode(t *testing.T) { + tests := []struct { + name string + node *ExprNode + }{ + { + "数字节点", + &ExprNode{Type: TypeNumber, Value: "123"}, + }, + { + "字段节点", + &ExprNode{Type: TypeField, Value: "field1"}, + }, + { + "运算符节点", + &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + Right: &ExprNode{Type: TypeNumber, Value: "2"}, + }, + }, + { + "函数节点", + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeNumber, Value: "1"}}, + }, + }, + { + "CASE节点", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeField, Value: "a"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + copied := copyNode(tt.node) + + // 检查复制的节点不是同一个对象 + assert.NotSame(t, tt.node, copied, "复制的节点应该是不同的对象") + + // 检查值是否相等 + assertNodeEqual(t, tt.node, copied) + + // 修改原节点,确保复制的节点不受影响 + originalValue := tt.node.Value + tt.node.Value = "modified" + assert.NotEqual(t, tt.node.Value, copied.Value, "修改原节点不应该影响复制的节点") + + // 恢复原值 + tt.node.Value = originalValue + }) + } +} + +// TestCopyNode_Nil 测试空节点复制 +func TestCopyNode_Nil(t *testing.T) { + result := copyNode(nil) + assert.Nil(t, result, "复制nil节点应该返回nil") +} + +// TestGetNodeType 测试获取节点类型 +func TestGetNodeType(t *testing.T) { + tests := []struct { + name string + node *ExprNode + expected string + }{ + {"数字节点", &ExprNode{Type: TypeNumber}, "number"}, + {"字段节点", &ExprNode{Type: TypeField}, "field"}, + {"字符串节点", &ExprNode{Type: TypeString}, "string"}, + {"运算符节点", &ExprNode{Type: TypeOperator}, "operator"}, + {"函数节点", &ExprNode{Type: TypeFunction}, "function"}, + {"括号节点", &ExprNode{Type: TypeParenthesis}, "parenthesis"}, + {"CASE节点", &ExprNode{Type: TypeCase}, "case"}, + {"未知类型", &ExprNode{Type: "unknown"}, "unknown"}, + {"空节点", nil, "nil"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getNodeType(tt.node) + assert.Equal(t, tt.expected, result, "节点类型应该正确") + }) + } +} + +// TestGetNodeValue 测试获取节点值 +func TestGetNodeValue(t *testing.T) { + tests := []struct { + name string + node *ExprNode + expected string + }{ + {"数字节点", &ExprNode{Value: "123"}, "123"}, + {"字段节点", &ExprNode{Value: "field1"}, "field1"}, + {"空值节点", &ExprNode{Value: ""}, ""}, + {"空节点", nil, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getNodeValue(tt.node) + assert.Equal(t, tt.expected, result, "节点值应该正确") + }) + } +} + +// TestSetNodeValue 测试设置节点值 +func TestSetNodeValue(t *testing.T) { + tests := []struct { + name string + node *ExprNode + newValue string + }{ + {"设置数字节点值", &ExprNode{Type: TypeNumber, Value: "123"}, "456"}, + {"设置字段节点值", &ExprNode{Type: TypeField, Value: "field1"}, "field2"}, + {"设置空值", &ExprNode{Type: TypeString, Value: "hello"}, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setNodeValue(tt.node, tt.newValue) + assert.Equal(t, tt.newValue, tt.node.Value, "节点值应该被正确设置") + }) + } +} + +// TestSetNodeValue_Nil 测试设置空节点值 +func TestSetNodeValue_Nil(t *testing.T) { + // 这应该不会panic + assert.NotPanics(t, func() { + setNodeValue(nil, "test") + }, "设置nil节点值不应该panic") +} + +// TestIsArithmeticOperator 测试算术运算符判断 +func TestIsArithmeticOperator(t *testing.T) { + tests := []struct { + operator string + expected bool + }{ + {"+", true}, + {"-", true}, + {"*", true}, + {"/", true}, + {"%", true}, + {"^", true}, + {">", false}, + {"<", false}, + {"==", false}, + {"AND", false}, + {"OR", false}, + {"LIKE", false}, + {"unknown", false}, + } + + for _, tt := range tests { + t.Run(tt.operator, func(t *testing.T) { + result := isArithmeticOperator(tt.operator) + assert.Equal(t, tt.expected, result, "算术运算符判断应该正确") + }) + } +} + +// TestIsLogicalOperator 测试逻辑运算符判断 +func TestIsLogicalOperator(t *testing.T) { + tests := []struct { + operator string + expected bool + }{ + {"AND", true}, + {"OR", true}, + {"NOT", true}, + {"+", false}, + {"-", false}, + {">", false}, + {"<", false}, + {"==", false}, + {"LIKE", false}, + {"unknown", false}, + } + + for _, tt := range tests { + t.Run(tt.operator, func(t *testing.T) { + result := isLogicalOperator(tt.operator) + assert.Equal(t, tt.expected, result, "逻辑运算符判断应该正确") + }) + } +} + +// TestIsUnaryOperator 测试一元运算符判断 +func TestIsUnaryOperator(t *testing.T) { + tests := []struct { + operator string + expected bool + }{ + {"NOT", true}, + {"+", false}, + {"-", false}, + {"*", false}, + {"AND", false}, + {"OR", false}, + {"unknown", false}, + } + + for _, tt := range tests { + t.Run(tt.operator, func(t *testing.T) { + result := isUnaryOperator(tt.operator) + assert.Equal(t, tt.expected, result, "一元运算符判断应该正确") + }) + } +} + +// TestIsKeyword 测试关键字判断 +func TestIsKeyword(t *testing.T) { + tests := []struct { + word string + expected bool + }{ + {"CASE", true}, + {"WHEN", true}, + {"THEN", true}, + {"ELSE", true}, + {"END", true}, + {"AND", true}, + {"OR", true}, + {"NOT", true}, + {"LIKE", true}, + {"IS", true}, + {"NULL", true}, + {"TRUE", true}, + {"FALSE", true}, + // 大小写测试 + {"case", true}, + {"Case", true}, + {"when", true}, + {"and", true}, + // 非关键字 + {"field", false}, + {"value", false}, + {"123", false}, + {"unknown", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.word, func(t *testing.T) { + result := isKeyword(tt.word) + assert.Equal(t, tt.expected, result, "关键字判断应该正确") + }) + } +} + +// TestNormalizeIdentifier 测试标识符规范化 +func TestNormalizeIdentifier(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"小写", "field", "field"}, + {"大写", "FIELD", "field"}, + {"混合大小写", "FieldName", "fieldname"}, + {"下划线", "field_name", "field_name"}, + {"数字", "field123", "field123"}, + {"空字符串", "", ""}, + {"单字符", "A", "a"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeIdentifier(tt.input) + assert.Equal(t, tt.expected, result, "标识符规范化应该正确") + }) + } +} + +// assertNodeEqual 断言两个表达式节点相等(测试辅助函数) +func assertNodeEqual(t *testing.T, expected, actual *ExprNode) { + if expected == nil && actual == nil { + return + } + + if expected == nil { + assert.Fail(t, "Expected node is nil but actual is not") + return + } + + if actual == nil { + assert.Fail(t, "Actual node is nil but expected is not") + return + } + + // 比较节点类型和值 + assert.Equal(t, expected.Type, actual.Type, "节点类型应该相等") + assert.Equal(t, expected.Value, actual.Value, "节点值应该相等") + + // 递归比较左右子节点 + assertNodeEqual(t, expected.Left, actual.Left) + assertNodeEqual(t, expected.Right, actual.Right) + + // 比较函数参数 + assert.Equal(t, len(expected.Args), len(actual.Args), "函数参数数量应该相等") + for i := range expected.Args { + assertNodeEqual(t, expected.Args[i], actual.Args[i]) + } + + // 比较CASE表达式 + if expected.CaseExpr == nil && actual.CaseExpr == nil { + return + } + + if expected.CaseExpr == nil || actual.CaseExpr == nil { + assert.Fail(t, "CASE表达式不匹配") + return + } + + assertNodeEqual(t, expected.CaseExpr.Value, actual.CaseExpr.Value) + assertNodeEqual(t, expected.CaseExpr.ElseResult, actual.CaseExpr.ElseResult) + + assert.Equal(t, len(expected.CaseExpr.WhenClauses), len(actual.CaseExpr.WhenClauses), "WHEN子句数量应该相等") + for i := range expected.CaseExpr.WhenClauses { + assertNodeEqual(t, expected.CaseExpr.WhenClauses[i].Condition, actual.CaseExpr.WhenClauses[i].Condition) + assertNodeEqual(t, expected.CaseExpr.WhenClauses[i].Result, actual.CaseExpr.WhenClauses[i].Result) + } +} + +func TestNodeToString(t *testing.T) { + tests := []struct { + name string + node *ExprNode + expected string + }{ + { + name: "nil node", + node: nil, + expected: "", + }, + { + name: "number node", + node: &ExprNode{ + Type: TypeNumber, + Value: "123", + }, + expected: "123", + }, + { + name: "string node", + node: &ExprNode{ + Type: TypeString, + Value: "'hello'", + }, + expected: "'hello'", + }, + { + name: "field node", + node: &ExprNode{ + Type: TypeField, + Value: "field1", + }, + expected: "field1", + }, + { + name: "operator node", + node: &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{ + Type: TypeField, + Value: "a", + }, + Right: &ExprNode{ + Type: TypeField, + Value: "b", + }, + }, + expected: "(a + b)", + }, + { + name: "function node", + node: &ExprNode{ + Type: TypeFunction, + Value: "sum", + Args: []*ExprNode{ + {Type: TypeField, Value: "field1"}, + {Type: TypeField, Value: "field2"}, + }, + }, + expected: "sum(field1, field2)", + }, + { + name: "case node", + node: &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + Value: &ExprNode{Type: TypeField, Value: "status"}, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeString, Value: "'active'"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + expected: "CASE status WHEN 'active' THEN 1 ELSE 0 END", + }, + { + name: "unknown type", + node: &ExprNode{ + Type: "unknown", + Value: "value", + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := nodeToString(tt.node) + if result != tt.expected { + t.Errorf("expected '%s', got '%s'", tt.expected, result) + } + }) + } +} + +func TestCaseExprToString(t *testing.T) { + tests := []struct { + name string + caseExpr *CaseExpression + expected string + }{ + { + name: "nil case expression", + caseExpr: nil, + expected: "", + }, + { + name: "simple case expression", + caseExpr: &CaseExpression{ + Value: &ExprNode{Type: TypeField, Value: "status"}, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeString, Value: "'active'"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + { + Condition: &ExprNode{Type: TypeString, Value: "'inactive'"}, + Result: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "-1"}, + }, + expected: "CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END", + }, + { + name: "search case expression", + caseExpr: &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeOperator, Value: ">", Left: &ExprNode{Type: TypeField, Value: "age"}, Right: &ExprNode{Type: TypeNumber, Value: "18"}}, + Result: &ExprNode{Type: TypeString, Value: "'adult'"}, + }, + { + Condition: &ExprNode{Type: TypeOperator, Value: ">", Left: &ExprNode{Type: TypeField, Value: "age"}, Right: &ExprNode{Type: TypeNumber, Value: "12"}}, + Result: &ExprNode{Type: TypeString, Value: "'teen'"}, + }, + }, + ElseResult: &ExprNode{Type: TypeString, Value: "'child'"}, + }, + expected: "CASE WHEN (age > 18) THEN 'adult' WHEN (age > 12) THEN 'teen' ELSE 'child' END", + }, + { + name: "case expression without else", + caseExpr: &CaseExpression{ + Value: &ExprNode{Type: TypeField, Value: "type"}, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeString, Value: "'A'"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + }, + expected: "CASE type WHEN 'A' THEN 1 END", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := caseExprToString(tt.caseExpr) + if result != tt.expected { + t.Errorf("expected '%s', got '%s'", tt.expected, result) + } + }) + } +} diff --git a/expr/validator.go b/expr/validator.go new file mode 100644 index 0000000..a79fb50 --- /dev/null +++ b/expr/validator.go @@ -0,0 +1,449 @@ +package expr + +import ( + "fmt" + "strings" + + "github.com/rulego/streamsql/functions" +) + +// validateExpression validates the validity of an expression +func validateExpression(node *ExprNode) error { + if node == nil { + return fmt.Errorf("expression node is nil") + } + + switch node.Type { + case TypeNumber: + return validateNumberNode(node) + case TypeString: + return validateStringNode(node) + case TypeField: + return validateFieldNode(node) + case TypeOperator: + return validateOperatorNode(node) + case TypeFunction: + return validateFunctionNode(node) + case TypeCase: + return validateCaseNode(node) + case TypeParenthesis: + return validateParenthesisNode(node) + default: + return fmt.Errorf("unknown node type: %s", node.Type) + } +} + +// validateNumberNode validates a number node +func validateNumberNode(node *ExprNode) error { + if node.Value == "" { + return fmt.Errorf("number node has empty value") + } + + // Check if it's a valid number format + if !isNumber(node.Value) { + return fmt.Errorf("invalid number format: %s", node.Value) + } + + return nil +} + +// validateStringNode validates a string node +func validateStringNode(node *ExprNode) error { + if node.Value == "" { + return fmt.Errorf("string node has empty value") + } + + // Check if string is properly quoted + if !isStringLiteral(node.Value) { + return fmt.Errorf("invalid string literal format: %s", node.Value) + } + + return nil +} + +// validateFieldNode validates a field node +func validateFieldNode(node *ExprNode) error { + if node.Value == "" { + return fmt.Errorf("field node has empty value") + } + + // Check if field name is a valid identifier + fieldName := node.Value + + if !isValidFieldName(fieldName) { + return fmt.Errorf("invalid field name: %s", node.Value) + } + + return nil +} + +// validateOperatorNode validates an operator node +func validateOperatorNode(node *ExprNode) error { + if node.Value == "" { + return fmt.Errorf("operator node has empty value") + } + + // Check if it's a valid operator + if !isOperator(node.Value) && !isComparisonOperator(node.Value) { + return fmt.Errorf("invalid operator: %s", node.Value) + } + + // Check if it's a unary operator + if isUnaryOperator(node.Value) { + // Unary operators only need left operand + if node.Left == nil { + return fmt.Errorf("unary operator '%s' missing operand", node.Value) + } + // Unary operators should not have right operand + if node.Right != nil { + return fmt.Errorf("unary operator '%s' should not have right operand", node.Value) + } + // Validate left operand + return validateExpression(node.Left) + } + + // Binary operators need both left and right operands + if node.Left == nil { + return fmt.Errorf("operator %s missing left operand", node.Value) + } + + if node.Right == nil { + return fmt.Errorf("operator %s missing right operand", node.Value) + } + + // Recursively validate operands + if err := validateExpression(node.Left); err != nil { + return fmt.Errorf("invalid left operand for operator %s: %v", node.Value, err) + } + + if err := validateExpression(node.Right); err != nil { + return fmt.Errorf("invalid right operand for operator %s: %v", node.Value, err) + } + + return nil +} + +// validateFunctionNode validates a function node +// Use unified function registration system for validation +func validateFunctionNode(node *ExprNode) error { + if node.Value == "" { + return fmt.Errorf("function node has empty value") + } + + // Check if function exists in registration system (using lowercase name) + fn, exists := functions.Get(strings.ToLower(node.Value)) + if !exists { + return fmt.Errorf("unknown function: %s", node.Value) + } + + // Use function's own validation logic for basic validation + // Create temporary argument array for validating argument count + tempArgs := make([]interface{}, len(node.Args)) + if err := fn.Validate(tempArgs); err != nil { + return fmt.Errorf("function %s validation failed: %v", node.Value, err) + } + + // Validate argument expressions + return validateFunctionArgs(node) +} + +// validateFunctionArgs validates function arguments +func validateFunctionArgs(node *ExprNode) error { + for i, arg := range node.Args { + if err := validateExpression(arg); err != nil { + return fmt.Errorf("invalid argument %d for function %s: %v", i+1, node.Value, err) + } + } + return nil +} + +// validateParenthesisNode validates a parenthesis node +func validateParenthesisNode(node *ExprNode) error { + // Parenthesis node should have a Left child containing the inner expression + if node.Left == nil { + return fmt.Errorf("parenthesis node missing inner expression") + } + + // Validate the inner expression + return validateExpression(node.Left) +} + +// validateCaseNode validates a CASE expression node +func validateCaseNode(node *ExprNode) error { + if node.CaseExpr == nil { + return fmt.Errorf("CASE expression is missing") + } + + caseExpr := node.CaseExpr + + // Validate the value part of CASE expression (if it's a simple CASE) + if caseExpr.Value != nil { + if err := validateExpression(caseExpr.Value); err != nil { + return fmt.Errorf("invalid CASE value expression: %v", err) + } + } + + // Validate WHEN clauses + if len(caseExpr.WhenClauses) == 0 { + return fmt.Errorf("CASE expression must have at least one WHEN clause") + } + + for i, whenClause := range caseExpr.WhenClauses { + if err := validateExpression(whenClause.Condition); err != nil { + return fmt.Errorf("invalid WHEN condition %d: %v", i+1, err) + } + if err := validateExpression(whenClause.Result); err != nil { + return fmt.Errorf("invalid THEN result %d: %v", i+1, err) + } + } + + // Validate ELSE clause (if exists) + if caseExpr.ElseResult != nil { + if err := validateExpression(caseExpr.ElseResult); err != nil { + return fmt.Errorf("invalid ELSE expression: %v", err) + } + } + + return nil +} + +// isValidFieldName checks if it's a valid field name +// isValidFieldName validates if field name is valid +// Supports normal identifiers and backtick-enclosed field names +func isValidFieldName(name string) bool { + if name == "" { + return false + } + + // If it's a backtick-enclosed field name, check if backticks are properly closed + if len(name) >= 2 && name[0] == '`' && name[len(name)-1] == '`' { + // Content inside backticks can contain any character (except backticks themselves) + inner := name[1 : len(name)-1] + if inner == "" { + return false // Empty backticks not allowed + } + // Check if there are backticks inside + for _, r := range inner { + if r == '`' { + return false + } + } + return true + } + + // Normal field name: can only contain letters, numbers, underscores (dots not allowed) + for i, r := range name { + // For non-ASCII characters, return false directly + if r > 127 { + return false + } + ch := byte(r) + if i == 0 { + // First character must be letter or underscore + if !isLetter(ch) && r != '_' { + return false + } + } else { + // Subsequent characters can be letters, numbers, underscores + if !isLetter(ch) && !isDigit(ch) && r != '_' { + return false + } + } + } + + return true +} + +// validateTokens validates the validity of token list +func validateTokens(tokens []string) error { + if len(tokens) == 0 { + return fmt.Errorf("empty token list") + } + + // Check parentheses matching + if err := validateParentheses(tokens); err != nil { + return err + } + + // Check order of operators and operands + if err := validateTokenOrder(tokens); err != nil { + return err + } + + return nil +} + +// validateParentheses validates parentheses matching +func validateParentheses(tokens []string) error { + stack := 0 + for i, token := range tokens { + if token == "(" { + stack++ + } else if token == ")" { + stack-- + if stack < 0 { + return fmt.Errorf("unmatched closing parenthesis at position %d", i) + } + } + } + + if stack > 0 { + return fmt.Errorf("unmatched opening parenthesis") + } + + return nil +} + +// validateTokenOrder validates token order +func validateTokenOrder(tokens []string) error { + if len(tokens) == 0 { + return nil + } + + // Check cannot start with operator (except unary operators) + firstToken := tokens[0] + if isOperator(firstToken) && !isUnaryOperator(firstToken) { + return fmt.Errorf("expression cannot start with operator: %s", firstToken) + } + + // Check cannot end with operator + lastToken := tokens[len(tokens)-1] + if isOperator(lastToken) { + return fmt.Errorf("expression cannot end with operator: %s", lastToken) + } + + // Check consecutive operators and consecutive operands + for i := 0; i < len(tokens)-1; i++ { + current := tokens[i] + next := tokens[i+1] + + // Check consecutive operators + if isOperator(current) && isOperator(next) { + // Allowed combination: operator followed by unary operator + if !isUnaryOperator(next) { + return fmt.Errorf("consecutive operators not allowed: %s %s at position %d", current, next, i) + } + } + + // Check consecutive operands (two non-operator, non-parenthesis tokens adjacent) + // Special handling for CASE expression keywords + if !isOperator(current) && !isOperator(next) && + current != "(" && current != ")" && next != "(" && next != ")" && + current != "," && next != "," && + !isCaseKeyword(current) && !isCaseKeyword(next) { + return fmt.Errorf("consecutive operands not allowed: %s %s at position %d", current, next, i) + } + } + + return nil +} + +// isCaseKeyword checks if it's a CASE expression keyword +func isCaseKeyword(token string) bool { + switch strings.ToUpper(token) { + case "CASE", "WHEN", "THEN", "ELSE", "END": + return true + default: + return false + } +} + +// validateSyntax validates expression syntax +func validateSyntax(expr string) error { + trimmed := strings.TrimSpace(expr) + if trimmed == "" { + return fmt.Errorf("empty expression") + } + + // Check basic syntax errors + if strings.Contains(trimmed, "()") { + return fmt.Errorf("empty parentheses not allowed") + } + + // Check mismatched parentheses + parenthesesCount := 0 + for _, ch := range trimmed { + if ch == '(' { + parenthesesCount++ + } else if ch == ')' { + parenthesesCount-- + if parenthesesCount < 0 { + return fmt.Errorf("mismatched parentheses") + } + } + } + if parenthesesCount != 0 { + return fmt.Errorf("mismatched parentheses") + } + + // Check consecutive operators + operators := []string{"+", "-", "*", "/", "%", "^", "=", "!=", "<>", ">", "<", ">=", "<="} + for _, op1 := range operators { + for _, op2 := range operators { + // Check consecutive operators (separated by spaces) + if strings.Contains(trimmed, " "+op1+" "+op2+" ") { + return fmt.Errorf("consecutive operators") + } + // Check directly adjacent operators (except allowed combinations) + if op1 != op2 && strings.Contains(trimmed, op1+op2) { + // Allow certain combinations like ">=, <=, !=, <>" + allowed := []string{">=", "<=", "!=", "<>"} + combination := op1 + op2 + isAllowed := false + for _, allowedOp := range allowed { + if combination == allowedOp { + isAllowed = true + break + } + } + if !isAllowed { + return fmt.Errorf("consecutive operators") + } + } + } + } + + // Check if expression starts or ends with operator + for _, op := range operators { + if strings.HasPrefix(trimmed, op+" ") { + return fmt.Errorf("expression cannot start with operator") + } + if strings.HasSuffix(trimmed, " "+op) { + return fmt.Errorf("expression cannot end with operator") + } + } + + return nil +} + +// ValidateExpression public interface: validates expression string +func ValidateExpression(expr string) error { + // First validate syntax + if err := validateSyntax(expr); err != nil { + return err + } + + // Tokenize + tokens, err := tokenize(expr) + if err != nil { + return fmt.Errorf("tokenization error: %v", err) + } + + // Validate tokens + if err := validateTokens(tokens); err != nil { + return fmt.Errorf("token validation error: %v", err) + } + + // Parse to AST + node, err := parseExpression(tokens) + if err != nil { + return fmt.Errorf("parsing error: %v", err) + } + + // Validate AST + if err := validateExpression(node); err != nil { + return fmt.Errorf("expression validation error: %v", err) + } + + return nil +} diff --git a/expr/validator_test.go b/expr/validator_test.go new file mode 100644 index 0000000..133fa16 --- /dev/null +++ b/expr/validator_test.go @@ -0,0 +1,885 @@ +package expr + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestValidateExpressionNode 测试表达式节点验证功能 +func TestValidateExpressionNode(t *testing.T) { + tests := []struct { + name string + node *ExprNode + wantErr bool + }{ + { + "有效数字节点", + &ExprNode{Type: TypeNumber, Value: "123"}, + false, + }, + { + "有效字段节点", + &ExprNode{Type: TypeField, Value: "field1"}, + false, + }, + { + "有效字符串节点", + &ExprNode{Type: TypeString, Value: "'hello'"}, + false, + }, + { + "有效运算符节点", + &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + Right: &ExprNode{Type: TypeNumber, Value: "2"}, + }, + false, + }, + { + "有效函数节点", + &ExprNode{ + Type: TypeFunction, + Value: "abs", + Args: []*ExprNode{{Type: TypeNumber, Value: "1"}}, + }, + false, + }, + { + "有效CASE节点", + &ExprNode{ + Type: TypeCase, + CaseExpr: &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + false, + }, + // 错误情况 + {"空节点", nil, true}, + {"无效数字", &ExprNode{Type: TypeNumber, Value: "abc"}, true}, + {"无效字段名", &ExprNode{Type: TypeField, Value: "123field"}, true}, + {"无效字符串", &ExprNode{Type: TypeString, Value: "hello"}, true}, + {"运算符缺少左操作数", &ExprNode{ + Type: TypeOperator, + Value: "+", + Right: &ExprNode{Type: TypeNumber, Value: "1"}, + }, true}, + {"运算符缺少右操作数", &ExprNode{ + Type: TypeOperator, + Value: "+", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + }, true}, + {"无效运算符", &ExprNode{ + Type: TypeOperator, + Value: "@", + Left: &ExprNode{Type: TypeNumber, Value: "1"}, + Right: &ExprNode{Type: TypeNumber, Value: "2"}, + }, true}, + {"无效函数", &ExprNode{ + Type: TypeFunction, + Value: "unknown", + Args: []*ExprNode{{Type: TypeNumber, Value: "1"}}, + }, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateExpression(tt.node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "验证不应该失败") + } + }) + } +} + +// TestValidateExpression 测试公共表达式验证接口 +func TestValidateExpression(t *testing.T) { + tests := []struct { + name string + expr string + wantErr bool + }{ + // 有效表达式 + {"简单数字", "123", false}, + {"简单字段", "field1", false}, + {"算术表达式", "1 + 2", false}, + {"比较表达式", "field1 > 10", false}, + {"函数调用", "abs(-5)", false}, + {"复杂表达式", "(field1 + field2) * 2", false}, + {"字符串比较", "name = 'test'", false}, + {"CASE表达式", "CASE WHEN field1 > 0 THEN 1 ELSE 0 END", false}, + {"逻辑表达式", "field1 > 0 AND field2 < 100", false}, + {"嵌套函数", "max(abs(field1), abs(field2))", false}, + + // 无效表达式 + {"空表达式", "", true}, + {"只有空格", " ", true}, + {"括号不匹配1", "(1 + 2", true}, + {"括号不匹配2", "1 + 2)", true}, + {"连续运算符", "1 + + 2", true}, + {"运算符开头", "+ 1 + 2", true}, + {"运算符结尾", "1 + 2 +", true}, + {"空括号", "()", true}, + {"无效数字", "12.34.56", true}, + {"无效字符串", "'unclosed string", true}, + {"无效函数", "unknown_func(1)", true}, + {"无效字段名", "123field", true}, + {"tokenize错误", "field with invalid 'quote", true}, + {"解析错误", "(((", true}, + {"AST验证错误", "123abc", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateExpression(tt.expr) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "验证不应该失败") + } + }) + } +} + +// TestValidateNumberNode 测试数字节点验证 +func TestValidateNumberNode(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + {"正整数", "123", false}, + {"负整数", "-123", false}, + {"零", "0", false}, + {"正小数", "3.14", false}, + {"负小数", "-3.14", false}, + {"科学计数法", "1.5e10", false}, + {"负科学计数法", "-1.5e-3", false}, + {"小数点开头", ".5", false}, + {"小数点结尾", "5.", false}, + // 错误情况 + {"空字符串", "", true}, + {"字母", "abc", true}, + {"多个小数点", "3.14.15", true}, + {"多个负号", "--5", true}, + {"负号在中间", "3-5", true}, + {"只有小数点", ".", true}, + {"只有负号", "-", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{Type: TypeNumber, Value: tt.value} + err := validateNumberNode(node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "数字验证不应该失败") + } + }) + } +} + +// TestValidateStringNode 测试字符串节点验证 +func TestValidateStringNode(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + {"单引号字符串", "'hello'", false}, + {"双引号字符串", "\"world\"", false}, + {"空字符串", "''", false}, + {"空双引号字符串", "\"\"", false}, + {"包含转义的字符串", "'hello\\world'", false}, + {"包含单引号的双引号字符串", "\"hello'world\"", false}, + {"包含双引号的单引号字符串", "'hello\"world'", false}, + // 错误情况 + {"未闭合单引号", "'hello", true}, + {"未闭合双引号", "\"hello", true}, + {"没有引号", "hello", true}, + {"空值", "", true}, + {"只有单引号", "'", true}, + {"只有双引号", "\"", true}, + {"引号不匹配", "'hello\"", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{Type: TypeString, Value: tt.value} + err := validateStringNode(node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "字符串验证不应该失败") + } + }) + } +} + +// TestValidateFieldNode 测试字段节点验证 +func TestValidateFieldNode(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + {"简单字段名", "field1", false}, + {"下划线开头", "_field", false}, + {"包含数字", "field123", false}, + {"驼峰命名", "fieldName", false}, + {"蛇形命名", "field_name", false}, + {"大写字段", "FIELD", false}, + {"反引号字段", "`field name`", false}, + {"反引号包含特殊字符", "`user.name`", false}, + {"反引号包含空格", "`user name`", false}, + // 错误情况 + {"空字段名", "", true}, + {"数字开头", "123field", true}, + {"包含特殊字符", "field-name", true}, + {"包含空格", "field name", true}, + {"包含点号", "field.name", true}, + {"未闭合反引号", "`field", true}, + {"只有反引号", "`", true}, + {"空反引号", "``", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{Type: TypeField, Value: tt.value} + err := validateFieldNode(node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "字段验证不应该失败") + } + }) + } +} + +// TestValidateOperatorNode 测试运算符节点验证 +func TestValidateOperatorNode(t *testing.T) { + tests := []struct { + name string + operator string + left *ExprNode + right *ExprNode + wantErr bool + }{ + { + "有效加法", + "+", + &ExprNode{Type: TypeNumber, Value: "1"}, + &ExprNode{Type: TypeNumber, Value: "2"}, + false, + }, + { + "有效比较", + ">", + &ExprNode{Type: TypeField, Value: "a"}, + &ExprNode{Type: TypeNumber, Value: "0"}, + false, + }, + { + "有效逻辑运算", + "AND", + &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + &ExprNode{ + Type: TypeOperator, + Value: "<", + Left: &ExprNode{Type: TypeField, Value: "b"}, + Right: &ExprNode{Type: TypeNumber, Value: "10"}, + }, + false, + }, + { + "有效NOT运算(单操作数)", + "NOT", + &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + nil, + false, + }, + // 错误情况 + {"无效运算符", "@", &ExprNode{Type: TypeNumber, Value: "1"}, &ExprNode{Type: TypeNumber, Value: "2"}, true}, + {"缺少左操作数", "+", nil, &ExprNode{Type: TypeNumber, Value: "2"}, true}, + {"缺少右操作数(双操作数运算符)", "+", &ExprNode{Type: TypeNumber, Value: "1"}, nil, true}, + {"NOT运算符有右操作数", "NOT", &ExprNode{Type: TypeNumber, Value: "1"}, &ExprNode{Type: TypeNumber, Value: "2"}, true}, + {"左操作数验证失败", "+", &ExprNode{Type: TypeNumber, Value: "abc"}, &ExprNode{Type: TypeNumber, Value: "2"}, true}, + {"右操作数验证失败", "+", &ExprNode{Type: TypeNumber, Value: "1"}, &ExprNode{Type: TypeNumber, Value: "abc"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{ + Type: TypeOperator, + Value: tt.operator, + Left: tt.left, + Right: tt.right, + } + err := validateOperatorNode(node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "运算符验证不应该失败") + } + }) + } +} + +// TestValidateFunctionNode 测试函数节点验证 +func TestValidateFunctionNode(t *testing.T) { + // 测试函数Value为空的情况 + t.Run("函数名为空", func(t *testing.T) { + node := &ExprNode{ + Type: TypeFunction, + Value: "", + Args: []*ExprNode{{Type: TypeNumber, Value: "1"}}, + } + err := validateFunctionNode(node) + assert.Error(t, err, "函数名为空时应该返回错误") + assert.Contains(t, err.Error(), "function node has empty value") + }) + + tests := []struct { + name string + funcName string + args []*ExprNode + wantErr bool + }{ + { + "ABS函数", + "abs", + []*ExprNode{{Type: TypeNumber, Value: "1"}}, + false, + }, + { + "MAX函数", + "max", + []*ExprNode{ + {Type: TypeNumber, Value: "1"}, + {Type: TypeNumber, Value: "2"}, + {Type: TypeNumber, Value: "3"}, + }, + false, + }, + { + "POW函数", + "pow", + []*ExprNode{ + {Type: TypeNumber, Value: "2"}, + {Type: TypeNumber, Value: "3"}, + }, + false, + }, + { + "COUNT函数(无参数)", + "count", + []*ExprNode{}, + false, + }, + // 错误情况 + {"未知函数", "unknown", []*ExprNode{{Type: TypeNumber, Value: "1"}}, true}, + {"ABS参数数量错误", "abs", []*ExprNode{}, true}, + {"POW参数数量错误", "pow", []*ExprNode{{Type: TypeNumber, Value: "2"}}, true}, + {"参数验证失败", "abs", []*ExprNode{{Type: TypeNumber, Value: "abc"}}, true}, + {"参数表达式验证失败", "abs", []*ExprNode{{Type: TypeField, Value: "invalid field name!"}}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{ + Type: TypeFunction, + Value: tt.funcName, + Args: tt.args, + } + err := validateFunctionNode(node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "函数验证不应该失败") + } + }) + } +} + +// TestValidateFunctionArgs 测试函数参数验证 +func TestValidateFunctionArgs(t *testing.T) { + tests := []struct { + name string + funcName string + args []*ExprNode + wantErr bool + }{ + // 单参数函数 + {"ABS正确参数", "abs", []*ExprNode{{Type: TypeNumber, Value: "1"}}, false}, + {"ABS参数过少", "abs", []*ExprNode{}, true}, + {"ABS参数过多", "abs", []*ExprNode{{Type: TypeNumber, Value: "1"}, {Type: TypeNumber, Value: "2"}}, true}, + + // 双参数函数 + {"POW正确参数", "pow", []*ExprNode{{Type: TypeNumber, Value: "2"}, {Type: TypeNumber, Value: "3"}}, false}, + {"POW参数过少", "pow", []*ExprNode{{Type: TypeNumber, Value: "2"}}, true}, + {"POW参数过多", "pow", []*ExprNode{{Type: TypeNumber, Value: "2"}, {Type: TypeNumber, Value: "3"}, {Type: TypeNumber, Value: "4"}}, true}, + + // 可变参数函数 + {"MAX单参数", "max", []*ExprNode{{Type: TypeNumber, Value: "1"}}, false}, + {"MAX多参数", "max", []*ExprNode{{Type: TypeNumber, Value: "1"}, {Type: TypeNumber, Value: "2"}, {Type: TypeNumber, Value: "3"}}, false}, + {"MAX无参数", "max", []*ExprNode{}, true}, + + // 无参数函数(如果有的话) + {"COUNT无参数", "count", []*ExprNode{}, false}, + {"COUNT有参数", "count", []*ExprNode{{Type: TypeNumber, Value: "1"}}, false}, // COUNT可以有参数 + + // 未知函数 + {"未知函数", "unknown", []*ExprNode{{Type: TypeNumber, Value: "1"}}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 创建一个函数节点来测试 + node := &ExprNode{ + Type: TypeFunction, + Value: tt.funcName, + Args: tt.args, + } + // 使用validateFunctionNode来验证函数名和参数数量 + err := validateFunctionNode(node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "函数参数验证不应该失败") + } + }) + } +} + +// TestValidateCaseNode 测试CASE节点验证 +func TestValidateCaseNode(t *testing.T) { + tests := []struct { + name string + caseExpr *CaseExpression + wantErr bool + }{ + { + "有效CASE表达式", + &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + false, + }, + { + "简单CASE表达式(带Value)", + &CaseExpression{ + Value: &ExprNode{Type: TypeField, Value: "status"}, + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{Type: TypeString, Value: "'active'"}, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + { + Condition: &ExprNode{Type: TypeString, Value: "'inactive'"}, + Result: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "-1"}, + }, + false, + }, + { + "多个WHEN子句", + &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + { + Condition: &ExprNode{ + Type: TypeOperator, + Value: "<", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + Result: &ExprNode{Type: TypeNumber, Value: "-1"}, + }, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + false, + }, + { + "没有ELSE子句", + &CaseExpression{ + WhenClauses: []WhenClause{ + { + Condition: &ExprNode{ + Type: TypeOperator, + Value: ">", + Left: &ExprNode{Type: TypeField, Value: "a"}, + Right: &ExprNode{Type: TypeNumber, Value: "0"}, + }, + Result: &ExprNode{Type: TypeNumber, Value: "1"}, + }, + }, + ElseResult: nil, + }, + false, + }, + // 错误情况 + {"没有WHEN子句", &CaseExpression{WhenClauses: []WhenClause{}, ElseResult: &ExprNode{Type: TypeNumber, Value: "0"}}, true}, + {"WHEN条件为空", &CaseExpression{ + WhenClauses: []WhenClause{ + {Condition: nil, Result: &ExprNode{Type: TypeNumber, Value: "1"}}, + }, + }, true}, + {"WHEN结果为空", &CaseExpression{ + WhenClauses: []WhenClause{ + {Condition: &ExprNode{Type: TypeField, Value: "a"}, Result: nil}, + }, + }, true}, + {"WHEN条件验证失败", &CaseExpression{ + WhenClauses: []WhenClause{ + {Condition: &ExprNode{Type: TypeNumber, Value: "abc"}, Result: &ExprNode{Type: TypeNumber, Value: "1"}}, + }, + }, true}, + {"WHEN结果验证失败", &CaseExpression{ + WhenClauses: []WhenClause{ + {Condition: &ExprNode{Type: TypeField, Value: "a"}, Result: &ExprNode{Type: TypeNumber, Value: "abc"}}, + }, + }, true}, + {"ELSE结果验证失败", &CaseExpression{ + WhenClauses: []WhenClause{ + {Condition: &ExprNode{Type: TypeField, Value: "a"}, Result: &ExprNode{Type: TypeNumber, Value: "1"}}, + }, + ElseResult: &ExprNode{Type: TypeNumber, Value: "abc"}, + }, true}, + {"简单CASE的Value验证失败", &CaseExpression{ + Value: &ExprNode{Type: TypeNumber, Value: "invalid_number"}, + WhenClauses: []WhenClause{ + {Condition: &ExprNode{Type: TypeString, Value: "'test'"}, Result: &ExprNode{Type: TypeNumber, Value: "1"}}, + }, + }, true}, + } + + // 添加CaseExpr为nil的测试用例 + t.Run("CaseExpr为nil", func(t *testing.T) { + node := &ExprNode{Type: TypeCase, CaseExpr: nil} + err := validateCaseNode(node) + assert.Error(t, err, "CaseExpr为nil时应该返回错误") + assert.Contains(t, err.Error(), "CASE expression is missing") + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &ExprNode{Type: TypeCase, CaseExpr: tt.caseExpr} + err := validateCaseNode(node) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "CASE节点验证不应该失败") + } + }) + } +} + +// TestIsValidFieldName 测试字段名验证 +func TestIsValidFieldName(t *testing.T) { + tests := []struct { + name string + fieldName string + expected bool + }{ + // 有效情况 + {"简单字段名", "field", true}, + {"下划线开头", "_field", true}, + {"包含数字", "field123", true}, + {"驼峰命名", "fieldName", true}, + {"蛇形命名", "field_name", true}, + {"大写字段", "FIELD", true}, + {"单字符字段", "a", true}, + {"单下划线", "_", true}, + {"反引号字段", "`field name`", true}, + {"反引号包含特殊字符", "`user.name`", true}, + {"反引号包含数字开头", "`123field`", true}, + {"反引号包含连字符", "`field-name`", true}, + {"反引号包含各种符号", "`field@#$%^&*()`", true}, + + // 无效情况 + {"空字段名", "", false}, + {"数字开头", "123field", false}, + {"包含连字符", "field-name", false}, + {"包含空格(无反引号)", "field name", false}, + {"包含点号(无反引号)", "field.name", false}, + {"包含特殊字符@", "field@name", false}, + {"包含特殊字符#", "field#name", false}, + {"包含特殊字符$", "field$name", false}, + {"包含特殊字符%", "field%name", false}, + {"未闭合反引号", "`field", false}, + {"空反引号", "``", false}, + {"反引号内包含反引号", "`field`name`", false}, + {"只有反引号开头", "`", false}, + {"非ASCII字符", "字段名", false}, + {"包含非ASCII字符", "field字段", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidFieldName(tt.fieldName) + assert.Equal(t, tt.expected, result, "字段名验证结果应该正确") + }) + } +} + +// TestValidateTokens 测试标记列表验证 +func TestValidateTokens(t *testing.T) { + tests := []struct { + name string + tokens []string + wantErr bool + }{ + {"有效标记列表", []string{"a", "+", "b"}, false}, + {"有效函数调用", []string{"abs", "(", "x", ")"}, false}, + {"有效CASE表达式", []string{"CASE", "WHEN", "a", ">", "0", "THEN", "1", "ELSE", "0", "END"}, false}, + // 错误情况 + {"空标记列表", []string{}, true}, + {"括号不匹配", []string{"(", "a", "+", "b"}, true}, + {"连续运算符", []string{"a", "+", "+", "b"}, true}, + {"运算符开头", []string{"+", "a"}, true}, + {"运算符结尾", []string{"a", "+"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateTokens(tt.tokens) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "标记验证不应该失败") + } + }) + } +} + +// TestValidateParentheses 测试括号验证 +func TestValidateParentheses(t *testing.T) { + tests := []struct { + name string + tokens []string + wantErr bool + }{ + {"匹配的括号", []string{"(", "a", "+", "b", ")"}, false}, + {"嵌套括号", []string{"(", "(", "a", "+", "b", ")", "*", "c", ")"}, false}, + {"函数括号", []string{"abs", "(", "x", ")"}, false}, + {"无括号", []string{"a", "+", "b"}, false}, + // 错误情况 + {"缺少右括号", []string{"(", "a", "+", "b"}, true}, + {"缺少左括号", []string{"a", "+", "b", ")"}, true}, + {"括号顺序错误", []string{")", "a", "+", "b", "("}, true}, + {"嵌套不匹配", []string{"(", "(", "a", "+", "b", ")"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateParentheses(tt.tokens) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "括号验证不应该失败") + } + }) + } +} + +// TestValidateTokenOrder 测试标记顺序验证 +func TestValidateTokenOrder(t *testing.T) { + tests := []struct { + name string + tokens []string + wantErr bool + }{ + {"正确顺序", []string{"a", "+", "b"}, false}, + {"函数调用", []string{"abs", "(", "x", ")"}, false}, + {"复杂表达式", []string{"a", "+", "b", "*", "c"}, false}, + // 错误情况 + {"连续运算符", []string{"a", "+", "+", "b"}, true}, + {"运算符开头", []string{"+", "a"}, true}, + {"运算符结尾", []string{"a", "+"}, true}, + {"连续操作数", []string{"a", "b", "+", "c"}, true}, + {"CASE关键字组合", []string{"CASE", "WHEN", "field", "THEN", "value", "END"}, false}, + {"操作符后跟一元操作符", []string{"a", "AND", "NOT", "b"}, false}, + {"连续二元操作符", []string{"a", "+", "*", "b"}, true}, + {"以一元操作符开始", []string{"NOT", "a"}, false}, + {"以二元操作符开始", []string{"-", "a"}, true}, + {"逗号分隔的参数", []string{"func", "(", "a", ",", "b", ")"}, false}, + {"括号和操作数混合", []string{"(", "a", ")", "+", "b"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateTokenOrder(tt.tokens) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "标记顺序验证不应该失败") + } + }) + } +} + +// TestValidateSyntax 测试语法验证 +func TestValidateParenthesisNode(t *testing.T) { + tests := []struct { + name string + node *ExprNode + wantErr bool + errMsg string + }{ + { + name: "有效的括号表达式", + node: &ExprNode{ + Type: TypeParenthesis, + Left: &ExprNode{ + Type: TypeField, + Value: "field1", + }, + }, + wantErr: false, + }, + { + name: "括号内为空", + node: &ExprNode{ + Type: TypeParenthesis, + Left: nil, + }, + wantErr: true, + errMsg: "parenthesis node missing inner expression", + }, + { + name: "括号内表达式无效", + node: &ExprNode{ + Type: TypeParenthesis, + Left: &ExprNode{ + Type: TypeField, + Value: "", // 空字段名 + }, + }, + wantErr: true, + errMsg: "field node has empty value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateParenthesisNode(tt.node) + if tt.wantErr { + if err == nil { + t.Errorf("validateParenthesisNode() expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateParenthesisNode() error = %v, want error containing %v", err, tt.errMsg) + } + } else { + if err != nil { + t.Errorf("validateParenthesisNode() error = %v, want nil", err) + } + } + }) + } +} + +func TestValidateSyntax(t *testing.T) { + tests := []struct { + name string + expr string + wantErr bool + }{ + // 有效表达式 + {"有效算术表达式", "a + b", false}, + {"有效函数调用", "abs(x)", false}, + {"有效CASE表达式", "CASE WHEN a > 0 THEN 1 END", false}, + {"有效比较表达式", "field >= 10", false}, + {"有效逻辑表达式", "a != b", false}, + {"有效不等于表达式", "a <> b", false}, + {"有效小于等于表达式", "a <= b", false}, + {"复杂表达式", "a + b * c", false}, + + // 错误情况 + {"空表达式", "", true}, + {"只有空格", " ", true}, + {"空括号", "()", true}, + {"括号不匹配1", "(a + b", true}, + {"括号不匹配2", "a + b)", true}, + {"连续运算符(空格分隔)", "a + + b", true}, + {"连续运算符(直接相邻)", "a+-b", true}, + {"运算符开头", "+ a + b", true}, + {"运算符结尾", "a + b +", true}, + {"乘法运算符开头", "* a + b", true}, + {"除法运算符结尾", "a + b /", true}, + {"模运算符开头", "% a + b", true}, + {"幂运算符结尾", "a + b ^", true}, + {"等号运算符开头", "= a + b", true}, + {"不等号运算符结尾", "a + b !=", true}, + {"大于号运算符开头", "> a + b", true}, + {"小于号运算符结尾", "a + b <", true}, + {"大于等于运算符开头", ">= a + b", true}, + {"小于等于运算符结尾", "a + b <=", true}, + {"不等于运算符开头", "<> a + b", true}, + {"多个连续运算符组合1", "a + * b", true}, + {"多个连续运算符组合2", "a / % b", true}, + {"多个连续运算符组合3", "a ^ + b", true}, + {"多个连续运算符组合4", "a = > b", true}, + {"多个连续运算符组合5", "a < = b", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSyntax(tt.expr) + if tt.wantErr { + assert.Error(t, err, "应该返回错误") + } else { + assert.NoError(t, err, "语法验证不应该失败") + } + }) + } +}