Files
2025-08-07 19:23:48 +08:00

339 lines
9.0 KiB
Go

package rsql
import (
"fmt"
"strings"
"testing"
)
// TestParseError 测试 ParseError 结构体
func TestParseError(t *testing.T) {
err := &ParseError{
Type: ErrorTypeSyntax,
Message: "Invalid syntax",
Position: 10,
Line: 2,
Column: 5,
Token: "SELECT",
Expected: []string{"FROM", "WHERE"},
Suggestions: []string{"Add FROM clause", "Check syntax"},
Context: "SELECT statement",
Recoverable: true,
}
// 测试 Error() 方法
errorStr := err.Error()
if !strings.Contains(errorStr, "SYNTAX_ERROR") {
t.Errorf("Error string should contain 'SYNTAX_ERROR', got: %s", errorStr)
}
if !strings.Contains(errorStr, "Invalid syntax") {
t.Errorf("Error string should contain message, got: %s", errorStr)
}
if !strings.Contains(errorStr, "line 2, column 5") {
t.Errorf("Error string should contain position info, got: %s", errorStr)
}
if !strings.Contains(errorStr, "found 'SELECT'") {
t.Errorf("Error string should contain token info, got: %s", errorStr)
}
if !strings.Contains(errorStr, "expected: FROM, WHERE") {
t.Errorf("Error string should contain expected tokens, got: %s", errorStr)
}
if !strings.Contains(errorStr, "Context: SELECT statement") {
t.Errorf("Error string should contain context, got: %s", errorStr)
}
if !strings.Contains(errorStr, "Suggestions: Add FROM clause; Check syntax") {
t.Errorf("Error string should contain suggestions, got: %s", errorStr)
}
// 测试 IsRecoverable() 方法
if !err.IsRecoverable() {
t.Error("Error should be recoverable")
}
}
// TestEnhancedErrorHandling 测试增强的错误处理
func TestEnhancedErrorHandling(t *testing.T) {
tests := []struct {
name string
input string
expectedErrors int
errorType ErrorType
contains string
recoverable bool
}{
{
name: "Missing FROM keyword",
input: "SELECT * table1",
expectedErrors: 1,
errorType: ErrorTypeUnexpectedToken,
contains: "Expected source identifier after FROM",
recoverable: true,
},
{
name: "Typo in SELECT",
input: "SELCT * FROM table1",
expectedErrors: 1,
errorType: ErrorTypeUnknownKeyword,
contains: "Unknown keyword 'SELCT'",
recoverable: true,
},
{
name: "Invalid character",
input: "SELECT * FROM table1 WHERE id # 5",
expectedErrors: 1,
errorType: ErrorTypeLexical,
contains: "Unexpected character",
recoverable: false,
},
{
name: "Unterminated string",
input: "SELECT * FROM table1 WHERE name = 'test",
expectedErrors: 1,
errorType: ErrorTypeUnterminatedString,
contains: "Unterminated string literal",
recoverable: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
// 应该有错误
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error but got none")
return
}
// 检查错误数量
if test.expectedErrors > 0 {
errors := parser.GetErrors()
if len(errors) != test.expectedErrors {
t.Errorf("Expected %d errors, got %d", test.expectedErrors, len(errors))
}
}
// 检查错误内容
if test.contains != "" {
errorFound := false
for _, parseErr := range parser.GetErrors() {
if strings.Contains(parseErr.Message, test.contains) {
errorFound = true
break
}
}
if !errorFound {
t.Errorf("Expected error containing '%s'", test.contains)
}
}
})
}
}
// TestErrorTypes 测试错误类型
func TestErrorTypes(t *testing.T) {
errorTypes := []ErrorType{
ErrorTypeSyntax,
ErrorTypeLexical,
ErrorTypeSemantics,
ErrorTypeUnexpectedToken,
ErrorTypeMissingToken,
ErrorTypeInvalidExpression,
ErrorTypeUnknownKeyword,
ErrorTypeInvalidNumber,
ErrorTypeUnterminatedString,
ErrorTypeMaxIterations,
ErrorTypeUnknownFunction,
}
for _, errorType := range errorTypes {
t.Run(fmt.Sprintf("ErrorType_%d", int(errorType)), func(t *testing.T) {
err := &ParseError{
Type: errorType,
Message: "Test error",
}
errorStr := err.Error()
if errorStr == "" {
t.Error("Error string should not be empty")
}
})
}
}
// TestErrorRecovery 测试错误恢复机制
func TestErrorRecovery(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
errorCount int
}{
{
name: "Multiple syntax errors",
input: "SELCT * FORM table WHRE id = 1",
expectError: true,
errorCount: 3, // SELCT, FORM, WHRE
},
{
name: "Missing tokens",
input: "SELECT FROM WHERE",
expectError: true,
errorCount: 1,
},
{
name: "Incomplete WHERE clause",
input: "SELECT * FROM table WHERE (",
expectError: false,
errorCount: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
if test.expectError {
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error but got none")
}
} else {
if err != nil || parser.HasErrors() {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
// TestNewFunctionValidator 测试 FunctionValidator 创建
func TestNewFunctionValidator(t *testing.T) {
lexer := NewLexer("SELECT * FROM table")
parser := &Parser{lexer: lexer}
er := NewErrorRecovery(parser)
fv := NewFunctionValidator(er)
if fv == nil {
t.Error("NewFunctionValidator should not return nil")
return
}
if fv.errorRecovery != er {
t.Error("FunctionValidator should store the provided ErrorRecovery")
}
}
// TestFunctionValidatorValidateExpression 测试函数验证器的表达式验证
func TestFunctionValidatorValidateExpression(t *testing.T) {
tests := []struct {
name string
expression string
expectedErrors int
errorType ErrorType
errorMessage string
}{
{
name: "Valid builtin function",
expression: "abs(temperature)",
expectedErrors: 0,
},
{
name: "Valid nested builtin functions",
expression: "sqrt(abs(temperature))",
expectedErrors: 0,
},
{
name: "Unknown function",
expression: "unknown_func(temperature)",
expectedErrors: 1,
errorType: ErrorTypeUnknownFunction,
errorMessage: "unknown_func",
},
{
name: "Multiple unknown functions",
expression: "unknown1(temperature) + unknown2(humidity)",
expectedErrors: 2,
errorType: ErrorTypeUnknownFunction,
},
{
name: "Mixed valid and invalid functions",
expression: "abs(temperature) + unknown_func(humidity)",
expectedErrors: 1,
errorType: ErrorTypeUnknownFunction,
errorMessage: "unknown_func",
},
{
name: "No functions in expression",
expression: "temperature + humidity",
expectedErrors: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
lexer := NewLexer("SELECT * FROM table")
parser := &Parser{lexer: lexer}
er := NewErrorRecovery(parser)
fv := NewFunctionValidator(er)
fv.ValidateExpression(test.expression, 0)
errors := er.GetErrors()
if len(errors) != test.expectedErrors {
t.Errorf("Expected %d errors, got %d", test.expectedErrors, len(errors))
return
}
if test.expectedErrors > 0 {
if errors[0].Type != test.errorType {
t.Errorf("Expected error type %v, got %v", test.errorType, errors[0].Type)
}
if test.errorMessage != "" && !strings.Contains(errors[0].Message, test.errorMessage) {
t.Errorf("Expected error message to contain '%s', got '%s'", test.errorMessage, errors[0].Message)
}
}
})
}
}
// TestFunctionValidatorBuiltins 测试函数验证器内置函数
func TestFunctionValidatorBuiltins(t *testing.T) {
lexer := NewLexer("SELECT * FROM table")
parser := &Parser{lexer: lexer}
er := NewErrorRecovery(parser)
validator := NewFunctionValidator(er)
// 测试内置函数验证(基于实际实现的数学函数)
builtinFunctions := []string{"ABS", "ROUND", "SQRT", "SIN", "COS", "FLOOR", "CEIL"}
for _, funcName := range builtinFunctions {
t.Run("Builtin_"+funcName, func(t *testing.T) {
if !validator.isBuiltinFunction(funcName) {
t.Errorf("Expected %s to be a valid builtin function", funcName)
}
})
}
// 测试聚合函数(这些在函数注册系统中是有效的)
aggregateFunctions := []string{"COUNT", "SUM", "AVG", "MAX", "MIN"}
for _, funcName := range aggregateFunctions {
t.Run("Aggregate_"+funcName, func(t *testing.T) {
// 聚合函数应该在函数注册系统中存在
if !validator.isBuiltinFunction(funcName) {
t.Errorf("Expected %s to be a valid function (it's registered in the function registry)", funcName)
}
})
}
// 测试无效函数
invalidFunctions := []string{"INVALID_FUNC", "UNKNOWN", ""}
for _, funcName := range invalidFunctions {
t.Run("Invalid_"+funcName, func(t *testing.T) {
if validator.isBuiltinFunction(funcName) {
t.Errorf("Expected %s to be an invalid function", funcName)
}
})
}
}