mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-03-11 21:07:18 +00:00
Fix processDirectDataSync nested field handling and data race in sync tests
This commit is contained in:
+54
-4
@@ -1678,6 +1678,17 @@ func (s *Stream) processDirectDataSync(data interface{}) (interface{}, error) {
|
||||
// 检查表达式是否是函数调用(包含括号)
|
||||
isFunctionCall := strings.Contains(fieldExpr.Expression, "(") && strings.Contains(fieldExpr.Expression, ")")
|
||||
|
||||
// 检查表达式是否包含嵌套字段(但排除函数调用中的点号)
|
||||
hasNestedFields := false
|
||||
if !isFunctionCall && strings.Contains(fieldExpr.Expression, ".") {
|
||||
hasNestedFields = true
|
||||
}
|
||||
|
||||
// 检查是否为CASE表达式
|
||||
trimmedExpr := strings.TrimSpace(fieldExpr.Expression)
|
||||
upperExpr := strings.ToUpper(trimmedExpr)
|
||||
isCaseExpression := strings.HasPrefix(upperExpr, SQLKeywordCase)
|
||||
|
||||
var evalResult interface{}
|
||||
|
||||
if isFunctionCall {
|
||||
@@ -1689,15 +1700,54 @@ func (s *Stream) processDirectDataSync(data interface{}) (interface{}, error) {
|
||||
continue
|
||||
}
|
||||
evalResult = exprResult
|
||||
} else {
|
||||
// 直接使用桥接器处理表达式
|
||||
exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap)
|
||||
} else if hasNestedFields || isCaseExpression {
|
||||
// 检测到嵌套字段(非函数调用)或CASE表达式,使用自定义表达式引擎
|
||||
expression, parseErr := expr.NewExpression(fieldExpr.Expression)
|
||||
if parseErr != nil {
|
||||
logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr)
|
||||
result[fieldName] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// 使用支持NULL的计算方法
|
||||
numResult, isNull, err := expression.EvaluateWithNull(dataMap)
|
||||
if err != nil {
|
||||
logger.Error("Expression evaluation failed for field %s: %v", fieldName, err)
|
||||
result[fieldName] = nil
|
||||
continue
|
||||
}
|
||||
evalResult = exprResult
|
||||
if isNull {
|
||||
evalResult = nil // NULL值
|
||||
} else {
|
||||
evalResult = numResult
|
||||
}
|
||||
} else {
|
||||
// 尝试使用桥接器处理其他表达式
|
||||
exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap)
|
||||
if err != nil {
|
||||
// 如果桥接器失败,回退到原来的表达式引擎(使用原始表达式,不是预处理的)
|
||||
expression, parseErr := expr.NewExpression(fieldExpr.Expression)
|
||||
if parseErr != nil {
|
||||
logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr)
|
||||
result[fieldName] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算表达式,支持NULL值
|
||||
numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap)
|
||||
if evalErr != nil {
|
||||
logger.Error("Expression evaluation failed for field %s: %v", fieldName, evalErr)
|
||||
result[fieldName] = nil
|
||||
continue
|
||||
}
|
||||
if isNull {
|
||||
evalResult = nil // NULL值
|
||||
} else {
|
||||
evalResult = numResult
|
||||
}
|
||||
} else {
|
||||
evalResult = exprResult
|
||||
}
|
||||
}
|
||||
|
||||
result[fieldName] = evalResult
|
||||
|
||||
+42
-20
@@ -17,6 +17,7 @@
|
||||
package streamsql
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -42,9 +43,12 @@ func TestEmitSyncWithAddSink(t *testing.T) {
|
||||
// 设置AddSink回调来收集异步结果
|
||||
var sinkCallCount int32
|
||||
var sinkResults []interface{}
|
||||
var sinkResultsMux sync.Mutex // 保护sinkResults访问
|
||||
ssql.AddSink(func(result interface{}) {
|
||||
atomic.AddInt32(&sinkCallCount, 1)
|
||||
sinkResultsMux.Lock()
|
||||
sinkResults = append(sinkResults, result)
|
||||
sinkResultsMux.Unlock()
|
||||
})
|
||||
|
||||
// 测试数据
|
||||
@@ -73,33 +77,51 @@ func TestEmitSyncWithAddSink(t *testing.T) {
|
||||
// 验证同步结果
|
||||
assert.Equal(t, 2, len(syncResults), "应该有2条同步结果(温度>20)")
|
||||
|
||||
// 安全读取异步回调结果
|
||||
sinkResultsMux.Lock()
|
||||
finalSinkResults := make([]interface{}, len(sinkResults))
|
||||
copy(finalSinkResults, sinkResults)
|
||||
sinkResultsMux.Unlock()
|
||||
|
||||
// 验证异步回调结果
|
||||
finalSinkCallCount := atomic.LoadInt32(&sinkCallCount)
|
||||
assert.Equal(t, int32(2), finalSinkCallCount, "AddSink应该被调用2次")
|
||||
assert.Equal(t, 2, len(sinkResults), "应该收集到2条异步结果")
|
||||
assert.Equal(t, 2, len(finalSinkResults), "应该收集到2条异步结果")
|
||||
|
||||
// 验证同步和异步结果的内容一致性
|
||||
if len(syncResults) > 0 && len(sinkResults) > 0 {
|
||||
// 检查第一个结果
|
||||
syncResult, ok1 := syncResults[0].(map[string]interface{})
|
||||
require.True(t, ok1, "同步结果应该是map类型")
|
||||
if len(syncResults) > 0 && len(finalSinkResults) > 0 {
|
||||
// 将结果转换为可比较的格式
|
||||
syncTemperatures := make([]float64, 0, len(syncResults))
|
||||
syncHumidities := make([]float64, 0, len(syncResults))
|
||||
asyncTemperatures := make([]float64, 0, len(finalSinkResults))
|
||||
asyncHumidities := make([]float64, 0, len(finalSinkResults))
|
||||
|
||||
// AddSink收到的是数组格式 []map[string]interface{}
|
||||
sinkResultArray, ok2 := sinkResults[0].([]map[string]interface{})
|
||||
require.True(t, ok2, "异步结果应该是数组类型")
|
||||
require.True(t, len(sinkResultArray) > 0, "异步结果数组不应为空")
|
||||
// 收集同步结果
|
||||
for _, result := range syncResults {
|
||||
if syncResult, ok := result.(map[string]interface{}); ok {
|
||||
syncTemperatures = append(syncTemperatures, syncResult["temperature"].(float64))
|
||||
syncHumidities = append(syncHumidities, syncResult["humidity"].(float64))
|
||||
}
|
||||
}
|
||||
|
||||
sinkResult := sinkResultArray[0]
|
||||
// 收集异步结果
|
||||
for _, result := range finalSinkResults {
|
||||
if sinkResultArray, ok := result.([]map[string]interface{}); ok && len(sinkResultArray) > 0 {
|
||||
sinkResult := sinkResultArray[0]
|
||||
asyncTemperatures = append(asyncTemperatures, sinkResult["temperature"].(float64))
|
||||
asyncHumidities = append(asyncHumidities, sinkResult["humidity"].(float64))
|
||||
}
|
||||
}
|
||||
|
||||
// 验证关键字段
|
||||
assert.Equal(t, 25.0, syncResult["temperature"])
|
||||
assert.Equal(t, 25.0, sinkResult["temperature"])
|
||||
assert.Equal(t, 60.0, syncResult["humidity"])
|
||||
assert.Equal(t, 60.0, sinkResult["humidity"])
|
||||
// 验证结果集合是否一致(不考虑顺序)
|
||||
assert.ElementsMatch(t, syncTemperatures, asyncTemperatures, "温度值集合应该一致")
|
||||
assert.ElementsMatch(t, syncHumidities, asyncHumidities, "湿度值集合应该一致")
|
||||
|
||||
// 验证计算字段
|
||||
assert.InDelta(t, 77.0, syncResult["temp_fahrenheit"], 0.1)
|
||||
assert.InDelta(t, 77.0, sinkResult["temp_fahrenheit"], 0.1)
|
||||
// 验证预期的数值是否都存在
|
||||
assert.Contains(t, syncTemperatures, 25.0, "同步结果应包含25.0")
|
||||
assert.Contains(t, syncTemperatures, 30.0, "同步结果应包含30.0")
|
||||
assert.Contains(t, asyncTemperatures, 25.0, "异步结果应包含25.0")
|
||||
assert.Contains(t, asyncTemperatures, 30.0, "异步结果应包含30.0")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -133,7 +155,7 @@ func TestEmitSyncWithAddSink(t *testing.T) {
|
||||
err := ssql.Execute(sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 添加多个AddSink回调
|
||||
// 添加多个AddSink回调,使用原子操作确保线程安全
|
||||
var sink1Count, sink2Count, sink3Count int32
|
||||
|
||||
ssql.AddSink(func(result interface{}) {
|
||||
@@ -201,7 +223,7 @@ func TestEmitSyncPerformance(t *testing.T) {
|
||||
err := ssql.Execute(sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 添加AddSink回调
|
||||
// 添加AddSink回调,使用原子操作确保线程安全
|
||||
var sinkCallCount int32
|
||||
ssql.AddSink(func(result interface{}) {
|
||||
atomic.AddInt32(&sinkCallCount, 1)
|
||||
|
||||
+2
-1
@@ -2,10 +2,11 @@ package window
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/utils/cast"
|
||||
|
||||
"github.com/rulego/streamsql/types"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user