Files
streamsql/rsql/parser.go
T
2025-06-11 18:46:32 +08:00

672 lines
17 KiB
Go

package rsql
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/rulego/streamsql/types"
)
type Parser struct {
lexer *Lexer
errorRecovery *ErrorRecovery
currentToken Token
input string
}
func NewParser(input string) *Parser {
lexer := NewLexer(input)
p := &Parser{
lexer: lexer,
input: input,
}
p.errorRecovery = NewErrorRecovery(p)
lexer.SetErrorRecovery(p.errorRecovery)
return p
}
// GetErrors 获取解析过程中的所有错误
func (p *Parser) GetErrors() []*ParseError {
return p.errorRecovery.GetErrors()
}
// HasErrors 检查是否有错误
func (p *Parser) HasErrors() bool {
return p.errorRecovery.HasErrors()
}
// expectToken 期望特定类型的token
func (p *Parser) expectToken(expected TokenType, context string) (Token, error) {
tok := p.lexer.NextToken()
if tok.Type != expected {
err := CreateUnexpectedTokenError(
tok.Value,
[]string{p.getTokenTypeName(expected)},
tok.Pos,
)
err.Context = context
p.errorRecovery.AddError(err)
// 尝试错误恢复
if err.IsRecoverable() && p.errorRecovery.RecoverFromError(ErrorTypeUnexpectedToken) {
return p.expectToken(expected, context)
}
return tok, err
}
return tok, nil
}
// getTokenTypeName 获取token类型名称
func (p *Parser) getTokenTypeName(tokenType TokenType) string {
switch tokenType {
case TokenSELECT:
return "SELECT"
case TokenFROM:
return "FROM"
case TokenWHERE:
return "WHERE"
case TokenGROUP:
return "GROUP"
case TokenBY:
return "BY"
case TokenComma:
return ","
case TokenLParen:
return "("
case TokenRParen:
return ")"
case TokenIdent:
return "identifier"
case TokenNumber:
return "number"
case TokenString:
return "string"
default:
return "unknown"
}
}
func (p *Parser) Parse() (*SelectStatement, error) {
stmt := &SelectStatement{}
// 解析SELECT子句
if err := p.parseSelect(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析FROM子句
if err := p.parseFrom(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析WHERE子句
if err := p.parseWhere(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析GROUP BY子句
if err := p.parseGroupBy(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析 HAVING 子句
if err := p.parseHaving(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
if err := p.parseWith(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析LIMIT子句
if err := p.parseLimit(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 如果有错误但可以恢复,返回部分解析结果和错误信息
if p.errorRecovery.HasErrors() {
return stmt, p.createCombinedError()
}
return stmt, nil
}
// createDetailedError 创建详细的错误信息
func (p *Parser) createDetailedError(err error) error {
if parseErr, ok := err.(*ParseError); ok {
parseErr.Context = FormatErrorContext(p.input, parseErr.Position, 20)
return parseErr
}
return err
}
// createCombinedError 创建组合错误信息
func (p *Parser) createCombinedError() error {
errors := p.errorRecovery.GetErrors()
if len(errors) == 1 {
return p.createDetailedError(errors[0])
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Found %d parsing errors:\n", len(errors)))
for i, err := range errors {
builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, err.Error()))
}
return fmt.Errorf(builder.String())
}
func (p *Parser) parseSelect(stmt *SelectStatement) error {
// 验证第一个token是否为SELECT
firstToken := p.lexer.NextToken()
if firstToken.Type != TokenSELECT {
// 如果不是SELECT,检查是否是拼写错误
if firstToken.Type == TokenIdent {
// 这里的错误已经由lexer的checkForTypos处理了
// 我们继续解析,假设用户想要SELECT
} else {
return CreateSyntaxError(
fmt.Sprintf("Expected SELECT, got %s", firstToken.Value),
firstToken.Pos,
firstToken.Value,
[]string{"SELECT"},
)
}
}
currentToken := p.lexer.NextToken()
if currentToken.Type == TokenDISTINCT {
stmt.Distinct = true
currentToken = p.lexer.NextToken() // 消费 DISTINCT,移动到下一个 token
}
// 设置最大字段数量限制,防止无限循环
maxFields := 100
fieldCount := 0
for {
fieldCount++
// 安全检查:防止无限循环
if fieldCount > maxFields {
return errors.New("select field list parsing exceeded maximum fields, possible syntax error")
}
var expr strings.Builder
parenthesesLevel := 0 // 跟踪括号嵌套层级
// 设置最大表达式长度,防止无限循环
maxExprParts := 100
exprPartCount := 0
for {
exprPartCount++
// 安全检查:防止无限循环
if exprPartCount > maxExprParts {
return errors.New("select field expression parsing exceeded maximum length, possible syntax error")
}
// 跟踪括号层级
if currentToken.Type == TokenLParen {
parenthesesLevel++
} else if currentToken.Type == TokenRParen {
parenthesesLevel--
}
// 只有在括号层级为0时,逗号才被视为字段分隔符
if parenthesesLevel == 0 && (currentToken.Type == TokenFROM || currentToken.Type == TokenComma || currentToken.Type == TokenAS || currentToken.Type == TokenEOF) {
break
}
expr.WriteString(currentToken.Value)
currentToken = p.lexer.NextToken()
}
field := Field{Expression: strings.TrimSpace(expr.String())}
// 处理别名
if currentToken.Type == TokenAS {
field.Alias = p.lexer.NextToken().Value
currentToken = p.lexer.NextToken()
}
// 如果表达式为空,跳过这个字段
if field.Expression != "" {
// 验证表达式中的函数
validator := NewFunctionValidator(p.errorRecovery)
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(field.Expression, pos-len(field.Expression))
stmt.Fields = append(stmt.Fields, field)
}
if currentToken.Type == TokenFROM || currentToken.Type == TokenEOF {
break
}
if currentToken.Type != TokenComma {
// 如果不是逗号,那么应该是语法错误
return fmt.Errorf("unexpected token %v, expected comma or FROM", currentToken.Value)
}
currentToken = p.lexer.NextToken()
}
// 确保至少有一个字段
if len(stmt.Fields) == 0 {
return errors.New("no fields specified in SELECT clause")
}
return nil
}
func (p *Parser) parseWhere(stmt *SelectStatement) error {
var conditions []string
current := p.lexer.NextToken() // 获取下一个token
if current.Type != TokenWHERE {
// 如果不是WHERE,回退token位置
return nil
}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("WHERE clause parsing exceeded maximum iterations, possible syntax error")
}
tok := p.lexer.NextToken()
if tok.Type == TokenGROUP || tok.Type == TokenEOF || tok.Type == TokenSliding ||
tok.Type == TokenTumbling || tok.Type == TokenCounting || tok.Type == TokenSession ||
tok.Type == TokenHAVING || tok.Type == TokenLIMIT {
break
}
switch tok.Type {
case TokenIdent, TokenNumber:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, tok.Value)
case TokenEQ:
if tok.Value == "=" {
conditions = append(conditions, "==")
} else {
conditions = append(conditions, tok.Value)
}
case TokenAND:
conditions = append(conditions, "&&")
case TokenOR:
conditions = append(conditions, "||")
default:
// 处理字符串值的引号
if len(conditions) > 0 && conditions[len(conditions)-1] == "'" {
conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value
} else {
conditions = append(conditions, tok.Value)
}
}
}
// 验证WHERE条件中的函数
whereCondition := strings.Join(conditions, " ")
if whereCondition != "" {
validator := NewFunctionValidator(p.errorRecovery)
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(whereCondition, pos-len(whereCondition))
}
stmt.Condition = whereCondition
return nil
}
func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) error {
p.lexer.NextToken() // 跳过(
var params []interface{}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for p.lexer.peekChar() != ')' {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("window function parameter parsing exceeded maximum iterations, possible syntax error")
}
valTok := p.lexer.NextToken()
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
break
}
if valTok.Type == TokenComma {
continue
}
//valTok := p.lexer.NextToken()
// 处理引号包裹的值
if strings.HasPrefix(valTok.Value, "'") && strings.HasSuffix(valTok.Value, "'") {
valTok.Value = strings.Trim(valTok.Value, "'")
}
params = append(params, convertValue(valTok.Value))
}
if &stmt.Window != nil {
stmt.Window.Params = params
stmt.Window.Type = winType
} else {
stmt.Window = WindowDefinition{
Type: winType,
Params: params,
}
}
return nil
}
func convertValue(s string) interface{} {
if s == "true" {
return true
}
if s == "false" {
return false
}
if i, err := strconv.Atoi(s); err == nil {
return i
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f
}
// 处理引号包裹的字符串
if strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'") {
return strings.Trim(s, "'")
}
return s
}
func (p *Parser) parseFrom(stmt *SelectStatement) error {
tok := p.lexer.NextToken()
if tok.Type != TokenIdent {
err := CreateUnexpectedTokenError(
tok.Value,
[]string{"table_name", "stream_name"},
tok.Pos,
)
err.Message = "Expected source identifier after FROM"
err.Context = "FROM clause requires a table or stream name"
err.Suggestions = []string{
"Ensure FROM is followed by a valid table or stream name",
"Check if the table name is spelled correctly",
}
p.errorRecovery.AddError(err)
return err
}
stmt.Source = tok.Value
return nil
}
func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
p.parseWindowFunction(stmt, tok.Value)
}
if tok.Type == TokenGROUP {
p.lexer.NextToken() // 跳过BY
}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("group by clause parsing exceeded maximum iterations, possible syntax error")
}
tok := p.lexer.NextToken()
if tok.Type == TokenWITH || tok.Type == TokenOrder || tok.Type == TokenEOF ||
tok.Type == TokenHAVING || tok.Type == TokenLIMIT {
break
}
if tok.Type == TokenComma {
continue
}
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
p.parseWindowFunction(stmt, tok.Value)
continue
}
stmt.GroupBy = append(stmt.GroupBy, tok.Value)
}
return nil
}
func (p *Parser) parseWith(stmt *SelectStatement) error {
// 查看当前 token,如果不是 WITH,则返回
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type != TokenWITH {
return nil // 没有 WITH 子句,不是错误
}
p.lexer.NextToken() // 跳过(
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for p.lexer.peekChar() != ')' {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("WITH clause parsing exceeded maximum iterations, possible syntax error")
}
valTok := p.lexer.NextToken()
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
break
}
if valTok.Type == TokenComma {
continue
}
if valTok.Type == TokenTimestamp {
next := p.lexer.NextToken()
if next.Type == TokenEQ {
next = p.lexer.NextToken()
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
next.Value = strings.Trim(next.Value, "'")
}
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
if stmt.Window.Type == "" {
stmt.Window = WindowDefinition{
TsProp: next.Value,
}
} else {
stmt.Window.TsProp = next.Value
}
}
}
if valTok.Type == TokenTimeUnit {
timeUnit := time.Minute
next := p.lexer.NextToken()
if next.Type == TokenEQ {
next = p.lexer.NextToken()
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
next.Value = strings.Trim(next.Value, "'")
}
switch next.Value {
case "dd":
timeUnit = 24 * time.Hour
case "hh":
timeUnit = time.Hour
case "mi":
timeUnit = time.Minute
case "ss":
timeUnit = time.Second
case "ms":
timeUnit = time.Millisecond
default:
}
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
if stmt.Window.Type == "" {
stmt.Window = WindowDefinition{
TimeUnit: timeUnit,
}
} else {
stmt.Window.TimeUnit = timeUnit
}
}
}
}
return nil
}
// parseLimit 解析LIMIT子句
func (p *Parser) parseLimit(stmt *SelectStatement) error {
// 查看当前token
if p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()).Type == TokenLIMIT {
// 获取下一个token,应该是一个数字
tok := p.lexer.NextToken()
if tok.Type == TokenNumber {
// 将数字字符串转换为整数
limit, err := strconv.Atoi(tok.Value)
if err != nil {
parseErr := CreateSyntaxError(
"LIMIT value must be a valid integer",
tok.Pos,
tok.Value,
[]string{"positive_integer"},
)
parseErr.Context = "LIMIT clause"
parseErr.Suggestions = []string{
"Use a positive integer, e.g., LIMIT 10",
"Ensure the number format is correct",
}
p.errorRecovery.AddError(parseErr)
return parseErr
}
if limit < 0 {
parseErr := CreateSyntaxError(
"LIMIT value must be positive",
tok.Pos,
tok.Value,
[]string{"positive_integer"},
)
parseErr.Suggestions = []string{"Use a positive integer, e.g., LIMIT 10"}
p.errorRecovery.AddError(parseErr)
return parseErr
}
stmt.Limit = limit
} else {
parseErr := CreateMissingTokenError("number", tok.Pos)
parseErr.Message = "LIMIT must be followed by an integer"
parseErr.Context = "LIMIT clause"
parseErr.Suggestions = []string{
"Add a number after LIMIT, e.g., LIMIT 10",
"Ensure LIMIT syntax is correct",
}
p.errorRecovery.AddError(parseErr)
return parseErr
}
}
return nil
}
// parseHaving 解析HAVING子句
func (p *Parser) parseHaving(stmt *SelectStatement) error {
// 查看当前token
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type != TokenHAVING {
return nil // 没有 HAVING 子句,不是错误
}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
var conditions []string
for {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("HAVING clause parsing exceeded maximum iterations, possible syntax error")
}
tok := p.lexer.NextToken()
if tok.Type == TokenLIMIT || tok.Type == TokenEOF || tok.Type == TokenWITH {
break
}
switch tok.Type {
case TokenIdent, TokenNumber:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, tok.Value)
case TokenEQ:
if tok.Value == "=" {
conditions = append(conditions, "==")
} else {
conditions = append(conditions, tok.Value)
}
case TokenAND:
conditions = append(conditions, "&&")
case TokenOR:
conditions = append(conditions, "||")
default:
// 处理字符串值的引号
if len(conditions) > 0 && conditions[len(conditions)-1] == "'" {
conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value
} else {
conditions = append(conditions, tok.Value)
}
}
}
// 验证HAVING条件中的函数
havingCondition := strings.Join(conditions, " ")
if havingCondition != "" {
validator := NewFunctionValidator(p.errorRecovery)
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(havingCondition, pos-len(havingCondition))
}
stmt.Having = havingCondition
return nil
}
// Parse 是包级别的Parse函数,用于解析SQL字符串并返回配置和条件
func Parse(sql string) (*types.Config, string, error) {
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
return nil, "", err
}
config, condition, err := stmt.ToStreamConfig()
if err != nil {
return nil, "", err
}
return config, condition, nil
}