feat:支持聚合函数的后运算 #37

This commit is contained in:
rulego-team
2025-08-28 19:19:51 +08:00
parent de6ca91c01
commit 4615b7a308
27 changed files with 3404 additions and 1213 deletions
+24
View File
@@ -20,6 +20,7 @@ const (
WindowStart = functions.WindowStart
WindowEnd = functions.WindowEnd
Collect = functions.Collect
FirstValue = functions.FirstValue
LastValue = functions.LastValue
MergeAgg = functions.MergeAgg
StdDevS = functions.StdDevS
@@ -33,6 +34,8 @@ const (
HadChanged = functions.HadChanged
// Expression aggregator for handling custom functions
Expression = functions.Expression
// Post-aggregation marker
PostAggregation = functions.PostAggregation
)
// AggregatorFunction aggregator function interface, re-exports functions.LegacyAggregatorFunction
@@ -55,9 +58,30 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction {
}
}
// Special handling for post-aggregation type (placeholder aggregator)
if aggType == "post_aggregation" {
return &PostAggregationPlaceholder{}
}
return functions.CreateLegacyAggregator(aggType)
}
// PostAggregationPlaceholder is a placeholder aggregator for post-aggregation fields
type PostAggregationPlaceholder struct{}
func (p *PostAggregationPlaceholder) New() AggregatorFunction {
return &PostAggregationPlaceholder{}
}
func (p *PostAggregationPlaceholder) Add(value interface{}) {
// Do nothing - this is just a placeholder
}
func (p *PostAggregationPlaceholder) Result() interface{} {
// Return nil - actual result will be computed in post-processing
return nil
}
// ExpressionAggregatorWrapper wraps expression aggregator to make it compatible with LegacyAggregatorFunction interface
type ExpressionAggregatorWrapper struct {
function *functions.ExpressionAggregatorFunction
+16 -3
View File
@@ -140,6 +140,12 @@ func (ga *GroupAggregator) isNumericAggregator(aggType AggregateType) bool {
return false
}
// shouldAllowNullValues 判断聚合函数是否应该允许NULL值
func (ga *GroupAggregator) shouldAllowNullValues(aggType AggregateType) bool {
// FIRST_VALUE和LAST_VALUE函数应该允许NULL值,因为它们需要记录第一个/最后一个值,即使是NULL
return aggType == FirstValue || aggType == LastValue
}
func (ga *GroupAggregator) Add(data interface{}) error {
ga.mu.Lock()
defer ga.mu.Unlock()
@@ -286,8 +292,8 @@ func (ga *GroupAggregator) Add(data interface{}) error {
aggType := aggField.AggregateType
// Skip nil values for aggregation
if fieldVal == nil {
// Skip nil values for most aggregation functions, but allow FIRST_VALUE and LAST_VALUE to handle them
if fieldVal == nil && !ga.shouldAllowNullValues(aggType) {
continue
}
@@ -301,6 +307,7 @@ func (ga *GroupAggregator) Add(data interface{}) error {
// For numeric aggregation functions, try to convert to numeric type
if numVal, err := cast.ToFloat64E(fieldVal); err == nil {
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
groupAgg.Add(numVal)
}
} else {
@@ -309,6 +316,7 @@ func (ga *GroupAggregator) Add(data interface{}) error {
} else {
// For non-numeric aggregation functions, pass original value directly
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
groupAgg.Add(fieldVal)
}
}
@@ -336,7 +344,12 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
}
}
for field, agg := range aggregators {
group[field] = agg.Result()
result := agg.Result()
group[field] = result
// Debug: log aggregator results (can be removed in production)
// if strings.HasPrefix(field, "__") {
// fmt.Printf("Aggregator %s result: %v (%T)\n", field, result, result)
// }
}
result = append(result, group)
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+43
View File
@@ -23,6 +23,13 @@ type AnalyticalFunction interface {
AggregatorFunction
}
// ParameterizedFunction defines the interface for functions that need parameter initialization
type ParameterizedFunction interface {
AggregatorFunction
// Init initializes the function with parsed arguments
Init(args []interface{}) error
}
// CreateAggregator creates an aggregator instance
func CreateAggregator(name string) (AggregatorFunction, error) {
fn, exists := Get(name)
@@ -37,6 +44,42 @@ func CreateAggregator(name string) (AggregatorFunction, error) {
return nil, fmt.Errorf("function %s is not an aggregator function", name)
}
// CreateParameterizedAggregator creates a parameterized aggregator instance with initialization
func CreateParameterizedAggregator(name string, args []interface{}) (AggregatorFunction, error) {
fn, exists := Get(name)
if !exists {
return nil, fmt.Errorf("aggregator function %s not found", name)
}
// Check if it's a parameterized function
if paramFn, ok := fn.(ParameterizedFunction); ok {
newInstance := paramFn.New()
if paramNewInstance, ok := newInstance.(ParameterizedFunction); ok {
if err := paramNewInstance.Init(args); err != nil {
return nil, fmt.Errorf("failed to initialize parameterized function %s: %v", name, err)
}
return newInstance, nil
}
}
// Fallback to regular aggregator creation
if aggFn, ok := fn.(AggregatorFunction); ok {
return aggFn.New(), nil
}
return nil, fmt.Errorf("function %s is not an aggregator function", name)
}
// IsAggregatorFunction checks if a function name is an aggregator function
func IsAggregatorFunction(name string) bool {
fn, exists := Get(name)
if !exists {
return false
}
_, ok := fn.(AggregatorFunction)
return ok
}
// CreateAnalytical creates an analytical function instance
func CreateAnalytical(name string) (AnalyticalFunction, error) {
fn, exists := Get(name)
+6
View File
@@ -18,6 +18,7 @@ const (
WindowStart AggregateType = "window_start"
WindowEnd AggregateType = "window_end"
Collect AggregateType = "collect"
FirstValue AggregateType = "first_value"
LastValue AggregateType = "last_value"
MergeAgg AggregateType = "merge_agg"
StdDev AggregateType = "stddev"
@@ -32,6 +33,8 @@ const (
HadChanged AggregateType = "had_changed"
// Expression aggregator for handling custom functions
Expression AggregateType = "expression"
// Post-aggregation marker for fields that need post-processing
PostAggregation AggregateType = "post_aggregation"
)
// String constant versions for convenience
@@ -46,6 +49,7 @@ const (
WindowStartStr = string(WindowStart)
WindowEndStr = string(WindowEnd)
CollectStr = string(Collect)
FirstValueStr = string(FirstValue)
LastValueStr = string(LastValue)
MergeAggStr = string(MergeAgg)
StdStr = "std"
@@ -61,6 +65,8 @@ const (
HadChangedStr = string(HadChanged)
// Expression aggregator
ExpressionStr = string(Expression)
// Post-aggregation marker
PostAggregationStr = string(PostAggregation)
)
// LegacyAggregatorFunction defines aggregator function interface compatible with legacy aggregator interface
+23 -3
View File
@@ -133,7 +133,7 @@ func TestCreateLegacyAggregatorPanic(t *testing.T) {
func TestFunctionAggregatorWrapper(t *testing.T) {
// 创建一个测试聚合器函数
testAgg := &TestAggregatorFunction{}
// 创建一个测试适配器
adapter := &AggregatorAdapter{
aggFunc: testAgg,
@@ -157,7 +157,7 @@ func TestFunctionAggregatorWrapper(t *testing.T) {
func TestAnalyticalAggregatorWrapper(t *testing.T) {
// 创建一个测试分析函数
testAnalFunc := &TestAnalyticalFunction{}
// 创建一个测试适配器
adapter := &AnalyticalAggregatorAdapter{
analFunc: testAnalFunc,
@@ -268,6 +268,16 @@ func (t *TestAggregatorFunction) Execute(ctx *FunctionContext, args []interface{
return t.Result(), nil
}
// GetMinArgs 返回最小参数数量
func (t *TestAggregatorFunction) GetMinArgs() int {
return 1
}
// GetMaxArgs 返回最大参数数量
func (t *TestAggregatorFunction) GetMaxArgs() int {
return 1
}
// TestAnalyticalFunction 测试用的分析函数实现
type TestAnalyticalFunction struct {
values []interface{}
@@ -335,4 +345,14 @@ func (t *TestAnalyticalFunction) Validate(args []interface{}) error {
// Execute 执行函数
func (t *TestAnalyticalFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
return t.Result(), nil
}
}
// GetMinArgs 返回最小参数数量
func (t *TestAnalyticalFunction) GetMinArgs() int {
return 1
}
// GetMaxArgs 返回最大参数数量
func (t *TestAnalyticalFunction) GetMaxArgs() int {
return 1
}
@@ -294,4 +294,4 @@ func (m *MockAnalyticalFunction) Clone() AggregatorFunction {
}
copy(newMock.values, m.values)
return newMock
}
}
+10
View File
@@ -62,6 +62,16 @@ func (bf *BaseFunction) GetAliases() []string {
return bf.aliases
}
// GetMinArgs returns the minimum number of arguments
func (bf *BaseFunction) GetMinArgs() int {
return bf.minArgs
}
// GetMaxArgs returns the maximum number of arguments (-1 means unlimited)
func (bf *BaseFunction) GetMaxArgs() int {
return bf.maxArgs
}
// ValidateArgCount validates the number of arguments
func (bf *BaseFunction) ValidateArgCount(args []interface{}) error {
argCount := len(args)
+3 -5
View File
@@ -83,6 +83,7 @@ func registerBuiltinFunctions() {
_ = Register(NewMedianAggregatorFunction())
_ = Register(NewPercentileFunction())
_ = Register(NewCollectFunction())
_ = Register(NewFirstValueFunction())
_ = Register(NewLastValueFunction())
_ = Register(NewMergeAggFunction())
_ = Register(NewStdDevSAggregatorFunction())
@@ -91,8 +92,9 @@ func registerBuiltinFunctions() {
_ = Register(NewVarSAggregatorFunction())
// Window functions
_ = Register(NewWindowStartFunction())
_ = Register(NewWindowEndFunction())
_ = Register(NewRowNumberFunction())
_ = Register(NewFirstValueFunction())
_ = Register(NewLeadFunction())
_ = Register(NewNthValueFunction())
@@ -102,10 +104,6 @@ func registerBuiltinFunctions() {
_ = Register(NewChangedColFunction())
_ = Register(NewHadChangedFunction())
// Window functions
_ = Register(NewWindowStartFunction())
_ = Register(NewWindowEndFunction())
// Expression functions
_ = Register(NewExpressionFunction())
_ = Register(NewExprFunction())
+7 -2
View File
@@ -54,12 +54,17 @@ func (bridge *ExprBridge) RegisterStreamSQLFunctionsToExpr() []expr.Option {
// Add function to expr environment
bridge.exprEnv[name] = wrappedFunc
bridge.exprEnv[strings.ToUpper(name)] = wrappedFunc
// Register function type information
// Register function type information for both lowercase and uppercase
options = append(options, expr.Function(
name,
wrappedFunc,
))
options = append(options, expr.Function(
strings.ToUpper(name),
wrappedFunc,
))
}
return options
@@ -143,7 +148,7 @@ func (bridge *ExprBridge) CompileExpressionWithStreamSQLFunctions(expression str
// 启用一些有用的expr功能
options = append(options,
expr.AllowUndefinedVariables(), // 允许未定义变量
expr.AsBool(), // 期望布尔结果(可根据需要调整)
// 移除 expr.AsBool() 以允许返回任意类型的值
)
return expr.Compile(expression, options...)
+80 -3
View File
@@ -150,7 +150,7 @@ func (f *AvgFunction) Add(value interface{}) {
func (f *AvgFunction) Result() interface{} {
if f.count == 0 {
return nil // Return nil when no valid values instead of 0.0
return nil // Return NULL when no valid values according to SQL standard
}
return f.sum / float64(f.count)
}
@@ -187,6 +187,13 @@ func (f *MinFunction) Validate(args []interface{}) error {
}
func (f *MinFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
// 检查是否有nil参数
for _, arg := range args {
if arg == nil {
return nil, nil
}
}
min := math.Inf(1)
for _, arg := range args {
val, err := cast.ToFloat64E(arg)
@@ -224,7 +231,7 @@ func (f *MinFunction) Add(value interface{}) {
func (f *MinFunction) Result() interface{} {
if f.first {
return nil
return nil // Return NULL when no data according to SQL standard
}
return f.value
}
@@ -261,6 +268,13 @@ func (f *MaxFunction) Validate(args []interface{}) error {
}
func (f *MaxFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
// 检查是否有nil参数
for _, arg := range args {
if arg == nil {
return nil, nil
}
}
max := math.Inf(-1)
for _, arg := range args {
val, err := cast.ToFloat64E(arg)
@@ -298,7 +312,7 @@ func (f *MaxFunction) Add(value interface{}) {
func (f *MaxFunction) Result() interface{} {
if f.first {
return nil
return nil // Return NULL when no data according to SQL standard
}
return f.value
}
@@ -582,6 +596,69 @@ func (f *CollectFunction) Clone() AggregatorFunction {
return newFunc
}
// FirstValueFunction 首个值函数 - 返回组中第一行的值
type FirstValueFunction struct {
*BaseFunction
firstValue interface{}
hasValue bool
}
func NewFirstValueFunction() *FirstValueFunction {
return &FirstValueFunction{
BaseFunction: NewBaseFunction("first_value", TypeAggregation, "聚合函数", "返回第一个值", 1, -1),
firstValue: nil,
hasValue: false,
}
}
func (f *FirstValueFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *FirstValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
if len(args) == 0 {
return nil, fmt.Errorf("function %s requires at least one argument", f.GetName())
}
// 返回第一个值
return args[0], nil
}
// 实现AggregatorFunction接口
func (f *FirstValueFunction) New() AggregatorFunction {
return &FirstValueFunction{
BaseFunction: f.BaseFunction,
firstValue: nil,
hasValue: false,
}
}
func (f *FirstValueFunction) Add(value interface{}) {
if !f.hasValue {
f.firstValue = value
f.hasValue = true
}
}
func (f *FirstValueFunction) Result() interface{} {
return f.firstValue
}
func (f *FirstValueFunction) Reset() {
f.firstValue = nil
f.hasValue = false
}
func (f *FirstValueFunction) Clone() AggregatorFunction {
return &FirstValueFunction{
BaseFunction: f.BaseFunction,
firstValue: f.firstValue,
hasValue: f.hasValue,
}
}
// LastValueFunction 最后值函数 - 返回组中最后一行的值
type LastValueFunction struct {
*BaseFunction
+11
View File
@@ -24,6 +24,17 @@ func (f *IfNullFunction) Validate(args []interface{}) error {
func (f *IfNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if args[0] == nil {
// 当第一个参数为nil时,返回第二个参数
// 如果第二个参数是数字0,确保返回float64类型以保持一致性
if args[1] != nil {
// 尝试转换为float64以保持数值类型一致性
if val, ok := args[1].(int); ok && val == 0 {
return 0.0, nil
}
if val, ok := args[1].(float32); ok {
return float64(val), nil
}
}
return args[1], nil
}
return args[0], nil
+10
View File
@@ -550,6 +550,11 @@ func (f *RoundFunction) Validate(args []interface{}) error {
}
func (f *RoundFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
// 检查第一个参数是否为nil
if args[0] == nil {
return nil, nil
}
val, err := cast.ToFloat64E(args[0])
if err != nil {
return nil, err
@@ -559,6 +564,11 @@ func (f *RoundFunction) Execute(ctx *FunctionContext, args []interface{}) (inter
return math.Round(val), nil
}
// 检查第二个参数是否为nil(如果存在)
if args[1] == nil {
return nil, nil
}
precision, err := cast.ToIntE(args[1])
if err != nil {
return nil, err
+75 -69
View File
@@ -38,7 +38,7 @@ type WindowStartFunction struct {
func NewWindowStartFunction() *WindowStartFunction {
return &WindowStartFunction{
BaseFunction: NewBaseFunction("window_start", TypeWindow, "window", "Return window start time", 0, 0),
BaseFunction: NewBaseFunction("window_start", TypeWindow, "窗口函数", "返回窗口开始时间", 0, 0),
}
}
@@ -88,7 +88,7 @@ type WindowEndFunction struct {
func NewWindowEndFunction() *WindowEndFunction {
return &WindowEndFunction{
BaseFunction: NewBaseFunction("window_end", TypeWindow, "window", "Return window end time", 0, 0),
BaseFunction: NewBaseFunction("window_end", TypeWindow, "窗口函数", "返回窗口结束时间", 0, 0),
}
}
@@ -246,63 +246,6 @@ func (f *ExpressionAggregatorFunction) Clone() AggregatorFunction {
}
}
// FirstValueFunction 返回窗口中第一个值
type FirstValueFunction struct {
*BaseFunction
firstValue interface{}
hasValue bool
}
func NewFirstValueFunction() *FirstValueFunction {
return &FirstValueFunction{
BaseFunction: NewBaseFunction("first_value", TypeWindow, "窗口函数", "返回窗口中第一个值", 1, 1),
hasValue: false,
}
}
func (f *FirstValueFunction) Validate(args []interface{}) error {
return f.ValidateArgCount(args)
}
func (f *FirstValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
if err := f.Validate(args); err != nil {
return nil, err
}
return f.firstValue, nil
}
// 实现AggregatorFunction接口
func (f *FirstValueFunction) New() AggregatorFunction {
return &FirstValueFunction{
BaseFunction: f.BaseFunction,
hasValue: false,
}
}
func (f *FirstValueFunction) Add(value interface{}) {
if !f.hasValue {
f.firstValue = value
f.hasValue = true
}
}
func (f *FirstValueFunction) Result() interface{} {
return f.firstValue
}
func (f *FirstValueFunction) Reset() {
f.firstValue = nil
f.hasValue = false
}
func (f *FirstValueFunction) Clone() AggregatorFunction {
return &FirstValueFunction{
BaseFunction: f.BaseFunction,
firstValue: f.firstValue,
hasValue: f.hasValue,
}
}
// LeadFunction 返回当前行之后第N行的值
type LeadFunction struct {
*BaseFunction
@@ -376,9 +319,9 @@ func (f *LeadFunction) New() AggregatorFunction {
return &LeadFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, 0),
offset: f.offset,
defaultValue: f.defaultValue,
hasDefault: f.hasDefault,
offset: f.offset, // 保持offset参数
defaultValue: f.defaultValue, // 保持默认值
hasDefault: f.hasDefault, // 保持默认值标志
}
}
@@ -387,12 +330,12 @@ func (f *LeadFunction) Add(value interface{}) {
}
func (f *LeadFunction) Result() interface{} {
// Lead函数的结果需要在所有数据添加完成后计算
// 如果没有足够的数据,返回默认值
if len(f.values) == 0 && f.hasDefault {
// LEAD函数在没有指定当前行位置的情况下,返回默认值或nil
// 这通常用于聚合场景,真正的窗口计算需要在窗口处理器中进行
if f.hasDefault {
return f.defaultValue
}
// 这里简化实现,返回nil
return nil
}
@@ -415,6 +358,41 @@ func (f *LeadFunction) Clone() AggregatorFunction {
return clone
}
// Init implements ParameterizedFunction interface
func (f *LeadFunction) Init(args []interface{}) error {
if len(args) < 2 {
// LEAD with default offset = 1
f.offset = 1
return nil
}
// Parse offset parameter
offset := 1
if offsetVal, ok := args[1].(int); ok {
offset = offsetVal
} else if offsetVal, ok := args[1].(int64); ok {
offset = int(offsetVal)
} else if offsetVal, ok := args[1].(float64); ok {
offset = int(offsetVal)
} else {
return fmt.Errorf("lead offset must be an integer, got %T", args[1])
}
if offset < 0 {
return fmt.Errorf("lead offset must be non-negative, got %d", offset)
}
f.offset = offset
// Parse default value if provided
if len(args) >= 3 {
f.defaultValue = args[2]
f.hasDefault = true
}
return nil
}
// NthValueFunction 返回窗口中第N个值
type NthValueFunction struct {
*BaseFunction
@@ -484,11 +462,13 @@ func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (in
// 实现AggregatorFunction接口
func (f *NthValueFunction) New() AggregatorFunction {
return &NthValueFunction{
newInstance := &NthValueFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, 0),
n: f.n,
n: f.n, // 保持n参数
}
return newInstance
}
func (f *NthValueFunction) Add(value interface{}) {
@@ -510,8 +490,34 @@ func (f *NthValueFunction) Clone() AggregatorFunction {
clone := &NthValueFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, len(f.values)),
n: f.n,
n: f.n, // 保持n参数
}
copy(clone.values, f.values)
return clone
}
// Init implements ParameterizedFunction interface
func (f *NthValueFunction) Init(args []interface{}) error {
if len(args) < 2 {
return fmt.Errorf("nth_value requires at least 2 arguments")
}
// Parse N parameter
n := 1
if nVal, ok := args[1].(int); ok {
n = nVal
} else if nVal, ok := args[1].(int64); ok {
n = int(nVal)
} else if nVal, ok := args[1].(float64); ok {
n = int(nVal)
} else {
return fmt.Errorf("nth_value n must be an integer, got %T", args[1])
}
if n <= 0 {
return fmt.Errorf("nth_value n must be positive, got %d", n)
}
f.n = n
return nil
}
+5
View File
@@ -61,6 +61,11 @@ type Function interface {
Execute(ctx *FunctionContext, args []interface{}) (interface{}, error)
// GetDescription returns the function description
GetDescription() string
// GetMinArgs returns the minimum number of arguments
GetMinArgs() int
// GetMaxArgs returns the maximum number of arguments (-1 means unlimited)
GetMaxArgs() int
}
// FunctionRegistry manages function registration and retrieval
+1 -1
View File
@@ -3,7 +3,7 @@ module github.com/rulego/streamsql
go 1.18
require (
github.com/expr-lang/expr v1.17.2
github.com/expr-lang/expr v1.17.6
github.com/stretchr/testify v1.10.0
)
+2 -2
View File
@@ -1,7 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/expr-lang/expr v1.17.2 h1:o0A99O/Px+/DTjEnQiodAgOIK9PPxL8DtXhBRKC+Iso=
github.com/expr-lang/expr v1.17.2/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4=
github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec=
github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+477 -119
View File
File diff suppressed because it is too large Load Diff
+87 -2
View File
@@ -360,7 +360,11 @@ func TestBuildSelectFields(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aggMap, fieldMap := buildSelectFields(tt.fields)
aggMap, fieldMap, err := buildSelectFields(tt.fields)
if err != nil {
t.Errorf("buildSelectFields() error = %v", err)
return
}
// 检查聚合函数映射
if len(aggMap) != len(tt.wantAggs) {
@@ -498,9 +502,14 @@ func TestParseAggregateTypeWithExpression(t *testing.T) {
},
}
// 测试正常情况
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aggType, name, expression, allFields := ParseAggregateTypeWithExpression(tt.exprStr)
aggType, name, expression, allFields, err := ParseAggregateTypeWithExpression(tt.exprStr)
if err != nil {
t.Errorf("ParseAggregateTypeWithExpression() returned error: %v", err)
return
}
if string(aggType) != tt.wantAggType {
t.Errorf("ParseAggregateTypeWithExpression() aggType = %s, want %s", aggType, tt.wantAggType)
@@ -524,6 +533,82 @@ func TestParseAggregateTypeWithExpression(t *testing.T) {
}
})
}
// 测试嵌套聚合函数检测
nestedTests := []struct {
name string
exprStr string
}{
{
name: "嵌套聚合函数 - MAX(AVG(temperature))",
exprStr: "MAX(AVG(temperature))",
},
{
name: "嵌套聚合函数 - COUNT(SUM(price))",
exprStr: "COUNT(SUM(price))",
},
{
name: "复杂嵌套 - MAX(ROUND(AVG(temperature), 1))",
exprStr: "MAX(ROUND(AVG(temperature), 1))",
},
}
for _, tt := range nestedTests {
t.Run(tt.name, func(t *testing.T) {
_, _, _, _, err := ParseAggregateTypeWithExpression(tt.exprStr)
if err == nil {
t.Errorf("ParseAggregateTypeWithExpression() should return error for nested aggregation: %s", tt.exprStr)
} else if !strings.Contains(err.Error(), "aggregate function calls cannot be nested") {
t.Errorf("ParseAggregateTypeWithExpression() error message should contain 'aggregate function calls cannot be nested', got: %v", err)
}
})
}
}
// TestDetectNestedAggregation 测试嵌套聚合函数检测
func TestDetectNestedAggregation(t *testing.T) {
tests := []struct {
name string
exprStr string
wantError bool
}{
{
name: "正常聚合函数",
exprStr: "MAX(temperature)",
wantError: false,
},
{
name: "嵌套聚合函数",
exprStr: "MAX(AVG(temperature))",
wantError: true,
},
{
name: "复杂嵌套",
exprStr: "MAX(ROUND(AVG(temperature), 1))",
wantError: true,
},
{
name: "非聚合函数嵌套",
exprStr: "UPPER(CONCAT(first_name, last_name))",
wantError: false,
},
{
name: "聚合函数包含非聚合函数",
exprStr: "MAX(ROUND(temperature, 1))",
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := detectNestedAggregation(tt.exprStr)
if tt.wantError && err == nil {
t.Errorf("detectNestedAggregation() should return error for: %s", tt.exprStr)
} else if !tt.wantError && err != nil {
t.Errorf("detectNestedAggregation() should not return error for: %s, got: %v", tt.exprStr, err)
}
})
}
}
// TestExtractAggFieldWithExpression 测试 extractAggFieldWithExpression 函数
+5 -1
View File
@@ -653,7 +653,11 @@ func TestBuildSelectFieldsWithExpressions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aggMap, fieldMap, expressions := buildSelectFieldsWithExpressions(tt.fields)
aggMap, fieldMap, expressions, _, err := buildSelectFieldsWithExpressions(tt.fields)
if err != nil {
t.Errorf("buildSelectFieldsWithExpressions() error = %v", err)
return
}
tt.checkFunc(t, aggMap, fieldMap, expressions)
})
}
+1 -1
View File
@@ -185,7 +185,7 @@ func (p *Parser) createCombinedError() error {
for i, err := range errors {
builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, err.Error()))
}
return fmt.Errorf(builder.String())
return fmt.Errorf("%s", builder.String())
}
func (p *Parser) parseSelect(stmt *SelectStatement) error {
+39 -1
View File
@@ -88,7 +88,30 @@ func (dp *DataProcessor) Process() {
func (dp *DataProcessor) initializeAggregator() {
// Convert to new AggregationField format
aggregationFields := convertToAggregationFields(dp.stream.config.SelectFields, dp.stream.config.FieldAlias)
dp.stream.aggregator = aggregator.NewGroupAggregator(dp.stream.config.GroupFields, aggregationFields)
// Check if we have post-aggregation expressions
if len(dp.stream.config.PostAggExpressions) > 0 {
// Use enhanced aggregator for post-aggregation support
enhancedAgg := aggregator.NewEnhancedGroupAggregator(dp.stream.config.GroupFields, aggregationFields)
// Add post-aggregation expressions
for _, postExpr := range dp.stream.config.PostAggExpressions {
err := enhancedAgg.AddPostAggregationExpression(
postExpr.OutputField,
postExpr.OriginalExpr,
convertToAggregationFieldInfos(postExpr.RequiredFields),
)
if err != nil {
// Log error but continue
fmt.Printf("Error adding post-aggregation expression %s: %v\n", postExpr.OriginalExpr, err)
}
}
dp.stream.aggregator = enhancedAgg
} else {
// Use regular aggregator
dp.stream.aggregator = aggregator.NewGroupAggregator(dp.stream.config.GroupFields, aggregationFields)
}
// Register expression calculators
for field, fieldExpr := range dp.stream.config.FieldExpressions {
@@ -96,6 +119,21 @@ func (dp *DataProcessor) initializeAggregator() {
}
}
// convertToAggregationFieldInfos converts types.AggregationFieldInfo to aggregator.AggregationFieldInfo
func convertToAggregationFieldInfos(fields []types.AggregationFieldInfo) []aggregator.AggregationFieldInfo {
result := make([]aggregator.AggregationFieldInfo, len(fields))
for i, field := range fields {
result[i] = aggregator.AggregationFieldInfo{
FuncName: field.FuncName,
InputField: field.InputField,
Placeholder: field.Placeholder,
AggType: field.AggType,
FullCall: field.FullCall, // 保持FullCall字段
}
}
return result
}
// registerExpressionCalculator registers expression calculator
func (dp *DataProcessor) registerExpressionCalculator(field string, fieldExpr types.FieldExpression) {
// Create local variables to avoid closure issues
File diff suppressed because it is too large Load Diff
+27 -55
View File
@@ -640,7 +640,7 @@ func TestNestedFunctionSupport(t *testing.T) {
// 执行包含 avg(round(temperature, 2)) 的查询
query := "SELECT device, avg(round(temperature, 2)) as avg_rounded FROM stream GROUP BY device, TumblingWindow('1s')"
t.Logf("Executing query: %s", query)
err := streamsql.Execute(query)
assert.Nil(t, err)
@@ -686,7 +686,6 @@ func TestNestedFunctionSupport(t *testing.T) {
} else if val, ok := avgRounded.(float64); ok {
// 期望值:avg(20.57, 25.23, 30.12) = (20.57 + 25.23 + 30.12) / 3 = 25.31
assert.InEpsilon(t, 25.31, val, 0.01)
t.Logf("avg(round()) test passed: %v", val)
} else {
t.Errorf("avg_rounded is not a float64: %v (type: %T)", avgRounded, avgRounded)
}
@@ -740,9 +739,6 @@ func TestNestedFunctionSupport(t *testing.T) {
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
for key, value := range item {
t.Logf(" %s: %v (type: %T)", key, value, value)
}
assert.Equal(t, "sensor1", item["device"])
@@ -798,7 +794,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
t.Logf("Result: %+v", item)
// 验证执行顺序:round(25.67, 1) -> 25.7, concat('temp_', '25.7') -> 'temp_25.7', upper('temp_25.7') -> 'TEMP_25.7'
assert.Equal(t, "TEMP_25.7", item["formatted_temp"])
@@ -814,7 +809,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
defer streamsql.Stop()
query := "SELECT device, round(len(upper(device)), 0) as device_length FROM stream"
t.Logf("Executing query: %s", query)
err := streamsql.Execute(query)
assert.Nil(t, err)
@@ -838,7 +833,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
t.Logf("Result: %+v", item)
// 验证执行顺序:upper('sensor1') -> 'SENSOR1', len('SENSOR1') -> 7, round(7, 0) -> 7
assert.Equal(t, float64(7), item["device_length"])
@@ -854,7 +848,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
defer streamsql.Stop()
query := "SELECT device, abs(round(sqrt(temperature), 2)) as processed_temp FROM stream"
t.Logf("Executing query: %s", query)
err := streamsql.Execute(query)
assert.Nil(t, err)
@@ -878,8 +872,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
//t.Logf("Result: %+v", item)
// 验证执行顺序:sqrt(16) -> 4, round(4, 2) -> 4.00, abs(4.00) -> 4.00
assert.Equal(t, float64(4), item["processed_temp"])
case <-ctx.Done():
@@ -887,56 +879,40 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
}
})
// 测试6: 复杂的聚合函数嵌套
// 测试6: 复杂的聚合函数嵌套 - 应该报错
t.Run("ComplexAggregationNesting", func(t *testing.T) {
// 测试 max(round(avg(temperature), 1))
// 测试 max(round(avg(temperature), 1)) - 这是嵌套聚合函数,应该报错
streamsql := New()
defer streamsql.Stop()
query := "SELECT device, max(round(avg(temperature), 1)) as max_rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')"
t.Logf("Executing query: %s", query)
err := streamsql.Execute(query)
assert.Nil(t, err)
// 应该返回嵌套聚合函数错误
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "aggregate function calls cannot be nested")
})
strm := streamsql.stream
resultChan := make(chan interface{}, 10)
strm.AddSink(func(result []map[string]interface{}) {
resultChan <- result
})
// 测试7: 其他类型的嵌套聚合函数检测
t.Run("NestedAggregationDetection", func(t *testing.T) {
streamsql := New()
defer streamsql.Stop()
// 添加测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 20.567},
{"device": "sensor1", "temperature": 25.234},
{"device": "sensor1", "temperature": 30.123},
}
// 测试 sum(count(*)) - 聚合函数嵌套聚合函数
query1 := "SELECT sum(count(*)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')"
err1 := streamsql.Execute(query1)
assert.NotNil(t, err1)
assert.Contains(t, err1.Error(), "aggregate function calls cannot be nested")
for _, data := range testData {
strm.Emit(data)
}
// 测试 avg(min(temperature)) - 聚合函数嵌套聚合函数
query2 := "SELECT avg(min(temperature)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')"
err2 := streamsql.Execute(query2)
assert.NotNil(t, err2)
assert.Contains(t, err2.Error(), "aggregate function calls cannot be nested")
// 等待窗口初始化
time.Sleep(1 * time.Second)
strm.Window.Trigger()
// 等待结果
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
select {
case result := <-resultChan:
resultSlice, ok := result.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
//t.Logf("Result: %+v", item)
// 验证执行顺序:avg(20.567, 25.234, 30.123) -> 25.308, round(25.308, 1) -> 25.3, max(25.3) -> 25.3
assert.InEpsilon(t, 25.3, item["max_rounded_avg"], 0.01)
case <-ctx.Done():
t.Fatal("测试超时")
}
// 测试 round(avg(temperature), 1) - 正常函数嵌套聚合函数,应该正常
query3 := "SELECT round(avg(temperature), 1) as normal_nesting FROM stream GROUP BY device, TumblingWindow('1s')"
err3 := streamsql.Execute(query3)
assert.Nil(t, err3) // 这种嵌套应该是允许的
})
// 测试7: 日期时间函数嵌套
@@ -946,7 +922,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
defer streamsql.Stop()
query := "SELECT device, year(date_add(created_at, 1, 'years')) as next_year FROM stream"
t.Logf("Executing query: %s", query)
err := streamsql.Execute(query)
assert.Nil(t, err)
@@ -970,7 +945,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
//t.Logf("Result: %+v", item)
// 验证执行顺序:date_add('2023-12-25 15:30:45', 1, 'years') -> '2024-12-25 15:30:45', year('2024-12-25 15:30:45') -> 2024
assert.Equal(t, float64(2024), item["next_year"])
@@ -986,7 +960,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
defer streamsql.Stop()
query := "SELECT device, sqrt(len(invalid_field)) as error_result FROM stream"
t.Logf("Executing query: %s", query)
err := streamsql.Execute(query)
assert.Nil(t, err)
@@ -1010,7 +983,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) {
assert.Len(t, resultSlice, 1)
item := resultSlice[0]
t.Logf("Error handling result: %+v", item)
// 验证错误处理:invalid_field不存在,应该返回nil或默认值
_, exists := item["error_result"]
File diff suppressed because it is too large Load Diff
+27 -9
View File
@@ -9,15 +9,16 @@ import (
// Config stream processing configuration
type Config struct {
// SQL processing related configuration
WindowConfig WindowConfig `json:"windowConfig"`
GroupFields []string `json:"groupFields"`
SelectFields map[string]aggregator.AggregateType `json:"selectFields"`
FieldAlias map[string]string `json:"fieldAlias"`
SimpleFields []string `json:"simpleFields"`
FieldExpressions map[string]FieldExpression `json:"fieldExpressions"`
FieldOrder []string `json:"fieldOrder"` // Original order of fields in SELECT statement
Where string `json:"where"`
Having string `json:"having"`
WindowConfig WindowConfig `json:"windowConfig"`
GroupFields []string `json:"groupFields"`
SelectFields map[string]aggregator.AggregateType `json:"selectFields"`
FieldAlias map[string]string `json:"fieldAlias"`
SimpleFields []string `json:"simpleFields"`
FieldExpressions map[string]FieldExpression `json:"fieldExpressions"`
PostAggExpressions []PostAggregationExpression `json:"postAggExpressions"` // Post-aggregation expressions
FieldOrder []string `json:"fieldOrder"` // Original order of fields in SELECT statement
Where string `json:"where"`
Having string `json:"having"`
// Feature switches
NeedWindow bool `json:"needWindow"`
@@ -47,6 +48,23 @@ type FieldExpression struct {
Fields []string `json:"fields"` // all fields referenced in expression
}
// PostAggregationExpression represents an expression that needs to be evaluated after aggregation
type PostAggregationExpression struct {
OutputField string `json:"outputField"` // 输出字段名
OriginalExpr string `json:"originalExpr"` // 原始表达式
ExpressionTemplate string `json:"expressionTemplate"` // 表达式模板
RequiredFields []AggregationFieldInfo `json:"requiredFields"` // 依赖的聚合字段
}
// AggregationFieldInfo holds information about an aggregation function in an expression
type AggregationFieldInfo struct {
FuncName string `json:"funcName"` // 函数名,如 "first_value"
InputField string `json:"inputField"` // 输入字段,如 "displayNum"
Placeholder string `json:"placeholder"` // 占位符,如 "__first_value_0__"
AggType aggregator.AggregateType `json:"aggType"` // 聚合类型
FullCall string `json:"fullCall"` // 完整函数调用,如 "NTH_VALUE(value, 2)"
}
// ProjectionSourceType projection source type
type ProjectionSourceType int