Files
2025-11-13 11:03:49 +08:00

695 lines
18 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package rsql
import (
"strings"
"testing"
"github.com/rulego/streamsql/window"
)
// TestSelectStatement_ToStreamConfig 测试 SelectStatement 转换为 Stream 配置
func TestSelectStatement_ToStreamConfig(t *testing.T) {
tests := []struct {
name string
stmt *SelectStatement
wantErr bool
errMsg string
checkFunc func(*testing.T, *SelectStatement)
}{
{
name: "基本 SELECT 语句",
stmt: &SelectStatement{
Fields: []Field{
{Expression: "temperature", Alias: "temp"},
{Expression: "humidity", Alias: ""},
},
Source: "sensor_data",
},
wantErr: false,
checkFunc: func(t *testing.T, stmt *SelectStatement) {
config, condition, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() error = %v", err)
return
}
if config == nil {
t.Error("ToStreamConfig() returned nil config")
return
}
if condition != "" {
t.Errorf("Expected empty condition, got %s", condition)
}
if len(config.SimpleFields) != 2 {
t.Errorf("Expected 2 simple fields, got %d", len(config.SimpleFields))
}
},
},
{
name: "SELECT * 语句",
stmt: &SelectStatement{
SelectAll: true,
Source: "sensor_data",
},
wantErr: false,
checkFunc: func(t *testing.T, stmt *SelectStatement) {
config, _, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() error = %v", err)
return
}
if len(config.SimpleFields) != 1 || config.SimpleFields[0] != "*" {
t.Errorf("Expected SimpleFields to contain '*', got %v", config.SimpleFields)
}
},
},
{
name: "带聚合函数的语句",
stmt: &SelectStatement{
Fields: []Field{
{Expression: "AVG(temperature)", Alias: "avg_temp"},
{Expression: "COUNT(*)", Alias: "count"},
},
Source: "sensor_data",
Window: WindowDefinition{
Type: "TUMBLINGWINDOW",
Params: []interface{}{"10s"},
},
},
wantErr: false,
checkFunc: func(t *testing.T, stmt *SelectStatement) {
config, _, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() error = %v", err)
return
}
if config.WindowConfig.Type != window.TypeTumbling {
t.Errorf("Expected tumbling window, got %v", config.WindowConfig.Type)
}
if !config.NeedWindow {
t.Error("Expected NeedWindow to be true")
}
},
},
{
name: "缺少 FROM 子句",
stmt: &SelectStatement{
Fields: []Field{
{Expression: "temperature"},
},
},
wantErr: true,
errMsg: "missing FROM clause",
},
{
name: "带 DISTINCT 的语句",
stmt: &SelectStatement{
Fields: []Field{
{Expression: "category"},
},
Distinct: true,
Source: "products",
},
wantErr: false,
checkFunc: func(t *testing.T, stmt *SelectStatement) {
config, _, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() error = %v", err)
return
}
if !config.Distinct {
t.Error("Expected Distinct to be true")
}
},
},
{
name: "带 LIMIT 的语句",
stmt: &SelectStatement{
Fields: []Field{
{Expression: "name"},
},
Source: "users",
Limit: 100,
},
wantErr: false,
checkFunc: func(t *testing.T, stmt *SelectStatement) {
config, _, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() error = %v", err)
return
}
if config.Limit != 100 {
t.Errorf("Expected Limit to be 100, got %d", config.Limit)
}
},
},
{
name: "带 GROUP BY 的语句",
stmt: &SelectStatement{
Fields: []Field{
{Expression: "category"},
{Expression: "COUNT(*)", Alias: "count"},
},
Source: "products",
GroupBy: []string{"category"},
},
wantErr: false,
checkFunc: func(t *testing.T, stmt *SelectStatement) {
config, _, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() error = %v", err)
return
}
if len(config.GroupFields) != 1 || config.GroupFields[0] != "category" {
t.Errorf("Expected GroupFields to contain 'category', got %v", config.GroupFields)
}
},
},
{
name: "带 HAVING 的语句",
stmt: &SelectStatement{
Fields: []Field{
{Expression: "category"},
{Expression: "COUNT(*)", Alias: "count"},
},
Source: "products",
GroupBy: []string{"category"},
Having: "COUNT(*) > 10",
},
wantErr: false,
checkFunc: func(t *testing.T, stmt *SelectStatement) {
config, _, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() error = %v", err)
return
}
if config.Having != "COUNT(*) > 10" {
t.Errorf("Expected Having to be 'COUNT(*) > 10', got %s", config.Having)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.wantErr {
_, _, err := tt.stmt.ToStreamConfig()
if err == nil {
t.Error("ToStreamConfig() expected error but got none")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("ToStreamConfig() error = %v, expected to contain %s", err, tt.errMsg)
}
} else {
if tt.checkFunc != nil {
tt.checkFunc(t, tt.stmt)
}
}
})
}
}
// TestSelectStatementEdgeCases 测试边界情况
func TestSelectStatementEdgeCases(t *testing.T) {
// 测试空字段列表
stmt := &SelectStatement{
Fields: []Field{},
Source: "test_table",
}
config, condition, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() with empty fields error = %v", err)
return
}
if config == nil {
t.Error("ToStreamConfig() returned nil config")
return
}
if condition != "" {
t.Errorf("Expected empty condition, got %s", condition)
}
// 测试复杂窗口类型
stmt2 := &SelectStatement{
Fields: []Field{
{Expression: "COUNT(*)", Alias: "count"},
},
Source: "test_table",
Window: WindowDefinition{
Type: "SESSIONWINDOW",
Params: []interface{}{"30s"},
},
GroupBy: []string{"user_id"},
}
config2, _, err := stmt2.ToStreamConfig()
if err != nil {
t.Errorf("ToStreamConfig() with session window error = %v", err)
return
}
if config2.WindowConfig.Type != window.TypeSession {
t.Errorf("Expected session window, got %v", config2.WindowConfig.Type)
}
if len(config2.WindowConfig.GroupByKeys) == 0 || config2.WindowConfig.GroupByKeys[0] != "user_id" {
t.Errorf("Expected GroupByKeys to contain 'user_id', got %v", config2.WindowConfig.GroupByKeys)
}
}
// TestSelectStatementConcurrency 测试并发安全性
func TestSelectStatementConcurrency(t *testing.T) {
stmt := &SelectStatement{
Fields: []Field{
{Expression: "temperature", Alias: "temp"},
{Expression: "COUNT(*)", Alias: "count"},
},
Source: "sensor_data",
Window: WindowDefinition{
Type: "TUMBLINGWINDOW",
Params: []interface{}{"10s"},
},
}
// 启动多个 goroutine 并发调用 ToStreamConfig
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 100; j++ {
config, condition, err := stmt.ToStreamConfig()
if err != nil {
t.Errorf("Concurrent ToStreamConfig() error = %v", err)
return
}
if config == nil {
t.Error("Concurrent ToStreamConfig() returned nil config")
return
}
if condition != "" {
t.Errorf("Concurrent ToStreamConfig() expected empty condition, got %s", condition)
return
}
}
done <- true
}()
}
// 等待所有 goroutine 完成
for i := 0; i < 10; i++ {
<-done
}
}
// TestBuildSelectFields 测试 buildSelectFields 函数
func TestBuildSelectFields(t *testing.T) {
tests := []struct {
name string
fields []Field
wantAggs map[string]string
wantMap map[string]string
}{
{
name: "带别名的聚合函数",
fields: []Field{
{Expression: "AVG(temperature)", Alias: "avg_temp"},
{Expression: "COUNT(*)", Alias: "total_count"},
},
wantAggs: map[string]string{
"avg_temp": "AVG",
"total_count": "COUNT",
},
wantMap: map[string]string{
"avg_temp": "temperature",
"total_count": "*",
},
},
{
name: "无别名的聚合函数",
fields: []Field{
{Expression: "SUM(amount)"},
{Expression: "MAX(price)"},
},
wantAggs: map[string]string{
"amount": "SUM",
"price": "MAX",
},
wantMap: map[string]string{
"amount": "amount",
"price": "price",
},
},
{
name: "混合字段",
fields: []Field{
{Expression: "name"},
{Expression: "COUNT(*)", Alias: "count"},
},
wantAggs: map[string]string{
"count": "COUNT",
},
wantMap: map[string]string{
"count": "*",
},
},
{
name: "空字段列表",
fields: []Field{},
wantAggs: map[string]string{},
wantMap: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aggMap, fieldMap, err := buildSelectFields(tt.fields)
if err != nil {
t.Errorf("buildSelectFields() error = %v", err)
return
}
// 检查聚合函数映射
if len(aggMap) != len(tt.wantAggs) {
t.Errorf("buildSelectFields() aggMap length = %d, want %d", len(aggMap), len(tt.wantAggs))
}
for key, want := range tt.wantAggs {
if got := string(aggMap[key]); got != want {
t.Errorf("buildSelectFields() aggMap[%s] = %s, want %s", key, got, want)
}
}
// 检查字段映射
if len(fieldMap) != len(tt.wantMap) {
t.Errorf("buildSelectFields() fieldMap length = %d, want %d", len(fieldMap), len(tt.wantMap))
}
for key, want := range tt.wantMap {
if got := fieldMap[key]; got != want {
t.Errorf("buildSelectFields() fieldMap[%s] = %s, want %s", key, got, want)
}
}
})
}
}
// TestIsAggregationFunction 测试 isAggregationFunction 函数
func TestIsAggregationFunction(t *testing.T) {
tests := []struct {
name string
expr string
want bool
}{
{"COUNT函数", "COUNT(*)", true},
{"AVG函数", "AVG(temperature)", true},
{"SUM函数", "SUM(amount)", true},
{"MAX函数", "MAX(price)", true},
{"MIN函数", "MIN(value)", true},
{"简单字段", "temperature", false},
{"字符串字面量", "'hello'", false},
{"数字字面量", "123", false},
{"空字符串", "", false},
{"表达式", "temperature + 10", false},
{"UPPER函数", "UPPER(name)", false},
{"CONCAT函数", "CONCAT(first_name, last_name)", false},
{"未知函数", "UNKNOWN_FUNC(field)", true}, // 保守处理
{"复杂表达式", "temperature > 25 AND humidity < 80", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isAggregationFunction(tt.expr); got != tt.want {
t.Errorf("isAggregationFunction(%s) = %v, want %v", tt.expr, got, tt.want)
}
})
}
}
// TestParseAggregateTypeWithExpression 测试 ParseAggregateTypeWithExpression 函数
func TestParseAggregateTypeWithExpression(t *testing.T) {
tests := []struct {
name string
exprStr string
wantAggType string
wantName string
wantExpression string
wantFields []string
}{
{
name: "COUNT聚合函数",
exprStr: "COUNT(*)",
wantAggType: "COUNT",
wantName: "*",
},
{
name: "AVG聚合函数",
exprStr: "AVG(temperature)",
wantAggType: "AVG",
wantName: "temperature",
},
{
name: "字符串字面量",
exprStr: "'hello world'",
wantAggType: "expression",
wantName: "hello world",
wantExpression: "'hello world'",
},
{
name: "双引号字符串",
exprStr: "\"test string\"",
wantAggType: "expression",
wantName: "test string",
wantExpression: "\"test string\"",
},
{
name: "CASE表达式",
exprStr: "CASE WHEN temperature > 25 THEN 'hot' ELSE 'cold' END",
wantAggType: "expression",
wantExpression: "CASE WHEN temperature > 25 THEN 'hot' ELSE 'cold' END",
},
{
name: "数学表达式",
exprStr: "temperature + 10",
wantAggType: "expression",
wantExpression: "temperature + 10",
},
{
name: "比较表达式",
exprStr: "temperature > 25",
wantAggType: "expression",
wantExpression: "temperature > 25",
},
{
name: "逻辑表达式",
exprStr: "temperature > 25 AND humidity < 80",
wantAggType: "expression",
wantExpression: "temperature > 25 AND humidity < 80",
},
{
name: "简单字段",
exprStr: "temperature",
wantAggType: "",
},
{
name: "UPPER字符串函数",
exprStr: "UPPER(name)",
wantAggType: "expression",
wantName: "name",
wantExpression: "UPPER(name)",
},
{
name: "CONCAT字符串函数",
exprStr: "CONCAT(first_name, last_name)",
wantAggType: "expression",
wantName: "first_name",
wantExpression: "CONCAT(first_name, last_name)",
},
}
// 测试正常情况
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aggType, name, expression, allFields, err := ParseAggregateTypeWithExpression(tt.exprStr)
if err != nil {
t.Errorf("ParseAggregateTypeWithExpression() returned error: %v", err)
return
}
if string(aggType) != tt.wantAggType {
t.Errorf("ParseAggregateTypeWithExpression() aggType = %s, want %s", aggType, tt.wantAggType)
}
if name != tt.wantName {
t.Errorf("ParseAggregateTypeWithExpression() name = %s, want %s", name, tt.wantName)
}
if tt.wantExpression != "" && expression != tt.wantExpression {
t.Errorf("ParseAggregateTypeWithExpression() expression = %s, want %s", expression, tt.wantExpression)
}
if tt.wantFields != nil {
if len(allFields) != len(tt.wantFields) {
t.Errorf("ParseAggregateTypeWithExpression() allFields length = %d, want %d", len(allFields), len(tt.wantFields))
} else {
for i, field := range tt.wantFields {
if allFields[i] != field {
t.Errorf("ParseAggregateTypeWithExpression() allFields[%d] = %s, want %s", i, allFields[i], field)
}
}
}
}
})
}
// 测试嵌套聚合函数检测
nestedTests := []struct {
name string
exprStr string
}{
{
name: "嵌套聚合函数 - MAX(AVG(temperature))",
exprStr: "MAX(AVG(temperature))",
},
{
name: "嵌套聚合函数 - COUNT(SUM(price))",
exprStr: "COUNT(SUM(price))",
},
{
name: "复杂嵌套 - MAX(ROUND(AVG(temperature), 1))",
exprStr: "MAX(ROUND(AVG(temperature), 1))",
},
}
for _, tt := range nestedTests {
t.Run(tt.name, func(t *testing.T) {
_, _, _, _, err := ParseAggregateTypeWithExpression(tt.exprStr)
if err == nil {
t.Errorf("ParseAggregateTypeWithExpression() should return error for nested aggregation: %s", tt.exprStr)
} else if !strings.Contains(err.Error(), "aggregate function calls cannot be nested") {
t.Errorf("ParseAggregateTypeWithExpression() error message should contain 'aggregate function calls cannot be nested', got: %v", err)
}
})
}
}
// TestDetectNestedAggregation 测试嵌套聚合函数检测
func TestDetectNestedAggregation(t *testing.T) {
tests := []struct {
name string
exprStr string
wantError bool
}{
{
name: "正常聚合函数",
exprStr: "MAX(temperature)",
wantError: false,
},
{
name: "嵌套聚合函数",
exprStr: "MAX(AVG(temperature))",
wantError: true,
},
{
name: "复杂嵌套",
exprStr: "MAX(ROUND(AVG(temperature), 1))",
wantError: true,
},
{
name: "非聚合函数嵌套",
exprStr: "UPPER(CONCAT(first_name, last_name))",
wantError: false,
},
{
name: "聚合函数包含非聚合函数",
exprStr: "MAX(ROUND(temperature, 1))",
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := detectNestedAggregation(tt.exprStr)
if tt.wantError && err == nil {
t.Errorf("detectNestedAggregation() should return error for: %s", tt.exprStr)
} else if !tt.wantError && err != nil {
t.Errorf("detectNestedAggregation() should not return error for: %s, got: %v", tt.exprStr, err)
}
})
}
}
// TestExtractAggFieldWithExpression 测试 extractAggFieldWithExpression 函数
func TestExtractAggFieldWithExpression(t *testing.T) {
tests := []struct {
name string
exprStr string
funcName string
wantFieldName string
wantExpression string
wantAllFields []string
}{
{
name: "COUNT星号",
exprStr: "COUNT(*)",
funcName: "count",
wantFieldName: "*",
},
{
name: "简单字段",
exprStr: "AVG(temperature)",
funcName: "AVG",
wantFieldName: "temperature",
},
{
name: "CONCAT函数",
exprStr: "CONCAT(first_name, last_name)",
funcName: "concat",
wantFieldName: "first_name",
wantExpression: "concat(first_name, last_name)",
wantAllFields: []string{"first_name", "last_name"},
},
{
name: "复杂表达式",
exprStr: "SUM(price * quantity)",
funcName: "SUM",
wantFieldName: "price",
wantExpression: "price * quantity",
},
{
name: "多参数函数",
exprStr: "DISTANCE(x1, y1, x2, y2)",
funcName: "DISTANCE",
wantFieldName: "x1",
wantExpression: "x1, y1, x2, y2",
// 不检查 allFields因为实际行为可能与预期不同
},
{
name: "无效表达式",
exprStr: "INVALID",
funcName: "COUNT",
},
{
name: "括号不匹配",
exprStr: "COUNT(",
funcName: "COUNT",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fieldName, expression, allFields := extractAggFieldWithExpression(tt.exprStr, tt.funcName)
if fieldName != tt.wantFieldName {
t.Errorf("extractAggFieldWithExpression() fieldName = %s, want %s", fieldName, tt.wantFieldName)
}
if tt.wantExpression != "" && expression != tt.wantExpression {
t.Errorf("extractAggFieldWithExpression() expression = %s, want %s", expression, tt.wantExpression)
}
if tt.wantAllFields != nil {
if len(allFields) != len(tt.wantAllFields) {
t.Errorf("extractAggFieldWithExpression() allFields length = %d, want %d, got fields: %v", len(allFields), len(tt.wantAllFields), allFields)
} else {
for i, field := range tt.wantAllFields {
if i < len(allFields) && allFields[i] != field {
t.Errorf("extractAggFieldWithExpression() allFields[%d] = %s, want %s", i, allFields[i], field)
}
}
}
}
})
}
}