Files
streamsql/rsql/parser_test.go
T
2025-08-05 11:25:49 +08:00

496 lines
12 KiB
Go

package rsql
import (
"strings"
"testing"
)
// TestNewParser 测试解析器的创建
func TestNewParser(t *testing.T) {
input := "SELECT * FROM table"
parser := NewParser(input)
if parser == nil {
t.Fatal("Expected parser to be created, got nil")
}
if parser.input != input {
t.Errorf("Expected input %s, got %s", input, parser.input)
}
if parser.lexer == nil {
t.Error("Expected lexer to be initialized")
}
if parser.errorRecovery == nil {
t.Error("Expected error recovery to be initialized")
}
}
// TestParserGetErrors 测试错误获取功能
func TestParserGetErrors(t *testing.T) {
// 使用一个明显无效的SQL,确保会产生错误
parser := NewParser("SELECT * FROM table WHERE INVALID_FUNCTION()")
_, err := parser.Parse() // 这会产生错误
if err == nil {
t.Error("Expected parser to have errors")
}
if !parser.HasErrors() {
t.Error("Expected parser to have errors")
}
errors := parser.GetErrors()
if len(errors) == 0 {
t.Error("Expected at least one error")
}
}
// TestParserBasicSelect 测试基本SELECT语句解析
func TestParserBasicSelect(t *testing.T) {
tests := []struct {
input string
expectError bool
description string
}{
{"SELECT * FROM table", false, "基本SELECT语句"},
{"SELECT id, name FROM users", false, "指定字段的SELECT语句"},
{"SELECT DISTINCT category FROM products", false, "带DISTINCT的SELECT语句"},
{"SELECT COUNT(*) FROM orders", false, "带聚合函数的SELECT语句"},
{"SELECT * FROM events LIMIT 100", false, "带LIMIT的SELECT语句"},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
if test.expectError {
if err == nil && !parser.HasErrors() {
t.Error("Expected error but got none")
}
} else {
if err != nil || parser.HasErrors() {
t.Errorf("Unexpected error: %v", err)
if parser.HasErrors() {
for _, parseErr := range parser.GetErrors() {
t.Errorf("Parse error: %s", parseErr.Error())
}
}
}
}
})
}
}
// TestParserErrorRecovery 测试错误恢复功能
func TestParserErrorRecovery(t *testing.T) {
tests := []struct {
input string
description string
}{
{"SELCT * FROM table", "typo in SELECT"},
{"SELECT * FORM table", "typo in FROM"},
{"SELECT * FROM", "missing table name"},
{"SELECT * FROM table LIMIT abc", "invalid limit value"},
{"SELECT * FROM table LIMIT -5", "negative limit value"},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
// 对于 "SELECT FROM table" 这种情况,可能不会产生错误,因为解析器可能会将其解释为有效的语法
if test.input == "SELECT FROM table" {
// 这种情况下,我们不强制要求有错误
return
}
// 应该有错误
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error but got none for input: %s", test.input)
return
}
// 检查是否记录了错误
if !parser.HasErrors() {
t.Errorf("Expected errors to be recorded for input: %s", test.input)
}
})
}
}
// TestParseBasicSQL 测试基本SQL解析功能
func TestParseBasicSQL(t *testing.T) {
tests := []struct {
name string
sql string
expectError bool
}{
{
name: "BasicSelect",
sql: "SELECT deviceId FROM Input",
expectError: false,
},
{
name: "SelectWithWhere",
sql: "SELECT deviceId FROM Input WHERE deviceId='aa'",
expectError: false,
},
{
name: "SelectWithGroupBy",
sql: "SELECT COUNT(*) FROM Input GROUP BY deviceId",
expectError: false,
},
{
name: "InvalidSQL",
sql: "INVALID SQL",
expectError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
config, condition, err := Parse(test.sql)
if test.expectError {
if err == nil {
t.Errorf("Expected error for %s but got none", test.sql)
}
} else {
if err != nil {
t.Errorf("Unexpected error for %s: %v", test.sql, err)
} else {
// 基本验证
if config == nil {
t.Errorf("Expected config but got nil for %s", test.sql)
}
// condition可以为空
_ = condition
}
}
})
}
}
// TestRSQLIntegration 测试RSQL包的集成功能
func TestRSQLIntegration(t *testing.T) {
tests := []struct {
name string
sql string
expectError bool
description string
}{
{
name: "BasicSelect",
sql: "SELECT * FROM events",
expectError: false,
description: "基本SELECT语句",
},
{
name: "SelectWithWhere",
sql: "SELECT id, name FROM users WHERE age > 18",
expectError: false,
description: "带WHERE条件的SELECT语句",
},
{
name: "SelectWithGroupBy",
sql: "SELECT COUNT(*) FROM orders GROUP BY status",
expectError: false,
description: "带GROUP BY的SELECT语句",
},
{
name: "SelectWithHaving",
sql: "SELECT COUNT(*) FROM products GROUP BY category HAVING COUNT(*) > 5",
expectError: false,
description: "带HAVING子句的SELECT语句",
},
{
name: "SelectWithLimit",
sql: "SELECT * FROM logs LIMIT 100",
expectError: false,
description: "带LIMIT的SELECT语句",
},
{
name: "SelectWithTumblingWindow",
sql: "SELECT COUNT(*) FROM events TUMBLINGWINDOW(5, 'mi') WITH (TIMESTAMP='ts', TIMEUNIT='mi')",
expectError: false,
description: "带滚动窗口的SELECT语句",
},
{
name: "InvalidSQL",
sql: "INVALID SQL STATEMENT",
expectError: true,
description: "无效的SQL语句",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.sql)
_, err := parser.Parse()
if test.expectError {
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error for %s but got none", test.description)
}
} else {
if err != nil || parser.HasErrors() {
t.Errorf("Unexpected error for %s: %v", test.description, err)
if parser.HasErrors() {
for _, parseErr := range parser.GetErrors() {
t.Errorf("Parse error: %s", parseErr.Error())
}
}
}
}
})
}
}
// TestEdgeCases 测试边界情况
func TestEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
description string
}{
{
name: "EmptyInput",
input: "",
expectError: true,
description: "空输入",
},
{
name: "WhitespaceOnly",
input: " \t\n ",
expectError: true,
description: "仅包含空白字符",
},
{
name: "SingleKeyword",
input: "SELECT",
expectError: true,
description: "单个关键字",
},
{
name: "VeryLongFieldList",
input: "SELECT " + strings.Repeat("field, ", 10) + "field FROM table",
expectError: false, // 改回false,因为这应该是有效的SQL
description: "长字段列表",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
if test.expectError {
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error for %s but got none", test.description)
}
} else {
if err != nil || parser.HasErrors() {
t.Errorf("Unexpected error for %s: %v", test.description, err)
}
}
})
}
}
// TestParserAdvancedFeatures 测试解析器的高级功能
func TestParserAdvancedFeatures(t *testing.T) {
tests := []struct {
name string
sql string
expectError bool
}{
{
name: "WindowFunction",
sql: "SELECT COUNT(*) FROM events TUMBLINGWINDOW(5, 'mi')",
expectError: false,
},
{
name: "WithClause",
sql: "SELECT * FROM events WITH (TIMESTAMP='ts', TIMEUNIT='mi')",
expectError: false,
},
{
name: "ComplexExpression",
sql: "SELECT (temperature + humidity) * 2 as combined FROM sensors",
expectError: false,
},
{
name: "NestedParentheses",
sql: "SELECT * FROM events WHERE ((status = 'active') AND (priority > 5))",
expectError: false,
},
{
name: "FunctionCalls",
sql: "SELECT ABS(temperature), SQRT(humidity) FROM sensors",
expectError: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.sql)
_, err := parser.Parse()
if test.expectError {
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error for %s but got none", test.sql)
}
} else {
if err != nil || parser.HasErrors() {
t.Errorf("Unexpected error for %s: %v", test.sql, err)
if parser.HasErrors() {
for _, parseErr := range parser.GetErrors() {
t.Errorf("Parse error: %s", parseErr.Error())
}
}
}
}
})
}
}
// TestComplexQueries 测试复杂查询
func TestComplexQueries(t *testing.T) {
tests := []struct {
name string
query string
}{
{
name: "ComplexAggregation",
query: "SELECT COUNT(*), AVG(temperature), MAX(humidity), MIN(pressure) FROM sensors GROUP BY location, device_type HAVING COUNT(*) > 10",
},
{
name: "NestedFunctions",
query: "SELECT ROUND(AVG(ABS(temperature - 20)), 2) as avg_temp_diff FROM climate_data",
},
{
name: "MultipleConditions",
query: "SELECT * FROM events WHERE (status = 'active' OR status = 'pending') AND priority > 5 AND created_at > '2023-01-01'",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.query)
_, err := parser.Parse()
if err != nil {
t.Errorf("Failed to parse complex query: %v", err)
}
if parser.HasErrors() {
for _, parseErr := range parser.GetErrors() {
t.Errorf("Parse error: %s", parseErr.Error())
}
}
})
}
}
// TestParserPerformance 测试解析器性能
func TestParserPerformance(t *testing.T) {
// 测试大量解析操作的性能
for i := 0; i < 1000; i++ {
sql := "SELECT field1, field2, field3 FROM table WHERE condition = 'value'"
parser := NewParser(sql)
_, err := parser.Parse()
if err != nil {
t.Errorf("Iteration %d failed: %v", i, err)
break
}
}
}
// TestParserConcurrency 测试解析器并发安全性
func TestParserConcurrency(t *testing.T) {
const numGoroutines = 10
const numIterations = 10
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
for j := 0; j < numIterations; j++ {
sql := "SELECT * FROM table" + string(rune('0'+id))
parser := NewParser(sql)
_, err := parser.Parse()
if err != nil {
t.Errorf("Goroutine %d iteration %d failed: %v", id, j, err)
}
}
}(i)
}
// 等待所有goroutines完成
for i := 0; i < numGoroutines; i++ {
<-done
}
}
// TestParserMemoryUsage 测试内存使用情况
func TestParserMemoryUsage(t *testing.T) {
// 测试大量解析操作不会导致内存泄漏
for i := 0; i < 1000; i++ {
sql := "SELECT field1, field2, field3 FROM table WHERE condition = 'value'"
parser := NewParser(sql)
_, err := parser.Parse()
if err != nil {
t.Errorf("Iteration %d failed: %v", i, err)
break
}
}
}
// TestParserWithDifferentInputSizes 测试不同输入大小的解析
func TestParserWithDifferentInputSizes(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
}{
{
name: "VeryShort",
input: "SELECT 1",
expectError: true, // 缺少FROM子句
},
{
name: "Short",
input: "SELECT * FROM t",
expectError: false,
},
{
name: "Medium",
input: "SELECT id, name, email FROM users WHERE active = true AND created_at > '2023-01-01'",
expectError: false,
},
{
name: "Long",
input: "SELECT u.id, u.name, u.email, p.title, p.content, c.name as category FROM users u JOIN posts p ON u.id = p.user_id JOIN categories c ON p.category_id = c.id WHERE u.active = true AND p.published = true AND c.visible = true ORDER BY p.created_at DESC LIMIT 100",
expectError: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parser := NewParser(test.input)
_, err := parser.Parse()
if test.expectError {
if err == nil && !parser.HasErrors() {
t.Errorf("Expected error for %s but got none", test.name)
}
} else {
if err != nil || parser.HasErrors() {
t.Errorf("Unexpected error for %s: %v", test.name, err)
}
}
})
}
}