From f8b4924d03cf130d6932d16ed56cef81b818ea4f Mon Sep 17 00:00:00 2001 From: rulego-team Date: Fri, 13 Jun 2025 21:37:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E6=9D=A1=E4=BB=B6?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=88=AB=E5=90=8D=E8=A7=A3=E6=9E=90=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rsql/ast.go | 35 ++++++++++++++++++++++++++--------- rsql/parser.go | 20 ++++++++++++-------- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/rsql/ast.go b/rsql/ast.go index 8d2f792..2ae4160 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -215,11 +215,20 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy // 解析聚合函数,并返回表达式信息 func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.AggregateType, name string, expression string, allFields []string) { + // 特殊处理 CASE 表达式 + if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(exprStr)), "CASE") { + // CASE 表达式作为特殊的表达式处理 + if parsedExpr, err := expr.NewExpression(exprStr); err == nil { + allFields = parsedExpr.GetFields() + } + return "expression", "", exprStr, allFields + } + // 检查是否是嵌套函数 if hasNestedFunctions(exprStr) { // 嵌套函数情况,提取所有函数 funcs := extractAllFunctions(exprStr) - + // 查找聚合函数 var aggregationFunc string for _, funcName := range funcs { @@ -231,7 +240,7 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg } } } - + if aggregationFunc != "" { // 有聚合函数的嵌套表达式,整个表达式作为expression处理 if parsedExpr, err := expr.NewExpression(exprStr); err == nil { @@ -246,11 +255,21 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg return "expression", "", exprStr, allFields } } - + // 单一函数的原有逻辑 // 提取函数名 funcName := extractFunctionName(exprStr) if funcName == "" { + // 如果不是函数调用,但包含运算符或关键字,可能是表达式 + if strings.ContainsAny(exprStr, "+-*/<>=!&|") || + strings.Contains(strings.ToUpper(exprStr), "AND") || + strings.Contains(strings.ToUpper(exprStr), "OR") { + // 作为表达式处理 + if parsedExpr, err := expr.NewExpression(exprStr); err == nil { + allFields = parsedExpr.GetFields() + } + return "expression", "", exprStr, allFields + } return "", "", "", nil } @@ -277,8 +296,6 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg // 窗口函数:使用函数名作为聚合类型 return aggregator.AggregateType(funcName), name, expression, allFields - - case functions.TypeString, functions.TypeConversion, functions.TypeCustom, functions.TypeMath: // 字符串函数、转换函数、自定义函数、数学函数:在聚合查询中作为表达式处理 // 使用 "expression" 作为特殊的聚合类型,表示这是一个表达式计算 @@ -319,7 +336,7 @@ func extractFunctionName(expr string) string { // 提取表达式中的所有函数名 func extractAllFunctions(expr string) []string { var funcNames []string - + // 简单的函数名匹配 i := 0 for i < len(expr) { @@ -328,7 +345,7 @@ func extractAllFunctions(expr string) []string { for i < len(expr) && (expr[i] >= 'a' && expr[i] <= 'z' || expr[i] >= 'A' && expr[i] <= 'Z' || expr[i] == '_') { i++ } - + if i < len(expr) && expr[i] == '(' && i > start { // 找到可能的函数名 funcName := expr[start:i] @@ -336,12 +353,12 @@ func extractAllFunctions(expr string) []string { funcNames = append(funcNames, funcName) } } - + if i < len(expr) { i++ } } - + return funcNames } diff --git a/rsql/parser.go b/rsql/parser.go index 5ecca15..a26b7db 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -49,12 +49,12 @@ func (p *Parser) expectToken(expected TokenType, context string) (Token, error) ) err.Context = context p.errorRecovery.AddError(err) - + // 尝试错误恢复 if err.IsRecoverable() && p.errorRecovery.RecoverFromError(ErrorTypeUnexpectedToken) { return p.expectToken(expected, context) } - + return tok, err } return tok, nil @@ -164,7 +164,7 @@ func (p *Parser) createCombinedError() error { if len(errors) == 1 { return p.createDetailedError(errors[0]) } - + var builder strings.Builder builder.WriteString(fmt.Sprintf("Found %d parsing errors:\n", len(errors))) for i, err := range errors { @@ -234,6 +234,10 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error { break } + // 如果不是第一个token,添加空格分隔符 + if expr.Len() > 0 { + expr.WriteString(" ") + } expr.WriteString(currentToken.Value) currentToken = p.lexer.NextToken() } @@ -252,7 +256,7 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error { validator := NewFunctionValidator(p.errorRecovery) pos, _, _ := p.lexer.GetPosition() validator.ValidateExpression(field.Expression, pos-len(field.Expression)) - + stmt.Fields = append(stmt.Fields, field) } @@ -325,7 +329,7 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error { } } } - + // 验证WHERE条件中的函数 whereCondition := strings.Join(conditions, " ") if whereCondition != "" { @@ -333,7 +337,7 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error { pos, _, _ := p.lexer.GetPosition() validator.ValidateExpression(whereCondition, pos-len(whereCondition)) } - + stmt.Condition = whereCondition return nil } @@ -424,7 +428,7 @@ func (p *Parser) parseFrom(stmt *SelectStatement) error { func (p *Parser) parseGroupBy(stmt *SelectStatement) error { tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()) if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { - p.parseWindowFunction(stmt, tok.Value) + _ = p.parseWindowFunction(stmt, tok.Value) } if tok.Type == TokenGROUP { p.lexer.NextToken() // 跳过BY @@ -450,7 +454,7 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error { continue } if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { - p.parseWindowFunction(stmt, tok.Value) + _ = p.parseWindowFunction(stmt, tok.Value) continue }