mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-03-14 22:37:20 +00:00
feat: 支持不指定Group By子句
This commit is contained in:
@@ -95,12 +95,14 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
return fmt.Errorf("key cannot be empty")
|
||||
}
|
||||
|
||||
// 去除最后的 | 符号
|
||||
key = key[:len(key)-1]
|
||||
/**
|
||||
sql中没有'Group By'时,key为空串
|
||||
// if key == "" {
|
||||
// return fmt.Errorf("key cannot be empty")
|
||||
// }
|
||||
// // 去除最后的 | 符号
|
||||
// key = key[:len(key)-1]
|
||||
*/
|
||||
|
||||
if _, exists := ga.groups[key]; !exists {
|
||||
ga.groups[key] = make(map[string]AggregatorFunction)
|
||||
|
||||
+12
-1
@@ -123,7 +123,18 @@ func extractAggField(expr string) string {
|
||||
start := strings.Index(expr, "(")
|
||||
end := strings.LastIndex(expr, ")")
|
||||
if start >= 0 && end > start {
|
||||
return strings.TrimSpace(expr[start+1 : end])
|
||||
// 提取括号内的内容
|
||||
fieldExpr := strings.TrimSpace(expr[start+1 : end])
|
||||
|
||||
// TODO 后期需完善函数内的运算表达式解析
|
||||
// 如果包含运算符,提取第一个操作数作为字段名,形如 temperature/10 的表达式,应解析出字段temperature
|
||||
for _, op := range []string{"/", "*", "+", "-"} {
|
||||
if opIndex := strings.Index(fieldExpr, op); opIndex > 0 {
|
||||
return strings.TrimSpace(fieldExpr[:opIndex])
|
||||
}
|
||||
}
|
||||
|
||||
return fieldExpr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ const (
|
||||
TokenWITH
|
||||
TokenTimestamp
|
||||
TokenTimeUnit
|
||||
TokenOrder
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
@@ -154,6 +155,27 @@ func (l *Lexer) readIdentifier() string {
|
||||
return l.input[pos:l.pos]
|
||||
}
|
||||
|
||||
func (l *Lexer) readPreviousIdentifier() string {
|
||||
// 保存当前位置
|
||||
endPos := l.pos
|
||||
|
||||
// 向前移动直到找到非字母字符或到达输入开始
|
||||
startPos := endPos - 1
|
||||
for startPos >= 0 && isLetter(l.input[startPos]) {
|
||||
startPos--
|
||||
}
|
||||
|
||||
// 调整到第一个字母字符的位置
|
||||
startPos++
|
||||
|
||||
// 如果找到有效的标识符,返回它
|
||||
if startPos < endPos {
|
||||
return l.input[startPos:endPos]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (l *Lexer) readNumber() string {
|
||||
pos := l.pos
|
||||
for isDigit(l.ch) || l.ch == '.' {
|
||||
@@ -213,6 +235,8 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenTimestamp, Value: ident}
|
||||
case "TIMEUNIT":
|
||||
return Token{Type: TokenTimeUnit, Value: ident}
|
||||
case "ORDER":
|
||||
return Token{Type: TokenOrder, Value: ident}
|
||||
default:
|
||||
return Token{Type: TokenIdent, Value: ident}
|
||||
}
|
||||
|
||||
+12
-5
@@ -82,7 +82,8 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error {
|
||||
}
|
||||
for {
|
||||
tok := p.lexer.NextToken()
|
||||
if tok.Type == TokenGROUP || tok.Type == TokenEOF {
|
||||
if tok.Type == TokenGROUP || tok.Type == TokenEOF || tok.Type == TokenSliding ||
|
||||
tok.Type == TokenTumbling || tok.Type == TokenCounting || tok.Type == TokenSession {
|
||||
break
|
||||
}
|
||||
switch tok.Type {
|
||||
@@ -172,19 +173,25 @@ func (p *Parser) parseFrom(stmt *SelectStatement) error {
|
||||
}
|
||||
|
||||
func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
|
||||
//p.lexer.NextToken() // 跳过GROUP
|
||||
p.lexer.NextToken() // 跳过BY
|
||||
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
|
||||
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
|
||||
p.parseWindowFunction(stmt, tok.Value)
|
||||
}
|
||||
if tok.Type == TokenGROUP {
|
||||
p.lexer.NextToken() // 跳过BY
|
||||
}
|
||||
|
||||
for {
|
||||
tok := p.lexer.NextToken()
|
||||
if tok.Type == TokenEOF {
|
||||
if tok.Type == TokenWITH || tok.Type == TokenOrder || tok.Type == TokenEOF {
|
||||
break
|
||||
}
|
||||
if tok.Type == TokenComma {
|
||||
continue
|
||||
}
|
||||
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
|
||||
return p.parseWindowFunction(stmt, tok.Value)
|
||||
p.parseWindowFunction(stmt, tok.Value)
|
||||
continue
|
||||
}
|
||||
|
||||
stmt.GroupBy = append(stmt.GroupBy, tok.Value)
|
||||
|
||||
+28
-3
@@ -27,7 +27,10 @@ func TestParseSQL(t *testing.T) {
|
||||
},
|
||||
GroupFields: []string{"deviceId"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"aa": "avg",
|
||||
"temperature": "avg",
|
||||
},
|
||||
FieldAlias: map[string]string{
|
||||
"temperature": "aa",
|
||||
},
|
||||
},
|
||||
condition: "deviceId == 'aa'",
|
||||
@@ -51,7 +54,7 @@ func TestParseSQL(t *testing.T) {
|
||||
condition: "",
|
||||
},
|
||||
{
|
||||
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ",
|
||||
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by TumblingWindow('10s'), deviceId with (TIMESTAMP='ts') ",
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
@@ -62,11 +65,33 @@ func TestParseSQL(t *testing.T) {
|
||||
},
|
||||
GroupFields: []string{"deviceId"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"aa": "avg",
|
||||
"temperature": "avg",
|
||||
},
|
||||
FieldAlias: map[string]string{
|
||||
"temperature": "aa",
|
||||
},
|
||||
},
|
||||
condition: "deviceId == 'aa'",
|
||||
},
|
||||
{
|
||||
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' and temperature>0 TumblingWindow('10s') with (TIMESTAMP='ts') ",
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{
|
||||
"size": 10 * time.Second,
|
||||
},
|
||||
TsProp: "ts",
|
||||
},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"temperature": "avg",
|
||||
},
|
||||
FieldAlias: map[string]string{
|
||||
"temperature": "aa",
|
||||
},
|
||||
},
|
||||
condition: "deviceId == 'aa' && temperature > 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -145,3 +145,61 @@ func TestStreamsql(t *testing.T) {
|
||||
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamsqlWithoutGroupBy(t *testing.T) {
|
||||
streamsql := New()
|
||||
var rsql = "SELECT max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
|
||||
testData := []interface{}{
|
||||
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime},
|
||||
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)},
|
||||
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime},
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
}
|
||||
// 捕获结果
|
||||
resultChan := make(chan interface{})
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var actual interface{}
|
||||
select {
|
||||
case actual = <-resultChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Timeout waiting for results")
|
||||
}
|
||||
|
||||
expected := []map[string]interface{}{
|
||||
{
|
||||
"max_age": 10.0,
|
||||
"min_score": 100.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
}
|
||||
|
||||
assert.IsType(t, []map[string]interface{}{}, actual)
|
||||
resultSlice, ok := actual.([]map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, resultSlice, 1)
|
||||
for _, expectedResult := range expected {
|
||||
//found := false
|
||||
for _, resultMap := range resultSlice {
|
||||
assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001)
|
||||
assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001)
|
||||
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
|
||||
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
|
||||
}
|
||||
//assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user