mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-03-14 06:17:29 +00:00
450 lines
12 KiB
Go
450 lines
12 KiB
Go
package expr
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/rulego/streamsql/functions"
|
|
)
|
|
|
|
// validateExpression validates the validity of an expression
|
|
func validateExpression(node *ExprNode) error {
|
|
if node == nil {
|
|
return fmt.Errorf("expression node is nil")
|
|
}
|
|
|
|
switch node.Type {
|
|
case TypeNumber:
|
|
return validateNumberNode(node)
|
|
case TypeString:
|
|
return validateStringNode(node)
|
|
case TypeField:
|
|
return validateFieldNode(node)
|
|
case TypeOperator:
|
|
return validateOperatorNode(node)
|
|
case TypeFunction:
|
|
return validateFunctionNode(node)
|
|
case TypeCase:
|
|
return validateCaseNode(node)
|
|
case TypeParenthesis:
|
|
return validateParenthesisNode(node)
|
|
default:
|
|
return fmt.Errorf("unknown node type: %s", node.Type)
|
|
}
|
|
}
|
|
|
|
// validateNumberNode validates a number node
|
|
func validateNumberNode(node *ExprNode) error {
|
|
if node.Value == "" {
|
|
return fmt.Errorf("number node has empty value")
|
|
}
|
|
|
|
// Check if it's a valid number format
|
|
if !isNumber(node.Value) {
|
|
return fmt.Errorf("invalid number format: %s", node.Value)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateStringNode validates a string node
|
|
func validateStringNode(node *ExprNode) error {
|
|
if node.Value == "" {
|
|
return fmt.Errorf("string node has empty value")
|
|
}
|
|
|
|
// Check if string is properly quoted
|
|
if !isStringLiteral(node.Value) {
|
|
return fmt.Errorf("invalid string literal format: %s", node.Value)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateFieldNode validates a field node
|
|
func validateFieldNode(node *ExprNode) error {
|
|
if node.Value == "" {
|
|
return fmt.Errorf("field node has empty value")
|
|
}
|
|
|
|
// Check if field name is a valid identifier
|
|
fieldName := node.Value
|
|
|
|
if !isValidFieldName(fieldName) {
|
|
return fmt.Errorf("invalid field name: %s", node.Value)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateOperatorNode validates an operator node
|
|
func validateOperatorNode(node *ExprNode) error {
|
|
if node.Value == "" {
|
|
return fmt.Errorf("operator node has empty value")
|
|
}
|
|
|
|
// Check if it's a valid operator
|
|
if !isOperator(node.Value) && !isComparisonOperator(node.Value) {
|
|
return fmt.Errorf("invalid operator: %s", node.Value)
|
|
}
|
|
|
|
// Check if it's a unary operator
|
|
if isUnaryOperator(node.Value) {
|
|
// Unary operators only need left operand
|
|
if node.Left == nil {
|
|
return fmt.Errorf("unary operator '%s' missing operand", node.Value)
|
|
}
|
|
// Unary operators should not have right operand
|
|
if node.Right != nil {
|
|
return fmt.Errorf("unary operator '%s' should not have right operand", node.Value)
|
|
}
|
|
// Validate left operand
|
|
return validateExpression(node.Left)
|
|
}
|
|
|
|
// Binary operators need both left and right operands
|
|
if node.Left == nil {
|
|
return fmt.Errorf("operator %s missing left operand", node.Value)
|
|
}
|
|
|
|
if node.Right == nil {
|
|
return fmt.Errorf("operator %s missing right operand", node.Value)
|
|
}
|
|
|
|
// Recursively validate operands
|
|
if err := validateExpression(node.Left); err != nil {
|
|
return fmt.Errorf("invalid left operand for operator %s: %v", node.Value, err)
|
|
}
|
|
|
|
if err := validateExpression(node.Right); err != nil {
|
|
return fmt.Errorf("invalid right operand for operator %s: %v", node.Value, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateFunctionNode validates a function node
|
|
// Use unified function registration system for validation
|
|
func validateFunctionNode(node *ExprNode) error {
|
|
if node.Value == "" {
|
|
return fmt.Errorf("function node has empty value")
|
|
}
|
|
|
|
// Check if function exists in registration system (using lowercase name)
|
|
fn, exists := functions.Get(strings.ToLower(node.Value))
|
|
if !exists {
|
|
return fmt.Errorf("unknown function: %s", node.Value)
|
|
}
|
|
|
|
// Use function's own validation logic for basic validation
|
|
// Create temporary argument array for validating argument count
|
|
tempArgs := make([]interface{}, len(node.Args))
|
|
if err := fn.Validate(tempArgs); err != nil {
|
|
return fmt.Errorf("function %s validation failed: %v", node.Value, err)
|
|
}
|
|
|
|
// Validate argument expressions
|
|
return validateFunctionArgs(node)
|
|
}
|
|
|
|
// validateFunctionArgs validates function arguments
|
|
func validateFunctionArgs(node *ExprNode) error {
|
|
for i, arg := range node.Args {
|
|
if err := validateExpression(arg); err != nil {
|
|
return fmt.Errorf("invalid argument %d for function %s: %v", i+1, node.Value, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validateParenthesisNode validates a parenthesis node
|
|
func validateParenthesisNode(node *ExprNode) error {
|
|
// Parenthesis node should have a Left child containing the inner expression
|
|
if node.Left == nil {
|
|
return fmt.Errorf("parenthesis node missing inner expression")
|
|
}
|
|
|
|
// Validate the inner expression
|
|
return validateExpression(node.Left)
|
|
}
|
|
|
|
// validateCaseNode validates a CASE expression node
|
|
func validateCaseNode(node *ExprNode) error {
|
|
if node.CaseExpr == nil {
|
|
return fmt.Errorf("CASE expression is missing")
|
|
}
|
|
|
|
caseExpr := node.CaseExpr
|
|
|
|
// Validate the value part of CASE expression (if it's a simple CASE)
|
|
if caseExpr.Value != nil {
|
|
if err := validateExpression(caseExpr.Value); err != nil {
|
|
return fmt.Errorf("invalid CASE value expression: %v", err)
|
|
}
|
|
}
|
|
|
|
// Validate WHEN clauses
|
|
if len(caseExpr.WhenClauses) == 0 {
|
|
return fmt.Errorf("CASE expression must have at least one WHEN clause")
|
|
}
|
|
|
|
for i, whenClause := range caseExpr.WhenClauses {
|
|
if err := validateExpression(whenClause.Condition); err != nil {
|
|
return fmt.Errorf("invalid WHEN condition %d: %v", i+1, err)
|
|
}
|
|
if err := validateExpression(whenClause.Result); err != nil {
|
|
return fmt.Errorf("invalid THEN result %d: %v", i+1, err)
|
|
}
|
|
}
|
|
|
|
// Validate ELSE clause (if exists)
|
|
if caseExpr.ElseResult != nil {
|
|
if err := validateExpression(caseExpr.ElseResult); err != nil {
|
|
return fmt.Errorf("invalid ELSE expression: %v", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isValidFieldName checks if it's a valid field name
|
|
// isValidFieldName validates if field name is valid
|
|
// Supports normal identifiers and backtick-enclosed field names
|
|
func isValidFieldName(name string) bool {
|
|
if name == "" {
|
|
return false
|
|
}
|
|
|
|
// If it's a backtick-enclosed field name, check if backticks are properly closed
|
|
if len(name) >= 2 && name[0] == '`' && name[len(name)-1] == '`' {
|
|
// Content inside backticks can contain any character (except backticks themselves)
|
|
inner := name[1 : len(name)-1]
|
|
if inner == "" {
|
|
return false // Empty backticks not allowed
|
|
}
|
|
// Check if there are backticks inside
|
|
for _, r := range inner {
|
|
if r == '`' {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Normal field name: can only contain letters, numbers, underscores (dots not allowed)
|
|
for i, r := range name {
|
|
// For non-ASCII characters, return false directly
|
|
if r > 127 {
|
|
return false
|
|
}
|
|
ch := byte(r)
|
|
if i == 0 {
|
|
// First character must be letter or underscore
|
|
if !isLetter(ch) && r != '_' {
|
|
return false
|
|
}
|
|
} else {
|
|
// Subsequent characters can be letters, numbers, underscores
|
|
if !isLetter(ch) && !isDigit(ch) && r != '_' {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// validateTokens validates the validity of token list
|
|
func validateTokens(tokens []string) error {
|
|
if len(tokens) == 0 {
|
|
return fmt.Errorf("empty token list")
|
|
}
|
|
|
|
// Check parentheses matching
|
|
if err := validateParentheses(tokens); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check order of operators and operands
|
|
if err := validateTokenOrder(tokens); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateParentheses validates parentheses matching
|
|
func validateParentheses(tokens []string) error {
|
|
stack := 0
|
|
for i, token := range tokens {
|
|
if token == "(" {
|
|
stack++
|
|
} else if token == ")" {
|
|
stack--
|
|
if stack < 0 {
|
|
return fmt.Errorf("unmatched closing parenthesis at position %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
if stack > 0 {
|
|
return fmt.Errorf("unmatched opening parenthesis")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateTokenOrder validates token order
|
|
func validateTokenOrder(tokens []string) error {
|
|
if len(tokens) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Check cannot start with operator (except unary operators)
|
|
firstToken := tokens[0]
|
|
if isOperator(firstToken) && !isUnaryOperator(firstToken) {
|
|
return fmt.Errorf("expression cannot start with operator: %s", firstToken)
|
|
}
|
|
|
|
// Check cannot end with operator
|
|
lastToken := tokens[len(tokens)-1]
|
|
if isOperator(lastToken) {
|
|
return fmt.Errorf("expression cannot end with operator: %s", lastToken)
|
|
}
|
|
|
|
// Check consecutive operators and consecutive operands
|
|
for i := 0; i < len(tokens)-1; i++ {
|
|
current := tokens[i]
|
|
next := tokens[i+1]
|
|
|
|
// Check consecutive operators
|
|
if isOperator(current) && isOperator(next) {
|
|
// Allowed combination: operator followed by unary operator
|
|
if !isUnaryOperator(next) {
|
|
return fmt.Errorf("consecutive operators not allowed: %s %s at position %d", current, next, i)
|
|
}
|
|
}
|
|
|
|
// Check consecutive operands (two non-operator, non-parenthesis tokens adjacent)
|
|
// Special handling for CASE expression keywords
|
|
if !isOperator(current) && !isOperator(next) &&
|
|
current != "(" && current != ")" && next != "(" && next != ")" &&
|
|
current != "," && next != "," &&
|
|
!isCaseKeyword(current) && !isCaseKeyword(next) {
|
|
return fmt.Errorf("consecutive operands not allowed: %s %s at position %d", current, next, i)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isCaseKeyword checks if it's a CASE expression keyword
|
|
func isCaseKeyword(token string) bool {
|
|
switch strings.ToUpper(token) {
|
|
case "CASE", "WHEN", "THEN", "ELSE", "END":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// validateSyntax validates expression syntax
|
|
func validateSyntax(expr string) error {
|
|
trimmed := strings.TrimSpace(expr)
|
|
if trimmed == "" {
|
|
return fmt.Errorf("empty expression")
|
|
}
|
|
|
|
// Check basic syntax errors
|
|
if strings.Contains(trimmed, "()") {
|
|
return fmt.Errorf("empty parentheses not allowed")
|
|
}
|
|
|
|
// Check mismatched parentheses
|
|
parenthesesCount := 0
|
|
for _, ch := range trimmed {
|
|
if ch == '(' {
|
|
parenthesesCount++
|
|
} else if ch == ')' {
|
|
parenthesesCount--
|
|
if parenthesesCount < 0 {
|
|
return fmt.Errorf("mismatched parentheses")
|
|
}
|
|
}
|
|
}
|
|
if parenthesesCount != 0 {
|
|
return fmt.Errorf("mismatched parentheses")
|
|
}
|
|
|
|
// Check consecutive operators
|
|
operators := []string{"+", "-", "*", "/", "%", "^", "=", "!=", "<>", ">", "<", ">=", "<="}
|
|
for _, op1 := range operators {
|
|
for _, op2 := range operators {
|
|
// Check consecutive operators (separated by spaces)
|
|
if strings.Contains(trimmed, " "+op1+" "+op2+" ") {
|
|
return fmt.Errorf("consecutive operators")
|
|
}
|
|
// Check directly adjacent operators (except allowed combinations)
|
|
if op1 != op2 && strings.Contains(trimmed, op1+op2) {
|
|
// Allow certain combinations like ">=, <=, !=, <>"
|
|
allowed := []string{">=", "<=", "!=", "<>"}
|
|
combination := op1 + op2
|
|
isAllowed := false
|
|
for _, allowedOp := range allowed {
|
|
if combination == allowedOp {
|
|
isAllowed = true
|
|
break
|
|
}
|
|
}
|
|
if !isAllowed {
|
|
return fmt.Errorf("consecutive operators")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if expression starts or ends with operator
|
|
for _, op := range operators {
|
|
if strings.HasPrefix(trimmed, op+" ") {
|
|
return fmt.Errorf("expression cannot start with operator")
|
|
}
|
|
if strings.HasSuffix(trimmed, " "+op) {
|
|
return fmt.Errorf("expression cannot end with operator")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateExpression public interface: validates expression string
|
|
func ValidateExpression(expr string) error {
|
|
// First validate syntax
|
|
if err := validateSyntax(expr); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Tokenize
|
|
tokens, err := tokenize(expr)
|
|
if err != nil {
|
|
return fmt.Errorf("tokenization error: %v", err)
|
|
}
|
|
|
|
// Validate tokens
|
|
if err := validateTokens(tokens); err != nil {
|
|
return fmt.Errorf("token validation error: %v", err)
|
|
}
|
|
|
|
// Parse to AST
|
|
node, err := parseExpression(tokens)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing error: %v", err)
|
|
}
|
|
|
|
// Validate AST
|
|
if err := validateExpression(node); err != nil {
|
|
return fmt.Errorf("expression validation error: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|