From 368b33ae34b8a40b7a302dc970927e5781ae7e7d Mon Sep 17 00:00:00 2001 From: rulego-team Date: Sun, 15 Jun 2025 21:32:44 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E5=88=AB=E5=90=8D?= =?UTF-8?q?=E6=98=A0=E5=B0=84=E5=92=8C=E5=AD=97=E6=AE=B5=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E7=A9=BA=E6=A0=BC=E8=A7=A3=E6=9E=90=E9=94=99=E8=AF=AF=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aggregator/group_aggregator.go | 67 ++++++++++++++++++++++-------- rsql/ast.go | 19 +++++---- rsql/parser.go | 64 +++++++++++++++++++++++++++- streamsql_custom_functions_test.go | 5 +-- 4 files changed, 126 insertions(+), 29 deletions(-) diff --git a/aggregator/group_aggregator.go b/aggregator/group_aggregator.go index 23e9d68..c4befa8 100644 --- a/aggregator/group_aggregator.go +++ b/aggregator/group_aggregator.go @@ -40,34 +40,67 @@ type ExpressionEvaluator struct { } func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator { + aggregators := make(map[string]AggregatorFunction) - // 重新组织 fieldMap 和 fieldAlias - // 测试中:fieldMap: {"temperature": Sum}, fieldAlias: {"temperature": "temperature_sum"} - // 这意味着:输入字段"temperature",聚合类型Sum,输出别名"temperature_sum" + // 处理两种可能的调用模式: + // 1. SQL解析模式:fieldMap是输出字段名->聚合类型,fieldAlias是输出字段名->输入字段名 + // 2. 直接测试模式:fieldMap是输入字段名->聚合类型,fieldAlias是输入字段名->输出字段名 - // 创建新的映射:输出字段名 -> 聚合类型 - newFieldMap := make(map[string]AggregateType) - // 创建新的别名映射:输出字段名 -> 输入字段名 - newFieldAlias := make(map[string]string) + // 创建最终的映射 + finalFieldMap := make(map[string]AggregateType) + finalFieldAlias := make(map[string]string) - for inputField, aggType := range fieldMap { - outputField := inputField // 默认输出字段名等于输入字段名 - if alias, exists := fieldAlias[inputField]; exists { - outputField = alias // 如果有别名,使用别名作为输出字段名 + // 简化的检测逻辑: + // 在直接测试模式中,fieldAlias 的值通常包含 "_sum", "_avg" 等后缀 + // 在SQL解析模式中,fieldAlias 的值是实际的数据字段名(如 "temperature") + + isSQLMode := false + if len(fieldAlias) > 0 { + // 检查是否有任何 fieldAlias 的值看起来像 SQL 解析模式(不包含聚合后缀) + for _, aliasValue := range fieldAlias { + // 如果值不包含典型的聚合后缀,可能是SQL模式 + if !strings.Contains(aliasValue, "_sum") && + !strings.Contains(aliasValue, "_avg") && + !strings.Contains(aliasValue, "_min") && + !strings.Contains(aliasValue, "_max") && + !strings.Contains(aliasValue, "_count") { + isSQLMode = true + break + } } + } - newFieldMap[outputField] = aggType - newFieldAlias[outputField] = inputField - aggregators[outputField] = CreateBuiltinAggregator(aggType) + if isSQLMode { + // SQL解析模式:fieldMap是输出字段名->聚合类型,fieldAlias是输出字段名->输入字段名 + finalFieldMap = fieldMap + finalFieldAlias = fieldAlias + } else { + // 直接测试模式:fieldMap是输入字段名->聚合类型,fieldAlias是输入字段名->输出字段名 + for inputField, aggType := range fieldMap { + outputField := inputField // 默认输出字段名等于输入字段名 + + // fieldAlias提供了:输入字段名 -> 输出别名的映射 + if alias, exists := fieldAlias[inputField]; exists { + outputField = alias + } + + finalFieldMap[outputField] = aggType + finalFieldAlias[outputField] = inputField + } + } + + // 创建聚合器 + for outputField := range finalFieldMap { + aggregators[outputField] = CreateBuiltinAggregator(finalFieldMap[outputField]) } return &GroupAggregator{ - fieldMap: newFieldMap, // 输出字段名 -> 聚合类型 + fieldMap: finalFieldMap, // 输出字段名 -> 聚合类型 groupFields: groupFields, aggregators: aggregators, groups: make(map[string]map[string]AggregatorFunction), - fieldAlias: newFieldAlias, // 输出字段名 -> 输入字段名 + fieldAlias: finalFieldAlias, // 输出字段名 -> 输入字段名 expressions: make(map[string]*ExpressionEvaluator), } } @@ -214,8 +247,8 @@ func (ga *GroupAggregator) Add(data interface{}) error { // 获取实际的输入字段名 // field现在是输出字段名(可能是别名),需要找到对应的输入字段名 inputFieldName := field + // 在聚合器内部,fieldAlias的映射方向是:输出字段名 -> 输入字段名 if mappedField, exists := ga.fieldAlias[field]; exists { - // 如果field是别名,获取实际输入字段名 inputFieldName = mappedField } diff --git a/rsql/ast.go b/rsql/ast.go index 2ae4160..df6a7f6 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -191,21 +191,24 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy for _, f := range fields { if alias := f.Alias; alias != "" { t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression) - if n != "" { - selectFields[n] = t - fieldMap[n] = alias + if t != "" { + // 使用别名作为聚合器的key,而不是字段名 + selectFields[alias] = t + if n != "" { + fieldMap[n] = alias + } else { + // 当没有特定字段名时,使用表达式或别名 + fieldMap[alias] = alias + } // 如果存在表达式,保存表达式信息 if expression != "" { - fieldExpressions[n] = types.FieldExpression{ + fieldExpressions[alias] = types.FieldExpression{ Field: n, Expression: expression, Fields: allFields, } } - } else if t != "" { - // 只有在聚合类型非空时才添加 - selectFields[alias] = t } // 如果聚合类型和字段名都为空,不做处理,避免空聚合器类型 } @@ -617,7 +620,7 @@ func buildSelectFieldsWithExpressions(fields []Field) ( // 使用别名作为键,这样每个聚合函数都有唯一的键 selectFields[alias] = t - // 字段映射:别名 -> 输入字段名 + // 字段映射:输出字段名 -> 输入字段名(直接为聚合器准备正确的映射) if n != "" { fieldMap[alias] = n } else { diff --git a/rsql/parser.go b/rsql/parser.go index a26b7db..e6af6d5 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -149,6 +149,21 @@ func (p *Parser) Parse() (*SelectStatement, error) { return stmt, nil } +// isKeyword 检查给定的字符串是否是SQL关键字 +func isKeyword(word string) bool { + keywords := map[string]bool{ + "SELECT": true, "FROM": true, "WHERE": true, "GROUP": true, "BY": true, + "ORDER": true, "HAVING": true, "LIMIT": true, "WITH": true, "AS": true, + "CASE": true, "WHEN": true, "THEN": true, "ELSE": true, "END": true, + "AND": true, "OR": true, "NOT": true, "IN": true, "IS": true, "NULL": true, + "DISTINCT": true, "COUNT": true, "SUM": true, "AVG": true, "MIN": true, "MAX": true, + "INNER": true, "LEFT": true, "RIGHT": true, "FULL": true, "OUTER": true, "JOIN": true, + "ON": true, "UNION": true, "ALL": true, "EXCEPT": true, "INTERSECT": true, + "EXISTS": true, "BETWEEN": true, "LIKE": true, "ASC": true, "DESC": true, + } + return keywords[word] +} + // createDetailedError 创建详细的错误信息 func (p *Parser) createDetailedError(err error) error { if parseErr, ok := err.(*ParseError); ok { @@ -235,8 +250,55 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error { } // 如果不是第一个token,添加空格分隔符 + // 但要注意特殊情况:某些token之间不应该加空格 if expr.Len() > 0 { - expr.WriteString(" ") + shouldAddSpace := true + + // 获取前一个token的信息 + exprStr := expr.String() + lastChar := exprStr[len(exprStr)-1:] + + // 以下情况不添加空格: + // 1. 函数名和左括号之间 + // 2. 标识符和数字之间(如 x1, y1) + // 3. 数字和标识符之间 + // 4. 左括号之后 + // 5. 右括号之前 + if currentToken.Type == TokenLParen && lastChar != " " && lastChar != "(" { + // 函数名和左括号之间不加空格 + shouldAddSpace = false + } else if lastChar == "(" || currentToken.Type == TokenRParen { + // 左括号之后或右括号之前不加空格 + shouldAddSpace = false + } else if len(exprStr) > 0 && currentToken.Type == TokenNumber { + // 检查前一个字符是否是字母(标识符的一部分),且前面没有空格 + // 这主要处理 x1, y1 这类标识符,但排除 THEN 1, ELSE 0 这类情况 + if ((lastChar[0] >= 'a' && lastChar[0] <= 'z') || (lastChar[0] >= 'A' && lastChar[0] <= 'Z') || lastChar[0] == '_') && + !strings.HasSuffix(exprStr, " ") { + // 进一步检查:如果前面是SQL关键字,则应该加空格 + words := strings.Fields(exprStr) + if len(words) > 0 { + lastWord := strings.ToUpper(words[len(words)-1]) + // 如果是关键字,应该加空格 + if isKeyword(lastWord) { + shouldAddSpace = true + } else { + shouldAddSpace = false + } + } else { + shouldAddSpace = false + } + } + } else if len(exprStr) > 0 && currentToken.Type == TokenIdent { + // 检查前一个字符是否是数字,且前面没有空格 + if (lastChar[0] >= '0' && lastChar[0] <= '9') && !strings.HasSuffix(exprStr, " ") { + shouldAddSpace = false + } + } + + if shouldAddSpace { + expr.WriteString(" ") + } } expr.WriteString(currentToken.Value) currentToken = p.lexer.NextToken() diff --git a/streamsql_custom_functions_test.go b/streamsql_custom_functions_test.go index 497e58e..0fd7e4f 100644 --- a/streamsql_custom_functions_test.go +++ b/streamsql_custom_functions_test.go @@ -3,12 +3,13 @@ package streamsql import ( "encoding/json" "fmt" - "github.com/rulego/streamsql/utils/cast" "math" "net" "testing" "time" + "github.com/rulego/streamsql/utils/cast" + "github.com/rulego/streamsql/aggregator" "github.com/rulego/streamsql/expr" "github.com/rulego/streamsql/functions" @@ -43,9 +44,7 @@ func TestCustomMathFunctions(t *testing.T) { func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { x1 := cast.ToFloat64(args[0]) y1 := cast.ToFloat64(args[1]) - x2 := cast.ToFloat64(args[2]) - y2 := cast.ToFloat64(args[3]) distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2))