mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-07-05 15:49:14 +00:00
feat:aggregate function can be nested with the CASE function.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -27,8 +27,8 @@ func TestAggregatorFunctionInterface(t *testing.T) {
|
||||
// 测试重置
|
||||
aggInstance.Reset()
|
||||
result = aggInstance.Result()
|
||||
if result != 0.0 {
|
||||
t.Errorf("Expected 0.0 after reset, got %v", result)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil after reset (SQL standard: SUM with no rows returns NULL), got %v", result)
|
||||
}
|
||||
|
||||
// 测试克隆
|
||||
|
@ -12,12 +12,14 @@ import (
|
||||
// SumFunction 求和函数
|
||||
type SumFunction struct {
|
||||
*BaseFunction
|
||||
value float64
|
||||
value float64
|
||||
hasValues bool // 标记是否有非NULL值
|
||||
}
|
||||
|
||||
func NewSumFunction() *SumFunction {
|
||||
return &SumFunction{
|
||||
BaseFunction: NewBaseFunction("sum", TypeAggregation, "聚合函数", "计算数值总和", 1, -1),
|
||||
hasValues: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,12 +29,20 @@ func (f *SumFunction) Validate(args []interface{}) error {
|
||||
|
||||
func (f *SumFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
sum := 0.0
|
||||
hasValues := false
|
||||
for _, arg := range args {
|
||||
if arg == nil {
|
||||
continue // 忽略NULL值
|
||||
}
|
||||
val, err := cast.ToFloat64E(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
continue // 忽略无法转换的值
|
||||
}
|
||||
sum += val
|
||||
hasValues = true
|
||||
}
|
||||
if !hasValues {
|
||||
return nil, nil // 当没有有效值时返回NULL
|
||||
}
|
||||
return sum, nil
|
||||
}
|
||||
@ -42,27 +52,40 @@ func (f *SumFunction) New() AggregatorFunction {
|
||||
return &SumFunction{
|
||||
BaseFunction: f.BaseFunction,
|
||||
value: 0,
|
||||
hasValues: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *SumFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
f.value += val
|
||||
f.hasValues = true
|
||||
}
|
||||
// 如果转换失败,也忽略该值
|
||||
}
|
||||
|
||||
func (f *SumFunction) Result() interface{} {
|
||||
if !f.hasValues {
|
||||
return nil // 当没有有效值时返回NULL而不是0.0
|
||||
}
|
||||
return f.value
|
||||
}
|
||||
|
||||
func (f *SumFunction) Reset() {
|
||||
f.value = 0
|
||||
f.hasValues = false
|
||||
}
|
||||
|
||||
func (f *SumFunction) Clone() AggregatorFunction {
|
||||
return &SumFunction{
|
||||
BaseFunction: f.BaseFunction,
|
||||
value: f.value,
|
||||
hasValues: f.hasValues,
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,14 +108,22 @@ func (f *AvgFunction) Validate(args []interface{}) error {
|
||||
|
||||
func (f *AvgFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
sum := 0.0
|
||||
count := 0
|
||||
for _, arg := range args {
|
||||
if arg == nil {
|
||||
continue // 忽略NULL值
|
||||
}
|
||||
val, err := cast.ToFloat64E(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
continue // 忽略无法转换的值
|
||||
}
|
||||
sum += val
|
||||
count++
|
||||
}
|
||||
return sum / float64(len(args)), nil
|
||||
if count == 0 {
|
||||
return nil, nil // 当没有有效值时返回nil
|
||||
}
|
||||
return sum / float64(count), nil
|
||||
}
|
||||
|
||||
// 实现AggregatorFunction接口
|
||||
@ -105,10 +136,16 @@ func (f *AvgFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *AvgFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
f.sum += val
|
||||
f.count++
|
||||
}
|
||||
// 如果转换失败,也忽略该值
|
||||
}
|
||||
|
||||
func (f *AvgFunction) Result() interface{} {
|
||||
@ -172,6 +209,11 @@ func (f *MinFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *MinFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
if f.first || val < f.value {
|
||||
f.value = val
|
||||
@ -241,6 +283,11 @@ func (f *MaxFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *MaxFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
if f.first || val > f.value {
|
||||
f.value = val
|
||||
@ -286,7 +333,13 @@ func (f *CountFunction) Validate(args []interface{}) error {
|
||||
}
|
||||
|
||||
func (f *CountFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
return int64(len(args)), nil
|
||||
count := 0
|
||||
for _, arg := range args {
|
||||
if arg != nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// 实现AggregatorFunction接口
|
||||
@ -298,7 +351,10 @@ func (f *CountFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *CountFunction) Add(value interface{}) {
|
||||
f.count++
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value != nil {
|
||||
f.count++
|
||||
}
|
||||
}
|
||||
|
||||
func (f *CountFunction) Result() interface{} {
|
||||
|
@ -294,16 +294,40 @@ func (s *Stream) process() {
|
||||
hasNestedFields := strings.Contains(currentFieldExpr.Expression, ".")
|
||||
|
||||
if hasNestedFields {
|
||||
// 直接使用自定义表达式引擎处理嵌套字段
|
||||
// 直接使用自定义表达式引擎处理嵌套字段,支持NULL值
|
||||
expression, parseErr := expr.NewExpression(currentFieldExpr.Expression)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("expression parse failed: %w", parseErr)
|
||||
}
|
||||
|
||||
numResult, err := expression.Evaluate(dataMap)
|
||||
// 使用支持NULL的计算方法
|
||||
numResult, isNull, err := expression.EvaluateWithNull(dataMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expression evaluation failed: %w", err)
|
||||
}
|
||||
if isNull {
|
||||
return nil, nil // 返回nil表示NULL值
|
||||
}
|
||||
return numResult, nil
|
||||
}
|
||||
|
||||
// 检查是否为CASE表达式
|
||||
trimmedExpr := strings.TrimSpace(currentFieldExpr.Expression)
|
||||
upperExpr := strings.ToUpper(trimmedExpr)
|
||||
if strings.HasPrefix(upperExpr, "CASE") {
|
||||
// CASE表达式使用支持NULL的计算方法
|
||||
expression, parseErr := expr.NewExpression(currentFieldExpr.Expression)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("CASE expression parse failed: %w", parseErr)
|
||||
}
|
||||
|
||||
numResult, isNull, err := expression.EvaluateWithNull(dataMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CASE expression evaluation failed: %w", err)
|
||||
}
|
||||
if isNull {
|
||||
return nil, nil // 返回nil表示NULL值
|
||||
}
|
||||
return numResult, nil
|
||||
}
|
||||
|
||||
@ -331,11 +355,14 @@ func (s *Stream) process() {
|
||||
return nil, fmt.Errorf("expression parse failed: %w", parseErr)
|
||||
}
|
||||
|
||||
// 计算表达式
|
||||
numResult, evalErr := expression.Evaluate(dataMap)
|
||||
// 计算表达式,支持NULL值
|
||||
numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap)
|
||||
if evalErr != nil {
|
||||
return nil, fmt.Errorf("expression evaluation failed: %w", evalErr)
|
||||
}
|
||||
if isNull {
|
||||
return nil, nil // 返回nil表示NULL值
|
||||
}
|
||||
return numResult, nil
|
||||
}
|
||||
|
||||
@ -496,13 +523,11 @@ func (s *Stream) processDirectData(data interface{}) {
|
||||
if bridge.ContainsIsNullOperator(processedExpr) {
|
||||
if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil {
|
||||
processedExpr = processed
|
||||
logger.Debug("Preprocessed IS NULL expression: %s -> %s", fieldExpr.Expression, processedExpr)
|
||||
}
|
||||
}
|
||||
if bridge.ContainsLikeOperator(processedExpr) {
|
||||
if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil {
|
||||
processedExpr = processed
|
||||
logger.Debug("Preprocessed LIKE expression: %s -> %s", fieldExpr.Expression, processedExpr)
|
||||
}
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2874,3 +2874,94 @@ func TestSelectAllFeature(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCaseNullValueHandlingInAggregation 测试CASE表达式在聚合函数中正确处理NULL值
|
||||
func TestCaseNullValueHandlingInAggregation(t *testing.T) {
|
||||
sql := `SELECT deviceType,
|
||||
SUM(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_sum,
|
||||
COUNT(CASE WHEN temperature > 30 THEN 1 ELSE NULL END) as high_temp_count,
|
||||
AVG(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_avg
|
||||
FROM stream
|
||||
GROUP BY deviceType, TumblingWindow('2s')`
|
||||
|
||||
// 创建StreamSQL实例
|
||||
ssql := New()
|
||||
defer ssql.Stop()
|
||||
|
||||
// 执行SQL
|
||||
err := ssql.Execute(sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 收集结果
|
||||
var results []map[string]interface{}
|
||||
resultChan := make(chan interface{}, 10)
|
||||
|
||||
ssql.Stream().AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
// 添加测试数据
|
||||
testData := []map[string]interface{}{
|
||||
{"deviceType": "sensor", "temperature": 35.0}, // 满足条件
|
||||
{"deviceType": "sensor", "temperature": 25.0}, // 不满足条件,返回NULL
|
||||
{"deviceType": "sensor", "temperature": 32.0}, // 满足条件
|
||||
{"deviceType": "monitor", "temperature": 28.0}, // 不满足条件,返回NULL
|
||||
{"deviceType": "monitor", "temperature": 33.0}, // 满足条件
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
ssql.Stream().AddData(data)
|
||||
}
|
||||
|
||||
// 等待窗口触发
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// 收集结果
|
||||
collecting:
|
||||
for {
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if resultSlice, ok := result.([]map[string]interface{}); ok {
|
||||
results = append(results, resultSlice...)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
break collecting
|
||||
}
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
assert.Len(t, results, 2, "应该有两个设备类型的结果")
|
||||
|
||||
// 验证各个deviceType的结果
|
||||
expectedResults := map[string]map[string]interface{}{
|
||||
"sensor": {
|
||||
"high_temp_sum": 67.0, // 35 + 32
|
||||
"high_temp_count": 2.0, // COUNT应该忽略NULL
|
||||
"high_temp_avg": 33.5, // (35 + 32) / 2
|
||||
},
|
||||
"monitor": {
|
||||
"high_temp_sum": 33.0, // 只有33
|
||||
"high_temp_count": 1.0, // COUNT应该忽略NULL
|
||||
"high_temp_avg": 33.0, // 只有33
|
||||
},
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
deviceType := result["deviceType"].(string)
|
||||
expected := expectedResults[deviceType]
|
||||
|
||||
assert.NotNil(t, expected, "应该有设备类型 %s 的期望结果", deviceType)
|
||||
|
||||
// 验证SUM聚合(忽略NULL值)
|
||||
assert.Equal(t, expected["high_temp_sum"], result["high_temp_sum"],
|
||||
"设备类型 %s 的SUM聚合结果应该正确", deviceType)
|
||||
|
||||
// 验证COUNT聚合(忽略NULL值)
|
||||
assert.Equal(t, expected["high_temp_count"], result["high_temp_count"],
|
||||
"设备类型 %s 的COUNT聚合结果应该正确", deviceType)
|
||||
|
||||
// 验证AVG聚合(忽略NULL值)
|
||||
assert.Equal(t, expected["high_temp_avg"], result["high_temp_avg"],
|
||||
"设备类型 %s 的AVG聚合结果应该正确", deviceType)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user