forked from GiteaTest2015/streamsql
339 lines
9.0 KiB
Go
339 lines
9.0 KiB
Go
package rsql
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// TestParseError 测试 ParseError 结构体
|
|
func TestParseError(t *testing.T) {
|
|
err := &ParseError{
|
|
Type: ErrorTypeSyntax,
|
|
Message: "Invalid syntax",
|
|
Position: 10,
|
|
Line: 2,
|
|
Column: 5,
|
|
Token: "SELECT",
|
|
Expected: []string{"FROM", "WHERE"},
|
|
Suggestions: []string{"Add FROM clause", "Check syntax"},
|
|
Context: "SELECT statement",
|
|
Recoverable: true,
|
|
}
|
|
|
|
// 测试 Error() 方法
|
|
errorStr := err.Error()
|
|
if !strings.Contains(errorStr, "SYNTAX_ERROR") {
|
|
t.Errorf("Error string should contain 'SYNTAX_ERROR', got: %s", errorStr)
|
|
}
|
|
if !strings.Contains(errorStr, "Invalid syntax") {
|
|
t.Errorf("Error string should contain message, got: %s", errorStr)
|
|
}
|
|
if !strings.Contains(errorStr, "line 2, column 5") {
|
|
t.Errorf("Error string should contain position info, got: %s", errorStr)
|
|
}
|
|
if !strings.Contains(errorStr, "found 'SELECT'") {
|
|
t.Errorf("Error string should contain token info, got: %s", errorStr)
|
|
}
|
|
if !strings.Contains(errorStr, "expected: FROM, WHERE") {
|
|
t.Errorf("Error string should contain expected tokens, got: %s", errorStr)
|
|
}
|
|
if !strings.Contains(errorStr, "Context: SELECT statement") {
|
|
t.Errorf("Error string should contain context, got: %s", errorStr)
|
|
}
|
|
if !strings.Contains(errorStr, "Suggestions: Add FROM clause; Check syntax") {
|
|
t.Errorf("Error string should contain suggestions, got: %s", errorStr)
|
|
}
|
|
|
|
// 测试 IsRecoverable() 方法
|
|
if !err.IsRecoverable() {
|
|
t.Error("Error should be recoverable")
|
|
}
|
|
}
|
|
|
|
// TestEnhancedErrorHandling 测试增强的错误处理
|
|
func TestEnhancedErrorHandling(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expectedErrors int
|
|
errorType ErrorType
|
|
contains string
|
|
recoverable bool
|
|
}{
|
|
{
|
|
name: "Missing FROM keyword",
|
|
input: "SELECT * table1",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnexpectedToken,
|
|
contains: "Expected source identifier after FROM",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Typo in SELECT",
|
|
input: "SELCT * FROM table1",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownKeyword,
|
|
contains: "Unknown keyword 'SELCT'",
|
|
recoverable: true,
|
|
},
|
|
{
|
|
name: "Invalid character",
|
|
input: "SELECT * FROM table1 WHERE id # 5",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeLexical,
|
|
contains: "Unexpected character",
|
|
recoverable: false,
|
|
},
|
|
{
|
|
name: "Unterminated string",
|
|
input: "SELECT * FROM table1 WHERE name = 'test",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnterminatedString,
|
|
contains: "Unterminated string literal",
|
|
recoverable: true,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
parser := NewParser(test.input)
|
|
_, err := parser.Parse()
|
|
|
|
// 应该有错误
|
|
if err == nil && !parser.HasErrors() {
|
|
t.Errorf("Expected error but got none")
|
|
return
|
|
}
|
|
|
|
// 检查错误数量
|
|
if test.expectedErrors > 0 {
|
|
errors := parser.GetErrors()
|
|
if len(errors) != test.expectedErrors {
|
|
t.Errorf("Expected %d errors, got %d", test.expectedErrors, len(errors))
|
|
}
|
|
}
|
|
|
|
// 检查错误内容
|
|
if test.contains != "" {
|
|
errorFound := false
|
|
for _, parseErr := range parser.GetErrors() {
|
|
if strings.Contains(parseErr.Message, test.contains) {
|
|
errorFound = true
|
|
break
|
|
}
|
|
}
|
|
if !errorFound {
|
|
t.Errorf("Expected error containing '%s'", test.contains)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestErrorTypes 测试错误类型
|
|
func TestErrorTypes(t *testing.T) {
|
|
errorTypes := []ErrorType{
|
|
ErrorTypeSyntax,
|
|
ErrorTypeLexical,
|
|
ErrorTypeSemantics,
|
|
ErrorTypeUnexpectedToken,
|
|
ErrorTypeMissingToken,
|
|
ErrorTypeInvalidExpression,
|
|
ErrorTypeUnknownKeyword,
|
|
ErrorTypeInvalidNumber,
|
|
ErrorTypeUnterminatedString,
|
|
ErrorTypeMaxIterations,
|
|
ErrorTypeUnknownFunction,
|
|
}
|
|
|
|
for _, errorType := range errorTypes {
|
|
t.Run(fmt.Sprintf("ErrorType_%d", int(errorType)), func(t *testing.T) {
|
|
err := &ParseError{
|
|
Type: errorType,
|
|
Message: "Test error",
|
|
}
|
|
errorStr := err.Error()
|
|
if errorStr == "" {
|
|
t.Error("Error string should not be empty")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestErrorRecovery 测试错误恢复机制
|
|
func TestErrorRecovery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expectError bool
|
|
errorCount int
|
|
}{
|
|
{
|
|
name: "Multiple syntax errors",
|
|
input: "SELCT * FORM table WHRE id = 1",
|
|
expectError: true,
|
|
errorCount: 3, // SELCT, FORM, WHRE
|
|
},
|
|
{
|
|
name: "Missing tokens",
|
|
input: "SELECT FROM WHERE",
|
|
expectError: true,
|
|
errorCount: 1,
|
|
},
|
|
{
|
|
name: "Incomplete WHERE clause",
|
|
input: "SELECT * FROM table WHERE (",
|
|
expectError: false,
|
|
errorCount: 0,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
parser := NewParser(test.input)
|
|
_, err := parser.Parse()
|
|
|
|
if test.expectError {
|
|
if err == nil && !parser.HasErrors() {
|
|
t.Errorf("Expected error but got none")
|
|
}
|
|
} else {
|
|
if err != nil || parser.HasErrors() {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestNewFunctionValidator 测试 FunctionValidator 创建
|
|
func TestNewFunctionValidator(t *testing.T) {
|
|
lexer := NewLexer("SELECT * FROM table")
|
|
parser := &Parser{lexer: lexer}
|
|
er := NewErrorRecovery(parser)
|
|
fv := NewFunctionValidator(er)
|
|
|
|
if fv == nil {
|
|
t.Error("NewFunctionValidator should not return nil")
|
|
return
|
|
}
|
|
|
|
if fv.errorRecovery != er {
|
|
t.Error("FunctionValidator should store the provided ErrorRecovery")
|
|
}
|
|
}
|
|
|
|
// TestFunctionValidatorValidateExpression 测试函数验证器的表达式验证
|
|
func TestFunctionValidatorValidateExpression(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
expression string
|
|
expectedErrors int
|
|
errorType ErrorType
|
|
errorMessage string
|
|
}{
|
|
{
|
|
name: "Valid builtin function",
|
|
expression: "abs(temperature)",
|
|
expectedErrors: 0,
|
|
},
|
|
{
|
|
name: "Valid nested builtin functions",
|
|
expression: "sqrt(abs(temperature))",
|
|
expectedErrors: 0,
|
|
},
|
|
{
|
|
name: "Unknown function",
|
|
expression: "unknown_func(temperature)",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownFunction,
|
|
errorMessage: "unknown_func",
|
|
},
|
|
{
|
|
name: "Multiple unknown functions",
|
|
expression: "unknown1(temperature) + unknown2(humidity)",
|
|
expectedErrors: 2,
|
|
errorType: ErrorTypeUnknownFunction,
|
|
},
|
|
{
|
|
name: "Mixed valid and invalid functions",
|
|
expression: "abs(temperature) + unknown_func(humidity)",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownFunction,
|
|
errorMessage: "unknown_func",
|
|
},
|
|
{
|
|
name: "No functions in expression",
|
|
expression: "temperature + humidity",
|
|
expectedErrors: 0,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
lexer := NewLexer("SELECT * FROM table")
|
|
parser := &Parser{lexer: lexer}
|
|
er := NewErrorRecovery(parser)
|
|
fv := NewFunctionValidator(er)
|
|
|
|
fv.ValidateExpression(test.expression, 0)
|
|
|
|
errors := er.GetErrors()
|
|
if len(errors) != test.expectedErrors {
|
|
t.Errorf("Expected %d errors, got %d", test.expectedErrors, len(errors))
|
|
return
|
|
}
|
|
|
|
if test.expectedErrors > 0 {
|
|
if errors[0].Type != test.errorType {
|
|
t.Errorf("Expected error type %v, got %v", test.errorType, errors[0].Type)
|
|
}
|
|
|
|
if test.errorMessage != "" && !strings.Contains(errors[0].Message, test.errorMessage) {
|
|
t.Errorf("Expected error message to contain '%s', got '%s'", test.errorMessage, errors[0].Message)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestFunctionValidatorBuiltins 测试函数验证器内置函数
|
|
func TestFunctionValidatorBuiltins(t *testing.T) {
|
|
lexer := NewLexer("SELECT * FROM table")
|
|
parser := &Parser{lexer: lexer}
|
|
er := NewErrorRecovery(parser)
|
|
validator := NewFunctionValidator(er)
|
|
|
|
// 测试内置函数验证(基于实际实现的数学函数)
|
|
builtinFunctions := []string{"ABS", "ROUND", "SQRT", "SIN", "COS", "FLOOR", "CEIL"}
|
|
for _, funcName := range builtinFunctions {
|
|
t.Run("Builtin_"+funcName, func(t *testing.T) {
|
|
if !validator.isBuiltinFunction(funcName) {
|
|
t.Errorf("Expected %s to be a valid builtin function", funcName)
|
|
}
|
|
})
|
|
}
|
|
|
|
// 测试聚合函数(这些在函数注册系统中是有效的)
|
|
aggregateFunctions := []string{"COUNT", "SUM", "AVG", "MAX", "MIN"}
|
|
for _, funcName := range aggregateFunctions {
|
|
t.Run("Aggregate_"+funcName, func(t *testing.T) {
|
|
// 聚合函数应该在函数注册系统中存在
|
|
if !validator.isBuiltinFunction(funcName) {
|
|
t.Errorf("Expected %s to be a valid function (it's registered in the function registry)", funcName)
|
|
}
|
|
})
|
|
}
|
|
|
|
// 测试无效函数
|
|
invalidFunctions := []string{"INVALID_FUNC", "UNKNOWN", ""}
|
|
for _, funcName := range invalidFunctions {
|
|
t.Run("Invalid_"+funcName, func(t *testing.T) {
|
|
if validator.isBuiltinFunction(funcName) {
|
|
t.Errorf("Expected %s to be an invalid function", funcName)
|
|
}
|
|
})
|
|
}
|
|
}
|