mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-07-05 07:39:38 +00:00
fix:修复别名映射和字段增加空格解析错误问题
This commit is contained in:
@ -40,34 +40,67 @@ type ExpressionEvaluator struct {
|
||||
}
|
||||
|
||||
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator {
|
||||
|
||||
aggregators := make(map[string]AggregatorFunction)
|
||||
|
||||
// 重新组织 fieldMap 和 fieldAlias
|
||||
// 测试中:fieldMap: {"temperature": Sum}, fieldAlias: {"temperature": "temperature_sum"}
|
||||
// 这意味着:输入字段"temperature",聚合类型Sum,输出别名"temperature_sum"
|
||||
// 处理两种可能的调用模式:
|
||||
// 1. SQL解析模式:fieldMap是输出字段名->聚合类型,fieldAlias是输出字段名->输入字段名
|
||||
// 2. 直接测试模式:fieldMap是输入字段名->聚合类型,fieldAlias是输入字段名->输出字段名
|
||||
|
||||
// 创建新的映射:输出字段名 -> 聚合类型
|
||||
newFieldMap := make(map[string]AggregateType)
|
||||
// 创建新的别名映射:输出字段名 -> 输入字段名
|
||||
newFieldAlias := make(map[string]string)
|
||||
// 创建最终的映射
|
||||
finalFieldMap := make(map[string]AggregateType)
|
||||
finalFieldAlias := make(map[string]string)
|
||||
|
||||
for inputField, aggType := range fieldMap {
|
||||
outputField := inputField // 默认输出字段名等于输入字段名
|
||||
if alias, exists := fieldAlias[inputField]; exists {
|
||||
outputField = alias // 如果有别名,使用别名作为输出字段名
|
||||
// 简化的检测逻辑:
|
||||
// 在直接测试模式中,fieldAlias 的值通常包含 "_sum", "_avg" 等后缀
|
||||
// 在SQL解析模式中,fieldAlias 的值是实际的数据字段名(如 "temperature")
|
||||
|
||||
isSQLMode := false
|
||||
if len(fieldAlias) > 0 {
|
||||
// 检查是否有任何 fieldAlias 的值看起来像 SQL 解析模式(不包含聚合后缀)
|
||||
for _, aliasValue := range fieldAlias {
|
||||
// 如果值不包含典型的聚合后缀,可能是SQL模式
|
||||
if !strings.Contains(aliasValue, "_sum") &&
|
||||
!strings.Contains(aliasValue, "_avg") &&
|
||||
!strings.Contains(aliasValue, "_min") &&
|
||||
!strings.Contains(aliasValue, "_max") &&
|
||||
!strings.Contains(aliasValue, "_count") {
|
||||
isSQLMode = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newFieldMap[outputField] = aggType
|
||||
newFieldAlias[outputField] = inputField
|
||||
aggregators[outputField] = CreateBuiltinAggregator(aggType)
|
||||
if isSQLMode {
|
||||
// SQL解析模式:fieldMap是输出字段名->聚合类型,fieldAlias是输出字段名->输入字段名
|
||||
finalFieldMap = fieldMap
|
||||
finalFieldAlias = fieldAlias
|
||||
} else {
|
||||
// 直接测试模式:fieldMap是输入字段名->聚合类型,fieldAlias是输入字段名->输出字段名
|
||||
for inputField, aggType := range fieldMap {
|
||||
outputField := inputField // 默认输出字段名等于输入字段名
|
||||
|
||||
// fieldAlias提供了:输入字段名 -> 输出别名的映射
|
||||
if alias, exists := fieldAlias[inputField]; exists {
|
||||
outputField = alias
|
||||
}
|
||||
|
||||
finalFieldMap[outputField] = aggType
|
||||
finalFieldAlias[outputField] = inputField
|
||||
}
|
||||
}
|
||||
|
||||
// 创建聚合器
|
||||
for outputField := range finalFieldMap {
|
||||
aggregators[outputField] = CreateBuiltinAggregator(finalFieldMap[outputField])
|
||||
}
|
||||
|
||||
return &GroupAggregator{
|
||||
fieldMap: newFieldMap, // 输出字段名 -> 聚合类型
|
||||
fieldMap: finalFieldMap, // 输出字段名 -> 聚合类型
|
||||
groupFields: groupFields,
|
||||
aggregators: aggregators,
|
||||
groups: make(map[string]map[string]AggregatorFunction),
|
||||
fieldAlias: newFieldAlias, // 输出字段名 -> 输入字段名
|
||||
fieldAlias: finalFieldAlias, // 输出字段名 -> 输入字段名
|
||||
expressions: make(map[string]*ExpressionEvaluator),
|
||||
}
|
||||
}
|
||||
@ -214,8 +247,8 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
// 获取实际的输入字段名
|
||||
// field现在是输出字段名(可能是别名),需要找到对应的输入字段名
|
||||
inputFieldName := field
|
||||
// 在聚合器内部,fieldAlias的映射方向是:输出字段名 -> 输入字段名
|
||||
if mappedField, exists := ga.fieldAlias[field]; exists {
|
||||
// 如果field是别名,获取实际输入字段名
|
||||
inputFieldName = mappedField
|
||||
}
|
||||
|
||||
|
19
rsql/ast.go
19
rsql/ast.go
@ -191,21 +191,24 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy
|
||||
for _, f := range fields {
|
||||
if alias := f.Alias; alias != "" {
|
||||
t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression)
|
||||
if n != "" {
|
||||
selectFields[n] = t
|
||||
fieldMap[n] = alias
|
||||
if t != "" {
|
||||
// 使用别名作为聚合器的key,而不是字段名
|
||||
selectFields[alias] = t
|
||||
if n != "" {
|
||||
fieldMap[n] = alias
|
||||
} else {
|
||||
// 当没有特定字段名时,使用表达式或别名
|
||||
fieldMap[alias] = alias
|
||||
}
|
||||
|
||||
// 如果存在表达式,保存表达式信息
|
||||
if expression != "" {
|
||||
fieldExpressions[n] = types.FieldExpression{
|
||||
fieldExpressions[alias] = types.FieldExpression{
|
||||
Field: n,
|
||||
Expression: expression,
|
||||
Fields: allFields,
|
||||
}
|
||||
}
|
||||
} else if t != "" {
|
||||
// 只有在聚合类型非空时才添加
|
||||
selectFields[alias] = t
|
||||
}
|
||||
// 如果聚合类型和字段名都为空,不做处理,避免空聚合器类型
|
||||
}
|
||||
@ -617,7 +620,7 @@ func buildSelectFieldsWithExpressions(fields []Field) (
|
||||
// 使用别名作为键,这样每个聚合函数都有唯一的键
|
||||
selectFields[alias] = t
|
||||
|
||||
// 字段映射:别名 -> 输入字段名
|
||||
// 字段映射:输出字段名 -> 输入字段名(直接为聚合器准备正确的映射)
|
||||
if n != "" {
|
||||
fieldMap[alias] = n
|
||||
} else {
|
||||
|
@ -149,6 +149,21 @@ func (p *Parser) Parse() (*SelectStatement, error) {
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
// isKeyword 检查给定的字符串是否是SQL关键字
|
||||
func isKeyword(word string) bool {
|
||||
keywords := map[string]bool{
|
||||
"SELECT": true, "FROM": true, "WHERE": true, "GROUP": true, "BY": true,
|
||||
"ORDER": true, "HAVING": true, "LIMIT": true, "WITH": true, "AS": true,
|
||||
"CASE": true, "WHEN": true, "THEN": true, "ELSE": true, "END": true,
|
||||
"AND": true, "OR": true, "NOT": true, "IN": true, "IS": true, "NULL": true,
|
||||
"DISTINCT": true, "COUNT": true, "SUM": true, "AVG": true, "MIN": true, "MAX": true,
|
||||
"INNER": true, "LEFT": true, "RIGHT": true, "FULL": true, "OUTER": true, "JOIN": true,
|
||||
"ON": true, "UNION": true, "ALL": true, "EXCEPT": true, "INTERSECT": true,
|
||||
"EXISTS": true, "BETWEEN": true, "LIKE": true, "ASC": true, "DESC": true,
|
||||
}
|
||||
return keywords[word]
|
||||
}
|
||||
|
||||
// createDetailedError 创建详细的错误信息
|
||||
func (p *Parser) createDetailedError(err error) error {
|
||||
if parseErr, ok := err.(*ParseError); ok {
|
||||
@ -235,8 +250,55 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error {
|
||||
}
|
||||
|
||||
// 如果不是第一个token,添加空格分隔符
|
||||
// 但要注意特殊情况:某些token之间不应该加空格
|
||||
if expr.Len() > 0 {
|
||||
expr.WriteString(" ")
|
||||
shouldAddSpace := true
|
||||
|
||||
// 获取前一个token的信息
|
||||
exprStr := expr.String()
|
||||
lastChar := exprStr[len(exprStr)-1:]
|
||||
|
||||
// 以下情况不添加空格:
|
||||
// 1. 函数名和左括号之间
|
||||
// 2. 标识符和数字之间(如 x1, y1)
|
||||
// 3. 数字和标识符之间
|
||||
// 4. 左括号之后
|
||||
// 5. 右括号之前
|
||||
if currentToken.Type == TokenLParen && lastChar != " " && lastChar != "(" {
|
||||
// 函数名和左括号之间不加空格
|
||||
shouldAddSpace = false
|
||||
} else if lastChar == "(" || currentToken.Type == TokenRParen {
|
||||
// 左括号之后或右括号之前不加空格
|
||||
shouldAddSpace = false
|
||||
} else if len(exprStr) > 0 && currentToken.Type == TokenNumber {
|
||||
// 检查前一个字符是否是字母(标识符的一部分),且前面没有空格
|
||||
// 这主要处理 x1, y1 这类标识符,但排除 THEN 1, ELSE 0 这类情况
|
||||
if ((lastChar[0] >= 'a' && lastChar[0] <= 'z') || (lastChar[0] >= 'A' && lastChar[0] <= 'Z') || lastChar[0] == '_') &&
|
||||
!strings.HasSuffix(exprStr, " ") {
|
||||
// 进一步检查:如果前面是SQL关键字,则应该加空格
|
||||
words := strings.Fields(exprStr)
|
||||
if len(words) > 0 {
|
||||
lastWord := strings.ToUpper(words[len(words)-1])
|
||||
// 如果是关键字,应该加空格
|
||||
if isKeyword(lastWord) {
|
||||
shouldAddSpace = true
|
||||
} else {
|
||||
shouldAddSpace = false
|
||||
}
|
||||
} else {
|
||||
shouldAddSpace = false
|
||||
}
|
||||
}
|
||||
} else if len(exprStr) > 0 && currentToken.Type == TokenIdent {
|
||||
// 检查前一个字符是否是数字,且前面没有空格
|
||||
if (lastChar[0] >= '0' && lastChar[0] <= '9') && !strings.HasSuffix(exprStr, " ") {
|
||||
shouldAddSpace = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldAddSpace {
|
||||
expr.WriteString(" ")
|
||||
}
|
||||
}
|
||||
expr.WriteString(currentToken.Value)
|
||||
currentToken = p.lexer.NextToken()
|
||||
|
@ -3,12 +3,13 @@ package streamsql
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
"math"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/expr"
|
||||
"github.com/rulego/streamsql/functions"
|
||||
@ -43,9 +44,7 @@ func TestCustomMathFunctions(t *testing.T) {
|
||||
func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) {
|
||||
x1 := cast.ToFloat64(args[0])
|
||||
y1 := cast.ToFloat64(args[1])
|
||||
|
||||
x2 := cast.ToFloat64(args[2])
|
||||
|
||||
y2 := cast.ToFloat64(args[3])
|
||||
|
||||
distance := math.Sqrt(math.Pow(x2-x1, 2) + math.Pow(y2-y1, 2))
|
||||
|
Reference in New Issue
Block a user