From 459a6ba3ca2f886f758734651f44ec00f59c42fb Mon Sep 17 00:00:00 2001 From: dexter Date: Thu, 17 Apr 2025 21:04:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E4=B8=8D=E6=8C=87?= =?UTF-8?q?=E5=AE=9AGroup=20By=E5=AD=90=E5=8F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aggregator/group_aggregator.go | 14 ++++---- rsql/ast.go | 13 +++++++- rsql/lexer.go | 24 ++++++++++++++ rsql/parser.go | 17 +++++++--- rsql/parser_test.go | 31 ++++++++++++++++-- streamsql_test.go | 58 ++++++++++++++++++++++++++++++++++ 6 files changed, 142 insertions(+), 15 deletions(-) diff --git a/aggregator/group_aggregator.go b/aggregator/group_aggregator.go index 0402541..9c81287 100644 --- a/aggregator/group_aggregator.go +++ b/aggregator/group_aggregator.go @@ -95,12 +95,14 @@ func (ga *GroupAggregator) Add(data interface{}) error { } } - if key == "" { - return fmt.Errorf("key cannot be empty") - } - - // 去除最后的 | 符号 - key = key[:len(key)-1] + /** + sql中没有'Group By'时,key为空串 + // if key == "" { + // return fmt.Errorf("key cannot be empty") + // } + // // 去除最后的 | 符号 + // key = key[:len(key)-1] + */ if _, exists := ga.groups[key]; !exists { ga.groups[key] = make(map[string]AggregatorFunction) diff --git a/rsql/ast.go b/rsql/ast.go index f9da831..a730535 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -123,7 +123,18 @@ func extractAggField(expr string) string { start := strings.Index(expr, "(") end := strings.LastIndex(expr, ")") if start >= 0 && end > start { - return strings.TrimSpace(expr[start+1 : end]) + // 提取括号内的内容 + fieldExpr := strings.TrimSpace(expr[start+1 : end]) + + // TODO 后期需完善函数内的运算表达式解析 + // 如果包含运算符,提取第一个操作数作为字段名,形如 temperature/10 的表达式,应解析出字段temperature + for _, op := range []string{"/", "*", "+", "-"} { + if opIndex := strings.Index(fieldExpr, op); opIndex > 0 { + return strings.TrimSpace(fieldExpr[:opIndex]) + } + } + + return fieldExpr } return "" } diff --git a/rsql/lexer.go b/rsql/lexer.go index 299b6aa..dadca59 100644 --- a/rsql/lexer.go +++ b/rsql/lexer.go @@ -37,6 +37,7 @@ const ( TokenWITH TokenTimestamp TokenTimeUnit + TokenOrder ) type Token struct { @@ -154,6 +155,27 @@ func (l *Lexer) readIdentifier() string { return l.input[pos:l.pos] } +func (l *Lexer) readPreviousIdentifier() string { + // 保存当前位置 + endPos := l.pos + + // 向前移动直到找到非字母字符或到达输入开始 + startPos := endPos - 1 + for startPos >= 0 && isLetter(l.input[startPos]) { + startPos-- + } + + // 调整到第一个字母字符的位置 + startPos++ + + // 如果找到有效的标识符,返回它 + if startPos < endPos { + return l.input[startPos:endPos] + } + + return "" +} + func (l *Lexer) readNumber() string { pos := l.pos for isDigit(l.ch) || l.ch == '.' { @@ -213,6 +235,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenTimestamp, Value: ident} case "TIMEUNIT": return Token{Type: TokenTimeUnit, Value: ident} + case "ORDER": + return Token{Type: TokenOrder, Value: ident} default: return Token{Type: TokenIdent, Value: ident} } diff --git a/rsql/parser.go b/rsql/parser.go index 01a451c..7d3967f 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -82,7 +82,8 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error { } for { tok := p.lexer.NextToken() - if tok.Type == TokenGROUP || tok.Type == TokenEOF { + if tok.Type == TokenGROUP || tok.Type == TokenEOF || tok.Type == TokenSliding || + tok.Type == TokenTumbling || tok.Type == TokenCounting || tok.Type == TokenSession { break } switch tok.Type { @@ -172,19 +173,25 @@ func (p *Parser) parseFrom(stmt *SelectStatement) error { } func (p *Parser) parseGroupBy(stmt *SelectStatement) error { - //p.lexer.NextToken() // 跳过GROUP - p.lexer.NextToken() // 跳过BY + tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()) + if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { + p.parseWindowFunction(stmt, tok.Value) + } + if tok.Type == TokenGROUP { + p.lexer.NextToken() // 跳过BY + } for { tok := p.lexer.NextToken() - if tok.Type == TokenEOF { + if tok.Type == TokenWITH || tok.Type == TokenOrder || tok.Type == TokenEOF { break } if tok.Type == TokenComma { continue } if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { - return p.parseWindowFunction(stmt, tok.Value) + p.parseWindowFunction(stmt, tok.Value) + continue } stmt.GroupBy = append(stmt.GroupBy, tok.Value) diff --git a/rsql/parser_test.go b/rsql/parser_test.go index b5f9de2..cb46c9d 100644 --- a/rsql/parser_test.go +++ b/rsql/parser_test.go @@ -27,7 +27,10 @@ func TestParseSQL(t *testing.T) { }, GroupFields: []string{"deviceId"}, SelectFields: map[string]aggregator.AggregateType{ - "aa": "avg", + "temperature": "avg", + }, + FieldAlias: map[string]string{ + "temperature": "aa", }, }, condition: "deviceId == 'aa'", @@ -51,7 +54,7 @@ func TestParseSQL(t *testing.T) { condition: "", }, { - sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ", + sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by TumblingWindow('10s'), deviceId with (TIMESTAMP='ts') ", expected: &model.Config{ WindowConfig: model.WindowConfig{ Type: "tumbling", @@ -62,11 +65,33 @@ func TestParseSQL(t *testing.T) { }, GroupFields: []string{"deviceId"}, SelectFields: map[string]aggregator.AggregateType{ - "aa": "avg", + "temperature": "avg", + }, + FieldAlias: map[string]string{ + "temperature": "aa", }, }, condition: "deviceId == 'aa'", }, + { + sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' and temperature>0 TumblingWindow('10s') with (TIMESTAMP='ts') ", + expected: &model.Config{ + WindowConfig: model.WindowConfig{ + Type: "tumbling", + Params: map[string]interface{}{ + "size": 10 * time.Second, + }, + TsProp: "ts", + }, + SelectFields: map[string]aggregator.AggregateType{ + "temperature": "avg", + }, + FieldAlias: map[string]string{ + "temperature": "aa", + }, + }, + condition: "deviceId == 'aa' && temperature > 0", + }, } for _, tt := range tests { diff --git a/streamsql_test.go b/streamsql_test.go index c96afe1..b83eeb2 100644 --- a/streamsql_test.go +++ b/streamsql_test.go @@ -145,3 +145,61 @@ func TestStreamsql(t *testing.T) { assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"])) } } + +func TestStreamsqlWithoutGroupBy(t *testing.T) { + streamsql := New() + var rsql = "SELECT max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC) + testData := []interface{}{ + map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime}, + map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)}, + map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime}, + } + + for _, data := range testData { + strm.AddData(data) + } + // 捕获结果 + resultChan := make(chan interface{}) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + var actual interface{} + select { + case actual = <-resultChan: + cancel() + case <-ctx.Done(): + t.Fatal("Timeout waiting for results") + } + + expected := []map[string]interface{}{ + { + "max_age": 10.0, + "min_score": 100.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + } + + assert.IsType(t, []map[string]interface{}{}, actual) + resultSlice, ok := actual.([]map[string]interface{}) + require.True(t, ok) + assert.Len(t, resultSlice, 1) + for _, expectedResult := range expected { + //found := false + for _, resultMap := range resultSlice { + assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001) + assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64)) + assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64)) + } + //assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"])) + } +}