Files
streamsql/rsql/coverage_test.go
T
2025-08-06 17:18:50 +08:00

835 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package rsql
import (
"strings"
"testing"
"github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/types"
)
// TestParseSmartParameters 测试智能参数解析函数
func TestParseSmartParameters(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "简单参数",
input: "param1,param2,param3",
expected: []string{"param1", "param2", "param3"},
},
{
name: "带引号的参数",
input: "'hello world',param2,\"quoted string\"",
expected: []string{"'hello world'", "param2", "\"quoted string\""},
},
{
name: "空字符串",
input: "",
expected: []string{},
},
{
name: "单个参数",
input: "single_param",
expected: []string{"single_param"},
},
{
name: "带逗号的引号内容",
input: "'hello,world',param2",
expected: []string{"'hello,world'", "param2"},
},
{
name: "混合引号类型",
input: "'single',\"double\",mixed",
expected: []string{"'single'", "\"double\"", "mixed"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseSmartParameters(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("parseSmartParameters() length = %v, want %v", len(result), len(tt.expected))
return
}
for i, expected := range tt.expected {
if result[i] != expected {
t.Errorf("parseSmartParameters()[%d] = %v, want %v", i, result[i], expected)
}
}
})
}
}
// TestIsIdentifier 测试标识符验证函数
func TestIsIdentifier(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"有效标识符", "valid_identifier", true},
{"带数字", "valid123", true},
{"下划线开头", "_valid", true},
{"大写字母", "VALID", true},
{"混合大小写", "ValidIdentifier", true},
{"空字符串", "", false},
{"数字开头", "123invalid", false},
{"特殊字符", "invalid@", false},
{"空格", "invalid name", false},
{"单个字母", "a", true},
{"单个下划线", "_", true},
{"包含连字符", "invalid-name", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isIdentifier(tt.input)
if result != tt.expected {
t.Errorf("isIdentifier(%q) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}
// TestExtractSimpleField 测试简单字段提取函数
func TestExtractSimpleField(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"简单字段", "field_name", "field_name"},
{"加法运算", "field1 + field2", "field1"},
{"减法运算", "field1 - field2", "field1"},
{"乘法运算", "field1 * field2", "field1"},
{"除法运算", "field1 / field2", "field1"},
{"带空格", " field1 + field2 ", "field1"},
{"无运算符", "simple_field", "simple_field"},
{"复杂表达式", "field1 + field2 * field3", "field1 + field2"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractSimpleField(tt.input)
if result != tt.expected {
t.Errorf("extractSimpleField(%q) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}
// TestParseWindowParams 测试窗口参数解析函数
func TestParseWindowParams(t *testing.T) {
tests := []struct {
name string
params []interface{}
windowType string
expectError bool
}{
{
name: "会话窗口参数",
params: []interface{}{"10s", "5s"},
windowType: "SESSIONWINDOW",
expectError: false,
},
{
name: "滚动窗口参数",
params: []interface{}{"30s"},
windowType: "TUMBLINGWINDOW",
expectError: false,
},
{
name: "滑动窗口参数",
params: []interface{}{"60s", "30s"},
windowType: "SLIDINGWINDOW",
expectError: false,
},
{
name: "无效持续时间",
params: []interface{}{"invalid"},
windowType: "TUMBLINGWINDOW",
expectError: true,
},
{
name: "非字符串参数",
params: []interface{}{123},
windowType: "TUMBLINGWINDOW",
expectError: true,
},
{
name: "空参数",
params: []interface{}{},
windowType: "TUMBLINGWINDOW",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result map[string]interface{}
var err error
if tt.windowType == "SESSIONWINDOW" {
result, err = parseWindowParamsWithType(tt.params, "SESSIONWINDOW")
} else {
result, err = parseWindowParams(tt.params)
}
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got none")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result == nil {
t.Error("Expected result but got nil")
}
}
})
}
}
// TestParseAggregateExpression 测试聚合表达式解析函数
func TestParseAggregateExpression(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"AVG函数", "avg(temperature)", "avg"},
{"SUM函数", "sum(price)", "sum"},
{"MAX函数", "max(value)", "max"},
{"MIN函数", "min(count)", "min"},
{"嵌套函数", "avg(sum(temperature))", "avg"},
{"无聚合函数", "temperature", ""},
{"空字符串", "", ""},
{"部分匹配", "average(temperature)", ""},
{"带空格", " avg ( temperature ) ", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseAggregateExpression(tt.input)
if result != tt.expected {
t.Errorf("parseAggregateExpression(%q) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}
// TestExpectToken 测试期望token函数
func TestExpectToken(t *testing.T) {
// 测试正常情况
parser := NewParser("SELECT field FROM table")
tok, err := parser.expectToken(TokenSELECT, "SELECT clause")
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if tok.Type != TokenSELECT {
t.Errorf("Expected SELECT token, got %v", tok.Type)
}
// 测试错误情况 - 使用一个不会触发无限递归的情况
parser2 := NewParser("FROM table")
// 直接测试getTokenTypeName函数而不是expectToken
result := parser2.getTokenTypeName(TokenSELECT)
if result != "SELECT" {
t.Errorf("Expected 'SELECT', got %v", result)
}
}
// TestGetTokenTypeName 测试获取token类型名称函数
func TestGetTokenTypeName(t *testing.T) {
parser := NewParser("")
tests := []struct {
tokenType TokenType
expected string
}{
{TokenSELECT, "SELECT"},
{TokenFROM, "FROM"},
{TokenWHERE, "WHERE"},
{TokenGROUP, "GROUP"},
{TokenBY, "BY"},
{TokenComma, ","},
{TokenLParen, "("},
{TokenRParen, ")"},
{TokenIdent, "identifier"},
{TokenQuotedIdent, "quoted identifier"},
{TokenNumber, "number"},
{TokenString, "string"},
{TokenType(999), "unknown"}, // 未知类型
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := parser.getTokenTypeName(tt.tokenType)
if result != tt.expected {
t.Errorf("getTokenTypeName(%v) = %v, want %v", tt.tokenType, result, tt.expected)
}
})
}
}
// TestSkipToNextDelimiter 测试跳转到下一个分隔符函数
func TestSkipToNextDelimiter(t *testing.T) {
// 测试正常跳转
parser := NewParser("field1, field2 FROM table")
er := NewErrorRecovery(parser)
er.parser = parser
// 跳过到逗号
success := er.skipToNextDelimiter()
if !success {
t.Error("Expected successful skip to delimiter")
}
// 测试到达EOF
parser2 := NewParser("field1 field2")
er2 := NewErrorRecovery(parser2)
er2.parser = parser2
success = er2.skipToNextDelimiter()
if success {
t.Error("Expected failure when reaching EOF")
}
}
// TestCreateSemanticError 测试创建语义错误函数
func TestCreateSemanticError(t *testing.T) {
err := CreateSemanticError("Invalid field reference", 10)
if err == nil {
t.Error("Expected semantic error to be created")
}
if err.Type != ErrorTypeSemantics {
t.Errorf("Expected ErrorTypeSemantics, got %v", err.Type)
}
if err.Message != "Invalid field reference" {
t.Errorf("Expected message 'Invalid field reference', got %v", err.Message)
}
if err.Position != 10 {
t.Errorf("Expected position 10, got %v", err.Position)
}
}
// TestFormatErrorContext 测试格式化错误上下文函数
func TestFormatErrorContext(t *testing.T) {
input := "SELECT field FROM table WHERE condition"
position := 10
contextLength := 5
result := FormatErrorContext(input, position, contextLength)
if result == "" {
t.Error("Expected non-empty error context")
}
// 测试边界情况
result2 := FormatErrorContext("", 0, 0)
if result2 != "" {
t.Error("Expected empty result for empty input")
}
result3 := FormatErrorContext("short", 100, 10)
if result3 != "" {
t.Error("Expected empty result for out-of-bounds position")
}
}
// TestConvertValue 测试值转换函数
func TestConvertValue(t *testing.T) {
tests := []struct {
input string
expected interface{}
}{
{"123", 123},
{"3.14", 3.14},
{"true", true},
{"false", false},
{"hello", "hello"},
{"", ""},
{"null", "null"},
{"NULL", "NULL"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := convertValue(tt.input)
if result != tt.expected {
t.Errorf("convertValue(%q) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}
// TestHandleLimitToken 测试处理LIMIT token函数
func TestHandleLimitToken(t *testing.T) {
// 测试正常情况
parser := NewParser("10")
stmt := &SelectStatement{}
// 模拟LIMIT token
limitToken := Token{Type: TokenLIMIT, Value: "LIMIT"}
err := parser.handleLimitToken(stmt, limitToken)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if stmt.Limit != 10 {
t.Errorf("Expected limit 10, got %v", stmt.Limit)
}
// 测试无效LIMIT值
parser2 := NewParser("invalid")
stmt2 := &SelectStatement{}
invalidToken := Token{Type: TokenLIMIT, Value: "LIMIT"}
err = parser2.handleLimitToken(stmt2, invalidToken)
if err == nil {
t.Error("Expected error for invalid limit value")
}
}
// TestReadString 测试读取字符串函数
func TestReadString(t *testing.T) {
// 测试正常字符串
lexer := NewLexer("'hello world'")
lexer.readChar() // 跳过开始的引号
result := lexer.readString()
if result != "hello world'" {
t.Errorf("Expected 'hello world', got %v", result)
}
// 测试未闭合的字符串
lexer2 := NewLexer("'unclosed string")
lexer2.readChar()
result2 := lexer2.readString()
if result2 != "unclosed string" {
t.Errorf("Expected partial string, got %v", result2)
}
}
// TestGenerateFunctionSuggestions 测试生成函数建议函数
func TestGenerateFunctionSuggestions(t *testing.T) {
tests := []struct {
functionName string
expectEmpty bool
}{
{"AVG", false},
{"SUM", false},
{"MAX", false},
{"MIN", false},
{"COUNT", false},
{"unknown_func", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.functionName, func(t *testing.T) {
suggestions := generateFunctionSuggestions(tt.functionName)
if tt.expectEmpty && len(suggestions) > 0 {
t.Errorf("Expected empty suggestions for %q, got %v", tt.functionName, suggestions)
}
if !tt.expectEmpty && len(suggestions) == 0 {
t.Errorf("Expected suggestions for %q, got empty", tt.functionName)
}
})
}
}
// TestBuildSelectFieldsWithExpressions 测试带表达式的选择字段构建函数
func TestBuildSelectFieldsWithExpressions(t *testing.T) {
tests := []struct {
name string
fields []Field
checkFunc func(*testing.T, map[string]aggregator.AggregateType, map[string]string, map[string]types.FieldExpression)
}{
{
name: "带别名的聚合函数",
fields: []Field{
{Expression: "AVG(temperature)", Alias: "avg_temp"},
{Expression: "SUM(price)", Alias: "total_price"},
},
checkFunc: func(t *testing.T, aggMap map[string]aggregator.AggregateType, fieldMap map[string]string, expressions map[string]types.FieldExpression) {
if len(aggMap) != 2 {
t.Errorf("Expected 2 aggregate functions, got %d", len(aggMap))
}
if len(fieldMap) != 2 {
t.Errorf("Expected 2 field mappings, got %d", len(fieldMap))
}
if len(expressions) != 0 {
t.Errorf("Expected 0 expressions, got %d", len(expressions))
}
},
},
{
name: "无别名的聚合函数",
fields: []Field{
{Expression: "AVG(temperature)", Alias: ""},
{Expression: "COUNT(*)", Alias: ""},
},
checkFunc: func(t *testing.T, aggMap map[string]aggregator.AggregateType, fieldMap map[string]string, expressions map[string]types.FieldExpression) {
if len(aggMap) == 0 {
t.Error("Expected aggregate functions")
}
},
},
{
name: "非聚合字段",
fields: []Field{
{Expression: "temperature", Alias: "temp"},
{Expression: "humidity", Alias: "hum"},
},
checkFunc: func(t *testing.T, aggMap map[string]aggregator.AggregateType, fieldMap map[string]string, expressions map[string]types.FieldExpression) {
if len(aggMap) != 0 {
t.Errorf("Expected no aggregate functions, got %d", len(aggMap))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aggMap, fieldMap, expressions := buildSelectFieldsWithExpressions(tt.fields)
tt.checkFunc(t, aggMap, fieldMap, expressions)
})
}
}
// TestParserLexerErrorHandling 测试词法分析器错误处理
func TestParserLexerErrorHandling(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
errorCount int
}{
{
name: "无效字符",
input: "SELECT @ FROM table",
expectError: true,
errorCount: 1,
},
{
name: "未终止字符串",
input: "SELECT 'unterminated FROM table",
expectError: true,
errorCount: 1,
},
{
name: "无效数字格式",
input: "SELECT 123abc FROM table",
expectError: false, // 解析为数字123和标识符abc
},
{
name: "反引号不匹配",
input: "SELECT `column FROM table",
expectError: true,
errorCount: 1,
},
{
name: "多个语法错误",
input: "SELECT @ 'unterminated `column FROM table",
expectError: true,
errorCount: 2,
},
{
name: "空输入",
input: "",
expectError: false,
errorCount: 0,
},
{
name: "只有空格",
input: " \t\n ",
expectError: false,
errorCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lexer := NewLexer(tt.input)
er := NewErrorRecovery(nil)
lexer.SetErrorRecovery(er)
// 读取所有token
for {
token := lexer.NextToken()
if token.Type == TokenEOF {
break
}
}
if er.HasErrors() != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, er.HasErrors())
}
if tt.expectError && len(er.GetErrors()) != tt.errorCount {
t.Errorf("expected %d errors, got %d", tt.errorCount, len(er.GetErrors()))
}
})
}
}
// TestParserErrorRecoveryCases 测试解析器错误恢复
// 注意:当前解析器的错误恢复机制可能与预期不同
func TestParserErrorRecoveryCases(t *testing.T) {
tests := []struct {
name string
query string
expectError bool
shouldParse bool
}{
{
name: "语法错误恢复",
query: "SELECT FROM table WHERE", // 缺少列名
expectError: true,
shouldParse: false,
},
{
name: "FROM子句缺失",
query: "SELECT column WHERE value > 1",
expectError: true,
shouldParse: false,
},
{
name: "WHERE条件不完整",
query: "SELECT * FROM table WHERE field >",
expectError: false, // 解析器可能将其作为表达式处理
shouldParse: true,
},
{
name: "括号不匹配表达式",
query: "SELECT func(field FROM table",
expectError: true, // 解析器检测到语法错误
shouldParse: false,
},
{
name: "空查询",
query: "",
expectError: true,
shouldParse: false,
},
{
name: "只有SELECT",
query: "SELECT",
expectError: true,
shouldParse: false,
},
{
name: "长标识符",
query: "SELECT " + strings.Repeat("a", 100) + " FROM table",
expectError: false,
shouldParse: true, // 解析器支持长标识符
},
{
name: "子查询语法",
query: "SELECT column FROM (SELECT * FROM table) AS sub",
expectError: true, // 当前解析器不支持子查询
shouldParse: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.query)
_, err := parser.Parse()
if (err != nil) != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
}
// 验证是否能继续解析
if !tt.shouldParse && err == nil {
t.Errorf("expected parsing to fail, but it succeeded")
}
})
}
}
// TestParserCaseExpressionErrors 测试CASE表达式错误处理
// 注意当前解析器将CASE表达式作为普通表达式处理不进行特殊语法验证
func TestParserCaseExpressionErrors(t *testing.T) {
tests := []struct {
name string
query string
expectError bool
}{
{
name: "CASE作为表达式",
query: "SELECT CASE WHEN condition THEN value END FROM table",
expectError: false, // 当前解析器将其作为表达式处理
},
{
name: "简单CASE表达式",
query: "SELECT CASE field WHEN 1 THEN 'one' ELSE 'other' END FROM table",
expectError: false,
},
{
name: "CASE表达式在WHERE中",
query: "SELECT * FROM table WHERE CASE WHEN field > 0 THEN 1 ELSE 0 END = 1",
expectError: false,
},
{
name: "嵌套CASE表达式",
query: "SELECT CASE WHEN CASE WHEN field > 0 THEN 1 ELSE 0 END = 1 THEN 'positive' ELSE 'negative' END FROM table",
expectError: false,
},
{
name: "CASE表达式语法错误",
query: "SELECT CASE WHEN value > 10 AND < 20 THEN 'A' END FROM table",
expectError: false, // 解析器不验证表达式内部语法
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.query)
_, err := parser.Parse()
if (err != nil) != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
}
})
}
}
// TestParserComplexFieldAccess 测试复杂字段访问错误处理
// 注意:当前解析器将复杂字段访问作为表达式处理
func TestParserComplexFieldAccess(t *testing.T) {
tests := []struct {
name string
query string
expectError bool
}{
{
name: "数组索引语法错误",
query: "SELECT field[ FROM table",
expectError: false, // 解析器将其作为表达式处理
},
{
name: "嵌套数组访问",
query: "SELECT field[0][1] FROM table",
expectError: false,
},
{
name: "字符串键访问",
query: "SELECT field['key'] FROM table",
expectError: false,
},
{
name: "双引号字符串键",
query: "SELECT field[\"key\"] FROM table",
expectError: false,
},
{
name: "嵌套字段访问",
query: "SELECT field.nested.deep FROM table",
expectError: false,
},
{
name: "混合访问表达式",
query: "SELECT field.nested[0].deep FROM table",
expectError: true, // lexer不支持点号在表达式中
},
{
name: "标识符数组索引",
query: "SELECT field[abc] FROM table",
expectError: false,
},
{
name: "未闭合括号",
query: "SELECT field[0 FROM table",
expectError: false, // 解析器将其作为表达式处理
},
{
name: "空数组索引",
query: "SELECT field[] FROM table",
expectError: false, // 解析器将其作为表达式处理
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.query)
_, err := parser.Parse()
if (err != nil) != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
}
})
}
}
// TestParserBoundaryConditions 测试解析器边界条件
func TestParserBoundaryConditions(t *testing.T) {
tests := []struct {
name string
query string
expectError bool
}{
{
name: "超长查询",
query: "SELECT " + strings.Repeat("column", 100) + " FROM table",
expectError: false,
},
{
name: "大量字段",
query: "SELECT " + strings.Repeat("col,", 50) + "last FROM table",
expectError: false, // 解析器实际上没有字段数量限制
},
{
name: "复杂WHERE条件",
query: "SELECT * FROM table WHERE " + strings.Repeat("field > 0 AND ", 100) + "true",
expectError: false,
},
{
name: "特殊字符标识符",
query: "SELECT `column with spaces`, \"quoted column\" FROM table",
expectError: false,
},
{
name: "GROUP BY大量字段",
query: "SELECT * FROM table GROUP BY " + strings.Repeat("field,", 100) + "last",
expectError: false,
},
{
name: "HAVING子句",
query: "SELECT * FROM table HAVING COUNT(*) > 0",
expectError: false,
},
{
name: "ORDER BY子句",
query: "SELECT * FROM table ORDER BY field",
expectError: false,
},
{
name: "LIMIT子句",
query: "SELECT * FROM table LIMIT 100",
expectError: false,
},
{
name: "窗口函数",
query: "SELECT * FROM table TUMBLINGWINDOW(ss, 5)",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.query)
_, err := parser.Parse()
if (err != nil) != tt.expectError {
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
}
})
}
}