package rsql import ( "errors" "fmt" "strconv" "strings" "time" "github.com/rulego/streamsql/types" ) type Parser struct { lexer *Lexer errorRecovery *ErrorRecovery currentToken Token input string } func NewParser(input string) *Parser { lexer := NewLexer(input) p := &Parser{ lexer: lexer, input: input, } p.errorRecovery = NewErrorRecovery(p) lexer.SetErrorRecovery(p.errorRecovery) return p } // GetErrors 获取解析过程中的所有错误 func (p *Parser) GetErrors() []*ParseError { return p.errorRecovery.GetErrors() } // HasErrors 检查是否有错误 func (p *Parser) HasErrors() bool { return p.errorRecovery.HasErrors() } // expectToken 期望特定类型的token func (p *Parser) expectToken(expected TokenType, context string) (Token, error) { tok := p.lexer.NextToken() if tok.Type != expected { err := CreateUnexpectedTokenError( tok.Value, []string{p.getTokenTypeName(expected)}, tok.Pos, ) err.Context = context p.errorRecovery.AddError(err) // 尝试错误恢复 if err.IsRecoverable() && p.errorRecovery.RecoverFromError(ErrorTypeUnexpectedToken) { return p.expectToken(expected, context) } return tok, err } return tok, nil } // getTokenTypeName 获取token类型名称 func (p *Parser) getTokenTypeName(tokenType TokenType) string { switch tokenType { case TokenSELECT: return "SELECT" case TokenFROM: return "FROM" case TokenWHERE: return "WHERE" case TokenGROUP: return "GROUP" case TokenBY: return "BY" case TokenComma: return "," case TokenLParen: return "(" case TokenRParen: return ")" case TokenIdent: return "identifier" case TokenNumber: return "number" case TokenString: return "string" default: return "unknown" } } func (p *Parser) Parse() (*SelectStatement, error) { stmt := &SelectStatement{} // 解析SELECT子句 if err := p.parseSelect(stmt); err != nil { if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) { return nil, p.createDetailedError(err) } } // 解析FROM子句 if err := p.parseFrom(stmt); err != nil { if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) { return nil, p.createDetailedError(err) } } // 解析WHERE子句 if err := p.parseWhere(stmt); err != nil { if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) { return nil, p.createDetailedError(err) } } // 解析GROUP BY子句 if err := p.parseGroupBy(stmt); err != nil { if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) { return nil, p.createDetailedError(err) } } // 解析 HAVING 子句 if err := p.parseHaving(stmt); err != nil { if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) { return nil, p.createDetailedError(err) } } if err := p.parseWith(stmt); err != nil { if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) { return nil, p.createDetailedError(err) } } // 解析LIMIT子句 if err := p.parseLimit(stmt); err != nil { if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) { return nil, p.createDetailedError(err) } } // 如果有错误但可以恢复,返回部分解析结果和错误信息 if p.errorRecovery.HasErrors() { return stmt, p.createCombinedError() } return stmt, nil } // createDetailedError 创建详细的错误信息 func (p *Parser) createDetailedError(err error) error { if parseErr, ok := err.(*ParseError); ok { parseErr.Context = FormatErrorContext(p.input, parseErr.Position, 20) return parseErr } return err } // createCombinedError 创建组合错误信息 func (p *Parser) createCombinedError() error { errors := p.errorRecovery.GetErrors() if len(errors) == 1 { return p.createDetailedError(errors[0]) } var builder strings.Builder builder.WriteString(fmt.Sprintf("Found %d parsing errors:\n", len(errors))) for i, err := range errors { builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, err.Error())) } return fmt.Errorf(builder.String()) } func (p *Parser) parseSelect(stmt *SelectStatement) error { // 验证第一个token是否为SELECT firstToken := p.lexer.NextToken() if firstToken.Type != TokenSELECT { // 如果不是SELECT,检查是否是拼写错误 if firstToken.Type == TokenIdent { // 这里的错误已经由lexer的checkForTypos处理了 // 我们继续解析,假设用户想要SELECT } else { return CreateSyntaxError( fmt.Sprintf("Expected SELECT, got %s", firstToken.Value), firstToken.Pos, firstToken.Value, []string{"SELECT"}, ) } } currentToken := p.lexer.NextToken() if currentToken.Type == TokenDISTINCT { stmt.Distinct = true currentToken = p.lexer.NextToken() // 消费 DISTINCT,移动到下一个 token } // 设置最大字段数量限制,防止无限循环 maxFields := 100 fieldCount := 0 for { fieldCount++ // 安全检查:防止无限循环 if fieldCount > maxFields { return errors.New("select field list parsing exceeded maximum fields, possible syntax error") } var expr strings.Builder parenthesesLevel := 0 // 跟踪括号嵌套层级 // 设置最大表达式长度,防止无限循环 maxExprParts := 100 exprPartCount := 0 for { exprPartCount++ // 安全检查:防止无限循环 if exprPartCount > maxExprParts { return errors.New("select field expression parsing exceeded maximum length, possible syntax error") } // 跟踪括号层级 if currentToken.Type == TokenLParen { parenthesesLevel++ } else if currentToken.Type == TokenRParen { parenthesesLevel-- } // 只有在括号层级为0时,逗号才被视为字段分隔符 if parenthesesLevel == 0 && (currentToken.Type == TokenFROM || currentToken.Type == TokenComma || currentToken.Type == TokenAS || currentToken.Type == TokenEOF) { break } expr.WriteString(currentToken.Value) currentToken = p.lexer.NextToken() } field := Field{Expression: strings.TrimSpace(expr.String())} // 处理别名 if currentToken.Type == TokenAS { field.Alias = p.lexer.NextToken().Value currentToken = p.lexer.NextToken() } // 如果表达式为空,跳过这个字段 if field.Expression != "" { // 验证表达式中的函数 validator := NewFunctionValidator(p.errorRecovery) pos, _, _ := p.lexer.GetPosition() validator.ValidateExpression(field.Expression, pos-len(field.Expression)) stmt.Fields = append(stmt.Fields, field) } if currentToken.Type == TokenFROM || currentToken.Type == TokenEOF { break } if currentToken.Type != TokenComma { // 如果不是逗号,那么应该是语法错误 return fmt.Errorf("unexpected token %v, expected comma or FROM", currentToken.Value) } currentToken = p.lexer.NextToken() } // 确保至少有一个字段 if len(stmt.Fields) == 0 { return errors.New("no fields specified in SELECT clause") } return nil } func (p *Parser) parseWhere(stmt *SelectStatement) error { var conditions []string current := p.lexer.NextToken() // 获取下一个token if current.Type != TokenWHERE { // 如果不是WHERE,回退token位置 return nil } // 设置最大次数限制,防止无限循环 maxIterations := 100 iterations := 0 for { iterations++ // 安全检查:防止无限循环 if iterations > maxIterations { return errors.New("WHERE clause parsing exceeded maximum iterations, possible syntax error") } tok := p.lexer.NextToken() if tok.Type == TokenGROUP || tok.Type == TokenEOF || tok.Type == TokenSliding || tok.Type == TokenTumbling || tok.Type == TokenCounting || tok.Type == TokenSession || tok.Type == TokenHAVING || tok.Type == TokenLIMIT { break } switch tok.Type { case TokenIdent, TokenNumber: conditions = append(conditions, tok.Value) case TokenString: conditions = append(conditions, tok.Value) case TokenEQ: if tok.Value == "=" { conditions = append(conditions, "==") } else { conditions = append(conditions, tok.Value) } case TokenAND: conditions = append(conditions, "&&") case TokenOR: conditions = append(conditions, "||") default: // 处理字符串值的引号 if len(conditions) > 0 && conditions[len(conditions)-1] == "'" { conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value } else { conditions = append(conditions, tok.Value) } } } // 验证WHERE条件中的函数 whereCondition := strings.Join(conditions, " ") if whereCondition != "" { validator := NewFunctionValidator(p.errorRecovery) pos, _, _ := p.lexer.GetPosition() validator.ValidateExpression(whereCondition, pos-len(whereCondition)) } stmt.Condition = whereCondition return nil } func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) error { p.lexer.NextToken() // 跳过( var params []interface{} // 设置最大次数限制,防止无限循环 maxIterations := 100 iterations := 0 for p.lexer.peekChar() != ')' { iterations++ // 安全检查:防止无限循环 if iterations > maxIterations { return errors.New("window function parameter parsing exceeded maximum iterations, possible syntax error") } valTok := p.lexer.NextToken() if valTok.Type == TokenRParen || valTok.Type == TokenEOF { break } if valTok.Type == TokenComma { continue } //valTok := p.lexer.NextToken() // 处理引号包裹的值 if strings.HasPrefix(valTok.Value, "'") && strings.HasSuffix(valTok.Value, "'") { valTok.Value = strings.Trim(valTok.Value, "'") } params = append(params, convertValue(valTok.Value)) } if &stmt.Window != nil { stmt.Window.Params = params stmt.Window.Type = winType } else { stmt.Window = WindowDefinition{ Type: winType, Params: params, } } return nil } func convertValue(s string) interface{} { if s == "true" { return true } if s == "false" { return false } if i, err := strconv.Atoi(s); err == nil { return i } if f, err := strconv.ParseFloat(s, 64); err == nil { return f } // 处理引号包裹的字符串 if strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'") { return strings.Trim(s, "'") } return s } func (p *Parser) parseFrom(stmt *SelectStatement) error { tok := p.lexer.NextToken() if tok.Type != TokenIdent { err := CreateUnexpectedTokenError( tok.Value, []string{"table_name", "stream_name"}, tok.Pos, ) err.Message = "Expected source identifier after FROM" err.Context = "FROM clause requires a table or stream name" err.Suggestions = []string{ "Ensure FROM is followed by a valid table or stream name", "Check if the table name is spelled correctly", } p.errorRecovery.AddError(err) return err } stmt.Source = tok.Value return nil } func (p *Parser) parseGroupBy(stmt *SelectStatement) error { tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()) if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { p.parseWindowFunction(stmt, tok.Value) } if tok.Type == TokenGROUP { p.lexer.NextToken() // 跳过BY } // 设置最大次数限制,防止无限循环 maxIterations := 100 iterations := 0 for { iterations++ // 安全检查:防止无限循环 if iterations > maxIterations { return errors.New("group by clause parsing exceeded maximum iterations, possible syntax error") } tok := p.lexer.NextToken() if tok.Type == TokenWITH || tok.Type == TokenOrder || tok.Type == TokenEOF || tok.Type == TokenHAVING || tok.Type == TokenLIMIT { break } if tok.Type == TokenComma { continue } if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession { p.parseWindowFunction(stmt, tok.Value) continue } stmt.GroupBy = append(stmt.GroupBy, tok.Value) } return nil } func (p *Parser) parseWith(stmt *SelectStatement) error { // 查看当前 token,如果不是 WITH,则返回 tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()) if tok.Type != TokenWITH { return nil // 没有 WITH 子句,不是错误 } p.lexer.NextToken() // 跳过( // 设置最大次数限制,防止无限循环 maxIterations := 100 iterations := 0 for p.lexer.peekChar() != ')' { iterations++ // 安全检查:防止无限循环 if iterations > maxIterations { return errors.New("WITH clause parsing exceeded maximum iterations, possible syntax error") } valTok := p.lexer.NextToken() if valTok.Type == TokenRParen || valTok.Type == TokenEOF { break } if valTok.Type == TokenComma { continue } if valTok.Type == TokenTimestamp { next := p.lexer.NextToken() if next.Type == TokenEQ { next = p.lexer.NextToken() if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") { next.Value = strings.Trim(next.Value, "'") } // 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition if stmt.Window.Type == "" { stmt.Window = WindowDefinition{ TsProp: next.Value, } } else { stmt.Window.TsProp = next.Value } } } if valTok.Type == TokenTimeUnit { timeUnit := time.Minute next := p.lexer.NextToken() if next.Type == TokenEQ { next = p.lexer.NextToken() if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") { next.Value = strings.Trim(next.Value, "'") } switch next.Value { case "dd": timeUnit = 24 * time.Hour case "hh": timeUnit = time.Hour case "mi": timeUnit = time.Minute case "ss": timeUnit = time.Second case "ms": timeUnit = time.Millisecond default: } // 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition if stmt.Window.Type == "" { stmt.Window = WindowDefinition{ TimeUnit: timeUnit, } } else { stmt.Window.TimeUnit = timeUnit } } } } return nil } // parseLimit 解析LIMIT子句 func (p *Parser) parseLimit(stmt *SelectStatement) error { // 查看当前token if p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()).Type == TokenLIMIT { // 获取下一个token,应该是一个数字 tok := p.lexer.NextToken() if tok.Type == TokenNumber { // 将数字字符串转换为整数 limit, err := strconv.Atoi(tok.Value) if err != nil { parseErr := CreateSyntaxError( "LIMIT value must be a valid integer", tok.Pos, tok.Value, []string{"positive_integer"}, ) parseErr.Context = "LIMIT clause" parseErr.Suggestions = []string{ "Use a positive integer, e.g., LIMIT 10", "Ensure the number format is correct", } p.errorRecovery.AddError(parseErr) return parseErr } if limit < 0 { parseErr := CreateSyntaxError( "LIMIT value must be positive", tok.Pos, tok.Value, []string{"positive_integer"}, ) parseErr.Suggestions = []string{"Use a positive integer, e.g., LIMIT 10"} p.errorRecovery.AddError(parseErr) return parseErr } stmt.Limit = limit } else { parseErr := CreateMissingTokenError("number", tok.Pos) parseErr.Message = "LIMIT must be followed by an integer" parseErr.Context = "LIMIT clause" parseErr.Suggestions = []string{ "Add a number after LIMIT, e.g., LIMIT 10", "Ensure LIMIT syntax is correct", } p.errorRecovery.AddError(parseErr) return parseErr } } return nil } // parseHaving 解析HAVING子句 func (p *Parser) parseHaving(stmt *SelectStatement) error { // 查看当前token tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()) if tok.Type != TokenHAVING { return nil // 没有 HAVING 子句,不是错误 } // 设置最大次数限制,防止无限循环 maxIterations := 100 iterations := 0 var conditions []string for { iterations++ // 安全检查:防止无限循环 if iterations > maxIterations { return errors.New("HAVING clause parsing exceeded maximum iterations, possible syntax error") } tok := p.lexer.NextToken() if tok.Type == TokenLIMIT || tok.Type == TokenEOF || tok.Type == TokenWITH { break } switch tok.Type { case TokenIdent, TokenNumber: conditions = append(conditions, tok.Value) case TokenString: conditions = append(conditions, tok.Value) case TokenEQ: if tok.Value == "=" { conditions = append(conditions, "==") } else { conditions = append(conditions, tok.Value) } case TokenAND: conditions = append(conditions, "&&") case TokenOR: conditions = append(conditions, "||") default: // 处理字符串值的引号 if len(conditions) > 0 && conditions[len(conditions)-1] == "'" { conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value } else { conditions = append(conditions, tok.Value) } } } // 验证HAVING条件中的函数 havingCondition := strings.Join(conditions, " ") if havingCondition != "" { validator := NewFunctionValidator(p.errorRecovery) pos, _, _ := p.lexer.GetPosition() validator.ValidateExpression(havingCondition, pos-len(havingCondition)) } stmt.Having = havingCondition return nil } // Parse 是包级别的Parse函数,用于解析SQL字符串并返回配置和条件 func Parse(sql string) (*types.Config, string, error) { parser := NewParser(sql) stmt, err := parser.Parse() if err != nil { return nil, "", err } config, condition, err := stmt.ToStreamConfig() if err != nil { return nil, "", err } return config, condition, nil }