mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-03-12 13:27:25 +00:00
695 lines
18 KiB
Go
695 lines
18 KiB
Go
package rsql
|
||
|
||
import (
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/rulego/streamsql/window"
|
||
)
|
||
|
||
// TestSelectStatement_ToStreamConfig 测试 SelectStatement 转换为 Stream 配置
|
||
func TestSelectStatement_ToStreamConfig(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
stmt *SelectStatement
|
||
wantErr bool
|
||
errMsg string
|
||
checkFunc func(*testing.T, *SelectStatement)
|
||
}{
|
||
{
|
||
name: "基本 SELECT 语句",
|
||
stmt: &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "temperature", Alias: "temp"},
|
||
{Expression: "humidity", Alias: ""},
|
||
},
|
||
Source: "sensor_data",
|
||
},
|
||
wantErr: false,
|
||
checkFunc: func(t *testing.T, stmt *SelectStatement) {
|
||
config, condition, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if config == nil {
|
||
t.Error("ToStreamConfig() returned nil config")
|
||
return
|
||
}
|
||
if condition != "" {
|
||
t.Errorf("Expected empty condition, got %s", condition)
|
||
}
|
||
if len(config.SimpleFields) != 2 {
|
||
t.Errorf("Expected 2 simple fields, got %d", len(config.SimpleFields))
|
||
}
|
||
},
|
||
},
|
||
{
|
||
name: "SELECT * 语句",
|
||
stmt: &SelectStatement{
|
||
SelectAll: true,
|
||
Source: "sensor_data",
|
||
},
|
||
wantErr: false,
|
||
checkFunc: func(t *testing.T, stmt *SelectStatement) {
|
||
config, _, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if len(config.SimpleFields) != 1 || config.SimpleFields[0] != "*" {
|
||
t.Errorf("Expected SimpleFields to contain '*', got %v", config.SimpleFields)
|
||
}
|
||
},
|
||
},
|
||
{
|
||
name: "带聚合函数的语句",
|
||
stmt: &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "AVG(temperature)", Alias: "avg_temp"},
|
||
{Expression: "COUNT(*)", Alias: "count"},
|
||
},
|
||
Source: "sensor_data",
|
||
Window: WindowDefinition{
|
||
Type: "TUMBLINGWINDOW",
|
||
Params: []interface{}{"10s"},
|
||
},
|
||
},
|
||
wantErr: false,
|
||
checkFunc: func(t *testing.T, stmt *SelectStatement) {
|
||
config, _, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if config.WindowConfig.Type != window.TypeTumbling {
|
||
t.Errorf("Expected tumbling window, got %v", config.WindowConfig.Type)
|
||
}
|
||
if !config.NeedWindow {
|
||
t.Error("Expected NeedWindow to be true")
|
||
}
|
||
},
|
||
},
|
||
{
|
||
name: "缺少 FROM 子句",
|
||
stmt: &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "temperature"},
|
||
},
|
||
},
|
||
wantErr: true,
|
||
errMsg: "missing FROM clause",
|
||
},
|
||
{
|
||
name: "带 DISTINCT 的语句",
|
||
stmt: &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "category"},
|
||
},
|
||
Distinct: true,
|
||
Source: "products",
|
||
},
|
||
wantErr: false,
|
||
checkFunc: func(t *testing.T, stmt *SelectStatement) {
|
||
config, _, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if !config.Distinct {
|
||
t.Error("Expected Distinct to be true")
|
||
}
|
||
},
|
||
},
|
||
{
|
||
name: "带 LIMIT 的语句",
|
||
stmt: &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "name"},
|
||
},
|
||
Source: "users",
|
||
Limit: 100,
|
||
},
|
||
wantErr: false,
|
||
checkFunc: func(t *testing.T, stmt *SelectStatement) {
|
||
config, _, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if config.Limit != 100 {
|
||
t.Errorf("Expected Limit to be 100, got %d", config.Limit)
|
||
}
|
||
},
|
||
},
|
||
{
|
||
name: "带 GROUP BY 的语句",
|
||
stmt: &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "category"},
|
||
{Expression: "COUNT(*)", Alias: "count"},
|
||
},
|
||
Source: "products",
|
||
GroupBy: []string{"category"},
|
||
},
|
||
wantErr: false,
|
||
checkFunc: func(t *testing.T, stmt *SelectStatement) {
|
||
config, _, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if len(config.GroupFields) != 1 || config.GroupFields[0] != "category" {
|
||
t.Errorf("Expected GroupFields to contain 'category', got %v", config.GroupFields)
|
||
}
|
||
},
|
||
},
|
||
{
|
||
name: "带 HAVING 的语句",
|
||
stmt: &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "category"},
|
||
{Expression: "COUNT(*)", Alias: "count"},
|
||
},
|
||
Source: "products",
|
||
GroupBy: []string{"category"},
|
||
Having: "COUNT(*) > 10",
|
||
},
|
||
wantErr: false,
|
||
checkFunc: func(t *testing.T, stmt *SelectStatement) {
|
||
config, _, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if config.Having != "COUNT(*) > 10" {
|
||
t.Errorf("Expected Having to be 'COUNT(*) > 10', got %s", config.Having)
|
||
}
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
if tt.wantErr {
|
||
_, _, err := tt.stmt.ToStreamConfig()
|
||
if err == nil {
|
||
t.Error("ToStreamConfig() expected error but got none")
|
||
return
|
||
}
|
||
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
|
||
t.Errorf("ToStreamConfig() error = %v, expected to contain %s", err, tt.errMsg)
|
||
}
|
||
} else {
|
||
if tt.checkFunc != nil {
|
||
tt.checkFunc(t, tt.stmt)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestSelectStatementEdgeCases 测试边界情况
|
||
func TestSelectStatementEdgeCases(t *testing.T) {
|
||
// 测试空字段列表
|
||
stmt := &SelectStatement{
|
||
Fields: []Field{},
|
||
Source: "test_table",
|
||
}
|
||
|
||
config, condition, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() with empty fields error = %v", err)
|
||
return
|
||
}
|
||
if config == nil {
|
||
t.Error("ToStreamConfig() returned nil config")
|
||
return
|
||
}
|
||
if condition != "" {
|
||
t.Errorf("Expected empty condition, got %s", condition)
|
||
}
|
||
|
||
// 测试复杂窗口类型
|
||
stmt2 := &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "COUNT(*)", Alias: "count"},
|
||
},
|
||
Source: "test_table",
|
||
Window: WindowDefinition{
|
||
Type: "SESSIONWINDOW",
|
||
Params: []interface{}{"30s"},
|
||
},
|
||
GroupBy: []string{"user_id"},
|
||
}
|
||
|
||
config2, _, err := stmt2.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("ToStreamConfig() with session window error = %v", err)
|
||
return
|
||
}
|
||
if config2.WindowConfig.Type != window.TypeSession {
|
||
t.Errorf("Expected session window, got %v", config2.WindowConfig.Type)
|
||
}
|
||
if len(config2.WindowConfig.GroupByKeys) == 0 || config2.WindowConfig.GroupByKeys[0] != "user_id" {
|
||
t.Errorf("Expected GroupByKeys to contain 'user_id', got %v", config2.WindowConfig.GroupByKeys)
|
||
}
|
||
}
|
||
|
||
// TestSelectStatementConcurrency 测试并发安全性
|
||
func TestSelectStatementConcurrency(t *testing.T) {
|
||
stmt := &SelectStatement{
|
||
Fields: []Field{
|
||
{Expression: "temperature", Alias: "temp"},
|
||
{Expression: "COUNT(*)", Alias: "count"},
|
||
},
|
||
Source: "sensor_data",
|
||
Window: WindowDefinition{
|
||
Type: "TUMBLINGWINDOW",
|
||
Params: []interface{}{"10s"},
|
||
},
|
||
}
|
||
|
||
// 启动多个 goroutine 并发调用 ToStreamConfig
|
||
done := make(chan bool, 10)
|
||
for i := 0; i < 10; i++ {
|
||
go func() {
|
||
for j := 0; j < 100; j++ {
|
||
config, condition, err := stmt.ToStreamConfig()
|
||
if err != nil {
|
||
t.Errorf("Concurrent ToStreamConfig() error = %v", err)
|
||
return
|
||
}
|
||
if config == nil {
|
||
t.Error("Concurrent ToStreamConfig() returned nil config")
|
||
return
|
||
}
|
||
if condition != "" {
|
||
t.Errorf("Concurrent ToStreamConfig() expected empty condition, got %s", condition)
|
||
return
|
||
}
|
||
}
|
||
done <- true
|
||
}()
|
||
}
|
||
|
||
// 等待所有 goroutine 完成
|
||
for i := 0; i < 10; i++ {
|
||
<-done
|
||
}
|
||
}
|
||
|
||
// TestBuildSelectFields 测试 buildSelectFields 函数
|
||
func TestBuildSelectFields(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
fields []Field
|
||
wantAggs map[string]string
|
||
wantMap map[string]string
|
||
}{
|
||
{
|
||
name: "带别名的聚合函数",
|
||
fields: []Field{
|
||
{Expression: "AVG(temperature)", Alias: "avg_temp"},
|
||
{Expression: "COUNT(*)", Alias: "total_count"},
|
||
},
|
||
wantAggs: map[string]string{
|
||
"avg_temp": "AVG",
|
||
"total_count": "COUNT",
|
||
},
|
||
wantMap: map[string]string{
|
||
"avg_temp": "temperature",
|
||
"total_count": "*",
|
||
},
|
||
},
|
||
{
|
||
name: "无别名的聚合函数",
|
||
fields: []Field{
|
||
{Expression: "SUM(amount)"},
|
||
{Expression: "MAX(price)"},
|
||
},
|
||
wantAggs: map[string]string{
|
||
"amount": "SUM",
|
||
"price": "MAX",
|
||
},
|
||
wantMap: map[string]string{
|
||
"amount": "amount",
|
||
"price": "price",
|
||
},
|
||
},
|
||
{
|
||
name: "混合字段",
|
||
fields: []Field{
|
||
{Expression: "name"},
|
||
{Expression: "COUNT(*)", Alias: "count"},
|
||
},
|
||
wantAggs: map[string]string{
|
||
"count": "COUNT",
|
||
},
|
||
wantMap: map[string]string{
|
||
"count": "*",
|
||
},
|
||
},
|
||
{
|
||
name: "空字段列表",
|
||
fields: []Field{},
|
||
wantAggs: map[string]string{},
|
||
wantMap: map[string]string{},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
aggMap, fieldMap, err := buildSelectFields(tt.fields)
|
||
if err != nil {
|
||
t.Errorf("buildSelectFields() error = %v", err)
|
||
return
|
||
}
|
||
|
||
// 检查聚合函数映射
|
||
if len(aggMap) != len(tt.wantAggs) {
|
||
t.Errorf("buildSelectFields() aggMap length = %d, want %d", len(aggMap), len(tt.wantAggs))
|
||
}
|
||
for key, want := range tt.wantAggs {
|
||
if got := string(aggMap[key]); got != want {
|
||
t.Errorf("buildSelectFields() aggMap[%s] = %s, want %s", key, got, want)
|
||
}
|
||
}
|
||
|
||
// 检查字段映射
|
||
if len(fieldMap) != len(tt.wantMap) {
|
||
t.Errorf("buildSelectFields() fieldMap length = %d, want %d", len(fieldMap), len(tt.wantMap))
|
||
}
|
||
for key, want := range tt.wantMap {
|
||
if got := fieldMap[key]; got != want {
|
||
t.Errorf("buildSelectFields() fieldMap[%s] = %s, want %s", key, got, want)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestIsAggregationFunction 测试 isAggregationFunction 函数
|
||
func TestIsAggregationFunction(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
expr string
|
||
want bool
|
||
}{
|
||
{"COUNT函数", "COUNT(*)", true},
|
||
{"AVG函数", "AVG(temperature)", true},
|
||
{"SUM函数", "SUM(amount)", true},
|
||
{"MAX函数", "MAX(price)", true},
|
||
{"MIN函数", "MIN(value)", true},
|
||
{"简单字段", "temperature", false},
|
||
{"字符串字面量", "'hello'", false},
|
||
{"数字字面量", "123", false},
|
||
{"空字符串", "", false},
|
||
{"表达式", "temperature + 10", false},
|
||
{"UPPER函数", "UPPER(name)", false},
|
||
{"CONCAT函数", "CONCAT(first_name, last_name)", false},
|
||
{"未知函数", "UNKNOWN_FUNC(field)", true}, // 保守处理
|
||
{"复杂表达式", "temperature > 25 AND humidity < 80", false},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
if got := isAggregationFunction(tt.expr); got != tt.want {
|
||
t.Errorf("isAggregationFunction(%s) = %v, want %v", tt.expr, got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestParseAggregateTypeWithExpression 测试 ParseAggregateTypeWithExpression 函数
|
||
func TestParseAggregateTypeWithExpression(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
exprStr string
|
||
wantAggType string
|
||
wantName string
|
||
wantExpression string
|
||
wantFields []string
|
||
}{
|
||
{
|
||
name: "COUNT聚合函数",
|
||
exprStr: "COUNT(*)",
|
||
wantAggType: "COUNT",
|
||
wantName: "*",
|
||
},
|
||
{
|
||
name: "AVG聚合函数",
|
||
exprStr: "AVG(temperature)",
|
||
wantAggType: "AVG",
|
||
wantName: "temperature",
|
||
},
|
||
{
|
||
name: "字符串字面量",
|
||
exprStr: "'hello world'",
|
||
wantAggType: "expression",
|
||
wantName: "hello world",
|
||
wantExpression: "'hello world'",
|
||
},
|
||
{
|
||
name: "双引号字符串",
|
||
exprStr: "\"test string\"",
|
||
wantAggType: "expression",
|
||
wantName: "test string",
|
||
wantExpression: "\"test string\"",
|
||
},
|
||
{
|
||
name: "CASE表达式",
|
||
exprStr: "CASE WHEN temperature > 25 THEN 'hot' ELSE 'cold' END",
|
||
wantAggType: "expression",
|
||
wantExpression: "CASE WHEN temperature > 25 THEN 'hot' ELSE 'cold' END",
|
||
},
|
||
{
|
||
name: "数学表达式",
|
||
exprStr: "temperature + 10",
|
||
wantAggType: "expression",
|
||
wantExpression: "temperature + 10",
|
||
},
|
||
{
|
||
name: "比较表达式",
|
||
exprStr: "temperature > 25",
|
||
wantAggType: "expression",
|
||
wantExpression: "temperature > 25",
|
||
},
|
||
{
|
||
name: "逻辑表达式",
|
||
exprStr: "temperature > 25 AND humidity < 80",
|
||
wantAggType: "expression",
|
||
wantExpression: "temperature > 25 AND humidity < 80",
|
||
},
|
||
{
|
||
name: "简单字段",
|
||
exprStr: "temperature",
|
||
wantAggType: "",
|
||
},
|
||
{
|
||
name: "UPPER字符串函数",
|
||
exprStr: "UPPER(name)",
|
||
wantAggType: "expression",
|
||
wantName: "name",
|
||
wantExpression: "UPPER(name)",
|
||
},
|
||
{
|
||
name: "CONCAT字符串函数",
|
||
exprStr: "CONCAT(first_name, last_name)",
|
||
wantAggType: "expression",
|
||
wantName: "first_name",
|
||
wantExpression: "CONCAT(first_name, last_name)",
|
||
},
|
||
}
|
||
|
||
// 测试正常情况
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
aggType, name, expression, allFields, err := ParseAggregateTypeWithExpression(tt.exprStr)
|
||
if err != nil {
|
||
t.Errorf("ParseAggregateTypeWithExpression() returned error: %v", err)
|
||
return
|
||
}
|
||
|
||
if string(aggType) != tt.wantAggType {
|
||
t.Errorf("ParseAggregateTypeWithExpression() aggType = %s, want %s", aggType, tt.wantAggType)
|
||
}
|
||
if name != tt.wantName {
|
||
t.Errorf("ParseAggregateTypeWithExpression() name = %s, want %s", name, tt.wantName)
|
||
}
|
||
if tt.wantExpression != "" && expression != tt.wantExpression {
|
||
t.Errorf("ParseAggregateTypeWithExpression() expression = %s, want %s", expression, tt.wantExpression)
|
||
}
|
||
if tt.wantFields != nil {
|
||
if len(allFields) != len(tt.wantFields) {
|
||
t.Errorf("ParseAggregateTypeWithExpression() allFields length = %d, want %d", len(allFields), len(tt.wantFields))
|
||
} else {
|
||
for i, field := range tt.wantFields {
|
||
if allFields[i] != field {
|
||
t.Errorf("ParseAggregateTypeWithExpression() allFields[%d] = %s, want %s", i, allFields[i], field)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
// 测试嵌套聚合函数检测
|
||
nestedTests := []struct {
|
||
name string
|
||
exprStr string
|
||
}{
|
||
{
|
||
name: "嵌套聚合函数 - MAX(AVG(temperature))",
|
||
exprStr: "MAX(AVG(temperature))",
|
||
},
|
||
{
|
||
name: "嵌套聚合函数 - COUNT(SUM(price))",
|
||
exprStr: "COUNT(SUM(price))",
|
||
},
|
||
{
|
||
name: "复杂嵌套 - MAX(ROUND(AVG(temperature), 1))",
|
||
exprStr: "MAX(ROUND(AVG(temperature), 1))",
|
||
},
|
||
}
|
||
|
||
for _, tt := range nestedTests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
_, _, _, _, err := ParseAggregateTypeWithExpression(tt.exprStr)
|
||
if err == nil {
|
||
t.Errorf("ParseAggregateTypeWithExpression() should return error for nested aggregation: %s", tt.exprStr)
|
||
} else if !strings.Contains(err.Error(), "aggregate function calls cannot be nested") {
|
||
t.Errorf("ParseAggregateTypeWithExpression() error message should contain 'aggregate function calls cannot be nested', got: %v", err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestDetectNestedAggregation 测试嵌套聚合函数检测
|
||
func TestDetectNestedAggregation(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
exprStr string
|
||
wantError bool
|
||
}{
|
||
{
|
||
name: "正常聚合函数",
|
||
exprStr: "MAX(temperature)",
|
||
wantError: false,
|
||
},
|
||
{
|
||
name: "嵌套聚合函数",
|
||
exprStr: "MAX(AVG(temperature))",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "复杂嵌套",
|
||
exprStr: "MAX(ROUND(AVG(temperature), 1))",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "非聚合函数嵌套",
|
||
exprStr: "UPPER(CONCAT(first_name, last_name))",
|
||
wantError: false,
|
||
},
|
||
{
|
||
name: "聚合函数包含非聚合函数",
|
||
exprStr: "MAX(ROUND(temperature, 1))",
|
||
wantError: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := detectNestedAggregation(tt.exprStr)
|
||
if tt.wantError && err == nil {
|
||
t.Errorf("detectNestedAggregation() should return error for: %s", tt.exprStr)
|
||
} else if !tt.wantError && err != nil {
|
||
t.Errorf("detectNestedAggregation() should not return error for: %s, got: %v", tt.exprStr, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestExtractAggFieldWithExpression 测试 extractAggFieldWithExpression 函数
|
||
func TestExtractAggFieldWithExpression(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
exprStr string
|
||
funcName string
|
||
wantFieldName string
|
||
wantExpression string
|
||
wantAllFields []string
|
||
}{
|
||
{
|
||
name: "COUNT星号",
|
||
exprStr: "COUNT(*)",
|
||
funcName: "count",
|
||
wantFieldName: "*",
|
||
},
|
||
{
|
||
name: "简单字段",
|
||
exprStr: "AVG(temperature)",
|
||
funcName: "AVG",
|
||
wantFieldName: "temperature",
|
||
},
|
||
{
|
||
name: "CONCAT函数",
|
||
exprStr: "CONCAT(first_name, last_name)",
|
||
funcName: "concat",
|
||
wantFieldName: "first_name",
|
||
wantExpression: "concat(first_name, last_name)",
|
||
wantAllFields: []string{"first_name", "last_name"},
|
||
},
|
||
{
|
||
name: "复杂表达式",
|
||
exprStr: "SUM(price * quantity)",
|
||
funcName: "SUM",
|
||
wantFieldName: "price",
|
||
wantExpression: "price * quantity",
|
||
},
|
||
{
|
||
name: "多参数函数",
|
||
exprStr: "DISTANCE(x1, y1, x2, y2)",
|
||
funcName: "DISTANCE",
|
||
wantFieldName: "x1",
|
||
wantExpression: "x1, y1, x2, y2",
|
||
// 不检查 allFields,因为实际行为可能与预期不同
|
||
},
|
||
{
|
||
name: "无效表达式",
|
||
exprStr: "INVALID",
|
||
funcName: "COUNT",
|
||
},
|
||
{
|
||
name: "括号不匹配",
|
||
exprStr: "COUNT(",
|
||
funcName: "COUNT",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
fieldName, expression, allFields := extractAggFieldWithExpression(tt.exprStr, tt.funcName)
|
||
|
||
if fieldName != tt.wantFieldName {
|
||
t.Errorf("extractAggFieldWithExpression() fieldName = %s, want %s", fieldName, tt.wantFieldName)
|
||
}
|
||
if tt.wantExpression != "" && expression != tt.wantExpression {
|
||
t.Errorf("extractAggFieldWithExpression() expression = %s, want %s", expression, tt.wantExpression)
|
||
}
|
||
if tt.wantAllFields != nil {
|
||
if len(allFields) != len(tt.wantAllFields) {
|
||
t.Errorf("extractAggFieldWithExpression() allFields length = %d, want %d, got fields: %v", len(allFields), len(tt.wantAllFields), allFields)
|
||
} else {
|
||
for i, field := range tt.wantAllFields {
|
||
if i < len(allFields) && allFields[i] != field {
|
||
t.Errorf("extractAggFieldWithExpression() allFields[%d] = %s, want %s", i, allFields[i], field)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|