forked from GiteaTest2015/streamsql
Merge pull request #22 from rulego/dev
feat:aggregate function can be nested with the CASE function.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -27,8 +27,8 @@ func TestAggregatorFunctionInterface(t *testing.T) {
|
||||
// 测试重置
|
||||
aggInstance.Reset()
|
||||
result = aggInstance.Result()
|
||||
if result != 0.0 {
|
||||
t.Errorf("Expected 0.0 after reset, got %v", result)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil after reset (SQL standard: SUM with no rows returns NULL), got %v", result)
|
||||
}
|
||||
|
||||
// 测试克隆
|
||||
|
||||
@@ -12,12 +12,14 @@ import (
|
||||
// SumFunction 求和函数
|
||||
type SumFunction struct {
|
||||
*BaseFunction
|
||||
value float64
|
||||
value float64
|
||||
hasValues bool // 标记是否有非NULL值
|
||||
}
|
||||
|
||||
func NewSumFunction() *SumFunction {
|
||||
return &SumFunction{
|
||||
BaseFunction: NewBaseFunction("sum", TypeAggregation, "聚合函数", "计算数值总和", 1, -1),
|
||||
hasValues: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,12 +29,20 @@ func (f *SumFunction) Validate(args []interface{}) error {
|
||||
|
||||
func (f *SumFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
sum := 0.0
|
||||
hasValues := false
|
||||
for _, arg := range args {
|
||||
if arg == nil {
|
||||
continue // 忽略NULL值
|
||||
}
|
||||
val, err := cast.ToFloat64E(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
continue // 忽略无法转换的值
|
||||
}
|
||||
sum += val
|
||||
hasValues = true
|
||||
}
|
||||
if !hasValues {
|
||||
return nil, nil // 当没有有效值时返回NULL
|
||||
}
|
||||
return sum, nil
|
||||
}
|
||||
@@ -42,27 +52,40 @@ func (f *SumFunction) New() AggregatorFunction {
|
||||
return &SumFunction{
|
||||
BaseFunction: f.BaseFunction,
|
||||
value: 0,
|
||||
hasValues: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *SumFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
f.value += val
|
||||
f.hasValues = true
|
||||
}
|
||||
// 如果转换失败,也忽略该值
|
||||
}
|
||||
|
||||
func (f *SumFunction) Result() interface{} {
|
||||
if !f.hasValues {
|
||||
return nil // 当没有有效值时返回NULL而不是0.0
|
||||
}
|
||||
return f.value
|
||||
}
|
||||
|
||||
func (f *SumFunction) Reset() {
|
||||
f.value = 0
|
||||
f.hasValues = false
|
||||
}
|
||||
|
||||
func (f *SumFunction) Clone() AggregatorFunction {
|
||||
return &SumFunction{
|
||||
BaseFunction: f.BaseFunction,
|
||||
value: f.value,
|
||||
hasValues: f.hasValues,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,14 +108,22 @@ func (f *AvgFunction) Validate(args []interface{}) error {
|
||||
|
||||
func (f *AvgFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
sum := 0.0
|
||||
count := 0
|
||||
for _, arg := range args {
|
||||
if arg == nil {
|
||||
continue // 忽略NULL值
|
||||
}
|
||||
val, err := cast.ToFloat64E(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
continue // 忽略无法转换的值
|
||||
}
|
||||
sum += val
|
||||
count++
|
||||
}
|
||||
return sum / float64(len(args)), nil
|
||||
if count == 0 {
|
||||
return nil, nil // 当没有有效值时返回nil
|
||||
}
|
||||
return sum / float64(count), nil
|
||||
}
|
||||
|
||||
// 实现AggregatorFunction接口
|
||||
@@ -105,10 +136,16 @@ func (f *AvgFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *AvgFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
f.sum += val
|
||||
f.count++
|
||||
}
|
||||
// 如果转换失败,也忽略该值
|
||||
}
|
||||
|
||||
func (f *AvgFunction) Result() interface{} {
|
||||
@@ -172,6 +209,11 @@ func (f *MinFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *MinFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
if f.first || val < f.value {
|
||||
f.value = val
|
||||
@@ -241,6 +283,11 @@ func (f *MaxFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *MaxFunction) Add(value interface{}) {
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value == nil {
|
||||
return // 忽略NULL值
|
||||
}
|
||||
|
||||
if val, err := cast.ToFloat64E(value); err == nil {
|
||||
if f.first || val > f.value {
|
||||
f.value = val
|
||||
@@ -286,7 +333,13 @@ func (f *CountFunction) Validate(args []interface{}) error {
|
||||
}
|
||||
|
||||
func (f *CountFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) {
|
||||
return int64(len(args)), nil
|
||||
count := 0
|
||||
for _, arg := range args {
|
||||
if arg != nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// 实现AggregatorFunction接口
|
||||
@@ -298,7 +351,10 @@ func (f *CountFunction) New() AggregatorFunction {
|
||||
}
|
||||
|
||||
func (f *CountFunction) Add(value interface{}) {
|
||||
f.count++
|
||||
// 增强的Add方法:忽略NULL值
|
||||
if value != nil {
|
||||
f.count++
|
||||
}
|
||||
}
|
||||
|
||||
func (f *CountFunction) Result() interface{} {
|
||||
|
||||
+90
-31
@@ -294,16 +294,40 @@ func (s *Stream) process() {
|
||||
hasNestedFields := strings.Contains(currentFieldExpr.Expression, ".")
|
||||
|
||||
if hasNestedFields {
|
||||
// 直接使用自定义表达式引擎处理嵌套字段
|
||||
// 直接使用自定义表达式引擎处理嵌套字段,支持NULL值
|
||||
expression, parseErr := expr.NewExpression(currentFieldExpr.Expression)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("expression parse failed: %w", parseErr)
|
||||
}
|
||||
|
||||
numResult, err := expression.Evaluate(dataMap)
|
||||
// 使用支持NULL的计算方法
|
||||
numResult, isNull, err := expression.EvaluateWithNull(dataMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expression evaluation failed: %w", err)
|
||||
}
|
||||
if isNull {
|
||||
return nil, nil // 返回nil表示NULL值
|
||||
}
|
||||
return numResult, nil
|
||||
}
|
||||
|
||||
// 检查是否为CASE表达式
|
||||
trimmedExpr := strings.TrimSpace(currentFieldExpr.Expression)
|
||||
upperExpr := strings.ToUpper(trimmedExpr)
|
||||
if strings.HasPrefix(upperExpr, "CASE") {
|
||||
// CASE表达式使用支持NULL的计算方法
|
||||
expression, parseErr := expr.NewExpression(currentFieldExpr.Expression)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("CASE expression parse failed: %w", parseErr)
|
||||
}
|
||||
|
||||
numResult, isNull, err := expression.EvaluateWithNull(dataMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CASE expression evaluation failed: %w", err)
|
||||
}
|
||||
if isNull {
|
||||
return nil, nil // 返回nil表示NULL值
|
||||
}
|
||||
return numResult, nil
|
||||
}
|
||||
|
||||
@@ -331,11 +355,14 @@ func (s *Stream) process() {
|
||||
return nil, fmt.Errorf("expression parse failed: %w", parseErr)
|
||||
}
|
||||
|
||||
// 计算表达式
|
||||
numResult, evalErr := expression.Evaluate(dataMap)
|
||||
// 计算表达式,支持NULL值
|
||||
numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap)
|
||||
if evalErr != nil {
|
||||
return nil, fmt.Errorf("expression evaluation failed: %w", evalErr)
|
||||
}
|
||||
if isNull {
|
||||
return nil, nil // 返回nil表示NULL值
|
||||
}
|
||||
return numResult, nil
|
||||
}
|
||||
|
||||
@@ -382,36 +409,70 @@ func (s *Stream) process() {
|
||||
|
||||
// 应用 HAVING 过滤条件
|
||||
if s.config.Having != "" {
|
||||
// 预处理HAVING条件中的LIKE语法,转换为expr-lang可理解的形式
|
||||
processedHaving := s.config.Having
|
||||
bridge := functions.GetExprBridge()
|
||||
if bridge.ContainsLikeOperator(s.config.Having) {
|
||||
if processed, err := bridge.PreprocessLikeExpression(s.config.Having); err == nil {
|
||||
processedHaving = processed
|
||||
}
|
||||
}
|
||||
// 检查HAVING条件是否包含CASE表达式
|
||||
hasCaseExpression := strings.Contains(strings.ToUpper(s.config.Having), "CASE")
|
||||
|
||||
// 预处理HAVING条件中的IS NULL语法
|
||||
if bridge.ContainsIsNullOperator(processedHaving) {
|
||||
if processed, err := bridge.PreprocessIsNullExpression(processedHaving); err == nil {
|
||||
processedHaving = processed
|
||||
}
|
||||
}
|
||||
var filteredResults []map[string]interface{}
|
||||
|
||||
// 创建 HAVING 条件
|
||||
havingFilter, err := condition.NewExprCondition(processedHaving)
|
||||
if err != nil {
|
||||
logger.Error("having filter error: %v", err)
|
||||
} else {
|
||||
// 应用 HAVING 过滤
|
||||
var filteredResults []map[string]interface{}
|
||||
for _, result := range finalResults {
|
||||
if havingFilter.Evaluate(result) {
|
||||
filteredResults = append(filteredResults, result)
|
||||
if hasCaseExpression {
|
||||
// HAVING条件包含CASE表达式,使用我们的表达式解析器
|
||||
expression, err := expr.NewExpression(s.config.Having)
|
||||
if err != nil {
|
||||
logger.Error("having filter error (CASE expression): %v", err)
|
||||
} else {
|
||||
// 应用 HAVING 过滤,使用CASE表达式计算器
|
||||
for _, result := range finalResults {
|
||||
// 使用EvaluateWithNull方法以支持NULL值处理
|
||||
havingResult, isNull, err := expression.EvaluateWithNull(result)
|
||||
if err != nil {
|
||||
logger.Error("having filter evaluation error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果结果是NULL,则不满足条件(SQL标准行为)
|
||||
if isNull {
|
||||
continue
|
||||
}
|
||||
|
||||
// 对于数值结果,大于0视为true(满足HAVING条件)
|
||||
if havingResult > 0 {
|
||||
filteredResults = append(filteredResults, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// HAVING条件不包含CASE表达式,使用原有的expr-lang处理
|
||||
// 预处理HAVING条件中的LIKE语法,转换为expr-lang可理解的形式
|
||||
processedHaving := s.config.Having
|
||||
bridge := functions.GetExprBridge()
|
||||
if bridge.ContainsLikeOperator(s.config.Having) {
|
||||
if processed, err := bridge.PreprocessLikeExpression(s.config.Having); err == nil {
|
||||
processedHaving = processed
|
||||
}
|
||||
}
|
||||
|
||||
// 预处理HAVING条件中的IS NULL语法
|
||||
if bridge.ContainsIsNullOperator(processedHaving) {
|
||||
if processed, err := bridge.PreprocessIsNullExpression(processedHaving); err == nil {
|
||||
processedHaving = processed
|
||||
}
|
||||
}
|
||||
|
||||
// 创建 HAVING 条件
|
||||
havingFilter, err := condition.NewExprCondition(processedHaving)
|
||||
if err != nil {
|
||||
logger.Error("having filter error: %v", err)
|
||||
} else {
|
||||
// 应用 HAVING 过滤
|
||||
for _, result := range finalResults {
|
||||
if havingFilter.Evaluate(result) {
|
||||
filteredResults = append(filteredResults, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
finalResults = filteredResults
|
||||
}
|
||||
|
||||
finalResults = filteredResults
|
||||
}
|
||||
|
||||
// 应用 LIMIT 限制
|
||||
@@ -496,13 +557,11 @@ func (s *Stream) processDirectData(data interface{}) {
|
||||
if bridge.ContainsIsNullOperator(processedExpr) {
|
||||
if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil {
|
||||
processedExpr = processed
|
||||
logger.Debug("Preprocessed IS NULL expression: %s -> %s", fieldExpr.Expression, processedExpr)
|
||||
}
|
||||
}
|
||||
if bridge.ContainsLikeOperator(processedExpr) {
|
||||
if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil {
|
||||
processedExpr = processed
|
||||
logger.Debug("Preprocessed LIKE expression: %s -> %s", fieldExpr.Expression, processedExpr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+617
-60
File diff suppressed because it is too large
Load Diff
@@ -2874,3 +2874,94 @@ func TestSelectAllFeature(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCaseNullValueHandlingInAggregation 测试CASE表达式在聚合函数中正确处理NULL值
|
||||
func TestCaseNullValueHandlingInAggregation(t *testing.T) {
|
||||
sql := `SELECT deviceType,
|
||||
SUM(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_sum,
|
||||
COUNT(CASE WHEN temperature > 30 THEN 1 ELSE NULL END) as high_temp_count,
|
||||
AVG(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_avg
|
||||
FROM stream
|
||||
GROUP BY deviceType, TumblingWindow('2s')`
|
||||
|
||||
// 创建StreamSQL实例
|
||||
ssql := New()
|
||||
defer ssql.Stop()
|
||||
|
||||
// 执行SQL
|
||||
err := ssql.Execute(sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 收集结果
|
||||
var results []map[string]interface{}
|
||||
resultChan := make(chan interface{}, 10)
|
||||
|
||||
ssql.Stream().AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
// 添加测试数据
|
||||
testData := []map[string]interface{}{
|
||||
{"deviceType": "sensor", "temperature": 35.0}, // 满足条件
|
||||
{"deviceType": "sensor", "temperature": 25.0}, // 不满足条件,返回NULL
|
||||
{"deviceType": "sensor", "temperature": 32.0}, // 满足条件
|
||||
{"deviceType": "monitor", "temperature": 28.0}, // 不满足条件,返回NULL
|
||||
{"deviceType": "monitor", "temperature": 33.0}, // 满足条件
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
ssql.Stream().AddData(data)
|
||||
}
|
||||
|
||||
// 等待窗口触发
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// 收集结果
|
||||
collecting:
|
||||
for {
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if resultSlice, ok := result.([]map[string]interface{}); ok {
|
||||
results = append(results, resultSlice...)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
break collecting
|
||||
}
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
assert.Len(t, results, 2, "应该有两个设备类型的结果")
|
||||
|
||||
// 验证各个deviceType的结果
|
||||
expectedResults := map[string]map[string]interface{}{
|
||||
"sensor": {
|
||||
"high_temp_sum": 67.0, // 35 + 32
|
||||
"high_temp_count": 2.0, // COUNT应该忽略NULL
|
||||
"high_temp_avg": 33.5, // (35 + 32) / 2
|
||||
},
|
||||
"monitor": {
|
||||
"high_temp_sum": 33.0, // 只有33
|
||||
"high_temp_count": 1.0, // COUNT应该忽略NULL
|
||||
"high_temp_avg": 33.0, // 只有33
|
||||
},
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
deviceType := result["deviceType"].(string)
|
||||
expected := expectedResults[deviceType]
|
||||
|
||||
assert.NotNil(t, expected, "应该有设备类型 %s 的期望结果", deviceType)
|
||||
|
||||
// 验证SUM聚合(忽略NULL值)
|
||||
assert.Equal(t, expected["high_temp_sum"], result["high_temp_sum"],
|
||||
"设备类型 %s 的SUM聚合结果应该正确", deviceType)
|
||||
|
||||
// 验证COUNT聚合(忽略NULL值)
|
||||
assert.Equal(t, expected["high_temp_count"], result["high_temp_count"],
|
||||
"设备类型 %s 的COUNT聚合结果应该正确", deviceType)
|
||||
|
||||
// 验证AVG聚合(忽略NULL值)
|
||||
assert.Equal(t, expected["high_temp_avg"], result["high_temp_avg"],
|
||||
"设备类型 %s 的AVG聚合结果应该正确", deviceType)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user