fix:修复别名映射和字段增加空格解析错误问题

This commit is contained in:
rulego-team
2025-06-15 21:32:44 +08:00
parent f8b4924d03
commit 368b33ae34
4 changed files with 126 additions and 29 deletions

View File

@ -40,34 +40,67 @@ type ExpressionEvaluator struct {
}
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator {
aggregators := make(map[string]AggregatorFunction)
// 重新组织 fieldMap 和 fieldAlias
// 测试中fieldMap: {"temperature": Sum}, fieldAlias: {"temperature": "temperature_sum"}
// 这意味着:输入字段"temperature"聚合类型Sum输出别名"temperature_sum"
// 处理两种可能的调用模式:
// 1. SQL解析模式fieldMap是输出字段名->聚合类型fieldAlias是输出字段名->输入字段名
// 2. 直接测试模式fieldMap是输入字段名->聚合类型fieldAlias是输入字段名->输出字段名
// 创建的映射:输出字段名 -> 聚合类型
newFieldMap := make(map[string]AggregateType)
// 创建新的别名映射:输出字段名 -> 输入字段名
newFieldAlias := make(map[string]string)
// 创建最终的映射
finalFieldMap := make(map[string]AggregateType)
finalFieldAlias := make(map[string]string)
for inputField, aggType := range fieldMap {
outputField := inputField // 默认输出字段名等于输入字段名
if alias, exists := fieldAlias[inputField]; exists {
outputField = alias // 如果有别名,使用别名作为输出字段名
// 简化的检测逻辑:
// 在直接测试模式中fieldAlias 的值通常包含 "_sum", "_avg" 等后缀
// 在SQL解析模式中fieldAlias 的值是实际的数据字段名(如 "temperature"
isSQLMode := false
if len(fieldAlias) > 0 {
// 检查是否有任何 fieldAlias 的值看起来像 SQL 解析模式(不包含聚合后缀)
for _, aliasValue := range fieldAlias {
// 如果值不包含典型的聚合后缀可能是SQL模式
if !strings.Contains(aliasValue, "_sum") &&
!strings.Contains(aliasValue, "_avg") &&
!strings.Contains(aliasValue, "_min") &&
!strings.Contains(aliasValue, "_max") &&
!strings.Contains(aliasValue, "_count") {
isSQLMode = true
break
}
}
}
newFieldMap[outputField] = aggType
newFieldAlias[outputField] = inputField
aggregators[outputField] = CreateBuiltinAggregator(aggType)
if isSQLMode {
// SQL解析模式fieldMap是输出字段名->聚合类型fieldAlias是输出字段名->输入字段名
finalFieldMap = fieldMap
finalFieldAlias = fieldAlias
} else {
// 直接测试模式fieldMap是输入字段名->聚合类型fieldAlias是输入字段名->输出字段名
for inputField, aggType := range fieldMap {
outputField := inputField // 默认输出字段名等于输入字段名
// fieldAlias提供了输入字段名 -> 输出别名的映射
if alias, exists := fieldAlias[inputField]; exists {
outputField = alias
}
finalFieldMap[outputField] = aggType
finalFieldAlias[outputField] = inputField
}
}
// 创建聚合器
for outputField := range finalFieldMap {
aggregators[outputField] = CreateBuiltinAggregator(finalFieldMap[outputField])
}
return &GroupAggregator{
fieldMap: newFieldMap, // 输出字段名 -> 聚合类型
fieldMap: finalFieldMap, // 输出字段名 -> 聚合类型
groupFields: groupFields,
aggregators: aggregators,
groups: make(map[string]map[string]AggregatorFunction),
fieldAlias: newFieldAlias, // 输出字段名 -> 输入字段名
fieldAlias: finalFieldAlias, // 输出字段名 -> 输入字段名
expressions: make(map[string]*ExpressionEvaluator),
}
}
@ -214,8 +247,8 @@ func (ga *GroupAggregator) Add(data interface{}) error {
// 获取实际的输入字段名
// field现在是输出字段名可能是别名需要找到对应的输入字段名
inputFieldName := field
// 在聚合器内部fieldAlias的映射方向是输出字段名 -> 输入字段名
if mappedField, exists := ga.fieldAlias[field]; exists {
// 如果field是别名获取实际输入字段名
inputFieldName = mappedField
}

View File

@ -191,21 +191,24 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy
for _, f := range fields {
if alias := f.Alias; alias != "" {
t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression)
if n != "" {
selectFields[n] = t
fieldMap[n] = alias
if t != "" {
// 使用别名作为聚合器的key而不是字段名
selectFields[alias] = t
if n != "" {
fieldMap[n] = alias
} else {
// 当没有特定字段名时,使用表达式或别名
fieldMap[alias] = alias
}
// 如果存在表达式,保存表达式信息
if expression != "" {
fieldExpressions[n] = types.FieldExpression{
fieldExpressions[alias] = types.FieldExpression{
Field: n,
Expression: expression,
Fields: allFields,
}
}
} else if t != "" {
// 只有在聚合类型非空时才添加
selectFields[alias] = t
}
// 如果聚合类型和字段名都为空,不做处理,避免空聚合器类型
}
@ -617,7 +620,7 @@ func buildSelectFieldsWithExpressions(fields []Field) (
// 使用别名作为键,这样每个聚合函数都有唯一的键
selectFields[alias] = t
// 字段映射:名 -> 输入字段名
// 字段映射:输出字段名 -> 输入字段名(直接为聚合器准备正确的映射)
if n != "" {
fieldMap[alias] = n
} else {

View File

@ -149,6 +149,21 @@ func (p *Parser) Parse() (*SelectStatement, error) {
return stmt, nil
}
// isKeyword 检查给定的字符串是否是SQL关键字
func isKeyword(word string) bool {
keywords := map[string]bool{
"SELECT": true, "FROM": true, "WHERE": true, "GROUP": true, "BY": true,
"ORDER": true, "HAVING": true, "LIMIT": true, "WITH": true, "AS": true,
"CASE": true, "WHEN": true, "THEN": true, "ELSE": true, "END": true,
"AND": true, "OR": true, "NOT": true, "IN": true, "IS": true, "NULL": true,
"DISTINCT": true, "COUNT": true, "SUM": true, "AVG": true, "MIN": true, "MAX": true,
"INNER": true, "LEFT": true, "RIGHT": true, "FULL": true, "OUTER": true, "JOIN": true,
"ON": true, "UNION": true, "ALL": true, "EXCEPT": true, "INTERSECT": true,
"EXISTS": true, "BETWEEN": true, "LIKE": true, "ASC": true, "DESC": true,
}
return keywords[word]
}
// createDetailedError 创建详细的错误信息
func (p *Parser) createDetailedError(err error) error {
if parseErr, ok := err.(*ParseError); ok {
@ -235,8 +250,55 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error {
}
// 如果不是第一个token添加空格分隔符
// 但要注意特殊情况某些token之间不应该加空格
if expr.Len() > 0 {
expr.WriteString(" ")
shouldAddSpace := true
// 获取前一个token的信息
exprStr := expr.String()
lastChar := exprStr[len(exprStr)-1:]
// 以下情况不添加空格:
// 1. 函数名和左括号之间
// 2. 标识符和数字之间(如 x1, y1
// 3. 数字和标识符之间
// 4. 左括号之后
// 5. 右括号之前
if currentToken.Type == TokenLParen && lastChar != " " && lastChar != "(" {
// 函数名和左括号之间不加空格
shouldAddSpace = false
} else if lastChar == "(" || currentToken.Type == TokenRParen {
// 左括号之后或右括号之前不加空格
shouldAddSpace = false
} else if len(exprStr) > 0 && currentToken.Type == TokenNumber {
// 检查前一个字符是否是字母(标识符的一部分),且前面没有空格
// 这主要处理 x1, y1 这类标识符,但排除 THEN 1, ELSE 0 这类情况
if ((lastChar[0] >= 'a' && lastChar[0] <= 'z') || (lastChar[0] >= 'A' && lastChar[0] <= 'Z') || lastChar[0] == '_') &&
!strings.HasSuffix(exprStr, " ") {
// 进一步检查如果前面是SQL关键字则应该加空格
words := strings.Fields(exprStr)
if len(words) > 0 {
lastWord := strings.ToUpper(words[len(words)-1])
// 如果是关键字,应该加空格
if isKeyword(lastWord) {
shouldAddSpace = true
} else {
shouldAddSpace = false
}
} else {
shouldAddSpace = false
}
}
} else if len(exprStr) > 0 && currentToken.Type == TokenIdent {
// 检查前一个字符是否是数字,且前面没有空格
if (lastChar[0] >= '0' && lastChar[0] <= '9') && !strings.HasSuffix(exprStr, " ") {
shouldAddSpace = false
}
}
if shouldAddSpace {
expr.WriteString(" ")
}
}
expr.WriteString(currentToken.Value)
currentToken = p.lexer.NextToken()

View File

@ -3,12 +3,13 @@ package streamsql
import (
"encoding/json"
"fmt"
"github.com/rulego/streamsql/utils/cast"
"math"
"net"
"testing"
"time"
"github.com/rulego/streamsql/utils/cast"
"github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/expr"
"github.com/rulego/streamsql/functions"
@ -43,9 +44,7 @@ func TestCustomMathFunctions(t *testing.T) {
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
x1 := cast.ToFloat64(args[0])
y1 := cast.ToFloat64(args[1])
x2 := cast.ToFloat64(args[2])
y2 := cast.ToFloat64(args[3])
distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2))