fix:修复条件函数别名解析错误

This commit is contained in:
rulego-team
2025-06-13 21:37:39 +08:00
parent 07488300f5
commit f8b4924d03
2 changed files with 38 additions and 17 deletions
+26 -9
View File
@@ -215,11 +215,20 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy
// 解析聚合函数,并返回表达式信息
func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.AggregateType, name string, expression string, allFields []string) {
// 特殊处理 CASE 表达式
if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(exprStr)), "CASE") {
// CASE 表达式作为特殊的表达式处理
if parsedExpr, err := expr.NewExpression(exprStr); err == nil {
allFields = parsedExpr.GetFields()
}
return "expression", "", exprStr, allFields
}
// 检查是否是嵌套函数
if hasNestedFunctions(exprStr) {
// 嵌套函数情况,提取所有函数
funcs := extractAllFunctions(exprStr)
// 查找聚合函数
var aggregationFunc string
for _, funcName := range funcs {
@@ -231,7 +240,7 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg
}
}
}
if aggregationFunc != "" {
// 有聚合函数的嵌套表达式,整个表达式作为expression处理
if parsedExpr, err := expr.NewExpression(exprStr); err == nil {
@@ -246,11 +255,21 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg
return "expression", "", exprStr, allFields
}
}
// 单一函数的原有逻辑
// 提取函数名
funcName := extractFunctionName(exprStr)
if funcName == "" {
// 如果不是函数调用,但包含运算符或关键字,可能是表达式
if strings.ContainsAny(exprStr, "+-*/<>=!&|") ||
strings.Contains(strings.ToUpper(exprStr), "AND") ||
strings.Contains(strings.ToUpper(exprStr), "OR") {
// 作为表达式处理
if parsedExpr, err := expr.NewExpression(exprStr); err == nil {
allFields = parsedExpr.GetFields()
}
return "expression", "", exprStr, allFields
}
return "", "", "", nil
}
@@ -277,8 +296,6 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg
// 窗口函数:使用函数名作为聚合类型
return aggregator.AggregateType(funcName), name, expression, allFields
case functions.TypeString, functions.TypeConversion, functions.TypeCustom, functions.TypeMath:
// 字符串函数、转换函数、自定义函数、数学函数:在聚合查询中作为表达式处理
// 使用 "expression" 作为特殊的聚合类型,表示这是一个表达式计算
@@ -319,7 +336,7 @@ func extractFunctionName(expr string) string {
// 提取表达式中的所有函数名
func extractAllFunctions(expr string) []string {
var funcNames []string
// 简单的函数名匹配
i := 0
for i < len(expr) {
@@ -328,7 +345,7 @@ func extractAllFunctions(expr string) []string {
for i < len(expr) && (expr[i] >= 'a' && expr[i] <= 'z' || expr[i] >= 'A' && expr[i] <= 'Z' || expr[i] == '_') {
i++
}
if i < len(expr) && expr[i] == '(' && i > start {
// 找到可能的函数名
funcName := expr[start:i]
@@ -336,12 +353,12 @@ func extractAllFunctions(expr string) []string {
funcNames = append(funcNames, funcName)
}
}
if i < len(expr) {
i++
}
}
return funcNames
}
+12 -8
View File
@@ -49,12 +49,12 @@ func (p *Parser) expectToken(expected TokenType, context string) (Token, error)
)
err.Context = context
p.errorRecovery.AddError(err)
// 尝试错误恢复
if err.IsRecoverable() && p.errorRecovery.RecoverFromError(ErrorTypeUnexpectedToken) {
return p.expectToken(expected, context)
}
return tok, err
}
return tok, nil
@@ -164,7 +164,7 @@ func (p *Parser) createCombinedError() error {
if len(errors) == 1 {
return p.createDetailedError(errors[0])
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Found %d parsing errors:\n", len(errors)))
for i, err := range errors {
@@ -234,6 +234,10 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error {
break
}
// 如果不是第一个token,添加空格分隔符
if expr.Len() > 0 {
expr.WriteString(" ")
}
expr.WriteString(currentToken.Value)
currentToken = p.lexer.NextToken()
}
@@ -252,7 +256,7 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error {
validator := NewFunctionValidator(p.errorRecovery)
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(field.Expression, pos-len(field.Expression))
stmt.Fields = append(stmt.Fields, field)
}
@@ -325,7 +329,7 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error {
}
}
}
// 验证WHERE条件中的函数
whereCondition := strings.Join(conditions, " ")
if whereCondition != "" {
@@ -333,7 +337,7 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error {
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(whereCondition, pos-len(whereCondition))
}
stmt.Condition = whereCondition
return nil
}
@@ -424,7 +428,7 @@ func (p *Parser) parseFrom(stmt *SelectStatement) error {
func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
p.parseWindowFunction(stmt, tok.Value)
_ = p.parseWindowFunction(stmt, tok.Value)
}
if tok.Type == TokenGROUP {
p.lexer.NextToken() // 跳过BY
@@ -450,7 +454,7 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
continue
}
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
p.parseWindowFunction(stmt, tok.Value)
_ = p.parseWindowFunction(stmt, tok.Value)
continue
}