feat:sql parser

This commit is contained in:
rulego-team
2025-03-13 15:28:24 +08:00
parent eb9a145ce3
commit a09aae6b75
5 changed files with 148 additions and 81 deletions

View File

@ -2,6 +2,7 @@ package rsql
import (
"fmt"
"github.com/rulego/streamsql/window"
"strings"
"time"
@ -24,7 +25,7 @@ type Field struct {
type WindowDefinition struct {
Type string
Params map[string]interface{}
Params []interface{}
}
// ToStreamConfig 将AST转换为Stream配置
@ -33,9 +34,15 @@ func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) {
return nil, "", fmt.Errorf("missing FROM clause")
}
// 解析窗口配置
windowType := "tumbling"
if s.Window.Type == "Sliding" {
windowType = "sliding"
windowType := window.TypeTumbling
if strings.ToUpper(s.Window.Type) == "TUMBLINGWINDOW" {
windowType = window.TypeTumbling
} else if strings.ToUpper(s.Window.Type) == "SLIDINGWINDOW" {
windowType = window.TypeSliding
} else if strings.ToUpper(s.Window.Type) == "COUNTINGWINDOW" {
windowType = window.TypeCounting
} else if strings.ToUpper(s.Window.Type) == "SESSIONWINDOW" {
windowType = window.TypeSession
}
params, err := parseWindowParams(s.Window.Params)
@ -92,23 +99,25 @@ func parseAggregateType(expr string) aggregator.AggregateType {
return ""
}
func parseWindowParams(params map[string]interface{}) (map[string]interface{}, error) {
func parseWindowParams(params []interface{}) (map[string]interface{}, error) {
result := make(map[string]interface{})
for k, v := range params {
switch k {
case "size", "slide":
if s, ok := v.(string); ok {
dur, err := time.ParseDuration(s)
if err != nil {
return nil, fmt.Errorf("invalid %s duration: %w", k, err)
}
result[k] = dur
} else {
return nil, fmt.Errorf("%s参数必须为字符串格式(如'5s')", k)
var key string
for index, v := range params {
if index == 0 {
key = "size"
} else if index == 1 {
key = "slide"
} else {
key = "offset"
}
if s, ok := v.(string); ok {
dur, err := time.ParseDuration(s)
if err != nil {
return nil, fmt.Errorf("invalid %s duration: %w", s, err)
}
default:
result[k] = v
result[key] = dur
} else {
return nil, fmt.Errorf("%s参数必须为字符串格式(如'5s')", s)
}
}

View File

@ -12,6 +12,10 @@ const (
TokenComma
TokenLParen
TokenRParen
TokenPlus
TokenMinus
TokenAsterisk
TokenSlash
TokenEQ
TokenNE
TokenGT
@ -28,6 +32,8 @@ const (
TokenAS
TokenTumbling
TokenSliding
TokenCounting
TokenSession
)
type Token struct {
@ -64,6 +70,18 @@ func (l *Lexer) NextToken() Token {
case ')':
l.readChar()
return Token{Type: TokenRParen, Value: ")"}
case '+':
l.readChar()
return Token{Type: TokenPlus, Value: "+"}
case '-':
l.readChar()
return Token{Type: TokenMinus, Value: "-"}
case '*':
l.readChar()
return Token{Type: TokenAsterisk, Value: "*"}
case '/':
l.readChar()
return Token{Type: TokenSlash, Value: "/"}
case '=':
l.readChar()
return Token{Type: TokenEQ, Value: "="}
@ -178,10 +196,14 @@ func (l *Lexer) lookupIdent(ident string) Token {
return Token{Type: TokenOR, Value: ident}
case "AND":
return Token{Type: TokenAND, Value: ident}
case "TUMBLING":
case "TUMBLINGWINDOW":
return Token{Type: TokenTumbling, Value: ident}
case "SLIDING":
case "SLIDINGWINDOW":
return Token{Type: TokenSliding, Value: ident}
case "COUNTINGWINDOW":
return Token{Type: TokenCounting, Value: ident}
case "SESSIONWINDOW":
return Token{Type: TokenSession, Value: ident}
default:
return Token{Type: TokenIdent, Value: ident}
}

View File

@ -2,7 +2,6 @@ package rsql
import (
"errors"
"fmt"
"strconv"
"strings"
)
@ -42,49 +41,50 @@ func (p *Parser) Parse() (*SelectStatement, error) {
return stmt, nil
}
func (p *Parser) parseSelect(stmt *SelectStatement) error {
p.lexer.NextToken() // 跳过SELECT
currentToken := p.lexer.NextToken()
for {
tok := p.lexer.NextToken()
if tok.Type == TokenFROM {
break
var expr strings.Builder
for {
if currentToken.Type == TokenFROM || currentToken.Type == TokenComma || currentToken.Type == TokenAS {
break
}
expr.WriteString(currentToken.Value)
currentToken = p.lexer.NextToken()
}
field := Field{Expression: tok.Value}
if p.lexer.peekChar() == ' ' {
if aliasTok := p.lexer.NextToken(); aliasTok.Type == TokenAS {
field.Alias = p.lexer.NextToken().Value
}
field := Field{Expression: strings.TrimSpace(expr.String())}
// 处理别名
if currentToken.Type == TokenAS {
field.Alias = p.lexer.NextToken().Value
}
stmt.Fields = append(stmt.Fields, field)
if p.lexer.NextToken().Type != TokenComma {
currentToken = p.lexer.NextToken()
if currentToken.Type == TokenFROM {
break
}
}
return nil
}
func (p *Parser) parseFrom(stmt *SelectStatement) error {
tok := p.lexer.NextToken()
if tok.Type != TokenIdent {
return errors.New("expected source identifier after FROM")
}
stmt.Source = tok.Value
return nil
}
func (p *Parser) parseWhere(stmt *SelectStatement) error {
var conditions []string
p.lexer.NextToken() // 跳过WHERE
current := p.lexer.NextToken() // 跳过WHERE
if current.Type != TokenWHERE {
return nil
}
for {
tok := p.lexer.NextToken()
if tok.Type == TokenGROUP || tok.Type == TokenEOF {
break
}
switch tok.Type {
case TokenIdent, TokenNumber, TokenString:
case TokenIdent, TokenNumber:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, "'"+tok.Value+"'")
case TokenEQ:
conditions = append(conditions, "==")
case TokenAND:
@ -92,43 +92,37 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error {
case TokenOR:
conditions = append(conditions, "||")
default:
stmt.Condition = strings.Join(conditions, " ")
return nil
}
}
}
func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
p.lexer.NextToken() // 跳过GROUP
p.lexer.NextToken() // 跳过BY
for {
tok := p.lexer.NextToken()
if tok.Type == TokenTumbling || tok.Type == TokenSliding {
return p.parseWindowFunction(stmt, tok.Value)
}
stmt.GroupBy = append(stmt.GroupBy, tok.Value)
if p.lexer.NextToken().Type != TokenComma {
break
// 处理字符串值的引号
if len(conditions) > 0 && conditions[len(conditions)-1] == "'" {
conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value
} else {
conditions = append(conditions, tok.Value)
}
}
}
stmt.Condition = strings.Join(conditions, " ")
return nil
}
func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) error {
p.lexer.NextToken() // 跳过函数名
params := make(map[string]interface{})
p.lexer.NextToken() // 跳过(
var params []interface{}
for p.lexer.peekChar() != ')' {
keyTok := p.lexer.NextToken()
if keyTok.Type != TokenIdent {
return fmt.Errorf("expected parameter key, got %v", keyTok)
}
valTok := p.lexer.NextToken()
params[keyTok.Value] = convertValue(valTok.Value)
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
break
}
if valTok.Type == TokenComma {
continue
}
//valTok := p.lexer.NextToken()
// 处理引号包裹的值
if strings.HasPrefix(valTok.Value, "'") && strings.HasSuffix(valTok.Value, "'") {
valTok.Value = strings.Trim(valTok.Value, "'")
}
params = append(params, convertValue(valTok.Value))
}
stmt.Window = WindowDefinition{
@ -136,8 +130,8 @@ func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) erro
Params: params,
}
return nil
}
func convertValue(s string) interface{} {
if s == "true" {
return true
@ -151,8 +145,43 @@ func convertValue(s string) interface{} {
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f
}
// 处理引号包裹的字符串
if strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'") {
return strings.Trim(s, "'")
}
return s
}
func (p *Parser) parseFrom(stmt *SelectStatement) error {
tok := p.lexer.NextToken()
if tok.Type != TokenIdent {
return errors.New("expected source identifier after FROM")
}
stmt.Source = tok.Value
return nil
}
func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
//p.lexer.NextToken() // 跳过GROUP
p.lexer.NextToken() // 跳过BY
for {
tok := p.lexer.NextToken()
if tok.Type == TokenEOF {
break
}
if tok.Type == TokenComma {
continue
}
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
return p.parseWindowFunction(stmt, tok.Value)
}
stmt.GroupBy = append(stmt.GroupBy, tok.Value)
//if p.lexer.NextToken().Type != TokenComma {
// break
//}
}
return nil
}

View File

@ -16,7 +16,7 @@ func TestParseSQL(t *testing.T) {
condition string
}{
{
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow(size='10s')",
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s')",
expected: &stream.Config{
WindowConfig: stream.WindowConfig{
Type: "tumbling",
@ -32,7 +32,7 @@ func TestParseSQL(t *testing.T) {
condition: "deviceId == 'aa'",
},
{
sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow(size='20s', slide='5s')",
sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow('20s', '5s')",
expected: &stream.Config{
WindowConfig: stream.WindowConfig{
Type: "sliding",
@ -67,7 +67,7 @@ func TestParseSQL(t *testing.T) {
}
}
func TestWindowParamParsing(t *testing.T) {
params := map[string]interface{}{"size": "10s", "slide": "5s"}
params := []interface{}{"10s", "5s"}
result, err := parseWindowParams(params)
assert.NoError(t, err)
assert.Equal(t, 10*time.Second, result["size"])

View File

@ -5,6 +5,13 @@ import (
"github.com/spf13/cast"
)
const (
TypeTumbling = "tumbling"
TypeSliding = "sliding"
TypeCounting = "counting"
TypeSession = "session"
)
type Window interface {
Add(item interface{})
GetResults() []interface{}
@ -17,13 +24,13 @@ type Window interface {
func CreateWindow(windowType string, params map[string]interface{}) (Window, error) {
switch windowType {
case "tumbling":
case TypeTumbling:
size, err := cast.ToDurationE(params["size"])
if err != nil {
return nil, fmt.Errorf("invalid size for tumbling window: %v", err)
}
return NewTumblingWindow(size), nil
case "sliding":
case TypeSliding:
size, err := cast.ToDurationE(params["size"])
if err != nil {
return nil, fmt.Errorf("invalid size for sliding window: %v", err)
@ -33,7 +40,7 @@ func CreateWindow(windowType string, params map[string]interface{}) (Window, err
return nil, fmt.Errorf("invalid slide for sliding window: %v", err)
}
return NewSlidingWindow(size, slide), nil
case "counting":
case TypeCounting:
count := cast.ToInt(params["count"])
if count <= 0 {
return nil, fmt.Errorf("count must be a positive integer")