diff --git a/rsql/ast.go b/rsql/ast.go index e2df0cf..99cffa4 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -2,6 +2,7 @@ package rsql import ( "fmt" + "github.com/rulego/streamsql/window" "strings" "time" @@ -24,7 +25,7 @@ type Field struct { type WindowDefinition struct { Type string - Params map[string]interface{} + Params []interface{} } // ToStreamConfig 将AST转换为Stream配置 @@ -33,9 +34,15 @@ func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) { return nil, "", fmt.Errorf("missing FROM clause") } // 解析窗口配置 - windowType := "tumbling" - if s.Window.Type == "Sliding" { - windowType = "sliding" + windowType := window.TypeTumbling + if strings.ToUpper(s.Window.Type) == "TUMBLINGWINDOW" { + windowType = window.TypeTumbling + } else if strings.ToUpper(s.Window.Type) == "SLIDINGWINDOW" { + windowType = window.TypeSliding + } else if strings.ToUpper(s.Window.Type) == "COUNTINGWINDOW" { + windowType = window.TypeCounting + } else if strings.ToUpper(s.Window.Type) == "SESSIONWINDOW" { + windowType = window.TypeSession } params, err := parseWindowParams(s.Window.Params) @@ -92,23 +99,25 @@ func parseAggregateType(expr string) aggregator.AggregateType { return "" } -func parseWindowParams(params map[string]interface{}) (map[string]interface{}, error) { +func parseWindowParams(params []interface{}) (map[string]interface{}, error) { result := make(map[string]interface{}) - - for k, v := range params { - switch k { - case "size", "slide": - if s, ok := v.(string); ok { - dur, err := time.ParseDuration(s) - if err != nil { - return nil, fmt.Errorf("invalid %s duration: %w", k, err) - } - result[k] = dur - } else { - return nil, fmt.Errorf("%s参数必须为字符串格式(如'5s')", k) + var key string + for index, v := range params { + if index == 0 { + key = "size" + } else if index == 1 { + key = "slide" + } else { + key = "offset" + } + if s, ok := v.(string); ok { + dur, err := time.ParseDuration(s) + if err != nil { + return nil, fmt.Errorf("invalid %s duration: %w", s, err) } - default: - result[k] = v + result[key] = dur + } else { + return nil, fmt.Errorf("%s参数必须为字符串格式(如'5s')", s) } } diff --git a/rsql/lexer.go b/rsql/lexer.go index a384bd9..7ade707 100644 --- a/rsql/lexer.go +++ b/rsql/lexer.go @@ -12,6 +12,10 @@ const ( TokenComma TokenLParen TokenRParen + TokenPlus + TokenMinus + TokenAsterisk + TokenSlash TokenEQ TokenNE TokenGT @@ -28,6 +32,8 @@ const ( TokenAS TokenTumbling TokenSliding + TokenCounting + TokenSession ) type Token struct { @@ -64,6 +70,18 @@ func (l *Lexer) NextToken() Token { case ')': l.readChar() return Token{Type: TokenRParen, Value: ")"} + case '+': + l.readChar() + return Token{Type: TokenPlus, Value: "+"} + case '-': + l.readChar() + return Token{Type: TokenMinus, Value: "-"} + case '*': + l.readChar() + return Token{Type: TokenAsterisk, Value: "*"} + case '/': + l.readChar() + return Token{Type: TokenSlash, Value: "/"} case '=': l.readChar() return Token{Type: TokenEQ, Value: "="} @@ -178,10 +196,14 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenOR, Value: ident} case "AND": return Token{Type: TokenAND, Value: ident} - case "TUMBLING": + case "TUMBLINGWINDOW": return Token{Type: TokenTumbling, Value: ident} - case "SLIDING": + case "SLIDINGWINDOW": return Token{Type: TokenSliding, Value: ident} + case "COUNTINGWINDOW": + return Token{Type: TokenCounting, Value: ident} + case "SESSIONWINDOW": + return Token{Type: TokenSession, Value: ident} default: return Token{Type: TokenIdent, Value: ident} } diff --git a/rsql/parser.go b/rsql/parser.go index eb45993..c3696e3 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -2,7 +2,6 @@ package rsql import ( "errors" - "fmt" "strconv" "strings" ) @@ -42,49 +41,50 @@ func (p *Parser) Parse() (*SelectStatement, error) { return stmt, nil } - func (p *Parser) parseSelect(stmt *SelectStatement) error { p.lexer.NextToken() // 跳过SELECT - + currentToken := p.lexer.NextToken() for { - tok := p.lexer.NextToken() - if tok.Type == TokenFROM { - break + var expr strings.Builder + for { + if currentToken.Type == TokenFROM || currentToken.Type == TokenComma || currentToken.Type == TokenAS { + break + } + expr.WriteString(currentToken.Value) + currentToken = p.lexer.NextToken() } - field := Field{Expression: tok.Value} - if p.lexer.peekChar() == ' ' { - if aliasTok := p.lexer.NextToken(); aliasTok.Type == TokenAS { - field.Alias = p.lexer.NextToken().Value - } + field := Field{Expression: strings.TrimSpace(expr.String())} + + // 处理别名 + if currentToken.Type == TokenAS { + field.Alias = p.lexer.NextToken().Value } stmt.Fields = append(stmt.Fields, field) - - if p.lexer.NextToken().Type != TokenComma { + currentToken = p.lexer.NextToken() + if currentToken.Type == TokenFROM { break } } return nil } -func (p *Parser) parseFrom(stmt *SelectStatement) error { - tok := p.lexer.NextToken() - if tok.Type != TokenIdent { - return errors.New("expected source identifier after FROM") - } - stmt.Source = tok.Value - return nil -} - func (p *Parser) parseWhere(stmt *SelectStatement) error { var conditions []string - p.lexer.NextToken() // 跳过WHERE - + current := p.lexer.NextToken() // 跳过WHERE + if current.Type != TokenWHERE { + return nil + } for { tok := p.lexer.NextToken() + if tok.Type == TokenGROUP || tok.Type == TokenEOF { + break + } switch tok.Type { - case TokenIdent, TokenNumber, TokenString: + case TokenIdent, TokenNumber: conditions = append(conditions, tok.Value) + case TokenString: + conditions = append(conditions, "'"+tok.Value+"'") case TokenEQ: conditions = append(conditions, "==") case TokenAND: @@ -92,43 +92,37 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error { case TokenOR: conditions = append(conditions, "||") default: - stmt.Condition = strings.Join(conditions, " ") - return nil - } - } -} - -func (p *Parser) parseGroupBy(stmt *SelectStatement) error { - p.lexer.NextToken() // 跳过GROUP - p.lexer.NextToken() // 跳过BY - - for { - tok := p.lexer.NextToken() - if tok.Type == TokenTumbling || tok.Type == TokenSliding { - return p.parseWindowFunction(stmt, tok.Value) - } - - stmt.GroupBy = append(stmt.GroupBy, tok.Value) - - if p.lexer.NextToken().Type != TokenComma { - break + // 处理字符串值的引号 + if len(conditions) > 0 && conditions[len(conditions)-1] == "'" { + conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value + } else { + conditions = append(conditions, tok.Value) + } } + } + stmt.Condition = strings.Join(conditions, " ") return nil } func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) error { - p.lexer.NextToken() // 跳过函数名 - params := make(map[string]interface{}) + p.lexer.NextToken() // 跳过( + var params []interface{} for p.lexer.peekChar() != ')' { - keyTok := p.lexer.NextToken() - if keyTok.Type != TokenIdent { - return fmt.Errorf("expected parameter key, got %v", keyTok) - } - valTok := p.lexer.NextToken() - params[keyTok.Value] = convertValue(valTok.Value) + if valTok.Type == TokenRParen || valTok.Type == TokenEOF { + break + } + if valTok.Type == TokenComma { + continue + } + //valTok := p.lexer.NextToken() + // 处理引号包裹的值 + if strings.HasPrefix(valTok.Value, "'") && strings.HasSuffix(valTok.Value, "'") { + valTok.Value = strings.Trim(valTok.Value, "'") + } + params = append(params, convertValue(valTok.Value)) } stmt.Window = WindowDefinition{ @@ -136,8 +130,8 @@ func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) erro Params: params, } return nil - } + func convertValue(s string) interface{} { if s == "true" { return true @@ -151,8 +145,43 @@ func convertValue(s string) interface{} { if f, err := strconv.ParseFloat(s, 64); err == nil { return f } + // 处理引号包裹的字符串 if strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'") { return strings.Trim(s, "'") } return s } + +func (p *Parser) parseFrom(stmt *SelectStatement) error { + tok := p.lexer.NextToken() + if tok.Type != TokenIdent { + return errors.New("expected source identifier after FROM") + } + stmt.Source = tok.Value + return nil +} + +func (p *Parser) parseGroupBy(stmt *SelectStatement) error { + //p.lexer.NextToken() // 跳过GROUP + p.lexer.NextToken() // 跳过BY + + for { + tok := p.lexer.NextToken() + if tok.Type == TokenEOF { + break + } + if tok.Type == TokenComma { + continue + } + if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { + return p.parseWindowFunction(stmt, tok.Value) + } + + stmt.GroupBy = append(stmt.GroupBy, tok.Value) + + //if p.lexer.NextToken().Type != TokenComma { + // break + //} + } + return nil +} diff --git a/rsql/parser_test.go b/rsql/parser_test.go index cc57b2c..2027545 100644 --- a/rsql/parser_test.go +++ b/rsql/parser_test.go @@ -16,7 +16,7 @@ func TestParseSQL(t *testing.T) { condition string }{ { - sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow(size='10s')", + sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s')", expected: &stream.Config{ WindowConfig: stream.WindowConfig{ Type: "tumbling", @@ -32,7 +32,7 @@ func TestParseSQL(t *testing.T) { condition: "deviceId == 'aa'", }, { - sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow(size='20s', slide='5s')", + sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow('20s', '5s')", expected: &stream.Config{ WindowConfig: stream.WindowConfig{ Type: "sliding", @@ -67,7 +67,7 @@ func TestParseSQL(t *testing.T) { } } func TestWindowParamParsing(t *testing.T) { - params := map[string]interface{}{"size": "10s", "slide": "5s"} + params := []interface{}{"10s", "5s"} result, err := parseWindowParams(params) assert.NoError(t, err) assert.Equal(t, 10*time.Second, result["size"]) diff --git a/window/factory.go b/window/factory.go index 274a4c0..d58cce9 100644 --- a/window/factory.go +++ b/window/factory.go @@ -5,6 +5,13 @@ import ( "github.com/spf13/cast" ) +const ( + TypeTumbling = "tumbling" + TypeSliding = "sliding" + TypeCounting = "counting" + TypeSession = "session" +) + type Window interface { Add(item interface{}) GetResults() []interface{} @@ -17,13 +24,13 @@ type Window interface { func CreateWindow(windowType string, params map[string]interface{}) (Window, error) { switch windowType { - case "tumbling": + case TypeTumbling: size, err := cast.ToDurationE(params["size"]) if err != nil { return nil, fmt.Errorf("invalid size for tumbling window: %v", err) } return NewTumblingWindow(size), nil - case "sliding": + case TypeSliding: size, err := cast.ToDurationE(params["size"]) if err != nil { return nil, fmt.Errorf("invalid size for sliding window: %v", err) @@ -33,7 +40,7 @@ func CreateWindow(windowType string, params map[string]interface{}) (Window, err return nil, fmt.Errorf("invalid slide for sliding window: %v", err) } return NewSlidingWindow(size, slide), nil - case "counting": + case TypeCounting: count := cast.ToInt(params["count"]) if count <= 0 { return nil, fmt.Errorf("count must be a positive integer")