Files
streamsql/rsql/parser.go
2025-06-11 18:46:32 +08:00

672 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package rsql
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/rulego/streamsql/types"
)
type Parser struct {
lexer *Lexer
errorRecovery *ErrorRecovery
currentToken Token
input string
}
func NewParser(input string) *Parser {
lexer := NewLexer(input)
p := &Parser{
lexer: lexer,
input: input,
}
p.errorRecovery = NewErrorRecovery(p)
lexer.SetErrorRecovery(p.errorRecovery)
return p
}
// GetErrors 获取解析过程中的所有错误
func (p *Parser) GetErrors() []*ParseError {
return p.errorRecovery.GetErrors()
}
// HasErrors 检查是否有错误
func (p *Parser) HasErrors() bool {
return p.errorRecovery.HasErrors()
}
// expectToken 期望特定类型的token
func (p *Parser) expectToken(expected TokenType, context string) (Token, error) {
tok := p.lexer.NextToken()
if tok.Type != expected {
err := CreateUnexpectedTokenError(
tok.Value,
[]string{p.getTokenTypeName(expected)},
tok.Pos,
)
err.Context = context
p.errorRecovery.AddError(err)
// 尝试错误恢复
if err.IsRecoverable() && p.errorRecovery.RecoverFromError(ErrorTypeUnexpectedToken) {
return p.expectToken(expected, context)
}
return tok, err
}
return tok, nil
}
// getTokenTypeName 获取token类型名称
func (p *Parser) getTokenTypeName(tokenType TokenType) string {
switch tokenType {
case TokenSELECT:
return "SELECT"
case TokenFROM:
return "FROM"
case TokenWHERE:
return "WHERE"
case TokenGROUP:
return "GROUP"
case TokenBY:
return "BY"
case TokenComma:
return ","
case TokenLParen:
return "("
case TokenRParen:
return ")"
case TokenIdent:
return "identifier"
case TokenNumber:
return "number"
case TokenString:
return "string"
default:
return "unknown"
}
}
func (p *Parser) Parse() (*SelectStatement, error) {
stmt := &SelectStatement{}
// 解析SELECT子句
if err := p.parseSelect(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析FROM子句
if err := p.parseFrom(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析WHERE子句
if err := p.parseWhere(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析GROUP BY子句
if err := p.parseGroupBy(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析 HAVING 子句
if err := p.parseHaving(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
if err := p.parseWith(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 解析LIMIT子句
if err := p.parseLimit(stmt); err != nil {
if !p.errorRecovery.RecoverFromError(ErrorTypeSyntax) {
return nil, p.createDetailedError(err)
}
}
// 如果有错误但可以恢复,返回部分解析结果和错误信息
if p.errorRecovery.HasErrors() {
return stmt, p.createCombinedError()
}
return stmt, nil
}
// createDetailedError 创建详细的错误信息
func (p *Parser) createDetailedError(err error) error {
if parseErr, ok := err.(*ParseError); ok {
parseErr.Context = FormatErrorContext(p.input, parseErr.Position, 20)
return parseErr
}
return err
}
// createCombinedError 创建组合错误信息
func (p *Parser) createCombinedError() error {
errors := p.errorRecovery.GetErrors()
if len(errors) == 1 {
return p.createDetailedError(errors[0])
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Found %d parsing errors:\n", len(errors)))
for i, err := range errors {
builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, err.Error()))
}
return fmt.Errorf(builder.String())
}
func (p *Parser) parseSelect(stmt *SelectStatement) error {
// 验证第一个token是否为SELECT
firstToken := p.lexer.NextToken()
if firstToken.Type != TokenSELECT {
// 如果不是SELECT检查是否是拼写错误
if firstToken.Type == TokenIdent {
// 这里的错误已经由lexer的checkForTypos处理了
// 我们继续解析假设用户想要SELECT
} else {
return CreateSyntaxError(
fmt.Sprintf("Expected SELECT, got %s", firstToken.Value),
firstToken.Pos,
firstToken.Value,
[]string{"SELECT"},
)
}
}
currentToken := p.lexer.NextToken()
if currentToken.Type == TokenDISTINCT {
stmt.Distinct = true
currentToken = p.lexer.NextToken() // 消费 DISTINCT移动到下一个 token
}
// 设置最大字段数量限制,防止无限循环
maxFields := 100
fieldCount := 0
for {
fieldCount++
// 安全检查:防止无限循环
if fieldCount > maxFields {
return errors.New("select field list parsing exceeded maximum fields, possible syntax error")
}
var expr strings.Builder
parenthesesLevel := 0 // 跟踪括号嵌套层级
// 设置最大表达式长度,防止无限循环
maxExprParts := 100
exprPartCount := 0
for {
exprPartCount++
// 安全检查:防止无限循环
if exprPartCount > maxExprParts {
return errors.New("select field expression parsing exceeded maximum length, possible syntax error")
}
// 跟踪括号层级
if currentToken.Type == TokenLParen {
parenthesesLevel++
} else if currentToken.Type == TokenRParen {
parenthesesLevel--
}
// 只有在括号层级为0时逗号才被视为字段分隔符
if parenthesesLevel == 0 && (currentToken.Type == TokenFROM || currentToken.Type == TokenComma || currentToken.Type == TokenAS || currentToken.Type == TokenEOF) {
break
}
expr.WriteString(currentToken.Value)
currentToken = p.lexer.NextToken()
}
field := Field{Expression: strings.TrimSpace(expr.String())}
// 处理别名
if currentToken.Type == TokenAS {
field.Alias = p.lexer.NextToken().Value
currentToken = p.lexer.NextToken()
}
// 如果表达式为空,跳过这个字段
if field.Expression != "" {
// 验证表达式中的函数
validator := NewFunctionValidator(p.errorRecovery)
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(field.Expression, pos-len(field.Expression))
stmt.Fields = append(stmt.Fields, field)
}
if currentToken.Type == TokenFROM || currentToken.Type == TokenEOF {
break
}
if currentToken.Type != TokenComma {
// 如果不是逗号,那么应该是语法错误
return fmt.Errorf("unexpected token %v, expected comma or FROM", currentToken.Value)
}
currentToken = p.lexer.NextToken()
}
// 确保至少有一个字段
if len(stmt.Fields) == 0 {
return errors.New("no fields specified in SELECT clause")
}
return nil
}
func (p *Parser) parseWhere(stmt *SelectStatement) error {
var conditions []string
current := p.lexer.NextToken() // 获取下一个token
if current.Type != TokenWHERE {
// 如果不是WHERE回退token位置
return nil
}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("WHERE clause parsing exceeded maximum iterations, possible syntax error")
}
tok := p.lexer.NextToken()
if tok.Type == TokenGROUP || tok.Type == TokenEOF || tok.Type == TokenSliding ||
tok.Type == TokenTumbling || tok.Type == TokenCounting || tok.Type == TokenSession ||
tok.Type == TokenHAVING || tok.Type == TokenLIMIT {
break
}
switch tok.Type {
case TokenIdent, TokenNumber:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, tok.Value)
case TokenEQ:
if tok.Value == "=" {
conditions = append(conditions, "==")
} else {
conditions = append(conditions, tok.Value)
}
case TokenAND:
conditions = append(conditions, "&&")
case TokenOR:
conditions = append(conditions, "||")
default:
// 处理字符串值的引号
if len(conditions) > 0 && conditions[len(conditions)-1] == "'" {
conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value
} else {
conditions = append(conditions, tok.Value)
}
}
}
// 验证WHERE条件中的函数
whereCondition := strings.Join(conditions, " ")
if whereCondition != "" {
validator := NewFunctionValidator(p.errorRecovery)
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(whereCondition, pos-len(whereCondition))
}
stmt.Condition = whereCondition
return nil
}
func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) error {
p.lexer.NextToken() // 跳过(
var params []interface{}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for p.lexer.peekChar() != ')' {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("window function parameter parsing exceeded maximum iterations, possible syntax error")
}
valTok := p.lexer.NextToken()
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
break
}
if valTok.Type == TokenComma {
continue
}
//valTok := p.lexer.NextToken()
// 处理引号包裹的值
if strings.HasPrefix(valTok.Value, "'") && strings.HasSuffix(valTok.Value, "'") {
valTok.Value = strings.Trim(valTok.Value, "'")
}
params = append(params, convertValue(valTok.Value))
}
if &stmt.Window != nil {
stmt.Window.Params = params
stmt.Window.Type = winType
} else {
stmt.Window = WindowDefinition{
Type: winType,
Params: params,
}
}
return nil
}
func convertValue(s string) interface{} {
if s == "true" {
return true
}
if s == "false" {
return false
}
if i, err := strconv.Atoi(s); err == nil {
return i
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f
}
// 处理引号包裹的字符串
if strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'") {
return strings.Trim(s, "'")
}
return s
}
func (p *Parser) parseFrom(stmt *SelectStatement) error {
tok := p.lexer.NextToken()
if tok.Type != TokenIdent {
err := CreateUnexpectedTokenError(
tok.Value,
[]string{"table_name", "stream_name"},
tok.Pos,
)
err.Message = "Expected source identifier after FROM"
err.Context = "FROM clause requires a table or stream name"
err.Suggestions = []string{
"Ensure FROM is followed by a valid table or stream name",
"Check if the table name is spelled correctly",
}
p.errorRecovery.AddError(err)
return err
}
stmt.Source = tok.Value
return nil
}
func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
p.parseWindowFunction(stmt, tok.Value)
}
if tok.Type == TokenGROUP {
p.lexer.NextToken() // 跳过BY
}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("group by clause parsing exceeded maximum iterations, possible syntax error")
}
tok := p.lexer.NextToken()
if tok.Type == TokenWITH || tok.Type == TokenOrder || tok.Type == TokenEOF ||
tok.Type == TokenHAVING || tok.Type == TokenLIMIT {
break
}
if tok.Type == TokenComma {
continue
}
if tok.Type == TokenTumbling || tok.Type == TokenSliding || tok.Type == TokenCounting || tok.Type == TokenSession {
p.parseWindowFunction(stmt, tok.Value)
continue
}
stmt.GroupBy = append(stmt.GroupBy, tok.Value)
}
return nil
}
func (p *Parser) parseWith(stmt *SelectStatement) error {
// 查看当前 token如果不是 WITH则返回
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type != TokenWITH {
return nil // 没有 WITH 子句,不是错误
}
p.lexer.NextToken() // 跳过(
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
for p.lexer.peekChar() != ')' {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("WITH clause parsing exceeded maximum iterations, possible syntax error")
}
valTok := p.lexer.NextToken()
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
break
}
if valTok.Type == TokenComma {
continue
}
if valTok.Type == TokenTimestamp {
next := p.lexer.NextToken()
if next.Type == TokenEQ {
next = p.lexer.NextToken()
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
next.Value = strings.Trim(next.Value, "'")
}
// 检查Window是否已初始化如果未初始化则创建新的WindowDefinition
if stmt.Window.Type == "" {
stmt.Window = WindowDefinition{
TsProp: next.Value,
}
} else {
stmt.Window.TsProp = next.Value
}
}
}
if valTok.Type == TokenTimeUnit {
timeUnit := time.Minute
next := p.lexer.NextToken()
if next.Type == TokenEQ {
next = p.lexer.NextToken()
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
next.Value = strings.Trim(next.Value, "'")
}
switch next.Value {
case "dd":
timeUnit = 24 * time.Hour
case "hh":
timeUnit = time.Hour
case "mi":
timeUnit = time.Minute
case "ss":
timeUnit = time.Second
case "ms":
timeUnit = time.Millisecond
default:
}
// 检查Window是否已初始化如果未初始化则创建新的WindowDefinition
if stmt.Window.Type == "" {
stmt.Window = WindowDefinition{
TimeUnit: timeUnit,
}
} else {
stmt.Window.TimeUnit = timeUnit
}
}
}
}
return nil
}
// parseLimit 解析LIMIT子句
func (p *Parser) parseLimit(stmt *SelectStatement) error {
// 查看当前token
if p.lexer.lookupIdent(p.lexer.readPreviousIdentifier()).Type == TokenLIMIT {
// 获取下一个token应该是一个数字
tok := p.lexer.NextToken()
if tok.Type == TokenNumber {
// 将数字字符串转换为整数
limit, err := strconv.Atoi(tok.Value)
if err != nil {
parseErr := CreateSyntaxError(
"LIMIT value must be a valid integer",
tok.Pos,
tok.Value,
[]string{"positive_integer"},
)
parseErr.Context = "LIMIT clause"
parseErr.Suggestions = []string{
"Use a positive integer, e.g., LIMIT 10",
"Ensure the number format is correct",
}
p.errorRecovery.AddError(parseErr)
return parseErr
}
if limit < 0 {
parseErr := CreateSyntaxError(
"LIMIT value must be positive",
tok.Pos,
tok.Value,
[]string{"positive_integer"},
)
parseErr.Suggestions = []string{"Use a positive integer, e.g., LIMIT 10"}
p.errorRecovery.AddError(parseErr)
return parseErr
}
stmt.Limit = limit
} else {
parseErr := CreateMissingTokenError("number", tok.Pos)
parseErr.Message = "LIMIT must be followed by an integer"
parseErr.Context = "LIMIT clause"
parseErr.Suggestions = []string{
"Add a number after LIMIT, e.g., LIMIT 10",
"Ensure LIMIT syntax is correct",
}
p.errorRecovery.AddError(parseErr)
return parseErr
}
}
return nil
}
// parseHaving 解析HAVING子句
func (p *Parser) parseHaving(stmt *SelectStatement) error {
// 查看当前token
tok := p.lexer.lookupIdent(p.lexer.readPreviousIdentifier())
if tok.Type != TokenHAVING {
return nil // 没有 HAVING 子句,不是错误
}
// 设置最大次数限制,防止无限循环
maxIterations := 100
iterations := 0
var conditions []string
for {
iterations++
// 安全检查:防止无限循环
if iterations > maxIterations {
return errors.New("HAVING clause parsing exceeded maximum iterations, possible syntax error")
}
tok := p.lexer.NextToken()
if tok.Type == TokenLIMIT || tok.Type == TokenEOF || tok.Type == TokenWITH {
break
}
switch tok.Type {
case TokenIdent, TokenNumber:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, tok.Value)
case TokenEQ:
if tok.Value == "=" {
conditions = append(conditions, "==")
} else {
conditions = append(conditions, tok.Value)
}
case TokenAND:
conditions = append(conditions, "&&")
case TokenOR:
conditions = append(conditions, "||")
default:
// 处理字符串值的引号
if len(conditions) > 0 && conditions[len(conditions)-1] == "'" {
conditions[len(conditions)-1] = conditions[len(conditions)-1] + tok.Value
} else {
conditions = append(conditions, tok.Value)
}
}
}
// 验证HAVING条件中的函数
havingCondition := strings.Join(conditions, " ")
if havingCondition != "" {
validator := NewFunctionValidator(p.errorRecovery)
pos, _, _ := p.lexer.GetPosition()
validator.ValidateExpression(havingCondition, pos-len(havingCondition))
}
stmt.Having = havingCondition
return nil
}
// Parse 是包级别的Parse函数用于解析SQL字符串并返回配置和条件
func Parse(sql string) (*types.Config, string, error) {
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
return nil, "", err
}
config, condition, err := stmt.ToStreamConfig()
if err != nil {
return nil, "", err
}
return config, condition, nil
}