From a5e4967021aa704c7ea2081229c2b5d8908cf5ad Mon Sep 17 00:00:00 2001 From: rulego-team Date: Wed, 11 Jun 2025 18:45:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=AE=8C=E5=96=84=E5=B5=8C=E5=A5=97?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rsql/ast.go | 94 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 76 insertions(+), 18 deletions(-) diff --git a/rsql/ast.go b/rsql/ast.go index 5811697..8d2f792 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -151,9 +151,7 @@ func isAggregationFunction(expr string) bool { case functions.TypeWindow: // 窗口函数需要聚合处理 return true - case functions.TypeMath: - // 数学函数在聚合上下文中需要聚合处理 - return true + default: // 其他类型的函数(字符串、转换等)不需要聚合处理 return false @@ -192,7 +190,7 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy for _, f := range fields { if alias := f.Alias; alias != "" { - t, n, expression, allFields := parseAggregateTypeWithExpression(f.Expression) + t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression) if n != "" { selectFields[n] = t fieldMap[n] = alias @@ -216,7 +214,40 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy } // 解析聚合函数,并返回表达式信息 -func parseAggregateTypeWithExpression(exprStr string) (aggType aggregator.AggregateType, name string, expression string, allFields []string) { +func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.AggregateType, name string, expression string, allFields []string) { + // 检查是否是嵌套函数 + if hasNestedFunctions(exprStr) { + // 嵌套函数情况,提取所有函数 + funcs := extractAllFunctions(exprStr) + + // 查找聚合函数 + var aggregationFunc string + for _, funcName := range funcs { + if fn, exists := functions.Get(funcName); exists { + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + aggregationFunc = funcName + break + } + } + } + + if aggregationFunc != "" { + // 有聚合函数的嵌套表达式,整个表达式作为expression处理 + if parsedExpr, err := expr.NewExpression(exprStr); err == nil { + allFields = parsedExpr.GetFields() + } + return aggregator.AggregateType(aggregationFunc), "", exprStr, allFields + } else { + // 没有聚合函数的嵌套表达式,作为普通表达式处理 + if parsedExpr, err := expr.NewExpression(exprStr); err == nil { + allFields = parsedExpr.GetFields() + } + return "expression", "", exprStr, allFields + } + } + + // 单一函数的原有逻辑 // 提取函数名 funcName := extractFunctionName(exprStr) if funcName == "" { @@ -246,18 +277,10 @@ func parseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg // 窗口函数:使用函数名作为聚合类型 return aggregator.AggregateType(funcName), name, expression, allFields - case functions.TypeMath: - // 数学函数:在聚合上下文中使用avg作为聚合类型 - if expression == "" { - expression = exprStr - if parsedExpr, err := expr.NewExpression(exprStr); err == nil { - allFields = parsedExpr.GetFields() - } - } - return "avg", name, expression, allFields - case functions.TypeString, functions.TypeConversion, functions.TypeCustom: - // 字符串函数、转换函数、自定义函数:在聚合查询中作为表达式处理 + + case functions.TypeString, functions.TypeConversion, functions.TypeCustom, functions.TypeMath: + // 字符串函数、转换函数、自定义函数、数学函数:在聚合查询中作为表达式处理 // 使用 "expression" 作为特殊的聚合类型,表示这是一个表达式计算 if expression == "" { expression = exprStr @@ -293,6 +316,41 @@ func extractFunctionName(expr string) string { return funcName } +// 提取表达式中的所有函数名 +func extractAllFunctions(expr string) []string { + var funcNames []string + + // 简单的函数名匹配 + i := 0 + for i < len(expr) { + // 查找函数名模式 + start := i + for i < len(expr) && (expr[i] >= 'a' && expr[i] <= 'z' || expr[i] >= 'A' && expr[i] <= 'Z' || expr[i] == '_') { + i++ + } + + if i < len(expr) && expr[i] == '(' && i > start { + // 找到可能的函数名 + funcName := expr[start:i] + if _, exists := functions.Get(funcName); exists { + funcNames = append(funcNames, funcName) + } + } + + if i < len(expr) { + i++ + } + } + + return funcNames +} + +// 检查表达式是否包含嵌套函数 +func hasNestedFunctions(expr string) bool { + funcs := extractAllFunctions(expr) + return len(funcs) > 1 +} + // 提取聚合函数字段,并解析表达式信息 func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName string, expression string, allFields []string) { start := strings.Index(strings.ToLower(exprStr), strings.ToLower(funcName)+"(") @@ -537,7 +595,7 @@ func buildSelectFieldsWithExpressions(fields []Field) ( for _, f := range fields { if alias := f.Alias; alias != "" { - t, n, expression, allFields := parseAggregateTypeWithExpression(f.Expression) + t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression) if t != "" { // 使用别名作为键,这样每个聚合函数都有唯一的键 selectFields[alias] = t @@ -561,7 +619,7 @@ func buildSelectFieldsWithExpressions(fields []Field) ( } } else { // 没有别名的情况,使用表达式本身作为字段名 - t, n, expression, allFields := parseAggregateTypeWithExpression(f.Expression) + t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression) if t != "" && n != "" { selectFields[n] = t fieldMap[n] = n