diff --git a/doc.go b/doc.go index dfbc29f..4721f96 100644 --- a/doc.go +++ b/doc.go @@ -152,24 +152,57 @@ StreamSQL 提供灵活的日志配置选项: // 禁用日志(生产环境) ssql := streamsql.New(streamsql.WithDiscardLog()) -# 性能配置 - -对于生产环境,建议进行以下配置: - - ssql := streamsql.New( - streamsql.WithDiscardLog(), // 禁用日志提升性能 - // 其他配置选项... - ) - # 与RuleGo集成 -StreamSQL可以与RuleGo规则引擎无缝集成,利用RuleGo丰富的组件生态: +StreamSQL提供了与RuleGo规则引擎的深度集成,通过两个专用组件实现流式数据处理: - // TODO: 提供RuleGo集成示例 +• streamTransform (x/streamTransform) - 流转换器,处理非聚合SQL查询 +• streamAggregator (x/streamAggregator) - 流聚合器,处理聚合SQL查询 -更多详细信息和高级用法,请参阅: -• 自定义函数开发指南: docs/CUSTOM_FUNCTIONS_GUIDE.md -• 快速入门指南: docs/FUNCTION_QUICK_START.md -• 完整示例: examples/ +基本集成示例: + + package main + + import ( + "github.com/rulego/rulego" + "github.com/rulego/rulego/api/types" + // 注册StreamSQL组件 + _ "github.com/rulego/rulego-components/external/streamsql" + ) + + func main() { + // 规则链配置 + ruleChainJson := `{ + "ruleChain": {"id": "rule01"}, + "metadata": { + "nodes": [{ + "id": "transform1", + "type": "x/streamTransform", + "configuration": { + "sql": "SELECT deviceId, temperature * 1.8 + 32 as temp_f FROM stream WHERE temperature > 20" + } + }, { + "id": "aggregator1", + "type": "x/streamAggregator", + "configuration": { + "sql": "SELECT deviceId, AVG(temperature) as avg_temp FROM stream GROUP BY deviceId, TumblingWindow('5s')" + } + }], + "connections": [{ + "fromId": "transform1", + "toId": "aggregator1", + "type": "Success" + }] + } + }` + + // 创建规则引擎 + ruleEngine, _ := rulego.New("rule01", []byte(ruleChainJson)) + + // 发送数据 + data := `{"deviceId":"sensor01","temperature":25.5}` + msg := types.NewMsg(0, "TELEMETRY", types.JSON, types.NewMetadata(), data) + ruleEngine.OnMsg(msg) + } */ package streamsql diff --git a/stream/stream.go b/stream/stream.go index e811ced..c56c973 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -21,6 +21,68 @@ import ( "github.com/rulego/streamsql/window" ) +// 窗口相关常量 +const ( + WindowStartField = "window_start" + WindowEndField = "window_end" +) + +// 溢出策略常量 +const ( + StrategyDrop = "drop" + StrategyBlock = "block" + StrategyExpand = "expand" + StrategyPersist = "persist" +) + +// 统计信息字段常量 +const ( + StatsInputCount = "input_count" + StatsOutputCount = "output_count" + StatsDroppedCount = "dropped_count" + StatsDataChanLen = "data_chan_len" + StatsDataChanCap = "data_chan_cap" + StatsResultChanLen = "result_chan_len" + StatsResultChanCap = "result_chan_cap" + StatsSinkPoolLen = "sink_pool_len" + StatsSinkPoolCap = "sink_pool_cap" + StatsActiveRetries = "active_retries" + StatsExpanding = "expanding" +) + +// 详细统计信息字段常量 +const ( + StatsBasicStats = "basic_stats" + StatsDataChanUsage = "data_chan_usage" + StatsResultChanUsage = "result_chan_usage" + StatsSinkPoolUsage = "sink_pool_usage" + StatsProcessRate = "process_rate" + StatsDropRate = "drop_rate" + StatsPerformanceLevel = "performance_level" +) + +// 性能级别常量 +const ( + PerformanceLevelCritical = "CRITICAL" + PerformanceLevelWarning = "WARNING" + PerformanceLevelHighLoad = "HIGH_LOAD" + PerformanceLevelModerateLoad = "MODERATE_LOAD" + PerformanceLevelOptimal = "OPTIMAL" +) + +// 持久化相关常量 +const ( + PersistenceEnabled = "enabled" + PersistenceMessage = "message" + PersistenceNotEnabledMsg = "persistence not enabled" + PerformanceConfigKey = "performanceConfig" +) + +// SQL关键字常量 +const ( + SQLKeywordCase = "CASE" +) + // fieldProcessInfo 字段处理信息,用于缓存预编译的字段处理逻辑 type fieldProcessInfo struct { fieldName string // 原始字段名 @@ -117,7 +179,7 @@ func newStreamWithUnifiedConfig(config types.Config) (*Stream, error) { windowConfig.Params = make(map[string]interface{}) } // 传递完整的性能配置给窗口 - windowConfig.Params["performanceConfig"] = config.PerformanceConfig + windowConfig.Params[PerformanceConfigKey] = config.PerformanceConfig win, err = window.CreateWindow(windowConfig) if err != nil { @@ -142,7 +204,7 @@ func newStreamWithUnifiedConfig(config types.Config) (*Stream, error) { } // 如果是持久化策略,初始化持久化管理器 - if perfConfig.OverflowConfig.Strategy == "persist" && perfConfig.OverflowConfig.PersistenceConfig != nil { + if perfConfig.OverflowConfig.Strategy == StrategyPersist && perfConfig.OverflowConfig.PersistenceConfig != nil { persistConfig := perfConfig.OverflowConfig.PersistenceConfig stream.persistenceManager = NewPersistenceManagerWithConfig( persistConfig.DataDir, @@ -156,11 +218,11 @@ func newStreamWithUnifiedConfig(config types.Config) (*Stream, error) { // 性能优化:根据溢出策略预设AddData函数指针,避免运行时switch判断 switch perfConfig.OverflowConfig.Strategy { - case "block": + case StrategyBlock: stream.addDataFunc = stream.addDataBlocking - case "expand": + case StrategyExpand: stream.addDataFunc = stream.addDataWithExpansion - case "persist": + case StrategyPersist: stream.addDataFunc = stream.addDataWithPersistence default: stream.addDataFunc = stream.addDataWithDrop @@ -375,7 +437,7 @@ func (s *Stream) process() { // 检查是否为CASE表达式 trimmedExpr := strings.TrimSpace(currentFieldExpr.Expression) upperExpr := strings.ToUpper(trimmedExpr) - if strings.HasPrefix(upperExpr, "CASE") { + if strings.HasPrefix(upperExpr, SQLKeywordCase) { // CASE表达式使用支持NULL的计算方法 expression, parseErr := expr.NewExpression(currentFieldExpr.Expression) if parseErr != nil { @@ -437,11 +499,21 @@ func (s *Stream) process() { // 处理窗口模式 go func() { + defer func() { + if r := recover(); r != nil { + logger.Error("Window processing goroutine panic recovered: %v", r) + } + }() + for batch := range s.Window.OutputChan() { // 处理窗口批数据 for _, item := range batch { - _ = s.aggregator.Put("window_start", item.Slot.WindowStart()) - _ = s.aggregator.Put("window_end", item.Slot.WindowEnd()) + if err := s.aggregator.Put(WindowStartField, item.Slot.WindowStart()); err != nil { + logger.Error("failed to put window start: %v", err) + } + if err := s.aggregator.Put(WindowEndField, item.Slot.WindowEnd()); err != nil { + logger.Error("failed to put window end: %v", err) + } if err := s.aggregator.Add(item.Data); err != nil { logger.Error("aggregate error: %v", err) } @@ -471,7 +543,7 @@ func (s *Stream) process() { // 应用 HAVING 过滤条件 if s.config.Having != "" { // 检查HAVING条件是否包含CASE表达式 - hasCaseExpression := strings.Contains(strings.ToUpper(s.config.Having), "CASE") + hasCaseExpression := strings.Contains(strings.ToUpper(s.config.Having), SQLKeywordCase) var filteredResults []map[string]interface{} @@ -1274,17 +1346,17 @@ func (s *Stream) GetStats() map[string]int64 { s.dataChanMux.RUnlock() return map[string]int64{ - "input_count": atomic.LoadInt64(&s.inputCount), - "output_count": atomic.LoadInt64(&s.outputCount), - "dropped_count": atomic.LoadInt64(&s.droppedCount), - "data_chan_len": dataChanLen, - "data_chan_cap": dataChanCap, - "result_chan_len": int64(len(s.resultChan)), - "result_chan_cap": int64(cap(s.resultChan)), - "sink_pool_len": int64(len(s.sinkWorkerPool)), - "sink_pool_cap": int64(cap(s.sinkWorkerPool)), - "active_retries": int64(atomic.LoadInt32(&s.activeRetries)), - "expanding": int64(atomic.LoadInt32(&s.expanding)), + StatsInputCount: atomic.LoadInt64(&s.inputCount), + StatsOutputCount: atomic.LoadInt64(&s.outputCount), + StatsDroppedCount: atomic.LoadInt64(&s.droppedCount), + StatsDataChanLen: dataChanLen, + StatsDataChanCap: dataChanCap, + StatsResultChanLen: int64(len(s.resultChan)), + StatsResultChanCap: int64(cap(s.resultChan)), + StatsSinkPoolLen: int64(len(s.sinkWorkerPool)), + StatsSinkPoolCap: int64(cap(s.sinkWorkerPool)), + StatsActiveRetries: int64(atomic.LoadInt32(&s.activeRetries)), + StatsExpanding: int64(atomic.LoadInt32(&s.expanding)), } } @@ -1293,27 +1365,27 @@ func (s *Stream) GetDetailedStats() map[string]interface{} { stats := s.GetStats() // 计算使用率 - dataUsage := float64(stats["data_chan_len"]) / float64(stats["data_chan_cap"]) * 100 - resultUsage := float64(stats["result_chan_len"]) / float64(stats["result_chan_cap"]) * 100 - sinkUsage := float64(stats["sink_pool_len"]) / float64(stats["sink_pool_cap"]) * 100 + dataUsage := float64(stats[StatsDataChanLen]) / float64(stats[StatsDataChanCap]) * 100 + resultUsage := float64(stats[StatsResultChanLen]) / float64(stats[StatsResultChanCap]) * 100 + sinkUsage := float64(stats[StatsSinkPoolLen]) / float64(stats[StatsSinkPoolCap]) * 100 // 计算效率指标 var processRate float64 = 100.0 var dropRate float64 = 0.0 - if stats["input_count"] > 0 { - processRate = float64(stats["output_count"]) / float64(stats["input_count"]) * 100 - dropRate = float64(stats["dropped_count"]) / float64(stats["input_count"]) * 100 + if stats[StatsInputCount] > 0 { + processRate = float64(stats[StatsOutputCount]) / float64(stats[StatsInputCount]) * 100 + dropRate = float64(stats[StatsDroppedCount]) / float64(stats[StatsInputCount]) * 100 } return map[string]interface{}{ - "basic_stats": stats, - "data_chan_usage": dataUsage, - "result_chan_usage": resultUsage, - "sink_pool_usage": sinkUsage, - "process_rate": processRate, - "drop_rate": dropRate, - "performance_level": s.assessPerformanceLevel(dataUsage, dropRate), + StatsBasicStats: stats, + StatsDataChanUsage: dataUsage, + StatsResultChanUsage: resultUsage, + StatsSinkPoolUsage: sinkUsage, + StatsProcessRate: processRate, + StatsDropRate: dropRate, + StatsPerformanceLevel: s.assessPerformanceLevel(dataUsage, dropRate), } } @@ -1321,15 +1393,15 @@ func (s *Stream) GetDetailedStats() map[string]interface{} { func (s *Stream) assessPerformanceLevel(dataUsage, dropRate float64) string { switch { case dropRate > 50: - return "CRITICAL" // 严重性能问题 + return PerformanceLevelCritical // 严重性能问题 case dropRate > 20: - return "WARNING" // 性能警告 + return PerformanceLevelWarning // 性能警告 case dataUsage > 90: - return "HIGH_LOAD" // 高负载 + return PerformanceLevelHighLoad // 高负载 case dataUsage > 70: - return "MODERATE_LOAD" // 中等负载 + return PerformanceLevelModerateLoad // 中等负载 default: - return "OPTIMAL" // 最佳状态 + return PerformanceLevelOptimal // 最佳状态 } } @@ -1534,16 +1606,224 @@ func (s *Stream) LoadAndReprocessPersistedData() error { func (s *Stream) GetPersistenceStats() map[string]interface{} { if s.persistenceManager == nil { return map[string]interface{}{ - "enabled": false, - "message": "persistence not enabled", + PersistenceEnabled: false, + PersistenceMessage: PersistenceNotEnabledMsg, } } stats := s.persistenceManager.GetStats() - stats["enabled"] = true + stats[PersistenceEnabled] = true return stats } +// IsAggregationQuery 检查当前流是否为聚合查询 +func (s *Stream) IsAggregationQuery() bool { + return s.config.NeedWindow +} + +// ProcessSync 同步处理单条数据,立即返回结果 +// 仅适用于非聚合查询,聚合查询会返回错误 +func (s *Stream) ProcessSync(data interface{}) (interface{}, error) { + // 检查是否为聚合查询 + if s.config.NeedWindow { + return nil, fmt.Errorf("聚合查询不支持同步处理") + } + + // 应用过滤条件 + if s.filter != nil && !s.filter.Evaluate(data) { + return nil, nil // 不匹配过滤条件,返回nil + } + + // 直接处理数据并返回结果 + return s.processDirectDataSync(data) +} + +// processDirectDataSync 同步版本的直接数据处理 +func (s *Stream) processDirectDataSync(data interface{}) (interface{}, error) { + // 增加输入计数 + atomic.AddInt64(&s.inputCount, 1) + + // 简化:直接将数据作为map处理 + dataMap, ok := data.(map[string]interface{}) + if !ok { + atomic.AddInt64(&s.droppedCount, 1) + return nil, fmt.Errorf("不支持的数据类型: %T", data) + } + + // 创建结果map,预分配合适容量 + estimatedSize := len(s.config.FieldExpressions) + len(s.config.SimpleFields) + if estimatedSize < 8 { + estimatedSize = 8 // 最小容量 + } + result := make(map[string]interface{}, estimatedSize) + + // 处理表达式字段 + for fieldName, fieldExpr := range s.config.FieldExpressions { + // 使用桥接器计算表达式,支持IS NULL等语法 + bridge := functions.GetExprBridge() + + // 预处理表达式中的IS NULL和LIKE语法 + processedExpr := fieldExpr.Expression + if bridge.ContainsIsNullOperator(processedExpr) { + if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil { + processedExpr = processed + } + } + if bridge.ContainsLikeOperator(processedExpr) { + if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil { + processedExpr = processed + } + } + + // 检查表达式是否是函数调用(包含括号) + isFunctionCall := strings.Contains(fieldExpr.Expression, "(") && strings.Contains(fieldExpr.Expression, ")") + + // 检查表达式是否包含嵌套字段(但排除函数调用中的点号) + hasNestedFields := false + if !isFunctionCall && strings.Contains(fieldExpr.Expression, ".") { + hasNestedFields = true + } + + // 检查是否为CASE表达式 + trimmedExpr := strings.TrimSpace(fieldExpr.Expression) + upperExpr := strings.ToUpper(trimmedExpr) + isCaseExpression := strings.HasPrefix(upperExpr, SQLKeywordCase) + + var evalResult interface{} + + if isFunctionCall { + // 对于函数调用,优先使用桥接器处理,这样可以保持原始类型 + exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if err != nil { + logger.Error("Function call evaluation failed for field %s: %v", fieldName, err) + result[fieldName] = nil + continue + } + evalResult = exprResult + } else if hasNestedFields || isCaseExpression { + // 检测到嵌套字段(非函数调用)或CASE表达式,使用自定义表达式引擎 + expression, parseErr := expr.NewExpression(fieldExpr.Expression) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + continue + } + + // 使用支持NULL的计算方法 + numResult, isNull, err := expression.EvaluateWithNull(dataMap) + if err != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, err) + result[fieldName] = nil + continue + } + if isNull { + evalResult = nil // NULL值 + } else { + evalResult = numResult + } + } else { + // 尝试使用桥接器处理其他表达式 + exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if err != nil { + // 如果桥接器失败,回退到原来的表达式引擎(使用原始表达式,不是预处理的) + expression, parseErr := expr.NewExpression(fieldExpr.Expression) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + continue + } + + // 计算表达式,支持NULL值 + numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap) + if evalErr != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, evalErr) + result[fieldName] = nil + continue + } + if isNull { + evalResult = nil // NULL值 + } else { + evalResult = numResult + } + } else { + evalResult = exprResult + } + } + + result[fieldName] = evalResult + } + + // 处理SimpleFields(复用现有逻辑) + if len(s.config.SimpleFields) > 0 { + for _, fieldSpec := range s.config.SimpleFields { + info := s.compiledFieldInfo[fieldSpec] + if info == nil { + // 如果没有预编译信息,回退到原逻辑(安全性保证) + s.processSingleFieldFallback(fieldSpec, dataMap, data, result) + continue + } + + if info.isSelectAll { + // SELECT *:批量复制所有字段,跳过表达式字段 + for k, v := range dataMap { + if _, isExpression := s.config.FieldExpressions[k]; !isExpression { + result[k] = v + } + } + continue + } + + // 跳过已经通过表达式字段处理的字段 + if _, isExpression := s.config.FieldExpressions[info.outputName]; isExpression { + continue + } + + if info.isFunctionCall { + // 执行函数调用 + if funcResult, err := s.executeFunction(info.fieldName, dataMap); err == nil { + result[info.outputName] = funcResult + } else { + logger.Error("Function execution error %s: %v", info.fieldName, err) + result[info.outputName] = nil + } + } else { + // 普通字段处理 + var value interface{} + var exists bool + + if info.hasNestedField { + value, exists = fieldpath.GetNestedField(data, info.fieldName) + } else { + value, exists = dataMap[info.fieldName] + } + + if exists { + result[info.outputName] = value + } else { + result[info.outputName] = nil + } + } + } + } else if len(s.config.FieldExpressions) == 0 { + // 如果没有指定字段且没有表达式字段,保留所有字段 + for k, v := range dataMap { + result[k] = v + } + } + + // 增加输出计数 + atomic.AddInt64(&s.outputCount, 1) + + // 包装结果为数组格式,保持与异步模式的一致性 + results := []map[string]interface{}{result} + + // 触发 AddSink 回调,保持同步和异步模式的一致性 + // 这样用户可以同时获得同步结果和异步回调 + s.callSinksAsync(results) + + return result, nil +} + // 向后兼容性函数 // NewStreamWithBuffers 创建带自定义缓冲区大小的Stream (已弃用,使用NewStreamWithCustomPerformance) @@ -1571,15 +1851,15 @@ func NewStreamWithoutDataLoss(config types.Config, strategy string) (*Stream, er // 应用用户指定的策略 validStrategies := map[string]bool{ - "drop": true, - "block": true, - "expand": true, - "persist": true, + StrategyDrop: true, + StrategyBlock: true, + StrategyExpand: true, + StrategyPersist: true, } if validStrategies[strategy] { perfConfig.OverflowConfig.Strategy = strategy - if strategy == "drop" { + if strategy == StrategyDrop { perfConfig.OverflowConfig.AllowDataLoss = true } } @@ -1599,7 +1879,7 @@ func NewStreamWithLossPolicy(config types.Config, dataBufSize, resultBufSize, si perfConfig.WorkerConfig.SinkPoolSize = sinkPoolSize perfConfig.OverflowConfig.Strategy = overflowStrategy perfConfig.OverflowConfig.BlockTimeout = timeout - perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == "drop") + perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == StrategyDrop) config.PerformanceConfig = perfConfig return newStreamWithUnifiedConfig(config) @@ -1616,10 +1896,10 @@ func NewStreamWithLossPolicyAndPersistence(config types.Config, dataBufSize, res perfConfig.WorkerConfig.SinkPoolSize = sinkPoolSize perfConfig.OverflowConfig.Strategy = overflowStrategy perfConfig.OverflowConfig.BlockTimeout = timeout - perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == "drop") + perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == StrategyDrop) // 设置持久化配置 - if overflowStrategy == "persist" { + if overflowStrategy == StrategyPersist { perfConfig.OverflowConfig.PersistenceConfig = &types.PersistenceConfig{ DataDir: persistDataDir, MaxFileSize: persistMaxFileSize, diff --git a/streamsql.go b/streamsql.go index 62f716a..1b30a0f 100644 --- a/streamsql.go +++ b/streamsql.go @@ -38,6 +38,9 @@ type Streamsql struct { // 性能配置模式 performanceMode string // "default", "high_performance", "low_latency", "zero_data_loss", "custom" customConfig *types.PerformanceConfig + + // 新增:同步处理模式配置 + enableSyncMode bool // 是否启用同步模式(用于非聚合查询) } // New 创建一个新的StreamSQL实例。 @@ -190,6 +193,64 @@ func (s *Streamsql) Emit(data interface{}) { } } +// EmitSync 同步处理数据,立即返回处理结果。 +// 仅适用于非聚合查询(如过滤、转换等),聚合查询会返回错误。 +// +// 对于非聚合查询,此方法提供同步的数据处理能力,同时: +// 1. 立即返回处理结果(同步) +// 2. 触发已注册的AddSink回调(异步) +// +// 这确保了同步和异步模式的一致性,用户可以同时获得: +// - 立即可用的处理结果 +// - 异步回调处理(用于日志、监控、持久化等) +// +// 参数: +// - data: 要处理的数据 +// +// 返回值: +// - interface{}: 处理后的结果,如果不匹配过滤条件返回nil +// - error: 处理错误,如果是聚合查询会返回错误 +// +// 示例: +// +// // 添加日志回调 +// ssql.AddSink(func(result interface{}) { +// fmt.Printf("异步日志: %v\n", result) +// }) +// +// // 同步处理并立即获取结果 +// result, err := ssql.EmitSync(map[string]interface{}{ +// "temperature": 25.5, +// "humidity": 60.0, +// }) +// if err != nil { +// // 处理错误 +// } else if result != nil { +// // 立即使用处理结果 +// fmt.Printf("同步结果: %v\n", result) +// // 同时异步回调也会被触发 +// } +func (s *Streamsql) EmitSync(data interface{}) (interface{}, error) { + if s.stream == nil { + return nil, fmt.Errorf("stream未初始化") + } + + // 检查是否为非聚合查询 + if s.stream.IsAggregationQuery() { + return nil, fmt.Errorf("同步模式仅支持非聚合查询,聚合查询请使用Emit()方法") + } + + return s.stream.ProcessSync(data) +} + +// IsAggregationQuery 检查当前查询是否为聚合查询 +func (s *Streamsql) IsAggregationQuery() bool { + if s.stream == nil { + return false + } + return s.stream.IsAggregationQuery() +} + // Stream 返回底层的流处理器实例。 // 通过此方法可以访问更底层的流处理功能。 // diff --git a/sync_sink_test.go b/sync_sink_test.go new file mode 100644 index 0000000..30c89be --- /dev/null +++ b/sync_sink_test.go @@ -0,0 +1,258 @@ +/* + * Copyright 2025 The RuleGo Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamsql + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEmitSyncWithAddSink 测试EmitSync同时触发AddSink回调 +func TestEmitSyncWithAddSink(t *testing.T) { + t.Run("非聚合查询同步+异步结果", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行非聚合查询 + sql := "SELECT temperature, humidity, temperature * 1.8 + 32 as temp_fahrenheit FROM stream WHERE temperature > 20" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 验证是非聚合查询 + assert.False(t, ssql.IsAggregationQuery()) + + // 设置AddSink回调来收集异步结果 + var sinkCallCount int32 + var sinkResults []interface{} + var sinkResultsMux sync.Mutex // 保护sinkResults访问 + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sinkCallCount, 1) + sinkResultsMux.Lock() + sinkResults = append(sinkResults, result) + sinkResultsMux.Unlock() + }) + + // 测试数据 + testData := []map[string]interface{}{ + {"temperature": 25.0, "humidity": 60.0}, // 符合条件 + {"temperature": 15.0, "humidity": 70.0}, // 被过滤 + {"temperature": 30.0, "humidity": 80.0}, // 符合条件 + } + + var syncResults []interface{} + + // 处理测试数据 + for _, data := range testData { + // 同步处理 + result, err := ssql.EmitSync(data) + require.NoError(t, err) + + if result != nil { + syncResults = append(syncResults, result) + } + } + + // 等待异步回调完成 + time.Sleep(100 * time.Millisecond) + + // 验证同步结果 + assert.Equal(t, 2, len(syncResults), "应该有2条同步结果(温度>20)") + + // 安全读取异步回调结果 + sinkResultsMux.Lock() + finalSinkResults := make([]interface{}, len(sinkResults)) + copy(finalSinkResults, sinkResults) + sinkResultsMux.Unlock() + + // 验证异步回调结果 + finalSinkCallCount := atomic.LoadInt32(&sinkCallCount) + assert.Equal(t, int32(2), finalSinkCallCount, "AddSink应该被调用2次") + assert.Equal(t, 2, len(finalSinkResults), "应该收集到2条异步结果") + + // 验证同步和异步结果的内容一致性 + if len(syncResults) > 0 && len(finalSinkResults) > 0 { + // 将结果转换为可比较的格式 + syncTemperatures := make([]float64, 0, len(syncResults)) + syncHumidities := make([]float64, 0, len(syncResults)) + asyncTemperatures := make([]float64, 0, len(finalSinkResults)) + asyncHumidities := make([]float64, 0, len(finalSinkResults)) + + // 收集同步结果 + for _, result := range syncResults { + if syncResult, ok := result.(map[string]interface{}); ok { + syncTemperatures = append(syncTemperatures, syncResult["temperature"].(float64)) + syncHumidities = append(syncHumidities, syncResult["humidity"].(float64)) + } + } + + // 收集异步结果 + for _, result := range finalSinkResults { + if sinkResultArray, ok := result.([]map[string]interface{}); ok && len(sinkResultArray) > 0 { + sinkResult := sinkResultArray[0] + asyncTemperatures = append(asyncTemperatures, sinkResult["temperature"].(float64)) + asyncHumidities = append(asyncHumidities, sinkResult["humidity"].(float64)) + } + } + + // 验证结果集合是否一致(不考虑顺序) + assert.ElementsMatch(t, syncTemperatures, asyncTemperatures, "温度值集合应该一致") + assert.ElementsMatch(t, syncHumidities, asyncHumidities, "湿度值集合应该一致") + + // 验证预期的数值是否都存在 + assert.Contains(t, syncTemperatures, 25.0, "同步结果应包含25.0") + assert.Contains(t, syncTemperatures, 30.0, "同步结果应包含30.0") + assert.Contains(t, asyncTemperatures, 25.0, "异步结果应包含25.0") + assert.Contains(t, asyncTemperatures, 30.0, "异步结果应包含30.0") + } + }) + + t.Run("聚合查询不支持EmitSync", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行聚合查询 + sql := "SELECT AVG(temperature) as avg_temp FROM stream GROUP BY TumblingWindow('1s')" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 验证是聚合查询 + assert.True(t, ssql.IsAggregationQuery()) + + // 尝试同步处理应该返回错误 + data := map[string]interface{}{"temperature": 25.0} + result, err := ssql.EmitSync(data) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "同步模式仅支持非聚合查询") + }) + + t.Run("多个AddSink回调都被触发", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行非聚合查询 + sql := "SELECT temperature FROM stream" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 添加多个AddSink回调,使用原子操作确保线程安全 + var sink1Count, sink2Count, sink3Count int32 + + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sink1Count, 1) + }) + + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sink2Count, 1) + }) + + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sink3Count, 1) + }) + + // 处理一条数据 + data := map[string]interface{}{"temperature": 25.0} + result, err := ssql.EmitSync(data) + require.NoError(t, err) + require.NotNil(t, result) + + // 等待异步回调 + time.Sleep(100 * time.Millisecond) + + // 验证所有回调都被触发 + assert.Equal(t, int32(1), atomic.LoadInt32(&sink1Count)) + assert.Equal(t, int32(1), atomic.LoadInt32(&sink2Count)) + assert.Equal(t, int32(1), atomic.LoadInt32(&sink3Count)) + }) + + t.Run("过滤条件不匹配时AddSink不触发", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行带过滤条件的查询 + sql := "SELECT temperature FROM stream WHERE temperature > 30" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 添加AddSink回调 + var sinkCallCount int32 + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sinkCallCount, 1) + }) + + // 处理不符合条件的数据 + data := map[string]interface{}{"temperature": 20.0} // 不符合 > 30 的条件 + result, err := ssql.EmitSync(data) + require.NoError(t, err) + assert.Nil(t, result, "不符合过滤条件应该返回nil") + + // 等待可能的异步回调 + time.Sleep(100 * time.Millisecond) + + // 验证AddSink没有被触发 + assert.Equal(t, int32(0), atomic.LoadInt32(&sinkCallCount), "过滤掉的数据不应触发AddSink") + }) +} + +// TestEmitSyncPerformance 测试EmitSync性能(包括AddSink触发) +func TestEmitSyncPerformance(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := "SELECT temperature, humidity FROM stream WHERE temperature > 0" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 添加AddSink回调,使用原子操作确保线程安全 + var sinkCallCount int32 + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sinkCallCount, 1) + }) + + // 性能测试 + testCount := 1000 + + start := time.Now() + for i := 0; i < testCount; i++ { + data := map[string]interface{}{ + "temperature": float64(20 + i%20), + "humidity": float64(50 + i%30), + } + + result, err := ssql.EmitSync(data) + require.NoError(t, err) + require.NotNil(t, result) + } + duration := time.Since(start) + + // 等待所有异步回调完成 + time.Sleep(200 * time.Millisecond) + + t.Logf("处理 %d 条数据耗时: %v", testCount, duration) + t.Logf("平均每条数据: %v", duration/time.Duration(testCount)) + t.Logf("AddSink 触发次数: %d", atomic.LoadInt32(&sinkCallCount)) + + // 验证性能和一致性 + assert.Less(t, duration, 1*time.Second, "性能应该足够好") + assert.Equal(t, int32(testCount), atomic.LoadInt32(&sinkCallCount), "所有数据都应触发AddSink") +} diff --git a/utils/reflectutil/reflectutil.go b/utils/reflectutil/reflectutil.go new file mode 100644 index 0000000..6b7429b --- /dev/null +++ b/utils/reflectutil/reflectutil.go @@ -0,0 +1,27 @@ +package reflectutil + +import ( + "fmt" + "reflect" +) + +// SafeFieldByName 安全地获取结构体字段 +func SafeFieldByName(v reflect.Value, fieldName string) (reflect.Value, error) { + // 检查Value是否有效 + if !v.IsValid() { + return reflect.Value{}, fmt.Errorf("invalid value") + } + + // 检查是否为结构体类型 + if v.Kind() != reflect.Struct { + return reflect.Value{}, fmt.Errorf("value is not a struct, got %v", v.Kind()) + } + + // 安全地获取字段 + field := v.FieldByName(fieldName) + if !field.IsValid() { + return reflect.Value{}, fmt.Errorf("field %s not found", fieldName) + } + + return field, nil +} diff --git a/window/factory.go b/window/factory.go index 32dbaf0..9bc14ad 100644 --- a/window/factory.go +++ b/window/factory.go @@ -2,10 +2,11 @@ package window import ( "fmt" - "github.com/rulego/streamsql/utils/cast" "reflect" "time" + "github.com/rulego/streamsql/utils/cast" + "github.com/rulego/streamsql/types" )