Files
streamsql/expr/expression.go
T
2025-08-29 17:30:58 +08:00

368 lines
10 KiB
Go

package expr
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/rulego/streamsql/functions"
)
// Expression types - expression type constants
const (
TypeNumber = "number" // Number constant
TypeField = "field" // Field reference
TypeOperator = "operator" // Operator
TypeFunction = "function" // Function call
TypeParenthesis = "parenthesis" // Parenthesis
TypeCase = "case" // CASE expression
TypeString = "string" // String constant
)
// WhenClause represents a WHEN clause in a CASE expression
type WhenClause struct {
Condition *ExprNode // WHEN condition
Result *ExprNode // THEN result
}
// CaseExpression represents the structure of a CASE expression
type CaseExpression struct {
Value *ExprNode // Expression after CASE (simple CASE)
WhenClauses []WhenClause // List of WHEN clauses
ElseResult *ExprNode // ELSE expression
}
// ExprNode represents an expression node
type ExprNode struct {
Type string // Node type
Value string // Node value
Left *ExprNode // Left child node
Right *ExprNode // Right child node
Args []*ExprNode // Function argument list
// Fields specific to CASE expressions
CaseExpr *CaseExpression // CASE expression structure
}
// Expression represents a computable expression
type Expression struct {
Root *ExprNode // Expression root node
useExprLang bool // Whether to use expr-lang/expr
exprLangExpression string // expr-lang expression string
}
// NewExpression creates a new expression
func NewExpression(exprStr string) (*Expression, error) {
// Perform basic syntax validation
if err := validateBasicSyntax(exprStr); err != nil {
return nil, err
}
// First try using custom parser
tokens, err := tokenize(exprStr)
if err != nil {
// If custom parsing fails, mark to use expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
root, err := parseExpression(tokens)
if err != nil {
// If custom parsing fails, mark to use expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
return &Expression{
Root: root,
useExprLang: false,
}, nil
}
// validateBasicSyntax performs basic syntax validation
func validateBasicSyntax(exprStr string) error {
// Check for empty expression
trimmed := strings.TrimSpace(exprStr)
if trimmed == "" {
return fmt.Errorf("empty expression")
}
// Check for mismatched parentheses
parenthesesCount := 0
for _, ch := range trimmed {
if ch == '(' {
parenthesesCount++
} else if ch == ')' {
parenthesesCount--
if parenthesesCount < 0 {
return fmt.Errorf("mismatched parentheses")
}
}
}
if parenthesesCount != 0 {
return fmt.Errorf("mismatched parentheses")
}
// Check for consecutive operators
operators := []string{"+", "-", "*", "/", "%", "^", "=", "!=", "<>", ">", "<", ">=", "<="}
for _, op1 := range operators {
for _, op2 := range operators {
if strings.Contains(trimmed, " "+op1+" "+op2+" ") {
return fmt.Errorf("consecutive operators")
}
}
}
// Check if expression starts or ends with operator
for _, op := range operators {
if strings.HasPrefix(trimmed, op+" ") {
return fmt.Errorf("expression cannot start with operator")
}
if strings.HasSuffix(trimmed, " "+op) {
return fmt.Errorf("expression cannot end with operator")
}
}
// Check for invalid characters
for i, ch := range trimmed {
// Allowed characters: letters, numbers, operators, parentheses, dots, underscores, spaces, quotes
if !isValidChar(ch) {
return fmt.Errorf("invalid character '%c' at position %d", ch, i)
}
}
return nil
}
// isValidChar checks if a character is valid
func isValidChar(ch rune) bool {
// Letters and numbers
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') {
return true
}
// Special characters
switch ch {
case ' ', '\t', '\n', '\r': // Whitespace characters
return true
case '+', '-', '*', '/', '%', '^': // Arithmetic operators
return true
case '(', ')', ',': // Parentheses and comma
return true
case '>', '<', '=', '!': // Comparison operators
return true
case '\'', '"': // Quotes
return true
case '.', '_': // Dot and underscore
return true
case '$': // Dollar sign (for JSON paths, etc.)
return true
case '`': // Backtick (for identifiers)
return true
default:
return false
}
}
// Evaluate calculates the value of the expression
func (e *Expression) Evaluate(data map[string]interface{}) (float64, error) {
if e.useExprLang {
return e.evaluateWithExprLang(data)
}
return evaluateNode(e.Root, data)
}
// evaluateWithExprLang evaluates expression using expr-lang/expr
func (e *Expression) evaluateWithExprLang(data map[string]interface{}) (float64, error) {
// Use bridge to evaluate expression
bridge := functions.GetExprBridge()
result, err := bridge.EvaluateExpression(e.exprLangExpression, data)
if err != nil {
return 0, err
}
// Try to convert result to float64
switch v := result.(type) {
case float64:
return v, nil
case float32:
return float64(v), nil
case int:
return float64(v), nil
case int32:
return float64(v), nil
case int64:
return float64(v), nil
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f, nil
}
return 0, fmt.Errorf("cannot convert string result '%s' to float64", v)
default:
return 0, fmt.Errorf("expression result type %T is not convertible to float64", result)
}
}
// GetFields gets all fields referenced in the expression
// Returns fields in sorted order to ensure consistent results
func (e *Expression) GetFields() []string {
if e.useExprLang {
// For expr-lang expressions, need to parse field references
// Simplified handling here, should use AST analysis in practice
return extractFieldsFromExprLang(e.exprLangExpression)
}
fields := make(map[string]bool)
collectFields(e.Root, fields)
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
// Sort fields to ensure consistent order
sort.Strings(result)
return result
}
// extractFieldsFromExprLang extracts field references from expr-lang expression (simplified version)
// Returns fields in sorted order to ensure consistent results
func extractFieldsFromExprLang(expression string) []string {
// This is a simplified implementation, should use AST parsing in practice
// Temporarily use regex or simple string parsing
fields := make(map[string]bool)
// Simple field extraction: find identifier patterns, support dot-separated nested fields
tokens := strings.FieldsFunc(expression, func(c rune) bool {
return !(c >= 'a' && c <= 'z') && !(c >= 'A' && c <= 'Z') && !(c >= '0' && c <= '9') && c != '_' && c != '.'
})
for _, token := range tokens {
if isValidFieldIdentifier(token) && !isNumber(token) && !isFunctionOrKeyword(token) {
fields[token] = true
}
}
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
// Sort fields to ensure consistent order
sort.Strings(result)
return result
}
// isValidFieldIdentifier checks if it's a valid field identifier (supports dot-separated nested fields)
func isValidFieldIdentifier(s string) bool {
if len(s) == 0 {
return false
}
// Split dot-separated fields
parts := strings.Split(s, ".")
for _, part := range parts {
if !isIdentifier(part) {
return false
}
}
return true
}
// isFunctionOrKeyword checks if it's a function name or keyword
func isFunctionOrKeyword(token string) bool {
// Check if it's a known function or keyword
keywords := []string{
"and", "or", "not", "true", "false", "nil", "null", "is",
"if", "else", "then", "in", "contains", "matches",
// CASE expression keywords
"case", "when", "then", "else", "end",
}
for _, keyword := range keywords {
if strings.ToLower(token) == keyword {
return true
}
}
// Check if it's a registered function
bridge := functions.GetExprBridge()
_, exists, _ := bridge.ResolveFunction(token)
return exists
}
// collectFields collects all fields in the expression
func collectFields(node *ExprNode, fields map[string]bool) {
if node == nil {
return
}
if node.Type == TypeField {
// Remove backticks (if present)
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1]
}
fields[fieldName] = true
}
// Recursively collect fields from child nodes
collectFields(node.Left, fields)
collectFields(node.Right, fields)
// Collect fields from function arguments
for _, arg := range node.Args {
collectFields(arg, fields)
}
// Collect fields from CASE expression
if node.CaseExpr != nil {
collectFields(node.CaseExpr.Value, fields)
collectFields(node.CaseExpr.ElseResult, fields)
for _, whenClause := range node.CaseExpr.WhenClauses {
collectFields(whenClause.Condition, fields)
collectFields(whenClause.Result, fields)
}
}
}
// EvaluateBool calculates the boolean value of the expression
func (e *Expression) EvaluateBool(data map[string]interface{}) (bool, error) {
if e.useExprLang {
// For expr-lang expressions, calculate numeric value first then convert to boolean
result, err := e.evaluateWithExprLang(data)
if err != nil {
return false, err
}
return result != 0, nil
}
return evaluateBoolNode(e.Root, data)
}
// EvaluateWithNull provides public interface for aggregate function calls, supports NULL value handling
func (e *Expression) EvaluateWithNull(data map[string]interface{}) (float64, bool, error) {
if e.useExprLang {
// expr-lang doesn't support NULL, fallback to original logic
result, err := e.evaluateWithExprLang(data)
return result, false, err
}
return evaluateNodeWithNull(e.Root, data)
}
// EvaluateValueWithNull evaluates expression and returns value of any type, supports NULL
func (e *Expression) EvaluateValueWithNull(data map[string]interface{}) (interface{}, bool, error) {
if e.useExprLang {
// expr-lang doesn't support NULL, fallback to original logic
result, err := e.evaluateWithExprLang(data)
return result, false, err
}
return evaluateNodeValueWithNull(e.Root, data)
}