diff --git a/functions/expr_bridge.go b/functions/expr_bridge.go index 821fbd6..764f499 100644 --- a/functions/expr_bridge.go +++ b/functions/expr_bridge.go @@ -2,12 +2,12 @@ package functions import ( "fmt" - "github.com/rulego/streamsql/utils/cast" "strconv" "strings" "github.com/expr-lang/expr" "github.com/expr-lang/expr/vm" + "github.com/rulego/streamsql/utils/cast" ) // ExprBridge 桥接 StreamSQL 函数系统与 expr-lang/expr @@ -109,6 +109,11 @@ func (bridge *ExprBridge) CompileExpressionWithStreamSQLFunctions(expression str // EvaluateExpression 评估表达式,自动选择最合适的引擎 func (bridge *ExprBridge) EvaluateExpression(expression string, data map[string]interface{}) (interface{}, error) { + // 首先检查是否是CONCAT函数调用 + if strings.HasPrefix(strings.ToUpper(expression), "CONCAT(") { + return bridge.evaluateConcatFunction(expression, data) + } + // 首先检查是否包含字符串拼接模式 if bridge.isStringConcatenationExpression(expression, data) { result, err := bridge.evaluateStringConcatenation(expression, data) @@ -177,6 +182,87 @@ func (bridge *ExprBridge) fallbackToCustomExpr(expression string, data map[strin return nil, fmt.Errorf("unable to evaluate expression: %s, string concat error: %v, numeric error: %v", expression, err, err) } +// evaluateConcatFunction 处理CONCAT函数调用 +func (bridge *ExprBridge) evaluateConcatFunction(expression string, data map[string]interface{}) (interface{}, error) { + // 提取CONCAT函数的参数 + start := strings.Index(expression, "(") + end := strings.LastIndex(expression, ")") + if start == -1 || end == -1 || end <= start { + return nil, fmt.Errorf("invalid CONCAT function syntax: %s", expression) + } + + // 获取参数字符串 + paramsStr := strings.TrimSpace(expression[start+1 : end]) + if paramsStr == "" { + return "", nil // 空参数返回空字符串 + } + + // 解析参数 + params := bridge.parseParameters(paramsStr) + var result strings.Builder + + for _, param := range params { + param = strings.TrimSpace(param) + + // 处理字符串字面量 + if (strings.HasPrefix(param, "'") && strings.HasSuffix(param, "'")) || + (strings.HasPrefix(param, "\"") && strings.HasSuffix(param, "\"")) { + // 去掉引号 + literal := param[1 : len(param)-1] + result.WriteString(literal) + } else { + // 处理字段引用 + if value, exists := data[param]; exists { + strValue := cast.ToString(value) + result.WriteString(strValue) + } else { + return nil, fmt.Errorf("field %s not found in data", param) + } + } + } + + return result.String(), nil +} + +// parseParameters 解析函数参数,正确处理引号内的逗号 +func (bridge *ExprBridge) parseParameters(paramsStr string) []string { + var params []string + var current strings.Builder + inQuotes := false + quoteChar := byte(0) + + for i := 0; i < len(paramsStr); i++ { + ch := paramsStr[i] + + if !inQuotes { + if ch == '\'' || ch == '"' { + inQuotes = true + quoteChar = ch + current.WriteByte(ch) + } else if ch == ',' { + // 参数分隔符 + params = append(params, current.String()) + current.Reset() + } else { + current.WriteByte(ch) + } + } else { + if ch == quoteChar { + inQuotes = false + quoteChar = 0 + } + current.WriteByte(ch) + } + } + + // 添加最后一个参数 + if current.Len() > 0 { + params = append(params, current.String()) + } + + return params +} + // evaluateStringConcatenation 处理字符串拼接表达式 func (bridge *ExprBridge) evaluateStringConcatenation(expression string, data map[string]interface{}) (interface{}, error) { // 检查是否是字符串拼接表达式 (包含 + 和字符串字面量) diff --git a/functions/functions_aggregation.go b/functions/functions_aggregation.go index 3fb449f..248be15 100644 --- a/functions/functions_aggregation.go +++ b/functions/functions_aggregation.go @@ -425,11 +425,13 @@ func (f *PercentileFunction) Execute(ctx *FunctionContext, args []interface{}) ( // CollectFunction 收集函数 - 获取当前窗口所有消息的列值组成的数组 type CollectFunction struct { *BaseFunction + values []interface{} } func NewCollectFunction() *CollectFunction { return &CollectFunction{ BaseFunction: NewBaseFunction("collect", TypeAggregation, "聚合函数", "收集所有值组成数组", 1, -1), + values: make([]interface{}, 0), } } @@ -444,14 +446,47 @@ func (f *CollectFunction) Execute(ctx *FunctionContext, args []interface{}) (int return result, nil } +// 实现AggregatorFunction接口 +func (f *CollectFunction) New() AggregatorFunction { + return &CollectFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, 0), + } +} + +func (f *CollectFunction) Add(value interface{}) { + f.values = append(f.values, value) +} + +func (f *CollectFunction) Result() interface{} { + result := make([]interface{}, len(f.values)) + copy(result, f.values) + return result +} + +func (f *CollectFunction) Reset() { + f.values = make([]interface{}, 0) +} + +func (f *CollectFunction) Clone() AggregatorFunction { + newFunc := &CollectFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, len(f.values)), + } + copy(newFunc.values, f.values) + return newFunc +} + // LastValueFunction 最后值函数 - 返回组中最后一行的值 type LastValueFunction struct { *BaseFunction + lastValue interface{} } func NewLastValueFunction() *LastValueFunction { return &LastValueFunction{ BaseFunction: NewBaseFunction("last_value", TypeAggregation, "聚合函数", "返回最后一个值", 1, -1), + lastValue: nil, } } @@ -467,14 +502,43 @@ func (f *LastValueFunction) Execute(ctx *FunctionContext, args []interface{}) (i return args[len(args)-1], nil } +// 实现AggregatorFunction接口 +func (f *LastValueFunction) New() AggregatorFunction { + return &LastValueFunction{ + BaseFunction: f.BaseFunction, + lastValue: nil, + } +} + +func (f *LastValueFunction) Add(value interface{}) { + f.lastValue = value +} + +func (f *LastValueFunction) Result() interface{} { + return f.lastValue +} + +func (f *LastValueFunction) Reset() { + f.lastValue = nil +} + +func (f *LastValueFunction) Clone() AggregatorFunction { + return &LastValueFunction{ + BaseFunction: f.BaseFunction, + lastValue: f.lastValue, + } +} + // MergeAggFunction 合并聚合函数 - 将组中的值合并为单个值 type MergeAggFunction struct { *BaseFunction + values []interface{} } func NewMergeAggFunction() *MergeAggFunction { return &MergeAggFunction{ BaseFunction: NewBaseFunction("merge_agg", TypeAggregation, "聚合函数", "合并所有值", 1, -1), + values: make([]interface{}, 0), } } @@ -498,6 +562,47 @@ func (f *MergeAggFunction) Execute(ctx *FunctionContext, args []interface{}) (in return result.String(), nil } +// 实现AggregatorFunction接口 +func (f *MergeAggFunction) New() AggregatorFunction { + return &MergeAggFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, 0), + } +} + +func (f *MergeAggFunction) Add(value interface{}) { + f.values = append(f.values, value) +} + +func (f *MergeAggFunction) Result() interface{} { + if len(f.values) == 0 { + return nil + } + + // 尝试合并为字符串 + var result strings.Builder + for i, arg := range f.values { + if i > 0 { + result.WriteString(",") + } + result.WriteString(cast.ToString(arg)) + } + return result.String() +} + +func (f *MergeAggFunction) Reset() { + f.values = make([]interface{}, 0) +} + +func (f *MergeAggFunction) Clone() AggregatorFunction { + newFunc := &MergeAggFunction{ + BaseFunction: f.BaseFunction, + values: make([]interface{}, len(f.values)), + } + copy(newFunc.values, f.values) + return newFunc +} + // StdDevSFunction 样本标准差函数 type StdDevSFunction struct { *BaseFunction diff --git a/functions/functions_window.go b/functions/functions_window.go index 8f15c2f..24efd52 100644 --- a/functions/functions_window.go +++ b/functions/functions_window.go @@ -220,10 +220,14 @@ func (f *ExpressionAggregatorFunction) New() AggregatorFunction { func (f *ExpressionAggregatorFunction) Add(value interface{}) { // 对于表达式聚合器,保存最后一个计算结果 + // 表达式的计算结果应该是每个数据项的计算结果 f.lastResult = value } func (f *ExpressionAggregatorFunction) Result() interface{} { + // 对于表达式聚合器,返回最后一个计算结果 + // 注意:对于字符串函数如CONCAT,每个数据项都会产生一个结果 + // 在窗口聚合中,我们返回最后一个计算的结果 return f.lastResult } diff --git a/rsql/ast.go b/rsql/ast.go index bd2a9d6..fa9b402 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -325,6 +325,29 @@ func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName s // 对于复杂表达式,包括多参数函数调用 expression = fieldExpr + // 对于CONCAT等字符串函数,直接保存完整表达式 + if strings.ToLower(funcName) == "concat" { + // 智能解析CONCAT函数的参数来提取字段名 + var fields []string + params := parseSmartParameters(fieldExpr) + for _, param := range params { + param = strings.TrimSpace(param) + // 如果参数不是字符串常量(不被引号包围),则认为是字段名 + if !((strings.HasPrefix(param, "'") && strings.HasSuffix(param, "'")) || + (strings.HasPrefix(param, "\"") && strings.HasSuffix(param, "\""))) { + if isIdentifier(param) { + fields = append(fields, param) + } + } + } + if len(fields) > 0 { + // 对于CONCAT函数,保存完整的函数调用作为表达式 + return fields[0], funcName + "(" + fieldExpr + ")", fields + } + // 如果没有找到字段,返回空字段名但保留表达式 + return "", funcName + "(" + fieldExpr + ")", nil + } + // 使用表达式引擎解析 parsedExpr, err := expr.NewExpression(fieldExpr) if err != nil { @@ -370,23 +393,62 @@ func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName s return fieldExpr, expression, nil } +// parseSmartParameters 智能解析函数参数,正确处理引号内的逗号 +func parseSmartParameters(paramsStr string) []string { + var params []string + var current strings.Builder + inQuotes := false + quoteChar := byte(0) + + for i := 0; i < len(paramsStr); i++ { + ch := paramsStr[i] + + if !inQuotes { + if ch == '\'' || ch == '"' { + inQuotes = true + quoteChar = ch + current.WriteByte(ch) + } else if ch == ',' { + // 参数分隔符 + params = append(params, current.String()) + current.Reset() + } else { + current.WriteByte(ch) + } + } else { + if ch == quoteChar { + inQuotes = false + quoteChar = 0 + } + current.WriteByte(ch) + } + } + + // 添加最后一个参数 + if current.Len() > 0 { + params = append(params, current.String()) + } + + return params +} + // isIdentifier 检查字符串是否是有效的标识符 func isIdentifier(s string) bool { if len(s) == 0 { return false } - + // 第一个字符必须是字母或下划线 if !((s[0] >= 'a' && s[0] <= 'z') || (s[0] >= 'A' && s[0] <= 'Z') || s[0] == '_') { return false } - + // 其余字符必须是字母、数字或下划线 for i := 1; i < len(s); i++ { - if !((s[i] >= 'a' && s[i] <= 'z') || (s[i] >= 'A' && s[i] <= 'Z') || - (s[i] >= '0' && s[i] <= '9') || s[i] == '_') { + char := s[i] + if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') || + (char >= '0' && char <= '9') || char == '_') { return false } } - return true } diff --git a/rsql/lexer.go b/rsql/lexer.go index f229450..4899206 100644 --- a/rsql/lexer.go +++ b/rsql/lexer.go @@ -90,6 +90,11 @@ func (l *Lexer) NextToken() Token { l.readChar() return Token{Type: TokenSlash, Value: "/"} case '=': + if l.peekChar() == '=' { + l.readChar() + l.readChar() + return Token{Type: TokenEQ, Value: "=="} + } l.readChar() return Token{Type: TokenEQ, Value: "="} case '>': @@ -114,6 +119,10 @@ func (l *Lexer) NextToken() Token { l.readChar() return Token{Type: TokenNE, Value: "!="} } + case '\'': + return Token{Type: TokenString, Value: l.readString()} + case '"': + return Token{Type: TokenString, Value: l.readString()} } if isLetter(l.ch) { @@ -125,10 +134,6 @@ func (l *Lexer) NextToken() Token { return Token{Type: TokenNumber, Value: l.readNumber()} } - if l.ch == '\'' { - return Token{Type: TokenString, Value: l.readString()} - } - l.readChar() return Token{Type: TokenEOF} } @@ -188,16 +193,20 @@ func (l *Lexer) readNumber() string { } func (l *Lexer) readString() string { - l.readChar() // 跳过开头单引号 - pos := l.pos + quoteChar := l.ch // 记录引号类型(单引号或双引号) + startPos := l.pos // 记录开始位置(包含引号) + l.readChar() // 跳过开头引号 - for l.ch != '\'' && l.ch != 0 { + for l.ch != quoteChar && l.ch != 0 { l.readChar() } - str := l.input[pos:l.pos] - l.readChar() // 跳过结尾单引号 - return str + if l.ch == quoteChar { + l.readChar() // 跳过结尾引号 + } + + // 返回包含引号的完整字符串 + return l.input[startPos:l.pos] } func (l *Lexer) skipWhitespace() { diff --git a/rsql/parser.go b/rsql/parser.go index 1c8981d..c6b5704 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -171,9 +171,13 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error { case TokenIdent, TokenNumber: conditions = append(conditions, tok.Value) case TokenString: - conditions = append(conditions, "'"+tok.Value+"'") + conditions = append(conditions, tok.Value) case TokenEQ: - conditions = append(conditions, "==") + if tok.Value == "=" { + conditions = append(conditions, "==") + } else { + conditions = append(conditions, tok.Value) + } case TokenAND: conditions = append(conditions, "&&") case TokenOR: @@ -431,9 +435,13 @@ func (p *Parser) parseHaving(stmt *SelectStatement) error { case TokenIdent, TokenNumber: conditions = append(conditions, tok.Value) case TokenString: - conditions = append(conditions, "'"+tok.Value+"'") + conditions = append(conditions, tok.Value) case TokenEQ: - conditions = append(conditions, "==") + if tok.Value == "=" { + conditions = append(conditions, "==") + } else { + conditions = append(conditions, tok.Value) + } case TokenAND: conditions = append(conditions, "&&") case TokenOR: diff --git a/stream/stream_test.go b/stream/stream_test.go index 26ee562..c6a81bb 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -23,6 +23,7 @@ func TestStreamProcess(t *testing.T) { "temperature": aggregator.Avg, "humidity": aggregator.Sum, }, + NeedWindow: true, } strm, err := NewStream(config) @@ -64,16 +65,41 @@ func TestStreamProcess(t *testing.T) { // 预期结果:只有 device='aa' 且 temperature>10 的数据会被聚合 expected := map[string]interface{}{ - "device": "aa", - "temperature_avg": 27.5, // (25+30)/2 - "humidity_sum": 115.0, // 60+55 + "device": "aa", + "temperature": 27.5, // (25+30)/2 + "humidity": 115.0, // 60+55 } // 验证结果 + t.Logf("Received result: %+v (type: %T)", actual, actual) + if actual == nil { + t.Fatal("Received nil result") + } assert.IsType(t, []map[string]interface{}{}, actual) + t.Logf("Type assertion successful") resultMap := actual.([]map[string]interface{}) - assert.InEpsilon(t, expected["temperature_avg"].(float64), resultMap[0]["temperature_avg"].(float64), 0.0001) - assert.InDelta(t, expected["humidity_sum"].(float64), resultMap[0]["humidity_sum"].(float64), 0.0001) + t.Logf("Result map length: %d", len(resultMap)) + if len(resultMap) > 0 { + t.Logf("First result: %+v", resultMap[0]) + + // 检查temperature字段 + if tempAvg, ok := resultMap[0]["temperature"]; ok { + t.Logf("temperature: %+v (type: %T)", tempAvg, tempAvg) + assert.InEpsilon(t, expected["temperature"].(float64), tempAvg.(float64), 0.0001) + } else { + t.Fatal("temperature field not found in result") + } + + // 检查humidity字段 + if humSum, ok := resultMap[0]["humidity"]; ok { + t.Logf("humidity: %+v (type: %T)", humSum, humSum) + assert.InDelta(t, expected["humidity"].(float64), humSum.(float64), 0.0001) + } else { + t.Fatal("humidity field not found in result") + } + } else { + t.Fatal("No results in result map") + } } // 不设置过滤器 @@ -88,6 +114,7 @@ func TestStreamWithoutFilter(t *testing.T) { "temperature": aggregator.Max, "humidity": aggregator.Min, }, + NeedWindow: true, } strm, err := NewStream(config) @@ -126,14 +153,14 @@ func TestStreamWithoutFilter(t *testing.T) { expected := []map[string]interface{}{ { - "device": "aa", - "temperature_max": 30.0, - "humidity_min": 55.0, + "device": "aa", + "temperature": 30.0, + "humidity": 55.0, }, { - "device": "bb", - "temperature_max": 22.0, - "humidity_min": 70.0, + "device": "bb", + "temperature": 22.0, + "humidity": 70.0, }, } @@ -146,8 +173,8 @@ func TestStreamWithoutFilter(t *testing.T) { found := false for _, resultMap := range resultSlice { if resultMap["device"] == expectedResult["device"] { - assert.InEpsilon(t, expectedResult["temperature_max"].(float64), resultMap["temperature_max"].(float64), 0.0001) - assert.InEpsilon(t, expectedResult["humidity_min"].(float64), resultMap["humidity_min"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["temperature"].(float64), resultMap["temperature"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["humidity"].(float64), resultMap["humidity"].(float64), 0.0001) found = true break } @@ -167,6 +194,7 @@ func TestIncompleteStreamProcess(t *testing.T) { "temperature": aggregator.Avg, "humidity": aggregator.Sum, }, + NeedWindow: true, } strm, err := NewStream(config) @@ -210,16 +238,41 @@ func TestIncompleteStreamProcess(t *testing.T) { // 预期结果:只有 device='aa' 且 temperature>10 的数据会被聚合 expected := map[string]interface{}{ - "device": "aa", - "temperature_avg": 27.5, // (25+30)/2 - "humidity_sum": 115.0, // 60+55 + "device": "aa", + "temperature": 27.5, // (25+30)/2 + "humidity": 115.0, // 60+55 } // 验证结果 + t.Logf("Received result: %+v (type: %T)", actual, actual) + if actual == nil { + t.Fatal("Received nil result") + } assert.IsType(t, []map[string]interface{}{}, actual) + t.Logf("Type assertion successful") resultMap := actual.([]map[string]interface{}) - assert.InEpsilon(t, expected["temperature_avg"].(float64), resultMap[0]["temperature_avg"].(float64), 0.0001) - assert.InDelta(t, expected["humidity_sum"].(float64), resultMap[0]["humidity_sum"].(float64), 0.0001) + t.Logf("Result map length: %d", len(resultMap)) + if len(resultMap) > 0 { + t.Logf("First result: %+v", resultMap[0]) + + // 检查temperature字段 + if tempAvg, ok := resultMap[0]["temperature"]; ok { + t.Logf("temperature: %+v (type: %T)", tempAvg, tempAvg) + assert.InEpsilon(t, expected["temperature"].(float64), tempAvg.(float64), 0.0001) + } else { + t.Fatal("temperature field not found in result") + } + + // 检查humidity字段 + if humSum, ok := resultMap[0]["humidity"]; ok { + t.Logf("humidity: %+v (type: %T)", humSum, humSum) + assert.InDelta(t, expected["humidity"].(float64), humSum.(float64), 0.0001) + } else { + t.Fatal("humidity field not found in result") + } + } else { + t.Fatal("No results in result map") + } } func TestWindowSlotAgg(t *testing.T) { @@ -236,6 +289,7 @@ func TestWindowSlotAgg(t *testing.T) { "start": aggregator.WindowStart, "end": aggregator.WindowEnd, }, + NeedWindow: true, } strm, err := NewStream(config) @@ -276,18 +330,18 @@ func TestWindowSlotAgg(t *testing.T) { expected := []map[string]interface{}{ { - "device": "aa", - "temperature_max": 30.0, - "humidity_min": 55.0, - "start": baseTime.UnixNano(), - "end": baseTime.Add(2 * time.Second).UnixNano(), + "device": "aa", + "temperature": 30.0, + "humidity": 55.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), }, { - "device": "bb", - "temperature_max": 22.0, - "humidity_min": 70.0, - "start": baseTime.UnixNano(), - "end": baseTime.Add(2 * time.Second).UnixNano(), + "device": "bb", + "temperature": 22.0, + "humidity": 70.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), }, } @@ -300,8 +354,8 @@ func TestWindowSlotAgg(t *testing.T) { found := false for _, resultMap := range resultSlice { if resultMap["device"] == expectedResult["device"] { - assert.InEpsilon(t, expectedResult["temperature_max"].(float64), resultMap["temperature_max"].(float64), 0.0001) - assert.InEpsilon(t, expectedResult["humidity_min"].(float64), resultMap["humidity_min"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["temperature"].(float64), resultMap["temperature"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["humidity"].(float64), resultMap["humidity"].(float64), 0.0001) assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64)) assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64)) found = true