From 2e24739c0de709f5bbdda81705cc3ddd3d252b13 Mon Sep 17 00:00:00 2001 From: rulego-team Date: Sun, 27 Jul 2025 13:10:16 +0800 Subject: [PATCH] Fix processDirectDataSync nested field handling and data race in sync tests --- stream/stream.go | 58 +++++++++++++++++++++++++++++++++++++++++--- sync_sink_test.go | 62 ++++++++++++++++++++++++++++++++--------------- window/factory.go | 3 ++- 3 files changed, 98 insertions(+), 25 deletions(-) diff --git a/stream/stream.go b/stream/stream.go index e41c747..c56c973 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -1678,6 +1678,17 @@ func (s *Stream) processDirectDataSync(data interface{}) (interface{}, error) { // 检查表达式是否是函数调用(包含括号) isFunctionCall := strings.Contains(fieldExpr.Expression, "(") && strings.Contains(fieldExpr.Expression, ")") + // 检查表达式是否包含嵌套字段(但排除函数调用中的点号) + hasNestedFields := false + if !isFunctionCall && strings.Contains(fieldExpr.Expression, ".") { + hasNestedFields = true + } + + // 检查是否为CASE表达式 + trimmedExpr := strings.TrimSpace(fieldExpr.Expression) + upperExpr := strings.ToUpper(trimmedExpr) + isCaseExpression := strings.HasPrefix(upperExpr, SQLKeywordCase) + var evalResult interface{} if isFunctionCall { @@ -1689,15 +1700,54 @@ func (s *Stream) processDirectDataSync(data interface{}) (interface{}, error) { continue } evalResult = exprResult - } else { - // 直接使用桥接器处理表达式 - exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + } else if hasNestedFields || isCaseExpression { + // 检测到嵌套字段(非函数调用)或CASE表达式,使用自定义表达式引擎 + expression, parseErr := expr.NewExpression(fieldExpr.Expression) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + continue + } + + // 使用支持NULL的计算方法 + numResult, isNull, err := expression.EvaluateWithNull(dataMap) if err != nil { logger.Error("Expression evaluation failed for field %s: %v", fieldName, err) result[fieldName] = nil continue } - evalResult = exprResult + if isNull { + evalResult = nil // NULL值 + } else { + evalResult = numResult + } + } else { + // 尝试使用桥接器处理其他表达式 + exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if err != nil { + // 如果桥接器失败,回退到原来的表达式引擎(使用原始表达式,不是预处理的) + expression, parseErr := expr.NewExpression(fieldExpr.Expression) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + continue + } + + // 计算表达式,支持NULL值 + numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap) + if evalErr != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, evalErr) + result[fieldName] = nil + continue + } + if isNull { + evalResult = nil // NULL值 + } else { + evalResult = numResult + } + } else { + evalResult = exprResult + } } result[fieldName] = evalResult diff --git a/sync_sink_test.go b/sync_sink_test.go index fc91a74..30c89be 100644 --- a/sync_sink_test.go +++ b/sync_sink_test.go @@ -17,6 +17,7 @@ package streamsql import ( + "sync" "sync/atomic" "testing" "time" @@ -42,9 +43,12 @@ func TestEmitSyncWithAddSink(t *testing.T) { // 设置AddSink回调来收集异步结果 var sinkCallCount int32 var sinkResults []interface{} + var sinkResultsMux sync.Mutex // 保护sinkResults访问 ssql.AddSink(func(result interface{}) { atomic.AddInt32(&sinkCallCount, 1) + sinkResultsMux.Lock() sinkResults = append(sinkResults, result) + sinkResultsMux.Unlock() }) // 测试数据 @@ -73,33 +77,51 @@ func TestEmitSyncWithAddSink(t *testing.T) { // 验证同步结果 assert.Equal(t, 2, len(syncResults), "应该有2条同步结果(温度>20)") + // 安全读取异步回调结果 + sinkResultsMux.Lock() + finalSinkResults := make([]interface{}, len(sinkResults)) + copy(finalSinkResults, sinkResults) + sinkResultsMux.Unlock() + // 验证异步回调结果 finalSinkCallCount := atomic.LoadInt32(&sinkCallCount) assert.Equal(t, int32(2), finalSinkCallCount, "AddSink应该被调用2次") - assert.Equal(t, 2, len(sinkResults), "应该收集到2条异步结果") + assert.Equal(t, 2, len(finalSinkResults), "应该收集到2条异步结果") // 验证同步和异步结果的内容一致性 - if len(syncResults) > 0 && len(sinkResults) > 0 { - // 检查第一个结果 - syncResult, ok1 := syncResults[0].(map[string]interface{}) - require.True(t, ok1, "同步结果应该是map类型") + if len(syncResults) > 0 && len(finalSinkResults) > 0 { + // 将结果转换为可比较的格式 + syncTemperatures := make([]float64, 0, len(syncResults)) + syncHumidities := make([]float64, 0, len(syncResults)) + asyncTemperatures := make([]float64, 0, len(finalSinkResults)) + asyncHumidities := make([]float64, 0, len(finalSinkResults)) - // AddSink收到的是数组格式 []map[string]interface{} - sinkResultArray, ok2 := sinkResults[0].([]map[string]interface{}) - require.True(t, ok2, "异步结果应该是数组类型") - require.True(t, len(sinkResultArray) > 0, "异步结果数组不应为空") + // 收集同步结果 + for _, result := range syncResults { + if syncResult, ok := result.(map[string]interface{}); ok { + syncTemperatures = append(syncTemperatures, syncResult["temperature"].(float64)) + syncHumidities = append(syncHumidities, syncResult["humidity"].(float64)) + } + } - sinkResult := sinkResultArray[0] + // 收集异步结果 + for _, result := range finalSinkResults { + if sinkResultArray, ok := result.([]map[string]interface{}); ok && len(sinkResultArray) > 0 { + sinkResult := sinkResultArray[0] + asyncTemperatures = append(asyncTemperatures, sinkResult["temperature"].(float64)) + asyncHumidities = append(asyncHumidities, sinkResult["humidity"].(float64)) + } + } - // 验证关键字段 - assert.Equal(t, 25.0, syncResult["temperature"]) - assert.Equal(t, 25.0, sinkResult["temperature"]) - assert.Equal(t, 60.0, syncResult["humidity"]) - assert.Equal(t, 60.0, sinkResult["humidity"]) + // 验证结果集合是否一致(不考虑顺序) + assert.ElementsMatch(t, syncTemperatures, asyncTemperatures, "温度值集合应该一致") + assert.ElementsMatch(t, syncHumidities, asyncHumidities, "湿度值集合应该一致") - // 验证计算字段 - assert.InDelta(t, 77.0, syncResult["temp_fahrenheit"], 0.1) - assert.InDelta(t, 77.0, sinkResult["temp_fahrenheit"], 0.1) + // 验证预期的数值是否都存在 + assert.Contains(t, syncTemperatures, 25.0, "同步结果应包含25.0") + assert.Contains(t, syncTemperatures, 30.0, "同步结果应包含30.0") + assert.Contains(t, asyncTemperatures, 25.0, "异步结果应包含25.0") + assert.Contains(t, asyncTemperatures, 30.0, "异步结果应包含30.0") } }) @@ -133,7 +155,7 @@ func TestEmitSyncWithAddSink(t *testing.T) { err := ssql.Execute(sql) require.NoError(t, err) - // 添加多个AddSink回调 + // 添加多个AddSink回调,使用原子操作确保线程安全 var sink1Count, sink2Count, sink3Count int32 ssql.AddSink(func(result interface{}) { @@ -201,7 +223,7 @@ func TestEmitSyncPerformance(t *testing.T) { err := ssql.Execute(sql) require.NoError(t, err) - // 添加AddSink回调 + // 添加AddSink回调,使用原子操作确保线程安全 var sinkCallCount int32 ssql.AddSink(func(result interface{}) { atomic.AddInt32(&sinkCallCount, 1) diff --git a/window/factory.go b/window/factory.go index 32dbaf0..9bc14ad 100644 --- a/window/factory.go +++ b/window/factory.go @@ -2,10 +2,11 @@ package window import ( "fmt" - "github.com/rulego/streamsql/utils/cast" "reflect" "time" + "github.com/rulego/streamsql/utils/cast" + "github.com/rulego/streamsql/types" )