Files
streamsql/rsql/error_test.go
T
2025-08-05 11:25:49 +08:00

338 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package rsql
import (
"strings"
"testing"
"fmt"
)
// TestParseError 测试 ParseError 结构体
func TestParseError(t *testing.T) {
err := &ParseError{
Type: ErrorTypeSyntax,
Message: "Invalid syntax",
Position: 10,
Line: 2,
Column: 5,
Token: "SELECT",
Expected: []string{"FROM", "WHERE"},
Suggestions: []string{"Add FROM clause", "Check syntax"},
Context: "SELECT statement",
Recoverable: true,
}
// 测试 Error() 方法
errorStr := err.Error()
if !strings.Contains(errorStr, "SYNTAX_ERROR") {
t.Errorf("Error string should contain 'SYNTAX_ERROR', got: %s", errorStr)
}
if !strings.Contains(errorStr, "Invalid syntax") {
t.Errorf("Error string should contain message, got: %s", errorStr)
}
if !strings.Contains(errorStr, "line 2, column 5") {
t.Errorf("Error string should contain position info, got: %s", errorStr)
}
if !strings.Contains(errorStr, "found 'SELECT'") {
t.Errorf("Error string should contain token info, got: %s", errorStr)
}
if !strings.Contains(errorStr, "expected: FROM, WHERE") {
t.Errorf("Error string should contain expected tokens, got: %s", errorStr)
}
if !strings.Contains(errorStr, "Context: SELECT statement") {
t.Errorf("Error string should contain context, got: %s", errorStr)
}
if !strings.Contains(errorStr, "Suggestions: Add FROM clause; Check syntax") {
t.Errorf("Error string should contain suggestions, got: %s", errorStr)
}
// 测试 IsRecoverable() 方法
if !err.IsRecoverable() {
t.Error("Error should be recoverable")
}
}
// TestEnhancedErrorHandling 测试增强的错误处理
func TestEnhancedErrorHandling(t *testing.T) {
tests := []struct {
name string
input string
expectedErrors int
errorType ErrorType
contains string
recoverable bool
}{
{
name: "Missing FROM keyword",
input: "SELECT * table1",
expectedErrors: 1,
errorType: ErrorTypeUnexpectedToken,
contains: "Expected source identifier after FROM",
recoverable: true,
},
{
name: "Typo in SELECT",
input: "SELCT * FROM table1",
expectedErrors: 1,
errorType: ErrorTypeUnknownKeyword,
contains: "Unknown keyword 'SELCT'",
recoverable: true,
},
{
name: "Invalid character",
input: "SELECT * FROM table1 WHERE id # 5",
expectedErrors: 1,
errorType: ErrorTypeLexical,
contains: "Unexpected character",
recoverable: false,
},
{
name: "Unterminated string",
input: "SELECT * FROM table1 WHERE name = 'test",
expectedErrors: 1,
errorType: ErrorTypeUnterminatedString,
contains: "Unterminated string literal",
recoverable: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
// 应该有错误
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error but got none")
return
}
// 检查错误数量
if test.expectedErrors > 0 {
errors := parser.GetErrors()
if len(errors) != test.expectedErrors {
t.Errorf("Expected %d errors, got %d", test.expectedErrors, len(errors))
}
}
// 检查错误内容
if test.contains != "" {
errorFound := false
for _, parseErr := range parser.GetErrors() {
if strings.Contains(parseErr.Message, test.contains) {
errorFound = true
break
}
}
if !errorFound {
t.Errorf("Expected error containing '%s'", test.contains)
}
}
})
}
}
// TestErrorTypes 测试错误类型
func TestErrorTypes(t *testing.T) {
errorTypes := []ErrorType{
ErrorTypeSyntax,
ErrorTypeLexical,
ErrorTypeSemantics,
ErrorTypeUnexpectedToken,
ErrorTypeMissingToken,
ErrorTypeInvalidExpression,
ErrorTypeUnknownKeyword,
ErrorTypeInvalidNumber,
ErrorTypeUnterminatedString,
ErrorTypeMaxIterations,
ErrorTypeUnknownFunction,
}
for _, errorType := range errorTypes {
t.Run(fmt.Sprintf("ErrorType_%d", int(errorType)), func(t *testing.T) {
err := &ParseError{
Type: errorType,
Message: "Test error",
}
errorStr := err.Error()
if errorStr == "" {
t.Error("Error string should not be empty")
}
})
}
}
// TestErrorRecovery 测试错误恢复机制
func TestErrorRecovery(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
errorCount int
}{
{
name: "Multiple syntax errors",
input: "SELCT * FORM table WHRE id = 1",
expectError: true,
errorCount: 3, // SELCT, FORM, WHRE
},
{
name: "Missing tokens",
input: "SELECT FROM WHERE",
expectError: true,
errorCount: 1,
},
{
name: "Incomplete WHERE clause",
input: "SELECT * FROM table WHERE (",
expectError: false,
errorCount: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
if test.expectError {
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error but got none")
}
} else {
if err != nil || parser.HasErrors() {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
// TestNewFunctionValidator 测试 FunctionValidator 创建
func TestNewFunctionValidator(t *testing.T) {
lexer := NewLexer("SELECT * FROM table")
parser := &Parser{lexer: lexer}
er := NewErrorRecovery(parser)
fv := NewFunctionValidator(er)
if fv == nil {
t.Error("NewFunctionValidator should not return nil")
return
}
if fv.errorRecovery != er {
t.Error("FunctionValidator should store the provided ErrorRecovery")
}
}
// TestFunctionValidatorValidateExpression 测试函数验证器的表达式验证
func TestFunctionValidatorValidateExpression(t *testing.T) {
tests := []struct {
name string
expression string
expectedErrors int
errorType ErrorType
errorMessage string
}{
{
name: "Valid builtin function",
expression: "abs(temperature)",
expectedErrors: 0,
},
{
name: "Valid nested builtin functions",
expression: "sqrt(abs(temperature))",
expectedErrors: 0,
},
{
name: "Unknown function",
expression: "unknown_func(temperature)",
expectedErrors: 1,
errorType: ErrorTypeUnknownFunction,
errorMessage: "unknown_func",
},
{
name: "Multiple unknown functions",
expression: "unknown1(temperature) + unknown2(humidity)",
expectedErrors: 2,
errorType: ErrorTypeUnknownFunction,
},
{
name: "Mixed valid and invalid functions",
expression: "abs(temperature) + unknown_func(humidity)",
expectedErrors: 1,
errorType: ErrorTypeUnknownFunction,
errorMessage: "unknown_func",
},
{
name: "No functions in expression",
expression: "temperature + humidity",
expectedErrors: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
lexer := NewLexer("SELECT * FROM table")
parser := &Parser{lexer: lexer}
er := NewErrorRecovery(parser)
fv := NewFunctionValidator(er)
fv.ValidateExpression(test.expression, 0)
errors := er.GetErrors()
if len(errors) != test.expectedErrors {
t.Errorf("Expected %d errors, got %d", test.expectedErrors, len(errors))
return
}
if test.expectedErrors > 0 {
if errors[0].Type != test.errorType {
t.Errorf("Expected error type %v, got %v", test.errorType, errors[0].Type)
}
if test.errorMessage != "" && !strings.Contains(errors[0].Message, test.errorMessage) {
t.Errorf("Expected error message to contain '%s', got '%s'", test.errorMessage, errors[0].Message)
}
}
})
}
}
// TestFunctionValidatorBuiltins 测试函数验证器内置函数
func TestFunctionValidatorBuiltins(t *testing.T) {
lexer := NewLexer("SELECT * FROM table")
parser := &Parser{lexer: lexer}
er := NewErrorRecovery(parser)
validator := NewFunctionValidator(er)
// 测试内置函数验证(基于实际实现的数学函数)
builtinFunctions := []string{"ABS", "ROUND", "SQRT", "SIN", "COS", "FLOOR", "CEIL"}
for _, funcName := range builtinFunctions {
t.Run("Builtin_"+funcName, func(t *testing.T) {
if !validator.isBuiltinFunction(funcName) {
t.Errorf("Expected %s to be a valid builtin function", funcName)
}
})
}
// 测试聚合函数这些不在isBuiltinFunction中但在SQL中是有效的
aggregateFunctions := []string{"COUNT", "SUM", "AVG", "MAX", "MIN"}
for _, funcName := range aggregateFunctions {
t.Run("Aggregate_"+funcName, func(t *testing.T) {
// 聚合函数不在isBuiltinFunction中这是正确的
if validator.isBuiltinFunction(funcName) {
t.Errorf("Expected %s to not be in builtin functions (it's an aggregate function)", funcName)
}
})
}
// 测试无效函数
invalidFunctions := []string{"INVALID_FUNC", "UNKNOWN", ""}
for _, funcName := range invalidFunctions {
t.Run("Invalid_"+funcName, func(t *testing.T) {
if validator.isBuiltinFunction(funcName) {
t.Errorf("Expected %s to be an invalid function", funcName)
}
})
}
}