feat: 支持不指定Group By子句

This commit is contained in:
dexter
2025-04-17 21:04:11 +08:00
parent 3eed15673e
commit 459a6ba3ca
6 changed files with 142 additions and 15 deletions
+8 -6
View File
@@ -95,12 +95,14 @@ func (ga *GroupAggregator) Add(data interface{}) error {
}
}
if key == "" {
return fmt.Errorf("key cannot be empty")
}
// 去除最后的 | 符号
key = key[:len(key)-1]
/**
sql中没有'Group By'时key为空串
// if key == "" {
// return fmt.Errorf("key cannot be empty")
// }
// // 去除最后的 | 符号
// key = key[:len(key)-1]
*/
if _, exists := ga.groups[key]; !exists {
ga.groups[key] = make(map[string]AggregatorFunction)
+12 -1
View File
@@ -123,7 +123,18 @@ func extractAggField(expr string) string {
start := strings.Index(expr, "(")
end := strings.LastIndex(expr, ")")
if start >= 0 && end > start {
return strings.TrimSpace(expr[start+1 : end])
// 提取括号内的内容
fieldExpr := strings.TrimSpace(expr[start+1 : end])
// TODO 后期需完善函数内的运算表达式解析
// 如果包含运算符,提取第一个操作数作为字段名,形如 temperature/10 的表达式应解析出字段temperature
for _, op := range []string{"/", "*", "+", "-"} {
if opIndex := strings.Index(fieldExpr, op); opIndex > 0 {
return strings.TrimSpace(fieldExpr[:opIndex])
}
}
return fieldExpr
}
return ""
}
+24
View File
@@ -37,6 +37,7 @@ const (
TokenWITH
TokenTimestamp
TokenTimeUnit
TokenOrder
)
type Token struct {
@@ -154,6 +155,27 @@ func (l *Lexer) readIdentifier() string {
return l.input[pos:l.pos]
}
func (l *Lexer) readPreviousIdentifier() string {
// 保存当前位置
endPos := l.pos
// 向前移动直到找到非字母字符或到达输入开始
startPos := endPos - 1
for startPos >= 0 && isLetter(l.input[startPos]) {
startPos--
}
// 调整到第一个字母字符的位置
startPos++
// 如果找到有效的标识符,返回它
if startPos < endPos {
return l.input[startPos:endPos]
}
return ""
}
func (l *Lexer) readNumber() string {
pos := l.pos
for isDigit(l.ch) || l.ch == '.' {
@@ -213,6 +235,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
return Token{Type: TokenTimestamp, Value: ident}
case "TIMEUNIT":
return Token{Type: TokenTimeUnit, Value: ident}
case "ORDER":
return Token{Type: TokenOrder, Value: ident}
default:
return Token{Type: TokenIdent, Value: ident}
}
+12 -5
View File
@@ -82,7 +82,8 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error {
}
for {
tok := p.lexer.NextToken()
if tok.Type == TokenGROUP || tok.Type == TokenEOF {
if tok.Type == TokenGROUP || tok.Type == TokenEOF || tok.Type == TokenSliding ||
tok.Type == TokenTumbling || tok.Type == TokenCounting || tok.Type == TokenSession {
break
}
switch tok.Type {
@@ -172,19 +173,25 @@ func (p *Parser) parseFrom(stmt *SelectStatement) error {
}
func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
//p.lexer.NextToken() // 跳过GROUP
p.lexer.NextToken() // 跳过BY
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
p.parseWindowFunction(stmt, tok.Value)
}
if tok.Type == TokenGROUP {
p.lexer.NextToken() // 跳过BY
}
for {
tok := p.lexer.NextToken()
if tok.Type == TokenEOF {
if tok.Type == TokenWITH || tok.Type == TokenOrder || tok.Type == TokenEOF {
break
}
if tok.Type == TokenComma {
continue
}
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
return p.parseWindowFunction(stmt, tok.Value)
p.parseWindowFunction(stmt, tok.Value)
continue
}
stmt.GroupBy = append(stmt.GroupBy, tok.Value)
+28 -3
View File
@@ -27,7 +27,10 @@ func TestParseSQL(t *testing.T) {
},
GroupFields: []string{"deviceId"},
SelectFields: map[string]aggregator.AggregateType{
"aa": "avg",
"temperature": "avg",
},
FieldAlias: map[string]string{
"temperature": "aa",
},
},
condition: "deviceId == 'aa'",
@@ -51,7 +54,7 @@ func TestParseSQL(t *testing.T) {
condition: "",
},
{
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ",
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by TumblingWindow('10s'), deviceId with (TIMESTAMP='ts') ",
expected: &model.Config{
WindowConfig: model.WindowConfig{
Type: "tumbling",
@@ -62,11 +65,33 @@ func TestParseSQL(t *testing.T) {
},
GroupFields: []string{"deviceId"},
SelectFields: map[string]aggregator.AggregateType{
"aa": "avg",
"temperature": "avg",
},
FieldAlias: map[string]string{
"temperature": "aa",
},
},
condition: "deviceId == 'aa'",
},
{
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' and temperature>0 TumblingWindow('10s') with (TIMESTAMP='ts') ",
expected: &model.Config{
WindowConfig: model.WindowConfig{
Type: "tumbling",
Params: map[string]interface{}{
"size": 10 * time.Second,
},
TsProp: "ts",
},
SelectFields: map[string]aggregator.AggregateType{
"temperature": "avg",
},
FieldAlias: map[string]string{
"temperature": "aa",
},
},
condition: "deviceId == 'aa' && temperature > 0",
},
}
for _, tt := range tests {
+58
View File
@@ -145,3 +145,61 @@ func TestStreamsql(t *testing.T) {
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
}
}
func TestStreamsqlWithoutGroupBy(t *testing.T) {
streamsql := New()
var rsql = "SELECT max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
testData := []interface{}{
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime},
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)},
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime},
}
for _, data := range testData {
strm.AddData(data)
}
// 捕获结果
resultChan := make(chan interface{})
strm.AddSink(func(result interface{}) {
resultChan <- result
})
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
var actual interface{}
select {
case actual = <-resultChan:
cancel()
case <-ctx.Done():
t.Fatal("Timeout waiting for results")
}
expected := []map[string]interface{}{
{
"max_age": 10.0,
"min_score": 100.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
}
assert.IsType(t, []map[string]interface{}{}, actual)
resultSlice, ok := actual.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
for _, expectedResult := range expected {
//found := false
for _, resultMap := range resultSlice {
assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001)
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
}
//assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
}
}