diff --git a/README.md b/README.md index 6eaa202..9daabfa 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ func main() { } // Handle real-time transformation results - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf("Real-time result: %+v\n", result) }) @@ -110,7 +110,7 @@ func main() { // Process data one by one, each will output results immediately for _, data := range sensorData { - ssql.Stream().AddData(data) + ssql.Emit(data) time.Sleep(100 * time.Millisecond) // Simulate real-time data arrival } @@ -273,7 +273,7 @@ func main() { } // Handle aggregation results - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf("Aggregation result: %+v\n", result) }) @@ -293,7 +293,7 @@ func main() { "timestamp": time.Now().Unix(), } - ssql.Stream().AddData(nestedData) + ssql.Emit(nestedData) } ``` diff --git a/README_ZH.md b/README_ZH.md index 6984a45..49f20c6 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -85,7 +85,7 @@ func main() { } // 处理实时转换结果 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf("实时处理结果: %+v\n", result) }) @@ -113,7 +113,7 @@ func main() { // 逐条处理数据,每条都会立即输出结果 for _, data := range sensorData { - ssql.Stream().AddData(data) + ssql.Emit(data) time.Sleep(100 * time.Millisecond) // 模拟实时数据到达 } @@ -289,7 +289,7 @@ func main() { } // 处理聚合结果 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf("聚合结果: %+v\n", result) }) @@ -309,7 +309,7 @@ func main() { "timestamp": time.Now().Unix(), } - ssql.Stream().AddData(nestedData) + ssql.Emit(nestedData) } ``` @@ -317,14 +317,6 @@ func main() { StreamSQL 支持多种函数类型,包括数学、字符串、转换、聚合、分析、窗口等上百个函数。[文档](docs/FUNCTIONS_USAGE_GUIDE.md) -### 🎨 支持的函数类型 - -- **📊 数学函数** - sqrt, power, abs, 三角函数等 -- **📝 字符串函数** - concat, upper, lower, trim等 -- **🔄 转换函数** - cast, hex2dec, encode/decode等 -- **📈 聚合函数** - 自定义聚合逻辑 -- **🔍 分析函数** - lag, latest, 变化检测等 - ## 概念 ### 窗口 diff --git a/doc.go b/doc.go index f649dde..4721f96 100644 --- a/doc.go +++ b/doc.go @@ -1,175 +1,208 @@ -/* - * Copyright 2025 The RuleGo Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* -Package streamsql 是一个轻量级的、基于 SQL 的物联网边缘流处理引擎。 - -StreamSQL 提供了高效的无界数据流处理和分析能力,支持多种窗口类型、聚合函数、 -自定义函数,以及与 RuleGo 生态的无缝集成。 - -# 核心特性 - -• 轻量级设计 - 纯内存操作,无外部依赖 -• SQL语法支持 - 使用熟悉的SQL语法处理流数据 -• 多种窗口类型 - 滑动窗口、滚动窗口、计数窗口、会话窗口 -• 丰富的聚合函数 - MAX, MIN, AVG, SUM, STDDEV, MEDIAN, PERCENTILE等 -• 插件式自定义函数 - 运行时动态注册,支持8种函数类型 -• RuleGo生态集成 - 利用RuleGo组件扩展输入输出源 - -# 入门示例 - -基本的流数据处理: - - package main - - import ( - "fmt" - "math/rand" - "time" - "github.com/rulego/streamsql" - ) - - func main() { - // 创建StreamSQL实例 - ssql := streamsql.New() - - // 定义SQL查询 - 每5秒按设备ID分组计算温度平均值 - sql := `SELECT deviceId, - AVG(temperature) as avg_temp, - MIN(humidity) as min_humidity, - window_start() as start, - window_end() as end - FROM stream - WHERE deviceId != 'device3' - GROUP BY deviceId, TumblingWindow('5s')` - - // 执行SQL,创建流处理任务 - err := ssql.Execute(sql) - if err != nil { - panic(err) - } - - // 添加结果处理回调 - ssql.Stream().AddSink(func(result interface{}) { - fmt.Printf("聚合结果: %v\n", result) - }) - - // 模拟发送流数据 - go func() { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - // 生成随机设备数据 - data := map[string]interface{}{ - "deviceId": fmt.Sprintf("device%d", rand.Intn(3)+1), - "temperature": 20.0 + rand.Float64()*10, - "humidity": 50.0 + rand.Float64()*20, - } - ssql.AddData(data) - } - } - }() - - // 运行30秒 - time.Sleep(30 * time.Second) - } - -# 窗口函数 - -StreamSQL 支持多种窗口类型: - - // 滚动窗口 - 每5秒一个独立窗口 - SELECT AVG(temperature) FROM stream GROUP BY TumblingWindow('5s') - - // 滑动窗口 - 窗口大小30秒,每10秒滑动一次 - SELECT MAX(temperature) FROM stream GROUP BY SlidingWindow('30s', '10s') - - // 计数窗口 - 每100条记录一个窗口 - SELECT COUNT(*) FROM stream GROUP BY CountingWindow(100) - - // 会话窗口 - 超时5分钟自动关闭会话 - SELECT user_id, COUNT(*) FROM stream GROUP BY user_id, SessionWindow('5m') - -# 自定义函数 - -StreamSQL 支持插件式自定义函数,运行时动态注册: - - // 注册温度转换函数 - functions.RegisterCustomFunction( - "fahrenheit_to_celsius", - functions.TypeConversion, - "温度转换", - "华氏度转摄氏度", - 1, 1, - func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { - f, _ := functions.ConvertToFloat64(args[0]) - return (f - 32) * 5 / 9, nil - }, - ) - - // 立即在SQL中使用 - sql := `SELECT deviceId, - AVG(fahrenheit_to_celsius(temperature)) as avg_celsius - FROM stream GROUP BY deviceId, TumblingWindow('5s')` - -支持的自定义函数类型: -• TypeMath - 数学计算函数 -• TypeString - 字符串处理函数 -• TypeConversion - 类型转换函数 -• TypeDateTime - 时间日期函数 -• TypeAggregation - 聚合函数 -• TypeAnalytical - 分析函数 -• TypeWindow - 窗口函数 -• TypeCustom - 通用自定义函数 - -# 日志配置 - -StreamSQL 提供灵活的日志配置选项: - - // 设置日志级别 - ssql := streamsql.New(streamsql.WithLogLevel(logger.DEBUG)) - - // 输出到文件 - logFile, _ := os.OpenFile("app.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - ssql := streamsql.New(streamsql.WithLogOutput(logFile, logger.INFO)) - - // 禁用日志(生产环境) - ssql := streamsql.New(streamsql.WithDiscardLog()) - -# 性能配置 - -对于生产环境,建议进行以下配置: - - ssql := streamsql.New( - streamsql.WithDiscardLog(), // 禁用日志提升性能 - // 其他配置选项... - ) - -# 与RuleGo集成 - -StreamSQL可以与RuleGo规则引擎无缝集成,利用RuleGo丰富的组件生态: - - // TODO: 提供RuleGo集成示例 - -更多详细信息和高级用法,请参阅: -• 自定义函数开发指南: docs/CUSTOM_FUNCTIONS_GUIDE.md -• 快速入门指南: docs/FUNCTION_QUICK_START.md -• 完整示例: examples/ -*/ -package streamsql +/* + * Copyright 2025 The RuleGo Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package streamsql 是一个轻量级的、基于 SQL 的物联网边缘流处理引擎。 + +StreamSQL 提供了高效的无界数据流处理和分析能力,支持多种窗口类型、聚合函数、 +自定义函数,以及与 RuleGo 生态的无缝集成。 + +# 核心特性 + +• 轻量级设计 - 纯内存操作,无外部依赖 +• SQL语法支持 - 使用熟悉的SQL语法处理流数据 +• 多种窗口类型 - 滑动窗口、滚动窗口、计数窗口、会话窗口 +• 丰富的聚合函数 - MAX, MIN, AVG, SUM, STDDEV, MEDIAN, PERCENTILE等 +• 插件式自定义函数 - 运行时动态注册,支持8种函数类型 +• RuleGo生态集成 - 利用RuleGo组件扩展输入输出源 + +# 入门示例 + +基本的流数据处理: + + package main + + import ( + "fmt" + "math/rand" + "time" + "github.com/rulego/streamsql" + ) + + func main() { + // 创建StreamSQL实例 + ssql := streamsql.New() + + // 定义SQL查询 - 每5秒按设备ID分组计算温度平均值 + sql := `SELECT deviceId, + AVG(temperature) as avg_temp, + MIN(humidity) as min_humidity, + window_start() as start, + window_end() as end + FROM stream + WHERE deviceId != 'device3' + GROUP BY deviceId, TumblingWindow('5s')` + + // 执行SQL,创建流处理任务 + err := ssql.Execute(sql) + if err != nil { + panic(err) + } + + // 添加结果处理回调 + ssql.AddSink(func(result interface{}) { + fmt.Printf("聚合结果: %v\n", result) + }) + + // 模拟发送流数据 + go func() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // 生成随机设备数据 + data := map[string]interface{}{ + "deviceId": fmt.Sprintf("device%d", rand.Intn(3)+1), + "temperature": 20.0 + rand.Float64()*10, + "humidity": 50.0 + rand.Float64()*20, + } + ssql.Emit(data) + } + } + }() + + // 运行30秒 + time.Sleep(30 * time.Second) + } + +# 窗口函数 + +StreamSQL 支持多种窗口类型: + + // 滚动窗口 - 每5秒一个独立窗口 + SELECT AVG(temperature) FROM stream GROUP BY TumblingWindow('5s') + + // 滑动窗口 - 窗口大小30秒,每10秒滑动一次 + SELECT MAX(temperature) FROM stream GROUP BY SlidingWindow('30s', '10s') + + // 计数窗口 - 每100条记录一个窗口 + SELECT COUNT(*) FROM stream GROUP BY CountingWindow(100) + + // 会话窗口 - 超时5分钟自动关闭会话 + SELECT user_id, COUNT(*) FROM stream GROUP BY user_id, SessionWindow('5m') + +# 自定义函数 + +StreamSQL 支持插件式自定义函数,运行时动态注册: + + // 注册温度转换函数 + functions.RegisterCustomFunction( + "fahrenheit_to_celsius", + functions.TypeConversion, + "温度转换", + "华氏度转摄氏度", + 1, 1, + func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + f, _ := functions.ConvertToFloat64(args[0]) + return (f - 32) * 5 / 9, nil + }, + ) + + // 立即在SQL中使用 + sql := `SELECT deviceId, + AVG(fahrenheit_to_celsius(temperature)) as avg_celsius + FROM stream GROUP BY deviceId, TumblingWindow('5s')` + +支持的自定义函数类型: +• TypeMath - 数学计算函数 +• TypeString - 字符串处理函数 +• TypeConversion - 类型转换函数 +• TypeDateTime - 时间日期函数 +• TypeAggregation - 聚合函数 +• TypeAnalytical - 分析函数 +• TypeWindow - 窗口函数 +• TypeCustom - 通用自定义函数 + +# 日志配置 + +StreamSQL 提供灵活的日志配置选项: + + // 设置日志级别 + ssql := streamsql.New(streamsql.WithLogLevel(logger.DEBUG)) + + // 输出到文件 + logFile, _ := os.OpenFile("app.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + ssql := streamsql.New(streamsql.WithLogOutput(logFile, logger.INFO)) + + // 禁用日志(生产环境) + ssql := streamsql.New(streamsql.WithDiscardLog()) + +# 与RuleGo集成 + +StreamSQL提供了与RuleGo规则引擎的深度集成,通过两个专用组件实现流式数据处理: + +• streamTransform (x/streamTransform) - 流转换器,处理非聚合SQL查询 +• streamAggregator (x/streamAggregator) - 流聚合器,处理聚合SQL查询 + +基本集成示例: + + package main + + import ( + "github.com/rulego/rulego" + "github.com/rulego/rulego/api/types" + // 注册StreamSQL组件 + _ "github.com/rulego/rulego-components/external/streamsql" + ) + + func main() { + // 规则链配置 + ruleChainJson := `{ + "ruleChain": {"id": "rule01"}, + "metadata": { + "nodes": [{ + "id": "transform1", + "type": "x/streamTransform", + "configuration": { + "sql": "SELECT deviceId, temperature * 1.8 + 32 as temp_f FROM stream WHERE temperature > 20" + } + }, { + "id": "aggregator1", + "type": "x/streamAggregator", + "configuration": { + "sql": "SELECT deviceId, AVG(temperature) as avg_temp FROM stream GROUP BY deviceId, TumblingWindow('5s')" + } + }], + "connections": [{ + "fromId": "transform1", + "toId": "aggregator1", + "type": "Success" + }] + } + }` + + // 创建规则引擎 + ruleEngine, _ := rulego.New("rule01", []byte(ruleChainJson)) + + // 发送数据 + data := `{"deviceId":"sensor01","temperature":25.5}` + msg := types.NewMsg(0, "TELEMETRY", types.JSON, types.NewMetadata(), data) + ruleEngine.OnMsg(msg) + } +*/ +package streamsql diff --git a/docs/NEGATIVE_NUMBER_SUPPORT.md b/docs/NEGATIVE_NUMBER_SUPPORT.md new file mode 100644 index 0000000..4e6aba3 --- /dev/null +++ b/docs/NEGATIVE_NUMBER_SUPPORT.md @@ -0,0 +1,225 @@ +# StreamSQL 负数支持文档 + +## 概述 + +StreamSQL 现在全面支持负数在 CASE 表达式中的使用。本文档总结了负数支持的完善情况、支持范围和使用建议。 + +## ✅ 已支持的负数用法 + +### 1. 基本负数常量 + +```sql +-- CASE 表达式中的负数常量 +CASE WHEN temperature > 0 THEN 1 ELSE -1 END + +-- 负数小数 +CASE WHEN temperature > 0 THEN 1.5 ELSE -2.5 END + +-- 负零 +CASE WHEN temperature = -0 THEN 1 ELSE 0 END +``` + +### 2. 比较运算符后的负数 + +```sql +-- 比较运算符后直接跟负数 +CASE WHEN temperature < -10 THEN 'FREEZING' ELSE 'NORMAL' END +CASE WHEN temperature >= -5.5 THEN 'ABOVE' ELSE 'BELOW' END +CASE WHEN temperature > -20 THEN 'WARM' ELSE 'COLD' END +``` + +### 3. 简单 CASE 表达式中的负数 + +```sql +-- 简单 CASE 中使用负数作为匹配值 +CASE temperature + WHEN -10 THEN 'FROZEN' + WHEN -5 THEN 'COLD' + WHEN 0 THEN 'ZERO' + ELSE 'OTHER' +END +``` + +### 4. 算术表达式中的负数 + +```sql +-- 括号内的负数运算 +CASE WHEN temperature + (-10) > 0 THEN 1 ELSE 0 END +CASE WHEN (temperature * -1) > 10 THEN 1 ELSE 0 END +``` + +## ⚠️ 部分支持或限制 + +### 1. 函数参数中的负数表达式 + +```sql +-- 当前不完全支持:函数参数中的负数变量 +CASE WHEN ABS(-temperature) > 10 THEN 1 ELSE 0 END -- ❌ + +-- 推荐替代方案:使用括号或先计算 +CASE WHEN ABS(temperature * -1) > 10 THEN 1 ELSE 0 END -- ✅ +``` + +### 2. BETWEEN 语句中的负数范围 + +```sql +-- 当前不支持:BETWEEN 与负数组合 +CASE WHEN temperature BETWEEN -20 AND -10 THEN 1 ELSE 0 END -- ❌ + +-- 推荐替代方案:使用比较运算符 +CASE WHEN temperature >= -20 AND temperature <= -10 THEN 1 ELSE 0 END -- ✅ +``` + +### 3. SQL 中的空格分隔负数 + +```sql +-- 避免在 SQL 中使用空格分隔的负数 +SELECT CASE WHEN temperature < - 10 THEN 'COLD' END -- ❌ 解析问题 + +-- 推荐写法:紧密连接或使用括号 +SELECT CASE WHEN temperature < -10 THEN 'COLD' END -- ✅ +SELECT CASE WHEN temperature < (-10) THEN 'COLD' END -- ✅ +``` + +## 🔧 技术实现 + +### 词法分析器增强 + +1. **智能负数识别**: + - 识别比较运算符后的负数(`<`, `>`, `<=`, `>=`, `==`, `!=`) + - 支持逻辑运算符后的负数(`AND`, `OR`) + - 支持 CASE 关键字后的负数(`WHEN`, `THEN`, `ELSE`) + +2. **连续运算符检查优化**: + - 允许比较运算符后跟负数的合法组合 + - 智能区分负数与减号运算符 + +3. **空格处理**: + - 正确处理空格分隔的负数标记 + - 改进 token 化过程以支持各种负数格式 + +### 表达式求值增强 + +1. **负数常量解析**:完全支持负整数和负小数 +2. **类型转换**:正确处理负数的数值转换 +3. **NULL 值处理**:负数与 NULL 值的正确交互 + +## 📊 测试覆盖 + +### 表达式级别测试 + +- ✅ 负数常量在 THEN/ELSE 中 +- ✅ 负数常量在 WHEN 条件中 +- ✅ 负数小数支持 +- ✅ 负数在算术表达式中 +- ✅ 负数在简单 CASE 中 +- ✅ 负零处理 + +### SQL 集成测试 + +- ✅ 完整 SQL 语句中的负数支持 +- ✅ 非聚合查询中的负数表达式 +- ✅ 聚合查询中的负数处理 + +## 🎯 使用建议 + +### 1. 推荐的负数写法 + +```sql +-- ✅ 推荐:紧密连接的负数 +CASE WHEN temperature < -10 THEN 'FREEZING' END + +-- ✅ 推荐:括号包围的负数(最安全) +CASE WHEN temperature < (-10) THEN 'FREEZING' END + +-- ✅ 推荐:负数小数 +CASE WHEN temperature < -10.5 THEN 'FREEZING' END +``` + +### 2. 避免的写法 + +```sql +-- ❌ 避免:空格分隔的负数 +CASE WHEN temperature < - 10 THEN 'FREEZING' END + +-- ❌ 避免:复杂的负数表达式在函数中 +CASE WHEN ABS(-temperature) > 10 THEN 1 END +``` + +### 3. 最佳实践 + +1. **使用括号**:当不确定负数解析时,总是使用括号包围负数 +2. **避免空格**:在负号和数字之间不要添加空格 +3. **测试验证**:对包含负数的复杂表达式进行充分测试 +4. **版本兼容**:确保使用的 StreamSQL 版本支持所需的负数功能 + +## 🚀 未来改进计划 + +1. **完全支持函数参数中的负数表达式** +2. **支持 BETWEEN 语句中的负数范围** +3. **改进 SQL 解析器对空格分隔负数的处理** +4. **扩展负数支持到更多数学和字符串函数** + +## 示例代码 + +```go +package main + +import ( + "fmt" + "github.com/rulego/streamsql" +) + +func main() { + // 创建 StreamSQL 实例 + sql := streamsql.New() + defer sql.Stop() + + // 包含负数的 SQL 查询 + query := ` + SELECT deviceId, + temperature, + CASE + WHEN temperature < -10 THEN 'FREEZING' + WHEN temperature < 0 THEN 'COLD' + WHEN temperature = 0 THEN 'ZERO' + ELSE 'POSITIVE' + END as temp_category, + CASE + WHEN temperature > 0 THEN temperature + ELSE (-1.0) + END as adjusted_temp + FROM stream + ` + + // 执行查询 + err := sql.Execute(query) + if err != nil { + fmt.Printf("执行失败: %v\n", err) + return + } + + // 添加数据处理器 + sql.AddSink(func(result interface{}) { + fmt.Printf("结果: %+v\n", result) + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "temperature": -15.0}, + {"deviceId": "sensor2", "temperature": -5.0}, + {"deviceId": "sensor3", "temperature": 0.0}, + {"deviceId": "sensor4", "temperature": 10.0}, + } + + for _, data := range testData { + sql.AddData(data) + } +} +``` + +--- + +**更新日期**: 2025-06-17 +**版本**: StreamSQL v0.x +**作者**: StreamSQL 开发团队 \ No newline at end of file diff --git a/examples/advanced-functions/main.go b/examples/advanced-functions/main.go index a99f783..ea0acdb 100644 --- a/examples/advanced-functions/main.go +++ b/examples/advanced-functions/main.go @@ -51,7 +51,7 @@ func main() { fmt.Println("✓ SQL执行成功") // 5. 添加结果监听器 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf("📊 聚合结果: %v\n", result) }) @@ -69,7 +69,7 @@ func main() { for _, data := range sensorData { fmt.Printf(" 设备: %s, 温度: %.1f°F, 湿度: %.1f%%\n", data["device"], data["temperature"], data["humidity"]) - ssql.AddData(data) + ssql.Emit(data) } // 7. 等待处理完成 diff --git a/examples/complex-nested-access/main.go b/examples/complex-nested-access/main.go index b79bbb2..7d5f418 100644 --- a/examples/complex-nested-access/main.go +++ b/examples/complex-nested-access/main.go @@ -1,437 +1,437 @@ -package main - -import ( - "context" - "fmt" - "math/rand" - "sync" - "time" - - "github.com/rulego/streamsql" -) - -func main() { - fmt.Println("🔧 StreamSQL 复杂嵌套字段访问功能演示") - fmt.Println("=======================================") - - // 创建 StreamSQL 实例 - ssql := streamsql.New() - defer ssql.Stop() - - // 演示1: 数组索引访问 - fmt.Println("\n📊 演示1: 数组索引访问") - demonstrateArrayAccess(ssql) - - // 演示2: Map键访问 - fmt.Println("\n🗝️ 演示2: Map键访问") - demonstrateMapKeyAccess(ssql) - - // 演示3: 混合复杂访问 - fmt.Println("\n🔄 演示3: 混合复杂访问") - demonstrateComplexMixedAccess(ssql) - - // 演示4: 负数索引访问 - fmt.Println("\n⬅️ 演示4: 负数索引访问") - demonstrateNegativeIndexAccess(ssql) - - // 演示5: 数组索引聚合计算 - fmt.Println("\n📈 演示5: 数组索引聚合计算") - demonstrateArrayIndexAggregation(ssql) - - fmt.Println("\n✅ 演示完成!") -} - -// 演示数组索引访问 -func demonstrateArrayAccess(ssql *streamsql.Streamsql) { - // SQL查询:提取数组中的特定元素 - rsql := `SELECT device, - sensors[0].temperature as first_sensor_temp, - sensors[1].humidity as second_sensor_humidity, - data[2] as third_data_item - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备测试数据 - testData := []map[string]interface{}{ - { - "device": "工业传感器-001", - "sensors": []interface{}{ - map[string]interface{}{"temperature": 25.5, "humidity": 60.2}, - map[string]interface{}{"temperature": 26.8, "humidity": 58.7}, - map[string]interface{}{"temperature": 24.1, "humidity": 62.1}, - }, - "data": []interface{}{"status_ok", "battery_95%", "signal_strong", "location_A1"}, - "timestamp": time.Now().Unix(), - }, - { - "device": "环境监测器-002", - "sensors": []interface{}{ - map[string]interface{}{"temperature": 22.3, "humidity": 65.8}, - map[string]interface{}{"temperature": 23.1, "humidity": 63.2}, - }, - "data": []interface{}{"status_warning", "battery_78%", "signal_weak"}, - "timestamp": time.Now().Unix(), - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 📋 数组索引访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 设备: %v\n", item["device"]) - fmt.Printf(" 第一个传感器温度: %v°C\n", item["first_sensor_temp"]) - fmt.Printf(" 第二个传感器湿度: %v%%\n", item["second_sensor_humidity"]) - fmt.Printf(" 第三个数据项: %v\n", item["third_data_item"]) - fmt.Println() - } - } - }) - - // 添加测试数据 - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待结果 - wg.Wait() -} - -// 演示Map键访问 -func demonstrateMapKeyAccess(ssql *streamsql.Streamsql) { - // SQL查询:使用字符串键访问Map数据 - rsql := `SELECT device_id, - config['host'] as server_host, - config["port"] as server_port, - settings['enable_ssl'] as ssl_enabled, - metadata["version"] as app_version - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备测试数据 - testData := []map[string]interface{}{ - { - "device_id": "gateway-001", - "config": map[string]interface{}{ - "host": "192.168.1.100", - "port": 8080, - "protocol": "https", - }, - "settings": map[string]interface{}{ - "enable_ssl": true, - "timeout": 30, - "max_retries": 3, - }, - "metadata": map[string]interface{}{ - "version": "v2.1.3", - "build_date": "2023-12-01", - "vendor": "TechCorp", - }, - }, - { - "device_id": "gateway-002", - "config": map[string]interface{}{ - "host": "192.168.1.101", - "port": 8443, - "protocol": "https", - }, - "settings": map[string]interface{}{ - "enable_ssl": false, - "timeout": 60, - "max_retries": 5, - }, - "metadata": map[string]interface{}{ - "version": "v2.0.8", - "build_date": "2023-11-15", - "vendor": "TechCorp", - }, - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 🗝️ Map键访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 设备ID: %v\n", item["device_id"]) - fmt.Printf(" 服务器主机: %v\n", item["server_host"]) - fmt.Printf(" 服务器端口: %v\n", item["server_port"]) - fmt.Printf(" SSL启用: %v\n", item["ssl_enabled"]) - fmt.Printf(" 应用版本: %v\n", item["app_version"]) - fmt.Println() - } - } - }) - - // 添加测试数据 - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待结果 - wg.Wait() -} - -// 演示混合复杂访问 -func demonstrateComplexMixedAccess(ssql *streamsql.Streamsql) { - // SQL查询:混合使用数组索引、Map键和嵌套字段访问 - rsql := `SELECT building, - floors[0].rooms[2]['name'] as first_floor_room3_name, - floors[1].sensors[0].readings['temperature'] as second_floor_first_sensor_temp, - metadata.building_info['architect'] as building_architect, - alerts[-1].message as latest_alert - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备复杂嵌套数据 - testData := map[string]interface{}{ - "building": "智能大厦A座", - "floors": []interface{}{ - // 第一层 - map[string]interface{}{ - "floor_number": 1, - "rooms": []interface{}{ - map[string]interface{}{"name": "大厅", "type": "public"}, - map[string]interface{}{"name": "接待室", "type": "office"}, - map[string]interface{}{"name": "会议室A", "type": "meeting"}, - map[string]interface{}{"name": "休息区", "type": "lounge"}, - }, - }, - // 第二层 - map[string]interface{}{ - "floor_number": 2, - "sensors": []interface{}{ - map[string]interface{}{ - "id": "sensor-201", - "readings": map[string]interface{}{ - "temperature": 23.5, - "humidity": 58.2, - "co2": 420, - }, - }, - map[string]interface{}{ - "id": "sensor-202", - "readings": map[string]interface{}{ - "temperature": 24.1, - "humidity": 60.8, - "co2": 380, - }, - }, - }, - }, - }, - "metadata": map[string]interface{}{ - "building_info": map[string]interface{}{ - "architect": "张建筑师", - "year_built": 2020, - "total_floors": 25, - }, - "owner": "科技园管委会", - }, - "alerts": []interface{}{ - map[string]interface{}{"level": "info", "message": "系统启动完成"}, - map[string]interface{}{"level": "warning", "message": "传感器信号弱"}, - map[string]interface{}{"level": "info", "message": "定期维护提醒"}, - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 🔄 混合复杂访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 建筑: %v\n", item["building"]) - fmt.Printf(" 一层第3个房间: %v\n", item["first_floor_room3_name"]) - fmt.Printf(" 二层第1个传感器温度: %v°C\n", item["second_floor_first_sensor_temp"]) - fmt.Printf(" 建筑师: %v\n", item["building_architect"]) - fmt.Printf(" 最新警报: %v\n", item["latest_alert"]) - fmt.Println() - } - } - }) - - // 添加数据 - ssql.Stream().AddData(testData) - - // 等待结果 - wg.Wait() -} - -// 演示负数索引访问 -func demonstrateNegativeIndexAccess(ssql *streamsql.Streamsql) { - // SQL查询:使用负数索引访问数组末尾元素 - rsql := `SELECT device_name, - readings[-1] as latest_reading, - history[-2] as second_last_event, - tags[-1] as last_tag - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备测试数据 - testData := []map[string]interface{}{ - { - "device_name": "温度监测器-Alpha", - "readings": []interface{}{18.5, 19.2, 20.1, 21.3, 22.8, 23.5}, // [-1] = 23.5 - "history": []interface{}{"boot", "calibration", "running", "alert", "resolved"}, // [-2] = "alert" - "tags": []interface{}{"indoor", "critical", "monitored"}, // [-1] = "monitored" - }, - { - "device_name": "湿度传感器-Beta", - "readings": []interface{}{45.2, 47.8, 52.1, 48.9}, // [-1] = 48.9 - "history": []interface{}{"init", "testing", "deployed"}, // [-2] = "testing" - "tags": []interface{}{"outdoor", "backup"}, // [-1] = "backup" - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" ⬅️ 负数索引访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 设备名称: %v\n", item["device_name"]) - fmt.Printf(" 最新读数: %v\n", item["latest_reading"]) - fmt.Printf(" 倒数第二个事件: %v\n", item["second_last_event"]) - fmt.Printf(" 最后一个标签: %v\n", item["last_tag"]) - fmt.Println() - } - } - }) - - // 添加测试数据 - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待结果 - wg.Wait() -} - -// 演示数组索引聚合计算 -func demonstrateArrayIndexAggregation(ssql *streamsql.Streamsql) { - // SQL查询:对数组中特定位置的数据进行聚合计算 - rsql := `SELECT location, - AVG(sensors[0].temperature) as avg_first_sensor_temp, - MAX(sensors[1].humidity) as max_second_sensor_humidity, - COUNT(*) as device_count - FROM stream - GROUP BY location, TumblingWindow('2s') - WITH (TIMESTAMP='timestamp', TIMEUNIT='ss')` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - var resultCount int - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 📈 数组索引聚合计算结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - resultCount++ - fmt.Printf(" 聚合结果 %d:\n", i+1) - fmt.Printf(" 位置: %v\n", item["location"]) - fmt.Printf(" 第一个传感器平均温度: %.2f°C\n", item["avg_first_sensor_temp"]) - fmt.Printf(" 第二个传感器最大湿度: %.1f%%\n", item["max_second_sensor_humidity"]) - fmt.Printf(" 设备数量: %v\n", item["device_count"]) - fmt.Println() - } - } - }) - - // 生成模拟数据 - locations := []string{"车间A", "车间B", "车间C"} - - go func() { - for i := 0; i < 12; i++ { - location := locations[rand.Intn(len(locations))] - - data := map[string]interface{}{ - "device_id": fmt.Sprintf("device-%03d", i+1), - "location": location, - "sensors": []interface{}{ - map[string]interface{}{ - "temperature": 20.0 + rand.Float64()*10.0, // 20-30°C - "humidity": 50.0 + rand.Float64()*20.0, // 50-70% - }, - map[string]interface{}{ - "temperature": 18.0 + rand.Float64()*12.0, // 18-30°C - "humidity": 45.0 + rand.Float64()*25.0, // 45-70% - }, - }, - "timestamp": time.Now().Unix(), - } - - ssql.Stream().AddData(data) - time.Sleep(200 * time.Millisecond) // 每200ms发送一条数据 - } - }() - - // 等待聚合结果 - ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) - defer cancel() - - select { - case <-ctx.Done(): - fmt.Println(" ⏰ 聚合计算超时") - case <-func() chan struct{} { - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - return done - }(): - fmt.Printf(" ✅ 聚合计算完成,共生成 %d 个窗口结果\n", resultCount) - } -} +package main + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/rulego/streamsql" +) + +func main() { + fmt.Println("🔧 StreamSQL 复杂嵌套字段访问功能演示") + fmt.Println("=======================================") + + // 创建 StreamSQL 实例 + ssql := streamsql.New() + defer ssql.Stop() + + // 演示1: 数组索引访问 + fmt.Println("\n📊 演示1: 数组索引访问") + demonstrateArrayAccess(ssql) + + // 演示2: Map键访问 + fmt.Println("\n🗝️ 演示2: Map键访问") + demonstrateMapKeyAccess(ssql) + + // 演示3: 混合复杂访问 + fmt.Println("\n🔄 演示3: 混合复杂访问") + demonstrateComplexMixedAccess(ssql) + + // 演示4: 负数索引访问 + fmt.Println("\n⬅️ 演示4: 负数索引访问") + demonstrateNegativeIndexAccess(ssql) + + // 演示5: 数组索引聚合计算 + fmt.Println("\n📈 演示5: 数组索引聚合计算") + demonstrateArrayIndexAggregation(ssql) + + fmt.Println("\n✅ 演示完成!") +} + +// 演示数组索引访问 +func demonstrateArrayAccess(ssql *streamsql.Streamsql) { + // SQL查询:提取数组中的特定元素 + rsql := `SELECT device, + sensors[0].temperature as first_sensor_temp, + sensors[1].humidity as second_sensor_humidity, + data[2] as third_data_item + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备测试数据 + testData := []map[string]interface{}{ + { + "device": "工业传感器-001", + "sensors": []interface{}{ + map[string]interface{}{"temperature": 25.5, "humidity": 60.2}, + map[string]interface{}{"temperature": 26.8, "humidity": 58.7}, + map[string]interface{}{"temperature": 24.1, "humidity": 62.1}, + }, + "data": []interface{}{"status_ok", "battery_95%", "signal_strong", "location_A1"}, + "timestamp": time.Now().Unix(), + }, + { + "device": "环境监测器-002", + "sensors": []interface{}{ + map[string]interface{}{"temperature": 22.3, "humidity": 65.8}, + map[string]interface{}{"temperature": 23.1, "humidity": 63.2}, + }, + "data": []interface{}{"status_warning", "battery_78%", "signal_weak"}, + "timestamp": time.Now().Unix(), + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 📋 数组索引访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 设备: %v\n", item["device"]) + fmt.Printf(" 第一个传感器温度: %v°C\n", item["first_sensor_temp"]) + fmt.Printf(" 第二个传感器湿度: %v%%\n", item["second_sensor_humidity"]) + fmt.Printf(" 第三个数据项: %v\n", item["third_data_item"]) + fmt.Println() + } + } + }) + + // 添加测试数据 + for _, data := range testData { + ssql.Emit(data) + } + + // 等待结果 + wg.Wait() +} + +// 演示Map键访问 +func demonstrateMapKeyAccess(ssql *streamsql.Streamsql) { + // SQL查询:使用字符串键访问Map数据 + rsql := `SELECT device_id, + config['host'] as server_host, + config["port"] as server_port, + settings['enable_ssl'] as ssl_enabled, + metadata["version"] as app_version + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备测试数据 + testData := []map[string]interface{}{ + { + "device_id": "gateway-001", + "config": map[string]interface{}{ + "host": "192.168.1.100", + "port": 8080, + "protocol": "https", + }, + "settings": map[string]interface{}{ + "enable_ssl": true, + "timeout": 30, + "max_retries": 3, + }, + "metadata": map[string]interface{}{ + "version": "v2.1.3", + "build_date": "2023-12-01", + "vendor": "TechCorp", + }, + }, + { + "device_id": "gateway-002", + "config": map[string]interface{}{ + "host": "192.168.1.101", + "port": 8443, + "protocol": "https", + }, + "settings": map[string]interface{}{ + "enable_ssl": false, + "timeout": 60, + "max_retries": 5, + }, + "metadata": map[string]interface{}{ + "version": "v2.0.8", + "build_date": "2023-11-15", + "vendor": "TechCorp", + }, + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 🗝️ Map键访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 设备ID: %v\n", item["device_id"]) + fmt.Printf(" 服务器主机: %v\n", item["server_host"]) + fmt.Printf(" 服务器端口: %v\n", item["server_port"]) + fmt.Printf(" SSL启用: %v\n", item["ssl_enabled"]) + fmt.Printf(" 应用版本: %v\n", item["app_version"]) + fmt.Println() + } + } + }) + + // 添加测试数据 + for _, data := range testData { + ssql.Emit(data) + } + + // 等待结果 + wg.Wait() +} + +// 演示混合复杂访问 +func demonstrateComplexMixedAccess(ssql *streamsql.Streamsql) { + // SQL查询:混合使用数组索引、Map键和嵌套字段访问 + rsql := `SELECT building, + floors[0].rooms[2]['name'] as first_floor_room3_name, + floors[1].sensors[0].readings['temperature'] as second_floor_first_sensor_temp, + metadata.building_info['architect'] as building_architect, + alerts[-1].message as latest_alert + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备复杂嵌套数据 + testData := map[string]interface{}{ + "building": "智能大厦A座", + "floors": []interface{}{ + // 第一层 + map[string]interface{}{ + "floor_number": 1, + "rooms": []interface{}{ + map[string]interface{}{"name": "大厅", "type": "public"}, + map[string]interface{}{"name": "接待室", "type": "office"}, + map[string]interface{}{"name": "会议室A", "type": "meeting"}, + map[string]interface{}{"name": "休息区", "type": "lounge"}, + }, + }, + // 第二层 + map[string]interface{}{ + "floor_number": 2, + "sensors": []interface{}{ + map[string]interface{}{ + "id": "sensor-201", + "readings": map[string]interface{}{ + "temperature": 23.5, + "humidity": 58.2, + "co2": 420, + }, + }, + map[string]interface{}{ + "id": "sensor-202", + "readings": map[string]interface{}{ + "temperature": 24.1, + "humidity": 60.8, + "co2": 380, + }, + }, + }, + }, + }, + "metadata": map[string]interface{}{ + "building_info": map[string]interface{}{ + "architect": "张建筑师", + "year_built": 2020, + "total_floors": 25, + }, + "owner": "科技园管委会", + }, + "alerts": []interface{}{ + map[string]interface{}{"level": "info", "message": "系统启动完成"}, + map[string]interface{}{"level": "warning", "message": "传感器信号弱"}, + map[string]interface{}{"level": "info", "message": "定期维护提醒"}, + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 🔄 混合复杂访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 建筑: %v\n", item["building"]) + fmt.Printf(" 一层第3个房间: %v\n", item["first_floor_room3_name"]) + fmt.Printf(" 二层第1个传感器温度: %v°C\n", item["second_floor_first_sensor_temp"]) + fmt.Printf(" 建筑师: %v\n", item["building_architect"]) + fmt.Printf(" 最新警报: %v\n", item["latest_alert"]) + fmt.Println() + } + } + }) + + // 添加数据 + ssql.Emit(testData) + + // 等待结果 + wg.Wait() +} + +// 演示负数索引访问 +func demonstrateNegativeIndexAccess(ssql *streamsql.Streamsql) { + // SQL查询:使用负数索引访问数组末尾元素 + rsql := `SELECT device_name, + readings[-1] as latest_reading, + history[-2] as second_last_event, + tags[-1] as last_tag + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备测试数据 + testData := []map[string]interface{}{ + { + "device_name": "温度监测器-Alpha", + "readings": []interface{}{18.5, 19.2, 20.1, 21.3, 22.8, 23.5}, // [-1] = 23.5 + "history": []interface{}{"boot", "calibration", "running", "alert", "resolved"}, // [-2] = "alert" + "tags": []interface{}{"indoor", "critical", "monitored"}, // [-1] = "monitored" + }, + { + "device_name": "湿度传感器-Beta", + "readings": []interface{}{45.2, 47.8, 52.1, 48.9}, // [-1] = 48.9 + "history": []interface{}{"init", "testing", "deployed"}, // [-2] = "testing" + "tags": []interface{}{"outdoor", "backup"}, // [-1] = "backup" + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" ⬅️ 负数索引访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 设备名称: %v\n", item["device_name"]) + fmt.Printf(" 最新读数: %v\n", item["latest_reading"]) + fmt.Printf(" 倒数第二个事件: %v\n", item["second_last_event"]) + fmt.Printf(" 最后一个标签: %v\n", item["last_tag"]) + fmt.Println() + } + } + }) + + // 添加测试数据 + for _, data := range testData { + ssql.Emit(data) + } + + // 等待结果 + wg.Wait() +} + +// 演示数组索引聚合计算 +func demonstrateArrayIndexAggregation(ssql *streamsql.Streamsql) { + // SQL查询:对数组中特定位置的数据进行聚合计算 + rsql := `SELECT location, + AVG(sensors[0].temperature) as avg_first_sensor_temp, + MAX(sensors[1].humidity) as max_second_sensor_humidity, + COUNT(*) as device_count + FROM stream + GROUP BY location, TumblingWindow('2s') + WITH (TIMESTAMP='timestamp', TIMEUNIT='ss')` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + var resultCount int + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 📈 数组索引聚合计算结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + resultCount++ + fmt.Printf(" 聚合结果 %d:\n", i+1) + fmt.Printf(" 位置: %v\n", item["location"]) + fmt.Printf(" 第一个传感器平均温度: %.2f°C\n", item["avg_first_sensor_temp"]) + fmt.Printf(" 第二个传感器最大湿度: %.1f%%\n", item["max_second_sensor_humidity"]) + fmt.Printf(" 设备数量: %v\n", item["device_count"]) + fmt.Println() + } + } + }) + + // 生成模拟数据 + locations := []string{"车间A", "车间B", "车间C"} + + go func() { + for i := 0; i < 12; i++ { + location := locations[rand.Intn(len(locations))] + + data := map[string]interface{}{ + "device_id": fmt.Sprintf("device-%03d", i+1), + "location": location, + "sensors": []interface{}{ + map[string]interface{}{ + "temperature": 20.0 + rand.Float64()*10.0, // 20-30°C + "humidity": 50.0 + rand.Float64()*20.0, // 50-70% + }, + map[string]interface{}{ + "temperature": 18.0 + rand.Float64()*12.0, // 18-30°C + "humidity": 45.0 + rand.Float64()*25.0, // 45-70% + }, + }, + "timestamp": time.Now().Unix(), + } + + ssql.Emit(data) + time.Sleep(200 * time.Millisecond) // 每200ms发送一条数据 + } + }() + + // 等待聚合结果 + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + select { + case <-ctx.Done(): + fmt.Println(" ⏰ 聚合计算超时") + case <-func() chan struct{} { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + return done + }(): + fmt.Printf(" ✅ 聚合计算完成,共生成 %d 个窗口结果\n", resultCount) + } +} diff --git a/examples/comprehensive-test/README.md b/examples/comprehensive-test/README.md new file mode 100644 index 0000000..08c2306 --- /dev/null +++ b/examples/comprehensive-test/README.md @@ -0,0 +1,92 @@ +# StreamSQL 综合测试演示 + +这个示例提供了一个统一的入口来测试和验证StreamSQL的各种功能特性。 + +## 功能覆盖 + +### 1. 基础数据过滤 +- 简单的WHERE条件过滤 +- 实时数据流处理 +- 结果回调处理 + +### 2. 聚合分析 +- 滚动窗口聚合(TumblingWindow) +- 多种聚合函数:AVG、COUNT、MAX、MIN +- 按字段分组 + +### 3. 滑动窗口 +- 滑动窗口分析(SlidingWindow) +- 窗口大小和滑动间隔配置 +- 连续数据流处理 + +### 4. 嵌套字段访问 +- 多层嵌套对象访问 +- 复杂数据结构处理 +- 嵌套字段条件过滤 + +### 5. 自定义函数 +- 数学函数(square、circle_area) +- 转换函数(f_to_c) +- 函数注册和使用 + +### 6. 复杂查询 +- 多种功能组合使用 +- 嵌套字段 + 自定义函数 + 聚合 +- 复杂业务场景模拟 + +## 运行方式 + +```bash +cd examples\comprehensive-test +go run main.go +``` + +## 预期输出 + +程序会依次执行6个测试场景,每个场景都会输出相应的结果: + +1. **基础过滤测试**:显示温度大于25度的设备告警 +2. **聚合分析测试**:显示每个设备的温度统计信息 +3. **滑动窗口测试**:显示滑动窗口内的温度分析 +4. **嵌套字段测试**:显示复杂数据结构的字段提取 +5. **自定义函数测试**:显示自定义函数的计算结果 +6. **复杂查询测试**:显示综合功能的查询结果 + +## 测试数据 + +- **传感器数据**:包含设备ID、温度、湿度等信息 +- **嵌套结构**:设备信息、位置信息、传感器数据的多层嵌套 +- **随机数据**:使用随机数生成模拟真实的传感器数据流 + +## 自定义函数说明 + +### square(x) +- **功能**:计算数值的平方 +- **参数**:数值 +- **返回**:平方值 + +### f_to_c(fahrenheit) +- **功能**:华氏度转摄氏度 +- **参数**:华氏度温度值 +- **返回**:摄氏度温度值 +- **公式**:(F - 32) × 5/9 + +### circle_area(radius) +- **功能**:计算圆的面积 +- **参数**:半径 +- **返回**:圆的面积 +- **公式**:π × r² + +## 注意事项 + +1. **窗口触发**:聚合查询需要等待窗口时间到达或手动触发 +2. **数据格式**:确保输入数据格式正确,特别是嵌套字段的结构 +3. **函数注册**:自定义函数需要在使用前注册 +4. **资源清理**:使用defer确保StreamSQL实例正确关闭 + +## 扩展建议 + +- 可以添加更多的自定义函数 +- 可以测试更复杂的窗口配置 +- 可以添加错误处理和异常数据测试 +- 可以集成性能测试和压力测试 \ No newline at end of file diff --git a/examples/comprehensive-test/main.go b/examples/comprehensive-test/main.go new file mode 100644 index 0000000..1949404 --- /dev/null +++ b/examples/comprehensive-test/main.go @@ -0,0 +1,428 @@ +package main + +import ( + "fmt" + "math" + "math/rand" + "time" + + "github.com/rulego/streamsql" + "github.com/rulego/streamsql/functions" + "github.com/rulego/streamsql/utils/cast" +) + +func main() { + fmt.Println("🚀 StreamSQL 综合测试演示") + fmt.Println("=============================") + + // 注册自定义函数 + registerCustomFunctions() + + // 运行各种测试场景 + runAllTests() + + fmt.Println("\n✅ 所有测试完成!") +} + +// 注册自定义函数 +func registerCustomFunctions() { + fmt.Println("\n📋 注册自定义函数...") + + // 数学函数:平方 + err := functions.RegisterCustomFunction( + "square", + functions.TypeMath, + "数学函数", + "计算平方", + 1, 1, + func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + val := cast.ToFloat64(args[0]) + return val * val, nil + }, + ) + if err != nil { + fmt.Printf("❌ 注册square函数失败: %v\n", err) + } else { + fmt.Println(" ✓ 注册数学函数: square") + } + + // 华氏度转摄氏度函数 + err = functions.RegisterCustomFunction( + "f_to_c", + functions.TypeConversion, + "温度转换", + "华氏度转摄氏度", + 1, 1, + func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + fahrenheit := cast.ToFloat64(args[0]) + celsius := (fahrenheit - 32) * 5 / 9 + return celsius, nil + }, + ) + if err != nil { + fmt.Printf("❌ 注册f_to_c函数失败: %v\n", err) + } else { + fmt.Println(" ✓ 注册转换函数: f_to_c") + } + + // 圆面积计算函数 + err = functions.RegisterCustomFunction( + "circle_area", + functions.TypeMath, + "几何计算", + "计算圆的面积", + 1, 1, + func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + radius := cast.ToFloat64(args[0]) + if radius < 0 { + return nil, fmt.Errorf("半径必须为正数") + } + area := math.Pi * radius * radius + return area, nil + }, + ) + if err != nil { + fmt.Printf("❌ 注册circle_area函数失败: %v\n", err) + } else { + fmt.Println(" ✓ 注册几何函数: circle_area") + } +} + +// 运行所有测试 +func runAllTests() { + // 测试1:基础数据过滤 + testBasicFiltering() + + // 测试2:聚合分析 + testAggregation() + + // 测试3:滑动窗口 + testSlidingWindow() + + // 测试4:嵌套字段访问 + testNestedFields() + + // 测试5:自定义函数 + testCustomFunctions() + + // 测试6:复杂查询 + testComplexQuery() +} + +// 测试1:基础数据过滤 +func testBasicFiltering() { + fmt.Println("\n🔍 测试1:基础数据过滤") + fmt.Println("========================") + + ssql := streamsql.New() + defer ssql.Stop() + + // 过滤温度大于25度的数据 + sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 25" + + err := ssql.Execute(sql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 添加结果处理函数 + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 📊 高温告警: %v\n", result) + }) + + // 发送测试数据 + testData := []map[string]interface{}{ + {"deviceId": "sensor001", "temperature": 23.5}, // 不会触发告警 + {"deviceId": "sensor002", "temperature": 28.3}, // 会触发告警 + {"deviceId": "sensor003", "temperature": 31.2}, // 会触发告警 + {"deviceId": "sensor004", "temperature": 22.1}, // 不会触发告警 + } + + for _, data := range testData { + ssql.Emit(data) + time.Sleep(100 * time.Millisecond) + } + + time.Sleep(500 * time.Millisecond) + fmt.Println(" ✅ 基础过滤测试完成") +} + +// 测试2:聚合分析 +func testAggregation() { + fmt.Println("\n📈 测试2:聚合分析") + fmt.Println("==================") + + ssql := streamsql.New() + defer ssql.Stop() + + // 每2秒计算一次各设备的平均温度 + sql := `SELECT deviceId, + AVG(temperature) as avg_temp, + COUNT(*) as sample_count, + MAX(temperature) as max_temp, + MIN(temperature) as min_temp + FROM stream + GROUP BY deviceId, TumblingWindow('2s')` + + err := ssql.Execute(sql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 处理聚合结果 + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 📊 聚合结果: %v\n", result) + }) + + // 模拟传感器数据流 + devices := []string{"sensor001", "sensor002", "sensor003"} + for i := 0; i < 8; i++ { + for _, device := range devices { + data := map[string]interface{}{ + "deviceId": device, + "temperature": 20.0 + rand.Float64()*15, // 20-35度随机温度 + "timestamp": time.Now(), + } + ssql.Emit(data) + } + time.Sleep(300 * time.Millisecond) + } + + // 等待窗口触发 + time.Sleep(2 * time.Second) + ssql.Stream().Window.Trigger() + time.Sleep(500 * time.Millisecond) + fmt.Println(" ✅ 聚合分析测试完成") +} + +// 测试3:滑动窗口 +func testSlidingWindow() { + fmt.Println("\n🔄 测试3:滑动窗口") + fmt.Println("==================") + + ssql := streamsql.New() + defer ssql.Stop() + + // 6秒滑动窗口,每2秒滑动一次 + sql := `SELECT deviceId, + AVG(temperature) as avg_temp, + MAX(temperature) as max_temp, + MIN(temperature) as min_temp, + COUNT(*) as count + FROM stream + WHERE temperature > 0 + GROUP BY deviceId, SlidingWindow('6s', '2s')` + + err := ssql.Execute(sql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 📊 滑动窗口分析: %v\n", result) + }) + + // 持续发送数据 + for i := 0; i < 10; i++ { + data := map[string]interface{}{ + "deviceId": "sensor001", + "temperature": 20.0 + rand.Float64()*10, + "timestamp": time.Now(), + } + ssql.Emit(data) + time.Sleep(800 * time.Millisecond) + } + + time.Sleep(1 * time.Second) + fmt.Println(" ✅ 滑动窗口测试完成") +} + +// 测试4:嵌套字段访问 +func testNestedFields() { + fmt.Println("\n🔧 测试4:嵌套字段访问") + fmt.Println("=======================") + + ssql := streamsql.New() + defer ssql.Stop() + + // 访问嵌套字段的SQL查询 + sql := `SELECT device.info.name as device_name, + device.location.building as building, + sensor.temperature as temp, + UPPER(device.info.type) as device_type + FROM stream + WHERE sensor.temperature > 25 AND device.info.status = 'active'` + + err := ssql.Execute(sql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 📊 嵌套字段结果: %v\n", result) + }) + + // 发送嵌套结构数据 + complexData := []map[string]interface{}{ + { + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "name": "温度传感器001", + "type": "temperature", + "status": "active", + }, + "location": map[string]interface{}{ + "building": "A栋", + "floor": "3F", + }, + }, + "sensor": map[string]interface{}{ + "temperature": 28.5, + "humidity": 65.0, + }, + }, + { + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "name": "湿度传感器002", + "type": "humidity", + "status": "inactive", // 不会匹配 + }, + "location": map[string]interface{}{ + "building": "B栋", + "floor": "2F", + }, + }, + "sensor": map[string]interface{}{ + "temperature": 30.0, + "humidity": 70.0, + }, + }, + } + + for _, data := range complexData { + ssql.Emit(data) + time.Sleep(200 * time.Millisecond) + } + + time.Sleep(500 * time.Millisecond) + fmt.Println(" ✅ 嵌套字段测试完成") +} + +// 测试5:自定义函数 +func testCustomFunctions() { + fmt.Println("\n🎯 测试5:自定义函数") + fmt.Println("====================") + + ssql := streamsql.New() + defer ssql.Stop() + + // 使用自定义函数的SQL查询 + sql := `SELECT + device, + square(value) as squared_value, + f_to_c(temperature) as celsius, + circle_area(radius) as area + FROM stream + WHERE value > 0` + + err := ssql.Execute(sql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 📊 自定义函数结果: %v\n", result) + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + { + "device": "sensor1", + "value": 5.0, + "temperature": 68.0, // 华氏度 + "radius": 3.0, + }, + { + "device": "sensor2", + "value": 10.0, + "temperature": 86.0, // 华氏度 + "radius": 2.5, + }, + { + "device": "sensor3", + "value": 0.0, // 不会匹配WHERE条件 + "temperature": 32.0, + "radius": 1.0, + }, + } + + for _, data := range testData { + ssql.Emit(data) + time.Sleep(200 * time.Millisecond) + } + + time.Sleep(500 * time.Millisecond) + fmt.Println(" ✅ 自定义函数测试完成") +} + +// 测试6:复杂查询 +func testComplexQuery() { + fmt.Println("\n🔬 测试6:复杂查询") + fmt.Println("==================") + + ssql := streamsql.New() + defer ssql.Stop() + + // 复杂的聚合查询,结合自定义函数和嵌套字段 + sql := `SELECT + device.location as location, + AVG(square(sensor.temperature)) as avg_temp_squared, + MAX(f_to_c(sensor.temperature)) as max_celsius, + COUNT(*) as sample_count, + SUM(circle_area(device.radius)) as total_area + FROM stream + WHERE sensor.temperature > 20 AND device.status = 'online' + GROUP BY device.location, TumblingWindow('3s')` + + err := ssql.Execute(sql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 📊 复杂查询结果: %v\n", result) + }) + + // 发送复杂测试数据 + locations := []string{"room-A", "room-B", "room-C"} + for i := 0; i < 12; i++ { + location := locations[i%len(locations)] + data := map[string]interface{}{ + "device": map[string]interface{}{ + "location": location, + "status": "online", + "radius": 1.0 + rand.Float64()*2.0, // 1-3的随机半径 + }, + "sensor": map[string]interface{}{ + "temperature": 25.0 + rand.Float64()*10.0, // 25-35度 + "humidity": 50.0 + rand.Float64()*30.0, // 50-80% + }, + "timestamp": time.Now(), + } + ssql.Emit(data) + time.Sleep(300 * time.Millisecond) + } + + // 等待窗口触发 + time.Sleep(3 * time.Second) + ssql.Stream().Window.Trigger() + time.Sleep(500 * time.Millisecond) + fmt.Println(" ✅ 复杂查询测试完成") +} diff --git a/examples/custom-functions-demo/main.go b/examples/custom-functions-demo/main.go index c55f4b9..8ca9514 100644 --- a/examples/custom-functions-demo/main.go +++ b/examples/custom-functions-demo/main.go @@ -625,12 +625,12 @@ func testMathFunctions(ssql *streamsql.Streamsql) { } // 添加结果监听器 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf(" 📊 数学函数结果: %v\n", result) }) for _, data := range testData { - ssql.AddData(data) + ssql.Emit(data) } time.Sleep(1 * time.Second) @@ -672,12 +672,12 @@ func testStringFunctions(ssql *streamsql.Streamsql) { }, } - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf(" 📊 字符串函数结果: %v\n", result) }) for _, data := range testData { - ssql.AddData(data) + ssql.Emit(data) } time.Sleep(500 * time.Millisecond) @@ -715,12 +715,12 @@ func testConversionFunctions(ssql *streamsql.Streamsql) { }, } - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf(" 📊 转换函数结果: %v\n", result) }) for _, data := range testData { - ssql.AddData(data) + ssql.Emit(data) } time.Sleep(500 * time.Millisecond) @@ -753,12 +753,12 @@ func testAggregateFunctions(ssql *streamsql.Streamsql) { map[string]interface{}{"device": "sensor1", "value": 128.0, "category": "A"}, } - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf(" 📊 聚合函数结果: %v\n", result) }) for _, data := range testData { - ssql.AddData(data) + ssql.Emit(data) } time.Sleep(1 * time.Second) diff --git a/examples/nested-field-examples/main.go b/examples/nested-field-examples/main.go index 34066f8..62beacd 100644 --- a/examples/nested-field-examples/main.go +++ b/examples/nested-field-examples/main.go @@ -1,625 +1,625 @@ -package main - -import ( - "context" - "fmt" - "math/rand" - "sync" - "time" - - "github.com/rulego/streamsql" -) - -func main() { - fmt.Println("🔧 StreamSQL 嵌套字段访问功能完整演示") - fmt.Println("=========================================") - - // 创建 StreamSQL 实例 - ssql := streamsql.New() - defer ssql.Stop() - - // 基础功能演示 - fmt.Println("\n📊 第一部分:基础嵌套字段访问") - demonstrateBasicNestedAccess(ssql) - - // 基础聚合演示 - fmt.Println("\n📈 第二部分:嵌套字段聚合") - demonstrateNestedAggregation(ssql) - - // 复杂功能演示 - fmt.Println("\n🔧 第三部分:复杂嵌套字段访问") - - // 演示1: 数组索引访问 - fmt.Println("\n📊 演示1: 数组索引访问") - demonstrateArrayAccess(ssql) - - // 演示2: Map键访问 - fmt.Println("\n🗝️ 演示2: Map键访问") - demonstrateMapKeyAccess(ssql) - - // 演示3: 混合复杂访问 - fmt.Println("\n🔄 演示3: 混合复杂访问") - demonstrateComplexMixedAccess(ssql) - - // 演示4: 负数索引访问 - fmt.Println("\n⬅️ 演示4: 负数索引访问") - demonstrateNegativeIndexAccess(ssql) - - // 演示5: 数组索引聚合计算 - fmt.Println("\n📈 演示5: 数组索引聚合计算") - demonstrateArrayIndexAggregation(ssql) - - fmt.Println("\n✅ 完整演示完成!") -} - -// 演示基础嵌套字段访问 -func demonstrateBasicNestedAccess(ssql *streamsql.Streamsql) { - // SQL查询使用基础嵌套字段 - rsql := `SELECT device.info.name as device_name, - device.location, - sensor.temperature, - sensor.humidity - FROM stream - WHERE device.location = 'room-A' - AND sensor.temperature > 20` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备测试数据 - testData := []map[string]interface{}{ - { - "device": map[string]interface{}{ - "info": map[string]interface{}{ - "name": "温度传感器-001", - "type": "temperature", - }, - "location": "room-A", - }, - "sensor": map[string]interface{}{ - "temperature": 25.5, - "humidity": 60.2, - }, - "timestamp": time.Now().Unix(), - }, - { - "device": map[string]interface{}{ - "info": map[string]interface{}{ - "name": "温度传感器-002", - "type": "temperature", - }, - "location": "room-B", // 不匹配条件 - }, - "sensor": map[string]interface{}{ - "temperature": 30.0, - "humidity": 55.8, - }, - "timestamp": time.Now().Unix(), - }, - { - "device": map[string]interface{}{ - "info": map[string]interface{}{ - "name": "温度传感器-003", - "type": "temperature", - }, - "location": "room-A", - }, - "sensor": map[string]interface{}{ - "temperature": 15.0, // 不匹配条件 - "humidity": 65.3, - }, - "timestamp": time.Now().Unix(), - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 📋 基础嵌套字段访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 设备名称: %v\n", item["device_name"]) - fmt.Printf(" 设备位置: %v\n", item["device.location"]) - fmt.Printf(" 温度: %v°C\n", item["sensor.temperature"]) - fmt.Printf(" 湿度: %v%%\n", item["sensor.humidity"]) - fmt.Println() - } - } - }) - - // 添加测试数据 - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待结果 - wg.Wait() -} - -// 演示嵌套字段聚合 -func demonstrateNestedAggregation(ssql *streamsql.Streamsql) { - // SQL查询:嵌套字段聚合 - rsql := `SELECT device.location, - AVG(sensor.temperature) as avg_temp, - MAX(sensor.humidity) as max_humidity, - COUNT(*) as sensor_count - FROM stream - GROUP BY device.location, TumblingWindow('2s') - WITH (TIMESTAMP='timestamp', TIMEUNIT='ss')` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - var resultCount int - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 📈 嵌套字段聚合结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - resultCount++ - fmt.Printf(" 聚合结果 %d:\n", i+1) - fmt.Printf(" 位置: %v\n", item["device.location"]) - fmt.Printf(" 平均温度: %.2f°C\n", item["avg_temp"]) - fmt.Printf(" 最大湿度: %.1f%%\n", item["max_humidity"]) - fmt.Printf(" 传感器数量: %v\n", item["sensor_count"]) - fmt.Println() - } - } - }) - - // 生成模拟数据 - locations := []string{"智能温室-A区", "智能温室-B区", "智能温室-C区"} - - go func() { - for i := 0; i < 9; i++ { - location := locations[rand.Intn(len(locations))] - - data := map[string]interface{}{ - "device": map[string]interface{}{ - "info": map[string]interface{}{ - "name": fmt.Sprintf("sensor-%03d", i+1), - "type": "environment", - }, - "location": location, - }, - "sensor": map[string]interface{}{ - "temperature": 18.0 + rand.Float64()*15.0, // 18-33°C - "humidity": 40.0 + rand.Float64()*30.0, // 40-70% - }, - "timestamp": time.Now().Unix(), - } - - ssql.Stream().AddData(data) - time.Sleep(300 * time.Millisecond) // 每300ms发送一条数据 - } - }() - - // 等待聚合结果 - ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) - defer cancel() - - select { - case <-ctx.Done(): - fmt.Println(" ⏰ 聚合计算超时") - case <-func() chan struct{} { - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - return done - }(): - fmt.Printf(" ✅ 聚合计算完成,共生成 %d 个窗口结果\n", resultCount) - } -} - -// 演示数组索引访问 -func demonstrateArrayAccess(ssql *streamsql.Streamsql) { - // SQL查询:提取数组中的特定元素 - rsql := `SELECT device, - sensors[0].temperature as first_sensor_temp, - sensors[1].humidity as second_sensor_humidity, - data[2] as third_data_item - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备测试数据 - testData := []map[string]interface{}{ - { - "device": "工业传感器-001", - "sensors": []interface{}{ - map[string]interface{}{"temperature": 25.5, "humidity": 60.2}, - map[string]interface{}{"temperature": 26.8, "humidity": 58.7}, - map[string]interface{}{"temperature": 24.1, "humidity": 62.1}, - }, - "data": []interface{}{"status_ok", "battery_95%", "signal_strong", "location_A1"}, - "timestamp": time.Now().Unix(), - }, - { - "device": "环境监测器-002", - "sensors": []interface{}{ - map[string]interface{}{"temperature": 22.3, "humidity": 65.8}, - map[string]interface{}{"temperature": 23.1, "humidity": 63.2}, - }, - "data": []interface{}{"status_warning", "battery_78%", "signal_weak"}, - "timestamp": time.Now().Unix(), - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 📋 数组索引访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 设备: %v\n", item["device"]) - fmt.Printf(" 第一个传感器温度: %v°C\n", item["first_sensor_temp"]) - fmt.Printf(" 第二个传感器湿度: %v%%\n", item["second_sensor_humidity"]) - fmt.Printf(" 第三个数据项: %v\n", item["third_data_item"]) - fmt.Println() - } - } - }) - - // 添加测试数据 - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待结果 - wg.Wait() -} - -// 演示Map键访问 -func demonstrateMapKeyAccess(ssql *streamsql.Streamsql) { - // SQL查询:使用字符串键访问Map数据 - rsql := `SELECT device_id, - config['host'] as server_host, - config["port"] as server_port, - settings['enable_ssl'] as ssl_enabled, - metadata["version"] as app_version - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备测试数据 - testData := []map[string]interface{}{ - { - "device_id": "gateway-001", - "config": map[string]interface{}{ - "host": "192.168.1.100", - "port": 8080, - "protocol": "https", - }, - "settings": map[string]interface{}{ - "enable_ssl": true, - "timeout": 30, - "max_retries": 3, - }, - "metadata": map[string]interface{}{ - "version": "v2.1.3", - "build_date": "2023-12-01", - "vendor": "TechCorp", - }, - }, - { - "device_id": "gateway-002", - "config": map[string]interface{}{ - "host": "192.168.1.101", - "port": 8443, - "protocol": "https", - }, - "settings": map[string]interface{}{ - "enable_ssl": false, - "timeout": 60, - "max_retries": 5, - }, - "metadata": map[string]interface{}{ - "version": "v2.0.8", - "build_date": "2023-11-15", - "vendor": "TechCorp", - }, - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 🗝️ Map键访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 设备ID: %v\n", item["device_id"]) - fmt.Printf(" 服务器主机: %v\n", item["server_host"]) - fmt.Printf(" 服务器端口: %v\n", item["server_port"]) - fmt.Printf(" SSL启用: %v\n", item["ssl_enabled"]) - fmt.Printf(" 应用版本: %v\n", item["app_version"]) - fmt.Println() - } - } - }) - - // 添加测试数据 - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待结果 - wg.Wait() -} - -// 演示混合复杂访问 -func demonstrateComplexMixedAccess(ssql *streamsql.Streamsql) { - // SQL查询:混合使用数组索引、Map键和嵌套字段访问 - rsql := `SELECT building, - floors[0].rooms[2]['name'] as first_floor_room3_name, - floors[1].sensors[0].readings['temperature'] as second_floor_first_sensor_temp, - metadata.building_info['architect'] as building_architect, - alerts[-1].message as latest_alert - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备复杂嵌套数据 - testData := map[string]interface{}{ - "building": "智能大厦A座", - "floors": []interface{}{ - // 第一层 - map[string]interface{}{ - "floor_number": 1, - "rooms": []interface{}{ - map[string]interface{}{"name": "大厅", "type": "public"}, - map[string]interface{}{"name": "接待室", "type": "office"}, - map[string]interface{}{"name": "会议室A", "type": "meeting"}, - map[string]interface{}{"name": "休息区", "type": "lounge"}, - }, - }, - // 第二层 - map[string]interface{}{ - "floor_number": 2, - "sensors": []interface{}{ - map[string]interface{}{ - "id": "sensor-201", - "readings": map[string]interface{}{ - "temperature": 23.5, - "humidity": 58.2, - "co2": 420, - }, - }, - map[string]interface{}{ - "id": "sensor-202", - "readings": map[string]interface{}{ - "temperature": 24.1, - "humidity": 60.8, - "co2": 380, - }, - }, - }, - }, - }, - "metadata": map[string]interface{}{ - "building_info": map[string]interface{}{ - "architect": "张建筑师", - "year_built": 2020, - "total_floors": 25, - }, - "owner": "科技园管委会", - }, - "alerts": []interface{}{ - map[string]interface{}{"level": "info", "message": "系统启动完成"}, - map[string]interface{}{"level": "warning", "message": "传感器信号弱"}, - map[string]interface{}{"level": "info", "message": "定期维护提醒"}, - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 🔄 混合复杂访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 建筑: %v\n", item["building"]) - fmt.Printf(" 一层第3个房间: %v\n", item["first_floor_room3_name"]) - fmt.Printf(" 二层第1个传感器温度: %v°C\n", item["second_floor_first_sensor_temp"]) - fmt.Printf(" 建筑师: %v\n", item["building_architect"]) - fmt.Printf(" 最新警报: %v\n", item["latest_alert"]) - fmt.Println() - } - } - }) - - // 添加数据 - ssql.Stream().AddData(testData) - - // 等待结果 - wg.Wait() -} - -// 演示负数索引访问 -func demonstrateNegativeIndexAccess(ssql *streamsql.Streamsql) { - // SQL查询:使用负数索引访问数组末尾元素 - rsql := `SELECT device_name, - readings[-1] as latest_reading, - history[-2] as second_last_event, - tags[-1] as last_tag - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - // 准备测试数据 - testData := []map[string]interface{}{ - { - "device_name": "温度监测器-Alpha", - "readings": []interface{}{18.5, 19.2, 20.1, 21.3, 22.8, 23.5}, // [-1] = 23.5 - "history": []interface{}{"boot", "calibration", "running", "alert", "resolved"}, // [-2] = "alert" - "tags": []interface{}{"indoor", "critical", "monitored"}, // [-1] = "monitored" - }, - { - "device_name": "湿度传感器-Beta", - "readings": []interface{}{45.2, 47.8, 52.1, 48.9}, // [-1] = 48.9 - "history": []interface{}{"init", "testing", "deployed"}, // [-2] = "testing" - "tags": []interface{}{"outdoor", "backup"}, // [-1] = "backup" - }, - } - - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" ⬅️ 负数索引访问结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - fmt.Printf(" 记录 %d:\n", i+1) - fmt.Printf(" 设备名称: %v\n", item["device_name"]) - fmt.Printf(" 最新读数: %v\n", item["latest_reading"]) - fmt.Printf(" 倒数第二个事件: %v\n", item["second_last_event"]) - fmt.Printf(" 最后一个标签: %v\n", item["last_tag"]) - fmt.Println() - } - } - }) - - // 添加测试数据 - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待结果 - wg.Wait() -} - -// 演示数组索引聚合计算 -func demonstrateArrayIndexAggregation(ssql *streamsql.Streamsql) { - // SQL查询:对数组中特定位置的数据进行聚合计算 - rsql := `SELECT location, - AVG(sensors[0].temperature) as avg_first_sensor_temp, - MAX(sensors[1].humidity) as max_second_sensor_humidity, - COUNT(*) as device_count - FROM stream - GROUP BY location, TumblingWindow('2s') - WITH (TIMESTAMP='timestamp', TIMEUNIT='ss')` - - err := ssql.Execute(rsql) - if err != nil { - fmt.Printf("❌ SQL执行失败: %v\n", err) - return - } - - var resultCount int - var wg sync.WaitGroup - wg.Add(1) - - // 设置结果回调 - ssql.Stream().AddSink(func(result interface{}) { - defer wg.Done() - - fmt.Println(" 📈 数组索引聚合计算结果:") - if resultSlice, ok := result.([]map[string]interface{}); ok { - for i, item := range resultSlice { - resultCount++ - fmt.Printf(" 聚合结果 %d:\n", i+1) - fmt.Printf(" 位置: %v\n", item["location"]) - fmt.Printf(" 第一个传感器平均温度: %.2f°C\n", item["avg_first_sensor_temp"]) - fmt.Printf(" 第二个传感器最大湿度: %.1f%%\n", item["max_second_sensor_humidity"]) - fmt.Printf(" 设备数量: %v\n", item["device_count"]) - fmt.Println() - } - } - }) - - // 生成模拟数据 - locations := []string{"车间A", "车间B", "车间C"} - - go func() { - for i := 0; i < 12; i++ { - location := locations[rand.Intn(len(locations))] - - data := map[string]interface{}{ - "device_id": fmt.Sprintf("device-%03d", i+1), - "location": location, - "sensors": []interface{}{ - map[string]interface{}{ - "temperature": 20.0 + rand.Float64()*10.0, // 20-30°C - "humidity": 50.0 + rand.Float64()*20.0, // 50-70% - }, - map[string]interface{}{ - "temperature": 18.0 + rand.Float64()*12.0, // 18-30°C - "humidity": 45.0 + rand.Float64()*25.0, // 45-70% - }, - }, - "timestamp": time.Now().Unix(), - } - - ssql.Stream().AddData(data) - time.Sleep(200 * time.Millisecond) // 每200ms发送一条数据 - } - }() - - // 等待聚合结果 - ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) - defer cancel() - - select { - case <-ctx.Done(): - fmt.Println(" ⏰ 聚合计算超时") - case <-func() chan struct{} { - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - return done - }(): - fmt.Printf(" ✅ 聚合计算完成,共生成 %d 个窗口结果\n", resultCount) - } -} +package main + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/rulego/streamsql" +) + +func main() { + fmt.Println("🔧 StreamSQL 嵌套字段访问功能完整演示") + fmt.Println("=========================================") + + // 创建 StreamSQL 实例 + ssql := streamsql.New() + defer ssql.Stop() + + // 基础功能演示 + fmt.Println("\n📊 第一部分:基础嵌套字段访问") + demonstrateBasicNestedAccess(ssql) + + // 基础聚合演示 + fmt.Println("\n📈 第二部分:嵌套字段聚合") + demonstrateNestedAggregation(ssql) + + // 复杂功能演示 + fmt.Println("\n🔧 第三部分:复杂嵌套字段访问") + + // 演示1: 数组索引访问 + fmt.Println("\n📊 演示1: 数组索引访问") + demonstrateArrayAccess(ssql) + + // 演示2: Map键访问 + fmt.Println("\n🗝️ 演示2: Map键访问") + demonstrateMapKeyAccess(ssql) + + // 演示3: 混合复杂访问 + fmt.Println("\n🔄 演示3: 混合复杂访问") + demonstrateComplexMixedAccess(ssql) + + // 演示4: 负数索引访问 + fmt.Println("\n⬅️ 演示4: 负数索引访问") + demonstrateNegativeIndexAccess(ssql) + + // 演示5: 数组索引聚合计算 + fmt.Println("\n📈 演示5: 数组索引聚合计算") + demonstrateArrayIndexAggregation(ssql) + + fmt.Println("\n✅ 完整演示完成!") +} + +// 演示基础嵌套字段访问 +func demonstrateBasicNestedAccess(ssql *streamsql.Streamsql) { + // SQL查询使用基础嵌套字段 + rsql := `SELECT device.info.name as device_name, + device.location, + sensor.temperature, + sensor.humidity + FROM stream + WHERE device.location = 'room-A' + AND sensor.temperature > 20` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备测试数据 + testData := []map[string]interface{}{ + { + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "name": "温度传感器-001", + "type": "temperature", + }, + "location": "room-A", + }, + "sensor": map[string]interface{}{ + "temperature": 25.5, + "humidity": 60.2, + }, + "timestamp": time.Now().Unix(), + }, + { + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "name": "温度传感器-002", + "type": "temperature", + }, + "location": "room-B", // 不匹配条件 + }, + "sensor": map[string]interface{}{ + "temperature": 30.0, + "humidity": 55.8, + }, + "timestamp": time.Now().Unix(), + }, + { + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "name": "温度传感器-003", + "type": "temperature", + }, + "location": "room-A", + }, + "sensor": map[string]interface{}{ + "temperature": 15.0, // 不匹配条件 + "humidity": 65.3, + }, + "timestamp": time.Now().Unix(), + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 📋 基础嵌套字段访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 设备名称: %v\n", item["device_name"]) + fmt.Printf(" 设备位置: %v\n", item["device.location"]) + fmt.Printf(" 温度: %v°C\n", item["sensor.temperature"]) + fmt.Printf(" 湿度: %v%%\n", item["sensor.humidity"]) + fmt.Println() + } + } + }) + + // 添加测试数据 + for _, data := range testData { + ssql.Emit(data) + } + + // 等待结果 + wg.Wait() +} + +// 演示嵌套字段聚合 +func demonstrateNestedAggregation(ssql *streamsql.Streamsql) { + // SQL查询:嵌套字段聚合 + rsql := `SELECT device.location, + AVG(sensor.temperature) as avg_temp, + MAX(sensor.humidity) as max_humidity, + COUNT(*) as sensor_count + FROM stream + GROUP BY device.location, TumblingWindow('2s') + WITH (TIMESTAMP='timestamp', TIMEUNIT='ss')` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + var resultCount int + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 📈 嵌套字段聚合结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + resultCount++ + fmt.Printf(" 聚合结果 %d:\n", i+1) + fmt.Printf(" 位置: %v\n", item["device.location"]) + fmt.Printf(" 平均温度: %.2f°C\n", item["avg_temp"]) + fmt.Printf(" 最大湿度: %.1f%%\n", item["max_humidity"]) + fmt.Printf(" 传感器数量: %v\n", item["sensor_count"]) + fmt.Println() + } + } + }) + + // 生成模拟数据 + locations := []string{"智能温室-A区", "智能温室-B区", "智能温室-C区"} + + go func() { + for i := 0; i < 9; i++ { + location := locations[rand.Intn(len(locations))] + + data := map[string]interface{}{ + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "name": fmt.Sprintf("sensor-%03d", i+1), + "type": "environment", + }, + "location": location, + }, + "sensor": map[string]interface{}{ + "temperature": 18.0 + rand.Float64()*15.0, // 18-33°C + "humidity": 40.0 + rand.Float64()*30.0, // 40-70% + }, + "timestamp": time.Now().Unix(), + } + + ssql.Emit(data) + time.Sleep(300 * time.Millisecond) // 每300ms发送一条数据 + } + }() + + // 等待聚合结果 + ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) + defer cancel() + + select { + case <-ctx.Done(): + fmt.Println(" ⏰ 聚合计算超时") + case <-func() chan struct{} { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + return done + }(): + fmt.Printf(" ✅ 聚合计算完成,共生成 %d 个窗口结果\n", resultCount) + } +} + +// 演示数组索引访问 +func demonstrateArrayAccess(ssql *streamsql.Streamsql) { + // SQL查询:提取数组中的特定元素 + rsql := `SELECT device, + sensors[0].temperature as first_sensor_temp, + sensors[1].humidity as second_sensor_humidity, + data[2] as third_data_item + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备测试数据 + testData := []map[string]interface{}{ + { + "device": "工业传感器-001", + "sensors": []interface{}{ + map[string]interface{}{"temperature": 25.5, "humidity": 60.2}, + map[string]interface{}{"temperature": 26.8, "humidity": 58.7}, + map[string]interface{}{"temperature": 24.1, "humidity": 62.1}, + }, + "data": []interface{}{"status_ok", "battery_95%", "signal_strong", "location_A1"}, + "timestamp": time.Now().Unix(), + }, + { + "device": "环境监测器-002", + "sensors": []interface{}{ + map[string]interface{}{"temperature": 22.3, "humidity": 65.8}, + map[string]interface{}{"temperature": 23.1, "humidity": 63.2}, + }, + "data": []interface{}{"status_warning", "battery_78%", "signal_weak"}, + "timestamp": time.Now().Unix(), + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 📋 数组索引访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 设备: %v\n", item["device"]) + fmt.Printf(" 第一个传感器温度: %v°C\n", item["first_sensor_temp"]) + fmt.Printf(" 第二个传感器湿度: %v%%\n", item["second_sensor_humidity"]) + fmt.Printf(" 第三个数据项: %v\n", item["third_data_item"]) + fmt.Println() + } + } + }) + + // 添加测试数据 + for _, data := range testData { + ssql.Emit(data) + } + + // 等待结果 + wg.Wait() +} + +// 演示Map键访问 +func demonstrateMapKeyAccess(ssql *streamsql.Streamsql) { + // SQL查询:使用字符串键访问Map数据 + rsql := `SELECT device_id, + config['host'] as server_host, + config["port"] as server_port, + settings['enable_ssl'] as ssl_enabled, + metadata["version"] as app_version + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备测试数据 + testData := []map[string]interface{}{ + { + "device_id": "gateway-001", + "config": map[string]interface{}{ + "host": "192.168.1.100", + "port": 8080, + "protocol": "https", + }, + "settings": map[string]interface{}{ + "enable_ssl": true, + "timeout": 30, + "max_retries": 3, + }, + "metadata": map[string]interface{}{ + "version": "v2.1.3", + "build_date": "2023-12-01", + "vendor": "TechCorp", + }, + }, + { + "device_id": "gateway-002", + "config": map[string]interface{}{ + "host": "192.168.1.101", + "port": 8443, + "protocol": "https", + }, + "settings": map[string]interface{}{ + "enable_ssl": false, + "timeout": 60, + "max_retries": 5, + }, + "metadata": map[string]interface{}{ + "version": "v2.0.8", + "build_date": "2023-11-15", + "vendor": "TechCorp", + }, + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 🗝️ Map键访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 设备ID: %v\n", item["device_id"]) + fmt.Printf(" 服务器主机: %v\n", item["server_host"]) + fmt.Printf(" 服务器端口: %v\n", item["server_port"]) + fmt.Printf(" SSL启用: %v\n", item["ssl_enabled"]) + fmt.Printf(" 应用版本: %v\n", item["app_version"]) + fmt.Println() + } + } + }) + + // 添加测试数据 + for _, data := range testData { + ssql.Emit(data) + } + + // 等待结果 + wg.Wait() +} + +// 演示混合复杂访问 +func demonstrateComplexMixedAccess(ssql *streamsql.Streamsql) { + // SQL查询:混合使用数组索引、Map键和嵌套字段访问 + rsql := `SELECT building, + floors[0].rooms[2]['name'] as first_floor_room3_name, + floors[1].sensors[0].readings['temperature'] as second_floor_first_sensor_temp, + metadata.building_info['architect'] as building_architect, + alerts[-1].message as latest_alert + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备复杂嵌套数据 + testData := map[string]interface{}{ + "building": "智能大厦A座", + "floors": []interface{}{ + // 第一层 + map[string]interface{}{ + "floor_number": 1, + "rooms": []interface{}{ + map[string]interface{}{"name": "大厅", "type": "public"}, + map[string]interface{}{"name": "接待室", "type": "office"}, + map[string]interface{}{"name": "会议室A", "type": "meeting"}, + map[string]interface{}{"name": "休息区", "type": "lounge"}, + }, + }, + // 第二层 + map[string]interface{}{ + "floor_number": 2, + "sensors": []interface{}{ + map[string]interface{}{ + "id": "sensor-201", + "readings": map[string]interface{}{ + "temperature": 23.5, + "humidity": 58.2, + "co2": 420, + }, + }, + map[string]interface{}{ + "id": "sensor-202", + "readings": map[string]interface{}{ + "temperature": 24.1, + "humidity": 60.8, + "co2": 380, + }, + }, + }, + }, + }, + "metadata": map[string]interface{}{ + "building_info": map[string]interface{}{ + "architect": "张建筑师", + "year_built": 2020, + "total_floors": 25, + }, + "owner": "科技园管委会", + }, + "alerts": []interface{}{ + map[string]interface{}{"level": "info", "message": "系统启动完成"}, + map[string]interface{}{"level": "warning", "message": "传感器信号弱"}, + map[string]interface{}{"level": "info", "message": "定期维护提醒"}, + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 🔄 混合复杂访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 建筑: %v\n", item["building"]) + fmt.Printf(" 一层第3个房间: %v\n", item["first_floor_room3_name"]) + fmt.Printf(" 二层第1个传感器温度: %v°C\n", item["second_floor_first_sensor_temp"]) + fmt.Printf(" 建筑师: %v\n", item["building_architect"]) + fmt.Printf(" 最新警报: %v\n", item["latest_alert"]) + fmt.Println() + } + } + }) + + // 添加数据 + ssql.Emit(testData) + + // 等待结果 + wg.Wait() +} + +// 演示负数索引访问 +func demonstrateNegativeIndexAccess(ssql *streamsql.Streamsql) { + // SQL查询:使用负数索引访问数组末尾元素 + rsql := `SELECT device_name, + readings[-1] as latest_reading, + history[-2] as second_last_event, + tags[-1] as last_tag + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + // 准备测试数据 + testData := []map[string]interface{}{ + { + "device_name": "温度监测器-Alpha", + "readings": []interface{}{18.5, 19.2, 20.1, 21.3, 22.8, 23.5}, // [-1] = 23.5 + "history": []interface{}{"boot", "calibration", "running", "alert", "resolved"}, // [-2] = "alert" + "tags": []interface{}{"indoor", "critical", "monitored"}, // [-1] = "monitored" + }, + { + "device_name": "湿度传感器-Beta", + "readings": []interface{}{45.2, 47.8, 52.1, 48.9}, // [-1] = 48.9 + "history": []interface{}{"init", "testing", "deployed"}, // [-2] = "testing" + "tags": []interface{}{"outdoor", "backup"}, // [-1] = "backup" + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" ⬅️ 负数索引访问结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + fmt.Printf(" 记录 %d:\n", i+1) + fmt.Printf(" 设备名称: %v\n", item["device_name"]) + fmt.Printf(" 最新读数: %v\n", item["latest_reading"]) + fmt.Printf(" 倒数第二个事件: %v\n", item["second_last_event"]) + fmt.Printf(" 最后一个标签: %v\n", item["last_tag"]) + fmt.Println() + } + } + }) + + // 添加测试数据 + for _, data := range testData { + ssql.Emit(data) + } + + // 等待结果 + wg.Wait() +} + +// 演示数组索引聚合计算 +func demonstrateArrayIndexAggregation(ssql *streamsql.Streamsql) { + // SQL查询:对数组中特定位置的数据进行聚合计算 + rsql := `SELECT location, + AVG(sensors[0].temperature) as avg_first_sensor_temp, + MAX(sensors[1].humidity) as max_second_sensor_humidity, + COUNT(*) as device_count + FROM stream + GROUP BY location, TumblingWindow('2s') + WITH (TIMESTAMP='timestamp', TIMEUNIT='ss')` + + err := ssql.Execute(rsql) + if err != nil { + fmt.Printf("❌ SQL执行失败: %v\n", err) + return + } + + var resultCount int + var wg sync.WaitGroup + wg.Add(1) + + // 设置结果回调 + ssql.AddSink(func(result interface{}) { + defer wg.Done() + + fmt.Println(" 📈 数组索引聚合计算结果:") + if resultSlice, ok := result.([]map[string]interface{}); ok { + for i, item := range resultSlice { + resultCount++ + fmt.Printf(" 聚合结果 %d:\n", i+1) + fmt.Printf(" 位置: %v\n", item["location"]) + fmt.Printf(" 第一个传感器平均温度: %.2f°C\n", item["avg_first_sensor_temp"]) + fmt.Printf(" 第二个传感器最大湿度: %.1f%%\n", item["max_second_sensor_humidity"]) + fmt.Printf(" 设备数量: %v\n", item["device_count"]) + fmt.Println() + } + } + }) + + // 生成模拟数据 + locations := []string{"车间A", "车间B", "车间C"} + + go func() { + for i := 0; i < 12; i++ { + location := locations[rand.Intn(len(locations))] + + data := map[string]interface{}{ + "device_id": fmt.Sprintf("device-%03d", i+1), + "location": location, + "sensors": []interface{}{ + map[string]interface{}{ + "temperature": 20.0 + rand.Float64()*10.0, // 20-30°C + "humidity": 50.0 + rand.Float64()*20.0, // 50-70% + }, + map[string]interface{}{ + "temperature": 18.0 + rand.Float64()*12.0, // 18-30°C + "humidity": 45.0 + rand.Float64()*25.0, // 45-70% + }, + }, + "timestamp": time.Now().Unix(), + } + + ssql.Emit(data) + time.Sleep(200 * time.Millisecond) // 每200ms发送一条数据 + } + }() + + // 等待聚合结果 + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + select { + case <-ctx.Done(): + fmt.Println(" ⏰ 聚合计算超时") + case <-func() chan struct{} { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + return done + }(): + fmt.Printf(" ✅ 聚合计算完成,共生成 %d 个窗口结果\n", resultCount) + } +} diff --git a/examples/non-aggregation/main.go b/examples/non-aggregation/main.go index bb12259..2162009 100644 --- a/examples/non-aggregation/main.go +++ b/examples/non-aggregation/main.go @@ -1,343 +1,343 @@ -package main - -import ( - "fmt" - "math/rand" - "time" - - "github.com/rulego/streamsql" -) - -// 非聚合场景使用示例 -// 展示StreamSQL在实时数据转换、过滤、清洗等场景中的应用 -func main() { - fmt.Println("=== StreamSQL 非聚合场景演示 ===") - - // 场景1: 实时数据清洗和标准化 - fmt.Println("\n1. 实时数据清洗和标准化") - demonstrateDataCleaning() - - // 场景2: 数据富化和计算字段 - fmt.Println("\n2. 数据富化和计算字段") - demonstrateDataEnrichment() - - // 场景3: 实时告警和事件过滤 - fmt.Println("\n3. 实时告警和事件过滤") - demonstrateRealTimeAlerting() - - // 场景4: 数据格式转换 - fmt.Println("\n4. 数据格式转换") - demonstrateDataFormatConversion() - - // 场景5: 基于条件的数据路由 - fmt.Println("\n5. 基于条件的数据路由") - demonstrateDataRouting() - - // 场景6: 嵌套字段处理 - fmt.Println("\n6. 嵌套字段处理") - demonstrateNestedFieldProcessing() - - fmt.Println("\n=== 演示完成 ===") -} - -// 场景1: 实时数据清洗和标准化 -func demonstrateDataCleaning() { - ssql := streamsql.New() - defer ssql.Stop() - - // 清洗和标准化SQL - rsql := `SELECT deviceId, - UPPER(TRIM(deviceType)) as device_type, - ROUND(temperature, 2) as temperature, - COALESCE(location, 'unknown') as location, - CASE WHEN status = 1 THEN 'active' - WHEN status = 0 THEN 'inactive' - ELSE 'unknown' END as status_text - FROM stream - WHERE deviceId != '' AND temperature > -999` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - // 结果处理 - ssql.Stream().AddSink(func(result interface{}) { - fmt.Printf(" 清洗后数据: %+v\n", result) - }) - - // 模拟脏数据输入 - dirtyData := []map[string]interface{}{ - {"deviceId": "sensor001", "deviceType": " temperature ", "temperature": 25.456789, "location": "room1", "status": 1}, - {"deviceId": "sensor002", "deviceType": "humidity", "temperature": 60.123, "location": nil, "status": 0}, - {"deviceId": "", "deviceType": "pressure", "temperature": nil, "location": "room2", "status": 2}, // 应被过滤 - {"deviceId": "sensor003", "deviceType": "TEMPERATURE", "temperature": 22.7, "location": "room3", "status": 1}, - } - - for _, data := range dirtyData { - ssql.Stream().AddData(data) - time.Sleep(50 * time.Millisecond) - } - - time.Sleep(200 * time.Millisecond) -} - -// 场景2: 数据富化和计算字段 -func demonstrateDataEnrichment() { - ssql := streamsql.New() - defer ssql.Stop() - - // 数据富化SQL - rsql := `SELECT *, - temperature * 1.8 + 32 as temp_fahrenheit, - CASE WHEN temperature > 30 THEN 'hot' - WHEN temperature < 15 THEN 'cold' - ELSE 'normal' END as temp_category, - CONCAT(location, '-', deviceId) as full_identifier, - NOW() as processed_timestamp, - ROUND(humidity / 100.0, 4) as humidity_ratio - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - fmt.Printf(" 富化后数据: %+v\n", result) - }) - - // 原始数据 - rawData := []map[string]interface{}{ - {"deviceId": "sensor001", "temperature": 32.5, "humidity": 65, "location": "greenhouse"}, - {"deviceId": "sensor002", "temperature": 12.0, "humidity": 45, "location": "warehouse"}, - {"deviceId": "sensor003", "temperature": 22.8, "humidity": 70, "location": "office"}, - } - - for _, data := range rawData { - ssql.Stream().AddData(data) - time.Sleep(100 * time.Millisecond) - } - - time.Sleep(200 * time.Millisecond) -} - -// 场景3: 实时告警和事件过滤 -func demonstrateRealTimeAlerting() { - ssql := streamsql.New() - defer ssql.Stop() - - // 告警过滤SQL - rsql := `SELECT deviceId, - temperature, - humidity, - location, - 'CRITICAL' as alert_level, - CASE WHEN temperature > 40 THEN 'High Temperature Alert' - WHEN temperature < 5 THEN 'Low Temperature Alert' - WHEN humidity > 90 THEN 'High Humidity Alert' - WHEN humidity < 20 THEN 'Low Humidity Alert' - ELSE 'Unknown Alert' END as alert_message, - NOW() as alert_time - FROM stream - WHERE temperature > 40 OR temperature < 5 OR humidity > 90 OR humidity < 20` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - fmt.Printf(" 🚨 告警事件: %+v\n", result) - }) - - // 模拟传感器数据(包含异常值) - sensorData := []map[string]interface{}{ - {"deviceId": "sensor001", "temperature": 25.0, "humidity": 60, "location": "room1"}, // 正常 - {"deviceId": "sensor002", "temperature": 45.0, "humidity": 50, "location": "room2"}, // 高温告警 - {"deviceId": "sensor003", "temperature": 20.0, "humidity": 95, "location": "room3"}, // 高湿度告警 - {"deviceId": "sensor004", "temperature": 2.0, "humidity": 30, "location": "room4"}, // 低温告警 - {"deviceId": "sensor005", "temperature": 22.0, "humidity": 15, "location": "room5"}, // 低湿度告警 - {"deviceId": "sensor006", "temperature": 24.0, "humidity": 55, "location": "room6"}, // 正常 - } - - for _, data := range sensorData { - ssql.Stream().AddData(data) - time.Sleep(150 * time.Millisecond) - } - - time.Sleep(200 * time.Millisecond) -} - -// 场景4: 数据格式转换 -func demonstrateDataFormatConversion() { - ssql := streamsql.New() - defer ssql.Stop() - - // 格式转换SQL - rsql := `SELECT deviceId, - CONCAT('{"device_id":"', deviceId, '","metrics":{"temp":', - CAST(temperature AS STRING), ',"hum":', - CAST(humidity AS STRING), '},"location":"', - location, '","timestamp":', - CAST(NOW() AS STRING), '}') as json_format, - CONCAT(deviceId, '|', location, '|', - CAST(temperature AS STRING), '|', - CAST(humidity AS STRING)) as csv_format - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - fmt.Printf(" 格式转换结果: %+v\n", result) - }) - - // 输入数据 - inputData := []map[string]interface{}{ - {"deviceId": "sensor001", "temperature": 25.5, "humidity": 60, "location": "warehouse-A"}, - {"deviceId": "sensor002", "temperature": 22.0, "humidity": 55, "location": "warehouse-B"}, - } - - for _, data := range inputData { - ssql.Stream().AddData(data) - time.Sleep(100 * time.Millisecond) - } - - time.Sleep(200 * time.Millisecond) -} - -// 场景5: 基于条件的数据路由 -func demonstrateDataRouting() { - ssql := streamsql.New() - defer ssql.Stop() - - // 数据路由SQL - rsql := `SELECT *, - CASE WHEN deviceType = 'temperature' AND temperature > 30 THEN 'high_temp_topic' - WHEN deviceType = 'humidity' AND humidity > 80 THEN 'high_humidity_topic' - WHEN deviceType = 'pressure' THEN 'pressure_topic' - ELSE 'default_topic' END as routing_topic, - CASE WHEN temperature > 35 OR humidity > 85 THEN 'urgent' - WHEN temperature > 25 OR humidity > 70 THEN 'normal' - ELSE 'low' END as priority - FROM stream` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - fmt.Printf(" 路由结果: %+v\n", result) - }) - - // 不同类型的设备数据 - deviceData := []map[string]interface{}{ - {"deviceId": "temp001", "deviceType": "temperature", "temperature": 35.0, "humidity": 60}, - {"deviceId": "hum001", "deviceType": "humidity", "temperature": 25.0, "humidity": 85}, - {"deviceId": "press001", "deviceType": "pressure", "temperature": 22.0, "pressure": 1013.25}, - {"deviceId": "temp002", "deviceType": "temperature", "temperature": 20.0, "humidity": 50}, - } - - for _, data := range deviceData { - ssql.Stream().AddData(data) - time.Sleep(100 * time.Millisecond) - } - - time.Sleep(200 * time.Millisecond) -} - -// 场景6: 嵌套字段处理 -func demonstrateNestedFieldProcessing() { - ssql := streamsql.New() - defer ssql.Stop() - - // 嵌套字段处理SQL - rsql := `SELECT device.info.id as device_id, - device.info.name as device_name, - device.location.building as building, - device.location.room as room, - metrics.temperature as temp, - metrics.humidity as humidity, - CONCAT(device.location.building, '-', device.location.room, '-', device.info.id) as full_path, - CASE WHEN metrics.temperature > device.config.max_temp THEN 'OVER_LIMIT' - ELSE 'NORMAL' END as temp_status - FROM stream - WHERE device.info.type = 'sensor'` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - fmt.Printf(" 嵌套字段处理结果: %+v\n", result) - }) - - // 嵌套结构数据 - nestedData := []map[string]interface{}{ - { - "device": map[string]interface{}{ - "info": map[string]interface{}{ - "id": "sensor001", - "name": "Temperature Sensor 1", - "type": "sensor", - }, - "location": map[string]interface{}{ - "building": "Building-A", - "room": "Room-101", - }, - "config": map[string]interface{}{ - "max_temp": 30.0, - "min_temp": 10.0, - }, - }, - "metrics": map[string]interface{}{ - "temperature": 32.5, - "humidity": 65, - }, - }, - { - "device": map[string]interface{}{ - "info": map[string]interface{}{ - "id": "sensor002", - "name": "Humidity Sensor 1", - "type": "sensor", - }, - "location": map[string]interface{}{ - "building": "Building-B", - "room": "Room-201", - }, - "config": map[string]interface{}{ - "max_temp": 25.0, - "min_temp": 15.0, - }, - }, - "metrics": map[string]interface{}{ - "temperature": 22.0, - "humidity": 70, - }, - }, - } - - for _, data := range nestedData { - ssql.Stream().AddData(data) - time.Sleep(100 * time.Millisecond) - } - - time.Sleep(200 * time.Millisecond) -} - -// 生成随机测试数据的辅助函数 -func generateRandomSensorData(deviceId string) map[string]interface{} { - return map[string]interface{}{ - "deviceId": deviceId, - "temperature": 15.0 + rand.Float64()*25.0, // 15-40度 - "humidity": 30.0 + rand.Float64()*40.0, // 30-70% - "location": fmt.Sprintf("room%d", rand.Intn(10)+1), - "timestamp": time.Now().Unix(), - } -} +package main + +import ( + "fmt" + "math/rand" + "time" + + "github.com/rulego/streamsql" +) + +// 非聚合场景使用示例 +// 展示StreamSQL在实时数据转换、过滤、清洗等场景中的应用 +func main() { + fmt.Println("=== StreamSQL 非聚合场景演示 ===") + + // 场景1: 实时数据清洗和标准化 + fmt.Println("\n1. 实时数据清洗和标准化") + demonstrateDataCleaning() + + // 场景2: 数据富化和计算字段 + fmt.Println("\n2. 数据富化和计算字段") + demonstrateDataEnrichment() + + // 场景3: 实时告警和事件过滤 + fmt.Println("\n3. 实时告警和事件过滤") + demonstrateRealTimeAlerting() + + // 场景4: 数据格式转换 + fmt.Println("\n4. 数据格式转换") + demonstrateDataFormatConversion() + + // 场景5: 基于条件的数据路由 + fmt.Println("\n5. 基于条件的数据路由") + demonstrateDataRouting() + + // 场景6: 嵌套字段处理 + fmt.Println("\n6. 嵌套字段处理") + demonstrateNestedFieldProcessing() + + fmt.Println("\n=== 演示完成 ===") +} + +// 场景1: 实时数据清洗和标准化 +func demonstrateDataCleaning() { + ssql := streamsql.New() + defer ssql.Stop() + + // 清洗和标准化SQL + rsql := `SELECT deviceId, + UPPER(TRIM(deviceType)) as device_type, + ROUND(temperature, 2) as temperature, + COALESCE(location, 'unknown') as location, + CASE WHEN status = 1 THEN 'active' + WHEN status = 0 THEN 'inactive' + ELSE 'unknown' END as status_text + FROM stream + WHERE deviceId != '' AND temperature > -999` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + // 结果处理 + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 清洗后数据: %+v\n", result) + }) + + // 模拟脏数据输入 + dirtyData := []map[string]interface{}{ + {"deviceId": "sensor001", "deviceType": " temperature ", "temperature": 25.456789, "location": "room1", "status": 1}, + {"deviceId": "sensor002", "deviceType": "humidity", "temperature": 60.123, "location": nil, "status": 0}, + {"deviceId": "", "deviceType": "pressure", "temperature": nil, "location": "room2", "status": 2}, // 应被过滤 + {"deviceId": "sensor003", "deviceType": "TEMPERATURE", "temperature": 22.7, "location": "room3", "status": 1}, + } + + for _, data := range dirtyData { + ssql.Emit(data) + time.Sleep(50 * time.Millisecond) + } + + time.Sleep(200 * time.Millisecond) +} + +// 场景2: 数据富化和计算字段 +func demonstrateDataEnrichment() { + ssql := streamsql.New() + defer ssql.Stop() + + // 数据富化SQL + rsql := `SELECT *, + temperature * 1.8 + 32 as temp_fahrenheit, + CASE WHEN temperature > 30 THEN 'hot' + WHEN temperature < 15 THEN 'cold' + ELSE 'normal' END as temp_category, + CONCAT(location, '-', deviceId) as full_identifier, + NOW() as processed_timestamp, + ROUND(humidity / 100.0, 4) as humidity_ratio + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 富化后数据: %+v\n", result) + }) + + // 原始数据 + rawData := []map[string]interface{}{ + {"deviceId": "sensor001", "temperature": 32.5, "humidity": 65, "location": "greenhouse"}, + {"deviceId": "sensor002", "temperature": 12.0, "humidity": 45, "location": "warehouse"}, + {"deviceId": "sensor003", "temperature": 22.8, "humidity": 70, "location": "office"}, + } + + for _, data := range rawData { + ssql.Emit(data) + time.Sleep(100 * time.Millisecond) + } + + time.Sleep(200 * time.Millisecond) +} + +// 场景3: 实时告警和事件过滤 +func demonstrateRealTimeAlerting() { + ssql := streamsql.New() + defer ssql.Stop() + + // 告警过滤SQL + rsql := `SELECT deviceId, + temperature, + humidity, + location, + 'CRITICAL' as alert_level, + CASE WHEN temperature > 40 THEN 'High Temperature Alert' + WHEN temperature < 5 THEN 'Low Temperature Alert' + WHEN humidity > 90 THEN 'High Humidity Alert' + WHEN humidity < 20 THEN 'Low Humidity Alert' + ELSE 'Unknown Alert' END as alert_message, + NOW() as alert_time + FROM stream + WHERE temperature > 40 OR temperature < 5 OR humidity > 90 OR humidity < 20` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 🚨 告警事件: %+v\n", result) + }) + + // 模拟传感器数据(包含异常值) + sensorData := []map[string]interface{}{ + {"deviceId": "sensor001", "temperature": 25.0, "humidity": 60, "location": "room1"}, // 正常 + {"deviceId": "sensor002", "temperature": 45.0, "humidity": 50, "location": "room2"}, // 高温告警 + {"deviceId": "sensor003", "temperature": 20.0, "humidity": 95, "location": "room3"}, // 高湿度告警 + {"deviceId": "sensor004", "temperature": 2.0, "humidity": 30, "location": "room4"}, // 低温告警 + {"deviceId": "sensor005", "temperature": 22.0, "humidity": 15, "location": "room5"}, // 低湿度告警 + {"deviceId": "sensor006", "temperature": 24.0, "humidity": 55, "location": "room6"}, // 正常 + } + + for _, data := range sensorData { + ssql.Emit(data) + time.Sleep(150 * time.Millisecond) + } + + time.Sleep(200 * time.Millisecond) +} + +// 场景4: 数据格式转换 +func demonstrateDataFormatConversion() { + ssql := streamsql.New() + defer ssql.Stop() + + // 格式转换SQL + rsql := `SELECT deviceId, + CONCAT('{"device_id":"', deviceId, '","metrics":{"temp":', + CAST(temperature AS STRING), ',"hum":', + CAST(humidity AS STRING), '},"location":"', + location, '","timestamp":', + CAST(NOW() AS STRING), '}') as json_format, + CONCAT(deviceId, '|', location, '|', + CAST(temperature AS STRING), '|', + CAST(humidity AS STRING)) as csv_format + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 格式转换结果: %+v\n", result) + }) + + // 输入数据 + inputData := []map[string]interface{}{ + {"deviceId": "sensor001", "temperature": 25.5, "humidity": 60, "location": "warehouse-A"}, + {"deviceId": "sensor002", "temperature": 22.0, "humidity": 55, "location": "warehouse-B"}, + } + + for _, data := range inputData { + ssql.Emit(data) + time.Sleep(100 * time.Millisecond) + } + + time.Sleep(200 * time.Millisecond) +} + +// 场景5: 基于条件的数据路由 +func demonstrateDataRouting() { + ssql := streamsql.New() + defer ssql.Stop() + + // 数据路由SQL + rsql := `SELECT *, + CASE WHEN deviceType = 'temperature' AND temperature > 30 THEN 'high_temp_topic' + WHEN deviceType = 'humidity' AND humidity > 80 THEN 'high_humidity_topic' + WHEN deviceType = 'pressure' THEN 'pressure_topic' + ELSE 'default_topic' END as routing_topic, + CASE WHEN temperature > 35 OR humidity > 85 THEN 'urgent' + WHEN temperature > 25 OR humidity > 70 THEN 'normal' + ELSE 'low' END as priority + FROM stream` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 路由结果: %+v\n", result) + }) + + // 不同类型的设备数据 + deviceData := []map[string]interface{}{ + {"deviceId": "temp001", "deviceType": "temperature", "temperature": 35.0, "humidity": 60}, + {"deviceId": "hum001", "deviceType": "humidity", "temperature": 25.0, "humidity": 85}, + {"deviceId": "press001", "deviceType": "pressure", "temperature": 22.0, "pressure": 1013.25}, + {"deviceId": "temp002", "deviceType": "temperature", "temperature": 20.0, "humidity": 50}, + } + + for _, data := range deviceData { + ssql.Emit(data) + time.Sleep(100 * time.Millisecond) + } + + time.Sleep(200 * time.Millisecond) +} + +// 场景6: 嵌套字段处理 +func demonstrateNestedFieldProcessing() { + ssql := streamsql.New() + defer ssql.Stop() + + // 嵌套字段处理SQL + rsql := `SELECT device.info.id as device_id, + device.info.name as device_name, + device.location.building as building, + device.location.room as room, + metrics.temperature as temp, + metrics.humidity as humidity, + CONCAT(device.location.building, '-', device.location.room, '-', device.info.id) as full_path, + CASE WHEN metrics.temperature > device.config.max_temp THEN 'OVER_LIMIT' + ELSE 'NORMAL' END as temp_status + FROM stream + WHERE device.info.type = 'sensor'` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + fmt.Printf(" 嵌套字段处理结果: %+v\n", result) + }) + + // 嵌套结构数据 + nestedData := []map[string]interface{}{ + { + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "id": "sensor001", + "name": "Temperature Sensor 1", + "type": "sensor", + }, + "location": map[string]interface{}{ + "building": "Building-A", + "room": "Room-101", + }, + "config": map[string]interface{}{ + "max_temp": 30.0, + "min_temp": 10.0, + }, + }, + "metrics": map[string]interface{}{ + "temperature": 32.5, + "humidity": 65, + }, + }, + { + "device": map[string]interface{}{ + "info": map[string]interface{}{ + "id": "sensor002", + "name": "Humidity Sensor 1", + "type": "sensor", + }, + "location": map[string]interface{}{ + "building": "Building-B", + "room": "Room-201", + }, + "config": map[string]interface{}{ + "max_temp": 25.0, + "min_temp": 15.0, + }, + }, + "metrics": map[string]interface{}{ + "temperature": 22.0, + "humidity": 70, + }, + }, + } + + for _, data := range nestedData { + ssql.Emit(data) + time.Sleep(100 * time.Millisecond) + } + + time.Sleep(200 * time.Millisecond) +} + +// 生成随机测试数据的辅助函数 +func generateRandomSensorData(deviceId string) map[string]interface{} { + return map[string]interface{}{ + "deviceId": deviceId, + "temperature": 15.0 + rand.Float64()*25.0, // 15-40度 + "humidity": 30.0 + rand.Float64()*40.0, // 30-70% + "location": fmt.Sprintf("room%d", rand.Intn(10)+1), + "timestamp": time.Now().Unix(), + } +} diff --git a/examples/null-comparison-examples/main.go b/examples/null-comparison-examples/main.go index 3d86102..cc09376 100644 --- a/examples/null-comparison-examples/main.go +++ b/examples/null-comparison-examples/main.go @@ -1,266 +1,266 @@ -package main - -import ( - "fmt" - "time" - - "github.com/rulego/streamsql" -) - -func main() { - fmt.Println("=== StreamSQL Null 比较语法演示 ===") - fmt.Println() - - demo1() // fieldName = nil 语法 - demo2() // fieldName != nil 语法 - demo3() // fieldName = null 和 != null 语法 - demo4() // 混合语法演示 - demo5() // 嵌套字段 null 比较 -} - -func demo1() { - fmt.Println("1. fieldName = nil 语法演示") - fmt.Println("-------------------------------------------") - - ssql := streamsql.New() - defer ssql.Stop() - - // 使用 = nil 语法查找空值 - rsql := `SELECT deviceId, value, status - FROM stream - WHERE value = nil` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - if results, ok := result.([]map[string]interface{}); ok { - for _, data := range results { - fmt.Printf("发现空值数据: %+v\n", data) - } - } - }) - - testData := []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, - {"deviceId": "sensor2", "value": nil, "status": "active"}, // 符合条件 - {"deviceId": "sensor3", "value": 30.0, "status": "inactive"}, - {"deviceId": "sensor4", "value": nil, "status": "error"}, // 符合条件 - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - time.Sleep(300 * time.Millisecond) - fmt.Println() -} - -func demo2() { - fmt.Println("2. fieldName != nil 语法演示") - fmt.Println("-------------------------------------------") - - ssql := streamsql.New() - defer ssql.Stop() - - // 使用 != nil 语法查找非空值 - rsql := `SELECT deviceId, value, status - FROM stream - WHERE value != nil AND value > 20` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - if results, ok := result.([]map[string]interface{}); ok { - for _, data := range results { - fmt.Printf("发现有效数据: %+v\n", data) - } - } - }) - - testData := []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, // 符合条件 - {"deviceId": "sensor2", "value": nil, "status": "active"}, // 不符合(空值) - {"deviceId": "sensor3", "value": 15.0, "status": "inactive"}, // 不符合(值<=20) - {"deviceId": "sensor4", "value": 30.0, "status": "error"}, // 符合条件 - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - time.Sleep(300 * time.Millisecond) - fmt.Println() -} - -func demo3() { - fmt.Println("3. fieldName = null 和 != null 语法演示") - fmt.Println("-------------------------------------------") - - ssql := streamsql.New() - defer ssql.Stop() - - // 使用 = null 和 != null 语法 - rsql := `SELECT deviceId, value, status - FROM stream - WHERE status != null OR value = null` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - if results, ok := result.([]map[string]interface{}); ok { - for _, data := range results { - status := data["status"] - value := data["value"] - if status != nil { - fmt.Printf("状态非空的数据: %+v\n", data) - } else if value == nil { - fmt.Printf("值为空的数据: %+v\n", data) - } - } - } - }) - - testData := []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, // 符合(status不为null) - {"deviceId": "sensor2", "value": nil, "status": nil}, // 符合(value为null) - {"deviceId": "sensor3", "value": 30.0, "status": "inactive"}, // 符合(status不为null) - {"deviceId": "sensor4", "value": nil, "status": "error"}, // 符合(两个条件都满足) - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - time.Sleep(300 * time.Millisecond) - fmt.Println() -} - -func demo4() { - fmt.Println("4. 混合 null 比较语法演示") - fmt.Println("-------------------------------------------") - - ssql := streamsql.New() - defer ssql.Stop() - - // 混合使用 IS NULL、= nil、!= null 等语法 - rsql := `SELECT deviceId, value, status, priority - FROM stream - WHERE (value IS NOT NULL AND value > 20) OR - (status = nil AND priority != null)` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - if results, ok := result.([]map[string]interface{}); ok { - for _, data := range results { - value := data["value"] - status := data["status"] - priority := data["priority"] - - if value != nil && value.(float64) > 20 { - fmt.Printf("高值数据 (value > 20): %+v\n", data) - } else if status == nil && priority != nil { - fmt.Printf("状态异常但有优先级的数据: %+v\n", data) - } - } - } - }) - - testData := []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.0, "status": "active", "priority": "high"}, // 符合第一个条件 - {"deviceId": "sensor2", "value": 15.0, "status": "active", "priority": "low"}, // 不符合 - {"deviceId": "sensor3", "value": nil, "status": nil, "priority": "medium"}, // 符合第二个条件 - {"deviceId": "sensor4", "value": nil, "status": nil, "priority": nil}, // 不符合 - {"deviceId": "sensor5", "value": 30.0, "status": "inactive", "priority": nil}, // 符合第一个条件 - {"deviceId": "sensor6", "value": 10.0, "status": nil, "priority": "urgent"}, // 符合第二个条件 - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - time.Sleep(300 * time.Millisecond) - fmt.Println() -} - -func demo5() { - fmt.Println("5. 嵌套字段 null 比较演示") - fmt.Println("-------------------------------------------") - - ssql := streamsql.New() - defer ssql.Stop() - - // 嵌套字段的 null 比较 - rsql := `SELECT deviceId, device.location - FROM stream - WHERE device.location != nil` - - err := ssql.Execute(rsql) - if err != nil { - panic(err) - } - - ssql.Stream().AddSink(func(result interface{}) { - if results, ok := result.([]map[string]interface{}); ok { - for _, data := range results { - fmt.Printf("有位置信息的设备: %+v\n", data) - } - } - }) - - testData := []map[string]interface{}{ - { - "deviceId": "sensor1", - "device": map[string]interface{}{ - "location": "warehouse-A", - }, - }, // 符合条件 - { - "deviceId": "sensor2", - "device": map[string]interface{}{ - "location": nil, - }, - }, // 不符合(location为nil) - { - "deviceId": "sensor3", - "device": map[string]interface{}{}, - }, // 不符合(location字段不存在) - { - "deviceId": "sensor4", - "device": map[string]interface{}{ - "location": "office-B", - }, - }, // 符合条件 - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - time.Sleep(300 * time.Millisecond) - fmt.Println() - - fmt.Println("=== Null 比较语法演示完成 ===") - fmt.Println() - fmt.Println("支持的 null 比较语法:") - fmt.Println("- fieldName IS NULL") - fmt.Println("- fieldName IS NOT NULL") - fmt.Println("- fieldName = nil") - fmt.Println("- fieldName != nil") - fmt.Println("- fieldName = null") - fmt.Println("- fieldName != null") - fmt.Println("- device.field = nil (嵌套字段)") - fmt.Println("- device.field != nil (嵌套字段)") -} +package main + +import ( + "fmt" + "time" + + "github.com/rulego/streamsql" +) + +func main() { + fmt.Println("=== StreamSQL Null 比较语法演示 ===") + fmt.Println() + + demo1() // fieldName = nil 语法 + demo2() // fieldName != nil 语法 + demo3() // fieldName = null 和 != null 语法 + demo4() // 混合语法演示 + demo5() // 嵌套字段 null 比较 +} + +func demo1() { + fmt.Println("1. fieldName = nil 语法演示") + fmt.Println("-------------------------------------------") + + ssql := streamsql.New() + defer ssql.Stop() + + // 使用 = nil 语法查找空值 + rsql := `SELECT deviceId, value, status + FROM stream + WHERE value = nil` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + if results, ok := result.([]map[string]interface{}); ok { + for _, data := range results { + fmt.Printf("发现空值数据: %+v\n", data) + } + } + }) + + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, + {"deviceId": "sensor2", "value": nil, "status": "active"}, // 符合条件 + {"deviceId": "sensor3", "value": 30.0, "status": "inactive"}, + {"deviceId": "sensor4", "value": nil, "status": "error"}, // 符合条件 + } + + for _, data := range testData { + ssql.Emit(data) + } + + time.Sleep(300 * time.Millisecond) + fmt.Println() +} + +func demo2() { + fmt.Println("2. fieldName != nil 语法演示") + fmt.Println("-------------------------------------------") + + ssql := streamsql.New() + defer ssql.Stop() + + // 使用 != nil 语法查找非空值 + rsql := `SELECT deviceId, value, status + FROM stream + WHERE value != nil AND value > 20` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + if results, ok := result.([]map[string]interface{}); ok { + for _, data := range results { + fmt.Printf("发现有效数据: %+v\n", data) + } + } + }) + + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, // 符合条件 + {"deviceId": "sensor2", "value": nil, "status": "active"}, // 不符合(空值) + {"deviceId": "sensor3", "value": 15.0, "status": "inactive"}, // 不符合(值<=20) + {"deviceId": "sensor4", "value": 30.0, "status": "error"}, // 符合条件 + } + + for _, data := range testData { + ssql.Emit(data) + } + + time.Sleep(300 * time.Millisecond) + fmt.Println() +} + +func demo3() { + fmt.Println("3. fieldName = null 和 != null 语法演示") + fmt.Println("-------------------------------------------") + + ssql := streamsql.New() + defer ssql.Stop() + + // 使用 = null 和 != null 语法 + rsql := `SELECT deviceId, value, status + FROM stream + WHERE status != null OR value = null` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + if results, ok := result.([]map[string]interface{}); ok { + for _, data := range results { + status := data["status"] + value := data["value"] + if status != nil { + fmt.Printf("状态非空的数据: %+v\n", data) + } else if value == nil { + fmt.Printf("值为空的数据: %+v\n", data) + } + } + } + }) + + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, // 符合(status不为null) + {"deviceId": "sensor2", "value": nil, "status": nil}, // 符合(value为null) + {"deviceId": "sensor3", "value": 30.0, "status": "inactive"}, // 符合(status不为null) + {"deviceId": "sensor4", "value": nil, "status": "error"}, // 符合(两个条件都满足) + } + + for _, data := range testData { + ssql.Emit(data) + } + + time.Sleep(300 * time.Millisecond) + fmt.Println() +} + +func demo4() { + fmt.Println("4. 混合 null 比较语法演示") + fmt.Println("-------------------------------------------") + + ssql := streamsql.New() + defer ssql.Stop() + + // 混合使用 IS NULL、= nil、!= null 等语法 + rsql := `SELECT deviceId, value, status, priority + FROM stream + WHERE (value IS NOT NULL AND value > 20) OR + (status = nil AND priority != null)` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + if results, ok := result.([]map[string]interface{}); ok { + for _, data := range results { + value := data["value"] + status := data["status"] + priority := data["priority"] + + if value != nil && value.(float64) > 20 { + fmt.Printf("高值数据 (value > 20): %+v\n", data) + } else if status == nil && priority != nil { + fmt.Printf("状态异常但有优先级的数据: %+v\n", data) + } + } + } + }) + + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.0, "status": "active", "priority": "high"}, // 符合第一个条件 + {"deviceId": "sensor2", "value": 15.0, "status": "active", "priority": "low"}, // 不符合 + {"deviceId": "sensor3", "value": nil, "status": nil, "priority": "medium"}, // 符合第二个条件 + {"deviceId": "sensor4", "value": nil, "status": nil, "priority": nil}, // 不符合 + {"deviceId": "sensor5", "value": 30.0, "status": "inactive", "priority": nil}, // 符合第一个条件 + {"deviceId": "sensor6", "value": 10.0, "status": nil, "priority": "urgent"}, // 符合第二个条件 + } + + for _, data := range testData { + ssql.Emit(data) + } + + time.Sleep(300 * time.Millisecond) + fmt.Println() +} + +func demo5() { + fmt.Println("5. 嵌套字段 null 比较演示") + fmt.Println("-------------------------------------------") + + ssql := streamsql.New() + defer ssql.Stop() + + // 嵌套字段的 null 比较 + rsql := `SELECT deviceId, device.location + FROM stream + WHERE device.location != nil` + + err := ssql.Execute(rsql) + if err != nil { + panic(err) + } + + ssql.AddSink(func(result interface{}) { + if results, ok := result.([]map[string]interface{}); ok { + for _, data := range results { + fmt.Printf("有位置信息的设备: %+v\n", data) + } + } + }) + + testData := []map[string]interface{}{ + { + "deviceId": "sensor1", + "device": map[string]interface{}{ + "location": "warehouse-A", + }, + }, // 符合条件 + { + "deviceId": "sensor2", + "device": map[string]interface{}{ + "location": nil, + }, + }, // 不符合(location为nil) + { + "deviceId": "sensor3", + "device": map[string]interface{}{}, + }, // 不符合(location字段不存在) + { + "deviceId": "sensor4", + "device": map[string]interface{}{ + "location": "office-B", + }, + }, // 符合条件 + } + + for _, data := range testData { + ssql.Emit(data) + } + + time.Sleep(300 * time.Millisecond) + fmt.Println() + + fmt.Println("=== Null 比较语法演示完成 ===") + fmt.Println() + fmt.Println("支持的 null 比较语法:") + fmt.Println("- fieldName IS NULL") + fmt.Println("- fieldName IS NOT NULL") + fmt.Println("- fieldName = nil") + fmt.Println("- fieldName != nil") + fmt.Println("- fieldName = null") + fmt.Println("- fieldName != null") + fmt.Println("- device.field = nil (嵌套字段)") + fmt.Println("- device.field != nil (嵌套字段)") +} diff --git a/examples/persistence/main.go b/examples/persistence/main.go index 7e269c3..6fdeb4a 100644 --- a/examples/persistence/main.go +++ b/examples/persistence/main.go @@ -78,7 +78,7 @@ func testDataOverflowPersistence() { "id": i, "value": fmt.Sprintf("data_%d", i), } - stream.AddData(data) + stream.Emit(data) } duration := time.Since(start) diff --git a/examples/simple-custom-functions/main.go b/examples/simple-custom-functions/main.go index fa1fb91..674b75d 100644 --- a/examples/simple-custom-functions/main.go +++ b/examples/simple-custom-functions/main.go @@ -122,7 +122,7 @@ func testSimpleQuery(ssql *streamsql.Streamsql) { } // 添加结果监听器 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf(" 📊 简单查询结果: %v\n", result) }) @@ -143,7 +143,7 @@ func testSimpleQuery(ssql *streamsql.Streamsql) { } for _, data := range testData { - ssql.AddData(data) + ssql.Emit(data) time.Sleep(200 * time.Millisecond) // 稍微延迟 } @@ -171,7 +171,7 @@ func testAggregateQuery(ssql *streamsql.Streamsql) { } // 添加结果监听器 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { fmt.Printf(" 📊 聚合查询结果: %v\n", result) }) @@ -198,7 +198,7 @@ func testAggregateQuery(ssql *streamsql.Streamsql) { } for _, data := range testData { - ssql.AddData(data) + ssql.Emit(data) } // 等待窗口触发 diff --git a/examples/table_print_demo/main.go b/examples/table_print_demo/main.go new file mode 100644 index 0000000..05e174a --- /dev/null +++ b/examples/table_print_demo/main.go @@ -0,0 +1,82 @@ +package main + +import ( + "fmt" + "time" + + "github.com/rulego/streamsql" +) + +// main 演示PrintTable方法的使用 +func main() { + fmt.Println("=== StreamSQL PrintTable 示例 ===") + + // 创建StreamSQL实例 + ssql := streamsql.New() + + // 示例1: 聚合查询 - 按设备分组统计温度 + fmt.Println("\n示例1: 聚合查询结果") + err := ssql.Execute("SELECT device, AVG(temperature) as avg_temp, MAX(temperature) as max_temp FROM stream GROUP BY device, TumblingWindow('3s')") + if err != nil { + fmt.Printf("执行SQL失败: %v\n", err) + return + } + + // 使用PrintTable方法以表格形式输出结果 + ssql.PrintTable() + + // 发送测试数据 + testData := []map[string]interface{}{ + {"device": "sensor1", "temperature": 25.5, "timestamp": time.Now()}, + {"device": "sensor1", "temperature": 26.0, "timestamp": time.Now()}, + {"device": "sensor2", "temperature": 23.8, "timestamp": time.Now()}, + {"device": "sensor2", "temperature": 24.2, "timestamp": time.Now()}, + {"device": "sensor1", "temperature": 27.1, "timestamp": time.Now()}, + } + + for _, data := range testData { + ssql.Emit(data) + } + + // 等待窗口触发 + time.Sleep(4 * time.Second) + + // 示例2: 非聚合查询 + fmt.Println("\n示例2: 非聚合查询结果") + ssql2 := streamsql.New() + err = ssql2.Execute("SELECT device, temperature, temperature * 1.8 + 32 as fahrenheit FROM stream WHERE temperature > 24") + if err != nil { + fmt.Printf("执行SQL失败: %v\n", err) + return + } + + ssql2.PrintTable() + + // 发送测试数据 + for _, data := range testData { + ssql2.Emit(data) + } + + // 等待处理完成 + time.Sleep(1 * time.Second) + + // 示例3: 对比原始Print方法 + fmt.Println("\n示例3: 原始Print方法输出对比") + ssql3 := streamsql.New() + err = ssql3.Execute("SELECT device, COUNT(*) as count FROM stream GROUP BY device, TumblingWindow('2s')") + if err != nil { + fmt.Printf("执行SQL失败: %v\n", err) + return + } + + fmt.Println("原始PrintTable方法:") + ssql3.PrintTable() + + // 发送数据 + for i := 0; i < 3; i++ { + ssql3.Emit(map[string]interface{}{"device": "test_device", "value": i}) + } + + time.Sleep(3 * time.Second) + fmt.Println("\n=== 示例结束 ===") +} \ No newline at end of file diff --git a/examples/unified_config/demo.go b/examples/unified_config/demo.go index 7afe75b..637a71c 100644 --- a/examples/unified_config/demo.go +++ b/examples/unified_config/demo.go @@ -1,218 +1,218 @@ -package main - -import ( - "fmt" - "log" - "strings" - "time" - - "github.com/rulego/streamsql/stream" - "github.com/rulego/streamsql/types" -) - -func main() { - fmt.Println("=== StreamSQL 统一配置系统演示 ===") - - // 1. 使用新的配置API创建默认配置Stream - fmt.Println("\n1. 默认配置Stream:") - defaultConfig := types.NewConfig() - defaultConfig.SimpleFields = []string{"temperature", "humidity", "location"} - - defaultStream, err := stream.NewStream(defaultConfig) - if err != nil { - log.Fatal(err) - } - printStreamStats("默认配置", defaultStream) - - // 2. 使用高性能预设配置 - fmt.Println("\n2. 高性能配置Stream:") - highPerfConfig := types.NewConfigWithPerformance(types.HighPerformanceConfig()) - highPerfConfig.SimpleFields = []string{"temperature", "humidity", "location"} - - highPerfStream, err := stream.NewStreamWithHighPerformance(highPerfConfig) - if err != nil { - log.Fatal(err) - } - printStreamStats("高性能配置", highPerfStream) - - // 3. 使用低延迟预设配置 - fmt.Println("\n3. 低延迟配置Stream:") - lowLatencyConfig := types.NewConfigWithPerformance(types.LowLatencyConfig()) - lowLatencyConfig.SimpleFields = []string{"temperature", "humidity", "location"} - - lowLatencyStream, err := stream.NewStreamWithLowLatency(lowLatencyConfig) - if err != nil { - log.Fatal(err) - } - printStreamStats("低延迟配置", lowLatencyStream) - - // 4. 使用零数据丢失预设配置 - fmt.Println("\n4. 零数据丢失配置Stream:") - zeroLossConfig := types.NewConfigWithPerformance(types.ZeroDataLossConfig()) - zeroLossConfig.SimpleFields = []string{"temperature", "humidity", "location"} - - zeroLossStream, err := stream.NewStreamWithZeroDataLoss(zeroLossConfig) - if err != nil { - log.Fatal(err) - } - printStreamStats("零数据丢失配置", zeroLossStream) - - // 5. 使用持久化预设配置 - fmt.Println("\n5. 持久化配置Stream:") - persistConfig := types.NewConfigWithPerformance(types.PersistencePerformanceConfig()) - persistConfig.SimpleFields = []string{"temperature", "humidity", "location"} - - persistStream, err := stream.NewStreamWithCustomPerformance(persistConfig, types.PersistencePerformanceConfig()) - if err != nil { - log.Fatal(err) - } - printStreamStats("持久化配置", persistStream) - - // 6. 创建完全自定义的配置 - fmt.Println("\n6. 自定义配置Stream:") - customPerfConfig := types.PerformanceConfig{ - BufferConfig: types.BufferConfig{ - DataChannelSize: 30000, - ResultChannelSize: 25000, - WindowOutputSize: 3000, - EnableDynamicResize: true, - MaxBufferSize: 200000, - UsageThreshold: 0.85, - }, - OverflowConfig: types.OverflowConfig{ - Strategy: "expand", - BlockTimeout: 15 * time.Second, - AllowDataLoss: false, - ExpansionConfig: types.ExpansionConfig{ - GrowthFactor: 2.0, - MinIncrement: 2000, - TriggerThreshold: 0.9, - ExpansionTimeout: 3 * time.Second, - }, - }, - WorkerConfig: types.WorkerConfig{ - SinkPoolSize: 800, - SinkWorkerCount: 12, - MaxRetryRoutines: 10, - }, - MonitoringConfig: types.MonitoringConfig{ - EnableMonitoring: true, - StatsUpdateInterval: 500 * time.Millisecond, - EnableDetailedStats: true, - WarningThresholds: types.WarningThresholds{ - DropRateWarning: 5.0, - DropRateCritical: 15.0, - BufferUsageWarning: 75.0, - BufferUsageCritical: 90.0, - }, - }, - } - - customConfig := types.NewConfigWithPerformance(customPerfConfig) - customConfig.SimpleFields = []string{"temperature", "humidity", "location"} - - customStream, err := stream.NewStreamWithCustomPerformance(customConfig, customPerfConfig) - if err != nil { - log.Fatal(err) - } - printStreamStats("自定义配置", customStream) - - // 7. 配置比较演示 - fmt.Println("\n7. 配置比较:") - compareConfigurations() - - // 8. 实时数据处理演示 - fmt.Println("\n8. 实时数据处理演示:") - demonstrateRealTimeProcessing(defaultStream) - - // 9. 窗口统一配置演示 - fmt.Println("\n9. 窗口统一配置演示:") - demonstrateWindowConfig() - - // 清理资源 - fmt.Println("\n10. 清理资源...") - defaultStream.Stop() - highPerfStream.Stop() - lowLatencyStream.Stop() - zeroLossStream.Stop() - persistStream.Stop() - customStream.Stop() - - fmt.Println("\n=== 演示完成 ===") -} - -func printStreamStats(name string, s *stream.Stream) { - stats := s.GetStats() - detailedStats := s.GetDetailedStats() - - fmt.Printf("【%s】统计信息:\n", name) - fmt.Printf(" 数据通道: %d/%d (使用率: %.1f%%)\n", - stats["data_chan_len"], stats["data_chan_cap"], - detailedStats["data_chan_usage"]) - fmt.Printf(" 结果通道: %d/%d (使用率: %.1f%%)\n", - stats["result_chan_len"], stats["result_chan_cap"], - detailedStats["result_chan_usage"]) - fmt.Printf(" 工作池: %d/%d (使用率: %.1f%%)\n", - stats["sink_pool_len"], stats["sink_pool_cap"], - detailedStats["sink_pool_usage"]) - fmt.Printf(" 性能等级: %s\n", detailedStats["performance_level"]) -} - -func compareConfigurations() { - configs := map[string]types.PerformanceConfig{ - "默认配置": types.DefaultPerformanceConfig(), - "高性能配置": types.HighPerformanceConfig(), - "低延迟配置": types.LowLatencyConfig(), - "零丢失配置": types.ZeroDataLossConfig(), - "持久化配置": types.PersistencePerformanceConfig(), - } - - fmt.Printf("%-12s %-10s %-10s %-10s %-10s %-15s\n", - "配置类型", "数据缓冲", "结果缓冲", "工作池", "工作线程", "溢出策略") - fmt.Println(strings.Repeat("-", 75)) - - for name, config := range configs { - fmt.Printf("%-12s %-10d %-10d %-10d %-10d %-15s\n", - name, - config.BufferConfig.DataChannelSize, - config.BufferConfig.ResultChannelSize, - config.WorkerConfig.SinkPoolSize, - config.WorkerConfig.SinkWorkerCount, - config.OverflowConfig.Strategy) - } -} - -func demonstrateRealTimeProcessing(s *stream.Stream) { - // 设置数据接收器 - s.AddSink(func(data interface{}) { - fmt.Printf(" 接收到处理结果: %v\n", data) - }) - - // 启动流处理 - s.Start() - - // 模拟发送数据 - for i := 0; i < 3; i++ { - data := map[string]interface{}{ - "temperature": 20.0 + float64(i)*2.5, - "humidity": 60.0 + float64(i)*5, - "location": fmt.Sprintf("sensor_%d", i+1), - "timestamp": time.Now().Unix(), - } - - fmt.Printf(" 发送数据: %v\n", data) - s.AddData(data) - time.Sleep(100 * time.Millisecond) - } - - // 等待处理完成 - time.Sleep(200 * time.Millisecond) - - // 显示最终统计 - finalStats := s.GetDetailedStats() - fmt.Printf(" 最终统计 - 输入: %d, 输出: %d, 丢弃: %d, 处理率: %.1f%%\n", - finalStats["basic_stats"].(map[string]int64)["input_count"], - finalStats["basic_stats"].(map[string]int64)["output_count"], - finalStats["basic_stats"].(map[string]int64)["dropped_count"], - finalStats["process_rate"]) -} +package main + +import ( + "fmt" + "log" + "strings" + "time" + + "github.com/rulego/streamsql/stream" + "github.com/rulego/streamsql/types" +) + +func main() { + fmt.Println("=== StreamSQL 统一配置系统演示 ===") + + // 1. 使用新的配置API创建默认配置Stream + fmt.Println("\n1. 默认配置Stream:") + defaultConfig := types.NewConfig() + defaultConfig.SimpleFields = []string{"temperature", "humidity", "location"} + + defaultStream, err := stream.NewStream(defaultConfig) + if err != nil { + log.Fatal(err) + } + printStreamStats("默认配置", defaultStream) + + // 2. 使用高性能预设配置 + fmt.Println("\n2. 高性能配置Stream:") + highPerfConfig := types.NewConfigWithPerformance(types.HighPerformanceConfig()) + highPerfConfig.SimpleFields = []string{"temperature", "humidity", "location"} + + highPerfStream, err := stream.NewStreamWithHighPerformance(highPerfConfig) + if err != nil { + log.Fatal(err) + } + printStreamStats("高性能配置", highPerfStream) + + // 3. 使用低延迟预设配置 + fmt.Println("\n3. 低延迟配置Stream:") + lowLatencyConfig := types.NewConfigWithPerformance(types.LowLatencyConfig()) + lowLatencyConfig.SimpleFields = []string{"temperature", "humidity", "location"} + + lowLatencyStream, err := stream.NewStreamWithLowLatency(lowLatencyConfig) + if err != nil { + log.Fatal(err) + } + printStreamStats("低延迟配置", lowLatencyStream) + + // 4. 使用零数据丢失预设配置 + fmt.Println("\n4. 零数据丢失配置Stream:") + zeroLossConfig := types.NewConfigWithPerformance(types.ZeroDataLossConfig()) + zeroLossConfig.SimpleFields = []string{"temperature", "humidity", "location"} + + zeroLossStream, err := stream.NewStreamWithZeroDataLoss(zeroLossConfig) + if err != nil { + log.Fatal(err) + } + printStreamStats("零数据丢失配置", zeroLossStream) + + // 5. 使用持久化预设配置 + fmt.Println("\n5. 持久化配置Stream:") + persistConfig := types.NewConfigWithPerformance(types.PersistencePerformanceConfig()) + persistConfig.SimpleFields = []string{"temperature", "humidity", "location"} + + persistStream, err := stream.NewStreamWithCustomPerformance(persistConfig, types.PersistencePerformanceConfig()) + if err != nil { + log.Fatal(err) + } + printStreamStats("持久化配置", persistStream) + + // 6. 创建完全自定义的配置 + fmt.Println("\n6. 自定义配置Stream:") + customPerfConfig := types.PerformanceConfig{ + BufferConfig: types.BufferConfig{ + DataChannelSize: 30000, + ResultChannelSize: 25000, + WindowOutputSize: 3000, + EnableDynamicResize: true, + MaxBufferSize: 200000, + UsageThreshold: 0.85, + }, + OverflowConfig: types.OverflowConfig{ + Strategy: "expand", + BlockTimeout: 15 * time.Second, + AllowDataLoss: false, + ExpansionConfig: types.ExpansionConfig{ + GrowthFactor: 2.0, + MinIncrement: 2000, + TriggerThreshold: 0.9, + ExpansionTimeout: 3 * time.Second, + }, + }, + WorkerConfig: types.WorkerConfig{ + SinkPoolSize: 800, + SinkWorkerCount: 12, + MaxRetryRoutines: 10, + }, + MonitoringConfig: types.MonitoringConfig{ + EnableMonitoring: true, + StatsUpdateInterval: 500 * time.Millisecond, + EnableDetailedStats: true, + WarningThresholds: types.WarningThresholds{ + DropRateWarning: 5.0, + DropRateCritical: 15.0, + BufferUsageWarning: 75.0, + BufferUsageCritical: 90.0, + }, + }, + } + + customConfig := types.NewConfigWithPerformance(customPerfConfig) + customConfig.SimpleFields = []string{"temperature", "humidity", "location"} + + customStream, err := stream.NewStreamWithCustomPerformance(customConfig, customPerfConfig) + if err != nil { + log.Fatal(err) + } + printStreamStats("自定义配置", customStream) + + // 7. 配置比较演示 + fmt.Println("\n7. 配置比较:") + compareConfigurations() + + // 8. 实时数据处理演示 + fmt.Println("\n8. 实时数据处理演示:") + demonstrateRealTimeProcessing(defaultStream) + + // 9. 窗口统一配置演示 + fmt.Println("\n9. 窗口统一配置演示:") + demonstrateWindowConfig() + + // 清理资源 + fmt.Println("\n10. 清理资源...") + defaultStream.Stop() + highPerfStream.Stop() + lowLatencyStream.Stop() + zeroLossStream.Stop() + persistStream.Stop() + customStream.Stop() + + fmt.Println("\n=== 演示完成 ===") +} + +func printStreamStats(name string, s *stream.Stream) { + stats := s.GetStats() + detailedStats := s.GetDetailedStats() + + fmt.Printf("【%s】统计信息:\n", name) + fmt.Printf(" 数据通道: %d/%d (使用率: %.1f%%)\n", + stats["data_chan_len"], stats["data_chan_cap"], + detailedStats["data_chan_usage"]) + fmt.Printf(" 结果通道: %d/%d (使用率: %.1f%%)\n", + stats["result_chan_len"], stats["result_chan_cap"], + detailedStats["result_chan_usage"]) + fmt.Printf(" 工作池: %d/%d (使用率: %.1f%%)\n", + stats["sink_pool_len"], stats["sink_pool_cap"], + detailedStats["sink_pool_usage"]) + fmt.Printf(" 性能等级: %s\n", detailedStats["performance_level"]) +} + +func compareConfigurations() { + configs := map[string]types.PerformanceConfig{ + "默认配置": types.DefaultPerformanceConfig(), + "高性能配置": types.HighPerformanceConfig(), + "低延迟配置": types.LowLatencyConfig(), + "零丢失配置": types.ZeroDataLossConfig(), + "持久化配置": types.PersistencePerformanceConfig(), + } + + fmt.Printf("%-12s %-10s %-10s %-10s %-10s %-15s\n", + "配置类型", "数据缓冲", "结果缓冲", "工作池", "工作线程", "溢出策略") + fmt.Println(strings.Repeat("-", 75)) + + for name, config := range configs { + fmt.Printf("%-12s %-10d %-10d %-10d %-10d %-15s\n", + name, + config.BufferConfig.DataChannelSize, + config.BufferConfig.ResultChannelSize, + config.WorkerConfig.SinkPoolSize, + config.WorkerConfig.SinkWorkerCount, + config.OverflowConfig.Strategy) + } +} + +func demonstrateRealTimeProcessing(s *stream.Stream) { + // 设置数据接收器 + s.AddSink(func(data interface{}) { + fmt.Printf(" 接收到处理结果: %v\n", data) + }) + + // 启动流处理 + s.Start() + + // 模拟发送数据 + for i := 0; i < 3; i++ { + data := map[string]interface{}{ + "temperature": 20.0 + float64(i)*2.5, + "humidity": 60.0 + float64(i)*5, + "location": fmt.Sprintf("sensor_%d", i+1), + "timestamp": time.Now().Unix(), + } + + fmt.Printf(" 发送数据: %v\n", data) + s.Emit(data) + time.Sleep(100 * time.Millisecond) + } + + // 等待处理完成 + time.Sleep(200 * time.Millisecond) + + // 显示最终统计 + finalStats := s.GetDetailedStats() + fmt.Printf(" 最终统计 - 输入: %d, 输出: %d, 丢弃: %d, 处理率: %.1f%%\n", + finalStats["basic_stats"].(map[string]int64)["input_count"], + finalStats["basic_stats"].(map[string]int64)["output_count"], + finalStats["basic_stats"].(map[string]int64)["dropped_count"], + finalStats["process_rate"]) +} diff --git a/examples/unified_config/window_config_demo.go b/examples/unified_config/window_config_demo.go index 722187f..81dfa89 100644 --- a/examples/unified_config/window_config_demo.go +++ b/examples/unified_config/window_config_demo.go @@ -1,74 +1,74 @@ -package main - -import ( - "fmt" - "time" - - "github.com/rulego/streamsql" - "github.com/rulego/streamsql/types" -) - -// demonstrateWindowConfig 演示窗口统一配置的使用 -func demonstrateWindowConfig() { - fmt.Println("=== 窗口统一配置演示 ===") - - // 1. 测试默认配置的窗口 - fmt.Println("\n1. 默认配置窗口测试") - testWindowWithConfig("默认配置", streamsql.New()) - - // 2. 测试高性能配置的窗口 - fmt.Println("\n2. 高性能配置窗口测试") - testWindowWithConfig("高性能配置", streamsql.New(streamsql.WithHighPerformance())) - - // 3. 测试低延迟配置的窗口 - fmt.Println("\n3. 低延迟配置窗口测试") - testWindowWithConfig("低延迟配置", streamsql.New(streamsql.WithLowLatency())) - - // 4. 测试自定义配置的窗口 - fmt.Println("\n4. 自定义配置窗口测试") - customConfig := types.DefaultPerformanceConfig() - customConfig.BufferConfig.WindowOutputSize = 2000 // 自定义窗口输出缓冲区大小 - testWindowWithConfig("自定义配置", streamsql.New(streamsql.WithCustomPerformance(customConfig))) - - fmt.Println("\n=== 窗口配置演示完成 ===") -} - -func testWindowWithConfig(configName string, ssql *streamsql.Streamsql) { - // 执行一个简单的滚动窗口查询 - sql := "SELECT deviceId, AVG(temperature) as avg_temp FROM stream GROUP BY deviceId, TumblingWindow('2s')" - - err := ssql.Execute(sql) - if err != nil { - fmt.Printf("❌ %s - 执行SQL失败: %v\n", configName, err) - return - } - - // 添加结果处理器 - stream := ssql.Stream() - if stream != nil { - stream.AddSink(func(result interface{}) { - fmt.Printf("📊 %s - 窗口结果: %v\n", configName, result) - }) - - // 发送测试数据 - for i := 0; i < 5; i++ { - data := map[string]interface{}{ - "deviceId": fmt.Sprintf("device_%d", i%2), - "temperature": 20.0 + float64(i), - "timestamp": time.Now(), - } - ssql.AddData(data) - } - - // 等待处理完成 - time.Sleep(3 * time.Second) - - // 获取统计信息 - stats := ssql.GetDetailedStats() - fmt.Printf("📈 %s - 统计信息: %v\n", configName, stats) - } - - // 停止流处理 - ssql.Stop() - fmt.Printf("✅ %s - 测试完成\n", configName) -} +package main + +import ( + "fmt" + "time" + + "github.com/rulego/streamsql" + "github.com/rulego/streamsql/types" +) + +// demonstrateWindowConfig 演示窗口统一配置的使用 +func demonstrateWindowConfig() { + fmt.Println("=== 窗口统一配置演示 ===") + + // 1. 测试默认配置的窗口 + fmt.Println("\n1. 默认配置窗口测试") + testWindowWithConfig("默认配置", streamsql.New()) + + // 2. 测试高性能配置的窗口 + fmt.Println("\n2. 高性能配置窗口测试") + testWindowWithConfig("高性能配置", streamsql.New(streamsql.WithHighPerformance())) + + // 3. 测试低延迟配置的窗口 + fmt.Println("\n3. 低延迟配置窗口测试") + testWindowWithConfig("低延迟配置", streamsql.New(streamsql.WithLowLatency())) + + // 4. 测试自定义配置的窗口 + fmt.Println("\n4. 自定义配置窗口测试") + customConfig := types.DefaultPerformanceConfig() + customConfig.BufferConfig.WindowOutputSize = 2000 // 自定义窗口输出缓冲区大小 + testWindowWithConfig("自定义配置", streamsql.New(streamsql.WithCustomPerformance(customConfig))) + + fmt.Println("\n=== 窗口配置演示完成 ===") +} + +func testWindowWithConfig(configName string, ssql *streamsql.Streamsql) { + // 执行一个简单的滚动窗口查询 + sql := "SELECT deviceId, AVG(temperature) as avg_temp FROM stream GROUP BY deviceId, TumblingWindow('2s')" + + err := ssql.Execute(sql) + if err != nil { + fmt.Printf("❌ %s - 执行SQL失败: %v\n", configName, err) + return + } + + // 添加结果处理器 + stream := ssql.Stream() + if stream != nil { + stream.AddSink(func(result interface{}) { + fmt.Printf("📊 %s - 窗口结果: %v\n", configName, result) + }) + + // 发送测试数据 + for i := 0; i < 5; i++ { + data := map[string]interface{}{ + "deviceId": fmt.Sprintf("device_%d", i%2), + "temperature": 20.0 + float64(i), + "timestamp": time.Now(), + } + ssql.Emit(data) + } + + // 等待处理完成 + time.Sleep(3 * time.Second) + + // 获取统计信息 + stats := ssql.GetDetailedStats() + fmt.Printf("📈 %s - 统计信息: %v\n", configName, stats) + } + + // 停止流处理 + ssql.Stop() + fmt.Printf("✅ %s - 测试完成\n", configName) +} diff --git a/expr/expression.go b/expr/expression.go index 0df9b98..dbf2dde 100644 --- a/expr/expression.go +++ b/expr/expression.go @@ -136,7 +136,9 @@ func validateBasicSyntax(exprStr string) error { // checkConsecutiveOperators 检查连续运算符 func checkConsecutiveOperators(expr string) error { // 简化的连续运算符检查:查找明显的双运算符模式 + // 但要允许比较运算符后跟负数的情况 operators := []string{"+", "-", "*", "/", "%", "^", "==", "!=", ">=", "<=", ">", "<"} + comparisonOps := []string{"==", "!=", ">=", "<=", ">", "<"} for i := 0; i < len(expr)-1; i++ { // 跳过空白字符 @@ -147,10 +149,12 @@ func checkConsecutiveOperators(expr string) error { // 检查当前位置是否是运算符 isCurrentOp := false currentOpLen := 0 + currentOp := "" for _, op := range operators { if i+len(op) <= len(expr) && expr[i:i+len(op)] == op { isCurrentOp = true currentOpLen = len(op) + currentOp = op break } } @@ -164,10 +168,35 @@ func checkConsecutiveOperators(expr string) error { // 检查下一个字符是否也是运算符 if nextPos < len(expr) { + // 特殊处理:如果当前是比较运算符,下一个是负号,且负号后跟数字,则允许 + isCurrentComparison := false + for _, compOp := range comparisonOps { + if currentOp == compOp { + isCurrentComparison = true + break + } + } + + // 检查是否是负数的情况 + if isCurrentComparison && nextPos < len(expr) && expr[nextPos] == '-' { + // 检查负号后是否跟数字 + digitPos := nextPos + 1 + for digitPos < len(expr) && (expr[digitPos] == ' ' || expr[digitPos] == '\t') { + digitPos++ + } + if digitPos < len(expr) && expr[digitPos] >= '0' && expr[digitPos] <= '9' { + // 这是比较运算符后跟负数,允许通过 + i = nextPos // 跳过到负号位置 + continue + } + } + + // 检查其他连续运算符 for _, op := range operators { if nextPos+len(op) <= len(expr) && expr[nextPos:nextPos+len(op)] == op { + // 如果不是允许的负数情况,则报错 return fmt.Errorf("consecutive operators found: '%s' followed by '%s'", - expr[i:i+currentOpLen], op) + currentOp, op) } } } @@ -394,28 +423,34 @@ func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) return float64(len(value)), nil case TypeField: + // 处理反引号标识符,去除反引号 + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号 + } + // 支持嵌套字段访问 - if fieldpath.IsNestedField(node.Value) { - if val, found := fieldpath.GetNestedField(data, node.Value); found { + if fieldpath.IsNestedField(fieldName) { + if val, found := fieldpath.GetNestedField(data, fieldName); found { // 尝试转换为float64 if floatVal, err := convertToFloat(val); err == nil { return floatVal, nil } // 如果不能转换为数字,返回错误 - return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", node.Value, val) + return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val) } } else { // 原有的简单字段访问 - if val, found := data[node.Value]; found { + if val, found := data[fieldName]; found { // 尝试转换为float64 if floatVal, err := convertToFloat(val); err == nil { return floatVal, nil } // 如果不能转换为数字,返回错误 - return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", node.Value, val) + return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val) } } - return 0, fmt.Errorf("field '%s' not found", node.Value) + return 0, fmt.Errorf("field '%s' not found", fieldName) case TypeOperator: // 计算左右子表达式的值 @@ -788,18 +823,24 @@ func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{} return value, nil case TypeField: + // 处理反引号标识符,去除反引号 + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号 + } + // 支持嵌套字段访问 - if fieldpath.IsNestedField(node.Value) { - if val, found := fieldpath.GetNestedField(data, node.Value); found { + if fieldpath.IsNestedField(fieldName) { + if val, found := fieldpath.GetNestedField(data, fieldName); found { return val, nil } } else { // 原有的简单字段访问 - if val, found := data[node.Value]; found { + if val, found := data[fieldName]; found { return val, nil } } - return nil, fmt.Errorf("field '%s' not found", node.Value) + return nil, fmt.Errorf("field '%s' not found", fieldName) default: // 对于其他类型,回退到数值计算 @@ -985,8 +1026,12 @@ func tokenize(expr string) ([]string, error) { prevToken == "(" || // 左括号后 prevToken == "," || // 逗号后(函数参数) isOperator(prevToken) || // 运算符后 + isComparisonOperator(prevToken) || // 比较运算符后 strings.ToUpper(prevToken) == "THEN" || // THEN后 - strings.ToUpper(prevToken) == "ELSE" // ELSE后 + strings.ToUpper(prevToken) == "ELSE" || // ELSE后 + strings.ToUpper(prevToken) == "WHEN" || // WHEN后 + strings.ToUpper(prevToken) == "AND" || // AND后 + strings.ToUpper(prevToken) == "OR" // OR后 } if canBeNegativeNumber && i+1 < len(expr) && isDigit(expr[i+1]) { @@ -1077,6 +1122,25 @@ func tokenize(expr string) ([]string, error) { continue } + // 处理反引号标识符 + if ch == '`' { + start := i + i++ // 跳过开始反引号 + + // 寻找结束反引号 + for i < len(expr) && expr[i] != '`' { + i++ + } + + if i >= len(expr) { + return nil, fmt.Errorf("unterminated quoted identifier starting at position %d", start) + } + + i++ // 跳过结束反引号 + tokens = append(tokens, expr[start:i]) + continue + } + // 处理标识符(字段名或函数名) if isLetter(ch) { start := i @@ -1621,6 +1685,505 @@ func isOperator(s string) bool { } } +// isComparisonOperator 检查是否是比较运算符 +func isComparisonOperator(s string) bool { + switch s { + case ">", "<", ">=", "<=", "==", "=", "!=", "<>": + return true + default: + return false + } +} + func isStringLiteral(expr string) bool { return len(expr) > 1 && (expr[0] == '\'' || expr[0] == '"') && expr[len(expr)-1] == expr[0] } + +// evaluateNodeWithNull 计算节点值,支持NULL值返回 +// 返回 (result, isNull, error) +func evaluateNodeWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) { + if node == nil { + return 0, true, nil // NULL + } + + switch node.Type { + case TypeNumber: + val, err := strconv.ParseFloat(node.Value, 64) + return val, false, err + + case TypeString: + // 字符串长度作为数值,特殊处理NULL字符串 + value := node.Value + if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { + value = value[1 : len(value)-1] + } + // 检查是否是NULL字符串 + if strings.ToUpper(value) == "NULL" { + return 0, true, nil + } + return float64(len(value)), false, nil + + case TypeField: + // 支持嵌套字段访问 + var fieldVal interface{} + var found bool + + if fieldpath.IsNestedField(node.Value) { + fieldVal, found = fieldpath.GetNestedField(data, node.Value) + } else { + fieldVal, found = data[node.Value] + } + + if !found || fieldVal == nil { + return 0, true, nil // NULL + } + + // 尝试转换为数值 + if val, err := convertToFloat(fieldVal); err == nil { + return val, false, nil + } + return 0, true, fmt.Errorf("cannot convert field '%s' to number", node.Value) + + case TypeOperator: + return evaluateOperatorWithNull(node, data) + + case TypeFunction: + // 函数调用保持原有逻辑,但处理NULL结果 + result, err := evaluateBuiltinFunction(node, data) + return result, false, err + + case TypeCase: + return evaluateCaseExpressionWithNull(node, data) + + default: + return 0, true, fmt.Errorf("unsupported node type: %s", node.Type) + } +} + +// evaluateOperatorWithNull 计算运算符表达式,支持NULL值 +func evaluateOperatorWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) { + leftVal, leftNull, err := evaluateNodeWithNull(node.Left, data) + if err != nil { + return 0, false, err + } + + rightVal, rightNull, err := evaluateNodeWithNull(node.Right, data) + if err != nil { + return 0, false, err + } + + // 算术运算:如果任一操作数为NULL,结果为NULL + if leftNull || rightNull { + switch node.Value { + case "+", "-", "*", "/", "%", "^": + return 0, true, nil + } + } + + // 比较运算:NULL值的比较有特殊规则 + switch node.Value { + case "==", "=": + if leftNull && rightNull { + return 1, false, nil // NULL = NULL 为 true + } + if leftNull || rightNull { + return 0, false, nil // NULL = value 为 false + } + if leftVal == rightVal { + return 1, false, nil + } + return 0, false, nil + + case "!=", "<>": + if leftNull && rightNull { + return 0, false, nil // NULL != NULL 为 false + } + if leftNull || rightNull { + return 0, false, nil // NULL != value 为 false + } + if leftVal != rightVal { + return 1, false, nil + } + return 0, false, nil + + case ">", "<", ">=", "<=": + if leftNull || rightNull { + return 0, false, nil // NULL与任何值的比较都为false + } + } + + // 对于非NULL值,执行正常的算术和比较运算 + switch node.Value { + case "+": + return leftVal + rightVal, false, nil + case "-": + return leftVal - rightVal, false, nil + case "*": + return leftVal * rightVal, false, nil + case "/": + if rightVal == 0 { + return 0, true, nil // 除零返回NULL + } + return leftVal / rightVal, false, nil + case "%": + if rightVal == 0 { + return 0, true, nil + } + return math.Mod(leftVal, rightVal), false, nil + case "^": + return math.Pow(leftVal, rightVal), false, nil + case ">": + if leftVal > rightVal { + return 1, false, nil + } + return 0, false, nil + case "<": + if leftVal < rightVal { + return 1, false, nil + } + return 0, false, nil + case ">=": + if leftVal >= rightVal { + return 1, false, nil + } + return 0, false, nil + case "<=": + if leftVal <= rightVal { + return 1, false, nil + } + return 0, false, nil + default: + return 0, false, fmt.Errorf("unsupported operator: %s", node.Value) + } +} + +// evaluateCaseExpressionWithNull 计算CASE表达式,支持NULL值 +func evaluateCaseExpressionWithNull(node *ExprNode, data map[string]interface{}) (float64, bool, error) { + if node.Type != TypeCase { + return 0, false, fmt.Errorf("node is not a CASE expression") + } + + // 处理简单CASE表达式 (CASE expr WHEN value1 THEN result1 ...) + if node.CaseExpr != nil { + // 计算CASE后面的表达式值 + caseValue, caseNull, err := evaluateNodeValueWithNull(node.CaseExpr, data) + if err != nil { + return 0, false, err + } + + // 遍历WHEN子句,查找匹配的值 + for _, whenClause := range node.WhenClauses { + conditionValue, condNull, err := evaluateNodeValueWithNull(whenClause.Condition, data) + if err != nil { + return 0, false, err + } + + // 比较值是否相等(考虑NULL值) + var isEqual bool + if caseNull && condNull { + isEqual = true // NULL = NULL + } else if caseNull || condNull { + isEqual = false // NULL != value + } else { + isEqual, err = compareValuesForEquality(caseValue, conditionValue) + if err != nil { + return 0, false, err + } + } + + if isEqual { + return evaluateNodeWithNull(whenClause.Result, data) + } + } + } else { + // 处理搜索CASE表达式 (CASE WHEN condition1 THEN result1 ...) + for _, whenClause := range node.WhenClauses { + // 评估WHEN条件 + conditionResult, err := evaluateBooleanConditionWithNull(whenClause.Condition, data) + if err != nil { + return 0, false, err + } + + // 如果条件为真,返回对应的结果 + if conditionResult { + return evaluateNodeWithNull(whenClause.Result, data) + } + } + } + + // 如果没有匹配的WHEN子句,执行ELSE子句 + if node.ElseExpr != nil { + return evaluateNodeWithNull(node.ElseExpr, data) + } + + // 如果没有ELSE子句,SQL标准是返回NULL + return 0, true, nil +} + +// evaluateNodeValueWithNull 计算节点值,返回interface{}以支持不同类型,包含NULL检查 +func evaluateNodeValueWithNull(node *ExprNode, data map[string]interface{}) (interface{}, bool, error) { + if node == nil { + return nil, true, nil + } + + switch node.Type { + case TypeNumber: + val, err := strconv.ParseFloat(node.Value, 64) + return val, false, err + + case TypeString: + // 去掉引号 + value := node.Value + if len(value) >= 2 && (value[0] == '\'' || value[0] == '"') { + value = value[1 : len(value)-1] + } + // 检查是否是NULL字符串 + if strings.ToUpper(value) == "NULL" { + return nil, true, nil + } + return value, false, nil + + case TypeField: + // 处理反引号标识符,去除反引号 + fieldName := node.Value + if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' { + fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号 + } + + // 支持嵌套字段访问 + if fieldpath.IsNestedField(fieldName) { + if val, found := fieldpath.GetNestedField(data, fieldName); found { + return val, val == nil, nil + } + } else { + // 原有的简单字段访问 + if val, found := data[fieldName]; found { + return val, val == nil, nil + } + } + return nil, true, nil // 字段不存在视为NULL + + default: + // 对于其他类型,回退到数值计算 + result, isNull, err := evaluateNodeWithNull(node, data) + return result, isNull, err + } +} + +// evaluateBooleanConditionWithNull 计算布尔条件表达式,支持NULL值 +func evaluateBooleanConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) { + if node == nil { + return false, fmt.Errorf("null condition expression") + } + + // 处理逻辑运算符(实现短路求值) + if node.Type == TypeOperator && (node.Value == "AND" || node.Value == "OR") { + leftBool, err := evaluateBooleanConditionWithNull(node.Left, data) + if err != nil { + return false, err + } + + // 短路求值:对于AND,如果左边为false,立即返回false + if node.Value == "AND" && !leftBool { + return false, nil + } + + // 短路求值:对于OR,如果左边为true,立即返回true + if node.Value == "OR" && leftBool { + return true, nil + } + + // 只有在需要时才评估右边的表达式 + rightBool, err := evaluateBooleanConditionWithNull(node.Right, data) + if err != nil { + return false, err + } + + switch node.Value { + case "AND": + return leftBool && rightBool, nil + case "OR": + return leftBool || rightBool, nil + } + } + + // 处理IS NULL和IS NOT NULL特殊情况 + if node.Type == TypeOperator && node.Value == "IS" { + return evaluateIsConditionWithNull(node, data) + } + + // 处理比较运算符 + if node.Type == TypeOperator { + leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data) + if err != nil { + return false, err + } + + rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data) + if err != nil { + return false, err + } + + return compareValuesWithNull(leftValue, leftNull, rightValue, rightNull, node.Value) + } + + // 对于其他表达式,计算其数值并转换为布尔值 + result, isNull, err := evaluateNodeWithNull(node, data) + if err != nil { + return false, err + } + + // NULL值在布尔上下文中为false,非零值为真,零值为假 + return !isNull && result != 0, nil +} + +// evaluateIsConditionWithNull 处理IS NULL和IS NOT NULL条件,支持NULL值 +func evaluateIsConditionWithNull(node *ExprNode, data map[string]interface{}) (bool, error) { + if node == nil || node.Left == nil || node.Right == nil { + return false, fmt.Errorf("invalid IS condition") + } + + // 获取左侧值 + leftValue, leftNull, err := evaluateNodeValueWithNull(node.Left, data) + if err != nil { + // 如果字段不存在,认为是null + leftValue = nil + leftNull = true + } + + // 检查右侧是否是NULL或NOT NULL + if node.Right.Type == TypeField && strings.ToUpper(node.Right.Value) == "NULL" { + // IS NULL + return leftNull || leftValue == nil, nil + } + + // 检查是否是IS NOT NULL + if node.Right.Type == TypeOperator && node.Right.Value == "NOT" && + node.Right.Right != nil && node.Right.Right.Type == TypeField && + strings.ToUpper(node.Right.Right.Value) == "NULL" { + // IS NOT NULL + return !leftNull && leftValue != nil, nil + } + + // 其他IS比较 + rightValue, rightNull, err := evaluateNodeValueWithNull(node.Right, data) + if err != nil { + return false, err + } + + return compareValuesWithNullForEquality(leftValue, leftNull, rightValue, rightNull) +} + +// compareValuesForEquality 比较两个值是否相等 +func compareValuesForEquality(left, right interface{}) (bool, error) { + // 尝试字符串比较 + leftStr, leftIsStr := left.(string) + rightStr, rightIsStr := right.(string) + + if leftIsStr && rightIsStr { + return leftStr == rightStr, nil + } + + // 尝试数值比较 + leftFloat, leftErr := convertToFloat(left) + rightFloat, rightErr := convertToFloat(right) + + if leftErr == nil && rightErr == nil { + return leftFloat == rightFloat, nil + } + + // 如果都不能转换,直接比较 + return left == right, nil +} + +// compareValuesWithNull 比较两个值(支持NULL) +func compareValuesWithNull(left interface{}, leftNull bool, right interface{}, rightNull bool, operator string) (bool, error) { + // NULL值的比较有特殊规则 + switch operator { + case "==", "=": + if leftNull && rightNull { + return true, nil // NULL = NULL 为 true + } + if leftNull || rightNull { + return false, nil // NULL = value 为 false + } + + case "!=", "<>": + if leftNull && rightNull { + return false, nil // NULL != NULL 为 false + } + if leftNull || rightNull { + return false, nil // NULL != value 为 false + } + + case ">", "<", ">=", "<=": + if leftNull || rightNull { + return false, nil // NULL与任何值的比较都为false + } + } + + // 对于非NULL值,执行正确的比较逻辑 + switch operator { + case "==", "=": + return compareValuesForEquality(left, right) + case "!=", "<>": + equal, err := compareValuesForEquality(left, right) + return !equal, err + case ">", "<", ">=", "<=": + // 进行数值比较 + leftFloat, leftErr := convertToFloat(left) + rightFloat, rightErr := convertToFloat(right) + + if leftErr != nil || rightErr != nil { + // 如果不能转换为数值,尝试字符串比较 + leftStr := fmt.Sprintf("%v", left) + rightStr := fmt.Sprintf("%v", right) + + switch operator { + case ">": + return leftStr > rightStr, nil + case "<": + return leftStr < rightStr, nil + case ">=": + return leftStr >= rightStr, nil + case "<=": + return leftStr <= rightStr, nil + } + } + + // 数值比较 + switch operator { + case ">": + return leftFloat > rightFloat, nil + case "<": + return leftFloat < rightFloat, nil + case ">=": + return leftFloat >= rightFloat, nil + case "<=": + return leftFloat <= rightFloat, nil + } + } + + return false, fmt.Errorf("unsupported operator: %s", operator) +} + +// compareValuesWithNullForEquality 比较两个值是否相等(支持NULL) +func compareValuesWithNullForEquality(left interface{}, leftNull bool, right interface{}, rightNull bool) (bool, error) { + if leftNull && rightNull { + return true, nil // NULL = NULL 为 true + } + if leftNull || rightNull { + return false, nil // NULL = value 为 false + } + return compareValuesForEquality(left, right) +} + +// EvaluateWithNull 提供公开接口,用于聚合函数调用 +func (e *Expression) EvaluateWithNull(data map[string]interface{}) (float64, bool, error) { + if e.useExprLang { + // expr-lang不支持NULL,回退到原有逻辑 + result, err := e.evaluateWithExprLang(data) + return result, false, err + } + return evaluateNodeWithNull(e.Root, data) +} diff --git a/expr/expression_test.go b/expr/expression_test.go index cd1593e..34e520d 100644 --- a/expr/expression_test.go +++ b/expr/expression_test.go @@ -61,6 +61,356 @@ func TestExpressionEvaluation(t *testing.T) { } } +// TestCaseExpressionParsing 测试CASE表达式的解析功能 +func TestCaseExpressionParsing(t *testing.T) { + tests := []struct { + name string + exprStr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + { + name: "简单的搜索CASE表达式", + exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 35.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "简单CASE表达式 - 值匹配", + exprStr: "CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END", + data: map[string]interface{}{"status": "active"}, + expected: 1.0, + wantErr: false, + }, + { + name: "CASE表达式 - ELSE分支", + exprStr: "CASE WHEN temperature > 50 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 25.5}, + expected: 0.0, + wantErr: false, + }, + { + name: "复杂搜索CASE表达式", + exprStr: "CASE WHEN temperature > 30 THEN 'HOT' WHEN temperature > 20 THEN 'WARM' ELSE 'COLD' END", + data: map[string]interface{}{"temperature": 25.0}, + expected: 4.0, // 字符串"WARM"的长度 + wantErr: false, + }, + { + name: "数值比较的简单CASE", + exprStr: "CASE temperature WHEN 25 THEN 1 WHEN 30 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": 30.0}, + expected: 2.0, + wantErr: false, + }, + { + name: "布尔值CASE表达式", + exprStr: "CASE WHEN temperature > 25 AND humidity > 50 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 30.0, "humidity": 60.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "多条件CASE表达式_AND", + exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 WHEN temperature > 20 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": 35.0, "humidity": 50.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "多条件CASE表达式_OR", + exprStr: "CASE WHEN temperature > 40 OR humidity > 80 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 25.0, "humidity": 85.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "函数调用在CASE中_ABS", + exprStr: "CASE WHEN ABS(temperature) > 30 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": -35.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "函数调用在CASE中_ROUND", + exprStr: "CASE WHEN ROUND(temperature) = 25 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 24.7}, + expected: 1.0, + wantErr: false, + }, + { + name: "复杂条件组合", + exprStr: "CASE WHEN temperature > 30 AND (humidity > 60 OR pressure < 1000) THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 35.0, "humidity": 55.0, "pressure": 950.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "CASE中的算术表达式", + exprStr: "CASE WHEN temperature * 1.8 + 32 > 100 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 40.0}, // 40*1.8+32 = 104 + expected: 1.0, + wantErr: false, + }, + { + name: "字符串函数在CASE中", + exprStr: "CASE WHEN LENGTH(device_name) > 5 THEN 1 ELSE 0 END", + data: map[string]interface{}{"device_name": "sensor123"}, + expected: 1.0, // LENGTH函数正常工作,"sensor123"长度为9 > 5,返回1 + wantErr: false, + }, + { + name: "简单CASE与函数", + exprStr: "CASE ABS(temperature) WHEN 30 THEN 1 WHEN 25 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": -30.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "CASE结果中的函数", + exprStr: "CASE WHEN temperature > 30 THEN ABS(temperature) ELSE ROUND(temperature) END", + data: map[string]interface{}{"temperature": 35.5}, + expected: 35.5, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expression, err := NewExpression(tt.exprStr) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err, "Expression creation should not fail") + assert.NotNil(t, expression, "Expression should not be nil") + + // 测试表达式计算 + result, err := expression.Evaluate(tt.data) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err, "Expression evaluation should not fail") + assert.Equal(t, tt.expected, result, "Expression result should match expected value") + }) + } +} + +// TestCaseExpressionFieldExtraction 测试CASE表达式的字段提取功能 +func TestCaseExpressionFieldExtraction(t *testing.T) { + testCases := []struct { + name string + exprStr string + expectedFields []string + }{ + { + name: "简单CASE字段提取", + exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", + expectedFields: []string{"temperature"}, + }, + { + name: "多字段CASE字段提取", + exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 ELSE 0 END", + expectedFields: []string{"temperature", "humidity"}, + }, + { + name: "简单CASE字段提取", + exprStr: "CASE status WHEN 'active' THEN temperature ELSE humidity END", + expectedFields: []string{"status", "temperature", "humidity"}, + }, + { + name: "函数CASE字段提取", + exprStr: "CASE WHEN ABS(temperature) > 30 THEN device_id ELSE location END", + expectedFields: []string{"temperature", "device_id", "location"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + expression, err := NewExpression(tc.exprStr) + assert.NoError(t, err, "表达式创建应该成功") + + fields := expression.GetFields() + + // 验证所有期望的字段都被提取到了 + for _, expectedField := range tc.expectedFields { + assert.Contains(t, fields, expectedField, "应该包含字段: %s", expectedField) + } + }) + } +} + +// TestCaseExpressionWithNullComparisons 测试CASE表达式中的NULL比较 +func TestCaseExpressionWithNullComparisons(t *testing.T) { + tests := []struct { + name string + exprStr string + data map[string]interface{} + expected interface{} // 使用interface{}以支持NULL值 + isNull bool + }{ + { + name: "NULL值在CASE条件中 - 应该走ELSE分支", + exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": nil}, + expected: 0.0, + isNull: false, + }, + { + name: "IS NULL条件 - 应该匹配", + exprStr: "CASE WHEN temperature IS NULL THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": nil}, + expected: 1.0, + isNull: false, + }, + { + name: "IS NOT NULL条件 - 不应该匹配", + exprStr: "CASE WHEN temperature IS NOT NULL THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": nil}, + expected: 0.0, + isNull: false, + }, + { + name: "CASE表达式返回NULL", + exprStr: "CASE WHEN temperature > 30 THEN temperature ELSE NULL END", + data: map[string]interface{}{"temperature": 25.0}, + expected: nil, + isNull: true, + }, + { + name: "CASE表达式返回有效值", + exprStr: "CASE WHEN temperature > 30 THEN temperature ELSE NULL END", + data: map[string]interface{}{"temperature": 35.0}, + expected: 35.0, + isNull: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expression, err := NewExpression(tt.exprStr) + assert.NoError(t, err, "表达式解析应该成功") + + // 测试支持NULL的计算方法 + result, isNull, err := expression.EvaluateWithNull(tt.data) + assert.NoError(t, err, "表达式计算应该成功") + + if tt.isNull { + assert.True(t, isNull, "表达式应该返回NULL") + } else { + assert.False(t, isNull, "表达式不应该返回NULL") + assert.Equal(t, tt.expected, result, "表达式结果应该匹配期望值") + } + }) + } +} + +// TestNegativeNumberSupport 专门测试负数支持 +func TestNegativeNumberSupport(t *testing.T) { + tests := []struct { + name string + exprStr string + data map[string]interface{} + expected float64 + wantErr bool + }{ + { + name: "负数常量在THEN中", + exprStr: "CASE WHEN temperature > 0 THEN 1 ELSE -1 END", + data: map[string]interface{}{"temperature": -5.0}, + expected: -1.0, + wantErr: false, + }, + { + name: "负数常量在WHEN中", + exprStr: "CASE WHEN temperature < -10 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": -15.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负数小数", + exprStr: "CASE WHEN temperature > 0 THEN 1.5 ELSE -2.5 END", + data: map[string]interface{}{"temperature": -1.0}, + expected: -2.5, + wantErr: false, + }, + { + name: "负数在算术表达式中", + exprStr: "CASE WHEN temperature + (-10) > 0 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 15.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负数与函数", + exprStr: "CASE WHEN ABS(temperature) > 10 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": -15.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负数在简单CASE中", + exprStr: "CASE temperature WHEN -10 THEN 1 WHEN -20 THEN 2 ELSE 0 END", + data: map[string]interface{}{"temperature": -10.0}, + expected: 1.0, + wantErr: false, + }, + { + name: "负零", + exprStr: "CASE WHEN temperature = -0 THEN 1 ELSE 0 END", + data: map[string]interface{}{"temperature": 0.0}, + expected: 1.0, + wantErr: false, + }, + // 基本负数运算 + { + name: "直接负数", + exprStr: "-5", + data: map[string]interface{}{}, + expected: -5.0, + wantErr: false, + }, + { + name: "负数加法", + exprStr: "-5 + 3", + data: map[string]interface{}{}, + expected: -2.0, + wantErr: false, + }, + { + name: "负数乘法", + exprStr: "-3 * 4", + data: map[string]interface{}{}, + expected: -12.0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expression, err := NewExpression(tt.exprStr) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err, "负数表达式解析应该成功") + assert.NotNil(t, expression, "表达式不应为空") + + // 测试表达式计算 + result, err := expression.Evaluate(tt.data) + assert.NoError(t, err, "负数表达式计算应该成功") + assert.Equal(t, tt.expected, result, "负数表达式结果应该匹配期望值") + }) + } +} + func TestGetFields(t *testing.T) { tests := []struct { expr string diff --git a/functions/expr_bridge.go b/functions/expr_bridge.go index d207507..9b74c90 100644 --- a/functions/expr_bridge.go +++ b/functions/expr_bridge.go @@ -151,7 +151,15 @@ func (bridge *ExprBridge) CompileExpressionWithStreamSQLFunctions(expression str // EvaluateExpression 评估表达式,自动选择最合适的引擎 func (bridge *ExprBridge) EvaluateExpression(expression string, data map[string]interface{}) (interface{}, error) { - // 首先检查是否包含LIKE操作符,如果有则进行预处理 + // 首先预处理反引号标识符 + if bridge.ContainsBacktickIdentifiers(expression) { + processedExpr, err := bridge.PreprocessBacktickIdentifiers(expression) + if err == nil { + expression = processedExpr + } + } + + // 检查是否包含LIKE操作符,如果有则进行预处理 if bridge.ContainsLikeOperator(expression) { processedExpr, err := bridge.PreprocessLikeExpression(expression) if err == nil { @@ -407,8 +415,9 @@ func (bridge *ExprBridge) isFunctionCall(expression string) bool { // PreprocessLikeExpression 预处理LIKE表达式,转换为expr-lang可理解的函数调用 func (bridge *ExprBridge) PreprocessLikeExpression(expression string) (string, error) { // 使用正则表达式匹配LIKE模式 - // 匹配: field LIKE 'pattern' (允许空模式) - likePattern := `(\w+(?:\.\w+)*)\s+LIKE\s+'([^']*)'` + // 匹配: field LIKE 'pattern' 或 `field` LIKE 'pattern' (允许空模式) + // 支持反引号标识符和普通标识符 + likePattern := `((?:` + "`" + `[^` + "`" + `]+` + "`" + `|\w+)(?:\.(?:` + "`" + `[^` + "`" + `]+` + "`" + `|\w+))*)\s+LIKE\s+'([^']*)'` re, err := regexp.Compile(likePattern) if err != nil { return expression, err @@ -424,6 +433,11 @@ func (bridge *ExprBridge) PreprocessLikeExpression(expression string) (string, e field := submatches[1] pattern := submatches[2] + // 处理反引号标识符,去除反引号 + if len(field) >= 2 && field[0] == '`' && field[len(field)-1] == '`' { + field = field[1 : len(field)-1] // 去掉反引号 + } + // 将LIKE模式转换为相应的函数调用 return bridge.convertLikeToFunction(field, pattern) }) @@ -476,6 +490,26 @@ func (bridge *ExprBridge) PreprocessIsNullExpression(expression string) (string, return result, nil } +// ContainsBacktickIdentifiers 检查表达式是否包含反引号标识符 +func (bridge *ExprBridge) ContainsBacktickIdentifiers(expression string) bool { + return strings.Contains(expression, "`") +} + +// PreprocessBacktickIdentifiers 预处理反引号标识符,去除反引号 +func (bridge *ExprBridge) PreprocessBacktickIdentifiers(expression string) (string, error) { + // 使用正则表达式匹配反引号标识符 + // 匹配: `identifier` 或 `nested.field` + backtickPattern := "`([^`]+)`" + re, err := regexp.Compile(backtickPattern) + if err != nil { + return expression, err + } + + // 替换所有反引号标识符,去除反引号 + result := re.ReplaceAllString(expression, "$1") + return result, nil +} + // convertLikeToFunction 将LIKE模式转换为expr-lang操作符 func (bridge *ExprBridge) convertLikeToFunction(field, pattern string) string { // 处理空模式 diff --git a/functions/extension_test.go b/functions/extension_test.go index 864241d..3ec4cec 100644 --- a/functions/extension_test.go +++ b/functions/extension_test.go @@ -27,8 +27,8 @@ func TestAggregatorFunctionInterface(t *testing.T) { // 测试重置 aggInstance.Reset() result = aggInstance.Result() - if result != 0.0 { - t.Errorf("Expected 0.0 after reset, got %v", result) + if result != nil { + t.Errorf("Expected nil after reset (SQL standard: SUM with no rows returns NULL), got %v", result) } // 测试克隆 diff --git a/functions/functions_aggregation.go b/functions/functions_aggregation.go index ad1d57e..6c035f4 100644 --- a/functions/functions_aggregation.go +++ b/functions/functions_aggregation.go @@ -12,12 +12,14 @@ import ( // SumFunction 求和函数 type SumFunction struct { *BaseFunction - value float64 + value float64 + hasValues bool // 标记是否有非NULL值 } func NewSumFunction() *SumFunction { return &SumFunction{ BaseFunction: NewBaseFunction("sum", TypeAggregation, "聚合函数", "计算数值总和", 1, -1), + hasValues: false, } } @@ -27,12 +29,20 @@ func (f *SumFunction) Validate(args []interface{}) error { func (f *SumFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { sum := 0.0 + hasValues := false for _, arg := range args { + if arg == nil { + continue // 忽略NULL值 + } val, err := cast.ToFloat64E(arg) if err != nil { - return nil, err + continue // 忽略无法转换的值 } sum += val + hasValues = true + } + if !hasValues { + return nil, nil // 当没有有效值时返回NULL } return sum, nil } @@ -42,27 +52,40 @@ func (f *SumFunction) New() AggregatorFunction { return &SumFunction{ BaseFunction: f.BaseFunction, value: 0, + hasValues: false, } } func (f *SumFunction) Add(value interface{}) { + // 增强的Add方法:忽略NULL值 + if value == nil { + return // 忽略NULL值 + } + if val, err := cast.ToFloat64E(value); err == nil { f.value += val + f.hasValues = true } + // 如果转换失败,也忽略该值 } func (f *SumFunction) Result() interface{} { + if !f.hasValues { + return nil // 当没有有效值时返回NULL而不是0.0 + } return f.value } func (f *SumFunction) Reset() { f.value = 0 + f.hasValues = false } func (f *SumFunction) Clone() AggregatorFunction { return &SumFunction{ BaseFunction: f.BaseFunction, value: f.value, + hasValues: f.hasValues, } } @@ -85,14 +108,22 @@ func (f *AvgFunction) Validate(args []interface{}) error { func (f *AvgFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { sum := 0.0 + count := 0 for _, arg := range args { + if arg == nil { + continue // 忽略NULL值 + } val, err := cast.ToFloat64E(arg) if err != nil { - return nil, err + continue // 忽略无法转换的值 } sum += val + count++ } - return sum / float64(len(args)), nil + if count == 0 { + return nil, nil // 当没有有效值时返回nil + } + return sum / float64(count), nil } // 实现AggregatorFunction接口 @@ -105,10 +136,16 @@ func (f *AvgFunction) New() AggregatorFunction { } func (f *AvgFunction) Add(value interface{}) { + // 增强的Add方法:忽略NULL值 + if value == nil { + return // 忽略NULL值 + } + if val, err := cast.ToFloat64E(value); err == nil { f.sum += val f.count++ } + // 如果转换失败,也忽略该值 } func (f *AvgFunction) Result() interface{} { @@ -172,6 +209,11 @@ func (f *MinFunction) New() AggregatorFunction { } func (f *MinFunction) Add(value interface{}) { + // 增强的Add方法:忽略NULL值 + if value == nil { + return // 忽略NULL值 + } + if val, err := cast.ToFloat64E(value); err == nil { if f.first || val < f.value { f.value = val @@ -241,6 +283,11 @@ func (f *MaxFunction) New() AggregatorFunction { } func (f *MaxFunction) Add(value interface{}) { + // 增强的Add方法:忽略NULL值 + if value == nil { + return // 忽略NULL值 + } + if val, err := cast.ToFloat64E(value); err == nil { if f.first || val > f.value { f.value = val @@ -286,7 +333,13 @@ func (f *CountFunction) Validate(args []interface{}) error { } func (f *CountFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { - return int64(len(args)), nil + count := 0 + for _, arg := range args { + if arg != nil { + count++ + } + } + return int64(count), nil } // 实现AggregatorFunction接口 @@ -298,7 +351,10 @@ func (f *CountFunction) New() AggregatorFunction { } func (f *CountFunction) Add(value interface{}) { - f.count++ + // 增强的Add方法:忽略NULL值 + if value != nil { + f.count++ + } } func (f *CountFunction) Result() interface{} { diff --git a/rsql/ast.go b/rsql/ast.go index b71bd27..9a16984 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -17,6 +17,7 @@ import ( type SelectStatement struct { Fields []Field Distinct bool + SelectAll bool // 新增:标识是否是SELECT *查询 Source string Condition string Window WindowDefinition @@ -92,13 +93,26 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { // 如果没有聚合函数,收集简单字段 if !hasAggregation { - for _, field := range s.Fields { - fieldName := field.Expression - if field.Alias != "" { - // 如果有别名,用别名作为字段名 - simpleFields = append(simpleFields, fieldName+":"+field.Alias) - } else { - simpleFields = append(simpleFields, fieldName) + // 如果是SELECT *查询,设置特殊标记 + if s.SelectAll { + simpleFields = append(simpleFields, "*") + } else { + for _, field := range s.Fields { + fieldName := field.Expression + if field.Alias != "" { + // 如果有别名,用别名作为字段名 + simpleFields = append(simpleFields, fieldName+":"+field.Alias) + } else { + // 对于没有别名的字段,检查是否为字符串字面量 + _, n, _, _ := ParseAggregateTypeWithExpression(fieldName) + if n != "" { + // 如果是字符串字面量,使用解析出的字段名(去掉引号) + simpleFields = append(simpleFields, n) + } else { + // 否则使用原始表达式 + simpleFields = append(simpleFields, fieldName) + } + } } } logger.Debug("收集简单字段: %v", simpleFields) @@ -107,6 +121,9 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { // 构建字段映射和表达式信息 aggs, fields, expressions := buildSelectFieldsWithExpressions(s.Fields) + // 提取字段顺序信息 + fieldOrder := extractFieldOrder(s.Fields) + // 构建Stream配置 config := types.Config{ WindowConfig: types.WindowConfig{ @@ -125,6 +142,7 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { SimpleFields: simpleFields, Having: s.Having, FieldExpressions: expressions, + FieldOrder: fieldOrder, } return &config, s.Condition, nil @@ -169,10 +187,33 @@ func isAggregationFunction(expr string) bool { if strings.Contains(expr, "(") && strings.Contains(expr, ")") { return true } - return false } +// extractFieldOrder 从Fields切片中提取字段的原始顺序 +// 返回按SELECT语句中出现顺序排列的字段名列表 +func extractFieldOrder(fields []Field) []string { + var fieldOrder []string + + for _, field := range fields { + // 如果有别名,使用别名作为字段名 + if field.Alias != "" { + fieldOrder = append(fieldOrder, field.Alias) + } else { + // 没有别名时,尝试解析表达式获取字段名 + _, fieldName, _, _ := ParseAggregateTypeWithExpression(field.Expression) + if fieldName != "" { + // 如果解析出字段名(如字符串字面量),使用解析出的名称 + fieldOrder = append(fieldOrder, fieldName) + } else { + // 否则使用原始表达式作为字段名 + fieldOrder = append(fieldOrder, field.Expression) + } + } + } + + return fieldOrder +} func extractGroupFields(s *SelectStatement) []string { var fields []string for _, f := range s.GroupBy { @@ -261,6 +302,15 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg // 提取函数名 funcName := extractFunctionName(exprStr) if funcName == "" { + // 检查是否是字符串字面量 + trimmed := strings.TrimSpace(exprStr) + if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) || + (strings.HasPrefix(trimmed, "\"") && strings.HasSuffix(trimmed, "\"")) { + // 字符串字面量:使用去掉引号的内容作为字段名 + fieldName := trimmed[1 : len(trimmed)-1] + return "expression", fieldName, exprStr, nil + } + // 如果不是函数调用,但包含运算符或关键字,可能是表达式 if strings.ContainsAny(exprStr, "+-*/<>=!&|") || strings.Contains(strings.ToUpper(exprStr), "AND") || @@ -638,6 +688,7 @@ func buildSelectFieldsWithExpressions(fields []Field) ( // 没有别名的情况,使用表达式本身作为字段名 t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression) if t != "" && n != "" { + // 对于字符串字面量,使用解析出的字段名(去掉引号)作为键 selectFields[n] = t fieldMap[n] = n diff --git a/rsql/lexer.go b/rsql/lexer.go index 7d93711..cf7f9c8 100644 --- a/rsql/lexer.go +++ b/rsql/lexer.go @@ -12,6 +12,7 @@ const ( TokenIdent TokenNumber TokenString + TokenQuotedIdent // 反引号标识符 TokenComma TokenLParen TokenRParen @@ -176,6 +177,8 @@ func (l *Lexer) NextToken() Token { return l.readStringToken(tokenPos, tokenLine, tokenColumn) case '"': return l.readStringToken(tokenPos, tokenLine, tokenColumn) + case '`': + return l.readQuotedIdentToken(tokenPos, tokenLine, tokenColumn) } if isLetter(l.ch) { @@ -439,6 +442,42 @@ func (l *Lexer) readStringToken(pos, line, column int) Token { return Token{Type: TokenString, Value: value, Pos: pos, Line: line, Column: column} } +// readQuotedIdentToken 读取反引号标识符token并处理错误 +func (l *Lexer) readQuotedIdentToken(pos, line, column int) Token { + startPos := l.pos + l.readChar() // 跳过开头反引号 + + for l.ch != '`' && l.ch != 0 { + l.readChar() + } + + if l.ch == 0 { + // 未闭合的反引号标识符 + if l.errorRecovery != nil { + err := &ParseError{ + Type: ErrorTypeUnterminatedString, + Message: "Unterminated quoted identifier", + Position: startPos, + Line: line, + Column: column, + Token: "`", + Suggestions: []string{"Add closing backtick '`'"}, + Recoverable: true, + } + l.errorRecovery.AddError(err) + } + value := l.input[startPos:l.pos] + return Token{Type: TokenQuotedIdent, Value: value, Pos: pos, Line: line, Column: column} + } + + if l.ch == '`' { + l.readChar() // 跳过结尾反引号 + } + + value := l.input[startPos:l.pos] + return Token{Type: TokenQuotedIdent, Value: value, Pos: pos, Line: line, Column: column} +} + // isValidNumber 验证数字格式 func (l *Lexer) isValidNumber(number string) bool { if number == "" { diff --git a/rsql/lexer_test.go b/rsql/lexer_test.go new file mode 100644 index 0000000..2c79c22 --- /dev/null +++ b/rsql/lexer_test.go @@ -0,0 +1,126 @@ +package rsql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestQuotedIdentifiers 测试反引号标识符的词法分析 +func TestQuotedIdentifiers(t *testing.T) { + t.Run("基本反引号标识符", func(t *testing.T) { + lexer := NewLexer("`deviceId`") + token := lexer.NextToken() + assert.Equal(t, TokenQuotedIdent, token.Type) + assert.Equal(t, "`deviceId`", token.Value) + }) + + t.Run("包含特殊字符的反引号标识符", func(t *testing.T) { + lexer := NewLexer("`device-id`") + token := lexer.NextToken() + assert.Equal(t, TokenQuotedIdent, token.Type) + assert.Equal(t, "`device-id`", token.Value) + }) + + t.Run("包含空格的反引号标识符", func(t *testing.T) { + lexer := NewLexer("`device id`") + token := lexer.NextToken() + assert.Equal(t, TokenQuotedIdent, token.Type) + assert.Equal(t, "`device id`", token.Value) + }) + + t.Run("未闭合的反引号标识符", func(t *testing.T) { + lexer := NewLexer("`deviceId") + errorRecovery := NewErrorRecovery(nil) + lexer.SetErrorRecovery(errorRecovery) + token := lexer.NextToken() + assert.Equal(t, TokenQuotedIdent, token.Type) + assert.True(t, errorRecovery.HasErrors()) + errors := errorRecovery.GetErrors() + assert.Equal(t, 1, len(errors)) + assert.Equal(t, ErrorTypeUnterminatedString, errors[0].Type) + }) +} + +// TestStringLiterals 测试字符串常量的词法分析 +func TestStringLiterals(t *testing.T) { + t.Run("单引号字符串", func(t *testing.T) { + lexer := NewLexer("'hello world'") + token := lexer.NextToken() + assert.Equal(t, TokenString, token.Type) + assert.Equal(t, "'hello world'", token.Value) + }) + + t.Run("双引号字符串", func(t *testing.T) { + lexer := NewLexer(`"hello world"`) + token := lexer.NextToken() + assert.Equal(t, TokenString, token.Type) + assert.Equal(t, `"hello world"`, token.Value) + }) + + t.Run("包含特殊字符的字符串", func(t *testing.T) { + lexer := NewLexer("'test-value_123'") + token := lexer.NextToken() + assert.Equal(t, TokenString, token.Type) + assert.Equal(t, "'test-value_123'", token.Value) + }) + + t.Run("空字符串", func(t *testing.T) { + lexer := NewLexer("''") + token := lexer.NextToken() + assert.Equal(t, TokenString, token.Type) + assert.Equal(t, "''", token.Value) + }) +} + +// TestComplexSQL 测试复杂SQL语句的词法分析 +func TestComplexSQL(t *testing.T) { + t.Run("包含反引号标识符和字符串常量的SQL", func(t *testing.T) { + sql := "SELECT `deviceId`, deviceType, 'aa' as test FROM stream WHERE `deviceId` LIKE 'sensor%'" + lexer := NewLexer(sql) + + // 验证token序列 + expectedTokens := []struct { + Type TokenType + Value string + }{ + {TokenSELECT, "SELECT"}, + {TokenQuotedIdent, "`deviceId`"}, + {TokenComma, ","}, + {TokenIdent, "deviceType"}, + {TokenComma, ","}, + {TokenString, "'aa'"}, + {TokenAS, "as"}, + {TokenIdent, "test"}, + {TokenFROM, "FROM"}, + {TokenIdent, "stream"}, + {TokenWHERE, "WHERE"}, + {TokenQuotedIdent, "`deviceId`"}, + {TokenLIKE, "LIKE"}, + {TokenString, "'sensor%'"}, + {TokenEOF, ""}, + } + + for i, expected := range expectedTokens { + token := lexer.NextToken() + assert.Equal(t, expected.Type, token.Type, "Token %d type mismatch", i) + if expected.Value != "" { + assert.Equal(t, expected.Value, token.Value, "Token %d value mismatch", i) + } + } + }) + + t.Run("双引号字符串常量", func(t *testing.T) { + sql := `SELECT deviceId, "test value" as name FROM stream` + lexer := NewLexer(sql) + + // 跳过前面的token直到字符串 + lexer.NextToken() // SELECT + lexer.NextToken() // deviceId + lexer.NextToken() // , + token := lexer.NextToken() // "test value" + + assert.Equal(t, TokenString, token.Type) + assert.Equal(t, `"test value"`, token.Value) + }) +} diff --git a/rsql/parser.go b/rsql/parser.go index 34810df..47b7e53 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -81,6 +81,8 @@ func (p *Parser) getTokenTypeName(tokenType TokenType) string { return ")" case TokenIdent: return "identifier" + case TokenQuotedIdent: + return "quoted identifier" case TokenNumber: return "number" case TokenString: @@ -212,6 +214,23 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error { currentToken = p.lexer.NextToken() // 消费 DISTINCT,移动到下一个 token } + // 检查是否是SELECT *查询 + if currentToken.Type == TokenIdent && currentToken.Value == "*" { + stmt.SelectAll = true + // 添加一个特殊的字段标记SELECT * + stmt.Fields = append(stmt.Fields, Field{Expression: "*"}) + + // 消费*token并检查下一个token + currentToken = p.lexer.NextToken() + + // 如果下一个token是FROM或EOF,则完成SELECT *解析 + if currentToken.Type == TokenFROM || currentToken.Type == TokenEOF { + return nil + } + + // 如果不是FROM/EOF,继续正常的字段解析流程 + } + // 设置最大字段数量限制,防止无限循环 maxFields := 100 fieldCount := 0 @@ -289,12 +308,12 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error { shouldAddSpace = false } } - } else if len(exprStr) > 0 && currentToken.Type == TokenIdent { - // 检查前一个字符是否是数字,且前面没有空格 - if (lastChar[0] >= '0' && lastChar[0] <= '9') && !strings.HasSuffix(exprStr, " ") { - shouldAddSpace = false - } + } else if len(exprStr) > 0 && (currentToken.Type == TokenIdent || currentToken.Type == TokenQuotedIdent) { + // 检查前一个字符是否是数字,且前面没有空格 + if (lastChar[0] >= '0' && lastChar[0] <= '9') && !strings.HasSuffix(exprStr, " ") { + shouldAddSpace = false } + } if shouldAddSpace { expr.WriteString(" ") @@ -368,7 +387,7 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error { break } switch tok.Type { - case TokenIdent, TokenNumber: + case TokenIdent, TokenNumber, TokenQuotedIdent: conditions = append(conditions, tok.Value) case TokenString: conditions = append(conditions, tok.Value) diff --git a/stream/stream.go b/stream/stream.go index 5b8ce2e..192fa85 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -21,6 +21,90 @@ import ( "github.com/rulego/streamsql/window" ) +// 窗口相关常量 +const ( + WindowStartField = "window_start" + WindowEndField = "window_end" +) + +// 溢出策略常量 +const ( + StrategyDrop = "drop" + StrategyBlock = "block" + StrategyExpand = "expand" + StrategyPersist = "persist" +) + +// 统计信息字段常量 +const ( + StatsInputCount = "input_count" + StatsOutputCount = "output_count" + StatsDroppedCount = "dropped_count" + StatsDataChanLen = "data_chan_len" + StatsDataChanCap = "data_chan_cap" + StatsResultChanLen = "result_chan_len" + StatsResultChanCap = "result_chan_cap" + StatsSinkPoolLen = "sink_pool_len" + StatsSinkPoolCap = "sink_pool_cap" + StatsActiveRetries = "active_retries" + StatsExpanding = "expanding" +) + +// 详细统计信息字段常量 +const ( + StatsBasicStats = "basic_stats" + StatsDataChanUsage = "data_chan_usage" + StatsResultChanUsage = "result_chan_usage" + StatsSinkPoolUsage = "sink_pool_usage" + StatsProcessRate = "process_rate" + StatsDropRate = "drop_rate" + StatsPerformanceLevel = "performance_level" +) + +// 性能级别常量 +const ( + PerformanceLevelCritical = "CRITICAL" + PerformanceLevelWarning = "WARNING" + PerformanceLevelHighLoad = "HIGH_LOAD" + PerformanceLevelModerateLoad = "MODERATE_LOAD" + PerformanceLevelOptimal = "OPTIMAL" +) + +// 持久化相关常量 +const ( + PersistenceEnabled = "enabled" + PersistenceMessage = "message" + PersistenceNotEnabledMsg = "persistence not enabled" + PerformanceConfigKey = "performanceConfig" +) + +// SQL关键字常量 +const ( + SQLKeywordCase = "CASE" +) + +// fieldProcessInfo 字段处理信息,用于缓存预编译的字段处理逻辑 +type fieldProcessInfo struct { + fieldName string // 原始字段名 + outputName string // 输出字段名 + isFunctionCall bool // 是否为函数调用 + hasNestedField bool // 是否包含嵌套字段 + isSelectAll bool // 是否为SELECT * + isStringLiteral bool // 是否为字符串字面量 + stringValue string // 预处理的字符串字面量值(去除引号) + alias string // 字段别名,用于快速访问 +} + +// expressionProcessInfo 表达式处理信息,用于缓存预编译的表达式处理逻辑 +type expressionProcessInfo struct { + originalExpr string // 原始表达式 + processedExpr string // 预处理后的表达式 + isFunctionCall bool // 是否为函数调用 + hasNestedFields bool // 是否包含嵌套字段 + compiledExpr *expr.Expression // 预编译的表达式对象 + needsBacktickPreprocess bool // 是否需要反引号预处理 +} + type Stream struct { dataChan chan interface{} filter condition.Condition @@ -52,6 +136,14 @@ type Stream struct { blockingTimeout time.Duration // 阻塞超时时间 overflowStrategy string // 溢出策略: "drop", "block", "expand", "persist" persistenceManager *PersistenceManager // 持久化管理器 + + // 预编译的AddData函数指针,避免每次switch判断 + addDataFunc func(interface{}) // 根据策略预设的函数指针 + + // 预编译字段处理信息,避免重复解析 + compiledFieldInfo map[string]*fieldProcessInfo // 字段处理信息缓存 + compiledExprInfo map[string]*expressionProcessInfo // 表达式处理信息缓存 + } // NewStream 使用统一配置创建Stream @@ -101,7 +193,7 @@ func newStreamWithUnifiedConfig(config types.Config) (*Stream, error) { windowConfig.Params = make(map[string]interface{}) } // 传递完整的性能配置给窗口 - windowConfig.Params["performanceConfig"] = config.PerformanceConfig + windowConfig.Params[PerformanceConfigKey] = config.PerformanceConfig win, err = window.CreateWindow(windowConfig) if err != nil { @@ -126,7 +218,7 @@ func newStreamWithUnifiedConfig(config types.Config) (*Stream, error) { } // 如果是持久化策略,初始化持久化管理器 - if perfConfig.OverflowConfig.Strategy == "persist" && perfConfig.OverflowConfig.PersistenceConfig != nil { + if perfConfig.OverflowConfig.Strategy == StrategyPersist && perfConfig.OverflowConfig.PersistenceConfig != nil { persistConfig := perfConfig.OverflowConfig.PersistenceConfig stream.persistenceManager = NewPersistenceManagerWithConfig( persistConfig.DataDir, @@ -138,6 +230,21 @@ func newStreamWithUnifiedConfig(config types.Config) (*Stream, error) { } } + // 根据溢出策略预设AddData函数指针,避免运行时switch判断 + switch perfConfig.OverflowConfig.Strategy { + case StrategyBlock: + stream.addDataFunc = stream.addDataBlocking + case StrategyExpand: + stream.addDataFunc = stream.addDataWithExpansion + case StrategyPersist: + stream.addDataFunc = stream.addDataWithPersistence + default: + stream.addDataFunc = stream.addDataWithDrop + } + + // 预编译字段处理信息 + stream.compileFieldProcessInfo() + // 启动工作协程,使用配置的工作线程数 go stream.startSinkWorkerPool(perfConfig.WorkerConfig.SinkWorkerCount) go stream.startResultConsumer() @@ -188,16 +295,25 @@ func (s *Stream) startResultConsumer() { } } +// RegisterFilter 注册过滤条件,支持反引号标识符、LIKE语法和IS NULL语法 func (s *Stream) RegisterFilter(conditionStr string) error { if strings.TrimSpace(conditionStr) == "" { return nil } - // 预处理LIKE语法,转换为expr-lang可理解的形式 processedCondition := conditionStr bridge := functions.GetExprBridge() - if bridge.ContainsLikeOperator(conditionStr) { - if processed, err := bridge.PreprocessLikeExpression(conditionStr); err == nil { + + // 首先预处理反引号标识符,去除反引号 + if bridge.ContainsBacktickIdentifiers(conditionStr) { + if processed, err := bridge.PreprocessBacktickIdentifiers(conditionStr); err == nil { + processedCondition = processed + } + } + + // 预处理LIKE语法,转换为expr-lang可理解的形式 + if bridge.ContainsLikeOperator(processedCondition) { + if processed, err := bridge.PreprocessLikeExpression(processedCondition); err == nil { processedCondition = processed } } @@ -241,6 +357,106 @@ func convertToAggregationFields(selectFields map[string]aggregator.AggregateType return fields } +// compileFieldProcessInfo 预编译字段处理信息,避免运行时重复解析 +func (s *Stream) compileFieldProcessInfo() { + s.compiledFieldInfo = make(map[string]*fieldProcessInfo) + s.compiledExprInfo = make(map[string]*expressionProcessInfo) + + // 编译SimpleFields信息 + for _, fieldSpec := range s.config.SimpleFields { + info := &fieldProcessInfo{} + + if fieldSpec == "*" { + info.isSelectAll = true + info.fieldName = "*" + info.outputName = "*" + } else { + // 解析别名 + parts := strings.Split(fieldSpec, ":") + info.fieldName = parts[0] + // 去除字段名中的反引号 + if len(info.fieldName) >= 2 && info.fieldName[0] == '`' && info.fieldName[len(info.fieldName)-1] == '`' { + info.fieldName = info.fieldName[1 : len(info.fieldName)-1] + } + info.outputName = info.fieldName + if len(parts) > 1 { + info.outputName = parts[1] + // 去除输出名中的反引号 + if len(info.outputName) >= 2 && info.outputName[0] == '`' && info.outputName[len(info.outputName)-1] == '`' { + info.outputName = info.outputName[1 : len(info.outputName)-1] + } + } + + // 预判断字段特征 + info.isFunctionCall = strings.Contains(info.fieldName, "(") && strings.Contains(info.fieldName, ")") + info.hasNestedField = !info.isFunctionCall && fieldpath.IsNestedField(info.fieldName) + + // 检查是否为字符串字面量并预处理值 + info.isStringLiteral = (len(info.fieldName) >= 2 && + ((info.fieldName[0] == '\'' && info.fieldName[len(info.fieldName)-1] == '\'') || + (info.fieldName[0] == '"' && info.fieldName[len(info.fieldName)-1] == '"'))) + + // 预处理字符串字面量值,去除引号 + if info.isStringLiteral && len(info.fieldName) >= 2 { + info.stringValue = info.fieldName[1 : len(info.fieldName)-1] + } + + // 设置别名用于快速访问 + info.alias = info.outputName + } + + s.compiledFieldInfo[fieldSpec] = info + } + + // 预编译表达式字段信息 + s.compileExpressionInfo() +} + +// compileExpressionInfo 预编译表达式处理信息 +func (s *Stream) compileExpressionInfo() { + bridge := functions.GetExprBridge() + + for fieldName, fieldExpr := range s.config.FieldExpressions { + exprInfo := &expressionProcessInfo{ + originalExpr: fieldExpr.Expression, + } + + // 预处理表达式 + processedExpr := fieldExpr.Expression + if bridge.ContainsIsNullOperator(processedExpr) { + if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil { + processedExpr = processed + } + } + if bridge.ContainsLikeOperator(processedExpr) { + if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil { + processedExpr = processed + } + } + exprInfo.processedExpr = processedExpr + + // 预判断表达式特征 + exprInfo.isFunctionCall = strings.Contains(fieldExpr.Expression, "(") && strings.Contains(fieldExpr.Expression, ")") + exprInfo.hasNestedFields = !exprInfo.isFunctionCall && strings.Contains(fieldExpr.Expression, ".") + exprInfo.needsBacktickPreprocess = bridge.ContainsBacktickIdentifiers(fieldExpr.Expression) + + // 预编译表达式对象(仅对非函数调用的表达式) + if !exprInfo.isFunctionCall { + exprToCompile := fieldExpr.Expression + if exprInfo.needsBacktickPreprocess { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToCompile); err == nil { + exprToCompile = processed + } + } + if compiledExpr, err := expr.NewExpression(exprToCompile); err == nil { + exprInfo.compiledExpr = compiledExpr + } + } + + s.compiledExprInfo[fieldName] = exprInfo + } +} + func (s *Stream) Start() { // 启动处理协程 go s.process() @@ -294,16 +510,56 @@ func (s *Stream) process() { hasNestedFields := strings.Contains(currentFieldExpr.Expression, ".") if hasNestedFields { - // 直接使用自定义表达式引擎处理嵌套字段 - expression, parseErr := expr.NewExpression(currentFieldExpr.Expression) + // 直接使用自定义表达式引擎处理嵌套字段,支持NULL值 + // 预处理反引号标识符 + exprToUse := currentFieldExpr.Expression + bridge := functions.GetExprBridge() + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, parseErr := expr.NewExpression(exprToUse) if parseErr != nil { return nil, fmt.Errorf("expression parse failed: %w", parseErr) } - numResult, err := expression.Evaluate(dataMap) + // 使用支持NULL的计算方法 + numResult, isNull, err := expression.EvaluateWithNull(dataMap) if err != nil { return nil, fmt.Errorf("expression evaluation failed: %w", err) } + if isNull { + return nil, nil // 返回nil表示NULL值 + } + return numResult, nil + } + + // 检查是否为CASE表达式 + trimmedExpr := strings.TrimSpace(currentFieldExpr.Expression) + upperExpr := strings.ToUpper(trimmedExpr) + if strings.HasPrefix(upperExpr, SQLKeywordCase) { + // CASE表达式使用支持NULL的计算方法 + // 预处理反引号标识符 + exprToUse := currentFieldExpr.Expression + bridge := functions.GetExprBridge() + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, parseErr := expr.NewExpression(exprToUse) + if parseErr != nil { + return nil, fmt.Errorf("CASE expression parse failed: %w", parseErr) + } + + numResult, isNull, err := expression.EvaluateWithNull(dataMap) + if err != nil { + return nil, fmt.Errorf("CASE expression evaluation failed: %w", err) + } + if isNull { + return nil, nil // 返回nil表示NULL值 + } return numResult, nil } @@ -326,16 +582,26 @@ func (s *Stream) process() { result, err := bridge.EvaluateExpression(processedExpr, dataMap) if err != nil { // 如果桥接器失败,回退到原来的表达式引擎(使用原始表达式,不是预处理的) - expression, parseErr := expr.NewExpression(currentFieldExpr.Expression) + // 预处理反引号标识符 + exprToUse := currentFieldExpr.Expression + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, parseErr := expr.NewExpression(exprToUse) if parseErr != nil { return nil, fmt.Errorf("expression parse failed: %w", parseErr) } - // 计算表达式 - numResult, evalErr := expression.Evaluate(dataMap) + // 计算表达式,支持NULL值 + numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap) if evalErr != nil { return nil, fmt.Errorf("expression evaluation failed: %w", evalErr) } + if isNull { + return nil, nil // 返回nil表示NULL值 + } return numResult, nil } @@ -349,11 +615,21 @@ func (s *Stream) process() { // 处理窗口模式 go func() { + defer func() { + if r := recover(); r != nil { + logger.Error("Window processing goroutine panic recovered: %v", r) + } + }() + for batch := range s.Window.OutputChan() { // 处理窗口批数据 for _, item := range batch { - _ = s.aggregator.Put("window_start", item.Slot.WindowStart()) - _ = s.aggregator.Put("window_end", item.Slot.WindowEnd()) + if err := s.aggregator.Put(WindowStartField, item.Slot.WindowStart()); err != nil { + logger.Error("failed to put window start: %v", err) + } + if err := s.aggregator.Put(WindowEndField, item.Slot.WindowEnd()); err != nil { + logger.Error("failed to put window end: %v", err) + } if err := s.aggregator.Add(item.Data); err != nil { logger.Error("aggregate error: %v", err) } @@ -382,36 +658,78 @@ func (s *Stream) process() { // 应用 HAVING 过滤条件 if s.config.Having != "" { - // 预处理HAVING条件中的LIKE语法,转换为expr-lang可理解的形式 - processedHaving := s.config.Having - bridge := functions.GetExprBridge() - if bridge.ContainsLikeOperator(s.config.Having) { - if processed, err := bridge.PreprocessLikeExpression(s.config.Having); err == nil { - processedHaving = processed - } - } + // 检查HAVING条件是否包含CASE表达式 + hasCaseExpression := strings.Contains(strings.ToUpper(s.config.Having), SQLKeywordCase) - // 预处理HAVING条件中的IS NULL语法 - if bridge.ContainsIsNullOperator(processedHaving) { - if processed, err := bridge.PreprocessIsNullExpression(processedHaving); err == nil { - processedHaving = processed - } - } + var filteredResults []map[string]interface{} - // 创建 HAVING 条件 - havingFilter, err := condition.NewExprCondition(processedHaving) - if err != nil { - logger.Error("having filter error: %v", err) - } else { - // 应用 HAVING 过滤 - var filteredResults []map[string]interface{} - for _, result := range finalResults { - if havingFilter.Evaluate(result) { - filteredResults = append(filteredResults, result) + if hasCaseExpression { + // HAVING条件包含CASE表达式,使用我们的表达式解析器 + // 预处理反引号标识符 + exprToUse := s.config.Having + bridge := functions.GetExprBridge() + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, err := expr.NewExpression(exprToUse) + if err != nil { + logger.Error("having filter error (CASE expression): %v", err) + } else { + // 应用 HAVING 过滤,使用CASE表达式计算器 + for _, result := range finalResults { + // 使用EvaluateWithNull方法以支持NULL值处理 + havingResult, isNull, err := expression.EvaluateWithNull(result) + if err != nil { + logger.Error("having filter evaluation error: %v", err) + continue + } + + // 如果结果是NULL,则不满足条件(SQL标准行为) + if isNull { + continue + } + + // 对于数值结果,大于0视为true(满足HAVING条件) + if havingResult > 0 { + filteredResults = append(filteredResults, result) + } + } + } + } else { + // HAVING条件不包含CASE表达式,使用原有的expr-lang处理 + // 预处理HAVING条件中的LIKE语法,转换为expr-lang可理解的形式 + processedHaving := s.config.Having + bridge := functions.GetExprBridge() + if bridge.ContainsLikeOperator(s.config.Having) { + if processed, err := bridge.PreprocessLikeExpression(s.config.Having); err == nil { + processedHaving = processed + } + } + + // 预处理HAVING条件中的IS NULL语法 + if bridge.ContainsIsNullOperator(processedHaving) { + if processed, err := bridge.PreprocessIsNullExpression(processedHaving); err == nil { + processedHaving = processed + } + } + + // 创建 HAVING 条件 + havingFilter, err := condition.NewExprCondition(processedHaving) + if err != nil { + logger.Error("having filter error: %v", err) + } else { + // 应用 HAVING 过滤 + for _, result := range finalResults { + if havingFilter.Evaluate(result) { + filteredResults = append(filteredResults, result) + } } } - finalResults = filteredResults } + + finalResults = filteredResults } // 应用 LIMIT 限制 @@ -419,7 +737,7 @@ func (s *Stream) process() { finalResults = finalResults[:s.config.Limit] } - // 优化: 发送结果到结果通道和 Sink 函数 + // 发送结果到结果通道和 Sink 函数 if len(finalResults) > 0 { // 非阻塞发送到结果通道 s.sendResultNonBlocking(finalResults) @@ -469,12 +787,111 @@ func (s *Stream) process() { } } -// processDirectData 直接处理非窗口数据 (优化版本) +// processExpressionFieldFallback 表达式字段处理的回退逻辑 +func (s *Stream) processExpressionFieldFallback(fieldName string, dataMap map[string]interface{}, result map[string]interface{}) { + fieldExpr, exists := s.config.FieldExpressions[fieldName] + if !exists { + result[fieldName] = nil + return + } + + // 使用桥接器计算表达式,支持IS NULL等语法 + bridge := functions.GetExprBridge() + + // 预处理表达式中的IS NULL和LIKE语法 + processedExpr := fieldExpr.Expression + if bridge.ContainsIsNullOperator(processedExpr) { + if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil { + processedExpr = processed + } + } + if bridge.ContainsLikeOperator(processedExpr) { + if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil { + processedExpr = processed + } + } + + // 检查表达式是否是函数调用(包含括号) + isFunctionCall := strings.Contains(fieldExpr.Expression, "(") && strings.Contains(fieldExpr.Expression, ")") + + // 检查表达式是否包含嵌套字段(但排除函数调用中的点号) + hasNestedFields := false + if !isFunctionCall && strings.Contains(fieldExpr.Expression, ".") { + hasNestedFields = true + } + + var evalResult interface{} + + if isFunctionCall { + // 对于函数调用,优先使用桥接器处理 + exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if err != nil { + logger.Error("Function call evaluation failed for field %s: %v", fieldName, err) + result[fieldName] = nil + return + } + evalResult = exprResult + } else if hasNestedFields { + // 检测到嵌套字段(非函数调用),使用自定义表达式引擎 + exprToUse := fieldExpr.Expression + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, parseErr := expr.NewExpression(exprToUse) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + return + } + + numResult, err := expression.Evaluate(dataMap) + if err != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, err) + result[fieldName] = nil + return + } + evalResult = numResult + } else { + // 尝试使用桥接器处理其他表达式 + exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if err != nil { + // 如果桥接器失败,回退到原来的表达式引擎 + exprToUse := fieldExpr.Expression + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, parseErr := expr.NewExpression(exprToUse) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + return + } + + numResult, evalErr := expression.Evaluate(dataMap) + if evalErr != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, evalErr) + result[fieldName] = nil + return + } + evalResult = numResult + } else { + evalResult = exprResult + } + } + + result[fieldName] = evalResult +} + +// processDirectData 直接处理非窗口数据 func (s *Stream) processDirectData(data interface{}) { // 增加输入计数 atomic.AddInt64(&s.inputCount, 1) - // 简化:直接将数据作为map处理 + // 直接将数据作为map处理 dataMap, ok := data.(map[string]interface{}) if !ok { logger.Error("Unsupported data type: %T", data) @@ -482,86 +899,67 @@ func (s *Stream) processDirectData(data interface{}) { return } - // 创建结果map - result := make(map[string]interface{}) + // 创建结果map,预分配合适容量 + estimatedSize := len(s.config.FieldExpressions) + len(s.config.SimpleFields) + if estimatedSize < 8 { + estimatedSize = 8 // 最小容量 + } + result := make(map[string]interface{}, estimatedSize) - // 处理表达式字段 - for fieldName, fieldExpr := range s.config.FieldExpressions { - - // 使用桥接器计算表达式,支持IS NULL等语法 - bridge := functions.GetExprBridge() - - // 预处理表达式中的IS NULL和LIKE语法 - processedExpr := fieldExpr.Expression - if bridge.ContainsIsNullOperator(processedExpr) { - if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil { - processedExpr = processed - logger.Debug("Preprocessed IS NULL expression: %s -> %s", fieldExpr.Expression, processedExpr) - } - } - if bridge.ContainsLikeOperator(processedExpr) { - if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil { - processedExpr = processed - logger.Debug("Preprocessed LIKE expression: %s -> %s", fieldExpr.Expression, processedExpr) - } - } - - // 检查表达式是否是函数调用(包含括号) - isFunctionCall := strings.Contains(fieldExpr.Expression, "(") && strings.Contains(fieldExpr.Expression, ")") - - // 检查表达式是否包含嵌套字段(但排除函数调用中的点号) - hasNestedFields := false - if !isFunctionCall && strings.Contains(fieldExpr.Expression, ".") { - hasNestedFields = true + // 处理表达式字段(使用预编译信息) + for fieldName := range s.config.FieldExpressions { + exprInfo := s.compiledExprInfo[fieldName] + if exprInfo == nil { + // 回退到原逻辑(安全性保证) + s.processExpressionFieldFallback(fieldName, dataMap, result) + continue } var evalResult interface{} + bridge := functions.GetExprBridge() - if isFunctionCall { - // 对于函数调用,优先使用桥接器处理,这样可以保持原始类型 - exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if exprInfo.isFunctionCall { + // 对于函数调用,使用桥接器处理 + exprResult, err := bridge.EvaluateExpression(exprInfo.processedExpr, dataMap) if err != nil { logger.Error("Function call evaluation failed for field %s: %v", fieldName, err) result[fieldName] = nil continue } evalResult = exprResult - } else if hasNestedFields { - // 检测到嵌套字段(非函数调用),使用自定义表达式引擎 - expression, parseErr := expr.NewExpression(fieldExpr.Expression) - if parseErr != nil { - logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) - result[fieldName] = nil - continue - } - - numResult, err := expression.Evaluate(dataMap) - if err != nil { - logger.Error("Expression evaluation failed for field %s: %v", fieldName, err) - result[fieldName] = nil - continue - } - evalResult = numResult - } else { - // 尝试使用桥接器处理其他表达式 - exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) - if err != nil { - // 如果桥接器失败,回退到原来的表达式引擎(使用原始表达式,不是预处理的) - expression, parseErr := expr.NewExpression(fieldExpr.Expression) - if parseErr != nil { - logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) - result[fieldName] = nil - continue - } - - // 计算表达式 - numResult, evalErr := expression.Evaluate(dataMap) - if evalErr != nil { - logger.Error("Expression evaluation failed for field %s: %v", fieldName, evalErr) + } else if exprInfo.hasNestedFields { + // 使用预编译的表达式对象 + if exprInfo.compiledExpr != nil { + numResult, err := exprInfo.compiledExpr.Evaluate(dataMap) + if err != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, err) result[fieldName] = nil continue } evalResult = numResult + } else { + // 回退到动态编译 + s.processExpressionFieldFallback(fieldName, dataMap, result) + continue + } + } else { + // 尝试使用桥接器处理其他表达式 + exprResult, err := bridge.EvaluateExpression(exprInfo.processedExpr, dataMap) + if err != nil { + // 如果桥接器失败,使用预编译的表达式对象 + if exprInfo.compiledExpr != nil { + numResult, evalErr := exprInfo.compiledExpr.Evaluate(dataMap) + if evalErr != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, evalErr) + result[fieldName] = nil + continue + } + evalResult = numResult + } else { + // 回退到动态编译 + s.processExpressionFieldFallback(fieldName, dataMap, result) + continue + } } else { evalResult = exprResult } @@ -570,46 +968,57 @@ func (s *Stream) processDirectData(data interface{}) { result[fieldName] = evalResult } - // 如果指定了字段,只保留这些字段 + // 使用预编译的字段信息处理SimpleFields if len(s.config.SimpleFields) > 0 { for _, fieldSpec := range s.config.SimpleFields { - // 处理别名 - parts := strings.Split(fieldSpec, ":") - fieldName := parts[0] - outputName := fieldName - if len(parts) > 1 { - outputName = parts[1] - } - - // 跳过已经通过表达式字段处理的字段 - if _, isExpression := s.config.FieldExpressions[outputName]; isExpression { + info := s.compiledFieldInfo[fieldSpec] + if info == nil { + // 如果没有预编译信息,回退到原逻辑(安全性保证) + s.processSingleFieldFallback(fieldSpec, dataMap, data, result) continue } - // 检查是否是函数调用 - if strings.Contains(fieldName, "(") && strings.Contains(fieldName, ")") { + if info.isSelectAll { + // SELECT *:批量复制所有字段,跳过表达式字段 + for k, v := range dataMap { + if _, isExpression := s.config.FieldExpressions[k]; !isExpression { + result[k] = v + } + } + continue + } + + // 跳过已经通过表达式字段处理的字段 + if _, isExpression := s.config.FieldExpressions[info.outputName]; isExpression { + continue + } + + if info.isStringLiteral { + // 字符串字面量处理:使用预编译的字符串值 + result[info.alias] = info.stringValue + } else if info.isFunctionCall { // 执行函数调用 - if funcResult, err := s.executeFunction(fieldName, dataMap); err == nil { - result[outputName] = funcResult + if funcResult, err := s.executeFunction(info.fieldName, dataMap); err == nil { + result[info.outputName] = funcResult } else { - logger.Error("Function execution error %s: %v", fieldName, err) - result[outputName] = nil + logger.Error("Function execution error %s: %v", info.fieldName, err) + result[info.outputName] = nil } } else { - // 普通字段 - 支持嵌套字段 + // 普通字段处理 var value interface{} var exists bool - if fieldpath.IsNestedField(fieldName) { - value, exists = fieldpath.GetNestedField(data, fieldName) + if info.hasNestedField { + value, exists = fieldpath.GetNestedField(data, info.fieldName) } else { - value, exists = dataMap[fieldName] + value, exists = dataMap[info.fieldName] } if exists { - result[outputName] = value + result[info.outputName] = value } else { - result[outputName] = nil + result[info.outputName] = nil } } } @@ -623,13 +1032,68 @@ func (s *Stream) processDirectData(data interface{}) { // 将结果包装为数组 results := []map[string]interface{}{result} - // 优化: 非阻塞发送结果到resultChan + // 非阻塞发送结果到resultChan s.sendResultNonBlocking(results) - // 优化: 异步调用所有sinks,避免阻塞 + // 异步调用所有sinks,避免阻塞 s.callSinksAsync(results) } +// processSingleFieldFallback 回退处理单个字段(当预编译信息缺失时) +func (s *Stream) processSingleFieldFallback(fieldSpec string, dataMap map[string]interface{}, data interface{}, result map[string]interface{}) { + // 处理SELECT *的特殊情况 + if fieldSpec == "*" { + // SELECT *:返回所有字段,但跳过已经通过表达式字段处理的字段 + for k, v := range dataMap { + // 如果该字段已经通过表达式字段处理,则跳过,保持表达式计算结果 + if _, isExpression := s.config.FieldExpressions[k]; !isExpression { + result[k] = v + } + } + return + } + + // 处理别名 + parts := strings.Split(fieldSpec, ":") + fieldName := parts[0] + outputName := fieldName + if len(parts) > 1 { + outputName = parts[1] + } + + // 跳过已经通过表达式字段处理的字段 + if _, isExpression := s.config.FieldExpressions[outputName]; isExpression { + return + } + + // 检查是否是函数调用 + if strings.Contains(fieldName, "(") && strings.Contains(fieldName, ")") { + // 执行函数调用 + if funcResult, err := s.executeFunction(fieldName, dataMap); err == nil { + result[outputName] = funcResult + } else { + logger.Error("Function execution error %s: %v", fieldName, err) + result[outputName] = nil + } + } else { + // 普通字段 - 支持嵌套字段 + var value interface{} + var exists bool + + if fieldpath.IsNestedField(fieldName) { + value, exists = fieldpath.GetNestedField(data, fieldName) + } else { + value, exists = dataMap[fieldName] + } + + if exists { + result[outputName] = value + } else { + result[outputName] = nil + } + } +} + // sendResultNonBlocking 非阻塞方式发送结果到resultChan (智能背压控制) func (s *Stream) sendResultNonBlocking(results []map[string]interface{}) { select { @@ -866,24 +1330,10 @@ func (s *Stream) smartSplitArgs(argsStr string) ([]string, error) { return args, nil } -func (s *Stream) AddData(data interface{}) { +func (s *Stream) Emit(data interface{}) { atomic.AddInt64(&s.inputCount, 1) - - // 根据溢出策略处理数据 - switch s.overflowStrategy { - case "block": - // 阻塞模式:保证数据不丢失 - s.addDataBlocking(data) - case "expand": - // 动态扩容模式:自动扩大缓冲区 - s.addDataWithExpansion(data) - case "persist": - // 持久化模式:溢出数据写入磁盘 - s.addDataWithPersistence(data) - default: - // 默认drop模式:原有逻辑 - s.addDataWithDrop(data) - } + // 直接调用预编译的函数指针,避免switch判断 + s.addDataFunc(data) } // addDataBlocking 阻塞模式添加数据,保证零数据丢失 (线程安全版本) @@ -925,7 +1375,7 @@ func (s *Stream) addDataWithExpansion(data interface{}) { // 扩容后重试,重新获取通道引用 if s.safeSendToDataChan(data) { - logger.Info("Successfully added data after data channel expansion") + logger.Debug("Successfully added data after data channel expansion") return } @@ -960,7 +1410,7 @@ func (s *Stream) addDataWithPersistence(data interface{}) { // addDataWithDrop 原有的丢弃模式 (线程安全版本) func (s *Stream) addDataWithDrop(data interface{}) { - // 优化: 智能非阻塞添加,分层背压控制 + // 智能非阻塞添加,分层背压控制 if s.safeSendToDataChan(data) { return } @@ -1101,17 +1551,17 @@ func (s *Stream) GetStats() map[string]int64 { s.dataChanMux.RUnlock() return map[string]int64{ - "input_count": atomic.LoadInt64(&s.inputCount), - "output_count": atomic.LoadInt64(&s.outputCount), - "dropped_count": atomic.LoadInt64(&s.droppedCount), - "data_chan_len": dataChanLen, - "data_chan_cap": dataChanCap, - "result_chan_len": int64(len(s.resultChan)), - "result_chan_cap": int64(cap(s.resultChan)), - "sink_pool_len": int64(len(s.sinkWorkerPool)), - "sink_pool_cap": int64(cap(s.sinkWorkerPool)), - "active_retries": int64(atomic.LoadInt32(&s.activeRetries)), - "expanding": int64(atomic.LoadInt32(&s.expanding)), + StatsInputCount: atomic.LoadInt64(&s.inputCount), + StatsOutputCount: atomic.LoadInt64(&s.outputCount), + StatsDroppedCount: atomic.LoadInt64(&s.droppedCount), + StatsDataChanLen: dataChanLen, + StatsDataChanCap: dataChanCap, + StatsResultChanLen: int64(len(s.resultChan)), + StatsResultChanCap: int64(cap(s.resultChan)), + StatsSinkPoolLen: int64(len(s.sinkWorkerPool)), + StatsSinkPoolCap: int64(cap(s.sinkWorkerPool)), + StatsActiveRetries: int64(atomic.LoadInt32(&s.activeRetries)), + StatsExpanding: int64(atomic.LoadInt32(&s.expanding)), } } @@ -1120,27 +1570,27 @@ func (s *Stream) GetDetailedStats() map[string]interface{} { stats := s.GetStats() // 计算使用率 - dataUsage := float64(stats["data_chan_len"]) / float64(stats["data_chan_cap"]) * 100 - resultUsage := float64(stats["result_chan_len"]) / float64(stats["result_chan_cap"]) * 100 - sinkUsage := float64(stats["sink_pool_len"]) / float64(stats["sink_pool_cap"]) * 100 + dataUsage := float64(stats[StatsDataChanLen]) / float64(stats[StatsDataChanCap]) * 100 + resultUsage := float64(stats[StatsResultChanLen]) / float64(stats[StatsResultChanCap]) * 100 + sinkUsage := float64(stats[StatsSinkPoolLen]) / float64(stats[StatsSinkPoolCap]) * 100 // 计算效率指标 var processRate float64 = 100.0 var dropRate float64 = 0.0 - if stats["input_count"] > 0 { - processRate = float64(stats["output_count"]) / float64(stats["input_count"]) * 100 - dropRate = float64(stats["dropped_count"]) / float64(stats["input_count"]) * 100 + if stats[StatsInputCount] > 0 { + processRate = float64(stats[StatsOutputCount]) / float64(stats[StatsInputCount]) * 100 + dropRate = float64(stats[StatsDroppedCount]) / float64(stats[StatsInputCount]) * 100 } return map[string]interface{}{ - "basic_stats": stats, - "data_chan_usage": dataUsage, - "result_chan_usage": resultUsage, - "sink_pool_usage": sinkUsage, - "process_rate": processRate, - "drop_rate": dropRate, - "performance_level": s.assessPerformanceLevel(dataUsage, dropRate), + StatsBasicStats: stats, + StatsDataChanUsage: dataUsage, + StatsResultChanUsage: resultUsage, + StatsSinkPoolUsage: sinkUsage, + StatsProcessRate: processRate, + StatsDropRate: dropRate, + StatsPerformanceLevel: s.assessPerformanceLevel(dataUsage, dropRate), } } @@ -1148,15 +1598,15 @@ func (s *Stream) GetDetailedStats() map[string]interface{} { func (s *Stream) assessPerformanceLevel(dataUsage, dropRate float64) string { switch { case dropRate > 50: - return "CRITICAL" // 严重性能问题 + return PerformanceLevelCritical // 严重性能问题 case dropRate > 20: - return "WARNING" // 性能警告 + return PerformanceLevelWarning // 性能警告 case dataUsage > 90: - return "HIGH_LOAD" // 高负载 + return PerformanceLevelHighLoad // 高负载 case dataUsage > 70: - return "MODERATE_LOAD" // 中等负载 + return PerformanceLevelModerateLoad // 中等负载 default: - return "OPTIMAL" // 最佳状态 + return PerformanceLevelOptimal // 最佳状态 } } @@ -1197,7 +1647,7 @@ func (s *Stream) expandDataChannel() { newCap = oldCap + 1000 // 至少增加1000 } - logger.Info("Dynamic expansion of data channel: %d -> %d", oldCap, newCap) + logger.Debug("Dynamic expansion of data channel: %d -> %d", oldCap, newCap) // 创建新的更大的通道 newChan := make(chan interface{}, newCap) @@ -1235,7 +1685,7 @@ migration_done: s.dataChan = newChan s.dataChanMux.Unlock() - logger.Info("Channel expansion completed: migrated %d items", migratedCount) + logger.Debug("Channel expansion completed: migrated %d items", migratedCount) } // persistAndRetryData 持久化数据并重试 (改进版本,具备指数退避和资源控制) @@ -1361,16 +1811,238 @@ func (s *Stream) LoadAndReprocessPersistedData() error { func (s *Stream) GetPersistenceStats() map[string]interface{} { if s.persistenceManager == nil { return map[string]interface{}{ - "enabled": false, - "message": "persistence not enabled", + PersistenceEnabled: false, + PersistenceMessage: PersistenceNotEnabledMsg, } } stats := s.persistenceManager.GetStats() - stats["enabled"] = true + stats[PersistenceEnabled] = true return stats } +// IsAggregationQuery 检查当前流是否为聚合查询 +func (s *Stream) IsAggregationQuery() bool { + return s.config.NeedWindow +} + +// ProcessSync 同步处理单条数据,立即返回结果 +// 仅适用于非聚合查询,聚合查询会返回错误 +func (s *Stream) ProcessSync(data interface{}) (interface{}, error) { + // 检查是否为聚合查询 + if s.config.NeedWindow { + return nil, fmt.Errorf("聚合查询不支持同步处理") + } + + // 应用过滤条件 + if s.filter != nil && !s.filter.Evaluate(data) { + return nil, nil // 不匹配过滤条件,返回nil + } + + // 直接处理数据并返回结果 + return s.processDirectDataSync(data) +} + +// processDirectDataSync 同步版本的直接数据处理 +func (s *Stream) processDirectDataSync(data interface{}) (interface{}, error) { + // 增加输入计数 + atomic.AddInt64(&s.inputCount, 1) + + // 简化:直接将数据作为map处理 + dataMap, ok := data.(map[string]interface{}) + if !ok { + atomic.AddInt64(&s.droppedCount, 1) + return nil, fmt.Errorf("不支持的数据类型: %T", data) + } + + // 创建结果map,预分配合适容量 + estimatedSize := len(s.config.FieldExpressions) + len(s.config.SimpleFields) + if estimatedSize < 8 { + estimatedSize = 8 // 最小容量 + } + result := make(map[string]interface{}, estimatedSize) + + // 处理表达式字段 + for fieldName, fieldExpr := range s.config.FieldExpressions { + // 使用桥接器计算表达式,支持IS NULL等语法 + bridge := functions.GetExprBridge() + + // 预处理表达式中的IS NULL和LIKE语法 + processedExpr := fieldExpr.Expression + if bridge.ContainsIsNullOperator(processedExpr) { + if processed, err := bridge.PreprocessIsNullExpression(processedExpr); err == nil { + processedExpr = processed + } + } + if bridge.ContainsLikeOperator(processedExpr) { + if processed, err := bridge.PreprocessLikeExpression(processedExpr); err == nil { + processedExpr = processed + } + } + + // 检查表达式是否是函数调用(包含括号) + isFunctionCall := strings.Contains(fieldExpr.Expression, "(") && strings.Contains(fieldExpr.Expression, ")") + + // 检查表达式是否包含嵌套字段(但排除函数调用中的点号) + hasNestedFields := false + if !isFunctionCall && strings.Contains(fieldExpr.Expression, ".") { + hasNestedFields = true + } + + // 检查是否为CASE表达式 + trimmedExpr := strings.TrimSpace(fieldExpr.Expression) + upperExpr := strings.ToUpper(trimmedExpr) + isCaseExpression := strings.HasPrefix(upperExpr, SQLKeywordCase) + + var evalResult interface{} + + if isFunctionCall { + // 对于函数调用,优先使用桥接器处理,这样可以保持原始类型 + exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if err != nil { + logger.Error("Function call evaluation failed for field %s: %v", fieldName, err) + result[fieldName] = nil + continue + } + evalResult = exprResult + } else if hasNestedFields || isCaseExpression { + // 检测到嵌套字段(非函数调用)或CASE表达式,使用自定义表达式引擎 + // 预处理反引号标识符 + exprToUse := fieldExpr.Expression + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, parseErr := expr.NewExpression(exprToUse) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + continue + } + + // 使用支持NULL的计算方法 + numResult, isNull, err := expression.EvaluateWithNull(dataMap) + if err != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, err) + result[fieldName] = nil + continue + } + if isNull { + evalResult = nil // NULL值 + } else { + evalResult = numResult + } + } else { + // 尝试使用桥接器处理其他表达式 + exprResult, err := bridge.EvaluateExpression(processedExpr, dataMap) + if err != nil { + // 如果桥接器失败,回退到原来的表达式引擎(使用原始表达式,不是预处理的) + // 预处理反引号标识符 + exprToUse := fieldExpr.Expression + if bridge.ContainsBacktickIdentifiers(exprToUse) { + if processed, err := bridge.PreprocessBacktickIdentifiers(exprToUse); err == nil { + exprToUse = processed + } + } + expression, parseErr := expr.NewExpression(exprToUse) + if parseErr != nil { + logger.Error("Expression parse failed for field %s: %v", fieldName, parseErr) + result[fieldName] = nil + continue + } + + // 计算表达式,支持NULL值 + numResult, isNull, evalErr := expression.EvaluateWithNull(dataMap) + if evalErr != nil { + logger.Error("Expression evaluation failed for field %s: %v", fieldName, evalErr) + result[fieldName] = nil + continue + } + if isNull { + evalResult = nil // NULL值 + } else { + evalResult = numResult + } + } else { + evalResult = exprResult + } + } + + result[fieldName] = evalResult + } + + // 处理SimpleFields(复用现有逻辑) + if len(s.config.SimpleFields) > 0 { + for _, fieldSpec := range s.config.SimpleFields { + info := s.compiledFieldInfo[fieldSpec] + if info == nil { + // 如果没有预编译信息,回退到原逻辑(安全性保证) + s.processSingleFieldFallback(fieldSpec, dataMap, data, result) + continue + } + + if info.isSelectAll { + // SELECT *:批量复制所有字段,跳过表达式字段 + for k, v := range dataMap { + if _, isExpression := s.config.FieldExpressions[k]; !isExpression { + result[k] = v + } + } + continue + } + + // 跳过已经通过表达式字段处理的字段 + if _, isExpression := s.config.FieldExpressions[info.outputName]; isExpression { + continue + } + + if info.isFunctionCall { + // 执行函数调用 + if funcResult, err := s.executeFunction(info.fieldName, dataMap); err == nil { + result[info.outputName] = funcResult + } else { + logger.Error("Function execution error %s: %v", info.fieldName, err) + result[info.outputName] = nil + } + } else { + // 普通字段处理 + var value interface{} + var exists bool + + if info.hasNestedField { + value, exists = fieldpath.GetNestedField(data, info.fieldName) + } else { + value, exists = dataMap[info.fieldName] + } + + if exists { + result[info.outputName] = value + } else { + result[info.outputName] = nil + } + } + } + } else if len(s.config.FieldExpressions) == 0 { + // 如果没有指定字段且没有表达式字段,保留所有字段 + for k, v := range dataMap { + result[k] = v + } + } + + // 增加输出计数 + atomic.AddInt64(&s.outputCount, 1) + + // 包装结果为数组格式,保持与异步模式的一致性 + results := []map[string]interface{}{result} + + // 触发 AddSink 回调,保持同步和异步模式的一致性 + // 这样用户可以同时获得同步结果和异步回调 + s.callSinksAsync(results) + + return result, nil +} + // 向后兼容性函数 // NewStreamWithBuffers 创建带自定义缓冲区大小的Stream (已弃用,使用NewStreamWithCustomPerformance) @@ -1398,15 +2070,15 @@ func NewStreamWithoutDataLoss(config types.Config, strategy string) (*Stream, er // 应用用户指定的策略 validStrategies := map[string]bool{ - "drop": true, - "block": true, - "expand": true, - "persist": true, + StrategyDrop: true, + StrategyBlock: true, + StrategyExpand: true, + StrategyPersist: true, } if validStrategies[strategy] { perfConfig.OverflowConfig.Strategy = strategy - if strategy == "drop" { + if strategy == StrategyDrop { perfConfig.OverflowConfig.AllowDataLoss = true } } @@ -1426,7 +2098,7 @@ func NewStreamWithLossPolicy(config types.Config, dataBufSize, resultBufSize, si perfConfig.WorkerConfig.SinkPoolSize = sinkPoolSize perfConfig.OverflowConfig.Strategy = overflowStrategy perfConfig.OverflowConfig.BlockTimeout = timeout - perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == "drop") + perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == StrategyDrop) config.PerformanceConfig = perfConfig return newStreamWithUnifiedConfig(config) @@ -1443,10 +2115,10 @@ func NewStreamWithLossPolicyAndPersistence(config types.Config, dataBufSize, res perfConfig.WorkerConfig.SinkPoolSize = sinkPoolSize perfConfig.OverflowConfig.Strategy = overflowStrategy perfConfig.OverflowConfig.BlockTimeout = timeout - perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == "drop") + perfConfig.OverflowConfig.AllowDataLoss = (overflowStrategy == StrategyDrop) // 设置持久化配置 - if overflowStrategy == "persist" { + if overflowStrategy == StrategyPersist { perfConfig.OverflowConfig.PersistenceConfig = &types.PersistenceConfig{ DataDir: persistDataDir, MaxFileSize: persistMaxFileSize, diff --git a/stream/stream_test.go b/stream/stream_test.go index 51e2e4a..ee3e5fb 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -55,7 +55,7 @@ func TestStreamProcess(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口关闭并触发结果 @@ -139,7 +139,7 @@ func TestStreamWithoutFilter(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 捕获结果 @@ -235,7 +235,7 @@ func TestIncompleteStreamProcess(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口关闭并触发结果 @@ -323,7 +323,7 @@ func TestWindowSlotAgg(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 捕获结果 @@ -492,7 +492,7 @@ func TestStreamWithPersistenceStrategy(t *testing.T) { "temperature": float64(20 + i), "timestamp": time.Now(), } - stream.AddData(data) + stream.Emit(data) } // 等待处理完成 @@ -546,7 +546,7 @@ func TestStreamPersistenceRecovery(t *testing.T) { } for _, data := range testData { - stream1.AddData(data) + stream1.Emit(data) } // 等待数据持久化 @@ -695,7 +695,7 @@ func TestStreamPersistencePerformance(t *testing.T) { "value": i, "data": fmt.Sprintf("performance_test_data_%d", i), } - stream.AddData(data) + stream.Emit(data) } elapsed := time.Since(start) @@ -754,3 +754,219 @@ func TestStreamsqlPersistenceConfigPassing(t *testing.T) { t.Logf("持久化配置验证通过: %+v", stats) } + +func TestSelectStarWithExpressionFields(t *testing.T) { + config := types.Config{ + NeedWindow: false, + SimpleFields: []string{"*"}, // SELECT * + FieldExpressions: map[string]types.FieldExpression{ + "name": { + Expression: "UPPER(name)", + Fields: []string{"name"}, + }, + "full_info": { + Expression: "CONCAT(name, ' - ', status)", + Fields: []string{"name", "status"}, + }, + }, + } + + stream, err := NewStream(config) + if err != nil { + t.Fatalf("Failed to create stream: %v", err) + } + defer stream.Stop() + + // 收集结果 - 使用sync.Mutex防止数据竞争 + var mu sync.Mutex + var results []interface{} + stream.AddSink(func(result interface{}) { + mu.Lock() + defer mu.Unlock() + results = append(results, result) + }) + + stream.Start() + + // 添加测试数据 + testData := map[string]interface{}{ + "name": "john", + "status": "active", + "age": 25, + } + + stream.Emit(testData) + + // 等待处理完成 + time.Sleep(100 * time.Millisecond) + + // 验证结果 - 使用互斥锁保护读取 + mu.Lock() + resultsLen := len(results) + var resultData map[string]interface{} + if resultsLen > 0 { + resultData = results[0].([]map[string]interface{})[0] + } + mu.Unlock() + + if resultsLen != 1 { + t.Fatalf("Expected 1 result, got %d", resultsLen) + } + + // 验证表达式字段的结果没有被覆盖 + if resultData["name"] != "JOHN" { + t.Errorf("Expected name to be 'JOHN' (uppercase), got %v", resultData["name"]) + } + + if resultData["full_info"] != "john - active" { + t.Errorf("Expected full_info to be 'john - active', got %v", resultData["full_info"]) + } + + // 验证原始字段仍然存在 + if resultData["status"] != "active" { + t.Errorf("Expected status to be 'active', got %v", resultData["status"]) + } + + if resultData["age"] != 25 { + t.Errorf("Expected age to be 25, got %v", resultData["age"]) + } +} + +func TestSelectStarWithExpressionFieldsOverride(t *testing.T) { + // 测试表达式字段名与原始字段名相同的情况 + config := types.Config{ + NeedWindow: false, + SimpleFields: []string{"*"}, // SELECT * + FieldExpressions: map[string]types.FieldExpression{ + "name": { + Expression: "UPPER(name)", + Fields: []string{"name"}, + }, + "age": { + Expression: "age * 2", + Fields: []string{"age"}, + }, + }, + } + + stream, err := NewStream(config) + if err != nil { + t.Fatalf("Failed to create stream: %v", err) + } + defer stream.Stop() + + // 收集结果 - 使用sync.Mutex防止数据竞争 + var mu sync.Mutex + var results []interface{} + stream.AddSink(func(result interface{}) { + mu.Lock() + defer mu.Unlock() + results = append(results, result) + }) + + stream.Start() + + // 添加测试数据 + testData := map[string]interface{}{ + "name": "alice", + "age": 30, + "status": "active", + } + + stream.Emit(testData) + + // 等待处理完成 + time.Sleep(100 * time.Millisecond) + + // 验证结果 - 使用互斥锁保护读取 + mu.Lock() + resultsLen := len(results) + var resultData map[string]interface{} + if resultsLen > 0 { + resultData = results[0].([]map[string]interface{})[0] + } + mu.Unlock() + + if resultsLen != 1 { + t.Fatalf("Expected 1 result, got %d", resultsLen) + } + + // 验证表达式字段的结果覆盖了原始字段 + if resultData["name"] != "ALICE" { + t.Errorf("Expected name to be 'ALICE' (expression result), got %v", resultData["name"]) + } + + // 检查age表达式的结果(可能是int或float64类型) + ageResult := resultData["age"] + if ageResult != 60 && ageResult != 60.0 { + t.Errorf("Expected age to be 60 (expression result), got %v (type: %T)", resultData["age"], resultData["age"]) + } + + // 验证没有表达式的字段保持原值 + if resultData["status"] != "active" { + t.Errorf("Expected status to be 'active', got %v", resultData["status"]) + } +} + +func TestSelectStarWithoutExpressionFields(t *testing.T) { + // 测试没有表达式字段时SELECT *的行为 + config := types.Config{ + NeedWindow: false, + SimpleFields: []string{"*"}, // SELECT * + } + + stream, err := NewStream(config) + if err != nil { + t.Fatalf("Failed to create stream: %v", err) + } + defer stream.Stop() + + // 收集结果 - 使用sync.Mutex防止数据竞争 + var mu sync.Mutex + var results []interface{} + stream.AddSink(func(result interface{}) { + mu.Lock() + defer mu.Unlock() + results = append(results, result) + }) + + stream.Start() + + // 添加测试数据 + testData := map[string]interface{}{ + "name": "bob", + "age": 35, + "status": "inactive", + } + + stream.Emit(testData) + + // 等待处理完成 + time.Sleep(100 * time.Millisecond) + + // 验证结果 - 使用互斥锁保护读取 + mu.Lock() + resultsLen := len(results) + var resultData map[string]interface{} + if resultsLen > 0 { + resultData = results[0].([]map[string]interface{})[0] + } + mu.Unlock() + + if resultsLen != 1 { + t.Fatalf("Expected 1 result, got %d", resultsLen) + } + + // 验证所有原始字段都被保留 + if resultData["name"] != "bob" { + t.Errorf("Expected name to be 'bob', got %v", resultData["name"]) + } + + if resultData["age"] != 35 { + t.Errorf("Expected age to be 35, got %v", resultData["age"]) + } + + if resultData["status"] != "inactive" { + t.Errorf("Expected status to be 'inactive', got %v", resultData["status"]) + } +} diff --git a/streamsql.go b/streamsql.go index d28c6b6..ee9a378 100644 --- a/streamsql.go +++ b/streamsql.go @@ -22,6 +22,7 @@ import ( "github.com/rulego/streamsql/rsql" "github.com/rulego/streamsql/stream" "github.com/rulego/streamsql/types" + "github.com/rulego/streamsql/utils/table" ) // Streamsql 是StreamSQL流处理引擎的主要接口。 @@ -31,13 +32,19 @@ import ( // // ssql := streamsql.New() // err := ssql.Execute("SELECT AVG(temperature) FROM stream GROUP BY TumblingWindow('5s')") -// ssql.AddData(map[string]interface{}{"temperature": 25.5}) +// ssql.Emit(map[string]interface{}{"temperature": 25.5}) type Streamsql struct { stream *stream.Stream // 性能配置模式 performanceMode string // "default", "high_performance", "low_latency", "zero_data_loss", "custom" customConfig *types.PerformanceConfig + + // 新增:同步处理模式配置 + enableSyncMode bool // 是否启用同步模式(用于非聚合查询) + + // 保存原始SELECT字段顺序,用于表格输出时保持字段顺序 + fieldOrder []string } // New 创建一个新的StreamSQL实例。 @@ -122,6 +129,9 @@ func (s *Streamsql) Execute(sql string) error { return fmt.Errorf("SQL解析失败: %w", err) } + // 从解析结果中获取字段顺序信息 + s.fieldOrder = config.FieldOrder + // 根据性能模式创建流处理器 var streamInstance *stream.Stream @@ -158,7 +168,7 @@ func (s *Streamsql) Execute(sql string) error { return nil } -// AddData 向流中添加一条数据记录。 +// Emit 向流中添加一条数据记录。 // 数据会根据已配置的SQL查询进行处理和聚合。 // // 支持的数据格式: @@ -171,7 +181,7 @@ func (s *Streamsql) Execute(sql string) error { // 示例: // // // 添加设备数据 -// ssql.AddData(map[string]interface{}{ +// ssql.Emit(map[string]interface{}{ // "deviceId": "sensor001", // "temperature": 25.5, // "humidity": 60.0, @@ -179,17 +189,75 @@ func (s *Streamsql) Execute(sql string) error { // }) // // // 添加用户行为数据 -// ssql.AddData(map[string]interface{}{ +// ssql.Emit(map[string]interface{}{ // "userId": "user123", // "action": "click", // "page": "/home", // }) -func (s *Streamsql) AddData(data interface{}) { +func (s *Streamsql) Emit(data interface{}) { if s.stream != nil { - s.stream.AddData(data) + s.stream.Emit(data) } } +// EmitSync 同步处理数据,立即返回处理结果。 +// 仅适用于非聚合查询(如过滤、转换等),聚合查询会返回错误。 +// +// 对于非聚合查询,此方法提供同步的数据处理能力,同时: +// 1. 立即返回处理结果(同步) +// 2. 触发已注册的AddSink回调(异步) +// +// 这确保了同步和异步模式的一致性,用户可以同时获得: +// - 立即可用的处理结果 +// - 异步回调处理(用于日志、监控、持久化等) +// +// 参数: +// - data: 要处理的数据 +// +// 返回值: +// - interface{}: 处理后的结果,如果不匹配过滤条件返回nil +// - error: 处理错误,如果是聚合查询会返回错误 +// +// 示例: +// +// // 添加日志回调 +// ssql.AddSink(func(result interface{}) { +// fmt.Printf("异步日志: %v\n", result) +// }) +// +// // 同步处理并立即获取结果 +// result, err := ssql.EmitSync(map[string]interface{}{ +// "temperature": 25.5, +// "humidity": 60.0, +// }) +// if err != nil { +// // 处理错误 +// } else if result != nil { +// // 立即使用处理结果 +// fmt.Printf("同步结果: %v\n", result) +// // 同时异步回调也会被触发 +// } +func (s *Streamsql) EmitSync(data interface{}) (interface{}, error) { + if s.stream == nil { + return nil, fmt.Errorf("stream未初始化") + } + + // 检查是否为非聚合查询 + if s.stream.IsAggregationQuery() { + return nil, fmt.Errorf("同步模式仅支持非聚合查询,聚合查询请使用Emit()方法") + } + + return s.stream.ProcessSync(data) +} + +// IsAggregationQuery 检查当前查询是否为聚合查询 +func (s *Streamsql) IsAggregationQuery() bool { + if s.stream == nil { + return false + } + return s.stream.IsAggregationQuery() +} + // Stream 返回底层的流处理器实例。 // 通过此方法可以访问更底层的流处理功能。 // @@ -248,3 +316,89 @@ func (s *Streamsql) Stop() { s.stream.Stop() } } + +// AddSink 直接添加结果处理回调函数。 +// 这是对 Stream().AddSink() 的便捷封装,使API调用更简洁。 +// +// 参数: +// - sink: 结果处理函数,接收处理结果作为参数 +// +// 示例: +// +// // 直接添加结果处理 +// ssql.AddSink(func(result interface{}) { +// fmt.Printf("处理结果: %v\n", result) +// }) +// +// // 添加多个处理器 +// ssql.AddSink(func(result interface{}) { +// // 保存到数据库 +// saveToDatabase(result) +// }) +// ssql.AddSink(func(result interface{}) { +// // 发送到消息队列 +// sendToQueue(result) +// }) +func (s *Streamsql) AddSink(sink func(interface{})) { + if s.stream != nil { + s.stream.AddSink(sink) + } +} + +// PrintTable 以表格形式打印结果到控制台,类似数据库输出格式。 +// 首先显示列名,然后逐行显示数据。 +// +// 支持的数据格式: +// - []map[string]interface{}: 多行记录 +// - map[string]interface{}: 单行记录 +// - 其他类型: 直接打印 +// +// 示例: +// +// // 表格式打印结果 +// ssql.PrintTable() +// +// // 输出格式: +// // +--------+----------+ +// // | device | max_temp | +// // +--------+----------+ +// // | aa | 30.0 | +// // | bb | 22.0 | +// // +--------+----------+ +func (s *Streamsql) PrintTable() { + s.AddSink(func(result interface{}) { + s.printTableFormat(result) + }) +} + +// printTableFormat 格式化打印表格数据 +func (s *Streamsql) printTableFormat(result interface{}) { + table.FormatTableData(result, s.fieldOrder) +} + +// ToChannel 返回结果通道,用于异步获取处理结果。 +// 通过此通道可以以非阻塞方式获取流处理结果。 +// +// 返回值: +// - <-chan interface{}: 只读的结果通道,如果未执行SQL则返回nil +// +// 示例: +// +// // 获取结果通道 +// resultChan := ssql.ToChannel() +// if resultChan != nil { +// go func() { +// for result := range resultChan { +// fmt.Printf("异步结果: %v\n", result) +// } +// }() +// } +// +// 注意: +// - 必须有消费者持续从通道读取数据,否则可能导致流处理阻塞 +func (s *Streamsql) ToChannel() <-chan interface{} { + if s.stream != nil { + return s.stream.GetResultsChan() + } + return nil +} diff --git a/streamsql_benchmark_test.go b/streamsql_benchmark_test.go index edf2c4d..328ebcc 100644 --- a/streamsql_benchmark_test.go +++ b/streamsql_benchmark_test.go @@ -2,56 +2,44 @@ package streamsql import ( "context" - "fmt" "math/rand" - "runtime" - "sync" "sync/atomic" "testing" "time" - - "github.com/rulego/streamsql/stream" - "github.com/rulego/streamsql/types" ) -// BenchmarkStreamSQLPerformance 综合性能基准测试(优化版本) -func BenchmarkStreamSQLPerformance(b *testing.B) { +// BenchmarkStreamSQLCore 核心性能基准测试 +func BenchmarkStreamSQLCore(b *testing.B) { tests := []struct { name string sql string hasWindow bool waitTime time.Duration - config string // 配置描述 }{ { name: "SimpleFilter", sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", hasWindow: false, waitTime: 50 * time.Millisecond, - config: "基准测试专用", }, { - name: "BasicAggregation", - sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('200ms')", + name: "WindowAggregation", + sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')", hasWindow: true, - waitTime: 400 * time.Millisecond, - config: "基准测试专用", + waitTime: 200 * time.Millisecond, }, { name: "ComplexQuery", sql: "SELECT deviceId, AVG(temperature), COUNT(*) FROM stream WHERE humidity > 50 GROUP BY deviceId, TumblingWindow('100ms')", hasWindow: true, - waitTime: 300 * time.Millisecond, - config: "基准测试专用", + waitTime: 250 * time.Millisecond, }, } for _, tt := range tests { b.Run(tt.name, func(b *testing.B) { - // 使用超大缓冲区专门针对基准测试优化 - // 基准测试需要处理大量迭代,需要更大的缓冲区 - bufferSize := max(int64(100000), int64(b.N/10)) // 至少10万,或者迭代数的1/10 - ssql := New(WithBufferSizes(int(bufferSize), int(bufferSize), 2000)) + // 使用默认配置进行基准测试 + ssql := New() defer ssql.Stop() err := ssql.Execute(tt.sql) @@ -61,158 +49,37 @@ func BenchmarkStreamSQLPerformance(b *testing.B) { var resultReceived int64 - // 添加非阻塞sink处理结果 - ssql.Stream().AddSink(func(result interface{}) { + // 添加结果处理器 + ssql.AddSink(func(result interface{}) { atomic.AddInt64(&resultReceived, 1) }) - // 使用context控制生命周期 + // 异步消费结果通道防止阻塞 ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // 异步消费resultChan,确保通道不被填满 go func() { for { select { case <-ssql.Stream().GetResultsChan(): - // 快速消费,避免通道阻塞 case <-ctx.Done(): return } } }() - // 测试数据 - 减少数据种类避免过度生成 - testData := generateTestData(3) + // 生成测试数据 + testData := generateTestData(5) - // 重置统计,获得准确的基准测试数据 + // 重置统计 ssql.Stream().ResetStats() b.ResetTimer() - // 执行基准测试 - 添加节流以避免瞬间填满缓冲区 - start := time.Now() - batchSize := 1000 // 分批处理,避免瞬间打满缓冲区 - for i := 0; i < b.N; i++ { - ssql.AddData(testData[i%len(testData)]) - - // 每处理一批数据,稍微暂停,让系统有时间处理 - if i > 0 && i%batchSize == 0 { - time.Sleep(10 * time.Microsecond) // 极短暂停 - } - } - inputDuration := time.Since(start) - - b.StopTimer() - - // 等待处理完成 - time.Sleep(tt.waitTime) - - cancel() // 停止结果处理goroutine - - // 获取详细统计信息 - detailedStats := ssql.Stream().GetDetailedStats() - received := atomic.LoadInt64(&resultReceived) - - // 计算性能指标 - inputThroughput := float64(b.N) / inputDuration.Seconds() - processRate := detailedStats["process_rate"].(float64) - dropRate := detailedStats["drop_rate"].(float64) - perfLevel := detailedStats["performance_level"].(string) - - b.ReportMetric(inputThroughput, "input_ops/sec") - b.ReportMetric(float64(received), "results_received") - b.ReportMetric(processRate, "process_rate_%") - b.ReportMetric(dropRate, "drop_rate_%") - - // 性能分析报告 - b.Logf("%s配置 (缓冲区: %d) - 性能等级: %s", tt.config, bufferSize, perfLevel) - b.Logf("处理效率: %.2f%%, 丢弃率: %.2f%%", processRate, dropRate) - b.Logf("缓冲区使用: 数据通道 %.1f%%, 结果通道 %.1f%%", - detailedStats["data_chan_usage"], detailedStats["result_chan_usage"]) - - if dropRate > 5 { // 降低警告阈值 - b.Logf("警告: 丢弃率 %.2f%% - 建议增加缓冲区大小", dropRate) - } - - if !tt.hasWindow && received == 0 { - b.Logf("警告: 非聚合查询未收到结果") - } - - // 性能建议 - if dropRate > 10 { - b.Logf("建议: 使用更大缓冲区配置,当前缓冲区可能不足") - } else if processRate == 100.0 && dropRate == 0.0 { - b.Logf("✓ 优秀: 完美的处理效率,无数据丢失") - } - }) - } -} - -// BenchmarkStreamSQLFixed 修复版本的基准测试 -func BenchmarkStreamSQLFixed(b *testing.B) { - tests := []struct { - name string - sql string - hasWindow bool - waitTime time.Duration - }{ - { - name: "SimpleFilter", - sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", - hasWindow: false, - waitTime: 10 * time.Millisecond, - }, - { - name: "BasicAggregation", - sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('200ms')", - hasWindow: true, - waitTime: 300 * time.Millisecond, - }, - } - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - ssql := New() - defer ssql.Stop() - - err := ssql.Execute(tt.sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - var processedCount int64 - - // 使用非阻塞的sink避免阻塞 - ssql.Stream().AddSink(func(result interface{}) { - // 非阻塞计数,不做任何可能阻塞的操作 - atomic.AddInt64(&processedCount, 1) - }) - - // 启动一个goroutine异步消费resultChan,防止填满 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - // 快速消费,不做处理 - case <-ctx.Done(): - return - } - } - }() - - // 测试数据 - testData := generateTestData(5) // 减少数据量避免过载 - - b.ResetTimer() - // 执行基准测试 start := time.Now() for i := 0; i < b.N; i++ { - ssql.AddData(testData[i%len(testData)]) + ssql.Emit(testData[i%len(testData)]) } inputDuration := time.Since(start) @@ -220,23 +87,125 @@ func BenchmarkStreamSQLFixed(b *testing.B) { // 等待处理完成 time.Sleep(tt.waitTime) + cancel() + + // 获取统计信息 + stats := ssql.Stream().GetStats() + received := atomic.LoadInt64(&resultReceived) // 计算性能指标 inputThroughput := float64(b.N) / inputDuration.Seconds() - processed := atomic.LoadInt64(&processedCount) + processedCount := stats["output_count"] + droppedCount := stats["dropped_count"] + processRate := float64(processedCount) / float64(b.N) * 100 + dropRate := float64(droppedCount) / float64(b.N) * 100 - b.ReportMetric(inputThroughput, "input_ops/sec") - b.ReportMetric(float64(processed), "processed_results") + b.ReportMetric(inputThroughput, "ops/sec") + b.ReportMetric(processRate, "process_rate_%") + b.ReportMetric(dropRate, "drop_rate_%") + b.ReportMetric(float64(received), "results") + + b.Logf("%s - 吞吐量: %.0f ops/sec, 处理率: %.1f%%, 丢弃率: %.2f%%", + tt.name, inputThroughput, processRate, dropRate) }) } } -// BenchmarkPureInputPerformance 纯输入性能基准测试(避免结果处理的影响) -func BenchmarkPureInputPerformance(b *testing.B) { - ssql := New() +// BenchmarkConfigComparison 配置对比基准测试 +func BenchmarkConfigComparison(b *testing.B) { + tests := []struct { + name string + setupFunc func() *Streamsql + }{ + { + name: "Default", + setupFunc: func() *Streamsql { + return New() + }, + }, + { + name: "HighPerformance", + setupFunc: func() *Streamsql { + return New(WithHighPerformance()) + }, + }, + { + name: "Lightweight", + setupFunc: func() *Streamsql { + return New(WithBufferSizes(5000, 5000, 250)) + }, + }, + } + + sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + ssql := tt.setupFunc() + defer ssql.Stop() + + err := ssql.Execute(sql) + if err != nil { + b.Fatalf("SQL执行失败: %v", err) + } + + var resultCount int64 + ssql.AddSink(func(result interface{}) { + atomic.AddInt64(&resultCount, 1) + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + select { + case <-ssql.Stream().GetResultsChan(): + case <-ctx.Done(): + return + } + } + }() + + testData := generateTestData(3) + ssql.Stream().ResetStats() + + b.ResetTimer() + + start := time.Now() + for i := 0; i < b.N; i++ { + ssql.Emit(testData[i%len(testData)]) + } + inputDuration := time.Since(start) + + b.StopTimer() + + time.Sleep(50 * time.Millisecond) + cancel() + + stats := ssql.Stream().GetStats() + + inputThroughput := float64(b.N) / inputDuration.Seconds() + processedCount := stats["output_count"] + droppedCount := stats["dropped_count"] + processRate := float64(processedCount) / float64(b.N) * 100 + dropRate := float64(droppedCount) / float64(b.N) * 100 + + b.ReportMetric(inputThroughput, "ops/sec") + b.ReportMetric(processRate, "process_rate_%") + b.ReportMetric(dropRate, "drop_rate_%") + + b.Logf("%s配置 - 吞吐量: %.0f ops/sec, 处理率: %.1f%%, 丢弃率: %.2f%%", + tt.name, inputThroughput, processRate, dropRate) + }) + } +} + +// BenchmarkPureInput 纯输入性能基准测试 +func BenchmarkPureInput(b *testing.B) { + ssql := New(WithHighPerformance()) defer ssql.Stop() - // 最简单的查询 sql := "SELECT deviceId FROM stream" err := ssql.Execute(sql) if err != nil { @@ -251,7 +220,6 @@ func BenchmarkPureInputPerformance(b *testing.B) { for { select { case <-ssql.Stream().GetResultsChan(): - // 快速丢弃结果 case <-ctx.Done(): return } @@ -267,15 +235,16 @@ func BenchmarkPureInputPerformance(b *testing.B) { b.ResetTimer() start := time.Now() - // 测量纯输入吞吐量 for i := 0; i < b.N; i++ { - ssql.AddData(data) + ssql.Emit(data) } b.StopTimer() duration := time.Since(start) throughput := float64(b.N) / duration.Seconds() b.ReportMetric(throughput, "pure_input_ops/sec") + + b.Logf("纯输入性能: %.0f ops/sec (%.1f万 ops/sec)", throughput, throughput/10000) } // generateTestData 生成测试数据 @@ -294,948 +263,6 @@ func generateTestData(count int) []map[string]interface{} { return data } -// generateIoTData 生成更真实的IoT设备数据 -func generateIoTData(count int) []map[string]interface{} { - data := make([]map[string]interface{}, count) - devices := []string{"sensor001", "sensor002", "sensor003", "gateway001", "gateway002"} - locations := []string{"building_a", "building_b", "outdoor", "warehouse"} - - for i := 0; i < count; i++ { - baseTemp := 20.0 - if rand.Float64() < 0.1 { // 10%概率产生异常值 - baseTemp = 40.0 - } - - data[i] = map[string]interface{}{ - "deviceId": devices[rand.Intn(len(devices))], - "location": locations[rand.Intn(len(locations))], - "temperature": baseTemp + rand.Float64()*10, - "humidity": 40.0 + rand.Float64()*30, - "pressure": 1000.0 + rand.Float64()*100, - "battery": rand.Float64() * 100, - "signal": -30.0 - rand.Float64()*50, - "timestamp": time.Now().UnixNano(), - } - } - return data -} - -// max 辅助函数 -func max(a, b int64) int64 { - if a > b { - return a - } - return b -} - -// TestStreamSQLPerformanceAnalysis 性能分析测试 -func TestStreamSQLPerformanceAnalysis(t *testing.T) { - scenarios := []struct { - name string - sql string - dataCount int - duration time.Duration - expectResults bool - }{ - { - name: "高频非聚合查询", - sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", - dataCount: 1000, - duration: 100 * time.Millisecond, - expectResults: true, - }, - { - name: "窗口聚合查询", - sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('50ms')", - dataCount: 500, - duration: 200 * time.Millisecond, - expectResults: true, - }, - { - name: "复杂聚合查询", - sql: "SELECT deviceId, AVG(temperature), MAX(humidity), COUNT(*) FROM stream WHERE temperature > 15 GROUP BY deviceId, TumblingWindow('100ms')", - dataCount: 300, - duration: 300 * time.Millisecond, - expectResults: true, - }, - } - - for _, scenario := range scenarios { - t.Run(scenario.name, func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - err := ssql.Execute(scenario.sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - var inputCount int64 - var resultCount int64 - - // 结果监听 - ctx, cancel := context.WithTimeout(context.Background(), scenario.duration*2) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - atomic.AddInt64(&resultCount, 1) - case <-ctx.Done(): - return - } - } - }() - - // 生成和输入数据 - testData := generateTestData(20) - start := time.Now() - - for i := 0; i < scenario.dataCount; i++ { - ssql.AddData(testData[i%len(testData)]) - atomic.AddInt64(&inputCount, 1) - - // 控制输入频率,避免过快 - if i%100 == 0 { - time.Sleep(1 * time.Millisecond) - } - } - - inputDuration := time.Since(start) - - // 等待处理完成 - time.Sleep(scenario.duration) - - input := atomic.LoadInt64(&inputCount) - results := atomic.LoadInt64(&resultCount) - - inputRate := float64(input) / inputDuration.Seconds() - - t.Logf("场景: %s", scenario.name) - t.Logf("输入数据: %d 条, 耗时: %v", input, inputDuration) - t.Logf("输入速率: %.2f ops/sec", inputRate) - t.Logf("生成结果: %d 个", results) - - if scenario.expectResults && results == 0 { - t.Logf("警告: 预期有结果但未收到任何结果") - } - - // 基本性能验证 - if inputRate < 1000 { - t.Logf("注意: 输入速率较低 (%.2f ops/sec)", inputRate) - } - - if input != int64(scenario.dataCount) { - t.Errorf("输入数据不完整: 期望 %d, 实际 %d", scenario.dataCount, input) - } - }) - } -} - -// TestDiagnoseBenchmarkIssues 诊断基准测试阻塞问题的测试用例 -func TestDiagnoseBenchmarkIssues(t *testing.T) { - t.Run("基础功能测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 使用最简单的查询 - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 检查流是否正确创建 - if ssql.Stream() == nil { - t.Fatal("流创建失败") - } - - var resultCount int64 - var lastResult interface{} - - // 添加结果回调 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&resultCount, 1) - lastResult = result - t.Logf("收到结果 #%d: %v", atomic.LoadInt64(&resultCount), result) - }) - - // 添加测试数据 - testData := map[string]interface{}{ - "deviceId": "device1", - "temperature": 25.0, - "humidity": 60.0, - } - - t.Logf("添加数据: %v", testData) - ssql.AddData(testData) - - // 等待结果 - time.Sleep(100 * time.Millisecond) - - count := atomic.LoadInt64(&resultCount) - t.Logf("处理结果数量: %d", count) - if count > 0 { - t.Logf("最后结果: %v", lastResult) - } - - // 验证非聚合查询应该立即返回结果 - if count == 0 { - t.Error("非聚合查询没有返回任何结果") - } - }) - - t.Run("窗口聚合测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 使用滚动窗口 - sql := "SELECT deviceId, AVG(temperature) as avg_temp FROM stream GROUP BY deviceId, TumblingWindow('200ms')" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - var resultCount int64 - var lastResult interface{} - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&resultCount, 1) - lastResult = result - t.Logf("窗口结果 #%d: %v", atomic.LoadInt64(&resultCount), result) - }) - - // 添加多条数据 - for i := 0; i < 5; i++ { - testData := map[string]interface{}{ - "deviceId": "device1", - "temperature": 20.0 + float64(i), - "humidity": 60.0, - } - t.Logf("添加数据 #%d: %v", i+1, testData) - ssql.AddData(testData) - time.Sleep(10 * time.Millisecond) // 小间隔 - } - - // 等待窗口触发 - t.Log("等待窗口触发...") - time.Sleep(300 * time.Millisecond) - - count := atomic.LoadInt64(&resultCount) - t.Logf("窗口结果数量: %d", count) - if count > 0 { - t.Logf("最后窗口结果: %v", lastResult) - } - - if count == 0 { - t.Error("窗口聚合没有返回任何结果") - } - }) - - t.Run("高频数据测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, COUNT(*) as count FROM stream GROUP BY deviceId, TumblingWindow('100ms')" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - var resultCount int64 - var totalDataPoints int64 - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&resultCount, 1) - if resultSlice, ok := result.([]map[string]interface{}); ok { - for _, r := range resultSlice { - if count, exists := r["count"]; exists { - if countVal, ok := count.(int64); ok { - atomic.AddInt64(&totalDataPoints, countVal) - } - } - } - } - t.Logf("高频结果 #%d: %v", atomic.LoadInt64(&resultCount), result) - }) - - // 高频添加数据 - start := time.Now() - dataCount := 50 - for i := 0; i < dataCount; i++ { - testData := map[string]interface{}{ - "deviceId": fmt.Sprintf("device%d", (i%3)+1), - "temperature": 20.0 + rand.Float64()*10, - "humidity": 50.0 + rand.Float64()*20, - } - ssql.AddData(testData) - } - inputDuration := time.Since(start) - - // 等待窗口处理 - time.Sleep(200 * time.Millisecond) - - windows := atomic.LoadInt64(&resultCount) - processed := atomic.LoadInt64(&totalDataPoints) - - t.Logf("输入 %d 条数据,用时 %v", dataCount, inputDuration) - t.Logf("生成 %d 个窗口,处理 %d 条数据", windows, processed) - t.Logf("输入速率: %.2f ops/sec", float64(dataCount)/inputDuration.Seconds()) - - if processed != int64(dataCount) { - t.Logf("警告: 处理数据量 (%d) 与输入数据量 (%d) 不匹配", processed, dataCount) - } - - if windows == 0 { - t.Error("高频数据测试没有生成任何窗口") - } - }) - - t.Run("性能基准预测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, temperature, humidity FROM stream WHERE temperature > 20" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - var processedCount int64 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&processedCount, 1) - }) - - testData := generateTestData(10) - - // 模拟基准测试的执行方式 - start := time.Now() - iterations := 1000 - for i := 0; i < iterations; i++ { - ssql.AddData(testData[i%len(testData)]) - } - duration := time.Since(start) - - time.Sleep(20 * time.Millisecond) // 等待处理完成 - - processed := atomic.LoadInt64(&processedCount) - throughput := float64(iterations) / duration.Seconds() - - t.Logf("执行 %d 次迭代,用时 %v", iterations, duration) - t.Logf("处理 %d 条数据", processed) - t.Logf("吞吐量: %.2f ops/sec", throughput) - - if processed == 0 { - t.Error("性能基准预测试没有处理任何数据") - } - - // 检查是否会阻塞 - if duration > 5*time.Second { - t.Error("执行时间过长,可能存在阻塞问题") - } - }) -} - -// TestStreamOptimizations 测试Stream优化效果 -func TestStreamOptimizations(t *testing.T) { - t.Run("非阻塞性能测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 使用简单查询测试 - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 添加多个sink模拟实际使用场景 - var sink1Count, sink2Count, sink3Count int64 - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sink1Count, 1) - time.Sleep(1 * time.Millisecond) // 模拟处理延迟 - }) - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sink2Count, 1) - time.Sleep(2 * time.Millisecond) // 模拟较慢的sink - }) - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sink3Count, 1) - }) - - // 快速输入大量数据 - testData := generateTestData(10) - inputCount := 1000 - - start := time.Now() - for i := 0; i < inputCount; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - // 等待处理完成 - time.Sleep(200 * time.Millisecond) - - // 获取统计信息 - stats := ssql.Stream().GetStats() - - t.Logf("输入 %d 条数据,耗时: %v", inputCount, inputDuration) - t.Logf("输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf("统计信息: %+v", stats) - t.Logf("Sink计数: sink1=%d, sink2=%d, sink3=%d", - atomic.LoadInt64(&sink1Count), - atomic.LoadInt64(&sink2Count), - atomic.LoadInt64(&sink3Count)) - - // 验证性能指标 - if inputDuration > 100*time.Millisecond { - t.Errorf("输入耗时过长: %v", inputDuration) - } - - if stats["dropped_count"] > int64(inputCount/10) { - t.Errorf("丢弃数据过多: %d", stats["dropped_count"]) - } - - // 验证非阻塞性 - throughput := float64(inputCount) / inputDuration.Seconds() - if throughput < 5000 { // 最低5K ops/sec - t.Errorf("吞吐量过低: %.2f ops/sec", throughput) - } - }) - - t.Run("窗口聚合优化测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - var resultCount int64 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&resultCount, 1) - // 模拟慢sink - time.Sleep(5 * time.Millisecond) - }) - - // 快速输入数据 - testData := generateTestData(5) - inputCount := 500 - - start := time.Now() - for i := 0; i < inputCount; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - // 等待窗口触发 - time.Sleep(200 * time.Millisecond) - - stats := ssql.Stream().GetStats() - results := atomic.LoadInt64(&resultCount) - - t.Logf("窗口聚合测试 - 输入: %d, 耗时: %v", inputCount, inputDuration) - t.Logf("输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf("生成窗口结果: %d", results) - t.Logf("统计信息: %+v", stats) - - // 验证窗口功能正常 - if results == 0 { - t.Error("窗口聚合未生成任何结果") - } - - // 验证非阻塞性 - if inputDuration > 100*time.Millisecond { - t.Errorf("窗口模式输入耗时过长: %v", inputDuration) - } - }) - - t.Run("高负载压力测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, temperature FROM stream" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 添加会阻塞的sink - ssql.Stream().AddSink(func(result interface{}) { - time.Sleep(10 * time.Millisecond) // 故意阻塞 - }) - - // 高频输入 - testData := generateTestData(3) - inputCount := 2000 // 增加输入量 - - start := time.Now() - for i := 0; i < inputCount; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - // 短暂等待 - time.Sleep(50 * time.Millisecond) - - stats := ssql.Stream().GetStats() - - t.Logf("高负载测试 - 输入: %d, 耗时: %v", inputCount, inputDuration) - t.Logf("输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf("统计信息: %+v", stats) - - // 即使有阻塞的sink,系统也应该保持响应 - if inputDuration > 200*time.Millisecond { - t.Errorf("高负载下输入耗时过长: %v", inputDuration) - } - - // 验证系统没有完全阻塞 - throughput := float64(inputCount) / inputDuration.Seconds() - if throughput < 1000 { - t.Errorf("高负载下吞吐量过低: %.2f ops/sec", throughput) - } - }) -} - -// TestStreamOptimizationsImproved 测试Stream改进后的优化效果 -func TestStreamOptimizationsImproved(t *testing.T) { - t.Run("改进的非阻塞性能测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 使用简单查询测试 - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 添加多个sink模拟真实场景 - var sink1Count, sink2Count int64 - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sink1Count, 1) - time.Sleep(2 * time.Millisecond) // 模拟处理延迟 - }) - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sink2Count, 1) - time.Sleep(1 * time.Millisecond) // 模拟较快的sink - }) - - // 重置统计信息 - ssql.Stream().ResetStats() - - // 快速输入大量数据 - testData := generateTestData(10) - inputCount := 2000 // 增加输入量测试 - - start := time.Now() - for i := 0; i < inputCount; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - // 等待处理完成 - time.Sleep(300 * time.Millisecond) - - // 获取统计信息 - stats := ssql.Stream().GetStats() - - t.Logf("改进测试 - 输入 %d 条数据,耗时: %v", inputCount, inputDuration) - t.Logf("输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf("统计信息: %+v", stats) - t.Logf("Sink计数: sink1=%d, sink2=%d", - atomic.LoadInt64(&sink1Count), - atomic.LoadInt64(&sink2Count)) - - // 计算处理效率 - inputTotal := stats["input_count"] - outputTotal := stats["output_count"] - droppedTotal := stats["dropped_count"] - - if inputTotal > 0 { - processRate := float64(outputTotal) / float64(inputTotal) * 100 - dropRate := float64(droppedTotal) / float64(inputTotal) * 100 - - t.Logf("处理效率: %.2f%%, 丢弃率: %.2f%%", processRate, dropRate) - - // 验证改进效果 - if dropRate > 50 { // 丢弃率不应超过50% - t.Errorf("丢弃率过高: %.2f%%", dropRate) - } - } - - // 验证非阻塞性 - throughput := float64(inputCount) / inputDuration.Seconds() - if throughput < 10000 { // 期望至少10K ops/sec - t.Logf("注意: 吞吐量较低: %.2f ops/sec", throughput) - } - - // 验证系统没有完全阻塞 - if inputDuration > 500*time.Millisecond { - t.Errorf("输入耗时过长: %v", inputDuration) - } - }) - - t.Run("超高负载压力测试", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, temperature FROM stream" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 添加会严重阻塞的sink - var sinkCount int64 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sinkCount, 1) - time.Sleep(20 * time.Millisecond) // 故意制造严重阻塞 - }) - - // 重置统计 - ssql.Stream().ResetStats() - - // 超高频输入 - testData := generateTestData(3) - inputCount := 5000 // 大幅增加输入量 - - start := time.Now() - for i := 0; i < inputCount; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - // 短暂等待 - time.Sleep(100 * time.Millisecond) - - stats := ssql.Stream().GetStats() - sinks := atomic.LoadInt64(&sinkCount) - - t.Logf("超高负载测试 - 输入: %d, 耗时: %v", inputCount, inputDuration) - t.Logf("输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf("Sink处理数: %d", sinks) - t.Logf("统计信息: %+v", stats) - - // 即使有严重阻塞的sink,系统仍应保持响应 - throughput := float64(inputCount) / inputDuration.Seconds() - - // 验证系统没有完全卡死 - if inputDuration > 1*time.Second { - t.Errorf("超高负载下输入耗时过长: %v", inputDuration) - } - - if throughput < 5000 { - t.Logf("注意: 超高负载下吞吐量: %.2f ops/sec", throughput) - } else { - t.Logf("优秀: 超高负载下仍保持高吞吐量: %.2f ops/sec", throughput) - } - - // 验证背压控制有效 - inputTotal := stats["input_count"] - droppedTotal := stats["dropped_count"] - if inputTotal > 0 { - dropRate := float64(droppedTotal) / float64(inputTotal) * 100 - t.Logf("背压控制 - 丢弃率: %.2f%%", dropRate) - } - }) - - t.Run("性能对比测试", func(t *testing.T) { - // 测试不同负载下的性能表现 - testCases := []struct { - name string - inputCount int - sinkDelay time.Duration - maxDropRate float64 - }{ - {"轻负载", 500, 1 * time.Millisecond, 10.0}, - {"中负载", 1500, 3 * time.Millisecond, 25.0}, - {"重负载", 3000, 5 * time.Millisecond, 40.0}, - } - - for _, tc := range testCases { - // 创建测试用例的副本,避免闭包问题 - testCase := tc - t.Run(testCase.name, func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 15" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 为每个测试用例创建独立的计数器 - var sinkCount int64 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sinkCount, 1) - time.Sleep(testCase.sinkDelay) - }) - - ssql.Stream().ResetStats() - - testData := generateTestData(5) - start := time.Now() - - for i := 0; i < testCase.inputCount; i++ { - ssql.AddData(testData[i%len(testData)]) - } - - inputDuration := time.Since(start) - time.Sleep(150 * time.Millisecond) - - stats := ssql.Stream().GetStats() - sinks := atomic.LoadInt64(&sinkCount) - - throughput := float64(testCase.inputCount) / inputDuration.Seconds() - - inputTotal := stats["input_count"] - outputTotal := stats["output_count"] - droppedTotal := stats["dropped_count"] - - var processRate, dropRate float64 - if inputTotal > 0 { - processRate = float64(outputTotal) / float64(inputTotal) * 100 - dropRate = float64(droppedTotal) / float64(inputTotal) * 100 - } - - t.Logf("%s结果:", testCase.name) - t.Logf(" 输入速率: %.2f ops/sec", throughput) - t.Logf(" 处理效率: %.2f%%", processRate) - t.Logf(" 丢弃率: %.2f%%", dropRate) - t.Logf(" Sink处理: %d", sinks) - t.Logf(" 统计: %+v", stats) - - // 验证性能标准 - if dropRate > testCase.maxDropRate { - t.Errorf("%s: 丢弃率过高 %.2f%% > %.2f%%", testCase.name, dropRate, testCase.maxDropRate) - } - - if throughput < 1000 { - t.Errorf("%s: 吞吐量过低 %.2f ops/sec", testCase.name, throughput) - } - }) - } - }) -} - -// TestMassiveBufferOptimization 测试超大缓冲区优化效果 -func TestMassiveBufferOptimization(t *testing.T) { - t.Run("标准配置vs高性能配置对比", func(t *testing.T) { - // 测试标准配置 - t.Run("标准配置", func(t *testing.T) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - var sinkCount int64 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&sinkCount, 1) - time.Sleep(1 * time.Millisecond) // 模拟处理延迟 - }) - - ssql.Stream().ResetStats() - - // 大量数据输入 - testData := generateTestData(5) - inputCount := 10000 - - start := time.Now() - for i := 0; i < inputCount; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - time.Sleep(300 * time.Millisecond) - - detailedStats := ssql.Stream().GetDetailedStats() - sinks := atomic.LoadInt64(&sinkCount) - - t.Logf("标准配置结果:") - t.Logf(" 输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf(" 处理效率: %.2f%%", detailedStats["process_rate"]) - t.Logf(" 丢弃率: %.2f%%", detailedStats["drop_rate"]) - t.Logf(" 性能等级: %s", detailedStats["performance_level"]) - t.Logf(" 数据通道使用率: %.2f%%", detailedStats["data_chan_usage"]) - t.Logf(" Sink处理数: %d", sinks) - t.Logf(" 详细统计: %+v", detailedStats["basic_stats"]) - }) - - // 测试高性能配置 - t.Run("高性能配置", func(t *testing.T) { - // 直接创建高性能Stream(绕过StreamSQL包装) - config := types.Config{ - SimpleFields: []string{"deviceId", "temperature"}, - } - - stream, err := stream.NewHighPerformanceStream(config) - if err != nil { - t.Fatalf("高性能Stream创建失败: %v", err) - } - defer stream.Stop() - - err = stream.RegisterFilter("temperature > 20") - if err != nil { - t.Fatalf("过滤器注册失败: %v", err) - } - - stream.Start() - - var sinkCount int64 - stream.AddSink(func(result interface{}) { - atomic.AddInt64(&sinkCount, 1) - time.Sleep(1 * time.Millisecond) // 模拟处理延迟 - }) - - stream.ResetStats() - - // 大量数据输入 - testData := generateTestData(5) - inputCount := 10000 - - start := time.Now() - for i := 0; i < inputCount; i++ { - stream.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - time.Sleep(300 * time.Millisecond) - - detailedStats := stream.GetDetailedStats() - sinks := atomic.LoadInt64(&sinkCount) - - t.Logf("高性能配置结果:") - t.Logf(" 输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf(" 处理效率: %.2f%%", detailedStats["process_rate"]) - t.Logf(" 丢弃率: %.2f%%", detailedStats["drop_rate"]) - t.Logf(" 性能等级: %s", detailedStats["performance_level"]) - t.Logf(" 数据通道使用率: %.2f%%", detailedStats["data_chan_usage"]) - t.Logf(" Sink处理数: %d", sinks) - t.Logf(" 详细统计: %+v", detailedStats["basic_stats"]) - }) - }) - - t.Run("超高负载抗压测试", func(t *testing.T) { - // 使用最大缓冲区配置测试极限情况 - config := types.Config{ - SimpleFields: []string{"deviceId", "temperature"}, - } - - // 自定义超大缓冲区:100K输入,100K结果,2K sink池 - stream, err := stream.NewStreamWithBuffers(config, 100000, 100000, 2000) - if err != nil { - t.Fatalf("超大缓冲区Stream创建失败: %v", err) - } - defer stream.Stop() - - err = stream.RegisterFilter("temperature > 15") - if err != nil { - t.Fatalf("过滤器注册失败: %v", err) - } - - stream.Start() - - // 添加多个慢速sink模拟极端场景 - var sink1Count, sink2Count, sink3Count int64 - - stream.AddSink(func(result interface{}) { - atomic.AddInt64(&sink1Count, 1) - time.Sleep(3 * time.Millisecond) // 慢速sink - }) - - stream.AddSink(func(result interface{}) { - atomic.AddInt64(&sink2Count, 1) - time.Sleep(5 * time.Millisecond) // 更慢的sink - }) - - stream.AddSink(func(result interface{}) { - atomic.AddInt64(&sink3Count, 1) - time.Sleep(1 * time.Millisecond) // 相对快速的sink - }) - - stream.ResetStats() - - // 超大量数据输入 - testData := generateTestData(3) - inputCount := 50000 // 5万条数据 - - t.Logf("开始超高负载测试:输入 %d 条数据", inputCount) - - start := time.Now() - for i := 0; i < inputCount; i++ { - stream.AddData(testData[i%len(testData)]) - - // 偶尔检查状态,避免测试超时 - if i%10000 == 0 && i > 0 { - t.Logf("已输入 %d 条数据", i) - } - } - inputDuration := time.Since(start) - - t.Logf("数据输入完成,耗时: %v", inputDuration) - - // 等待处理完成 - time.Sleep(500 * time.Millisecond) - - detailedStats := stream.GetDetailedStats() - - t.Logf("超高负载测试结果:") - t.Logf(" 输入速率: %.2f ops/sec", float64(inputCount)/inputDuration.Seconds()) - t.Logf(" 处理效率: %.2f%%", detailedStats["process_rate"]) - t.Logf(" 丢弃率: %.2f%%", detailedStats["drop_rate"]) - t.Logf(" 性能等级: %s", detailedStats["performance_level"]) - t.Logf(" 数据通道使用率: %.2f%%", detailedStats["data_chan_usage"]) - t.Logf(" 结果通道使用率: %.2f%%", detailedStats["result_chan_usage"]) - t.Logf(" Sink池使用率: %.2f%%", detailedStats["sink_pool_usage"]) - - sinks := []int64{ - atomic.LoadInt64(&sink1Count), - atomic.LoadInt64(&sink2Count), - atomic.LoadInt64(&sink3Count), - } - t.Logf(" Sink处理数: %v", sinks) - t.Logf(" 详细统计: %+v", detailedStats["basic_stats"]) - - // 验证性能指标 - if detailedStats["drop_rate"].(float64) > 30 { - t.Logf("注意: 丢弃率较高 %.2f%%,但系统未阻塞", detailedStats["drop_rate"]) - } - - throughput := float64(inputCount) / inputDuration.Seconds() - if throughput < 50000 { // 期望至少5万ops/sec - t.Logf("注意: 吞吐量 %.2f ops/sec,但在超高负载下属于正常范围", throughput) - } else { - t.Logf("优秀: 超高负载下仍保持高吞吐量 %.2f ops/sec", throughput) - } - - // 验证系统没有完全阻塞 - if inputDuration > 2*time.Second { - t.Errorf("超高负载下输入耗时过长: %v", inputDuration) - } - - // 验证缓冲区配置有效 - basicStats := detailedStats["basic_stats"].(map[string]int64) - t.Logf("缓冲区配置验证:") - t.Logf(" 数据通道: %d/%d", basicStats["data_chan_len"], basicStats["data_chan_cap"]) - t.Logf(" 结果通道: %d/%d", basicStats["result_chan_len"], basicStats["result_chan_cap"]) - t.Logf(" Sink池: %d/%d", basicStats["sink_pool_len"], basicStats["sink_pool_cap"]) - }) -} - // BenchmarkConfigurationComparison 不同配置性能对比基准测试 func BenchmarkConfigurationComparison(b *testing.B) { tests := []struct { @@ -1295,7 +322,7 @@ func BenchmarkConfigurationComparison(b *testing.B) { var resultCount int64 // 添加轻量级sink - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { atomic.AddInt64(&resultCount, 1) }) @@ -1322,7 +349,7 @@ func BenchmarkConfigurationComparison(b *testing.B) { // 执行基准测试 start := time.Now() for i := 0; i < b.N; i++ { - ssql.AddData(testData[i%len(testData)]) + ssql.Emit(testData[i%len(testData)]) } inputDuration := time.Since(start) @@ -1364,796 +391,128 @@ func BenchmarkConfigurationComparison(b *testing.B) { } } -// BenchmarkOptimizedPerformance 优化后的单项性能测试 -func BenchmarkOptimizedPerformance(b *testing.B) { - b.Run("纯输入性能-高性能配置", func(b *testing.B) { - ssql := New(WithHighPerformance()) - defer ssql.Stop() - - sql := "SELECT deviceId FROM stream" - err := ssql.Execute(sql) - if err != nil { - b.Fatal(err) - } - - // 消费者防止阻塞 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - // 预生成数据 - data := map[string]interface{}{ - "deviceId": "device1", - "temperature": 25.0, - } - - b.ResetTimer() - start := time.Now() - - for i := 0; i < b.N; i++ { - ssql.AddData(data) - } - - b.StopTimer() - duration := time.Since(start) - throughput := float64(b.N) / duration.Seconds() - - // 获取统计 - detailedStats := ssql.Stream().GetDetailedStats() - dropRate := detailedStats["drop_rate"].(float64) - - b.ReportMetric(throughput, "pure_input_ops/sec") - b.ReportMetric(dropRate, "drop_rate_%") - - b.Logf("高性能配置下丢弃率: %.2f%%", dropRate) - }) - - b.Run("窗口聚合性能-超大缓冲", func(b *testing.B) { - ssql := New(WithBufferSizes(50000, 50000, 1500)) - defer ssql.Stop() - - sql := "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')" - err := ssql.Execute(sql) - if err != nil { - b.Fatal(err) - } - - var resultCount int64 - - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&resultCount, 1) - }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - testData := generateTestData(2) - - b.ResetTimer() - - start := time.Now() - for i := 0; i < b.N; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - b.StopTimer() - - time.Sleep(200 * time.Millisecond) - cancel() - - results := atomic.LoadInt64(&resultCount) - detailedStats := ssql.Stream().GetDetailedStats() - - inputThroughput := float64(b.N) / inputDuration.Seconds() - processRate := detailedStats["process_rate"].(float64) - dropRate := detailedStats["drop_rate"].(float64) - - b.ReportMetric(inputThroughput, "input_ops/sec") - b.ReportMetric(processRate, "process_rate_%") - b.ReportMetric(dropRate, "drop_rate_%") - b.ReportMetric(float64(results), "window_results") - - b.Logf("窗口聚合 - 处理效率: %.2f%%, 丢弃率: %.2f%%", processRate, dropRate) - }) -} - // TestMemoryUsageComparison 内存使用对比测试 -func TestMemoryUsageComparison(t *testing.T) { - tests := []struct { - name string - setupFunc func() *Streamsql - description string - expectedMB float64 // 预期内存使用(MB) - }{ - { - name: "轻量配置", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(5000, 5000, 250)) - }, - description: "5K数据 + 5K结果 + 250sink池", - expectedMB: 1.0, // 预期约1MB - }, - { - name: "默认配置(中等场景)", - setupFunc: func() *Streamsql { - return New() - }, - description: "20K数据 + 20K结果 + 800sink池", - expectedMB: 3.0, // 预期约3MB - }, - { - name: "高性能配置", - setupFunc: func() *Streamsql { - return New(WithHighPerformance()) - }, - description: "50K数据 + 50K结果 + 1Ksinki池", - expectedMB: 12.0, // 预期约12MB - }, - { - name: "超大缓冲配置", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(100000, 100000, 2000)) - }, - description: "100K数据缓冲,100K结果缓冲,2Ksinki池", - expectedMB: 25.0, // 预期约25MB - }, - } - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 获取开始内存 - var startMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&startMem) - - // 创建Stream - ssql := tt.setupFunc() - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 等待初始化完成 - time.Sleep(10 * time.Millisecond) - - // 获取创建后内存 - var afterCreateMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&afterCreateMem) - - createUsage := float64(afterCreateMem.Alloc-startMem.Alloc) / 1024 / 1024 - - // 添加一些数据测试内存增长 - testData := generateTestData(3) - for i := 0; i < 1000; i++ { - ssql.AddData(testData[i%len(testData)]) - } - - time.Sleep(50 * time.Millisecond) - - // 获取使用后内存 - var afterUseMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&afterUseMem) - - totalUsage := float64(afterUseMem.Alloc-startMem.Alloc) / 1024 / 1024 - - // 获取详细统计 - detailedStats := ssql.Stream().GetDetailedStats() - basicStats := detailedStats["basic_stats"].(map[string]int64) - - ssql.Stop() - - t.Logf("=== %s 内存使用分析 ===", tt.name) - t.Logf("配置: %s", tt.description) - t.Logf("创建开销: %.2f MB", createUsage) - t.Logf("总内存使用: %.2f MB", totalUsage) - t.Logf("缓冲区配置:") - t.Logf(" 数据通道: %d", basicStats["data_chan_cap"]) - t.Logf(" 结果通道: %d", basicStats["result_chan_cap"]) - t.Logf(" Sink池: %d", basicStats["sink_pool_cap"]) - - // 计算理论内存使用 (每个接口槽位约24字节) - dataChanMem := float64(basicStats["data_chan_cap"]) * 24 / 1024 / 1024 - resultChanMem := float64(basicStats["result_chan_cap"]) * 24 / 1024 / 1024 - sinkPoolMem := float64(basicStats["sink_pool_cap"]) * 8 / 1024 / 1024 // 函数指针 - - theoreticalMem := dataChanMem + resultChanMem + sinkPoolMem - - t.Logf("理论内存分配:") - t.Logf(" 数据通道: %.2f MB", dataChanMem) - t.Logf(" 结果通道: %.2f MB", resultChanMem) - t.Logf(" Sink池: %.2f MB", sinkPoolMem) - t.Logf(" 理论总计: %.2f MB", theoreticalMem) - - // 内存效率分析 - if totalUsage > tt.expectedMB*2 { - t.Logf("警告: 内存使用超过预期2倍 (%.2f MB > %.2f MB)", totalUsage, tt.expectedMB*2) - } else if totalUsage > tt.expectedMB*1.5 { - t.Logf("注意: 内存使用超过预期50%% (%.2f MB > %.2f MB)", totalUsage, tt.expectedMB*1.5) - } else { - t.Logf("✓ 内存使用在合理范围内 (%.2f MB)", totalUsage) - } - }) - } -} - -// TestResourceCostAnalysis 资源成本分析测试 -func TestResourceCostAnalysis(t *testing.T) { - scenarios := []struct { - name string - setup func() *Streamsql - workload int - description string - costCategory string - }{ - { - name: "轻量场景", - setup: func() *Streamsql { return New(WithBufferSizes(5000, 5000, 250)) }, - workload: 1000, - description: "资源受限环境,轻量业务", - costCategory: "低成本", - }, - { - name: "中等场景(默认)", - setup: func() *Streamsql { return New() }, - workload: 10000, - description: "生产环境,正常峰值", - costCategory: "中等成本", - }, - { - name: "高负载场景", - setup: func() *Streamsql { return New(WithHighPerformance()) }, - workload: 50000, - description: "高并发,极端负载", - costCategory: "高成本", - }, - } - - sql := "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')" - - t.Log("=== 资源成本对比分析 ===") - - for _, scenario := range scenarios { - t.Run(scenario.name, func(t *testing.T) { - // 内存监控 - var beforeMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&beforeMem) - - // CPU和goroutine监控 - beforeGoroutines := runtime.NumGoroutine() - start := time.Now() - - // 创建实例 - ssql := scenario.setup() - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 运行负载 - testData := generateTestData(5) - for i := 0; i < scenario.workload; i++ { - ssql.AddData(testData[i%len(testData)]) - } - - // 等待处理完成 - time.Sleep(200 * time.Millisecond) - - // 获取统计 - detailedStats := ssql.Stream().GetDetailedStats() - basicStats := detailedStats["basic_stats"].(map[string]int64) - - // 资源使用测量 - duration := time.Since(start) - var afterMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&afterMem) - afterGoroutines := runtime.NumGoroutine() - - memUsage := float64(afterMem.Alloc-beforeMem.Alloc) / 1024 / 1024 - goroutineIncrease := afterGoroutines - beforeGoroutines - - ssql.Stop() - - // 成本分析报告 - t.Logf("--- %s 成本分析 ---", scenario.name) - t.Logf("场景: %s (%s)", scenario.description, scenario.costCategory) - t.Logf("负载: %d 条数据", scenario.workload) - t.Logf("执行时间: %v", duration) - - t.Logf("资源消耗:") - t.Logf(" 内存: %.2f MB", memUsage) - t.Logf(" Goroutine增加: %d", goroutineIncrease) - t.Logf(" 处理速率: %.2f ops/sec", float64(scenario.workload)/duration.Seconds()) - - t.Logf("缓冲区开销:") - t.Logf(" 数据通道容量: %d", basicStats["data_chan_cap"]) - t.Logf(" 结果通道容量: %d", basicStats["result_chan_cap"]) - t.Logf(" Sink池容量: %d", basicStats["sink_pool_cap"]) - - t.Logf("性能指标:") - t.Logf(" 处理效率: %.2f%%", detailedStats["process_rate"]) - t.Logf(" 丢弃率: %.2f%%", detailedStats["drop_rate"]) - t.Logf(" 性能等级: %s", detailedStats["performance_level"]) - - // 成本效益分析 - throughput := float64(scenario.workload) / duration.Seconds() - memEfficiency := throughput / memUsage // ops/sec per MB - - t.Logf("成本效益:") - t.Logf(" 内存效率: %.2f ops/sec/MB", memEfficiency) - - // 推荐使用场景 - switch scenario.costCategory { - case "低成本": - if throughput > 5000 { - t.Logf("✓ 推荐: 适合日常业务、开发测试、资源受限环境") - } - case "中等成本": - if throughput > 20000 { - t.Logf("✓ 推荐: 适合生产环境、中等负载、平衡性能和成本") - } - case "高成本": - if throughput > 100000 { - t.Logf("✓ 推荐: 适合极高负载、关键业务、性能优先场景") - } else { - t.Logf("⚠ 注意: 高成本但吞吐量未达到预期,可能配置过度") - } - } - }) - } - - // 配置选择建议 - t.Log("\n=== 配置选择建议 ===") - t.Log("1. 默认配置: 适合大多数场景,内存占用低(~2MB),goroutine开销小") - t.Log("2. 中等配置: 适合中等负载,内存占用中等(~5MB),平衡性能和成本") - t.Log("3. 高性能配置: 适合极高负载,内存占用高(~12MB),最大化吞吐量") - t.Log("4. 自定义配置: 根据具体业务需求精确调优,避免资源浪费") -} - -// TestHighPerformanceCostAnalysis 高性能模式代价深度分析 -func TestHighPerformanceCostAnalysis(t *testing.T) { - t.Log("=== 高性能模式 vs 默认配置深度对比 ===") - - configs := []struct { - name string - setup func() *Streamsql - category string - }{ - { - name: "默认配置", - setup: func() *Streamsql { return New() }, - category: "基准", - }, - { - name: "高性能配置", - setup: func() *Streamsql { return New(WithHighPerformance()) }, - category: "优化", - }, - } - - sql := "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('50ms')" - workload := 10000 - - var results []map[string]interface{} - var resultsMutex sync.Mutex - - for _, config := range configs { - t.Run(config.name, func(t *testing.T) { - // 详细资源监控 - var beforeMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&beforeMem) - beforeGoroutines := runtime.NumGoroutine() - - // 创建并启动Stream - ssql := config.setup() - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("执行失败: %v", err) - } - - // 等待完全初始化 - time.Sleep(20 * time.Millisecond) - - // 测量初始化开销 - var afterInitMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&afterInitMem) - afterInitGoroutines := runtime.NumGoroutine() - - initMemCost := float64(afterInitMem.Alloc-beforeMem.Alloc) / 1024 / 1024 - initGoroutineCost := afterInitGoroutines - beforeGoroutines - - // 运行负载测试 - testData := generateTestData(3) - start := time.Now() - - for i := 0; i < workload; i++ { - ssql.AddData(testData[i%len(testData)]) - } - - inputDuration := time.Since(start) - time.Sleep(100 * time.Millisecond) // 等待处理完成 - - // 最终资源测量 - var finalMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&finalMem) - finalGoroutines := runtime.NumGoroutine() - - totalMemUsage := float64(finalMem.Alloc-beforeMem.Alloc) / 1024 / 1024 - runtimeMemUsage := totalMemUsage - initMemCost - - // 获取详细统计 - detailedStats := ssql.Stream().GetDetailedStats() - basicStats := detailedStats["basic_stats"].(map[string]int64) - - ssql.Stop() - - // 分析结果 - result := map[string]interface{}{ - "name": config.name, - "category": config.category, - "init_memory_mb": initMemCost, - "runtime_memory_mb": runtimeMemUsage, - "total_memory_mb": totalMemUsage, - "init_goroutines": initGoroutineCost, - "total_goroutines": finalGoroutines - beforeGoroutines, - "input_duration_ms": float64(inputDuration.Nanoseconds()) / 1e6, - "throughput_ops_sec": float64(workload) / inputDuration.Seconds(), - "data_chan_cap": basicStats["data_chan_cap"], - "result_chan_cap": basicStats["result_chan_cap"], - "sink_pool_cap": basicStats["sink_pool_cap"], - "process_rate": detailedStats["process_rate"], - "drop_rate": detailedStats["drop_rate"], - "performance_level": detailedStats["performance_level"], - } - - resultsMutex.Lock() - results = append(results, result) - resultsMutex.Unlock() - - // 详细报告 - t.Logf("=== %s 详细分析 ===", config.name) - t.Logf("初始化开销:") - t.Logf(" 内存: %.2f MB", initMemCost) - t.Logf(" Goroutine: %d 个", initGoroutineCost) - - t.Logf("运行时开销:") - t.Logf(" 额外内存: %.2f MB", runtimeMemUsage) - t.Logf(" 总内存: %.2f MB", totalMemUsage) - t.Logf(" 总Goroutine: %d 个", finalGoroutines-beforeGoroutines) - - t.Logf("性能表现:") - t.Logf(" 输入耗时: %.2f ms", float64(inputDuration.Nanoseconds())/1e6) - t.Logf(" 吞吐量: %.2f ops/sec", float64(workload)/inputDuration.Seconds()) - t.Logf(" 处理效率: %.2f%%", detailedStats["process_rate"]) - t.Logf(" 丢弃率: %.2f%%", detailedStats["drop_rate"]) - - t.Logf("缓冲区配置:") - t.Logf(" 数据通道: %d", basicStats["data_chan_cap"]) - t.Logf(" 结果通道: %d", basicStats["result_chan_cap"]) - t.Logf(" Sink池: %d", basicStats["sink_pool_cap"]) - }) - } - - // 对比分析 - if len(results) == 2 { - defaultResult := results[0] - highPerfResult := results[1] - - t.Log("\n=== 对比分析总结 ===") - - // 内存开销对比 - memMultiplier := highPerfResult["total_memory_mb"].(float64) / defaultResult["total_memory_mb"].(float64) - t.Logf("内存开销倍数: %.1fx (%.2f MB vs %.2f MB)", - memMultiplier, - highPerfResult["total_memory_mb"], - defaultResult["total_memory_mb"]) - - // 性能提升对比 - perfMultiplier := highPerfResult["throughput_ops_sec"].(float64) / defaultResult["throughput_ops_sec"].(float64) - t.Logf("性能提升倍数: %.1fx (%.0f ops/sec vs %.0f ops/sec)", - perfMultiplier, - highPerfResult["throughput_ops_sec"], - defaultResult["throughput_ops_sec"]) - - // 缓冲区容量对比 - dataCapMultiplier := float64(highPerfResult["data_chan_cap"].(int64)) / float64(defaultResult["data_chan_cap"].(int64)) - t.Logf("缓冲区容量倍数: %.1fx", dataCapMultiplier) - - // 成本效益分析 - memEfficiencyDefault := defaultResult["throughput_ops_sec"].(float64) / defaultResult["total_memory_mb"].(float64) - memEfficiencyHighPerf := highPerfResult["throughput_ops_sec"].(float64) / highPerfResult["total_memory_mb"].(float64) - - t.Logf("内存效率对比:") - t.Logf(" 默认配置: %.0f ops/sec/MB", memEfficiencyDefault) - t.Logf(" 高性能配置: %.0f ops/sec/MB", memEfficiencyHighPerf) - t.Logf(" 效率比: %.2fx", memEfficiencyHighPerf/memEfficiencyDefault) - - // 代价分析 - t.Log("\n=== 高性能模式代价分析 ===") - t.Logf("✓ 性能收益: %.1fx 吞吐量提升", perfMultiplier) - t.Logf("✗ 内存代价: %.1fx 内存消耗增长", memMultiplier) - t.Logf("✗ 缓冲区代价: %.1fx 缓冲区容量增长", dataCapMultiplier) - - if memMultiplier < perfMultiplier { - t.Log("✓ 结论: 高性能模式性价比较高,内存增长小于性能提升") - } else { - t.Log("⚠ 结论: 高性能模式需要权衡,内存增长超过性能提升比例") - } - } -} - -// TestLightweightVsDefaultPerformanceAnalysis 专门分析轻量配置vs默认配置性能差异的测试 -func TestLightweightVsDefaultPerformanceAnalysis(t *testing.T) { - configs := []struct { - name string - setupFunc func() *Streamsql - description string - expectedPerf string - }{ - { - name: "轻量配置(5K)", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(5000, 5000, 250)) - }, - description: "5K数据缓冲,5K结果缓冲,250sink池", - expectedPerf: "高吞吐,低内存", - }, - { - name: "默认配置(20K)", - setupFunc: func() *Streamsql { - return New() - }, - description: "20K数据缓冲,20K结果缓冲,800sink池", - expectedPerf: "平衡性能", - }, - { - name: "中等配置(10K)", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(10000, 10000, 400)) - }, - description: "10K数据缓冲,10K结果缓冲,400sink池", - expectedPerf: "介于两者之间", - }, - } - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - - t.Log("=== 轻量配置 vs 默认配置深度对比分析 ===") - - var results []map[string]interface{} - var resultsMutex sync.Mutex - - for _, config := range configs { - t.Run(config.name, func(t *testing.T) { - // 内存和性能监控 - var beforeMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&beforeMem) - - ssql := config.setupFunc() - defer ssql.Stop() - - err := ssql.Execute(sql) - if err != nil { - t.Fatalf("SQL执行失败: %v", err) - } - - // 测量初始化后内存 - var afterInitMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&afterInitMem) - initMemory := float64(afterInitMem.Alloc-beforeMem.Alloc) / 1024 / 1024 - - var resultCount int64 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&resultCount, 1) - }) - - // 消费resultChan防止阻塞 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - // 重置统计 - ssql.Stream().ResetStats() - - // 执行性能测试 - testData := generateTestData(3) - iterations := 50000 // 固定迭代次数便于对比 - - start := time.Now() - for i := 0; i < iterations; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - // 等待处理完成 - time.Sleep(100 * time.Millisecond) - - // 获取详细统计 - detailedStats := ssql.Stream().GetDetailedStats() - basicStats := detailedStats["basic_stats"].(map[string]int64) - - // 最终内存测量 - var finalMem runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&finalMem) - totalMemory := float64(finalMem.Alloc-beforeMem.Alloc) / 1024 / 1024 - - // 计算指标 - inputThroughput := float64(iterations) / inputDuration.Seconds() - processRate := detailedStats["process_rate"].(float64) - dropRate := detailedStats["drop_rate"].(float64) - perfLevel := detailedStats["performance_level"].(string) - dataChanUsage := detailedStats["data_chan_usage"].(float64) - memEfficiency := inputThroughput / totalMemory // ops/sec per MB - - result := map[string]interface{}{ - "name": config.name, - "description": config.description, - "data_chan_cap": basicStats["data_chan_cap"], - "result_chan_cap": basicStats["result_chan_cap"], - "sink_pool_cap": basicStats["sink_pool_cap"], - "init_memory_mb": initMemory, - "total_memory_mb": totalMemory, - "input_throughput": inputThroughput, - "process_rate": processRate, - "drop_rate": dropRate, - "performance_level": perfLevel, - "data_chan_usage": dataChanUsage, - "mem_efficiency": memEfficiency, - "input_duration_ms": float64(inputDuration.Nanoseconds()) / 1e6, - } - resultsMutex.Lock() - results = append(results, result) - resultsMutex.Unlock() - - // 详细报告 - t.Logf("=== %s 分析报告 ===", config.name) - t.Logf("配置: %s", config.description) - t.Logf("预期: %s", config.expectedPerf) - - t.Logf("缓冲区配置:") - t.Logf(" 数据通道: %d", basicStats["data_chan_cap"]) - t.Logf(" 结果通道: %d", basicStats["result_chan_cap"]) - t.Logf(" Sink池: %d", basicStats["sink_pool_cap"]) - - t.Logf("内存使用:") - t.Logf(" 初始化: %.2f MB", initMemory) - t.Logf(" 总计: %.2f MB", totalMemory) - - t.Logf("性能指标:") - t.Logf(" 输入速率: %.0f ops/sec", inputThroughput) - t.Logf(" 输入耗时: %.2f ms", float64(inputDuration.Nanoseconds())/1e6) - t.Logf(" 处理效率: %.2f%%", processRate) - t.Logf(" 丢弃率: %.2f%%", dropRate) - t.Logf(" 性能等级: %s", perfLevel) - t.Logf(" 数据通道使用率: %.1f%%", dataChanUsage) - t.Logf(" 内存效率: %.0f ops/sec/MB", memEfficiency) - - // 性能分析 - if dropRate > 10 { - t.Logf("⚠ 警告: 丢弃率较高 (%.2f%%)", dropRate) - } else if dropRate < 1 { - t.Logf("✓ 优秀: 丢弃率很低 (%.2f%%)", dropRate) - } - - if dataChanUsage > 80 { - t.Logf("⚠ 警告: 数据通道使用率过高 (%.1f%%)", dataChanUsage) - } else if dataChanUsage < 50 { - t.Logf("✓ 良好: 数据通道使用率适中 (%.1f%%)", dataChanUsage) - } - }) - } - - // 对比分析 - if len(results) >= 2 { - t.Log("\n=== 对比分析结论 ===") - - lightweight := results[0] // 轻量配置 - defaultCfg := results[1] // 默认配置 - - // 性能对比 - perfRatio := lightweight["input_throughput"].(float64) / defaultCfg["input_throughput"].(float64) - memRatio := lightweight["total_memory_mb"].(float64) / defaultCfg["total_memory_mb"].(float64) - memEffRatio := lightweight["mem_efficiency"].(float64) / defaultCfg["mem_efficiency"].(float64) - - t.Logf("性能倍数: %.2fx (轻量 %.0f vs 默认 %.0f ops/sec)", - perfRatio, - lightweight["input_throughput"].(float64), - defaultCfg["input_throughput"].(float64)) - - t.Logf("内存倍数: %.2fx (轻量 %.2f vs 默认 %.2f MB)", - memRatio, - lightweight["total_memory_mb"].(float64), - defaultCfg["total_memory_mb"].(float64)) - - t.Logf("内存效率倍数: %.2fx (轻量 %.0f vs 默认 %.0f ops/sec/MB)", - memEffRatio, - lightweight["mem_efficiency"].(float64), - defaultCfg["mem_efficiency"].(float64)) - - // 分析原因 - t.Log("\n=== 轻量配置性能更高的可能原因 ===") - - lightUsage := lightweight["data_chan_usage"].(float64) - defaultUsage := defaultCfg["data_chan_usage"].(float64) - - t.Logf("1. 缓冲区压力差异:") - t.Logf(" 轻量配置数据通道使用率: %.1f%%", lightUsage) - t.Logf(" 默认配置数据通道使用率: %.1f%%", defaultUsage) - - if lightUsage < defaultUsage { - t.Log(" → 轻量配置缓冲区压力更小,减少了队列等待时间") - } - - lightCapacity := lightweight["data_chan_cap"].(int64) - defaultCapacity := defaultCfg["data_chan_cap"].(int64) - capacityRatio := float64(defaultCapacity) / float64(lightCapacity) - - t.Logf("2. 缓冲区容量差异:") - t.Logf(" 容量倍数: %.1fx (轻量 %d vs 默认 %d)", - capacityRatio, lightCapacity, defaultCapacity) - t.Log(" → 大缓冲区可能导致更多内存分配和GC压力") - - t.Logf("3. 内存分配模式:") - t.Logf(" 轻量配置总内存: %.2f MB", lightweight["total_memory_mb"].(float64)) - t.Logf(" 默认配置总内存: %.2f MB", defaultCfg["total_memory_mb"].(float64)) - t.Log(" → 轻量配置内存占用更少,减少GC频率和暂停时间") - - t.Log("\n=== 技术解释 ===") - t.Log("轻量配置吞吐量更高的核心原因:") - t.Log("1. **内存局部性更好**: 小缓冲区提高CPU缓存命中率") - t.Log("2. **GC压力更小**: 减少垃圾回收的暂停时间") - t.Log("3. **队列效率更高**: 小队列减少锁竞争和等待时间") - t.Log("4. **资源竞争减少**: 更少的内存分配减少系统调用开销") - t.Log("5. **适合高频小数据**: 本测试场景正好符合轻量配置的优势区间") - - if perfRatio > 1.1 { - t.Log("\n✓ 结论: 轻量配置在此场景下确实具有性能优势") - t.Log(" 推荐: 对于高频率、小数据量的场景,优先考虑轻量配置") - } else { - t.Log("\n→ 结论: 性能差异不显著,可能存在测试误差") - } - } -} +//func TestMemoryUsageComparison(t *testing.T) { +// tests := []struct { +// name string +// setupFunc func() *Streamsql +// description string +// expectedMB float64 // 预期内存使用(MB) +// }{ +// { +// name: "轻量配置", +// setupFunc: func() *Streamsql { +// return New(WithBufferSizes(5000, 5000, 250)) +// }, +// description: "5K数据 + 5K结果 + 250sink池", +// expectedMB: 1.0, // 预期约1MB +// }, +// { +// name: "默认配置(中等场景)", +// setupFunc: func() *Streamsql { +// return New() +// }, +// description: "20K数据 + 20K结果 + 800sink池", +// expectedMB: 3.0, // 预期约3MB +// }, +// { +// name: "高性能配置", +// setupFunc: func() *Streamsql { +// return New(WithHighPerformance()) +// }, +// description: "50K数据 + 50K结果 + 1Ksinki池", +// expectedMB: 12.0, // 预期约12MB +// }, +// { +// name: "超大缓冲配置", +// setupFunc: func() *Streamsql { +// return New(WithBufferSizes(100000, 100000, 2000)) +// }, +// description: "100K数据缓冲,100K结果缓冲,2Ksinki池", +// expectedMB: 25.0, // 预期约25MB +// }, +// } +// +// sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" +// +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// // 获取开始内存 +// var startMem runtime.MemStats +// runtime.GC() +// runtime.ReadMemStats(&startMem) +// +// // 创建Stream +// ssql := tt.setupFunc() +// err := ssql.Execute(sql) +// if err != nil { +// t.Fatalf("SQL执行失败: %v", err) +// } +// +// // 等待初始化完成 +// time.Sleep(10 * time.Millisecond) +// +// // 获取创建后内存 +// var afterCreateMem runtime.MemStats +// runtime.GC() +// runtime.ReadMemStats(&afterCreateMem) +// +// createUsage := float64(afterCreateMem.Alloc-startMem.Alloc) / 1024 / 1024 +// +// // 添加一些数据测试内存增长 +// testData := generateTestData(3) +// for i := 0; i < 1000; i++ { +// ssql.Emit(testData[i%len(testData)]) +// } +// +// time.Sleep(50 * time.Millisecond) +// +// // 获取使用后内存 +// var afterUseMem runtime.MemStats +// runtime.GC() +// runtime.ReadMemStats(&afterUseMem) +// +// totalUsage := float64(afterUseMem.Alloc-startMem.Alloc) / 1024 / 1024 +// +// // 获取详细统计 +// detailedStats := ssql.Stream().GetDetailedStats() +// basicStats := detailedStats["basic_stats"].(map[string]int64) +// +// ssql.Stop() +// +// t.Logf("=== %s 内存使用分析 ===", tt.name) +// t.Logf("配置: %s", tt.description) +// t.Logf("创建开销: %.2f MB", createUsage) +// t.Logf("总内存使用: %.2f MB", totalUsage) +// t.Logf("缓冲区配置:") +// t.Logf(" 数据通道: %d", basicStats["data_chan_cap"]) +// t.Logf(" 结果通道: %d", basicStats["result_chan_cap"]) +// t.Logf(" Sink池: %d", basicStats["sink_pool_cap"]) +// +// // 计算理论内存使用 (每个接口槽位约24字节) +// dataChanMem := float64(basicStats["data_chan_cap"]) * 24 / 1024 / 1024 +// resultChanMem := float64(basicStats["result_chan_cap"]) * 24 / 1024 / 1024 +// sinkPoolMem := float64(basicStats["sink_pool_cap"]) * 8 / 1024 / 1024 // 函数指针 +// +// theoreticalMem := dataChanMem + resultChanMem + sinkPoolMem +// +// t.Logf("理论内存分配:") +// t.Logf(" 数据通道: %.2f MB", dataChanMem) +// t.Logf(" 结果通道: %.2f MB", resultChanMem) +// t.Logf(" Sink池: %.2f MB", sinkPoolMem) +// t.Logf(" 理论总计: %.2f MB", theoreticalMem) +// +// // 内存效率分析 +// if totalUsage > tt.expectedMB*2 { +// t.Logf("警告: 内存使用超过预期2倍 (%.2f MB > %.2f MB)", totalUsage, tt.expectedMB*2) +// } else if totalUsage > tt.expectedMB*1.5 { +// t.Logf("注意: 内存使用超过预期50%% (%.2f MB > %.2f MB)", totalUsage, tt.expectedMB*1.5) +// } else { +// t.Logf("✓ 内存使用在合理范围内 (%.2f MB)", totalUsage) +// } +// }) +// } +//} // BenchmarkLightweightVsDefaultComparison 轻量 vs 默认配置基准测试 func BenchmarkLightweightVsDefaultComparison(b *testing.B) { @@ -2188,7 +547,7 @@ func BenchmarkLightweightVsDefaultComparison(b *testing.B) { } var resultCount int64 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { atomic.AddInt64(&resultCount, 1) }) @@ -2211,7 +570,7 @@ func BenchmarkLightweightVsDefaultComparison(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - ssql.AddData(testData[i%len(testData)]) + ssql.Emit(testData[i%len(testData)]) } inputDuration := time.Since(start) @@ -2282,7 +641,7 @@ func BenchmarkStreamSQLRealistic(b *testing.B) { var actualResultCount int64 // 测量实际的处理完成 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { atomic.AddInt64(&actualResultCount, 1) }) @@ -2299,7 +658,7 @@ func BenchmarkStreamSQLRealistic(b *testing.B) { start := time.Now() for i := 0; i < maxIterations; i++ { // 直接使用AddData,如果系统处理不过来会自然阻塞或丢弃 - ssql.AddData(testData[i%len(testData)]) + ssql.Emit(testData[i%len(testData)]) atomic.AddInt64(&processedCount, 1) // 每100条数据稍微停顿,模拟真实的数据流 @@ -2344,102 +703,6 @@ func min(a, b int) int { return b } -// BenchmarkStreamSQLOptimized 优化的高性能基准测试 -func BenchmarkStreamSQLOptimized(b *testing.B) { - tests := []struct { - name string - sql string - hasWindow bool - waitTime time.Duration - }{ - { - name: "HighThroughputFilter", - sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", - hasWindow: false, - waitTime: 10 * time.Millisecond, - }, - { - name: "HighThroughputAggregation", - sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('50ms')", - hasWindow: true, - waitTime: 100 * time.Millisecond, - }, - } - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - // 使用高性能配置 - ssql := New(WithHighPerformance()) - defer ssql.Stop() - - err := ssql.Execute(tt.sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - var actualResultCount int64 - - // 极简的sink,避免任何额外开销 - ssql.Stream().AddSink(func(result interface{}) { - atomic.AddInt64(&actualResultCount, 1) - }) - - // 异步消费resultChan,避免阻塞 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - // 立即丢弃,避免处理开销 - case <-ctx.Done(): - return - } - } - }() - - // 预生成测试数据,避免重复生成开销 - testData := generateTestData(5) - - ssql.Stream().ResetStats() - b.ResetTimer() - - // 纯粹的性能测试:无限制,无延迟 - start := time.Now() - for i := 0; i < b.N; i++ { - ssql.AddData(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - b.StopTimer() - - // 最小等待时间 - time.Sleep(tt.waitTime) - cancel() - - results := atomic.LoadInt64(&actualResultCount) - stats := ssql.Stream().GetStats() - - // 计算纯输入吞吐量 - inputThroughput := float64(b.N) / inputDuration.Seconds() - - b.ReportMetric(inputThroughput, "optimized_ops/sec") - b.ReportMetric(float64(results), "actual_results") - b.ReportMetric(float64(stats["dropped_count"]), "dropped_data") - - // 输出性能数据 - b.Logf("优化测试 - 输入: %d 条, 结果: %d 个", b.N, results) - b.Logf("优化吞吐量: %.0f ops/sec (%.1f万 ops/sec)", inputThroughput, inputThroughput/10000) - - if dropped := stats["dropped_count"]; dropped > 0 { - dropRate := float64(dropped) / float64(b.N) * 100 - b.Logf("丢弃率: %.2f%% (%d/%d)", dropRate, dropped, b.N) - } - }) - } -} - // BenchmarkPurePerformance 纯性能基准测试(无等待,无限制) func BenchmarkPurePerformance(b *testing.B) { ssql := New(WithHighPerformance()) @@ -2476,7 +739,7 @@ func BenchmarkPurePerformance(b *testing.B) { // 纯输入性能测试 for i := 0; i < b.N; i++ { - ssql.AddData(data) + ssql.Emit(data) } b.StopTimer() @@ -2542,7 +805,7 @@ func BenchmarkEndToEndProcessing(b *testing.B) { resultChan := make(chan bool, currentBatchSize) // 设置sink来捕获结果 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { count := atomic.AddInt64(&resultsReceived, 1) if count <= int64(currentBatchSize) { resultChan <- true @@ -2554,7 +817,7 @@ func BenchmarkEndToEndProcessing(b *testing.B) { // 输入数据 for i := 0; i < currentBatchSize; i++ { - ssql.AddData(testData[i%len(testData)]) + ssql.Emit(testData[i%len(testData)]) } // 等待所有结果处理完成(对于非聚合查询) @@ -2618,7 +881,7 @@ func BenchmarkSustainedProcessing(b *testing.B) { var lastResultTime time.Time // 设置结果处理器 - ssql.Stream().AddSink(func(result interface{}) { + ssql.AddSink(func(result interface{}) { atomic.AddInt64(&processedResults, 1) lastResultTime = time.Now() }) @@ -2630,7 +893,7 @@ func BenchmarkSustainedProcessing(b *testing.B) { // 持续输入数据 for i := 0; i < b.N; i++ { - ssql.AddData(testData[i%len(testData)]) + ssql.Emit(testData[i%len(testData)]) // 每1000条检查一次处理进度 if i > 0 && i%1000 == 0 { diff --git a/streamsql_case_test.go b/streamsql_case_test.go index 33170f5..66c5978 100644 --- a/streamsql_case_test.go +++ b/streamsql_case_test.go @@ -1,203 +1,15 @@ package streamsql -/* -CASE表达式测试状况说明: - -✅ 支持的功能: -- 基本搜索CASE表达式 (CASE WHEN ... THEN ... END) -- 简单CASE表达式 (CASE expr WHEN value THEN result END) -- 多条件逻辑 (AND, OR, NOT) -- 比较操作符 (>, <, >=, <=, =, !=) -- 数学函数 (ABS, ROUND等) -- 算术表达式 (+, -, *, /) -- 字段引用和提取 -- 非聚合SQL查询中使用 - -⚠️ 已知限制: -- 嵌套CASE表达式 (回退到expr-lang) -- 某些字符串函数 (类型转换问题) -- 聚合函数中的CASE表达式 (需要进一步实现) - -📝 测试策略: -- 对于已知限制,测试会跳过或标记为预期行为 -- 确保核心功能不受影响 -- 为未来改进提供清晰的测试基准 -*/ - import ( "context" - "strings" "sync" "testing" "time" - "github.com/rulego/streamsql/expr" + "github.com/rulego/streamsql/rsql" "github.com/stretchr/testify/assert" ) -// TestCaseExpressionParsing 测试CASE表达式的解析功能 -func TestCaseExpressionParsing(t *testing.T) { - tests := []struct { - name string - exprStr string - data map[string]interface{} - expected float64 - wantErr bool - }{ - { - name: "简单的搜索CASE表达式", - exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 35.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "简单CASE表达式 - 值匹配", - exprStr: "CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END", - data: map[string]interface{}{"status": "active"}, - expected: 1.0, - wantErr: false, - }, - { - name: "CASE表达式 - ELSE分支", - exprStr: "CASE WHEN temperature > 50 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 25.5}, - expected: 0.0, - wantErr: false, - }, - { - name: "复杂搜索CASE表达式", - exprStr: "CASE WHEN temperature > 30 THEN 'HOT' WHEN temperature > 20 THEN 'WARM' ELSE 'COLD' END", - data: map[string]interface{}{"temperature": 25.0}, - expected: 4.0, // 字符串"WARM"的长度,因为我们的字符串处理返回长度 - wantErr: false, - }, - { - name: "嵌套CASE表达式", - exprStr: "CASE WHEN temperature > 25 THEN CASE WHEN humidity > 60 THEN 1 ELSE 2 END ELSE 0 END", - data: map[string]interface{}{"temperature": 30.0, "humidity": 70.0}, - expected: 0.0, // 嵌套CASE回退到expr-lang,计算失败返回默认值0 - wantErr: false, - }, - { - name: "数值比较的简单CASE", - exprStr: "CASE temperature WHEN 25 THEN 1 WHEN 30 THEN 2 ELSE 0 END", - data: map[string]interface{}{"temperature": 30.0}, - expected: 2.0, - wantErr: false, - }, - { - name: "布尔值CASE表达式", - exprStr: "CASE WHEN temperature > 25 AND humidity > 50 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 30.0, "humidity": 60.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "多条件CASE表达式_AND", - exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 WHEN temperature > 20 THEN 2 ELSE 0 END", - data: map[string]interface{}{"temperature": 35.0, "humidity": 50.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "多条件CASE表达式_OR", - exprStr: "CASE WHEN temperature > 40 OR humidity > 80 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 25.0, "humidity": 85.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "函数调用在CASE中_ABS", - exprStr: "CASE WHEN ABS(temperature) > 30 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": -35.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "函数调用在CASE中_ROUND", - exprStr: "CASE WHEN ROUND(temperature) = 25 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 24.7}, - expected: 1.0, - wantErr: false, - }, - { - name: "复杂条件组合", - exprStr: "CASE WHEN temperature > 30 AND (humidity > 60 OR pressure < 1000) THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 35.0, "humidity": 55.0, "pressure": 950.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "CASE中的算术表达式", - exprStr: "CASE WHEN temperature * 1.8 + 32 > 100 THEN 1 ELSE 0 END", - data: map[string]interface{}{"temperature": 40.0}, // 40*1.8+32 = 104 - expected: 1.0, - wantErr: false, - }, - { - name: "字符串函数在CASE中", - exprStr: "CASE WHEN LENGTH(device_name) > 5 THEN 1 ELSE 0 END", - data: map[string]interface{}{"device_name": "sensor123"}, - expected: 1.0, // LENGTH函数现在正常工作,"sensor123"长度为9 > 5,返回1 - wantErr: false, - }, - { - name: "简单CASE与函数", - exprStr: "CASE ABS(temperature) WHEN 30 THEN 1 WHEN 25 THEN 2 ELSE 0 END", - data: map[string]interface{}{"temperature": -30.0}, - expected: 1.0, - wantErr: false, - }, - { - name: "CASE结果中的函数", - exprStr: "CASE WHEN temperature > 30 THEN ABS(temperature) ELSE ROUND(temperature) END", - data: map[string]interface{}{"temperature": 35.5}, - expected: 35.5, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 测试表达式创建 - expression, err := expr.NewExpression(tt.exprStr) - if tt.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err, "Expression creation should not fail") - assert.NotNil(t, expression, "Expression should not be nil") - - // 调试:检查表达式是否使用了expr-lang - t.Logf("Expression uses expr-lang: %v", expression.Root == nil) - if expression.Root != nil { - t.Logf("Expression root type: %s", expression.Root.Type) - } - - // 测试表达式计算 - result, err := expression.Evaluate(tt.data) - if tt.wantErr { - assert.Error(t, err) - return - } - - if err != nil { - t.Logf("Error evaluating expression: %v", err) - // 对于已知的限制(嵌套CASE和某些字符串函数),跳过测试 - if tt.name == "嵌套CASE表达式" || tt.name == "字符串函数在CASE中" { - t.Skipf("Known limitation: %s", err.Error()) - return - } - } - - assert.NoError(t, err, "Expression evaluation should not fail") - assert.Equal(t, tt.expected, result, "Expression result should match expected value") - }) - } -} - // TestCaseExpressionInSQL 测试CASE表达式在SQL查询中的使用 func TestCaseExpressionInSQL(t *testing.T) { // 测试非聚合场景中的CASE表达式 @@ -238,7 +50,7 @@ func TestCaseExpressionInSQL(t *testing.T) { }) for _, data := range testData { - streamSQL.stream.AddData(data) + streamSQL.Emit(data) } // 等待处理 @@ -290,7 +102,7 @@ func TestCaseExpressionInAggregation(t *testing.T) { }) for _, data := range testData { - streamSQL.stream.AddData(data) + streamSQL.Emit(data) } // 等待窗口触发 @@ -302,50 +114,98 @@ func TestCaseExpressionInAggregation(t *testing.T) { // 等待结果 time.Sleep(100 * time.Millisecond) - // 验证至少有结果返回 + // 验证结果 resultsMutex.Lock() - resultCount := len(results) - var firstResult map[string]interface{} - if resultCount > 0 { - firstResult = results[0] + defer resultsMutex.Unlock() + + assert.Greater(t, len(results), 0, "应该有聚合结果返回") + + // 验证结果结构和内容 + deviceResults := make(map[string]map[string]interface{}) + for _, result := range results { + deviceId, ok := result["deviceId"].(string) + assert.True(t, ok, "deviceId应该是字符串类型") + deviceResults[deviceId] = result } - resultsMutex.Unlock() - assert.Greater(t, resultCount, 0, "应该有聚合结果返回") + // 期望有两个设备的结果 + assert.Len(t, deviceResults, 2, "应该有两个设备的聚合结果") + assert.Contains(t, deviceResults, "device1", "应该包含device1的结果") + assert.Contains(t, deviceResults, "device2", "应该包含device2的结果") - // 验证结果结构 - if resultCount > 0 { - t.Logf("聚合结果: %+v", firstResult) - assert.Contains(t, firstResult, "deviceId", "结果应该包含deviceId") - assert.Contains(t, firstResult, "total_count", "结果应该包含total_count") - assert.Contains(t, firstResult, "hot_count", "结果应该包含hot_count") - assert.Contains(t, firstResult, "avg_active_temp", "结果应该包含avg_active_temp") + // 验证device1的结果 + device1Result := deviceResults["device1"] - // 验证hot_count的逻辑:temperature > 30的记录数 - if deviceId := firstResult["deviceId"]; deviceId == "device1" { - // device1有两条温度>30的记录(35.0, 32.0) - hotCount := firstResult["hot_count"] - t.Logf("device1的hot_count: %v (类型: %T)", hotCount, hotCount) + // 基本字段检查 + assert.Contains(t, device1Result, "total_count", "device1结果应该包含total_count") + assert.Contains(t, device1Result, "hot_count", "device1结果应该包含hot_count") + assert.Contains(t, device1Result, "avg_active_temp", "device1结果应该包含avg_active_temp") - // 检查CASE表达式是否在聚合中正常工作 - if hotCount == 0 || hotCount == 0.0 { - t.Skip("CASE表达式在聚合函数中暂不支持,跳过此测试") - return - } - assert.Equal(t, 2.0, hotCount, "device1应该有2条高温记录") - } + // 详细数值验证 + totalCount1 := getFloat64Value(device1Result["total_count"]) + hotCount1 := getFloat64Value(device1Result["hot_count"]) + avgActiveTemp1 := getFloat64Value(device1Result["avg_active_temp"]) + + // device1: 3条记录总数 + assert.Equal(t, 3.0, totalCount1, "device1应该有3条记录") + + // device1: 2条高温记录 (35.0 > 30, 32.0 > 30) + assert.Equal(t, 2.0, hotCount1, "device1应该有2条高温记录") + + // device1: active状态的平均温度 (35.0 + 0 + 32.0) / 3 = 22.333... + expectedActiveAvg := (35.0 + 0 + 32.0) / 3.0 + assert.InDelta(t, expectedActiveAvg, avgActiveTemp1, 0.01, + "device1的AVG(CASE WHEN...)应该正确计算") + + // 验证device2的结果 + device2Result := deviceResults["device2"] + + // 基本字段检查 + assert.Contains(t, device2Result, "total_count", "device2结果应该包含total_count") + assert.Contains(t, device2Result, "hot_count", "device2结果应该包含hot_count") + assert.Contains(t, device2Result, "avg_active_temp", "device2结果应该包含avg_active_temp") + + // 详细数值验证 + totalCount2 := getFloat64Value(device2Result["total_count"]) + hotCount2 := getFloat64Value(device2Result["hot_count"]) + avgActiveTemp2 := getFloat64Value(device2Result["avg_active_temp"]) + + // device2: 2条记录总数 + assert.Equal(t, 2.0, totalCount2, "device2应该有2条记录") + + // device2: 0条高温记录 (没有温度>30的) + assert.Equal(t, 0.0, hotCount2, "device2应该有0条高温记录") + + // device2: CASE WHEN status='active' THEN temperature ELSE 0 + // 28.0 (active) + 0 (inactive) = 28.0, 平均值 = (28.0 + 0) / 2 = 14.0 + expectedActiveAvg2 := (28.0 + 0) / 2.0 + assert.InDelta(t, expectedActiveAvg2, avgActiveTemp2, 0.01, + "device2的AVG(CASE WHEN...)应该正确计算") +} + +// getFloat64Value 辅助函数,将interface{}转换为float64 +func getFloat64Value(value interface{}) float64 { + switch v := value.(type) { + case float64: + return v + case float32: + return float64(v) + case int: + return float64(v) + case int64: + return float64(v) + default: + return 0.0 } } // TestComplexCaseExpressionsInAggregation 测试复杂CASE表达式在聚合查询中的使用 func TestComplexCaseExpressionsInAggregation(t *testing.T) { - // 测试用例集合 testCases := []struct { name string sql string data []map[string]interface{} description string - expectSkip bool // 是否预期跳过(由于已知限制) }{ { name: "多条件CASE在SUM中", @@ -362,7 +222,6 @@ func TestComplexCaseExpressionsInAggregation(t *testing.T) { {"deviceId": "device1", "temperature": 20.0, "humidity": 40.0, "ts": time.Now()}, }, description: "测试多条件CASE表达式在SUM聚合中的使用", - expectSkip: true, // 聚合中的CASE表达式暂不完全支持 }, { name: "函数调用CASE在AVG中", @@ -377,7 +236,6 @@ func TestComplexCaseExpressionsInAggregation(t *testing.T) { {"deviceId": "device1", "temperature": 35.0, "ts": time.Now()}, // 这个会被排除 }, description: "测试带函数的CASE表达式在AVG聚合中的使用", - expectSkip: false, // 测试SQL解析是否正常 }, { name: "复杂算术CASE在COUNT中", @@ -392,7 +250,6 @@ func TestComplexCaseExpressionsInAggregation(t *testing.T) { {"deviceId": "device1", "temperature": 35.0, "ts": time.Now()}, // 95F }, description: "测试算术表达式CASE在COUNT聚合中的使用", - expectSkip: true, // 聚合中的CASE表达式暂不完全支持 }, } @@ -403,22 +260,7 @@ func TestComplexCaseExpressionsInAggregation(t *testing.T) { defer streamSQL.Stop() err := streamSQL.Execute(tc.sql) - - // 如果SQL执行失败,检查是否是已知的限制 - if err != nil { - t.Logf("SQL执行失败: %v", err) - if tc.expectSkip { - t.Skipf("已知限制: %s - %v", tc.description, err) - return - } - // 如果不是预期的跳过,则检查是否是CASE表达式在聚合中的问题 - if strings.Contains(err.Error(), "CASEWHEN") || strings.Contains(err.Error(), "Unknown function") { - t.Skipf("CASE表达式在聚合SQL解析中的已知问题: %v", err) - return - } - assert.NoError(t, err, "执行SQL应该成功: %s", tc.description) - return - } + assert.NoError(t, err, "执行SQL应该成功") // 添加数据并获取结果 var results []map[string]interface{} @@ -432,7 +274,7 @@ func TestComplexCaseExpressionsInAggregation(t *testing.T) { }) for _, data := range tc.data { - streamSQL.stream.AddData(data) + streamSQL.Emit(data) } // 等待窗口触发 @@ -447,199 +289,18 @@ func TestComplexCaseExpressionsInAggregation(t *testing.T) { // 验证至少有结果返回 resultsMutex.Lock() hasResults := len(results) > 0 - var firstResult map[string]interface{} - if hasResults { - firstResult = results[0] - } resultsMutex.Unlock() - if hasResults { - t.Logf("Test case '%s' results: %+v", tc.name, firstResult) - - // 检查CASE表达式在聚合中的实际支持情况 - result := firstResult - for key, value := range result { - if key != "deviceId" && (value == 0 || value == 0.0) { - t.Logf("注意: %s 返回0,CASE表达式在聚合中可能暂不完全支持", key) - if tc.expectSkip { - t.Skipf("CASE表达式在聚合函数中暂不支持: %s", tc.description) - return - } - } - } - } else { - t.Log("未收到聚合结果 - 这对某些测试用例可能是预期的") - } + assert.True(t, hasResults, "应该有聚合结果返回") }) } } -// TestCaseExpressionFieldExtraction 测试CASE表达式的字段提取功能 -func TestCaseExpressionFieldExtraction(t *testing.T) { - testCases := []struct { - name string - exprStr string - expectedFields []string - }{ - { - name: "简单CASE字段提取", - exprStr: "CASE WHEN temperature > 30 THEN 1 ELSE 0 END", - expectedFields: []string{"temperature"}, - }, - { - name: "多字段CASE字段提取", - exprStr: "CASE WHEN temperature > 30 AND humidity < 60 THEN 1 ELSE 0 END", - expectedFields: []string{"temperature", "humidity"}, - }, - { - name: "简单CASE字段提取", - exprStr: "CASE status WHEN 'active' THEN temperature ELSE humidity END", - expectedFields: []string{"status", "temperature", "humidity"}, - }, - { - name: "函数CASE字段提取", - exprStr: "CASE WHEN ABS(temperature) > 30 THEN device_id ELSE location END", - expectedFields: []string{"temperature", "device_id", "location"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - expression, err := expr.NewExpression(tc.exprStr) - assert.NoError(t, err, "表达式创建应该成功") - - fields := expression.GetFields() - - // 验证所有期望的字段都被提取到了 - for _, expectedField := range tc.expectedFields { - assert.Contains(t, fields, expectedField, "应该包含字段: %s", expectedField) - } - - t.Logf("Expression: %s", tc.exprStr) - t.Logf("Extracted fields: %v", fields) - }) - } -} - -// TestCaseExpressionComprehensive 综合测试CASE表达式的完整功能 -func TestCaseExpressionComprehensive(t *testing.T) { - //t.Log("=== CASE表达式功能综合测试 ===") - - // 测试各种支持的CASE表达式类型 - supportedCases := []struct { - name string - expression string - testData map[string]interface{} - description string - }{ - { - name: "简单搜索CASE", - expression: "CASE WHEN temperature > 30 THEN 'HOT' ELSE 'COOL' END", - testData: map[string]interface{}{"temperature": 35.0}, - description: "基本的条件判断", - }, - { - name: "简单CASE值匹配", - expression: "CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END", - testData: map[string]interface{}{"status": "active"}, - description: "基于值的直接匹配", - }, - { - name: "多条件AND逻辑", - expression: "CASE WHEN temperature > 25 AND humidity > 60 THEN 1 ELSE 0 END", - testData: map[string]interface{}{"temperature": 30.0, "humidity": 70.0}, - description: "支持AND逻辑运算符", - }, - { - name: "多条件OR逻辑", - expression: "CASE WHEN temperature > 40 OR humidity > 80 THEN 1 ELSE 0 END", - testData: map[string]interface{}{"temperature": 25.0, "humidity": 85.0}, - description: "支持OR逻辑运算符", - }, - { - name: "复杂条件组合", - expression: "CASE WHEN temperature > 30 AND (humidity > 60 OR pressure < 1000) THEN 1 ELSE 0 END", - testData: map[string]interface{}{"temperature": 35.0, "humidity": 55.0, "pressure": 950.0}, - description: "支持括号和复杂逻辑组合", - }, - { - name: "函数调用在条件中", - expression: "CASE WHEN ABS(temperature) > 30 THEN 1 ELSE 0 END", - testData: map[string]interface{}{"temperature": -35.0}, - description: "支持在WHEN条件中调用函数", - }, - { - name: "算术表达式在条件中", - expression: "CASE WHEN temperature * 1.8 + 32 > 100 THEN 1 ELSE 0 END", - testData: map[string]interface{}{"temperature": 40.0}, - description: "支持算术表达式", - }, - { - name: "函数调用在结果中", - expression: "CASE WHEN temperature > 30 THEN ABS(temperature) ELSE ROUND(temperature) END", - testData: map[string]interface{}{"temperature": 35.5}, - description: "支持在THEN/ELSE结果中调用函数", - }, - { - name: "负数支持", - expression: "CASE WHEN temperature > 0 THEN 1 ELSE -1 END", - testData: map[string]interface{}{"temperature": -5.0}, - description: "正确处理负数常量", - }, - } - - for _, tc := range supportedCases { - t.Run(tc.name, func(t *testing.T) { - t.Logf("测试: %s", tc.description) - t.Logf("表达式: %s", tc.expression) - - expression, err := expr.NewExpression(tc.expression) - assert.NoError(t, err, "表达式解析应该成功") - assert.NotNil(t, expression, "表达式不应为空") - - // 检查是否使用了自定义解析器(不回退到expr-lang) - assert.False(t, expression.Root == nil, "应该使用自定义CASE解析器,而不是回退到expr-lang") - assert.Equal(t, "case", expression.Root.Type, "根节点应该是CASE类型") - - // 执行表达式计算 - result, err := expression.Evaluate(tc.testData) - assert.NoError(t, err, "表达式计算应该成功") - - t.Logf("计算结果: %v", result) - - // 测试字段提取 - fields := expression.GetFields() - assert.Greater(t, len(fields), 0, "应该能够提取到字段") - t.Logf("提取的字段: %v", fields) - }) - } - - //// 统计支持情况 - //t.Logf("\n=== CASE表达式功能支持总结 ===") - //t.Logf("✅ 基本搜索CASE表达式 (CASE WHEN ... THEN ... END)") - //t.Logf("✅ 简单CASE表达式 (CASE expr WHEN value THEN result END)") - //t.Logf("✅ 多个WHEN子句支持") - //t.Logf("✅ ELSE子句支持") - //t.Logf("✅ AND/OR逻辑运算符") - //t.Logf("✅ 括号表达式分组") - //t.Logf("✅ 数学函数调用 (ABS, ROUND等)") - //t.Logf("✅ 算术表达式 (+, -, *, /)") - //t.Logf("✅ 比较操作符 (>, <, >=, <=, =, !=)") - //t.Logf("✅ 负数常量") - //t.Logf("✅ 字符串字面量") - //t.Logf("✅ 字段引用") - //t.Logf("✅ 字段提取功能") - //t.Logf("✅ 在聚合函数中使用 (SUM, AVG, COUNT等)") - //t.Logf("❌ 嵌套CASE表达式 (回退到expr-lang)") - //t.Logf("❌ 字符串函数在某些场景 (类型转换问题)") -} - // TestCaseExpressionNonAggregated 测试非聚合场景下的CASE表达式 func TestCaseExpressionNonAggregated(t *testing.T) { tests := []struct { name string sql string testData []map[string]interface{} - expected interface{} wantErr bool }{ { @@ -676,25 +337,6 @@ func TestCaseExpressionNonAggregated(t *testing.T) { }, wantErr: false, }, - { - name: "嵌套CASE表达式", - sql: `SELECT deviceId, - CASE - WHEN temperature > 25 THEN - CASE - WHEN humidity > 70 THEN 'HOT_HUMID' - ELSE 'HOT_DRY' - END - ELSE 'NORMAL' - END as condition_type - FROM stream`, - testData: []map[string]interface{}{ - {"deviceId": "device1", "temperature": 30.0, "humidity": 80.0}, - {"deviceId": "device2", "temperature": 30.0, "humidity": 60.0}, - {"deviceId": "device3", "temperature": 20.0, "humidity": 80.0}, - }, - wantErr: false, - }, { name: "CASE表达式与其他字段组合", sql: `SELECT deviceId, temperature, @@ -726,8 +368,6 @@ func TestCaseExpressionNonAggregated(t *testing.T) { } if err != nil { - t.Logf("SQL execution failed for %s: %v", tt.name, err) - // 如果SQL执行失败,说明不支持该语法 t.Skip("CASE expression not yet supported in non-aggregated context") return } @@ -737,7 +377,7 @@ func TestCaseExpressionNonAggregated(t *testing.T) { // 添加测试数据 for _, data := range tt.testData { - strm.AddData(data) + strm.Emit(data) } // 捕获结果 @@ -754,11 +394,9 @@ func TestCaseExpressionNonAggregated(t *testing.T) { select { case result := <-resultChan: - t.Logf("Result: %v", result) - // 验证结果格式 assert.NotNil(t, result) case <-ctx.Done(): - t.Log("Timeout waiting for results - this may be expected for non-windowed queries") + // 对于非窗口查询,超时可能是正常的 } }) } @@ -770,7 +408,6 @@ func TestCaseExpressionAggregated(t *testing.T) { name string sql string testData []map[string]interface{} - expected interface{} wantErr bool }{ { @@ -780,7 +417,7 @@ func TestCaseExpressionAggregated(t *testing.T) { COUNT(CASE WHEN temperature <= 25 THEN 1 END) as normal_temp_count, COUNT(*) as total_count FROM stream - GROUP BY deviceId, TumblingWindow('5s') + GROUP BY deviceId, TumblingWindow('1s') WITH (TIMESTAMP='ts', TIMEUNIT='ss')`, testData: []map[string]interface{}{ {"deviceId": "device1", "temperature": 30.0, "ts": time.Now()}, @@ -803,7 +440,7 @@ func TestCaseExpressionAggregated(t *testing.T) { ELSE NULL END) as avg_high_humidity FROM stream - GROUP BY deviceId, TumblingWindow('5s') + GROUP BY deviceId, TumblingWindow('1s') WITH (TIMESTAMP='ts', TIMEUNIT='ss')`, testData: []map[string]interface{}{ {"deviceId": "device1", "temperature": 30.0, "humidity": 60.0, "ts": time.Now()}, @@ -812,52 +449,12 @@ func TestCaseExpressionAggregated(t *testing.T) { }, wantErr: false, }, - { - name: "CASE表达式作为聚合函数参数", - sql: `SELECT deviceId, - MAX(CASE - WHEN status = 'active' THEN temperature - ELSE -999 - END) as max_active_temp, - MIN(CASE - WHEN status = 'active' THEN temperature - ELSE 999 - END) as min_active_temp - FROM stream - GROUP BY deviceId, TumblingWindow('5s') - WITH (TIMESTAMP='ts', TIMEUNIT='ss')`, - testData: []map[string]interface{}{ - {"deviceId": "device1", "temperature": 30.0, "status": "active", "ts": time.Now()}, - {"deviceId": "device1", "temperature": 20.0, "status": "inactive", "ts": time.Now()}, - {"deviceId": "device1", "temperature": 35.0, "status": "active", "ts": time.Now()}, - }, - wantErr: false, - }, - { - name: "HAVING子句中的CASE表达式", - sql: `SELECT deviceId, - AVG(temperature) as avg_temp, - COUNT(*) as count - FROM stream - GROUP BY deviceId, TumblingWindow('5s') - HAVING AVG(CASE - WHEN temperature > 25 THEN 1 - ELSE 0 - END) > 0.5 - WITH (TIMESTAMP='ts', TIMEUNIT='ss')`, - testData: []map[string]interface{}{ - {"deviceId": "device1", "temperature": 30.0, "ts": time.Now()}, - {"deviceId": "device1", "temperature": 28.0, "ts": time.Now()}, - {"deviceId": "device1", "temperature": 20.0, "ts": time.Now()}, - {"deviceId": "device2", "temperature": 22.0, "ts": time.Now()}, - {"deviceId": "device2", "temperature": 21.0, "ts": time.Now()}, - }, - wantErr: false, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + streamsql := New() defer streamsql.Stop() @@ -869,223 +466,421 @@ func TestCaseExpressionAggregated(t *testing.T) { } if err != nil { - //t.Logf("SQL execution failed for %s: %v", tt.name, err) - // 如果SQL执行失败,说明不支持该语法 t.Skip("CASE expression not yet supported in aggregated context") return } - // 如果执行成功,继续测试数据处理 strm := streamsql.stream - // 添加数据并获取结果 - var results []map[string]interface{} - var resultsMutex sync.Mutex + // 使用通道等待结果,避免固定等待时间 + resultChan := make(chan interface{}, 5) strm.AddSink(func(result interface{}) { - if resultSlice, ok := result.([]map[string]interface{}); ok { - resultsMutex.Lock() - results = append(results, resultSlice...) - resultsMutex.Unlock() + select { + case resultChan <- result: + default: } }) for _, data := range tt.testData { - strm.AddData(data) + strm.Emit(data) + } + + // 使用带超时的等待机制 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var results []map[string]interface{} + + // 等待窗口触发或超时 + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-time.After(1200 * time.Millisecond): + // 如果1.2秒内没有结果,手动触发窗口 + if strm.Window != nil { + strm.Window.Trigger() + } + // 再等待一点时间获取结果 + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-time.After(200 * time.Millisecond): + // 超时,继续验证 + } + case <-ctx.Done(): + return + } + + // 验证结果 + if len(results) > 0 { + firstResult := results[0] + assert.NotNil(t, firstResult) + assert.Contains(t, firstResult, "deviceId", "Result should contain deviceId") + } + }) + } +} + +// TestCaseExpressionNullHandlingInAggregation 测试CASE表达式在聚合函数中正确处理NULL值 +func TestCaseExpressionNullHandlingInAggregation(t *testing.T) { + testCases := []struct { + name string + sql string + testData []map[string]interface{} + expectedDeviceResults map[string]map[string]interface{} + description string + }{ + { + name: "CASE表达式在SUM/COUNT/AVG聚合中正确处理NULL值", + sql: `SELECT deviceType, + SUM(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_sum, + COUNT(CASE WHEN temperature > 30 THEN 1 ELSE NULL END) as high_temp_count, + AVG(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_avg, + COUNT(*) as total_count + FROM stream + GROUP BY deviceType, TumblingWindow('2s')`, + testData: []map[string]interface{}{ + {"deviceType": "sensor", "temperature": 35.0}, // 满足条件 + {"deviceType": "sensor", "temperature": 25.0}, // 不满足条件,返回NULL + {"deviceType": "sensor", "temperature": 32.0}, // 满足条件 + {"deviceType": "monitor", "temperature": 28.0}, // 不满足条件,返回NULL + {"deviceType": "monitor", "temperature": 33.0}, // 满足条件 + }, + expectedDeviceResults: map[string]map[string]interface{}{ + "sensor": { + "high_temp_sum": 67.0, // 35 + 32 + "high_temp_count": 2.0, // COUNT应该忽略NULL + "high_temp_avg": 33.5, // (35 + 32) / 2 + "total_count": 3.0, // 总记录数 + }, + "monitor": { + "high_temp_sum": 33.0, // 只有33 + "high_temp_count": 1.0, // COUNT应该忽略NULL + "high_temp_avg": 33.0, // 只有33 + "total_count": 2.0, // 总记录数 + }, + }, + description: "验证CASE表达式返回的NULL值被聚合函数正确忽略", + }, + { + name: "全部返回NULL值时聚合函数的行为", + sql: `SELECT deviceType, + SUM(CASE WHEN temperature > 50 THEN temperature ELSE NULL END) as impossible_sum, + COUNT(CASE WHEN temperature > 50 THEN 1 ELSE NULL END) as impossible_count, + AVG(CASE WHEN temperature > 50 THEN temperature ELSE NULL END) as impossible_avg, + COUNT(*) as total_count + FROM stream + GROUP BY deviceType, TumblingWindow('2s')`, + testData: []map[string]interface{}{ + {"deviceType": "cold_sensor", "temperature": 20.0}, // 不满足条件 + {"deviceType": "cold_sensor", "temperature": 25.0}, // 不满足条件 + {"deviceType": "cold_sensor", "temperature": 30.0}, // 不满足条件 + }, + expectedDeviceResults: map[string]map[string]interface{}{ + "cold_sensor": { + "impossible_sum": nil, // 全NULL时SUM应返回NULL + "impossible_count": 0.0, // COUNT应返回0 + "impossible_avg": nil, // 全NULL时AVG应返回NULL + "total_count": 3.0, // 总记录数 + }, + }, + description: "验证当CASE表达式全部返回NULL时,聚合函数的正确行为", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // 创建StreamSQL实例 + ssql := New() + defer ssql.Stop() + + // 执行SQL + err := ssql.Execute(tc.sql) + assert.NoError(t, err, "SQL执行应该成功") + + // 收集结果 + var results []map[string]interface{} + resultChan := make(chan interface{}, 10) + + ssql.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + for _, data := range tc.testData { + ssql.Stream().Emit(data) } // 等待窗口触发 - time.Sleep(6 * time.Second) + time.Sleep(3 * time.Second) - // 手动触发窗口 - if strm.Window != nil { - strm.Window.Trigger() + // 收集结果 + collecting: + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-time.After(500 * time.Millisecond): + break collecting + } } - // 等待结果 - time.Sleep(200 * time.Millisecond) + // 验证结果数量 + assert.Len(t, results, len(tc.expectedDeviceResults), "结果数量应该匹配") - // 验证至少有结果返回 - resultsMutex.Lock() - hasResults := len(results) > 0 - var firstResult map[string]interface{} - if hasResults { - firstResult = results[0] - } - resultsMutex.Unlock() - if hasResults { - assert.NotNil(t, firstResult) + // 验证各个deviceType的结果 + for _, result := range results { + deviceType := result["deviceType"].(string) + expected := tc.expectedDeviceResults[deviceType] - // 验证结果结构 - result := firstResult - assert.Contains(t, result, "deviceId", "Result should contain deviceId") + assert.NotNil(t, expected, "应该有设备类型 %s 的期望结果", deviceType) - // 检查CASE表达式在聚合中的支持情况 - for key, value := range result { - if key != "deviceId" && (value == 0 || value == 0.0) { - t.Logf("注意: %s 返回0,可能CASE表达式在聚合中暂不完全支持", key) + // 验证每个字段 + for key, expectedValue := range expected { + if key == "deviceType" { + continue + } + + actualValue := result[key] + + // 处理NULL值比较 + if expectedValue == nil { + assert.Nil(t, actualValue, + "设备类型 %s 的字段 %s 应该为NULL", deviceType, key) + } else { + assert.Equal(t, expectedValue, actualValue, + "设备类型 %s 的字段 %s 应该匹配", deviceType, key) } } - } else { - t.Log("No aggregation results received - this may be expected for some test cases") } }) } } -// TestComplexCaseExpressions 测试复杂的CASE表达式场景 -func TestComplexCaseExpressions(t *testing.T) { - tests := []struct { - name string - sql string - testData []map[string]interface{} - wantErr bool - }{ - { - name: "多条件CASE表达式", - sql: `SELECT deviceId, - CASE - WHEN temperature > 30 AND humidity > 70 THEN 'CRITICAL' - WHEN temperature > 25 OR humidity > 80 THEN 'WARNING' - WHEN temperature BETWEEN 20 AND 25 THEN 'NORMAL' - ELSE 'UNKNOWN' - END as alert_level - FROM stream`, - testData: []map[string]interface{}{ - {"deviceId": "device1", "temperature": 35.0, "humidity": 75.0}, - {"deviceId": "device2", "temperature": 28.0, "humidity": 60.0}, - {"deviceId": "device3", "temperature": 22.0, "humidity": 50.0}, - {"deviceId": "device4", "temperature": 15.0, "humidity": 60.0}, - }, - wantErr: false, - }, - { - name: "CASE表达式与数学运算", - sql: `SELECT deviceId, - temperature, - CASE - WHEN temperature > 30 THEN ROUND(temperature * 1.2) - WHEN temperature > 20 THEN temperature * 1.1 - ELSE temperature - END as processed_temp - FROM stream`, - testData: []map[string]interface{}{ - {"deviceId": "device1", "temperature": 35.5}, - {"deviceId": "device2", "temperature": 25.3}, - {"deviceId": "device3", "temperature": 15.7}, - }, - wantErr: false, - }, - { - name: "CASE表达式与字符串处理", - sql: `SELECT deviceId, - CASE - WHEN LENGTH(deviceId) > 10 THEN 'LONG_NAME' - WHEN deviceId LIKE 'device%' THEN 'DEVICE_TYPE' - ELSE 'OTHER' - END as device_category - FROM stream`, - testData: []map[string]interface{}{ - {"deviceId": "very_long_device_name"}, - {"deviceId": "device1"}, - {"deviceId": "sensor1"}, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - err := streamsql.Execute(tt.sql) - - if tt.wantErr { - assert.Error(t, err) - return - } - - if err != nil { - //t.Logf("SQL execution failed for %s: %v", tt.name, err) - t.Skip("Complex CASE expression not yet supported") - return - } - - // 如果执行成功,继续测试数据处理 - strm := streamsql.stream - - // 添加测试数据 - for _, data := range tt.testData { - strm.AddData(data) - } - - // 简单验证能够执行而不报错 - //t.Log("Complex CASE expression executed successfully") - }) - } -} - -// TestCaseExpressionEdgeCases 测试边界情况 -func TestCaseExpressionEdgeCases(t *testing.T) { +// TestHavingWithCaseExpression 测试HAVING子句中的CASE表达式 +func TestHavingWithCaseExpression(t *testing.T) { tests := []struct { name string sql string wantErr bool + errMsg string }{ { - name: "CASE表达式语法错误 - 缺少END", - sql: `SELECT deviceId, - CASE - WHEN temperature > 30 THEN 'HOT' - ELSE 'NORMAL' - FROM stream`, - wantErr: false, // SQL解析器可能会容错处理 + name: "简单CASE表达式在HAVING中", + sql: `SELECT deviceId, + AVG(temperature) as avg_temp, + AVG(CASE WHEN temperature > 30 THEN temperature ELSE 0 END) as conditional_avg + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + HAVING conditional_avg > 25 + WITH (TIMESTAMP='ts', TIMEUNIT='ss')`, + wantErr: false, }, { - name: "CASE表达式语法错误 - 缺少THEN", - sql: `SELECT deviceId, - CASE - WHEN temperature > 30 'HOT' - ELSE 'NORMAL' - END as temp_category - FROM stream`, - wantErr: false, // SQL解析器可能会容错处理 - }, - { - name: "空的CASE表达式", - sql: `SELECT deviceId, - CASE END as empty_case - FROM stream`, - wantErr: false, // SQL解析器可能会容错处理 - }, - { - name: "只有ELSE的CASE表达式", - sql: `SELECT deviceId, - CASE - ELSE 'DEFAULT' - END as only_else - FROM stream`, - wantErr: false, // 这在SQL标准中是合法的 + name: "复杂CASE表达式在HAVING中", + sql: `SELECT deviceId, + COUNT(*) as total_count, + SUM(CASE + WHEN temperature > 35 THEN 2 + WHEN temperature > 25 THEN 1 + ELSE 0 + END) as weighted_score + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + HAVING weighted_score > 3 + WITH (TIMESTAMP='ts', TIMEUNIT='ss')`, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - err := streamsql.Execute(tt.sql) + // 测试SQL解析 + _, err := rsql.NewParser(tt.sql).Parse() if tt.wantErr { - assert.Error(t, err, "Expected SQL execution to fail") + assert.Error(t, err, "应该产生解析错误") + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg, "错误消息应该包含期望的内容") + } } else { + assert.NoError(t, err, "SQL解析应该成功") + } + + // 如果解析成功,尝试创建StreamSQL实例 + if !tt.wantErr && err == nil { + streamSQL := New() + defer streamSQL.Stop() + + err = streamSQL.Execute(tt.sql) if err != nil { - t.Logf("SQL execution failed for %s: %v", tt.name, err) - t.Skip("CASE expression syntax not yet supported") - } else { - assert.NoError(t, err, "Expected SQL execution to succeed") + t.Skipf("HAVING中的CASE表达式执行暂不支持: %v", err) } } }) } } + +// TestHavingWithCaseExpressionFunctional 功能测试HAVING子句中的CASE表达式 +func TestHavingWithCaseExpressionFunctional(t *testing.T) { + sql := `SELECT deviceId, + AVG(temperature) as avg_temp, + COUNT(*) as total_count, + SUM(CASE WHEN temperature > 30 THEN 1 ELSE 0 END) as hot_count + FROM stream + GROUP BY deviceId, TumblingWindow('2s') + HAVING hot_count >= 2 + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + // 创建StreamSQL实例 + streamSQL := New() + defer streamSQL.Stop() + + err := streamSQL.Execute(sql) + assert.NoError(t, err, "执行SQL应该成功") + + // 模拟数据 + baseTime := time.Now() + testData := []map[string]interface{}{ + // device1: 3条高温记录,应该通过HAVING条件 + {"deviceId": "device1", "temperature": 35.0, "ts": baseTime}, + {"deviceId": "device1", "temperature": 32.0, "ts": baseTime}, + {"deviceId": "device1", "temperature": 31.0, "ts": baseTime}, + {"deviceId": "device1", "temperature": 25.0, "ts": baseTime}, // 不是高温 + + // device2: 1条高温记录,不应该通过HAVING条件 + {"deviceId": "device2", "temperature": 33.0, "ts": baseTime}, + {"deviceId": "device2", "temperature": 28.0, "ts": baseTime}, + {"deviceId": "device2", "temperature": 26.0, "ts": baseTime}, + + // device3: 2条高温记录,应该通过HAVING条件 + {"deviceId": "device3", "temperature": 34.0, "ts": baseTime}, + {"deviceId": "device3", "temperature": 31.0, "ts": baseTime}, + {"deviceId": "device3", "temperature": 29.0, "ts": baseTime}, + } + + // 添加数据并获取结果 + var results []map[string]interface{} + var resultsMutex sync.Mutex + streamSQL.stream.AddSink(func(result interface{}) { + resultsMutex.Lock() + defer resultsMutex.Unlock() + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + }) + + for _, data := range testData { + streamSQL.Emit(data) + } + + // 等待窗口触发 + time.Sleep(2500 * time.Millisecond) + + // 手动触发窗口 + streamSQL.stream.Window.Trigger() + + // 等待结果 + time.Sleep(200 * time.Millisecond) + + // 验证结果 + resultsMutex.Lock() + defer resultsMutex.Unlock() + + // 应该只有device1和device3通过HAVING条件(hot_count >= 2) + assert.Greater(t, len(results), 0, "应该有结果返回") + + // 验证结果中只包含满足HAVING条件的设备 + deviceResults := make(map[string]map[string]interface{}) + for _, result := range results { + deviceId, ok := result["deviceId"].(string) + assert.True(t, ok, "deviceId应该是字符串类型") + deviceResults[deviceId] = result + } + + // 验证HAVING条件的过滤效果 + for deviceId, result := range deviceResults { + hotCount := getFloat64Value(result["hot_count"]) + assert.GreaterOrEqual(t, hotCount, 2.0, + "设备 %s 的hot_count应该 >= 2 (HAVING条件)", deviceId) + } + + // device2应该被HAVING条件过滤掉(只有1条高温记录 < 2) + assert.NotContains(t, deviceResults, "device2", + "device2应该被HAVING条件过滤掉(hot_count=1 < 2)") + + // 验证期望的设备出现在结果中 + assert.Contains(t, deviceResults, "device1", "device1应该通过HAVING条件") + assert.Contains(t, deviceResults, "device3", "device3应该通过HAVING条件") +} + +// TestNegativeNumberInSQL 测试负数在完整SQL中的使用 +func TestNegativeNumberInSQL(t *testing.T) { + sql := `SELECT deviceId, + temperature, + CASE + WHEN temperature < -10.0 THEN 'FREEZING' + WHEN temperature < 0 THEN 'COLD' + WHEN temperature = 0 THEN 'ZERO' + ELSE 'POSITIVE' + END as temp_category, + CASE + WHEN temperature > 0 THEN temperature + ELSE -1.0 + END as adjusted_temp + FROM stream` + + streamSQL := New() + defer streamSQL.Stop() + + err := streamSQL.Execute(sql) + assert.NoError(t, err, "包含负数的SQL应该执行成功") + + // 模拟包含负数的数据 + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "temperature": -15.0}, + {"deviceId": "sensor2", "temperature": -5.0}, + {"deviceId": "sensor3", "temperature": 0.0}, + {"deviceId": "sensor4", "temperature": 10.0}, + } + + // 收集结果 + var results []map[string]interface{} + var resultsMutex sync.Mutex + + streamSQL.stream.AddSink(func(result interface{}) { + resultsMutex.Lock() + defer resultsMutex.Unlock() + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } else if resultMap, ok := result.(map[string]interface{}); ok { + results = append(results, resultMap) + } + }) + + // 添加测试数据 + for _, data := range testData { + streamSQL.Emit(data) + } + + // 等待处理 + time.Sleep(200 * time.Millisecond) + + // 验证结果 + resultsMutex.Lock() + defer resultsMutex.Unlock() + + for _, result := range results { + // 验证包含必要字段 + assert.Contains(t, result, "deviceId", "结果应该包含deviceId") + assert.Contains(t, result, "temperature", "结果应该包含temperature") + assert.Contains(t, result, "temp_category", "结果应该包含temp_category") + assert.Contains(t, result, "adjusted_temp", "结果应该包含adjusted_temp") + } +} diff --git a/streamsql_custom_functions_test.go b/streamsql_custom_functions_test.go index 0fd7e4f..4f0f0d3 100644 --- a/streamsql_custom_functions_test.go +++ b/streamsql_custom_functions_test.go @@ -72,7 +72,7 @@ func TestCustomMathFunctions(t *testing.T) { // 创建结果接收通道 resultChan := make(chan interface{}, 10) - streamsql.Stream().AddSink(func(result interface{}) { + streamsql.AddSink(func(result interface{}) { resultChan <- result }) @@ -86,7 +86,7 @@ func TestCustomMathFunctions(t *testing.T) { "y2": 4.0, // 距离应该是5 } - streamsql.AddData(testData) + streamsql.Emit(testData) // 等待窗口触发 time.Sleep(1 * time.Second) @@ -179,7 +179,7 @@ func TestCustomStringFunctions(t *testing.T) { // 创建结果接收通道 resultChan := make(chan interface{}, 10) - streamsql.Stream().AddSink(func(result interface{}) { + streamsql.AddSink(func(result interface{}) { resultChan <- result }) @@ -189,7 +189,7 @@ func TestCustomStringFunctions(t *testing.T) { "metadata": `{"version":"1.0","type":"temperature"}`, } - streamsql.AddData(testData) + streamsql.Emit(testData) time.Sleep(200 * time.Millisecond) // 验证结果 @@ -318,7 +318,7 @@ func TestCustomAggregateFunctions(t *testing.T) { // 创建结果接收通道 resultChan := make(chan interface{}, 10) - streamsql.Stream().AddSink(func(result interface{}) { + streamsql.AddSink(func(result interface{}) { resultChan <- result }) @@ -331,7 +331,7 @@ func TestCustomAggregateFunctions(t *testing.T) { } for _, data := range testData { - streamsql.AddData(data) + streamsql.Emit(data) } time.Sleep(1 * time.Second) @@ -557,7 +557,7 @@ func TestCustomFunctionWithAggregation(t *testing.T) { // 创建结果接收通道 resultChan := make(chan interface{}, 10) - streamsql.Stream().AddSink(func(result interface{}) { + streamsql.AddSink(func(result interface{}) { resultChan <- result }) @@ -568,7 +568,7 @@ func TestCustomFunctionWithAggregation(t *testing.T) { } for _, data := range testData { - streamsql.AddData(data) + streamsql.Emit(data) } time.Sleep(1 * time.Second) diff --git a/streamsql_function_integration_test.go b/streamsql_function_integration_test.go index b8997bb..9c45246 100644 --- a/streamsql_function_integration_test.go +++ b/streamsql_function_integration_test.go @@ -32,7 +32,7 @@ func TestFunctionIntegrationNonAggregation(t *testing.T) { "temperature": -25.5, "humidity": 64.0, } - strm.AddData(testData) + strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -77,7 +77,7 @@ func TestFunctionIntegrationNonAggregation(t *testing.T) { "device": "sensor01", "location": "ROOM_A", } - strm.AddData(testData) + strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -120,7 +120,7 @@ func TestFunctionIntegrationNonAggregation(t *testing.T) { "temperature": 25.7, "humidity": 65.0, } - strm.AddData(testData) + strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -162,7 +162,7 @@ func TestFunctionIntegrationNonAggregation(t *testing.T) { "device": "test-device", "timestamp": testTime, } - strm.AddData(testData) + strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -204,7 +204,7 @@ func TestFunctionIntegrationNonAggregation(t *testing.T) { "device": "test-device", "metadata": `{"type": "temperature_sensor", "version": "1.0"}`, } - strm.AddData(testData) + strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -253,7 +253,7 @@ func TestFunctionIntegrationAggregation(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -317,7 +317,7 @@ func TestFunctionIntegrationAggregation(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -368,7 +368,7 @@ func TestFunctionIntegrationAggregation(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -430,7 +430,7 @@ func TestFunctionIntegrationMixed(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -449,12 +449,6 @@ func TestFunctionIntegrationMixed(t *testing.T) { item := resultSlice[0] - // 打印调试信息 - t.Logf("Result item: %+v", item) - for key, value := range item { - t.Logf(" %s: %v (type: %T)", key, value, value) - } - assert.Equal(t, "sensor1", item["device"]) assert.Equal(t, "SENSOR1", item["device_upper"]) @@ -472,7 +466,6 @@ func TestFunctionIntegrationMixed(t *testing.T) { } else if val, ok := roundedAvg.(float64); ok { // 验证结果在合理范围内 assert.True(t, val >= 25.0 && val <= 25.5, "rounded_avg should be between 25.0 and 25.5, got %v", val) - t.Logf("rounded_avg test passed: %v", val) } else { t.Errorf("rounded_avg is not a float64: %v (type: %T)", roundedAvg, roundedAvg) } @@ -504,7 +497,7 @@ func TestFunctionIntegrationMixed(t *testing.T) { "device": "sensor1", "temperature": 25.7, } - strm.AddData(testData) + strm.Emit(testData) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -547,7 +540,7 @@ func TestFunctionIntegrationMixed(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -584,7 +577,6 @@ func TestNestedFunctionSupport(t *testing.T) { // 执行包含 round(avg(temperature), 2) 的查询 query := "SELECT device, round(avg(temperature), 2) as rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')" - t.Logf("Executing query: %s", query) err := streamsql.Execute(query) assert.Nil(t, err) @@ -602,7 +594,7 @@ func TestNestedFunctionSupport(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -620,11 +612,6 @@ func TestNestedFunctionSupport(t *testing.T) { assert.Len(t, resultSlice, 1) item := resultSlice[0] - t.Logf("Result item: %+v", item) - for key, value := range item { - t.Logf(" %s: %v (type: %T)", key, value, value) - } - assert.Equal(t, "sensor1", item["device"]) // 验证四舍五入的平均值 @@ -672,7 +659,7 @@ func TestNestedFunctionSupport(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -741,7 +728,7 @@ func TestNestedFunctionSupport(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -807,7 +794,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { }) // 添加测试数据 - strm.AddData(map[string]interface{}{"device": "sensor1", "temperature": 25.67}) + strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 25.67}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -847,7 +834,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { }) // 添加测试数据 - strm.AddData(map[string]interface{}{"device": "sensor1"}) + strm.Emit(map[string]interface{}{"device": "sensor1"}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -887,7 +874,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { }) // 添加测试数据 - strm.AddData(map[string]interface{}{"device": "sensor1", "temperature": 16.0}) + strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 16.0}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -934,7 +921,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -979,7 +966,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { }) // 添加测试数据 - strm.AddData(map[string]interface{}{"device": "sensor1", "created_at": "2023-12-25 15:30:45"}) + strm.Emit(map[string]interface{}{"device": "sensor1", "created_at": "2023-12-25 15:30:45"}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -1019,7 +1006,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { }) // 添加测试数据(不包含invalid_field) - strm.AddData(map[string]interface{}{"device": "sensor1", "temperature": 25.0}) + strm.Emit(map[string]interface{}{"device": "sensor1", "temperature": 25.0}) // 等待结果 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) diff --git a/streamsql_is_null_test.go b/streamsql_is_null_test.go index c96a851..c31ab54 100644 --- a/streamsql_is_null_test.go +++ b/streamsql_is_null_test.go @@ -1,1033 +1,1033 @@ -package streamsql - -import ( - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestIsNullOperatorInSQL 测试IS NULL和IS NOT NULL语法功能 -func TestIsNullOperatorInSQL(t *testing.T) { - testCases := []struct { - name string - sql string - testData []map[string]interface{} - expected []map[string]interface{} - }{ - { - name: "IS NULL测试", - sql: "SELECT deviceId, value FROM stream WHERE value IS NULL", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor3", "value": 30.0}, - {"deviceId": "sensor4", "value": nil}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor4", "value": nil}, - }, - }, - { - name: "IS NOT NULL测试", - sql: "SELECT deviceId, value FROM stream WHERE value IS NOT NULL", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor3", "value": 30.0}, - {"deviceId": "sensor4", "value": nil}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor3", "value": 30.0}, - }, - }, - { - name: "嵌套字段IS NULL测试", - sql: "SELECT deviceId, device.location FROM stream WHERE device.location IS NULL", - testData: []map[string]interface{}{ - { - "deviceId": "sensor1", - "device": map[string]interface{}{ - "location": "warehouse-A", - }, - }, - { - "deviceId": "sensor2", - "device": map[string]interface{}{ - "location": nil, - }, - }, - { - "deviceId": "sensor3", - "device": map[string]interface{}{}, - }, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor2", "device.location": nil}, - {"deviceId": "sensor3", "device.location": nil}, // 字段不存在也被认为是null - }, - }, - { - name: "组合条件 - IS NULL AND其他条件", - sql: "SELECT deviceId, value, status FROM stream WHERE value IS NULL AND status = 'active'", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, - {"deviceId": "sensor2", "value": nil, "status": "active"}, - {"deviceId": "sensor3", "value": nil, "status": "inactive"}, - {"deviceId": "sensor4", "value": 30.0, "status": "active"}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor2", "value": nil, "status": "active"}, - }, - }, - { - name: "组合条件 - IS NOT NULL OR其他条件", - sql: "SELECT deviceId, value, status FROM stream WHERE value IS NOT NULL OR status = 'error'", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, - {"deviceId": "sensor2", "value": nil, "status": "active"}, - {"deviceId": "sensor3", "value": nil, "status": "error"}, - {"deviceId": "sensor4", "value": 30.0, "status": "inactive"}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, - {"deviceId": "sensor3", "value": nil, "status": "error"}, - {"deviceId": "sensor4", "value": 30.0, "status": "inactive"}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // 创建StreamSQL实例 - ssql := New() - defer ssql.Stop() - - // 执行SQL - err := ssql.Execute(tc.sql) - require.NoError(t, err) - - // 收集结果 - var results []map[string]interface{} - resultChan := make(chan interface{}, 10) - - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 使用一个done channel来同步 - done := make(chan bool, 1) - - // 添加测试数据 - for _, data := range tc.testData { - ssql.Stream().AddData(data) - } - - // 在另一个goroutine中收集结果 - go func() { - defer func() { done <- true }() - // 等待一段时间收集结果 - timeout := time.After(300 * time.Millisecond) - for { - select { - case result := <-resultChan: - if resultSlice, ok := result.([]map[string]interface{}); ok { - results = append(results, resultSlice...) - } - case <-timeout: - return - } - } - }() - - // 等待收集完成 - <-done - - // 验证结果数量 - assert.Len(t, results, len(tc.expected), "结果数量应该匹配") - - // 验证结果内容(不依赖顺序) - expectedDeviceIds := make([]string, len(tc.expected)) - for i, exp := range tc.expected { - expectedDeviceIds[i] = exp["deviceId"].(string) - } - - actualDeviceIds := make([]string, len(results)) - for i, result := range results { - actualDeviceIds[i] = result["deviceId"].(string) - } - - // 验证每个期望的设备ID都在结果中 - for _, expectedId := range expectedDeviceIds { - assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) - } - - // 验证每个结果的字段值 - for _, result := range results { - deviceId := result["deviceId"].(string) - // 找到对应的期望结果 - var expectedResult map[string]interface{} - for _, exp := range tc.expected { - if exp["deviceId"].(string) == deviceId { - expectedResult = exp - break - } - } - - if expectedResult != nil { - for key, expectedValue := range expectedResult { - actualValue := result[key] - assert.Equal(t, expectedValue, actualValue, - "设备 %s 的字段 %s 值应该匹配: 期望 %v, 实际 %v", deviceId, key, expectedValue, actualValue) - } - } - } - }) - } -} - -// TestIsNullInAggregation 测试聚合查询中的IS NULL -func TestIsNullInAggregation(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 聚合查询:统计非空值的数量 - sql := `SELECT deviceType, - COUNT(*) as total_count, - COUNT(value) as non_null_count - FROM stream - WHERE value IS NOT NULL - GROUP BY deviceType, TumblingWindow('2s')` - - err := ssql.Execute(sql) - require.NoError(t, err) - - // 收集结果 - resultChan := make(chan interface{}, 10) - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - testData := []map[string]interface{}{ - {"deviceType": "temperature", "value": 25.5}, - {"deviceType": "temperature", "value": nil}, - {"deviceType": "temperature", "value": 27.0}, - {"deviceType": "humidity", "value": 60.0}, - {"deviceType": "humidity", "value": nil}, - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待窗口触发 - time.Sleep(3 * time.Second) - - // 验证结果 - select { - case result := <-resultChan: - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - // 应该有temperature和humidity两种类型的结果 - assert.GreaterOrEqual(t, len(resultSlice), 1, "应该至少有一个聚合结果") - - for _, item := range resultSlice { - deviceType := item["deviceType"] - totalCount, _ := item["total_count"].(float64) - nonNullCount, _ := item["non_null_count"].(float64) - - if deviceType == "temperature" { - // temperature有2个非空值(25.5, 27.0) - assert.Equal(t, 2.0, totalCount, "temperature总数应该是2") - assert.Equal(t, 2.0, nonNullCount, "temperature非空数应该是2") - } else if deviceType == "humidity" { - // humidity有1个非空值(60.0) - assert.Equal(t, 1.0, totalCount, "humidity总数应该是1") - assert.Equal(t, 1.0, nonNullCount, "humidity非空数应该是1") - } - } - case <-time.After(5 * time.Second): - t.Fatal("测试超时,未收到聚合结果") - } -} - -// TestIsNullInHaving 测试HAVING子句中真正的IS NULL功能 -func TestIsNullInHaving(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 测试HAVING子句中的IS NULL:只返回平均值为NULL的设备类型 - sql := `SELECT deviceType, - COUNT(*) as total_count, - AVG(value) as avg_value - FROM stream - GROUP BY deviceType, TumblingWindow('2s') - HAVING avg_value IS NULL` - - err := ssql.Execute(sql) - require.NoError(t, err) - - resultChan := make(chan interface{}, 10) - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据:只给pressure设备类型添加null值,这样它的平均值会是null - testData := []map[string]interface{}{ - {"deviceType": "temperature", "value": 25.0}, - {"deviceType": "temperature", "value": 27.0}, // temperature有值,平均值不为null - {"deviceType": "humidity", "value": 60.0}, // humidity有值,平均值不为null - {"deviceType": "pressure", "value": nil}, // pressure只有null值 - {"deviceType": "pressure", "value": nil}, // pressure再次null值,平均值会是null - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待窗口触发 - time.Sleep(3 * time.Second) - - // 验证结果 - select { - case result := <-resultChan: - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - // 应该只有pressure类型的结果(平均值为null) - assert.Len(t, resultSlice, 1, "应该只有一个结果") - - if len(resultSlice) > 0 { - item := resultSlice[0] - assert.Equal(t, "pressure", item["deviceType"], "应该是pressure类型") - - // 验证avg_value确实为null - avgValue := item["avg_value"] - assert.Nil(t, avgValue, "pressure的平均值应该是null") - - // 验证total_count - totalCount, ok := item["total_count"].(float64) - assert.True(t, ok, "total_count应该是float64类型") - assert.Equal(t, 2.0, totalCount, "pressure应该有2条记录") - } - - case <-time.After(5 * time.Second): - t.Fatal("测试超时,未收到聚合结果") - } -} - -// TestIsNullInHavingWithIsNotNull 测试HAVING子句中的IS NOT NULL功能 -func TestIsNullInHavingWithIsNotNull(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 测试HAVING子句中的IS NOT NULL:只返回平均值不为NULL的设备类型 - sql := `SELECT deviceType, - COUNT(*) as total_count, - AVG(value) as avg_value - FROM stream - GROUP BY deviceType, TumblingWindow('2s') - HAVING avg_value IS NOT NULL` - - err := ssql.Execute(sql) - require.NoError(t, err) - - resultChan := make(chan interface{}, 10) - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - testData := []map[string]interface{}{ - {"deviceType": "temperature", "value": 25.0}, - {"deviceType": "temperature", "value": 27.0}, // temperature有值,平均值不为null - {"deviceType": "humidity", "value": 60.0}, // humidity有值,平均值不为null - {"deviceType": "pressure", "value": nil}, // pressure只有null值,平均值会是null - {"deviceType": "pressure", "value": nil}, - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待窗口触发 - time.Sleep(3 * time.Second) - - // 验证结果 - select { - case result := <-resultChan: - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - // 应该有temperature和humidity两种类型的结果(平均值不为null) - assert.Len(t, resultSlice, 2, "应该有两个结果") - - foundTypes := make([]string, 0) - for _, item := range resultSlice { - deviceType, ok := item["deviceType"].(string) - require.True(t, ok, "deviceType应该是string类型") - - // 验证avg_value不为null - avgValue := item["avg_value"] - assert.NotNil(t, avgValue, fmt.Sprintf("%s的平均值应该不为null", deviceType)) - - foundTypes = append(foundTypes, deviceType) - } - - // 验证包含temperature和humidity,不包含pressure - assert.Contains(t, foundTypes, "temperature", "结果应该包含temperature") - assert.Contains(t, foundTypes, "humidity", "结果应该包含humidity") - assert.NotContains(t, foundTypes, "pressure", "结果不应该包含pressure") - - case <-time.After(5 * time.Second): - t.Fatal("测试超时,未收到聚合结果") - } -} - -// TestIsNullWithOtherOperators 测试IS NULL与其他操作符的组合 -func TestIsNullWithOtherOperators(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 测试复杂的WHERE条件 - sql := `SELECT deviceId, value, status, location - FROM stream - WHERE (value IS NOT NULL AND value > 20) OR - (status IS NULL AND location LIKE 'warehouse%')` - - err := ssql.Execute(sql) - require.NoError(t, err) - - resultChan := make(chan interface{}, 10) - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - testData := []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.0, "status": "active", "location": "warehouse-A"}, // 满足第一个条件 - {"deviceId": "sensor2", "value": 15.0, "status": "active", "location": "warehouse-B"}, // 不满足条件 - {"deviceId": "sensor3", "value": nil, "status": nil, "location": "warehouse-C"}, // 满足第二个条件 - {"deviceId": "sensor4", "value": nil, "status": "inactive", "location": "warehouse-D"}, // 不满足条件 - {"deviceId": "sensor5", "value": 30.0, "status": nil, "location": "office-A"}, // 满足第一个条件 - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 使用超时方式安全收集结果 - var results []map[string]interface{} - timeout := time.After(500 * time.Millisecond) - -collecting: - for { - select { - case result := <-resultChan: - if resultSlice, ok := result.([]map[string]interface{}); ok { - results = append(results, resultSlice...) - } - case <-timeout: - break collecting - } - } - - // 验证结果:应该有sensor1, sensor3, sensor5 - assert.Len(t, results, 3, "应该有3个结果") - - expectedDeviceIds := []string{"sensor1", "sensor3", "sensor5"} - actualDeviceIds := make([]string, len(results)) - for i, result := range results { - actualDeviceIds[i] = result["deviceId"].(string) - } - - for _, expectedId := range expectedDeviceIds { - assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) - } -} - -// TestCaseWhenWithIsNull 测试CASE WHEN表达式中使用IS NULL和IS NOT NULL -func TestCaseWhenWithIsNull(t *testing.T) { - testCases := []struct { - name string - sql string - testData []map[string]interface{} - expected []map[string]interface{} - }{ - { - name: "CASE WHEN IS NULL基本测试", - sql: `SELECT deviceId, - CASE WHEN status IS NULL THEN 0 - WHEN status IS NOT NULL THEN 1 - ELSE 2 END as status_flag - FROM stream`, - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "status": "active"}, - {"deviceId": "sensor2", "status": nil}, - {"deviceId": "sensor3", "status": "inactive"}, - {"deviceId": "sensor4"}, // 没有status字段 - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "status_flag": 1.0}, - {"deviceId": "sensor2", "status_flag": 0.0}, - {"deviceId": "sensor3", "status_flag": 1.0}, - {"deviceId": "sensor4", "status_flag": 0.0}, - }, - }, - { - name: "CASE WHEN IS NOT NULL复杂条件测试", - sql: `SELECT deviceId, - CASE WHEN temperature IS NOT NULL AND temperature > 25 THEN 2 - WHEN temperature IS NOT NULL AND temperature <= 25 THEN 1 - WHEN temperature IS NULL THEN 0 - ELSE 3 END as temp_level - FROM stream`, - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "temperature": 30.0}, - {"deviceId": "sensor2", "temperature": 20.0}, - {"deviceId": "sensor3", "temperature": nil}, - {"deviceId": "sensor4"}, // 没有temperature字段 - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "temp_level": 2.0}, - {"deviceId": "sensor2", "temp_level": 1.0}, - {"deviceId": "sensor3", "temp_level": 0.0}, - {"deviceId": "sensor4", "temp_level": 0.0}, - }, - }, - { - name: "多个CASE WHEN IS NULL条件测试", - sql: `SELECT deviceId, - CASE WHEN status IS NULL AND temperature IS NULL THEN 0 - WHEN status IS NULL AND temperature IS NOT NULL THEN 1 - WHEN status IS NOT NULL AND temperature IS NULL THEN 2 - WHEN status IS NOT NULL AND temperature IS NOT NULL THEN 3 - ELSE 4 END as combined_flag - FROM stream`, - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "status": "active", "temperature": 25.0}, - {"deviceId": "sensor2", "status": "active", "temperature": nil}, - {"deviceId": "sensor3", "status": nil, "temperature": 30.0}, - {"deviceId": "sensor4", "status": nil, "temperature": nil}, - {"deviceId": "sensor5"}, // 两个字段都不存在 - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "combined_flag": 3.0}, - {"deviceId": "sensor2", "combined_flag": 2.0}, - {"deviceId": "sensor3", "combined_flag": 1.0}, - {"deviceId": "sensor4", "combined_flag": 0.0}, - {"deviceId": "sensor5", "combined_flag": 0.0}, - }, - }, - { - name: "CASE WHEN IS NULL与聚合函数结合测试", - sql: `SELECT deviceType, - COUNT(*) as total_count, - SUM(CASE WHEN value IS NULL THEN 1 ELSE 0 END) as null_count, - SUM(CASE WHEN value IS NOT NULL THEN 1 ELSE 0 END) as non_null_count - FROM stream - GROUP BY deviceType, TumblingWindow('2s')`, - testData: []map[string]interface{}{ - {"deviceType": "temperature", "value": 25.0}, - {"deviceType": "temperature", "value": nil}, - {"deviceType": "temperature", "value": 27.0}, - {"deviceType": "humidity", "value": 60.0}, - {"deviceType": "humidity", "value": nil}, - {"deviceType": "humidity", "value": nil}, - }, - expected: []map[string]interface{}{ - { - "deviceType": "temperature", - "total_count": 3.0, - "null_count": 1.0, - "non_null_count": 2.0, - }, - { - "deviceType": "humidity", - "total_count": 3.0, - "null_count": 2.0, - "non_null_count": 1.0, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // 创建StreamSQL实例 - ssql := New() - defer ssql.Stop() - - // 执行SQL - err := ssql.Execute(tc.sql) - require.NoError(t, err) - - // 收集结果 - var results []map[string]interface{} - resultChan := make(chan interface{}, 10) - - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - for _, data := range tc.testData { - ssql.Stream().AddData(data) - } - - // 使用超时方式安全收集结果 - var timeout time.Duration - if tc.name == "CASE WHEN IS NULL与聚合函数结合测试" { - timeout = 4 * time.Second // 聚合查询需要更长时间 - } else { - timeout = 500 * time.Millisecond - } - - timeoutChan := time.After(timeout) - - collecting: - for { - select { - case result := <-resultChan: - if resultSlice, ok := result.([]map[string]interface{}); ok { - results = append(results, resultSlice...) - } - case <-timeoutChan: - break collecting - } - } - - // 验证结果数量 - assert.Len(t, results, len(tc.expected), "结果数量应该匹配") - - // 对于聚合查询,验证逻辑略有不同 - if tc.name == "CASE WHEN IS NULL与聚合函数结合测试" { - // 验证每个deviceType的结果 - for _, expectedResult := range tc.expected { - expectedDeviceType := expectedResult["deviceType"].(string) - - // 在结果中找到对应的deviceType - var actualResult map[string]interface{} - for _, result := range results { - if result["deviceType"].(string) == expectedDeviceType { - actualResult = result - break - } - } - - require.NotNil(t, actualResult, "应该找到设备类型 %s 的结果", expectedDeviceType) - - // 验证各个统计值 - assert.Equal(t, expectedResult["total_count"], actualResult["total_count"], - "设备类型 %s 的total_count应该匹配", expectedDeviceType) - assert.Equal(t, expectedResult["null_count"], actualResult["null_count"], - "设备类型 %s 的null_count应该匹配", expectedDeviceType) - assert.Equal(t, expectedResult["non_null_count"], actualResult["non_null_count"], - "设备类型 %s 的non_null_count应该匹配", expectedDeviceType) - } - } else { - // 验证普通查询的结果 - expectedDeviceIds := make([]string, len(tc.expected)) - for i, exp := range tc.expected { - expectedDeviceIds[i] = exp["deviceId"].(string) - } - - actualDeviceIds := make([]string, len(results)) - for i, result := range results { - actualDeviceIds[i] = result["deviceId"].(string) - } - - // 验证每个期望的设备ID都在结果中 - for _, expectedId := range expectedDeviceIds { - assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) - } - - // 验证每个结果的字段值 - for _, result := range results { - deviceId := result["deviceId"].(string) - // 找到对应的期望结果 - var expectedResult map[string]interface{} - for _, exp := range tc.expected { - if exp["deviceId"].(string) == deviceId { - expectedResult = exp - break - } - } - - if expectedResult != nil { - for key, expectedValue := range expectedResult { - if key != "deviceId" { // deviceId已经验证过了 - actualValue := result[key] - assert.Equal(t, expectedValue, actualValue, - "设备 %s 的字段 %s 值应该匹配: 期望 %v, 实际 %v", deviceId, key, expectedValue, actualValue) - } - } - } - } - } - }) - } -} - -// TestNullComparisons 测试 = nil、!= nil、= null、!= null 等语法 -func TestNullComparisons(t *testing.T) { - testCases := []struct { - name string - sql string - testData []map[string]interface{} - expected []map[string]interface{} - }{ - { - name: "fieldName = nil 测试", - sql: "SELECT deviceId, value FROM stream WHERE value = nil", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor3", "value": 30.0}, - {"deviceId": "sensor4", "value": nil}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor4", "value": nil}, - }, - }, - { - name: "fieldName != nil 测试", - sql: "SELECT deviceId, value FROM stream WHERE value != nil", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor3", "value": 30.0}, - {"deviceId": "sensor4", "value": nil}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor3", "value": 30.0}, - }, - }, - { - name: "fieldName = null 测试", - sql: "SELECT deviceId, value FROM stream WHERE value = null", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor3", "value": 30.0}, - {"deviceId": "sensor4", "value": nil}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor4", "value": nil}, - }, - }, - { - name: "fieldName != null 测试", - sql: "SELECT deviceId, value FROM stream WHERE value != null", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor2", "value": nil}, - {"deviceId": "sensor3", "value": 30.0}, - {"deviceId": "sensor4", "value": nil}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5}, - {"deviceId": "sensor3", "value": 30.0}, - }, - }, - { - name: "嵌套字段 = nil 测试", - sql: "SELECT deviceId, device.location FROM stream WHERE device.location = nil", - testData: []map[string]interface{}{ - { - "deviceId": "sensor1", - "device": map[string]interface{}{ - "location": "warehouse-A", - }, - }, - { - "deviceId": "sensor2", - "device": map[string]interface{}{ - "location": nil, - }, - }, - { - "deviceId": "sensor3", - "device": map[string]interface{}{}, - }, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor2", "device.location": nil}, - {"deviceId": "sensor3", "device.location": nil}, // 字段不存在也被认为是null - }, - }, - { - name: "嵌套字段 != nil 测试", - sql: "SELECT deviceId, device.location FROM stream WHERE device.location != nil", - testData: []map[string]interface{}{ - { - "deviceId": "sensor1", - "device": map[string]interface{}{ - "location": "warehouse-A", - }, - }, - { - "deviceId": "sensor2", - "device": map[string]interface{}{ - "location": nil, - }, - }, - { - "deviceId": "sensor3", - "device": map[string]interface{}{}, - }, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "device.location": "warehouse-A"}, - }, - }, - { - name: "组合条件 - != nil AND 其他条件", - sql: "SELECT deviceId, value, status FROM stream WHERE value != nil AND value > 20", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, - {"deviceId": "sensor2", "value": nil, "status": "active"}, - {"deviceId": "sensor3", "value": 15.0, "status": "inactive"}, - {"deviceId": "sensor4", "value": 30.0, "status": "active"}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, - {"deviceId": "sensor4", "value": 30.0, "status": "active"}, - }, - }, - { - name: "组合条件 - = nil OR 其他条件", - sql: "SELECT deviceId, value, status FROM stream WHERE value = nil OR status = 'error'", - testData: []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.5, "status": "active"}, - {"deviceId": "sensor2", "value": nil, "status": "active"}, - {"deviceId": "sensor3", "value": 30.0, "status": "error"}, - {"deviceId": "sensor4", "value": nil, "status": "inactive"}, - }, - expected: []map[string]interface{}{ - {"deviceId": "sensor2", "value": nil, "status": "active"}, - {"deviceId": "sensor3", "value": 30.0, "status": "error"}, - {"deviceId": "sensor4", "value": nil, "status": "inactive"}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // 创建StreamSQL实例 - ssql := New() - defer ssql.Stop() - - // 执行SQL - err := ssql.Execute(tc.sql) - require.NoError(t, err) - - // 收集结果 - var results []map[string]interface{} - resultChan := make(chan interface{}, 10) - - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - for _, data := range tc.testData { - ssql.Stream().AddData(data) - } - - // 使用超时方式安全收集结果 - timeout := time.After(500 * time.Millisecond) - - collecting: - for { - select { - case result := <-resultChan: - if resultSlice, ok := result.([]map[string]interface{}); ok { - results = append(results, resultSlice...) - } - case <-timeout: - break collecting - } - } - - // 验证结果数量 - assert.Len(t, results, len(tc.expected), "结果数量应该匹配") - - // 验证结果内容(不依赖顺序) - expectedDeviceIds := make([]string, len(tc.expected)) - for i, exp := range tc.expected { - expectedDeviceIds[i] = exp["deviceId"].(string) - } - - actualDeviceIds := make([]string, len(results)) - for i, result := range results { - actualDeviceIds[i] = result["deviceId"].(string) - } - - // 验证每个期望的设备ID都在结果中 - for _, expectedId := range expectedDeviceIds { - assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) - } - - // 验证每个结果的字段值 - for _, result := range results { - deviceId := result["deviceId"].(string) - // 找到对应的期望结果 - var expectedResult map[string]interface{} - for _, exp := range tc.expected { - if exp["deviceId"].(string) == deviceId { - expectedResult = exp - break - } - } - - if expectedResult != nil { - for key, expectedValue := range expectedResult { - actualValue := result[key] - assert.Equal(t, expectedValue, actualValue, - "设备 %s 的字段 %s 值应该匹配: 期望 %v, 实际 %v", deviceId, key, expectedValue, actualValue) - } - } - } - }) - } -} - -// TestNullComparisonInAggregation 测试聚合查询中的 = nil 和 != nil -func TestNullComparisonInAggregation(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 聚合查询:统计非空值的数量 - sql := `SELECT deviceType, - COUNT(*) as total_count, - COUNT(value) as non_null_count - FROM stream - WHERE value != nil - GROUP BY deviceType, TumblingWindow('2s')` - - err := ssql.Execute(sql) - require.NoError(t, err) - - // 收集结果 - resultChan := make(chan interface{}, 10) - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - testData := []map[string]interface{}{ - {"deviceType": "temperature", "value": 25.5}, - {"deviceType": "temperature", "value": nil}, - {"deviceType": "temperature", "value": 27.0}, - {"deviceType": "humidity", "value": 60.0}, - {"deviceType": "humidity", "value": nil}, - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 等待窗口触发 - time.Sleep(3 * time.Second) - - // 验证结果 - select { - case result := <-resultChan: - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - // 应该有temperature和humidity两种类型的结果 - assert.GreaterOrEqual(t, len(resultSlice), 1, "应该至少有一个聚合结果") - - for _, item := range resultSlice { - deviceType := item["deviceType"] - totalCount, _ := item["total_count"].(float64) - nonNullCount, _ := item["non_null_count"].(float64) - - if deviceType == "temperature" { - // temperature有2个非空值(25.5, 27.0) - assert.Equal(t, 2.0, totalCount, "temperature总数应该是2") - assert.Equal(t, 2.0, nonNullCount, "temperature非空数应该是2") - } else if deviceType == "humidity" { - // humidity有1个非空值(60.0) - assert.Equal(t, 1.0, totalCount, "humidity总数应该是1") - assert.Equal(t, 1.0, nonNullCount, "humidity非空数应该是1") - } - } - case <-time.After(5 * time.Second): - t.Fatal("测试超时,未收到聚合结果") - } -} - -// TestMixedNullComparisons 测试混合使用 IS NULL、= nil、= null、!= null 等语法 -func TestMixedNullComparisons(t *testing.T) { - ssql := New() - defer ssql.Stop() - - // 测试混合null比较语法 - sql := `SELECT deviceId, value, status, priority - FROM stream - WHERE (value IS NOT NULL AND value > 20) OR - (status = nil AND priority != null)` - - err := ssql.Execute(sql) - require.NoError(t, err) - - resultChan := make(chan interface{}, 10) - ssql.Stream().AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - testData := []map[string]interface{}{ - {"deviceId": "sensor1", "value": 25.0, "status": "active", "priority": "high"}, // 满足第一个条件 - {"deviceId": "sensor2", "value": 15.0, "status": "active", "priority": "low"}, // 不满足条件 - {"deviceId": "sensor3", "value": nil, "status": nil, "priority": "medium"}, // 满足第二个条件 - {"deviceId": "sensor4", "value": nil, "status": nil, "priority": nil}, // 不满足条件 - {"deviceId": "sensor5", "value": 30.0, "status": "inactive", "priority": nil}, // 满足第一个条件 - {"deviceId": "sensor6", "value": 10.0, "status": nil, "priority": "urgent"}, // 满足第二个条件 - } - - for _, data := range testData { - ssql.Stream().AddData(data) - } - - // 使用超时方式安全收集结果 - var results []map[string]interface{} - timeout := time.After(500 * time.Millisecond) - -collecting: - for { - select { - case result := <-resultChan: - if resultSlice, ok := result.([]map[string]interface{}); ok { - results = append(results, resultSlice...) - } - case <-timeout: - break collecting - } - } - - // 验证结果:应该有sensor1, sensor3, sensor5, sensor6 - assert.Len(t, results, 4, "应该有4个结果") - - expectedDeviceIds := []string{"sensor1", "sensor3", "sensor5", "sensor6"} - actualDeviceIds := make([]string, len(results)) - for i, result := range results { - actualDeviceIds[i] = result["deviceId"].(string) - } - - for _, expectedId := range expectedDeviceIds { - assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) - } -} +package streamsql + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestIsNullOperatorInSQL 测试IS NULL和IS NOT NULL语法功能 +func TestIsNullOperatorInSQL(t *testing.T) { + testCases := []struct { + name string + sql string + testData []map[string]interface{} + expected []map[string]interface{} + }{ + { + name: "IS NULL测试", + sql: "SELECT deviceId, value FROM stream WHERE value IS NULL", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor3", "value": 30.0}, + {"deviceId": "sensor4", "value": nil}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor4", "value": nil}, + }, + }, + { + name: "IS NOT NULL测试", + sql: "SELECT deviceId, value FROM stream WHERE value IS NOT NULL", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor3", "value": 30.0}, + {"deviceId": "sensor4", "value": nil}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor3", "value": 30.0}, + }, + }, + { + name: "嵌套字段IS NULL测试", + sql: "SELECT deviceId, device.location FROM stream WHERE device.location IS NULL", + testData: []map[string]interface{}{ + { + "deviceId": "sensor1", + "device": map[string]interface{}{ + "location": "warehouse-A", + }, + }, + { + "deviceId": "sensor2", + "device": map[string]interface{}{ + "location": nil, + }, + }, + { + "deviceId": "sensor3", + "device": map[string]interface{}{}, + }, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor2", "device.location": nil}, + {"deviceId": "sensor3", "device.location": nil}, // 字段不存在也被认为是null + }, + }, + { + name: "组合条件 - IS NULL AND其他条件", + sql: "SELECT deviceId, value, status FROM stream WHERE value IS NULL AND status = 'active'", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, + {"deviceId": "sensor2", "value": nil, "status": "active"}, + {"deviceId": "sensor3", "value": nil, "status": "inactive"}, + {"deviceId": "sensor4", "value": 30.0, "status": "active"}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor2", "value": nil, "status": "active"}, + }, + }, + { + name: "组合条件 - IS NOT NULL OR其他条件", + sql: "SELECT deviceId, value, status FROM stream WHERE value IS NOT NULL OR status = 'error'", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, + {"deviceId": "sensor2", "value": nil, "status": "active"}, + {"deviceId": "sensor3", "value": nil, "status": "error"}, + {"deviceId": "sensor4", "value": 30.0, "status": "inactive"}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, + {"deviceId": "sensor3", "value": nil, "status": "error"}, + {"deviceId": "sensor4", "value": 30.0, "status": "inactive"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // 创建StreamSQL实例 + ssql := New() + defer ssql.Stop() + + // 执行SQL + err := ssql.Execute(tc.sql) + require.NoError(t, err) + + // 收集结果 + var results []map[string]interface{} + resultChan := make(chan interface{}, 10) + + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 使用一个done channel来同步 + done := make(chan bool, 1) + + // 添加测试数据 + for _, data := range tc.testData { + ssql.Stream().Emit(data) + } + + // 在另一个goroutine中收集结果 + go func() { + defer func() { done <- true }() + // 等待一段时间收集结果 + timeout := time.After(300 * time.Millisecond) + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-timeout: + return + } + } + }() + + // 等待收集完成 + <-done + + // 验证结果数量 + assert.Len(t, results, len(tc.expected), "结果数量应该匹配") + + // 验证结果内容(不依赖顺序) + expectedDeviceIds := make([]string, len(tc.expected)) + for i, exp := range tc.expected { + expectedDeviceIds[i] = exp["deviceId"].(string) + } + + actualDeviceIds := make([]string, len(results)) + for i, result := range results { + actualDeviceIds[i] = result["deviceId"].(string) + } + + // 验证每个期望的设备ID都在结果中 + for _, expectedId := range expectedDeviceIds { + assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) + } + + // 验证每个结果的字段值 + for _, result := range results { + deviceId := result["deviceId"].(string) + // 找到对应的期望结果 + var expectedResult map[string]interface{} + for _, exp := range tc.expected { + if exp["deviceId"].(string) == deviceId { + expectedResult = exp + break + } + } + + if expectedResult != nil { + for key, expectedValue := range expectedResult { + actualValue := result[key] + assert.Equal(t, expectedValue, actualValue, + "设备 %s 的字段 %s 值应该匹配: 期望 %v, 实际 %v", deviceId, key, expectedValue, actualValue) + } + } + } + }) + } +} + +// TestIsNullInAggregation 测试聚合查询中的IS NULL +func TestIsNullInAggregation(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 聚合查询:统计非空值的数量 + sql := `SELECT deviceType, + COUNT(*) as total_count, + COUNT(value) as non_null_count + FROM stream + WHERE value IS NOT NULL + GROUP BY deviceType, TumblingWindow('2s')` + + err := ssql.Execute(sql) + require.NoError(t, err) + + // 收集结果 + resultChan := make(chan interface{}, 10) + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"deviceType": "temperature", "value": 25.5}, + {"deviceType": "temperature", "value": nil}, + {"deviceType": "temperature", "value": 27.0}, + {"deviceType": "humidity", "value": 60.0}, + {"deviceType": "humidity", "value": nil}, + } + + for _, data := range testData { + ssql.Stream().Emit(data) + } + + // 等待窗口触发 + time.Sleep(3 * time.Second) + + // 验证结果 + select { + case result := <-resultChan: + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 应该有temperature和humidity两种类型的结果 + assert.GreaterOrEqual(t, len(resultSlice), 1, "应该至少有一个聚合结果") + + for _, item := range resultSlice { + deviceType := item["deviceType"] + totalCount, _ := item["total_count"].(float64) + nonNullCount, _ := item["non_null_count"].(float64) + + if deviceType == "temperature" { + // temperature有2个非空值(25.5, 27.0) + assert.Equal(t, 2.0, totalCount, "temperature总数应该是2") + assert.Equal(t, 2.0, nonNullCount, "temperature非空数应该是2") + } else if deviceType == "humidity" { + // humidity有1个非空值(60.0) + assert.Equal(t, 1.0, totalCount, "humidity总数应该是1") + assert.Equal(t, 1.0, nonNullCount, "humidity非空数应该是1") + } + } + case <-time.After(5 * time.Second): + t.Fatal("测试超时,未收到聚合结果") + } +} + +// TestIsNullInHaving 测试HAVING子句中真正的IS NULL功能 +func TestIsNullInHaving(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 测试HAVING子句中的IS NULL:只返回平均值为NULL的设备类型 + sql := `SELECT deviceType, + COUNT(*) as total_count, + AVG(value) as avg_value + FROM stream + GROUP BY deviceType, TumblingWindow('2s') + HAVING avg_value IS NULL` + + err := ssql.Execute(sql) + require.NoError(t, err) + + resultChan := make(chan interface{}, 10) + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据:只给pressure设备类型添加null值,这样它的平均值会是null + testData := []map[string]interface{}{ + {"deviceType": "temperature", "value": 25.0}, + {"deviceType": "temperature", "value": 27.0}, // temperature有值,平均值不为null + {"deviceType": "humidity", "value": 60.0}, // humidity有值,平均值不为null + {"deviceType": "pressure", "value": nil}, // pressure只有null值 + {"deviceType": "pressure", "value": nil}, // pressure再次null值,平均值会是null + } + + for _, data := range testData { + ssql.Stream().Emit(data) + } + + // 等待窗口触发 + time.Sleep(3 * time.Second) + + // 验证结果 + select { + case result := <-resultChan: + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 应该只有pressure类型的结果(平均值为null) + assert.Len(t, resultSlice, 1, "应该只有一个结果") + + if len(resultSlice) > 0 { + item := resultSlice[0] + assert.Equal(t, "pressure", item["deviceType"], "应该是pressure类型") + + // 验证avg_value确实为null + avgValue := item["avg_value"] + assert.Nil(t, avgValue, "pressure的平均值应该是null") + + // 验证total_count + totalCount, ok := item["total_count"].(float64) + assert.True(t, ok, "total_count应该是float64类型") + assert.Equal(t, 2.0, totalCount, "pressure应该有2条记录") + } + + case <-time.After(5 * time.Second): + t.Fatal("测试超时,未收到聚合结果") + } +} + +// TestIsNullInHavingWithIsNotNull 测试HAVING子句中的IS NOT NULL功能 +func TestIsNullInHavingWithIsNotNull(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 测试HAVING子句中的IS NOT NULL:只返回平均值不为NULL的设备类型 + sql := `SELECT deviceType, + COUNT(*) as total_count, + AVG(value) as avg_value + FROM stream + GROUP BY deviceType, TumblingWindow('2s') + HAVING avg_value IS NOT NULL` + + err := ssql.Execute(sql) + require.NoError(t, err) + + resultChan := make(chan interface{}, 10) + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"deviceType": "temperature", "value": 25.0}, + {"deviceType": "temperature", "value": 27.0}, // temperature有值,平均值不为null + {"deviceType": "humidity", "value": 60.0}, // humidity有值,平均值不为null + {"deviceType": "pressure", "value": nil}, // pressure只有null值,平均值会是null + {"deviceType": "pressure", "value": nil}, + } + + for _, data := range testData { + ssql.Stream().Emit(data) + } + + // 等待窗口触发 + time.Sleep(3 * time.Second) + + // 验证结果 + select { + case result := <-resultChan: + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 应该有temperature和humidity两种类型的结果(平均值不为null) + assert.Len(t, resultSlice, 2, "应该有两个结果") + + foundTypes := make([]string, 0) + for _, item := range resultSlice { + deviceType, ok := item["deviceType"].(string) + require.True(t, ok, "deviceType应该是string类型") + + // 验证avg_value不为null + avgValue := item["avg_value"] + assert.NotNil(t, avgValue, fmt.Sprintf("%s的平均值应该不为null", deviceType)) + + foundTypes = append(foundTypes, deviceType) + } + + // 验证包含temperature和humidity,不包含pressure + assert.Contains(t, foundTypes, "temperature", "结果应该包含temperature") + assert.Contains(t, foundTypes, "humidity", "结果应该包含humidity") + assert.NotContains(t, foundTypes, "pressure", "结果不应该包含pressure") + + case <-time.After(5 * time.Second): + t.Fatal("测试超时,未收到聚合结果") + } +} + +// TestIsNullWithOtherOperators 测试IS NULL与其他操作符的组合 +func TestIsNullWithOtherOperators(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 测试复杂的WHERE条件 + sql := `SELECT deviceId, value, status, location + FROM stream + WHERE (value IS NOT NULL AND value > 20) OR + (status IS NULL AND location LIKE 'warehouse%')` + + err := ssql.Execute(sql) + require.NoError(t, err) + + resultChan := make(chan interface{}, 10) + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.0, "status": "active", "location": "warehouse-A"}, // 满足第一个条件 + {"deviceId": "sensor2", "value": 15.0, "status": "active", "location": "warehouse-B"}, // 不满足条件 + {"deviceId": "sensor3", "value": nil, "status": nil, "location": "warehouse-C"}, // 满足第二个条件 + {"deviceId": "sensor4", "value": nil, "status": "inactive", "location": "warehouse-D"}, // 不满足条件 + {"deviceId": "sensor5", "value": 30.0, "status": nil, "location": "office-A"}, // 满足第一个条件 + } + + for _, data := range testData { + ssql.Stream().Emit(data) + } + + // 使用超时方式安全收集结果 + var results []map[string]interface{} + timeout := time.After(500 * time.Millisecond) + +collecting: + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-timeout: + break collecting + } + } + + // 验证结果:应该有sensor1, sensor3, sensor5 + assert.Len(t, results, 3, "应该有3个结果") + + expectedDeviceIds := []string{"sensor1", "sensor3", "sensor5"} + actualDeviceIds := make([]string, len(results)) + for i, result := range results { + actualDeviceIds[i] = result["deviceId"].(string) + } + + for _, expectedId := range expectedDeviceIds { + assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) + } +} + +// TestCaseWhenWithIsNull 测试CASE WHEN表达式中使用IS NULL和IS NOT NULL +func TestCaseWhenWithIsNull(t *testing.T) { + testCases := []struct { + name string + sql string + testData []map[string]interface{} + expected []map[string]interface{} + }{ + { + name: "CASE WHEN IS NULL基本测试", + sql: `SELECT deviceId, + CASE WHEN status IS NULL THEN 0 + WHEN status IS NOT NULL THEN 1 + ELSE 2 END as status_flag + FROM stream`, + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "status": "active"}, + {"deviceId": "sensor2", "status": nil}, + {"deviceId": "sensor3", "status": "inactive"}, + {"deviceId": "sensor4"}, // 没有status字段 + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "status_flag": 1.0}, + {"deviceId": "sensor2", "status_flag": 0.0}, + {"deviceId": "sensor3", "status_flag": 1.0}, + {"deviceId": "sensor4", "status_flag": 0.0}, + }, + }, + { + name: "CASE WHEN IS NOT NULL复杂条件测试", + sql: `SELECT deviceId, + CASE WHEN temperature IS NOT NULL AND temperature > 25 THEN 2 + WHEN temperature IS NOT NULL AND temperature <= 25 THEN 1 + WHEN temperature IS NULL THEN 0 + ELSE 3 END as temp_level + FROM stream`, + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "temperature": 30.0}, + {"deviceId": "sensor2", "temperature": 20.0}, + {"deviceId": "sensor3", "temperature": nil}, + {"deviceId": "sensor4"}, // 没有temperature字段 + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "temp_level": 2.0}, + {"deviceId": "sensor2", "temp_level": 1.0}, + {"deviceId": "sensor3", "temp_level": 0.0}, + {"deviceId": "sensor4", "temp_level": 0.0}, + }, + }, + { + name: "多个CASE WHEN IS NULL条件测试", + sql: `SELECT deviceId, + CASE WHEN status IS NULL AND temperature IS NULL THEN 0 + WHEN status IS NULL AND temperature IS NOT NULL THEN 1 + WHEN status IS NOT NULL AND temperature IS NULL THEN 2 + WHEN status IS NOT NULL AND temperature IS NOT NULL THEN 3 + ELSE 4 END as combined_flag + FROM stream`, + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "status": "active", "temperature": 25.0}, + {"deviceId": "sensor2", "status": "active", "temperature": nil}, + {"deviceId": "sensor3", "status": nil, "temperature": 30.0}, + {"deviceId": "sensor4", "status": nil, "temperature": nil}, + {"deviceId": "sensor5"}, // 两个字段都不存在 + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "combined_flag": 3.0}, + {"deviceId": "sensor2", "combined_flag": 2.0}, + {"deviceId": "sensor3", "combined_flag": 1.0}, + {"deviceId": "sensor4", "combined_flag": 0.0}, + {"deviceId": "sensor5", "combined_flag": 0.0}, + }, + }, + { + name: "CASE WHEN IS NULL与聚合函数结合测试", + sql: `SELECT deviceType, + COUNT(*) as total_count, + SUM(CASE WHEN value IS NULL THEN 1 ELSE 0 END) as null_count, + SUM(CASE WHEN value IS NOT NULL THEN 1 ELSE 0 END) as non_null_count + FROM stream + GROUP BY deviceType, TumblingWindow('2s')`, + testData: []map[string]interface{}{ + {"deviceType": "temperature", "value": 25.0}, + {"deviceType": "temperature", "value": nil}, + {"deviceType": "temperature", "value": 27.0}, + {"deviceType": "humidity", "value": 60.0}, + {"deviceType": "humidity", "value": nil}, + {"deviceType": "humidity", "value": nil}, + }, + expected: []map[string]interface{}{ + { + "deviceType": "temperature", + "total_count": 3.0, + "null_count": 1.0, + "non_null_count": 2.0, + }, + { + "deviceType": "humidity", + "total_count": 3.0, + "null_count": 2.0, + "non_null_count": 1.0, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // 创建StreamSQL实例 + ssql := New() + defer ssql.Stop() + + // 执行SQL + err := ssql.Execute(tc.sql) + require.NoError(t, err) + + // 收集结果 + var results []map[string]interface{} + resultChan := make(chan interface{}, 10) + + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + for _, data := range tc.testData { + ssql.Stream().Emit(data) + } + + // 使用超时方式安全收集结果 + var timeout time.Duration + if tc.name == "CASE WHEN IS NULL与聚合函数结合测试" { + timeout = 4 * time.Second // 聚合查询需要更长时间 + } else { + timeout = 500 * time.Millisecond + } + + timeoutChan := time.After(timeout) + + collecting: + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-timeoutChan: + break collecting + } + } + + // 验证结果数量 + assert.Len(t, results, len(tc.expected), "结果数量应该匹配") + + // 对于聚合查询,验证逻辑略有不同 + if tc.name == "CASE WHEN IS NULL与聚合函数结合测试" { + // 验证每个deviceType的结果 + for _, expectedResult := range tc.expected { + expectedDeviceType := expectedResult["deviceType"].(string) + + // 在结果中找到对应的deviceType + var actualResult map[string]interface{} + for _, result := range results { + if result["deviceType"].(string) == expectedDeviceType { + actualResult = result + break + } + } + + require.NotNil(t, actualResult, "应该找到设备类型 %s 的结果", expectedDeviceType) + + // 验证各个统计值 + assert.Equal(t, expectedResult["total_count"], actualResult["total_count"], + "设备类型 %s 的total_count应该匹配", expectedDeviceType) + assert.Equal(t, expectedResult["null_count"], actualResult["null_count"], + "设备类型 %s 的null_count应该匹配", expectedDeviceType) + assert.Equal(t, expectedResult["non_null_count"], actualResult["non_null_count"], + "设备类型 %s 的non_null_count应该匹配", expectedDeviceType) + } + } else { + // 验证普通查询的结果 + expectedDeviceIds := make([]string, len(tc.expected)) + for i, exp := range tc.expected { + expectedDeviceIds[i] = exp["deviceId"].(string) + } + + actualDeviceIds := make([]string, len(results)) + for i, result := range results { + actualDeviceIds[i] = result["deviceId"].(string) + } + + // 验证每个期望的设备ID都在结果中 + for _, expectedId := range expectedDeviceIds { + assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) + } + + // 验证每个结果的字段值 + for _, result := range results { + deviceId := result["deviceId"].(string) + // 找到对应的期望结果 + var expectedResult map[string]interface{} + for _, exp := range tc.expected { + if exp["deviceId"].(string) == deviceId { + expectedResult = exp + break + } + } + + if expectedResult != nil { + for key, expectedValue := range expectedResult { + if key != "deviceId" { // deviceId已经验证过了 + actualValue := result[key] + assert.Equal(t, expectedValue, actualValue, + "设备 %s 的字段 %s 值应该匹配: 期望 %v, 实际 %v", deviceId, key, expectedValue, actualValue) + } + } + } + } + } + }) + } +} + +// TestNullComparisons 测试 = nil、!= nil、= null、!= null 等语法 +func TestNullComparisons(t *testing.T) { + testCases := []struct { + name string + sql string + testData []map[string]interface{} + expected []map[string]interface{} + }{ + { + name: "fieldName = nil 测试", + sql: "SELECT deviceId, value FROM stream WHERE value = nil", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor3", "value": 30.0}, + {"deviceId": "sensor4", "value": nil}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor4", "value": nil}, + }, + }, + { + name: "fieldName != nil 测试", + sql: "SELECT deviceId, value FROM stream WHERE value != nil", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor3", "value": 30.0}, + {"deviceId": "sensor4", "value": nil}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor3", "value": 30.0}, + }, + }, + { + name: "fieldName = null 测试", + sql: "SELECT deviceId, value FROM stream WHERE value = null", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor3", "value": 30.0}, + {"deviceId": "sensor4", "value": nil}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor4", "value": nil}, + }, + }, + { + name: "fieldName != null 测试", + sql: "SELECT deviceId, value FROM stream WHERE value != null", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor2", "value": nil}, + {"deviceId": "sensor3", "value": 30.0}, + {"deviceId": "sensor4", "value": nil}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5}, + {"deviceId": "sensor3", "value": 30.0}, + }, + }, + { + name: "嵌套字段 = nil 测试", + sql: "SELECT deviceId, device.location FROM stream WHERE device.location = nil", + testData: []map[string]interface{}{ + { + "deviceId": "sensor1", + "device": map[string]interface{}{ + "location": "warehouse-A", + }, + }, + { + "deviceId": "sensor2", + "device": map[string]interface{}{ + "location": nil, + }, + }, + { + "deviceId": "sensor3", + "device": map[string]interface{}{}, + }, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor2", "device.location": nil}, + {"deviceId": "sensor3", "device.location": nil}, // 字段不存在也被认为是null + }, + }, + { + name: "嵌套字段 != nil 测试", + sql: "SELECT deviceId, device.location FROM stream WHERE device.location != nil", + testData: []map[string]interface{}{ + { + "deviceId": "sensor1", + "device": map[string]interface{}{ + "location": "warehouse-A", + }, + }, + { + "deviceId": "sensor2", + "device": map[string]interface{}{ + "location": nil, + }, + }, + { + "deviceId": "sensor3", + "device": map[string]interface{}{}, + }, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "device.location": "warehouse-A"}, + }, + }, + { + name: "组合条件 - != nil AND 其他条件", + sql: "SELECT deviceId, value, status FROM stream WHERE value != nil AND value > 20", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, + {"deviceId": "sensor2", "value": nil, "status": "active"}, + {"deviceId": "sensor3", "value": 15.0, "status": "inactive"}, + {"deviceId": "sensor4", "value": 30.0, "status": "active"}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, + {"deviceId": "sensor4", "value": 30.0, "status": "active"}, + }, + }, + { + name: "组合条件 - = nil OR 其他条件", + sql: "SELECT deviceId, value, status FROM stream WHERE value = nil OR status = 'error'", + testData: []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.5, "status": "active"}, + {"deviceId": "sensor2", "value": nil, "status": "active"}, + {"deviceId": "sensor3", "value": 30.0, "status": "error"}, + {"deviceId": "sensor4", "value": nil, "status": "inactive"}, + }, + expected: []map[string]interface{}{ + {"deviceId": "sensor2", "value": nil, "status": "active"}, + {"deviceId": "sensor3", "value": 30.0, "status": "error"}, + {"deviceId": "sensor4", "value": nil, "status": "inactive"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // 创建StreamSQL实例 + ssql := New() + defer ssql.Stop() + + // 执行SQL + err := ssql.Execute(tc.sql) + require.NoError(t, err) + + // 收集结果 + var results []map[string]interface{} + resultChan := make(chan interface{}, 10) + + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + for _, data := range tc.testData { + ssql.Stream().Emit(data) + } + + // 使用超时方式安全收集结果 + timeout := time.After(500 * time.Millisecond) + + collecting: + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-timeout: + break collecting + } + } + + // 验证结果数量 + assert.Len(t, results, len(tc.expected), "结果数量应该匹配") + + // 验证结果内容(不依赖顺序) + expectedDeviceIds := make([]string, len(tc.expected)) + for i, exp := range tc.expected { + expectedDeviceIds[i] = exp["deviceId"].(string) + } + + actualDeviceIds := make([]string, len(results)) + for i, result := range results { + actualDeviceIds[i] = result["deviceId"].(string) + } + + // 验证每个期望的设备ID都在结果中 + for _, expectedId := range expectedDeviceIds { + assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) + } + + // 验证每个结果的字段值 + for _, result := range results { + deviceId := result["deviceId"].(string) + // 找到对应的期望结果 + var expectedResult map[string]interface{} + for _, exp := range tc.expected { + if exp["deviceId"].(string) == deviceId { + expectedResult = exp + break + } + } + + if expectedResult != nil { + for key, expectedValue := range expectedResult { + actualValue := result[key] + assert.Equal(t, expectedValue, actualValue, + "设备 %s 的字段 %s 值应该匹配: 期望 %v, 实际 %v", deviceId, key, expectedValue, actualValue) + } + } + } + }) + } +} + +// TestNullComparisonInAggregation 测试聚合查询中的 = nil 和 != nil +func TestNullComparisonInAggregation(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 聚合查询:统计非空值的数量 + sql := `SELECT deviceType, + COUNT(*) as total_count, + COUNT(value) as non_null_count + FROM stream + WHERE value != nil + GROUP BY deviceType, TumblingWindow('2s')` + + err := ssql.Execute(sql) + require.NoError(t, err) + + // 收集结果 + resultChan := make(chan interface{}, 10) + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"deviceType": "temperature", "value": 25.5}, + {"deviceType": "temperature", "value": nil}, + {"deviceType": "temperature", "value": 27.0}, + {"deviceType": "humidity", "value": 60.0}, + {"deviceType": "humidity", "value": nil}, + } + + for _, data := range testData { + ssql.Stream().Emit(data) + } + + // 等待窗口触发 + time.Sleep(3 * time.Second) + + // 验证结果 + select { + case result := <-resultChan: + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 应该有temperature和humidity两种类型的结果 + assert.GreaterOrEqual(t, len(resultSlice), 1, "应该至少有一个聚合结果") + + for _, item := range resultSlice { + deviceType := item["deviceType"] + totalCount, _ := item["total_count"].(float64) + nonNullCount, _ := item["non_null_count"].(float64) + + if deviceType == "temperature" { + // temperature有2个非空值(25.5, 27.0) + assert.Equal(t, 2.0, totalCount, "temperature总数应该是2") + assert.Equal(t, 2.0, nonNullCount, "temperature非空数应该是2") + } else if deviceType == "humidity" { + // humidity有1个非空值(60.0) + assert.Equal(t, 1.0, totalCount, "humidity总数应该是1") + assert.Equal(t, 1.0, nonNullCount, "humidity非空数应该是1") + } + } + case <-time.After(5 * time.Second): + t.Fatal("测试超时,未收到聚合结果") + } +} + +// TestMixedNullComparisons 测试混合使用 IS NULL、= nil、= null、!= null 等语法 +func TestMixedNullComparisons(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 测试混合null比较语法 + sql := `SELECT deviceId, value, status, priority + FROM stream + WHERE (value IS NOT NULL AND value > 20) OR + (status = nil AND priority != null)` + + err := ssql.Execute(sql) + require.NoError(t, err) + + resultChan := make(chan interface{}, 10) + ssql.Stream().AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 25.0, "status": "active", "priority": "high"}, // 满足第一个条件 + {"deviceId": "sensor2", "value": 15.0, "status": "active", "priority": "low"}, // 不满足条件 + {"deviceId": "sensor3", "value": nil, "status": nil, "priority": "medium"}, // 满足第二个条件 + {"deviceId": "sensor4", "value": nil, "status": nil, "priority": nil}, // 不满足条件 + {"deviceId": "sensor5", "value": 30.0, "status": "inactive", "priority": nil}, // 满足第一个条件 + {"deviceId": "sensor6", "value": 10.0, "status": nil, "priority": "urgent"}, // 满足第二个条件 + } + + for _, data := range testData { + ssql.Stream().Emit(data) + } + + // 使用超时方式安全收集结果 + var results []map[string]interface{} + timeout := time.After(500 * time.Millisecond) + +collecting: + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-timeout: + break collecting + } + } + + // 验证结果:应该有sensor1, sensor3, sensor5, sensor6 + assert.Len(t, results, 4, "应该有4个结果") + + expectedDeviceIds := []string{"sensor1", "sensor3", "sensor5", "sensor6"} + actualDeviceIds := make([]string, len(results)) + for i, result := range results { + actualDeviceIds[i] = result["deviceId"].(string) + } + + for _, expectedId := range expectedDeviceIds { + assert.Contains(t, actualDeviceIds, expectedId, "结果应该包含设备ID %s", expectedId) + } +} diff --git a/streamsql_like_test.go b/streamsql_like_test.go index 931a15e..a7c2701 100644 --- a/streamsql_like_test.go +++ b/streamsql_like_test.go @@ -1,551 +1,551 @@ -package streamsql - -import ( - "context" - "fmt" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestLikeOperatorInSQL 测试LIKE语法功能 -func TestLikeOperatorInSQL(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - // 测试场景1:基本LIKE模式匹配 - 前缀匹配 - t.Run("前缀匹配(prefix%)", func(t *testing.T) { - // 测试使用LIKE进行前缀匹配 - var rsql = "SELECT deviceId, deviceType FROM stream WHERE deviceId LIKE 'sensor%'" - err := streamsql.Execute(rsql) - assert.Nil(t, err) - strm := streamsql.stream - - // 创建结果接收通道 - resultChan := make(chan interface{}, 10) - - // 添加结果回调 - strm.AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - testData := []interface{}{ - map[string]interface{}{"deviceId": "sensor001", "deviceType": "temperature"}, - map[string]interface{}{"deviceId": "device002", "deviceType": "humidity"}, - map[string]interface{}{"deviceId": "sensor003", "deviceType": "pressure"}, - map[string]interface{}{"deviceId": "pump004", "deviceType": "actuator"}, - } - - // 添加数据 - for _, data := range testData { - strm.AddData(data) - } - - // 等待并收集结果 - var results []interface{} - timeout := time.After(2 * time.Second) - done := false - - for !done && len(results) < 2 { - select { - case result := <-resultChan: - results = append(results, result) - case <-timeout: - done = true - } - } - - // 验证结果:应该只有sensor001和sensor003匹配 - assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") - - // 验证结果中只包含以"sensor"开头的设备 - for _, result := range results { - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - for _, item := range resultSlice { - deviceId, _ := item["deviceId"].(string) - assert.True(t, strings.HasPrefix(deviceId, "sensor"), - fmt.Sprintf("设备ID %s 应该以'sensor'开头", deviceId)) - } - } - }) - - // 测试场景2:后缀匹配 - t.Run("后缀匹配(%suffix)", func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - var rsql = "SELECT deviceId, status FROM stream WHERE status LIKE '%error'" - err := streamsql.Execute(rsql) - assert.Nil(t, err) - strm := streamsql.stream - - resultChan := make(chan interface{}, 10) - strm.AddSink(func(result interface{}) { - resultChan <- result - }) - - testData := []interface{}{ - map[string]interface{}{"deviceId": "dev1", "status": "connection_error"}, - map[string]interface{}{"deviceId": "dev2", "status": "running"}, - map[string]interface{}{"deviceId": "dev3", "status": "timeout_error"}, - map[string]interface{}{"deviceId": "dev4", "status": "normal"}, - } - - for _, data := range testData { - strm.AddData(data) - } - - // 等待结果 - var results []interface{} - timeout := time.After(2 * time.Second) - done := false - - for !done && len(results) < 2 { - select { - case result := <-resultChan: - results = append(results, result) - case <-timeout: - done = true - } - } - - // 验证结果:应该只有以"error"结尾的状态 - assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") - - for _, result := range results { - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - for _, item := range resultSlice { - status, _ := item["status"].(string) - assert.True(t, strings.HasSuffix(status, "error"), - fmt.Sprintf("状态 %s 应该以'error'结尾", status)) - } - } - }) - - // 测试场景3:包含匹配 - t.Run("包含匹配(%substring%)", func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - var rsql = "SELECT deviceId, message FROM stream WHERE message LIKE '%alert%'" - err := streamsql.Execute(rsql) - assert.Nil(t, err) - strm := streamsql.stream - - resultChan := make(chan interface{}, 10) - strm.AddSink(func(result interface{}) { - resultChan <- result - }) - - testData := []interface{}{ - map[string]interface{}{"deviceId": "dev1", "message": "system alert: high temperature"}, - map[string]interface{}{"deviceId": "dev2", "message": "normal operation"}, - map[string]interface{}{"deviceId": "dev3", "message": "critical alert detected"}, - map[string]interface{}{"deviceId": "dev4", "message": "info: device startup"}, - } - - for _, data := range testData { - strm.AddData(data) - } - - // 等待结果 - var results []interface{} - timeout := time.After(2 * time.Second) - done := false - - for !done && len(results) < 2 { - select { - case result := <-resultChan: - results = append(results, result) - case <-timeout: - done = true - } - } - - // 验证结果:应该只有包含"alert"的消息 - assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") - - for _, result := range results { - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - for _, item := range resultSlice { - message, _ := item["message"].(string) - assert.True(t, strings.Contains(message, "alert"), - fmt.Sprintf("消息 %s 应该包含'alert'", message)) - } - } - }) - - // 测试场景4:单字符通配符 - t.Run("单字符通配符(_)", func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - var rsql = "SELECT deviceId, code FROM stream WHERE code LIKE 'E_0_'" - err := streamsql.Execute(rsql) - assert.Nil(t, err) - strm := streamsql.stream - - resultChan := make(chan interface{}, 10) - strm.AddSink(func(result interface{}) { - resultChan <- result - }) - - testData := []interface{}{ - map[string]interface{}{"deviceId": "dev1", "code": "E101"}, - map[string]interface{}{"deviceId": "dev2", "code": "E202"}, - map[string]interface{}{"deviceId": "dev3", "code": "E305"}, - map[string]interface{}{"deviceId": "dev4", "code": "F101"}, - } - - for _, data := range testData { - strm.AddData(data) - } - - // 等待结果 - var results []interface{} - timeout := time.After(2 * time.Second) - done := false - - for !done && len(results) < 2 { - select { - case result := <-resultChan: - results = append(results, result) - case <-timeout: - done = true - } - } - - // 验证结果:应该只有E_0_模式的代码(E101, E202不匹配E_0_,只有E305也不完全匹配) - // 实际上,根据模式E_0_,应该匹配如E101, E202等,让我们调整测试数据 - assert.GreaterOrEqual(t, len(results), 0, "根据通配符模式可能有匹配结果") - }) - - // 测试场景5:复杂模式 - t.Run("复杂LIKE模式", func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - var rsql = "SELECT deviceId, filename FROM stream WHERE filename LIKE '%.log'" - err := streamsql.Execute(rsql) - assert.Nil(t, err) - strm := streamsql.stream - - resultChan := make(chan interface{}, 10) - strm.AddSink(func(result interface{}) { - resultChan <- result - }) - - testData := []interface{}{ - map[string]interface{}{"deviceId": "dev1", "filename": "system.log"}, - map[string]interface{}{"deviceId": "dev2", "filename": "config.txt"}, - map[string]interface{}{"deviceId": "dev3", "filename": "error.log"}, - map[string]interface{}{"deviceId": "dev4", "filename": "backup.bak"}, - } - - for _, data := range testData { - strm.AddData(data) - } - - // 等待结果 - var results []interface{} - timeout := time.After(2 * time.Second) - done := false - - for !done && len(results) < 2 { - select { - case result := <-resultChan: - results = append(results, result) - case <-timeout: - done = true - } - } - - // 验证结果:应该只有.log文件 - assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") - - for _, result := range results { - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - for _, item := range resultSlice { - filename, _ := item["filename"].(string) - assert.True(t, strings.HasSuffix(filename, ".log"), - fmt.Sprintf("文件名 %s 应该以'.log'结尾", filename)) - } - } - }) - - // 测试场景6:在聚合查询中使用LIKE - t.Run("聚合查询中的LIKE", func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - var rsql = "SELECT deviceType, count(*) as device_count FROM stream WHERE deviceId LIKE 'sensor%' GROUP BY deviceType" - err := streamsql.Execute(rsql) - assert.Nil(t, err) - strm := streamsql.stream - - resultChan := make(chan interface{}, 10) - strm.AddSink(func(result interface{}) { - resultChan <- result - }) - - testData := []interface{}{ - map[string]interface{}{"deviceId": "sensor001", "deviceType": "temperature"}, - map[string]interface{}{"deviceId": "sensor002", "deviceType": "temperature"}, - map[string]interface{}{"deviceId": "device003", "deviceType": "temperature"}, - map[string]interface{}{"deviceId": "sensor004", "deviceType": "humidity"}, - map[string]interface{}{"deviceId": "pump005", "deviceType": "actuator"}, - } - - for _, data := range testData { - strm.AddData(data) - } - - // 等待聚合 - time.Sleep(500 * time.Millisecond) - strm.Window.Trigger() - - // 等待结果 - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - var actual interface{} - select { - case actual = <-resultChan: - cancel() - case <-ctx.Done(): - t.Fatal("测试超时,未收到聚合结果") - } - - // 验证聚合结果 - resultSlice, ok := actual.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - // 应该有两种设备类型:temperature(2个sensor), humidity(1个sensor) - assert.GreaterOrEqual(t, len(resultSlice), 1, "应该有至少一种设备类型的聚合结果") - - for _, result := range resultSlice { - deviceType, _ := result["deviceType"].(string) - count, ok := result["device_count"].(float64) - assert.True(t, ok, "device_count应该是float64类型") - assert.Greater(t, count, 0.0, "设备数量应该大于0") - - // 验证设备类型 - assert.True(t, deviceType == "temperature" || deviceType == "humidity", - fmt.Sprintf("设备类型 %s 应该是temperature或humidity", deviceType)) - } - }) - - // 测试场景7:HAVING子句中的LIKE - t.Run("HAVING子句中的LIKE", func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - var rsql = "SELECT deviceType, max(temperature) as max_temp FROM stream GROUP BY deviceType HAVING deviceType LIKE '%temp%'" - err := streamsql.Execute(rsql) - assert.Nil(t, err) - strm := streamsql.stream - - resultChan := make(chan interface{}, 10) - strm.AddSink(func(result interface{}) { - resultChan <- result - }) - - testData := []interface{}{ - map[string]interface{}{"deviceType": "temperature_sensor", "temperature": 25.0}, - map[string]interface{}{"deviceType": "temperature_sensor", "temperature": 30.0}, - map[string]interface{}{"deviceType": "humidity_sensor", "temperature": 22.0}, - map[string]interface{}{"deviceType": "pressure_gauge", "temperature": 20.0}, - } - - for _, data := range testData { - strm.AddData(data) - } - - // 等待聚合 - time.Sleep(500 * time.Millisecond) - strm.Window.Trigger() - - // 等待结果 - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - var actual interface{} - select { - case actual = <-resultChan: - cancel() - case <-ctx.Done(): - t.Fatal("测试超时,未收到HAVING+LIKE结果") - } - - // 验证HAVING + LIKE结果 - resultSlice, ok := actual.([]map[string]interface{}) - require.True(t, ok, "结果应该是[]map[string]interface{}类型") - - // 应该只有包含"temp"的设备类型 - for _, result := range resultSlice { - deviceType, _ := result["deviceType"].(string) - assert.True(t, strings.Contains(deviceType, "temp"), - fmt.Sprintf("设备类型 %s 应该包含'temp'", deviceType)) - - maxTemp, ok := result["max_temp"].(float64) - assert.True(t, ok, "max_temp应该是float64类型") - assert.Greater(t, maxTemp, 0.0, "最大温度应该大于0") - } - }) -} - -// TestLikeFunctionEquivalence 测试LIKE语法与现有字符串函数的等价性 -func TestLikeFunctionEquivalence(t *testing.T) { - // 简化测试,重点验证LIKE功能已经正常工作 - t.Run("LIKE语法工作正常验证", func(t *testing.T) { - streamsql := New() - defer streamsql.Stop() - - // 使用LIKE的查询 - var likeSQL = "SELECT deviceId FROM stream WHERE deviceId LIKE 'sensor%'" - err := streamsql.Execute(likeSQL) - assert.Nil(t, err) - - resultChan := make(chan interface{}, 10) - streamsql.stream.AddSink(func(result interface{}) { - resultChan <- result - }) - - // 测试数据 - testData := []interface{}{ - map[string]interface{}{"deviceId": "sensor001"}, - map[string]interface{}{"deviceId": "device002"}, - map[string]interface{}{"deviceId": "sensor003"}, - } - - // 添加数据 - for _, data := range testData { - streamsql.stream.AddData(data) - } - - // 收集结果 - timeout := time.After(2 * time.Second) - var results []interface{} - - for len(results) < 2 { - select { - case result := <-resultChan: - results = append(results, result) - case <-timeout: - break - } - } - - // 验证LIKE查询返回了预期的结果 - assert.Equal(t, 2, len(results), "LIKE查询应该返回2个匹配'sensor%'的结果") - t.Logf("LIKE查询成功返回%d个结果", len(results)) - - // 验证返回的结果确实是以'sensor'开头的 - for i, result := range results { - resultSlice, ok := result.([]map[string]interface{}) - assert.True(t, ok, fmt.Sprintf("结果%d应该是[]map[string]interface{}类型", i)) - if len(resultSlice) > 0 { - deviceId, exists := resultSlice[0]["deviceId"] - assert.True(t, exists, "结果应该包含deviceId字段") - deviceIdStr, ok := deviceId.(string) - assert.True(t, ok, "deviceId应该是字符串类型") - assert.True(t, strings.HasPrefix(deviceIdStr, "sensor"), - fmt.Sprintf("deviceId '%s' 应该以'sensor'开头", deviceIdStr)) - } - } - }) -} - -// TestLikePatternMatching 测试LIKE模式匹配算法的正确性 -func TestLikePatternMatching(t *testing.T) { - // 这些是单元测试,直接测试LIKE匹配函数 - tests := []struct { - text string - pattern string - expected bool - desc string - }{ - // 前缀匹配测试 - {"hello", "hello%", true, "精确前缀匹配"}, - {"hello world", "hello%", true, "前缀匹配"}, - {"hi there", "hello%", false, "前缀不匹配"}, - {"", "%", true, "空字符串匹配任意模式"}, - - // 后缀匹配测试 - {"test.log", "%.log", true, "后缀匹配"}, - {"test.txt", "%.log", false, "后缀不匹配"}, - - // 包含匹配测试 - {"hello world test", "%world%", true, "包含匹配"}, - {"hello test", "%world%", false, "不包含"}, - - // 单字符通配符测试 - {"abc", "a_c", true, "单字符通配符匹配"}, - {"aXc", "a_c", true, "单字符通配符匹配任意字符"}, - {"abbc", "a_c", false, "单字符通配符不匹配多个字符"}, - - // 复杂模式测试 - {"file123.log", "file___.log", true, "多个单字符通配符"}, - {"file12.log", "file___.log", false, "字符数不匹配"}, - {"prefix_test_suffix", "prefix%suffix", true, "前后缀组合"}, - - // 边界情况测试 - {"", "", true, "空模式匹配空字符串"}, - {"abc", "", false, "非空字符串不匹配空模式"}, - {"", "abc", false, "空字符串不匹配非空模式"}, - {"abc", "abc", true, "完全匹配"}, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - // 直接使用内部函数进行测试 - // 注意:这里我们需要通过SQL查询来测试,因为匹配函数是内部的 - streamsql := New() - defer streamsql.Stop() - - // 构造SQL查询 - rsql := fmt.Sprintf("SELECT value FROM stream WHERE value LIKE '%s'", test.pattern) - err := streamsql.Execute(rsql) - assert.Nil(t, err) - - resultChan := make(chan interface{}, 10) - streamsql.stream.AddSink(func(result interface{}) { - resultChan <- result - }) - - // 添加测试数据 - testData := map[string]interface{}{"value": test.text} - streamsql.stream.AddData(testData) - - // 等待结果 - timeout := time.After(1 * time.Second) - var hasResult bool - - select { - case result := <-resultChan: - resultSlice, ok := result.([]map[string]interface{}) - hasResult = ok && len(resultSlice) > 0 - case <-timeout: - hasResult = false - } - - if test.expected { - assert.True(t, hasResult, fmt.Sprintf("模式'%s'应该匹配文本'%s'", test.pattern, test.text)) - } else { - assert.False(t, hasResult, fmt.Sprintf("模式'%s'不应该匹配文本'%s'", test.pattern, test.text)) - } - }) - } -} +package streamsql + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestLikeOperatorInSQL 测试LIKE语法功能 +func TestLikeOperatorInSQL(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + // 测试场景1:基本LIKE模式匹配 - 前缀匹配 + t.Run("前缀匹配(prefix%)", func(t *testing.T) { + // 测试使用LIKE进行前缀匹配 + var rsql = "SELECT deviceId, deviceType FROM stream WHERE deviceId LIKE 'sensor%'" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果回调 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []interface{}{ + map[string]interface{}{"deviceId": "sensor001", "deviceType": "temperature"}, + map[string]interface{}{"deviceId": "device002", "deviceType": "humidity"}, + map[string]interface{}{"deviceId": "sensor003", "deviceType": "pressure"}, + map[string]interface{}{"deviceId": "pump004", "deviceType": "actuator"}, + } + + // 添加数据 + for _, data := range testData { + strm.Emit(data) + } + + // 等待并收集结果 + var results []interface{} + timeout := time.After(2 * time.Second) + done := false + + for !done && len(results) < 2 { + select { + case result := <-resultChan: + results = append(results, result) + case <-timeout: + done = true + } + } + + // 验证结果:应该只有sensor001和sensor003匹配 + assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") + + // 验证结果中只包含以"sensor"开头的设备 + for _, result := range results { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + for _, item := range resultSlice { + deviceId, _ := item["deviceId"].(string) + assert.True(t, strings.HasPrefix(deviceId, "sensor"), + fmt.Sprintf("设备ID %s 应该以'sensor'开头", deviceId)) + } + } + }) + + // 测试场景2:后缀匹配 + t.Run("后缀匹配(%suffix)", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT deviceId, status FROM stream WHERE status LIKE '%error'" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + testData := []interface{}{ + map[string]interface{}{"deviceId": "dev1", "status": "connection_error"}, + map[string]interface{}{"deviceId": "dev2", "status": "running"}, + map[string]interface{}{"deviceId": "dev3", "status": "timeout_error"}, + map[string]interface{}{"deviceId": "dev4", "status": "normal"}, + } + + for _, data := range testData { + strm.Emit(data) + } + + // 等待结果 + var results []interface{} + timeout := time.After(2 * time.Second) + done := false + + for !done && len(results) < 2 { + select { + case result := <-resultChan: + results = append(results, result) + case <-timeout: + done = true + } + } + + // 验证结果:应该只有以"error"结尾的状态 + assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") + + for _, result := range results { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + for _, item := range resultSlice { + status, _ := item["status"].(string) + assert.True(t, strings.HasSuffix(status, "error"), + fmt.Sprintf("状态 %s 应该以'error'结尾", status)) + } + } + }) + + // 测试场景3:包含匹配 + t.Run("包含匹配(%substring%)", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT deviceId, message FROM stream WHERE message LIKE '%alert%'" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + testData := []interface{}{ + map[string]interface{}{"deviceId": "dev1", "message": "system alert: high temperature"}, + map[string]interface{}{"deviceId": "dev2", "message": "normal operation"}, + map[string]interface{}{"deviceId": "dev3", "message": "critical alert detected"}, + map[string]interface{}{"deviceId": "dev4", "message": "info: device startup"}, + } + + for _, data := range testData { + strm.Emit(data) + } + + // 等待结果 + var results []interface{} + timeout := time.After(2 * time.Second) + done := false + + for !done && len(results) < 2 { + select { + case result := <-resultChan: + results = append(results, result) + case <-timeout: + done = true + } + } + + // 验证结果:应该只有包含"alert"的消息 + assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") + + for _, result := range results { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + for _, item := range resultSlice { + message, _ := item["message"].(string) + assert.True(t, strings.Contains(message, "alert"), + fmt.Sprintf("消息 %s 应该包含'alert'", message)) + } + } + }) + + // 测试场景4:单字符通配符 + t.Run("单字符通配符(_)", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT deviceId, code FROM stream WHERE code LIKE 'E_0_'" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + testData := []interface{}{ + map[string]interface{}{"deviceId": "dev1", "code": "E101"}, + map[string]interface{}{"deviceId": "dev2", "code": "E202"}, + map[string]interface{}{"deviceId": "dev3", "code": "E305"}, + map[string]interface{}{"deviceId": "dev4", "code": "F101"}, + } + + for _, data := range testData { + strm.Emit(data) + } + + // 等待结果 + var results []interface{} + timeout := time.After(2 * time.Second) + done := false + + for !done && len(results) < 2 { + select { + case result := <-resultChan: + results = append(results, result) + case <-timeout: + done = true + } + } + + // 验证结果:应该只有E_0_模式的代码(E101, E202不匹配E_0_,只有E305也不完全匹配) + // 实际上,根据模式E_0_,应该匹配如E101, E202等,让我们调整测试数据 + assert.GreaterOrEqual(t, len(results), 0, "根据通配符模式可能有匹配结果") + }) + + // 测试场景5:复杂模式 + t.Run("复杂LIKE模式", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT deviceId, filename FROM stream WHERE filename LIKE '%.log'" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + testData := []interface{}{ + map[string]interface{}{"deviceId": "dev1", "filename": "system.log"}, + map[string]interface{}{"deviceId": "dev2", "filename": "config.txt"}, + map[string]interface{}{"deviceId": "dev3", "filename": "error.log"}, + map[string]interface{}{"deviceId": "dev4", "filename": "backup.bak"}, + } + + for _, data := range testData { + strm.Emit(data) + } + + // 等待结果 + var results []interface{} + timeout := time.After(2 * time.Second) + done := false + + for !done && len(results) < 2 { + select { + case result := <-resultChan: + results = append(results, result) + case <-timeout: + done = true + } + } + + // 验证结果:应该只有.log文件 + assert.GreaterOrEqual(t, len(results), 1, "应该收到至少一个匹配结果") + + for _, result := range results { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + for _, item := range resultSlice { + filename, _ := item["filename"].(string) + assert.True(t, strings.HasSuffix(filename, ".log"), + fmt.Sprintf("文件名 %s 应该以'.log'结尾", filename)) + } + } + }) + + // 测试场景6:在聚合查询中使用LIKE + t.Run("聚合查询中的LIKE", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT deviceType, count(*) as device_count FROM stream WHERE deviceId LIKE 'sensor%' GROUP BY deviceType" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + testData := []interface{}{ + map[string]interface{}{"deviceId": "sensor001", "deviceType": "temperature"}, + map[string]interface{}{"deviceId": "sensor002", "deviceType": "temperature"}, + map[string]interface{}{"deviceId": "device003", "deviceType": "temperature"}, + map[string]interface{}{"deviceId": "sensor004", "deviceType": "humidity"}, + map[string]interface{}{"deviceId": "pump005", "deviceType": "actuator"}, + } + + for _, data := range testData { + strm.Emit(data) + } + + // 等待聚合 + time.Sleep(500 * time.Millisecond) + strm.Window.Trigger() + + // 等待结果 + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + var actual interface{} + select { + case actual = <-resultChan: + cancel() + case <-ctx.Done(): + t.Fatal("测试超时,未收到聚合结果") + } + + // 验证聚合结果 + resultSlice, ok := actual.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 应该有两种设备类型:temperature(2个sensor), humidity(1个sensor) + assert.GreaterOrEqual(t, len(resultSlice), 1, "应该有至少一种设备类型的聚合结果") + + for _, result := range resultSlice { + deviceType, _ := result["deviceType"].(string) + count, ok := result["device_count"].(float64) + assert.True(t, ok, "device_count应该是float64类型") + assert.Greater(t, count, 0.0, "设备数量应该大于0") + + // 验证设备类型 + assert.True(t, deviceType == "temperature" || deviceType == "humidity", + fmt.Sprintf("设备类型 %s 应该是temperature或humidity", deviceType)) + } + }) + + // 测试场景7:HAVING子句中的LIKE + t.Run("HAVING子句中的LIKE", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT deviceType, max(temperature) as max_temp FROM stream GROUP BY deviceType HAVING deviceType LIKE '%temp%'" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + testData := []interface{}{ + map[string]interface{}{"deviceType": "temperature_sensor", "temperature": 25.0}, + map[string]interface{}{"deviceType": "temperature_sensor", "temperature": 30.0}, + map[string]interface{}{"deviceType": "humidity_sensor", "temperature": 22.0}, + map[string]interface{}{"deviceType": "pressure_gauge", "temperature": 20.0}, + } + + for _, data := range testData { + strm.Emit(data) + } + + // 等待聚合 + time.Sleep(500 * time.Millisecond) + strm.Window.Trigger() + + // 等待结果 + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + var actual interface{} + select { + case actual = <-resultChan: + cancel() + case <-ctx.Done(): + t.Fatal("测试超时,未收到HAVING+LIKE结果") + } + + // 验证HAVING + LIKE结果 + resultSlice, ok := actual.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 应该只有包含"temp"的设备类型 + for _, result := range resultSlice { + deviceType, _ := result["deviceType"].(string) + assert.True(t, strings.Contains(deviceType, "temp"), + fmt.Sprintf("设备类型 %s 应该包含'temp'", deviceType)) + + maxTemp, ok := result["max_temp"].(float64) + assert.True(t, ok, "max_temp应该是float64类型") + assert.Greater(t, maxTemp, 0.0, "最大温度应该大于0") + } + }) +} + +// TestLikeFunctionEquivalence 测试LIKE语法与现有字符串函数的等价性 +func TestLikeFunctionEquivalence(t *testing.T) { + // 简化测试,重点验证LIKE功能已经正常工作 + t.Run("LIKE语法工作正常验证", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + // 使用LIKE的查询 + var likeSQL = "SELECT deviceId FROM stream WHERE deviceId LIKE 'sensor%'" + err := streamsql.Execute(likeSQL) + assert.Nil(t, err) + + resultChan := make(chan interface{}, 10) + streamsql.stream.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 测试数据 + testData := []interface{}{ + map[string]interface{}{"deviceId": "sensor001"}, + map[string]interface{}{"deviceId": "device002"}, + map[string]interface{}{"deviceId": "sensor003"}, + } + + // 添加数据 + for _, data := range testData { + streamsql.stream.Emit(data) + } + + // 收集结果 + timeout := time.After(2 * time.Second) + var results []interface{} + + for len(results) < 2 { + select { + case result := <-resultChan: + results = append(results, result) + case <-timeout: + break + } + } + + // 验证LIKE查询返回了预期的结果 + assert.Equal(t, 2, len(results), "LIKE查询应该返回2个匹配'sensor%'的结果") + t.Logf("LIKE查询成功返回%d个结果", len(results)) + + // 验证返回的结果确实是以'sensor'开头的 + for i, result := range results { + resultSlice, ok := result.([]map[string]interface{}) + assert.True(t, ok, fmt.Sprintf("结果%d应该是[]map[string]interface{}类型", i)) + if len(resultSlice) > 0 { + deviceId, exists := resultSlice[0]["deviceId"] + assert.True(t, exists, "结果应该包含deviceId字段") + deviceIdStr, ok := deviceId.(string) + assert.True(t, ok, "deviceId应该是字符串类型") + assert.True(t, strings.HasPrefix(deviceIdStr, "sensor"), + fmt.Sprintf("deviceId '%s' 应该以'sensor'开头", deviceIdStr)) + } + } + }) +} + +// TestLikePatternMatching 测试LIKE模式匹配算法的正确性 +func TestLikePatternMatching(t *testing.T) { + // 这些是单元测试,直接测试LIKE匹配函数 + tests := []struct { + text string + pattern string + expected bool + desc string + }{ + // 前缀匹配测试 + {"hello", "hello%", true, "精确前缀匹配"}, + {"hello world", "hello%", true, "前缀匹配"}, + {"hi there", "hello%", false, "前缀不匹配"}, + {"", "%", true, "空字符串匹配任意模式"}, + + // 后缀匹配测试 + {"test.log", "%.log", true, "后缀匹配"}, + {"test.txt", "%.log", false, "后缀不匹配"}, + + // 包含匹配测试 + {"hello world test", "%world%", true, "包含匹配"}, + {"hello test", "%world%", false, "不包含"}, + + // 单字符通配符测试 + {"abc", "a_c", true, "单字符通配符匹配"}, + {"aXc", "a_c", true, "单字符通配符匹配任意字符"}, + {"abbc", "a_c", false, "单字符通配符不匹配多个字符"}, + + // 复杂模式测试 + {"file123.log", "file___.log", true, "多个单字符通配符"}, + {"file12.log", "file___.log", false, "字符数不匹配"}, + {"prefix_test_suffix", "prefix%suffix", true, "前后缀组合"}, + + // 边界情况测试 + {"", "", true, "空模式匹配空字符串"}, + {"abc", "", false, "非空字符串不匹配空模式"}, + {"", "abc", false, "空字符串不匹配非空模式"}, + {"abc", "abc", true, "完全匹配"}, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + // 直接使用内部函数进行测试 + // 注意:这里我们需要通过SQL查询来测试,因为匹配函数是内部的 + streamsql := New() + defer streamsql.Stop() + + // 构造SQL查询 + rsql := fmt.Sprintf("SELECT value FROM stream WHERE value LIKE '%s'", test.pattern) + err := streamsql.Execute(rsql) + assert.Nil(t, err) + + resultChan := make(chan interface{}, 10) + streamsql.stream.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := map[string]interface{}{"value": test.text} + streamsql.stream.Emit(testData) + + // 等待结果 + timeout := time.After(1 * time.Second) + var hasResult bool + + select { + case result := <-resultChan: + resultSlice, ok := result.([]map[string]interface{}) + hasResult = ok && len(resultSlice) > 0 + case <-timeout: + hasResult = false + } + + if test.expected { + assert.True(t, hasResult, fmt.Sprintf("模式'%s'应该匹配文本'%s'", test.pattern, test.text)) + } else { + assert.False(t, hasResult, fmt.Sprintf("模式'%s'不应该匹配文本'%s'", test.pattern, test.text)) + } + }) + } +} diff --git a/streamsql_plugin_test.go b/streamsql_plugin_test.go index 6b7a898..0764f49 100644 --- a/streamsql_plugin_test.go +++ b/streamsql_plugin_test.go @@ -104,7 +104,7 @@ func testStringFunctionsOnly(t *testing.T) { "phone": "13812345678", } - streamsql.AddData(testData) + streamsql.Emit(testData) time.Sleep(300 * time.Millisecond) select { @@ -149,7 +149,7 @@ func testConversionFunctionsOnly(t *testing.T) { "user_id": "12345", } - streamsql.AddData(testData) + streamsql.Emit(testData) time.Sleep(300 * time.Millisecond) select { @@ -205,7 +205,7 @@ func testMathFunctionsInAggregate(t *testing.T) { } for _, data := range testData { - streamsql.AddData(data) + streamsql.Emit(data) } time.Sleep(1 * time.Second) @@ -269,7 +269,7 @@ func TestRuntimeFunctionManagement(t *testing.T) { resultChan <- result }) - streamsql.AddData(map[string]interface{}{"value": "test"}) + streamsql.Emit(map[string]interface{}{"value": "test"}) time.Sleep(300 * time.Millisecond) select { @@ -390,7 +390,7 @@ func TestCompleteSQLIntegration(t *testing.T) { "amount": 100.0, } - streamsql.AddData(testData) + streamsql.Emit(testData) time.Sleep(300 * time.Millisecond) select { diff --git a/streamsql_quoted_support_test.go b/streamsql_quoted_support_test.go new file mode 100644 index 0000000..24617fa --- /dev/null +++ b/streamsql_quoted_support_test.go @@ -0,0 +1,456 @@ +package streamsql + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/rulego/streamsql/functions" + "github.com/rulego/streamsql/utils/cast" + "github.com/stretchr/testify/assert" +) + +// testCase 定义测试用例结构 +type testCase struct { + name string + sql string + testData []interface{} + expectedLen int + validator func(t *testing.T, results []map[string]interface{}) +} + +// executeTestCase 执行单个测试用例的通用逻辑 +func executeTestCase(t *testing.T, streamsql *Streamsql, tc testCase) { + t.Run(tc.name, func(t *testing.T) { + err := streamsql.Execute(tc.sql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道和互斥锁保护并发访问 + resultChan := make(chan interface{}, 10) + var results []map[string]interface{} + var resultsMutex sync.Mutex + + strm.AddSink(func(result interface{}) { + select { + case resultChan <- result: + default: + // 通道满时丢弃结果,避免阻塞 + } + }) + + // 添加测试数据 + for _, data := range tc.testData { + strm.Emit(data) + } + + // 等待数据处理 + time.Sleep(200 * time.Millisecond) + + // 收集所有结果 + timeout := time.After(2 * time.Second) + for { + resultsMutex.Lock() + currentLen := len(results) + resultsMutex.Unlock() + + if currentLen >= tc.expectedLen { + break + } + + select { + case result := <-resultChan: + resultsMutex.Lock() + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } else if resultMap, ok := result.(map[string]interface{}); ok { + results = append(results, resultMap) + } + resultsMutex.Unlock() + case <-timeout: + goto checkResults + } + } + + checkResults: + // 验证结果长度(使用互斥锁保护) + resultsMutex.Lock() + finalResults := make([]map[string]interface{}, len(results)) + copy(finalResults, results) + resultsMutex.Unlock() + + assert.Equal(t, tc.expectedLen, len(finalResults)) + // 执行自定义验证 + if tc.validator != nil { + tc.validator(t, finalResults) + } + }) +} + +// executeAggregationTestCase 执行聚合函数测试用例的通用逻辑 +func executeAggregationTestCase(t *testing.T, streamsql *Streamsql, tc testCase) { + t.Run(tc.name, func(t *testing.T) { + err := streamsql.Execute(tc.sql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + select { + case resultChan <- result: + default: + // 通道满时丢弃结果,避免阻塞 + } + }) + + // 添加测试数据 + for _, data := range tc.testData { + strm.Emit(data) + } + + // 等待窗口触发 + time.Sleep(1 * time.Second) + strm.Window.Trigger() + time.Sleep(500 * time.Millisecond) + + // 验证结果 + select { + case result := <-resultChan: + if tc.validator != nil { + tc.validator(t, result.([]map[string]interface{})) + } + case <-time.After(3 * time.Second): + t.Fatal("测试超时") + } + }) +} + +// executeFunctionTestCase 执行函数测试用例的通用逻辑 +func executeFunctionTestCase(t *testing.T, streamsql *Streamsql, tc testCase) { + t.Run(tc.name, func(t *testing.T) { + err := streamsql.Execute(tc.sql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + strm.AddSink(func(result interface{}) { + select { + case resultChan <- result: + default: + // 通道满时丢弃结果,避免阻塞 + } + }) + + // 添加测试数据 + for _, data := range tc.testData { + strm.Emit(data) + } + + time.Sleep(200 * time.Millisecond) + + // 验证结果 + select { + case result := <-resultChan: + if tc.validator != nil { + tc.validator(t, result.([]map[string]interface{})) + } + case <-time.After(2 * time.Second): + t.Fatal("测试超时") + } + }) +} + +// TestQuotedIdentifiersAndStringLiterals 测试反引号标识符和字符串常量支持 +func TestQuotedIdentifiersAndStringLiterals(t *testing.T) { + // 注册测试函数(因为有测试用例使用自定义函数) + registerTestFunctions(t) + defer unregisterTestFunctions() + + streamsql := New() + defer streamsql.Stop() + + // 通用测试数据 + standardTestData := []interface{}{ + map[string]interface{}{"deviceId": "sensor001", "deviceType": "temperature"}, + map[string]interface{}{"deviceId": "device002", "deviceType": "humidity"}, + map[string]interface{}{"deviceId": "sensor003", "deviceType": "pressure"}, + } + + // 定义测试用例 + testCases := []testCase{ + { + name: "反引号标识符支持", + sql: "SELECT `deviceId`, `deviceType` FROM stream WHERE `deviceId` LIKE 'sensor%'", + testData: standardTestData, + expectedLen: 2, + validator: func(t *testing.T, results []map[string]interface{}) { + for _, result := range results { + deviceId := result["deviceId"].(string) + assert.True(t, deviceId == "sensor001" || deviceId == "sensor003") + } + }, + }, + { + name: "单引号字符串常量支持", + sql: "SELECT deviceId, deviceType, 'constant_value' as test FROM stream WHERE deviceId = 'sensor001'", + testData: standardTestData, + expectedLen: 1, + validator: func(t *testing.T, results []map[string]interface{}) { + if len(results) > 0 { + resultMap := results[0] + assert.Equal(t, "sensor001", resultMap["deviceId"]) + assert.Equal(t, "temperature", resultMap["deviceType"]) + assert.Equal(t, "constant_value", resultMap["test"]) + } + }, + }, + { + name: "双引号字符串常量支持", + sql: `SELECT deviceId, deviceType, "another_constant" as test FROM stream WHERE deviceType = "temperature"`, + testData: standardTestData, + expectedLen: 1, + validator: func(t *testing.T, results []map[string]interface{}) { + if len(results) > 0 { + resultMap := results[0] + assert.Equal(t, "sensor001", resultMap["deviceId"]) + assert.Equal(t, "temperature", resultMap["deviceType"]) + assert.Equal(t, "another_constant", resultMap["test"]) + } + }, + }, + { + name: "混合使用反引号标识符和字符串常量", + sql: "SELECT `deviceId`, `deviceType`, 'mixed_test' as test_field,'normal' FROM stream WHERE `deviceId` = 'sensor001'", + testData: standardTestData, + expectedLen: 1, + validator: func(t *testing.T, results []map[string]interface{}) { + for _, result := range results { + deviceId := result["deviceId"].(string) + assert.True(t, deviceId == "sensor001") + assert.Equal(t, "mixed_test", result["test_field"]) + assert.Equal(t, "normal", result["normal"]) + assert.Nil(t, result["'normal'"]) + } + }, + }, + { + name: "字符串常量一致性验证", + sql: `SELECT 'single_quote' as test1, "double_quote" as test2 FROM stream LIMIT 1`, + testData: []interface{}{map[string]interface{}{"deviceId": "test001", "deviceType": "test"}}, + expectedLen: 1, + validator: func(t *testing.T, results []map[string]interface{}) { + if len(results) > 0 { + resultMap := results[0] + assert.Equal(t, "single_quote", resultMap["test1"]) + assert.Equal(t, "double_quote", resultMap["test2"]) + } + }, + }, + } + + // 执行所有测试用例 + for _, tc := range testCases { + executeTestCase(t, streamsql, tc) + } +} + +// TestStringConstantExpressions 测试字符串常量表达式 +func TestStringConstantExpressions(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + // 通用测试数据 + testData := []interface{}{ + map[string]interface{}{"deviceId": "sensor001", "deviceType": "temperature"}, + map[string]interface{}{"deviceId": "device002", "deviceType": "humidity"}, + map[string]interface{}{"deviceId": "sensor003", "deviceType": "pressure"}, + } + + // 字符串常量验证函数 + stringConstantValidator := func(expectedValue string) func(t *testing.T, results []map[string]interface{}) { + return func(t *testing.T, results []map[string]interface{}) { + for _, result := range results { + deviceId := result["deviceId"].(string) + assert.True(t, deviceId == "sensor001" || deviceId == "sensor003") + assert.Equal(t, expectedValue, result["test"]) + } + } + } + + testCases := []testCase{ + { + name: "单引号字符串常量作为表达式字段", + sql: "SELECT deviceId, deviceType, 'aa' as test FROM stream WHERE deviceId LIKE 'sensor%'", + testData: testData, + expectedLen: 2, + validator: stringConstantValidator("aa"), + }, + { + name: "双引号字符串常量作为表达式字段", + sql: `SELECT deviceId, deviceType, "aa" as test FROM stream WHERE deviceId LIKE 'sensor%'`, + testData: testData, + expectedLen: 2, + validator: stringConstantValidator("aa"), + }, + } + + // 执行所有测试用例 + for _, tc := range testCases { + executeTestCase(t, streamsql, tc) + } +} + +// TestAggregationWithQuotedIdentifiers 测试聚合函数与反引号标识符的结合使用 +func TestAggregationWithQuotedIdentifiers(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + // 聚合测试数据 + aggregationTestData := []interface{}{ + map[string]interface{}{"deviceId": "sensor001", "temperature": 25.5}, + map[string]interface{}{"deviceId": "sensor001", "temperature": 26.0}, + map[string]interface{}{"deviceId": "sensor002", "temperature": 30.0}, + } + + // 聚合结果验证函数 + aggregationValidator := func(t *testing.T, results []map[string]interface{}) { + resultSlice := results + assert.Len(t, resultSlice, 2) // 应该有两个设备的聚合结果 + + for _, item := range resultSlice { + if item["deviceId"] == "sensor001" { + assert.Equal(t, 25.75, item["avg_temp"]) // (25.5 + 26.0) / 2 = 25.75 + assert.Equal(t, float64(2), item["device_count"]) + } else if item["deviceId"] == "sensor002" { + assert.Equal(t, 30.0, item["avg_temp"]) + assert.Equal(t, float64(1), item["device_count"]) + } + } + } + + testCases := []testCase{ + { + name: "聚合函数与字段组合", + sql: "SELECT deviceId, AVG(temperature) as avg_temp, COUNT(deviceId) as device_count FROM stream GROUP BY deviceId, TumblingWindow('1s')", + testData: aggregationTestData, + validator: aggregationValidator, + }, + } + + // 执行所有聚合测试用例 + for _, tc := range testCases { + executeAggregationTestCase(t, streamsql, tc) + } +} + +// TestCustomFunctionWithQuotedIdentifiers 测试自定义函数与反引号标识符和字符串常量的参数传递 +func TestCustomFunctionWithQuotedIdentifiers(t *testing.T) { + // 注册测试函数 + registerTestFunctions(t) + defer unregisterTestFunctions() + + streamsql := New() + defer streamsql.Stop() + + testCases := []testCase{ + { + name: "函数参数:字段值vs字符串常量", + sql: "SELECT deviceId, func01(temperature) as squared_temp, func02('temperature') as string_length FROM stream WHERE deviceId = 'sensor001'", + testData: []interface{}{map[string]interface{}{"deviceId": "sensor001", "temperature": 5.0}, map[string]interface{}{"deviceId": "sensor002", "temperature": 10.0}}, + validator: func(t *testing.T, results []map[string]interface{}) { + resultSlice := results + assert.Len(t, resultSlice, 1) + item := resultSlice[0] + assert.Equal(t, "sensor001", item["deviceId"]) + assert.Equal(t, 25.0, item["squared_temp"]) // func01(5.0) = 25.0 + assert.Equal(t, 11, item["string_length"]) // func02('temperature') = 11 + }, + }, + { + name: "反引号标识符作为函数参数", + sql: "SELECT deviceId, func01(temperature) as squared_temp, get_type(deviceId) as device_type FROM stream WHERE deviceId = 'sensor001'", + testData: []interface{}{map[string]interface{}{"deviceId": "sensor001", "temperature": 6.0}, map[string]interface{}{"deviceId": "sensor002", "temperature": 8.0}}, + validator: func(t *testing.T, results []map[string]interface{}) { + resultSlice := results + assert.Len(t, resultSlice, 1) + item := resultSlice[0] + assert.Equal(t, "sensor001", item["deviceId"]) + assert.Equal(t, 36.0, item["squared_temp"]) // func01(6.0) = 36.0 + assert.Contains(t, item["device_type"], "sensor001") + }, + }, + { + name: "混合使用字段值和字符串常量", + sql: `SELECT deviceId, func01(temperature) as field_result, func02("constant_string") as const_result, get_type('literal') as literal_type FROM stream LIMIT 1`, + testData: []interface{}{map[string]interface{}{"deviceId": "test001", "temperature": 7.0}}, + validator: func(t *testing.T, results []map[string]interface{}) { + resultSlice := results + assert.Len(t, resultSlice, 1) + item := resultSlice[0] + assert.Equal(t, "test001", item["deviceId"]) + assert.Equal(t, 49.0, item["field_result"]) // func01(7.0) = 49.0 + assert.Equal(t, 15, item["const_result"]) // func02("constant_string") = 15 + assert.Contains(t, item["literal_type"], "literal") + }, + }, + } + + // 执行所有函数测试用例 + for _, tc := range testCases { + executeFunctionTestCase(t, streamsql, tc) + } +} + +// registerTestFunctions 注册测试用的自定义函数 +func registerTestFunctions(t *testing.T) { + // 注册测试函数:接收字段值并返回其平方 + err := functions.RegisterCustomFunction( + "func01", + functions.TypeMath, + "测试函数", + "计算数值的平方", + 1, 1, + func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + val := cast.ToFloat64(args[0]) + return val * val, nil + }, + ) + assert.NoError(t, err) + + // 注册测试函数:接收字符串并返回其长度 + err = functions.RegisterCustomFunction( + "func02", + functions.TypeString, + "测试函数", + "计算字符串长度", + 1, 1, + func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + str := cast.ToString(args[0]) + return len(str), nil + }, + ) + assert.NoError(t, err) + + // 注册测试函数:接收参数并返回其类型信息 + err = functions.RegisterCustomFunction( + "get_type", + functions.TypeCustom, + "测试函数", + "获取参数类型", + 1, 1, + func(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + return fmt.Sprintf("%T:%v", args[0], args[0]), nil + }, + ) + assert.NoError(t, err) +} + +// unregisterTestFunctions 注销测试用的自定义函数 +func unregisterTestFunctions() { + functions.Unregister("func01") + functions.Unregister("func02") + functions.Unregister("get_type") +} diff --git a/streamsql_sync_sink_test.go b/streamsql_sync_sink_test.go new file mode 100644 index 0000000..ba4121a --- /dev/null +++ b/streamsql_sync_sink_test.go @@ -0,0 +1,306 @@ +/* + * Copyright 2025 The RuleGo Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamsql + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEmitSyncWithAddSink 测试EmitSync同时触发AddSink回调 +func TestEmitSyncWithAddSink(t *testing.T) { + t.Run("非聚合查询同步+异步结果", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行非聚合查询 - 测试反引号字段与字符串常量的混合用法 + sql := "SELECT `temperature`, humidity, `temperature` * 1.8 + 32 as temp_fahrenheit, 'normal' as status, 'sensor_data' as data_type FROM stream WHERE temperature > 20" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 验证是非聚合查询 + assert.False(t, ssql.IsAggregationQuery()) + + // 设置AddSink回调来收集异步结果 + var sinkCallCount int32 + var sinkResults []interface{} + var sinkResultsMux sync.Mutex // 保护sinkResults访问 + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sinkCallCount, 1) + sinkResultsMux.Lock() + sinkResults = append(sinkResults, result) + sinkResultsMux.Unlock() + }) + + // 测试数据 + testData := []map[string]interface{}{ + {"temperature": 25.0, "humidity": 60.0}, // 符合条件 + {"temperature": 15.0, "humidity": 70.0}, // 被过滤 + {"temperature": 30.0, "humidity": 80.0}, // 符合条件 + } + + var syncResults []interface{} + + // 处理测试数据 + for _, data := range testData { + // 同步处理 + result, err := ssql.EmitSync(data) + require.NoError(t, err) + + if result != nil { + syncResults = append(syncResults, result) + } + } + + // 等待异步回调完成 + time.Sleep(100 * time.Millisecond) + + // 验证同步结果 + assert.Equal(t, 2, len(syncResults), "应该有2条同步结果(温度>20)") + + // 安全读取异步回调结果 + sinkResultsMux.Lock() + finalSinkResults := make([]interface{}, len(sinkResults)) + copy(finalSinkResults, sinkResults) + sinkResultsMux.Unlock() + + // 验证异步回调结果 + finalSinkCallCount := atomic.LoadInt32(&sinkCallCount) + assert.Equal(t, int32(2), finalSinkCallCount, "AddSink应该被调用2次") + assert.Equal(t, 2, len(finalSinkResults), "应该收集到2条异步结果") + + // 验证同步和异步结果的内容一致性 + if len(syncResults) > 0 && len(finalSinkResults) > 0 { + // 将结果转换为可比较的格式 + syncTemperatures := make([]float64, 0, len(syncResults)) + syncHumidities := make([]float64, 0, len(syncResults)) + asyncTemperatures := make([]float64, 0, len(finalSinkResults)) + asyncHumidities := make([]float64, 0, len(finalSinkResults)) + + // 收集同步结果 + for _, result := range syncResults { + if syncResult, ok := result.(map[string]interface{}); ok { + syncTemperatures = append(syncTemperatures, syncResult["temperature"].(float64)) + syncHumidities = append(syncHumidities, syncResult["humidity"].(float64)) + + // 验证字符串常量字段 + assert.Equal(t, "normal", syncResult["status"], "status字段应该是常量'normal'") + assert.Equal(t, "sensor_data", syncResult["data_type"], "data_type字段应该是常量'sensor_data'") + + // 验证反引号字段的数学运算 + expectedFahrenheit := syncResult["temperature"].(float64)*1.8 + 32 + assert.InDelta(t, expectedFahrenheit, syncResult["temp_fahrenheit"].(float64), 0.01, "华氏温度转换应该正确") + + // 验证结果包含所有预期字段 + assert.Contains(t, syncResult, "temperature", "应该包含temperature字段") + assert.Contains(t, syncResult, "humidity", "应该包含humidity字段") + assert.Contains(t, syncResult, "temp_fahrenheit", "应该包含temp_fahrenheit字段") + assert.Contains(t, syncResult, "status", "应该包含status字段") + assert.Contains(t, syncResult, "data_type", "应该包含data_type字段") + } + } + + // 收集异步结果 + for _, result := range finalSinkResults { + if sinkResultArray, ok := result.([]map[string]interface{}); ok && len(sinkResultArray) > 0 { + sinkResult := sinkResultArray[0] + asyncTemperatures = append(asyncTemperatures, sinkResult["temperature"].(float64)) + asyncHumidities = append(asyncHumidities, sinkResult["humidity"].(float64)) + } + } + + // 验证结果集合是否一致(不考虑顺序) + assert.ElementsMatch(t, syncTemperatures, asyncTemperatures, "温度值集合应该一致") + assert.ElementsMatch(t, syncHumidities, asyncHumidities, "湿度值集合应该一致") + + // 验证预期的数值是否都存在 + assert.Contains(t, syncTemperatures, 25.0, "同步结果应包含25.0") + assert.Contains(t, syncTemperatures, 30.0, "同步结果应包含30.0") + assert.Contains(t, asyncTemperatures, 25.0, "异步结果应包含25.0") + assert.Contains(t, asyncTemperatures, 30.0, "异步结果应包含30.0") + } + }) + + t.Run("聚合查询不支持EmitSync", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行聚合查询 + sql := "SELECT AVG(temperature) as avg_temp FROM stream GROUP BY TumblingWindow('1s')" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 验证是聚合查询 + assert.True(t, ssql.IsAggregationQuery()) + + // 尝试同步处理应该返回错误 + data := map[string]interface{}{"temperature": 25.0} + result, err := ssql.EmitSync(data) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "同步模式仅支持非聚合查询") + }) + + t.Run("多个AddSink回调都被触发", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行非聚合查询 + sql := "SELECT temperature FROM stream" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 添加多个AddSink回调,使用原子操作确保线程安全 + var sink1Count, sink2Count, sink3Count int32 + + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sink1Count, 1) + }) + + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sink2Count, 1) + }) + + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sink3Count, 1) + }) + + // 处理一条数据 + data := map[string]interface{}{"temperature": 25.0} + result, err := ssql.EmitSync(data) + require.NoError(t, err) + require.NotNil(t, result) + + // 等待异步回调 + time.Sleep(100 * time.Millisecond) + + // 验证所有回调都被触发 + assert.Equal(t, int32(1), atomic.LoadInt32(&sink1Count)) + assert.Equal(t, int32(1), atomic.LoadInt32(&sink2Count)) + assert.Equal(t, int32(1), atomic.LoadInt32(&sink3Count)) + }) + + t.Run("过滤条件不匹配时AddSink不触发", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 执行带过滤条件的查询 + sql := "SELECT temperature FROM stream WHERE temperature > 30" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 添加AddSink回调 + var sinkCallCount int32 + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sinkCallCount, 1) + }) + + // 处理不符合条件的数据 + data := map[string]interface{}{"temperature": 20.0} // 不符合 > 30 的条件 + result, err := ssql.EmitSync(data) + require.NoError(t, err) + assert.Nil(t, result, "不符合过滤条件应该返回nil") + + // 等待可能的异步回调 + time.Sleep(100 * time.Millisecond) + + // 验证AddSink没有被触发 + assert.Equal(t, int32(0), atomic.LoadInt32(&sinkCallCount), "过滤掉的数据不应触发AddSink") + }) + + // 新增测试:字符串常量与反引号字段的复杂混合用法 + t.Run("字符串常量与反引号字段混合用法", func(t *testing.T) { + ssql := New() + defer ssql.Stop() + + // 测试包含多种字符串常量的SQL查询 + sql := "SELECT `temperature` as temp, 'celsius' as unit, 'high' as level, `humidity`, 'percent' as humidity_unit FROM stream WHERE temperature > 20" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 测试数据 + testData := map[string]interface{}{ + "temperature": 25.5, + "humidity": 65.0, + } + + // 同步处理 + result, err := ssql.EmitSync(testData) + require.NoError(t, err) + require.NotNil(t, result) + + if syncResult, ok := result.(map[string]interface{}); ok { + // 验证反引号字段 + assert.Equal(t, 25.5, syncResult["temp"], "温度字段应该正确") + assert.Equal(t, 65.0, syncResult["humidity"], "湿度字段应该正确") + + // 验证字符串常量字段 + assert.Equal(t, "celsius", syncResult["unit"], "单位应该是celsius") + assert.Equal(t, "high", syncResult["level"], "级别应该是high") + assert.Equal(t, "percent", syncResult["humidity_unit"], "湿度单位应该是percent") + } + }) +} + +// TestEmitSyncPerformance 测试EmitSync性能(包括AddSink触发) +func TestEmitSyncPerformance(t *testing.T) { + ssql := New() + defer ssql.Stop() + + sql := "SELECT temperature, humidity FROM stream WHERE temperature > 0" + err := ssql.Execute(sql) + require.NoError(t, err) + + // 添加AddSink回调,使用原子操作确保线程安全 + var sinkCallCount int32 + ssql.AddSink(func(result interface{}) { + atomic.AddInt32(&sinkCallCount, 1) + }) + + // 性能测试 + testCount := 1000 + + start := time.Now() + for i := 0; i < testCount; i++ { + data := map[string]interface{}{ + "temperature": float64(20 + i%20), + "humidity": float64(50 + i%30), + } + + result, err := ssql.EmitSync(data) + require.NoError(t, err) + require.NotNil(t, result) + } + duration := time.Since(start) + + // 等待所有异步回调完成 + time.Sleep(200 * time.Millisecond) + + t.Logf("处理 %d 条数据耗时: %v", testCount, duration) + t.Logf("平均每条数据: %v", duration/time.Duration(testCount)) + t.Logf("AddSink 触发次数: %d", atomic.LoadInt32(&sinkCallCount)) + + // 验证性能和一致性 + assert.Less(t, duration, 1*time.Second, "性能应该足够好") + assert.Equal(t, int32(testCount), atomic.LoadInt32(&sinkCallCount), "所有数据都应触发AddSink") +} diff --git a/streamsql_table_print_test.go b/streamsql_table_print_test.go new file mode 100644 index 0000000..6aaf0d1 --- /dev/null +++ b/streamsql_table_print_test.go @@ -0,0 +1,55 @@ +package streamsql + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestPrintTable 测试PrintTable方法的基本功能 +func TestPrintTable(t *testing.T) { + // 创建StreamSQL实例并测试PrintTable + ssql := New() + err := ssql.Execute("SELECT device, AVG(temperature) as avg_temp FROM stream GROUP BY device, TumblingWindow('2s')") + assert.NoError(t, err) + + // 使用PrintTable方法(不验证输出内容,只确保不会panic) + assert.NotPanics(t, func() { + ssql.PrintTable() + }, "PrintTable方法不应该panic") + + // 发送测试数据 + testData := []map[string]interface{}{ + {"device": "sensor1", "temperature": 25.0}, + {"device": "sensor2", "temperature": 30.0}, + } + + for _, data := range testData { + ssql.Emit(data) + } + + // 等待窗口触发 + time.Sleep(3 * time.Second) +} + +// TestPrintTableFormat 测试printTableFormat方法处理不同数据类型 +func TestPrintTableFormat(t *testing.T) { + ssql := New() + + // 测试不同类型的数据,确保不会panic + assert.NotPanics(t, func() { + // 测试空切片 + ssql.printTableFormat([]map[string]interface{}{}) + }, "空切片不应该panic") + + assert.NotPanics(t, func() { + // 测试单个map + ssql.printTableFormat(map[string]interface{}{"key": "value"}) + }, "单个map不应该panic") + + assert.NotPanics(t, func() { + // 测试其他类型 + ssql.printTableFormat("string data") + }, "字符串数据不应该panic") +} \ No newline at end of file diff --git a/streamsql_test.go b/streamsql_test.go index f2c0aa3..c36e374 100644 --- a/streamsql_test.go +++ b/streamsql_test.go @@ -70,8 +70,8 @@ func TestStreamData(t *testing.T) { "humidity": 50.0 + rand.Float64()*20, // 湿度范围: 50-70% } // 将数据添加到流中,触发 StreamSQL 的实时处理 - // AddData 会将数据分发到相应的窗口和聚合器中 - ssql.stream.AddData(randomData) + // Emit 会将数据分发到相应的窗口和聚合器中 + ssql.Emit(randomData) } case <-ctx.Done(): @@ -131,7 +131,7 @@ func TestStreamsql(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 捕获结果 resultChan := make(chan interface{}) @@ -201,7 +201,7 @@ func TestStreamsqlWithoutGroupBy(t *testing.T) { } for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 捕获结果 resultChan := make(chan interface{}) @@ -272,7 +272,7 @@ func TestStreamsqlDistinct(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -350,7 +350,7 @@ func TestStreamsqlLimit(t *testing.T) { streamsql := New() defer streamsql.Stop() - var rsql = "SELECT device, temperature FROM stream LIMIT 2" + var rsql = "SELECT * FROM stream LIMIT 2" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream @@ -374,7 +374,7 @@ func TestStreamsqlLimit(t *testing.T) { // 实时验证:添加一条数据,立即验证一条结果 for i, data := range testData { // 添加数据 - strm.AddData(data) + strm.Emit(data) // 立即等待并验证结果 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -438,7 +438,7 @@ func TestStreamsqlLimit(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待聚合 @@ -513,7 +513,7 @@ func TestStreamsqlLimit(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口触发 @@ -586,7 +586,7 @@ func TestStreamsqlLimit(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待聚合 @@ -657,7 +657,7 @@ func TestSimpleQuery(t *testing.T) { // 发送数据 //fmt.Println("添加数据...") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待结果 @@ -713,7 +713,7 @@ func TestHavingClause(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -796,7 +796,7 @@ func TestSessionWindow(t *testing.T) { if item.wait > 0 { time.Sleep(item.wait) } - strm.AddData(item.data) + strm.Emit(item.data) } // 等待会话超时,使最后一个会话触发 @@ -889,7 +889,7 @@ func TestExpressionInAggregation(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -974,7 +974,7 @@ func TestAdvancedFunctionsInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1073,7 +1073,7 @@ func TestCustomFunctionInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1158,7 +1158,7 @@ func TestNewAggregateFunctionsInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1268,7 +1268,7 @@ func TestStatisticalAggregateFunctionsInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1370,7 +1370,7 @@ func TestDeduplicateAggregateInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1480,7 +1480,7 @@ func TestExprAggregationFunctions(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1637,7 +1637,7 @@ func TestAnalyticalFunctionsInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1734,7 +1734,7 @@ func TestLagFunctionInSQL(t *testing.T) { //fmt.Println("添加测试数据:", testData) for _, data := range testData { //fmt.Printf("添加第%d个数据: temperature=%.1f\n", i+1, data.(map[string]interface{})["temperature"]) - strm.AddData(data) + strm.Emit(data) time.Sleep(100 * time.Millisecond) // 稍微延迟确保顺序 } @@ -1832,7 +1832,7 @@ func TestHadChangedFunctionInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -1912,7 +1912,7 @@ func TestLatestFunctionInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -2004,7 +2004,7 @@ func TestChangedColFunctionInSQL(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -2085,7 +2085,7 @@ func TestAnalyticalFunctionsIncrementalComputation(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -2184,7 +2184,7 @@ func TestIncrementalComputationBasic(t *testing.T) { // 添加数据 //fmt.Println("添加测试数据") for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 创建结果接收通道 @@ -2288,7 +2288,7 @@ func TestExprFunctions(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待结果 @@ -2374,7 +2374,7 @@ func TestExprFunctionsInAggregation(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待窗口初始化 @@ -2450,7 +2450,7 @@ func TestNestedExprFunctions(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待结果 @@ -2540,7 +2540,7 @@ func TestExprFunctionsWithStreamSQLFunctions(t *testing.T) { // 添加数据 for _, data := range testData { - strm.AddData(data) + strm.Emit(data) } // 等待结果 @@ -2591,3 +2591,377 @@ func TestExprFunctionsWithStreamSQLFunctions(t *testing.T) { } } } + +// TestSelectAllFeature 专门测试SELECT *功能 +func TestSelectAllFeature(t *testing.T) { + // 测试场景1:基本SELECT *查询 + t.Run("基本SELECT *查询", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := map[string]interface{}{ + "device": "sensor001", + "temperature": 25.5, + "humidity": 60, + "location": "room1", + "status": "active", + } + + // 发送数据 + strm.Emit(testData) + + // 等待结果 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + select { + case result := <-resultChan: + // 验证结果 + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + require.Len(t, resultSlice, 1, "应该只有一条结果") + + item := resultSlice[0] + // 验证所有原始字段都存在 + assert.Equal(t, "sensor001", item["device"], "device字段应该正确") + assert.Equal(t, 25.5, item["temperature"], "temperature字段应该正确") + assert.Equal(t, 60, item["humidity"], "humidity字段应该正确") + assert.Equal(t, "room1", item["location"], "location字段应该正确") + assert.Equal(t, "active", item["status"], "status字段应该正确") + + // 验证字段数量 + assert.Len(t, item, 5, "应该包含所有5个字段") + + cancel() + case <-ctx.Done(): + t.Fatal("测试超时,未收到结果") + } + }) + + // 测试场景2:SELECT * + WHERE条件 + t.Run("SELECT * + WHERE条件", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream WHERE temperature > 20" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"device": "sensor1", "temperature": 25.0, "humidity": 60}, // 应该被包含 + {"device": "sensor2", "temperature": 15.0, "humidity": 70}, // 应该被过滤掉 + {"device": "sensor3", "temperature": 30.0, "humidity": 50}, // 应该被包含 + } + + var results []interface{} + var resultsMutex sync.Mutex + + // 发送数据 + for _, data := range testData { + strm.Emit(data) + + // 立即检查结果 + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + select { + case result := <-resultChan: + resultsMutex.Lock() + results = append(results, result) + resultsMutex.Unlock() + cancel() + case <-ctx.Done(): + cancel() + // 对于不满足条件的数据,超时是正常的 + } + } + + // 验证结果 + resultsMutex.Lock() + finalResultCount := len(results) + resultsCopy := make([]interface{}, len(results)) + copy(resultsCopy, results) + resultsMutex.Unlock() + + assert.Equal(t, 2, finalResultCount, "应该有2条记录满足条件") + + // 验证结果内容 + deviceFound := make(map[string]bool) + for _, result := range resultsCopy { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + require.Len(t, resultSlice, 1, "每个结果应该只有一条记录") + + item := resultSlice[0] + device, _ := item["device"].(string) + temp, _ := item["temperature"].(float64) + + // 验证温度条件 + assert.Greater(t, temp, 20.0, "温度应该大于20") + + // 记录找到的设备 + deviceFound[device] = true + + // 验证所有字段都存在 + assert.Contains(t, item, "device", "应该包含device字段") + assert.Contains(t, item, "temperature", "应该包含temperature字段") + assert.Contains(t, item, "humidity", "应该包含humidity字段") + } + + // 验证正确的设备被包含 + assert.True(t, deviceFound["sensor1"], "sensor1应该被包含") + assert.True(t, deviceFound["sensor3"], "sensor3应该被包含") + assert.False(t, deviceFound["sensor2"], "sensor2不应该被包含") + }) + + // 测试场景3:SELECT * + LIMIT + t.Run("SELECT * + LIMIT", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream LIMIT 2" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"device": "sensor1", "temperature": 25.0}, + {"device": "sensor2", "temperature": 26.0}, + {"device": "sensor3", "temperature": 27.0}, + {"device": "sensor4", "temperature": 28.0}, + } + + var results []interface{} + var resultsMutex sync.Mutex + + // 发送数据 + for _, data := range testData { + strm.Emit(data) + + // 立即检查结果 + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + select { + case result := <-resultChan: + resultsMutex.Lock() + results = append(results, result) + resultsMutex.Unlock() + cancel() + case <-ctx.Done(): + cancel() + } + } + + // 验证结果 + resultsMutex.Lock() + finalResultCount := len(results) + resultsCopy := make([]interface{}, len(results)) + copy(resultsCopy, results) + resultsMutex.Unlock() + + assert.GreaterOrEqual(t, finalResultCount, 2, "应该至少有2条结果") + + // 验证结果内容 + for _, result := range resultsCopy { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 验证LIMIT限制:每个batch最多2条记录 + assert.LessOrEqual(t, len(resultSlice), 2, "每个batch最多2条记录") + assert.Greater(t, len(resultSlice), 0, "应该有结果") + + // 验证字段 + for _, item := range resultSlice { + assert.Contains(t, item, "device", "结果应包含device字段") + assert.Contains(t, item, "temperature", "结果应包含temperature字段") + } + } + }) + + // 测试场景4:SELECT * with嵌套字段 + t.Run("SELECT * with嵌套字段", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加带嵌套字段的测试数据 + testData := map[string]interface{}{ + "device": "sensor001", + "metrics": map[string]interface{}{ + "temperature": 25.5, + "humidity": 60, + }, + "location": map[string]interface{}{ + "building": "A", + "room": "101", + }, + } + + // 发送数据 + strm.Emit(testData) + + // 等待结果 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + select { + case result := <-resultChan: + // 验证结果 + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + require.Len(t, resultSlice, 1, "应该只有一条结果") + + item := resultSlice[0] + // 验证顶级字段 + assert.Equal(t, "sensor001", item["device"], "device字段应该正确") + + // 验证嵌套字段结构被保留 + metrics, ok := item["metrics"].(map[string]interface{}) + assert.True(t, ok, "metrics应该是map类型") + assert.Equal(t, 25.5, metrics["temperature"], "嵌套temperature字段应该正确") + assert.Equal(t, 60, metrics["humidity"], "嵌套humidity字段应该正确") + + location, ok := item["location"].(map[string]interface{}) + assert.True(t, ok, "location应该是map类型") + assert.Equal(t, "A", location["building"], "嵌套building字段应该正确") + assert.Equal(t, "101", location["room"], "嵌套room字段应该正确") + + cancel() + case <-ctx.Done(): + t.Fatal("测试超时,未收到结果") + } + }) +} + +// TestCaseNullValueHandlingInAggregation 测试CASE表达式在聚合函数中正确处理NULL值 +func TestCaseNullValueHandlingInAggregation(t *testing.T) { + sql := `SELECT deviceType, + SUM(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_sum, + COUNT(CASE WHEN temperature > 30 THEN 1 ELSE NULL END) as high_temp_count, + AVG(CASE WHEN temperature > 30 THEN temperature ELSE NULL END) as high_temp_avg + FROM stream + GROUP BY deviceType, TumblingWindow('2s')` + + // 创建StreamSQL实例 + ssql := New() + defer ssql.Stop() + + // 执行SQL + err := ssql.Execute(sql) + require.NoError(t, err) + + // 收集结果 + var results []map[string]interface{} + resultChan := make(chan interface{}, 10) + + ssql.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"deviceType": "sensor", "temperature": 35.0}, // 满足条件 + {"deviceType": "sensor", "temperature": 25.0}, // 不满足条件,返回NULL + {"deviceType": "sensor", "temperature": 32.0}, // 满足条件 + {"deviceType": "monitor", "temperature": 28.0}, // 不满足条件,返回NULL + {"deviceType": "monitor", "temperature": 33.0}, // 满足条件 + } + + for _, data := range testData { + ssql.Emit(data) + } + + // 等待窗口触发 + time.Sleep(3 * time.Second) + + // 收集结果 +collecting: + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-time.After(500 * time.Millisecond): + break collecting + } + } + + // 验证结果 + assert.Len(t, results, 2, "应该有两个设备类型的结果") + + // 验证各个deviceType的结果 + expectedResults := map[string]map[string]interface{}{ + "sensor": { + "high_temp_sum": 67.0, // 35 + 32 + "high_temp_count": 2.0, // COUNT应该忽略NULL + "high_temp_avg": 33.5, // (35 + 32) / 2 + }, + "monitor": { + "high_temp_sum": 33.0, // 只有33 + "high_temp_count": 1.0, // COUNT应该忽略NULL + "high_temp_avg": 33.0, // 只有33 + }, + } + + for _, result := range results { + deviceType := result["deviceType"].(string) + expected := expectedResults[deviceType] + + assert.NotNil(t, expected, "应该有设备类型 %s 的期望结果", deviceType) + + // 验证SUM聚合(忽略NULL值) + assert.Equal(t, expected["high_temp_sum"], result["high_temp_sum"], + "设备类型 %s 的SUM聚合结果应该正确", deviceType) + + // 验证COUNT聚合(忽略NULL值) + assert.Equal(t, expected["high_temp_count"], result["high_temp_count"], + "设备类型 %s 的COUNT聚合结果应该正确", deviceType) + + // 验证AVG聚合(忽略NULL值) + assert.Equal(t, expected["high_temp_avg"], result["high_temp_avg"], + "设备类型 %s 的AVG聚合结果应该正确", deviceType) + } +} diff --git a/types/config.go b/types/config.go index e81f6de..6c7230b 100644 --- a/types/config.go +++ b/types/config.go @@ -15,8 +15,9 @@ type Config struct { FieldAlias map[string]string `json:"fieldAlias"` SimpleFields []string `json:"simpleFields"` FieldExpressions map[string]FieldExpression `json:"fieldExpressions"` + FieldOrder []string `json:"fieldOrder"` // SELECT语句中字段的原始顺序 Where string `json:"where"` - Having string `json:"having"` + Having string `json:"having"` // 功能开关 NeedWindow bool `json:"needWindow"` diff --git a/utils/reflectutil/reflectutil.go b/utils/reflectutil/reflectutil.go new file mode 100644 index 0000000..6b7429b --- /dev/null +++ b/utils/reflectutil/reflectutil.go @@ -0,0 +1,27 @@ +package reflectutil + +import ( + "fmt" + "reflect" +) + +// SafeFieldByName 安全地获取结构体字段 +func SafeFieldByName(v reflect.Value, fieldName string) (reflect.Value, error) { + // 检查Value是否有效 + if !v.IsValid() { + return reflect.Value{}, fmt.Errorf("invalid value") + } + + // 检查是否为结构体类型 + if v.Kind() != reflect.Struct { + return reflect.Value{}, fmt.Errorf("value is not a struct, got %v", v.Kind()) + } + + // 安全地获取字段 + field := v.FieldByName(fieldName) + if !field.IsValid() { + return reflect.Value{}, fmt.Errorf("field %s not found", fieldName) + } + + return field, nil +} diff --git a/utils/table/table.go b/utils/table/table.go new file mode 100644 index 0000000..7c7334a --- /dev/null +++ b/utils/table/table.go @@ -0,0 +1,150 @@ +/* + * Copyright 2025 The RuleGo Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package table + +import ( + "fmt" +) + +// PrintTableFromSlice 从切片数据打印表格 +// 支持自定义字段顺序,如果fieldOrder为空则使用字母排序 +func PrintTableFromSlice(data []map[string]interface{}, fieldOrder []string) { + if len(data) == 0 { + return + } + + // 收集所有列名 + columnSet := make(map[string]bool) + for _, row := range data { + for col := range row { + columnSet[col] = true + } + } + + // 根据字段顺序排列列名 + var columns []string + if len(fieldOrder) > 0 { + // 使用指定的字段顺序 + for _, field := range fieldOrder { + if columnSet[field] { + columns = append(columns, field) + delete(columnSet, field) // 标记已处理 + } + } + // 添加剩余的列(如果有的话) + for col := range columnSet { + columns = append(columns, col) + } + } else { + // 如果没有指定字段顺序,使用简单排序 + columns = make([]string, 0, len(columnSet)) + for col := range columnSet { + columns = append(columns, col) + } + // 简单排序,确保输出一致性 + for i := 0; i < len(columns)-1; i++ { + for j := i + 1; j < len(columns); j++ { + if columns[i] > columns[j] { + columns[i], columns[j] = columns[j], columns[i] + } + } + } + } + + // 计算每列的最大宽度 + colWidths := make([]int, len(columns)) + for i, col := range columns { + colWidths[i] = len(col) // 列名长度 + for _, row := range data { + if val, exists := row[col]; exists { + valStr := fmt.Sprintf("%v", val) + if len(valStr) > colWidths[i] { + colWidths[i] = len(valStr) + } + } + } + // 最小宽度为4 + if colWidths[i] < 4 { + colWidths[i] = 4 + } + } + + // 打印顶部边框 + PrintTableBorder(colWidths) + + // 打印列名 + fmt.Print("|") + for i, col := range columns { + fmt.Printf(" %-*s |", colWidths[i], col) + } + fmt.Println() + + // 打印分隔线 + PrintTableBorder(colWidths) + + // 打印数据行 + for _, row := range data { + fmt.Print("|") + for i, col := range columns { + val := "" + if v, exists := row[col]; exists { + val = fmt.Sprintf("%v", v) + } + fmt.Printf(" %-*s |", colWidths[i], val) + } + fmt.Println() + } + + // 打印底部边框 + PrintTableBorder(colWidths) + + // 打印行数统计 + fmt.Printf("(%d rows)\n", len(data)) +} + +// PrintTableBorder 打印表格边框 +func PrintTableBorder(columnWidths []int) { + fmt.Print("+") + for _, width := range columnWidths { + for i := 0; i < width+2; i++ { + fmt.Print("-") + } + fmt.Print("+") + } + fmt.Println() +} + +// FormatTableData 格式化表格数据,支持多种数据类型 +func FormatTableData(result interface{}, fieldOrder []string) { + switch v := result.(type) { + case []map[string]interface{}: + if len(v) == 0 { + fmt.Println("(0 rows)") + return + } + PrintTableFromSlice(v, fieldOrder) + case map[string]interface{}: + if len(v) == 0 { + fmt.Println("(0 rows)") + return + } + PrintTableFromSlice([]map[string]interface{}{v}, fieldOrder) + default: + // 对于非表格数据,直接打印 + fmt.Printf("Result: %v\n", result) + } +} \ No newline at end of file diff --git a/utils/table/table_test.go b/utils/table/table_test.go new file mode 100644 index 0000000..21b5851 --- /dev/null +++ b/utils/table/table_test.go @@ -0,0 +1,91 @@ +/* + * Copyright 2025 The RuleGo Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package table + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestPrintTableFromSlice 测试表格打印功能 +func TestPrintTableFromSlice(t *testing.T) { + // 测试空数据 + assert.NotPanics(t, func() { + PrintTableFromSlice([]map[string]interface{}{}, nil) + }, "空数据不应该panic") + + // 测试正常数据 + data := []map[string]interface{}{ + {"name": "Alice", "age": 30, "city": "New York"}, + {"name": "Bob", "age": 25, "city": "Los Angeles"}, + } + assert.NotPanics(t, func() { + PrintTableFromSlice(data, nil) + }, "正常数据不应该panic") + + // 测试带字段顺序的数据 + fieldOrder := []string{"name", "city", "age"} + assert.NotPanics(t, func() { + PrintTableFromSlice(data, fieldOrder) + }, "带字段顺序的数据不应该panic") +} + +// TestPrintTableBorder 测试边框打印功能 +func TestPrintTableBorder(t *testing.T) { + // 测试正常宽度 + assert.NotPanics(t, func() { + colWidths := []int{5, 8, 6} + PrintTableBorder(colWidths) + }, "PrintTableBorder不应该panic") + + // 测试空宽度 + assert.NotPanics(t, func() { + PrintTableBorder([]int{}) + }, "空宽度数组不应该panic") +} + +// TestFormatTableData 测试数据格式化功能 +func TestFormatTableData(t *testing.T) { + // 测试切片数据 + sliceData := []map[string]interface{}{ + {"device": "sensor1", "temp": 25.5}, + } + assert.NotPanics(t, func() { + FormatTableData(sliceData, nil) + }, "切片数据不应该panic") + + // 测试单个map数据 + mapData := map[string]interface{}{"device": "sensor1", "temp": 25.5} + assert.NotPanics(t, func() { + FormatTableData(mapData, nil) + }, "map数据不应该panic") + + // 测试其他类型数据 + assert.NotPanics(t, func() { + FormatTableData("string data", nil) + }, "字符串数据不应该panic") + + // 测试空数据 + assert.NotPanics(t, func() { + FormatTableData([]map[string]interface{}{}, nil) + }, "空切片数据不应该panic") + + assert.NotPanics(t, func() { + FormatTableData(map[string]interface{}{}, nil) + }, "空map数据不应该panic") +} \ No newline at end of file diff --git a/window/factory.go b/window/factory.go index 32dbaf0..9bc14ad 100644 --- a/window/factory.go +++ b/window/factory.go @@ -2,10 +2,11 @@ package window import ( "fmt" - "github.com/rulego/streamsql/utils/cast" "reflect" "time" + "github.com/rulego/streamsql/utils/cast" + "github.com/rulego/streamsql/types" )