mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-03-16 07:17:25 +00:00
381 lines
9.7 KiB
Go
381 lines
9.7 KiB
Go
package rsql
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestEnhancedErrorHandling(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expectedErrors int
|
|
errorType ErrorType
|
|
contains string
|
|
recoverable bool
|
|
}{
|
|
{
|
|
name: "Missing FROM keyword",
|
|
input: "SELECT * table1",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnexpectedToken,
|
|
contains: "Expected source identifier after FROM",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Typo in SELECT",
|
|
input: "SELCT * FROM table1",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownKeyword,
|
|
contains: "Unknown keyword 'SELCT'",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Typo in FROM",
|
|
input: "SELECT * FORM table1",
|
|
expectedErrors: 2, // FORM typo + missing FROM
|
|
errorType: ErrorTypeUnexpectedToken,
|
|
contains: "Expected source identifier after FROM",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Invalid character",
|
|
input: "SELECT * FROM table1 WHERE id # 5",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeLexical,
|
|
contains: "Unexpected character",
|
|
recoverable: false,
|
|
},
|
|
{
|
|
name: "Unterminated string",
|
|
input: "SELECT * FROM table1 WHERE name = 'test",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnterminatedString,
|
|
contains: "Unterminated string literal",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Invalid number format",
|
|
input: "SELECT * FROM table1 WHERE id = 12.34.56",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeInvalidNumber,
|
|
contains: "Invalid number format",
|
|
recoverable: false,
|
|
},
|
|
{
|
|
name: "Invalid LIMIT value",
|
|
input: "SELECT * FROM table1 LIMIT abc",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeMissingToken, // 4
|
|
contains: "LIMIT must be followed by an integer",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Negative LIMIT value",
|
|
input: "SELECT * FROM table1 LIMIT -5",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeMissingToken, // 4
|
|
contains: "LIMIT must be followed by an integer",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Multiple errors",
|
|
input: "SELCT * FORM table1 WHERE id # 5",
|
|
expectedErrors: -1, // 任意数量的错误,只要有错误就行
|
|
errorType: ErrorTypeUnknownKeyword, // 不检查具体类型
|
|
contains: "", // 不检查具体消息
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Unknown function",
|
|
input: "SELECT unknown_func(value) FROM stream",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownFunction, // 11
|
|
contains: "Unknown function 'unknown_func'",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Misspelled function",
|
|
input: "SELECT coun(value) FROM stream",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownFunction, // 11
|
|
contains: "Unknown function 'coun'",
|
|
recoverable: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parser := NewParser(tt.input)
|
|
_, err := parser.Parse()
|
|
|
|
// 检查是否有错误
|
|
if !parser.HasErrors() && err == nil {
|
|
t.Errorf("Expected error but got none")
|
|
return
|
|
}
|
|
|
|
// 检查错误数量
|
|
errors := parser.GetErrors()
|
|
if tt.expectedErrors >= 0 && len(errors) != tt.expectedErrors {
|
|
t.Errorf("Expected %d errors, got %d", tt.expectedErrors, len(errors))
|
|
} else if tt.expectedErrors == -1 && len(errors) == 0 {
|
|
t.Errorf("Expected at least one error, got none")
|
|
}
|
|
|
|
// 检查错误类型(至少有一个匹配)
|
|
found := false
|
|
for _, parseErr := range errors {
|
|
if parseErr.Type == tt.errorType {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found && len(errors) > 0 {
|
|
// 如果没找到期望的错误类型,但有其他错误,记录实际的错误类型
|
|
t.Logf("Expected error type %v not found. Actual error types: %v", tt.errorType, getErrorTypes(errors))
|
|
// 对于多错误情况,只要有错误就算通过
|
|
if tt.name != "Multiple errors" {
|
|
t.Errorf("Expected error type %v not found", tt.errorType)
|
|
}
|
|
}
|
|
|
|
// 检查错误消息内容
|
|
if tt.contains != "" && len(errors) > 0 {
|
|
found := false
|
|
for _, parseErr := range errors {
|
|
if strings.Contains(parseErr.Message, tt.contains) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
errorMessage := ""
|
|
if err != nil {
|
|
errorMessage = err.Error()
|
|
} else if len(errors) > 0 {
|
|
errorMessage = errors[0].Error()
|
|
}
|
|
t.Errorf("Error message should contain '%s', got: %s", tt.contains, errorMessage)
|
|
}
|
|
}
|
|
|
|
// 检查可恢复性
|
|
if len(errors) > 0 && errors[0].IsRecoverable() != tt.recoverable {
|
|
t.Errorf("Expected recoverable=%v, got %v", tt.recoverable, errors[0].IsRecoverable())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestErrorRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
canParse bool // 是否能够部分解析
|
|
}{
|
|
{
|
|
name: "Recoverable syntax error",
|
|
input: "SELECT * FROM table1 WHERE id = 'unclosed",
|
|
canParse: true,
|
|
},
|
|
{
|
|
name: "Multiple recoverable errors",
|
|
input: "SELCT * FORM table1",
|
|
canParse: true,
|
|
},
|
|
{
|
|
name: "Non-recoverable error",
|
|
input: "SELECT * FROM table1 WHERE id = 12.34.56",
|
|
canParse: true, // 词法错误但解析器可以继续
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parser := NewParser(tt.input)
|
|
stmt, err := parser.Parse()
|
|
|
|
if tt.canParse {
|
|
if stmt == nil {
|
|
t.Errorf("Expected partial parsing result, got nil")
|
|
}
|
|
if !parser.HasErrors() {
|
|
t.Errorf("Expected errors to be recorded")
|
|
}
|
|
} else {
|
|
if err == nil {
|
|
t.Errorf("Expected parsing to fail completely")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestErrorPositioning(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expectedLine int
|
|
expectedColumn int
|
|
}{
|
|
{
|
|
name: "Single line error",
|
|
input: "SELECT * FROM table1 WHERE id # 5",
|
|
expectedLine: 1,
|
|
expectedColumn: 30, // 大概位置
|
|
},
|
|
{
|
|
name: "Multi-line error",
|
|
input: "SELECT *\nFROM table1\nWHERE id # 5",
|
|
expectedLine: 3,
|
|
expectedColumn: 10, // 大概位置
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parser := NewParser(tt.input)
|
|
_, _ = parser.Parse()
|
|
|
|
errors := parser.GetErrors()
|
|
if len(errors) == 0 {
|
|
t.Errorf("Expected at least one error")
|
|
return
|
|
}
|
|
|
|
firstError := errors[0]
|
|
if firstError.Line != tt.expectedLine {
|
|
t.Errorf("Expected line %d, got %d", tt.expectedLine, firstError.Line)
|
|
}
|
|
|
|
// 列号检查相对宽松,因为计算可能有偏差
|
|
if firstError.Column < 1 {
|
|
t.Errorf("Expected column > 0, got %d", firstError.Column)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestErrorSuggestions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expectedSuggestion string
|
|
}{
|
|
{
|
|
name: "SELECT typo",
|
|
input: "SELCT * FROM table1",
|
|
expectedSuggestion: "SELECT",
|
|
},
|
|
{
|
|
name: "FROM typo",
|
|
input: "SELECT * FORM table1",
|
|
expectedSuggestion: "FROM",
|
|
},
|
|
{
|
|
name: "WHERE typo",
|
|
input: "SELECT * FROM table1 WHER id = 1",
|
|
expectedSuggestion: "WHERE",
|
|
},
|
|
{
|
|
name: "Unterminated string",
|
|
input: "SELECT * FROM table1 WHERE name = 'test",
|
|
expectedSuggestion: "Add closing quote",
|
|
},
|
|
{
|
|
name: "Invalid LIMIT",
|
|
input: "SELECT * FROM table1 LIMIT abc",
|
|
expectedSuggestion: "Add a number after LIMIT",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parser := NewParser(tt.input)
|
|
_, _ = parser.Parse()
|
|
|
|
errors := parser.GetErrors()
|
|
if len(errors) == 0 {
|
|
t.Errorf("Expected at least one error")
|
|
return
|
|
}
|
|
|
|
found := false
|
|
for _, err := range errors {
|
|
for _, suggestion := range err.Suggestions {
|
|
if strings.Contains(suggestion, tt.expectedSuggestion) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if found {
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("Expected suggestion containing '%s' not found", tt.expectedSuggestion)
|
|
t.Logf("Available suggestions: %v", errors[0].Suggestions)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestErrorContext(t *testing.T) {
|
|
input := "SELECT * FROM table1 WHERE id # 5"
|
|
parser := NewParser(input)
|
|
_, err := parser.Parse()
|
|
|
|
if err == nil {
|
|
t.Errorf("Expected error but got none")
|
|
return
|
|
}
|
|
|
|
errorMessage := err.Error()
|
|
if !strings.Contains(errorMessage, "WHERE id # 5") {
|
|
t.Errorf("Error message should contain context, got: %s", errorMessage)
|
|
}
|
|
|
|
if !strings.Contains(errorMessage, "^") {
|
|
t.Errorf("Error message should contain position pointer, got: %s", errorMessage)
|
|
}
|
|
}
|
|
|
|
func TestValidSQLParsing(t *testing.T) {
|
|
// 确保有效的SQL仍然能正常解析
|
|
validInputs := []string{
|
|
"SELECT * FROM table1",
|
|
"SELECT id, name FROM users WHERE age > 18",
|
|
"SELECT COUNT(*) FROM orders GROUP BY status",
|
|
"SELECT * FROM products LIMIT 10",
|
|
}
|
|
|
|
for _, input := range validInputs {
|
|
t.Run(input, func(t *testing.T) {
|
|
parser := NewParser(input)
|
|
stmt, err := parser.Parse()
|
|
|
|
if err != nil {
|
|
t.Errorf("Valid SQL should parse without error, got: %v", err)
|
|
}
|
|
|
|
if stmt == nil {
|
|
t.Errorf("Valid SQL should return statement")
|
|
}
|
|
|
|
if parser.HasErrors() {
|
|
t.Errorf("Valid SQL should not have errors")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// getErrorTypes 获取错误类型列表
|
|
func getErrorTypes(errors []*ParseError) []ErrorType {
|
|
types := make([]ErrorType, len(errors))
|
|
for i, err := range errors {
|
|
types[i] = err.Type
|
|
}
|
|
return types
|
|
} |