feat:aggregate function can be nested with the CASE function.

This commit is contained in:
rulego-team
2025-06-16 20:22:37 +08:00
parent f9c21ff4bd
commit 1500dc5d23
6 changed files with 1286 additions and 74 deletions

File diff suppressed because it is too large Load Diff

View File

@ -27,8 +27,8 @@ func TestAggregatorFunctionInterface(t *testing.T) {
// 测试重置
aggInstance.Reset()
result = aggInstance.Result()
if result != 0.0 {
t.Errorf("Expected 0.0 after reset, got %v", result)
if result != nil {
t.Errorf("Expected nil after reset (SQL standard: SUM with no rows returns NULL), got %v", result)
}
// 测试克隆

View File

@ -12,12 +12,14 @@ import (
// SumFunction 求和函数
type SumFunction struct {
*BaseFunction
value float64
value float64
hasValues bool // 标记是否有非NULL值
}
func NewSumFunction() *SumFunction {
return &SumFunction{
BaseFunction: NewBaseFunction("sum", TypeAggregation, "聚合函数", "计算数值总和", 1, -1),
hasValues: false,
}
}
@ -27,12 +29,20 @@ func (f *SumFunction) Validate(args []interface{}) error {
func (f *SumFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
sum := 0.0
hasValues := false
for _, arg := range args {
if arg == nil {
continue // 忽略NULL值
}
val, err := cast.ToFloat64E(arg)
if err != nil {
return nil, err
continue // 忽略无法转换的值
}
sum += val
hasValues = true
}
if !hasValues {
return nil, nil // 当没有有效值时返回NULL
}
return sum, nil
}
@ -42,27 +52,40 @@ func (f *SumFunction) New() AggregatorFunction {
return &SumFunction{
BaseFunction: f.BaseFunction,
value: 0,
hasValues: false,
}
}
func (f *SumFunction) Add(value interface{}) {
// 增强的Add方法忽略NULL值
if value == nil {
return // 忽略NULL值
}
if val, err := cast.ToFloat64E(value); err == nil {
f.value += val
f.hasValues = true
}
// 如果转换失败,也忽略该值
}
func (f *SumFunction) Result() interface{} {
if !f.hasValues {
return nil // 当没有有效值时返回NULL而不是0.0
}
return f.value
}
func (f *SumFunction) Reset() {
f.value = 0
f.hasValues = false
}
func (f *SumFunction) Clone() AggregatorFunction {
return &SumFunction{
BaseFunction: f.BaseFunction,
value: f.value,
hasValues: f.hasValues,
}
}
@ -85,14 +108,22 @@ func (f *AvgFunction) Validate(args []interface{}) error {
func (f *AvgFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
sum := 0.0
count := 0
for _, arg := range args {
if arg == nil {
continue // 忽略NULL值
}
val, err := cast.ToFloat64E(arg)
if err != nil {
return nil, err
continue // 忽略无法转换的值
}
sum += val
count++
}
return sum / float64(len(args)), nil
if count == 0 {
return nil, nil // 当没有有效值时返回nil
}
return sum / float64(count), nil
}
// 实现AggregatorFunction接口
@ -105,10 +136,16 @@ func (f *AvgFunction) New() AggregatorFunction {
}
func (f *AvgFunction) Add(value interface{}) {
// 增强的Add方法忽略NULL值
if value == nil {
return // 忽略NULL值
}
if val, err := cast.ToFloat64E(value); err == nil {
f.sum += val
f.count++
}
// 如果转换失败,也忽略该值
}
func (f *AvgFunction) Result() interface{} {
@ -172,6 +209,11 @@ func (f *MinFunction) New() AggregatorFunction {
}
func (f *MinFunction) Add(value interface{}) {
// 增强的Add方法忽略NULL值
if value == nil {
return // 忽略NULL值
}
if val, err := cast.ToFloat64E(value); err == nil {
if f.first || val < f.value {
f.value = val
@ -241,6 +283,11 @@ func (f *MaxFunction) New() AggregatorFunction {
}
func (f *MaxFunction) Add(value interface{}) {
// 增强的Add方法忽略NULL值
if value == nil {
return // 忽略NULL值
}
if val, err := cast.ToFloat64E(value); err == nil {
if f.first || val > f.value {
f.value = val
@ -286,7 +333,13 @@ func (f *CountFunction) Validate(args []interface{}) error {
}
func (f *CountFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
return int64(len(args)), nil
count := 0
for _, arg := range args {
if arg != nil {
count++
}
}
return int64(count), nil
}
// 实现AggregatorFunction接口
@ -298,7 +351,10 @@ func (f *CountFunction) New() AggregatorFunction {
}
func (f *CountFunction) Add(value interface{}) {
f.count++
// 增强的Add方法忽略NULL值
if value != nil {
f.count++
}
}
func (f *CountFunction) Result() interface{} {

View File

@ -294,16 +294,40 @@ func (s *Stream) process() {
hasNestedFields := strings.Contains(currentFieldExpr.Expression, ".")
if hasNestedFields {
// 直接使用自定义表达式引擎处理嵌套字段
// 直接使用自定义表达式引擎处理嵌套字段支持NULL值
expression, parseErr := expr.NewExpression(currentFieldExpr.Expression)
if parseErr != nil {
return nil, fmt.Errorf("expression parse failed: %w", parseErr)
}
numResult, err := expression.Evaluate(dataMap)
// 使用支持NULL的计算方法
numResult, isNull, err := expression.EvaluateWithNull(dataMap)
if err != nil {
return nil, fmt.Errorf("expression evaluation failed: %w", err)
}
if isNull {
return nil, nil // 返回nil表示NULL值
}
return numResult, nil
}
// 检查是否为CASE表达式
trimmedExpr := strings.TrimSpace(currentFieldExpr.Expression)
upperExpr := strings.ToUpper(trimmedExpr)
if strings.HasPrefix(upperExpr, "CASE") {
// CASE表达式使用支持NULL的计算方法
expression, parseErr := expr.NewExpression(currentFieldExpr.Expression)
if parseErr != nil {
return nil, fmt.Errorf("CASE expression parse failed: %w", parseErr)
}
numResult, isNull, err := expression.EvaluateWithNull(dataMap)
if err != nil {
return nil, fmt.Errorf("CASE expression evaluation failed: %w", err)
}
if isNull {
return nil, nil // 返回nil表示NULL值
}
return numResult, nil
}
@ -331,11 +355,14 @@ func (s *Stream) process() {
return nil, fmt.Errorf("expression parse failed: %w", parseErr)
}
// 计算表达式
numResult, evalErr := expression.Evaluate(dataMap)
// 计算表达式支持NULL值
numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap)
if evalErr != nil {
return nil, fmt.Errorf("expression evaluation failed: %w", evalErr)
}
if isNull {
return nil, nil // 返回nil表示NULL值
}
return numResult, nil
}
@ -496,13 +523,11 @@ func (s *Stream) processDirectData(data interface{}) {
if bridge.ContainsIsNullOperator(processedExpr) {
if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil {
processedExpr = processed
logger.Debug("Preprocessed IS NULL expression: %s -> %s", fieldExpr.Expression, processedExpr)
}
}
if bridge.ContainsLikeOperator(processedExpr) {
if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil {
processedExpr = processed
logger.Debug("Preprocessed LIKE expression: %s -> %s", fieldExpr.Expression, processedExpr)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -2874,3 +2874,94 @@ func TestSelectAllFeature(t *testing.T) {
}
})
}
// TestCaseNullValueHandlingInAggregation 测试CASE表达式在聚合函数中正确处理NULL值
func TestCaseNullValueHandlingInAggregation(t *testing.T) {
sql := `SELECT deviceType,
SUM(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_sum,
COUNT(CASE WHEN temperature > 30 THEN 1 ELSE NULL END) as high_temp_count,
AVG(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_avg
FROM stream
GROUP BY deviceType, TumblingWindow('2s')`
// 创建StreamSQL实例
ssql := New()
defer ssql.Stop()
// 执行SQL
err := ssql.Execute(sql)
require.NoError(t, err)
// 收集结果
var results []map[string]interface{}
resultChan := make(chan interface{}, 10)
ssql.Stream().AddSink(func(result interface{}) {
resultChan <- result
})
// 添加测试数据
testData := []map[string]interface{}{
{"deviceType": "sensor", "temperature": 35.0}, // 满足条件
{"deviceType": "sensor", "temperature": 25.0}, // 不满足条件返回NULL
{"deviceType": "sensor", "temperature": 32.0}, // 满足条件
{"deviceType": "monitor", "temperature": 28.0}, // 不满足条件返回NULL
{"deviceType": "monitor", "temperature": 33.0}, // 满足条件
}
for _, data := range testData {
ssql.Stream().AddData(data)
}
// 等待窗口触发
time.Sleep(3 * time.Second)
// 收集结果
collecting:
for {
select {
case result := <-resultChan:
if resultSlice, ok := result.([]map[string]interface{}); ok {
results = append(results, resultSlice...)
}
case <-time.After(500 * time.Millisecond):
break collecting
}
}
// 验证结果
assert.Len(t, results, 2, "应该有两个设备类型的结果")
// 验证各个deviceType的结果
expectedResults := map[string]map[string]interface{}{
"sensor": {
"high_temp_sum": 67.0, // 35 + 32
"high_temp_count": 2.0, // COUNT应该忽略NULL
"high_temp_avg": 33.5, // (35 + 32) / 2
},
"monitor": {
"high_temp_sum": 33.0, // 只有33
"high_temp_count": 1.0, // COUNT应该忽略NULL
"high_temp_avg": 33.0, // 只有33
},
}
for _, result := range results {
deviceType := result["deviceType"].(string)
expected := expectedResults[deviceType]
assert.NotNil(t, expected, "应该有设备类型 %s 的期望结果", deviceType)
// 验证SUM聚合忽略NULL值
assert.Equal(t, expected["high_temp_sum"], result["high_temp_sum"],
"设备类型 %s 的SUM聚合结果应该正确", deviceType)
// 验证COUNT聚合忽略NULL值
assert.Equal(t, expected["high_temp_count"], result["high_temp_count"],
"设备类型 %s 的COUNT聚合结果应该正确", deviceType)
// 验证AVG聚合忽略NULL值
assert.Equal(t, expected["high_temp_avg"], result["high_temp_avg"],
"设备类型 %s 的AVG聚合结果应该正确", deviceType)
}
}