5 Commits

Author SHA1 Message Date
rulego-team 8475810768 fix:fix: resolve window initialization deadlock in tests 2025-07-30 19:34:27 +08:00
Whki ffb050dc51 Merge pull request #30 from rulego/dev
Dev
2025-07-30 18:22:43 +08:00
rulego-team 48ab3c5276 feat:增加打印table到控制台 2025-07-30 18:15:39 +08:00
rulego-team 9c961f2de0 feat(ast):增加select字段原始顺序列表 2025-07-30 18:13:16 +08:00
rulego-team 596d1cb769 feat:增加反引号单引号和字符串常量的语法识别 #29 2025-07-30 17:27:20 +08:00
20 changed files with 1587 additions and 202 deletions
+2 -49
View File
@@ -29,12 +29,7 @@ jobs:
run: go mod download
- name: Run all tests
run: go test -v -race -timeout 30s ./...
- name: Run CASE expression tests specifically
run: |
echo "Testing CASE expression functionality..."
go test -v -run TestCaseExpression -timeout 20s
run: go test -v -race -timeout 300s ./...
release:
name: Create Release
@@ -61,52 +56,10 @@ jobs:
- name: Download dependencies
run: go mod download
- name: Build binaries
run: |
# Build for multiple platforms
GOOS=linux GOARCH=amd64 go build -o streamsql-linux-amd64 ./...
GOOS=windows GOARCH=amd64 go build -o streamsql-windows-amd64.exe ./...
GOOS=darwin GOARCH=amd64 go build -o streamsql-darwin-amd64 ./...
GOOS=darwin GOARCH=arm64 go build -o streamsql-darwin-arm64 ./...
- name: Generate changelog
id: changelog
run: |
echo "CHANGELOG<<EOF" >> $GITHUB_OUTPUT
echo "## 🚀 StreamSQL $(echo ${{ github.ref }} | sed 's/refs\/tags\///')" >> $GITHUB_OUTPUT
echo "" >> $GITHUB_OUTPUT
echo "### ✨ 新增功能" >> $GITHUB_OUTPUT
echo "- 完善的CASE表达式支持" >> $GITHUB_OUTPUT
echo "- 多条件逻辑表达式 (AND, OR)" >> $GITHUB_OUTPUT
echo "- 数学函数集成" >> $GITHUB_OUTPUT
echo "- 字段提取和引用功能" >> $GITHUB_OUTPUT
echo "" >> $GITHUB_OUTPUT
echo "### 🔧 改进" >> $GITHUB_OUTPUT
echo "- 负数解析优化" >> $GITHUB_OUTPUT
echo "- 字符串和数值混合比较" >> $GITHUB_OUTPUT
echo "- 表达式解析性能提升" >> $GITHUB_OUTPUT
echo "" >> $GITHUB_OUTPUT
echo "### 📋 测试覆盖" >> $GITHUB_OUTPUT
echo "- ✅ 基础CASE表达式解析" >> $GITHUB_OUTPUT
echo "- ✅ 复杂条件组合" >> $GITHUB_OUTPUT
echo "- ✅ 函数调用支持" >> $GITHUB_OUTPUT
echo "- ✅ 字段提取功能" >> $GITHUB_OUTPUT
echo "- ⚠️ 聚合函数中的使用 (部分支持)" >> $GITHUB_OUTPUT
echo "" >> $GITHUB_OUTPUT
echo "---" >> $GITHUB_OUTPUT
echo "📖 **完整文档**: [README.md](README.md) | [中文文档](README_ZH.md)" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Create Release
uses: softprops/action-gh-release@v1
with:
body: ${{ steps.changelog.outputs.CHANGELOG }}
files: |
streamsql-linux-amd64
streamsql-windows-amd64.exe
streamsql-darwin-amd64
streamsql-darwin-arm64
draft: false
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+82
View File
@@ -0,0 +1,82 @@
package main
import (
"fmt"
"time"
"github.com/rulego/streamsql"
)
// main 演示PrintTable方法的使用
func main() {
fmt.Println("=== StreamSQL PrintTable 示例 ===")
// 创建StreamSQL实例
ssql := streamsql.New()
// 示例1: 聚合查询 - 按设备分组统计温度
fmt.Println("\n示例1: 聚合查询结果")
err := ssql.Execute("SELECT device, AVG(temperature) as avg_temp, MAX(temperature) as max_temp FROM stream GROUP BY device, TumblingWindow('3s')")
if err != nil {
fmt.Printf("执行SQL失败: %v\n", err)
return
}
// 使用PrintTable方法以表格形式输出结果
ssql.PrintTable()
// 发送测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 25.5, "timestamp": time.Now()},
{"device": "sensor1", "temperature": 26.0, "timestamp": time.Now()},
{"device": "sensor2", "temperature": 23.8, "timestamp": time.Now()},
{"device": "sensor2", "temperature": 24.2, "timestamp": time.Now()},
{"device": "sensor1", "temperature": 27.1, "timestamp": time.Now()},
}
for _, data := range testData {
ssql.Emit(data)
}
// 等待窗口触发
time.Sleep(4 * time.Second)
// 示例2: 非聚合查询
fmt.Println("\n示例2: 非聚合查询结果")
ssql2 := streamsql.New()
err = ssql2.Execute("SELECT device, temperature, temperature * 1.8 + 32 as fahrenheit FROM stream WHERE temperature > 24")
if err != nil {
fmt.Printf("执行SQL失败: %v\n", err)
return
}
ssql2.PrintTable()
// 发送测试数据
for _, data := range testData {
ssql2.Emit(data)
}
// 等待处理完成
time.Sleep(1 * time.Second)
// 示例3: 对比原始Print方法
fmt.Println("\n示例3: 原始Print方法输出对比")
ssql3 := streamsql.New()
err = ssql3.Execute("SELECT device, COUNT(*) as count FROM stream GROUP BY device, TumblingWindow('2s')")
if err != nil {
fmt.Printf("执行SQL失败: %v\n", err)
return
}
fmt.Println("原始PrintTable方法:")
ssql3.PrintTable()
// 发送数据
for i := 0; i < 3; i++ {
ssql3.Emit(map[string]interface{}{"device": "test_device", "value": i})
}
time.Sleep(3 * time.Second)
fmt.Println("\n=== 示例结束 ===")
}
+50 -13
View File
@@ -423,28 +423,34 @@ func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error)
return float64(len(value)), nil
case TypeField:
// 处理反引号标识符,去除反引号
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号
}
// 支持嵌套字段访问
if fieldpath.IsNestedField(node.Value) {
if val, found := fieldpath.GetNestedField(data, node.Value); found {
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
// 尝试转换为float64
if floatVal, err := convertToFloat(val); err == nil {
return floatVal, nil
}
// 如果不能转换为数字,返回错误
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", node.Value, val)
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
}
} else {
// 原有的简单字段访问
if val, found := data[node.Value]; found {
if val, found := data[fieldName]; found {
// 尝试转换为float64
if floatVal, err := convertToFloat(val); err == nil {
return floatVal, nil
}
// 如果不能转换为数字,返回错误
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", node.Value, val)
return 0, fmt.Errorf("field '%s' value cannot be converted to number: %v", fieldName, val)
}
}
return 0, fmt.Errorf("field '%s' not found", node.Value)
return 0, fmt.Errorf("field '%s' not found", fieldName)
case TypeOperator:
// 计算左右子表达式的值
@@ -817,18 +823,24 @@ func evaluateNodeValue(node *ExprNode, data map[string]interface{}) (interface{}
return value, nil
case TypeField:
// 处理反引号标识符,去除反引号
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号
}
// 支持嵌套字段访问
if fieldpath.IsNestedField(node.Value) {
if val, found := fieldpath.GetNestedField(data, node.Value); found {
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
return val, nil
}
} else {
// 原有的简单字段访问
if val, found := data[node.Value]; found {
if val, found := data[fieldName]; found {
return val, nil
}
}
return nil, fmt.Errorf("field '%s' not found", node.Value)
return nil, fmt.Errorf("field '%s' not found", fieldName)
default:
// 对于其他类型,回退到数值计算
@@ -1110,6 +1122,25 @@ func tokenize(expr string) ([]string, error) {
continue
}
// 处理反引号标识符
if ch == '`' {
start := i
i++ // 跳过开始反引号
// 寻找结束反引号
for i < len(expr) && expr[i] != '`' {
i++
}
if i >= len(expr) {
return nil, fmt.Errorf("unterminated quoted identifier starting at position %d", start)
}
i++ // 跳过结束反引号
tokens = append(tokens, expr[start:i])
continue
}
// 处理标识符(字段名或函数名)
if isLetter(ch) {
start := i
@@ -1913,14 +1944,20 @@ func evaluateNodeValueWithNull(node *ExprNode, data map[string]interface{}) (int
return value, false, nil
case TypeField:
// 处理反引号标识符,去除反引号
fieldName := node.Value
if len(fieldName) >= 2 && fieldName[0] == '`' && fieldName[len(fieldName)-1] == '`' {
fieldName = fieldName[1 : len(fieldName)-1] // 去掉反引号
}
// 支持嵌套字段访问
if fieldpath.IsNestedField(node.Value) {
if val, found := fieldpath.GetNestedField(data, node.Value); found {
if fieldpath.IsNestedField(fieldName) {
if val, found := fieldpath.GetNestedField(data, fieldName); found {
return val, val == nil, nil
}
} else {
// 原有的简单字段访问
if val, found := data[node.Value]; found {
if val, found := data[fieldName]; found {
return val, val == nil, nil
}
}
+37 -3
View File
@@ -151,7 +151,15 @@ func (bridge *ExprBridge) CompileExpressionWithStreamSQLFunctions(expression str
// EvaluateExpression 评估表达式,自动选择最合适的引擎
func (bridge *ExprBridge) EvaluateExpression(expression string, data map[string]interface{}) (interface{}, error) {
// 首先检查是否包含LIKE操作符如果有则进行预处理
// 首先预处理反引号标识符
if bridge.ContainsBacktickIdentifiers(expression) {
processedExpr, err := bridge.PreprocessBacktickIdentifiers(expression)
if err == nil {
expression = processedExpr
}
}
// 检查是否包含LIKE操作符如果有则进行预处理
if bridge.ContainsLikeOperator(expression) {
processedExpr, err := bridge.PreprocessLikeExpression(expression)
if err == nil {
@@ -407,8 +415,9 @@ func (bridge *ExprBridge) isFunctionCall(expression string) bool {
// PreprocessLikeExpression 预处理LIKE表达式转换为expr-lang可理解的函数调用
func (bridge *ExprBridge) PreprocessLikeExpression(expression string) (string, error) {
// 使用正则表达式匹配LIKE模式
// 匹配: field LIKE 'pattern' (允许空模式)
likePattern := `(\w+(?:\.\w+)*)\s+LIKE\s+'([^']*)'`
// 匹配: field LIKE 'pattern' 或 `field` LIKE 'pattern' (允许空模式)
// 支持反引号标识符和普通标识符
likePattern := `((?:` + "`" + `[^` + "`" + `]+` + "`" + `|\w+)(?:\.(?:` + "`" + `[^` + "`" + `]+` + "`" + `|\w+))*)\s+LIKE\s+'([^']*)'`
re, err := regexp.Compile(likePattern)
if err != nil {
return expression, err
@@ -424,6 +433,11 @@ func (bridge *ExprBridge) PreprocessLikeExpression(expression string) (string, e
field := submatches[1]
pattern := submatches[2]
// 处理反引号标识符,去除反引号
if len(field) >= 2 && field[0] == '`' && field[len(field)-1] == '`' {
field = field[1 : len(field)-1] // 去掉反引号
}
// 将LIKE模式转换为相应的函数调用
return bridge.convertLikeToFunction(field, pattern)
})
@@ -476,6 +490,26 @@ func (bridge *ExprBridge) PreprocessIsNullExpression(expression string) (string,
return result, nil
}
// ContainsBacktickIdentifiers 检查表达式是否包含反引号标识符
func (bridge *ExprBridge) ContainsBacktickIdentifiers(expression string) bool {
return strings.Contains(expression, "`")
}
// PreprocessBacktickIdentifiers 预处理反引号标识符,去除反引号
func (bridge *ExprBridge) PreprocessBacktickIdentifiers(expression string) (string, error) {
// 使用正则表达式匹配反引号标识符
// 匹配: `identifier` 或 `nested.field`
backtickPattern := "`([^`]+)`"
re, err := regexp.Compile(backtickPattern)
if err != nil {
return expression, err
}
// 替换所有反引号标识符,去除反引号
result := re.ReplaceAllString(expression, "$1")
return result, nil
}
// convertLikeToFunction 将LIKE模式转换为expr-lang操作符
func (bridge *ExprBridge) convertLikeToFunction(field, pattern string) string {
// 处理空模式
+47 -2
View File
@@ -103,7 +103,15 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
// 如果有别名,用别名作为字段名
simpleFields = append(simpleFields, fieldName+":"+field.Alias)
} else {
simpleFields = append(simpleFields, fieldName)
// 对于没有别名的字段,检查是否为字符串字面量
_, n, _, _ := ParseAggregateTypeWithExpression(fieldName)
if n != "" {
// 如果是字符串字面量,使用解析出的字段名(去掉引号)
simpleFields = append(simpleFields, n)
} else {
// 否则使用原始表达式
simpleFields = append(simpleFields, fieldName)
}
}
}
}
@@ -113,6 +121,9 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
// 构建字段映射和表达式信息
aggs, fields, expressions := buildSelectFieldsWithExpressions(s.Fields)
// 提取字段顺序信息
fieldOrder := extractFieldOrder(s.Fields)
// 构建Stream配置
config := types.Config{
WindowConfig: types.WindowConfig{
@@ -131,6 +142,7 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
SimpleFields: simpleFields,
Having: s.Having,
FieldExpressions: expressions,
FieldOrder: fieldOrder,
}
return &config, s.Condition, nil
@@ -175,10 +187,33 @@ func isAggregationFunction(expr string) bool {
if strings.Contains(expr, "(") && strings.Contains(expr, ")") {
return true
}
return false
}
// extractFieldOrder 从Fields切片中提取字段的原始顺序
// 返回按SELECT语句中出现顺序排列的字段名列表
func extractFieldOrder(fields []Field) []string {
var fieldOrder []string
for _, field := range fields {
// 如果有别名,使用别名作为字段名
if field.Alias != "" {
fieldOrder = append(fieldOrder, field.Alias)
} else {
// 没有别名时,尝试解析表达式获取字段名
_, fieldName, _, _ := ParseAggregateTypeWithExpression(field.Expression)
if fieldName != "" {
// 如果解析出字段名(如字符串字面量),使用解析出的名称
fieldOrder = append(fieldOrder, fieldName)
} else {
// 否则使用原始表达式作为字段名
fieldOrder = append(fieldOrder, field.Expression)
}
}
}
return fieldOrder
}
func extractGroupFields(s *SelectStatement) []string {
var fields []string
for _, f := range s.GroupBy {
@@ -267,6 +302,15 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg
// 提取函数名
funcName := extractFunctionName(exprStr)
if funcName == "" {
// 检查是否是字符串字面量
trimmed := strings.TrimSpace(exprStr)
if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) ||
(strings.HasPrefix(trimmed, "\"") && strings.HasSuffix(trimmed, "\"")) {
// 字符串字面量:使用去掉引号的内容作为字段名
fieldName := trimmed[1 : len(trimmed)-1]
return "expression", fieldName, exprStr, nil
}
// 如果不是函数调用,但包含运算符或关键字,可能是表达式
if strings.ContainsAny(exprStr, "+-*/<>=!&|") ||
strings.Contains(strings.ToUpper(exprStr), "AND") ||
@@ -644,6 +688,7 @@ func buildSelectFieldsWithExpressions(fields []Field) (
// 没有别名的情况,使用表达式本身作为字段名
t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression)
if t != "" && n != "" {
// 对于字符串字面量,使用解析出的字段名(去掉引号)作为键
selectFields[n] = t
fieldMap[n] = n
+39
View File
@@ -12,6 +12,7 @@ const (
TokenIdent
TokenNumber
TokenString
TokenQuotedIdent // 反引号标识符
TokenComma
TokenLParen
TokenRParen
@@ -176,6 +177,8 @@ func (l *Lexer) NextToken() Token {
return l.readStringToken(tokenPos, tokenLine, tokenColumn)
case '"':
return l.readStringToken(tokenPos, tokenLine, tokenColumn)
case '`':
return l.readQuotedIdentToken(tokenPos, tokenLine, tokenColumn)
}
if isLetter(l.ch) {
@@ -439,6 +442,42 @@ func (l *Lexer) readStringToken(pos, line, column int) Token {
return Token{Type: TokenString, Value: value, Pos: pos, Line: line, Column: column}
}
// readQuotedIdentToken 读取反引号标识符token并处理错误
func (l *Lexer) readQuotedIdentToken(pos, line, column int) Token {
startPos := l.pos
l.readChar() // 跳过开头反引号
for l.ch != '`' && l.ch != 0 {
l.readChar()
}
if l.ch == 0 {
// 未闭合的反引号标识符
if l.errorRecovery != nil {
err := &ParseError{
Type: ErrorTypeUnterminatedString,
Message: "Unterminated quoted identifier",
Position: startPos,
Line: line,
Column: column,
Token: "`",
Suggestions: []string{"Add closing backtick '`'"},
Recoverable: true,
}
l.errorRecovery.AddError(err)
}
value := l.input[startPos:l.pos]
return Token{Type: TokenQuotedIdent, Value: value, Pos: pos, Line: line, Column: column}
}
if l.ch == '`' {
l.readChar() // 跳过结尾反引号
}
value := l.input[startPos:l.pos]
return Token{Type: TokenQuotedIdent, Value: value, Pos: pos, Line: line, Column: column}
}
// isValidNumber 验证数字格式
func (l *Lexer) isValidNumber(number string) bool {
if number == "" {
+126
View File
@@ -0,0 +1,126 @@
package rsql
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestQuotedIdentifiers 测试反引号标识符的词法分析
func TestQuotedIdentifiers(t *testing.T) {
t.Run("基本反引号标识符", func(t *testing.T) {
lexer := NewLexer("`deviceId`")
token := lexer.NextToken()
assert.Equal(t, TokenQuotedIdent, token.Type)
assert.Equal(t, "`deviceId`", token.Value)
})
t.Run("包含特殊字符的反引号标识符", func(t *testing.T) {
lexer := NewLexer("`device-id`")
token := lexer.NextToken()
assert.Equal(t, TokenQuotedIdent, token.Type)
assert.Equal(t, "`device-id`", token.Value)
})
t.Run("包含空格的反引号标识符", func(t *testing.T) {
lexer := NewLexer("`device id`")
token := lexer.NextToken()
assert.Equal(t, TokenQuotedIdent, token.Type)
assert.Equal(t, "`device id`", token.Value)
})
t.Run("未闭合的反引号标识符", func(t *testing.T) {
lexer := NewLexer("`deviceId")
errorRecovery := NewErrorRecovery(nil)
lexer.SetErrorRecovery(errorRecovery)
token := lexer.NextToken()
assert.Equal(t, TokenQuotedIdent, token.Type)
assert.True(t, errorRecovery.HasErrors())
errors := errorRecovery.GetErrors()
assert.Equal(t, 1, len(errors))
assert.Equal(t, ErrorTypeUnterminatedString, errors[0].Type)
})
}
// TestStringLiterals 测试字符串常量的词法分析
func TestStringLiterals(t *testing.T) {
t.Run("单引号字符串", func(t *testing.T) {
lexer := NewLexer("'hello world'")
token := lexer.NextToken()
assert.Equal(t, TokenString, token.Type)
assert.Equal(t, "'hello world'", token.Value)
})
t.Run("双引号字符串", func(t *testing.T) {
lexer := NewLexer(`"hello world"`)
token := lexer.NextToken()
assert.Equal(t, TokenString, token.Type)
assert.Equal(t, `"hello world"`, token.Value)
})
t.Run("包含特殊字符的字符串", func(t *testing.T) {
lexer := NewLexer("'test-value_123'")
token := lexer.NextToken()
assert.Equal(t, TokenString, token.Type)
assert.Equal(t, "'test-value_123'", token.Value)
})
t.Run("空字符串", func(t *testing.T) {
lexer := NewLexer("''")
token := lexer.NextToken()
assert.Equal(t, TokenString, token.Type)
assert.Equal(t, "''", token.Value)
})
}
// TestComplexSQL 测试复杂SQL语句的词法分析
func TestComplexSQL(t *testing.T) {
t.Run("包含反引号标识符和字符串常量的SQL", func(t *testing.T) {
sql := "SELECT `deviceId`, deviceType, 'aa' as test FROM stream WHERE `deviceId` LIKE 'sensor%'"
lexer := NewLexer(sql)
// 验证token序列
expectedTokens := []struct {
Type TokenType
Value string
}{
{TokenSELECT, "SELECT"},
{TokenQuotedIdent, "`deviceId`"},
{TokenComma, ","},
{TokenIdent, "deviceType"},
{TokenComma, ","},
{TokenString, "'aa'"},
{TokenAS, "as"},
{TokenIdent, "test"},
{TokenFROM, "FROM"},
{TokenIdent, "stream"},
{TokenWHERE, "WHERE"},
{TokenQuotedIdent, "`deviceId`"},
{TokenLIKE, "LIKE"},
{TokenString, "'sensor%'"},
{TokenEOF, ""},
}
for i, expected := range expectedTokens {
token := lexer.NextToken()
assert.Equal(t, expected.Type, token.Type, "Token %d type mismatch", i)
if expected.Value != "" {
assert.Equal(t, expected.Value, token.Value, "Token %d value mismatch", i)
}
}
})
t.Run("双引号字符串常量", func(t *testing.T) {
sql := `SELECT deviceId, "test value" as name FROM stream`
lexer := NewLexer(sql)
// 跳过前面的token直到字符串
lexer.NextToken() // SELECT
lexer.NextToken() // deviceId
lexer.NextToken() // ,
token := lexer.NextToken() // "test value"
assert.Equal(t, TokenString, token.Type)
assert.Equal(t, `"test value"`, token.Value)
})
}
+8 -6
View File
@@ -81,6 +81,8 @@ func (p *Parser) getTokenTypeName(tokenType TokenType) string {
return ")"
case TokenIdent:
return "identifier"
case TokenQuotedIdent:
return "quoted identifier"
case TokenNumber:
return "number"
case TokenString:
@@ -306,12 +308,12 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error {
shouldAddSpace = false
}
}
} else if len(exprStr) > 0 && currentToken.Type == TokenIdent {
// 检查前一个字符是否是数字,且前面没有空格
if (lastChar[0] >= '0' && lastChar[0] <= '9') && !strings.HasSuffix(exprStr, " ") {
shouldAddSpace = false
}
} else if len(exprStr) > 0 && (currentToken.Type == TokenIdent || currentToken.Type == TokenQuotedIdent) {
// 检查前一个字符是否是数字,且前面没有空格
if (lastChar[0] >= '0' && lastChar[0] <= '9') && !strings.HasSuffix(exprStr, " ") {
shouldAddSpace = false
}
}
if shouldAddSpace {
expr.WriteString(" ")
@@ -385,7 +387,7 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error {
break
}
switch tok.Type {
case TokenIdent, TokenNumber:
case TokenIdent, TokenNumber, TokenQuotedIdent:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, tok.Value)
+308 -89
View File
File diff suppressed because it is too large Load Diff
+30 -10
View File
@@ -22,6 +22,7 @@ import (
"github.com/rulego/streamsql/rsql"
"github.com/rulego/streamsql/stream"
"github.com/rulego/streamsql/types"
"github.com/rulego/streamsql/utils/table"
)
// Streamsql 是StreamSQL流处理引擎的主要接口。
@@ -41,6 +42,9 @@ type Streamsql struct {
// 新增:同步处理模式配置
enableSyncMode bool // 是否启用同步模式(用于非聚合查询)
// 保存原始SELECT字段顺序用于表格输出时保持字段顺序
fieldOrder []string
}
// New 创建一个新的StreamSQL实例。
@@ -125,6 +129,9 @@ func (s *Streamsql) Execute(sql string) error {
return fmt.Errorf("SQL解析失败: %w", err)
}
// 从解析结果中获取字段顺序信息
s.fieldOrder = config.FieldOrder
// 根据性能模式创建流处理器
var streamInstance *stream.Stream
@@ -338,24 +345,37 @@ func (s *Streamsql) AddSink(sink func(interface{})) {
}
}
// Print 打印结果到控制台
// 这是一个便捷方法自动添加一个打印结果的sink函数
// PrintTable 以表格形式打印结果到控制台,类似数据库输出格式
// 首先显示列名,然后逐行显示数据
//
// 支持的数据格式:
// - []map[string]interface{}: 多行记录
// - map[string]interface{}: 单行记录
// - 其他类型: 直接打印
//
// 示例:
//
// // 简单打印结果
// ssql.Print()
// // 表格式打印结果
// ssql.PrintTable()
//
// // 等价于:
// ssql.AddSink(func(result interface{}) {
// fmt.Printf("Ressult: %v\n", result)
// })
func (s *Streamsql) Print() {
// // 输出格式:
// // +--------+----------+
// // | device | max_temp |
// // +--------+----------+
// // | aa | 30.0 |
// // | bb | 22.0 |
// // +--------+----------+
func (s *Streamsql) PrintTable() {
s.AddSink(func(result interface{}) {
fmt.Printf("Ressult: %v\n", result)
s.printTableFormat(result)
})
}
// printTableFormat 格式化打印表格数据
func (s *Streamsql) printTableFormat(result interface{}) {
table.FormatTableData(result, s.fieldOrder)
}
// ToChannel 返回结果通道,用于异步获取处理结果。
// 通过此通道可以以非阻塞方式获取流处理结果。
//
+23 -23
View File
@@ -2,6 +2,7 @@ package streamsql
import (
"fmt"
"sync"
"testing"
"time"
@@ -114,38 +115,33 @@ func TestIsNullOperatorInSQL(t *testing.T) {
// 收集结果
var results []map[string]interface{}
resultChan := make(chan interface{}, 10)
resultsMutex := sync.Mutex{}
ssql.Stream().AddSink(func(result interface{}) {
resultChan <- result
})
// 使用一个done channel来同步
done := make(chan bool, 1)
// 添加测试数据
for _, data := range tc.testData {
ssql.Stream().Emit(data)
}
// 在另一个goroutine中收集结果
go func() {
defer func() { done <- true }()
// 等待一段时间收集结果
timeout := time.After(300 * time.Millisecond)
for {
select {
case result := <-resultChan:
if resultSlice, ok := result.([]map[string]interface{}); ok {
results = append(results, resultSlice...)
}
case <-timeout:
return
}
}
}()
// 使用更短的超时时间避免在CI环境中长时间等待
timeout := time.After(500 * time.Millisecond)
// 等待收集完成
<-done
collecting:
for {
select {
case result := <-resultChan:
resultsMutex.Lock()
if resultSlice, ok := result.([]map[string]interface{}); ok {
results = append(results, resultSlice...)
}
resultsMutex.Unlock()
case <-timeout:
break collecting
}
}
// 验证结果数量
assert.Len(t, results, len(tc.expected), "结果数量应该匹配")
@@ -156,10 +152,12 @@ func TestIsNullOperatorInSQL(t *testing.T) {
expectedDeviceIds[i] = exp["deviceId"].(string)
}
resultsMutex.Lock()
actualDeviceIds := make([]string, len(results))
for i, result := range results {
actualDeviceIds[i] = result["deviceId"].(string)
}
resultsMutex.Unlock()
// 验证每个期望的设备ID都在结果中
for _, expectedId := range expectedDeviceIds {
@@ -167,6 +165,7 @@ func TestIsNullOperatorInSQL(t *testing.T) {
}
// 验证每个结果的字段值
resultsMutex.Lock()
for _, result := range results {
deviceId := result["deviceId"].(string)
// 找到对应的期望结果
@@ -186,6 +185,7 @@ func TestIsNullOperatorInSQL(t *testing.T) {
}
}
}
resultsMutex.Unlock()
})
}
}
@@ -424,7 +424,7 @@ func TestIsNullWithOtherOperators(t *testing.T) {
// 使用超时方式安全收集结果
var results []map[string]interface{}
timeout := time.After(500 * time.Millisecond)
timeout := time.After(2 * time.Second)
collecting:
for {
@@ -1004,7 +1004,7 @@ func TestMixedNullComparisons(t *testing.T) {
// 使用超时方式安全收集结果
var results []map[string]interface{}
timeout := time.After(500 * time.Millisecond)
timeout := time.After(2 * time.Second)
collecting:
for {
File diff suppressed because it is too large Load Diff
@@ -32,8 +32,8 @@ func TestEmitSyncWithAddSink(t *testing.T) {
ssql := New()
defer ssql.Stop()
// 执行非聚合查询
sql := "SELECT temperature, humidity, temperature * 1.8 + 32 as temp_fahrenheit FROM stream WHERE temperature > 20"
// 执行非聚合查询 - 测试反引号字段与字符串常量的混合用法
sql := "SELECT `temperature`, humidity, `temperature` * 1.8 + 32 as temp_fahrenheit, 'normal' as status, 'sensor_data' as data_type FROM stream WHERE temperature > 20"
err := ssql.Execute(sql)
require.NoError(t, err)
@@ -101,6 +101,21 @@ func TestEmitSyncWithAddSink(t *testing.T) {
if syncResult, ok := result.(map[string]interface{}); ok {
syncTemperatures = append(syncTemperatures, syncResult["temperature"].(float64))
syncHumidities = append(syncHumidities, syncResult["humidity"].(float64))
// 验证字符串常量字段
assert.Equal(t, "normal", syncResult["status"], "status字段应该是常量'normal'")
assert.Equal(t, "sensor_data", syncResult["data_type"], "data_type字段应该是常量'sensor_data'")
// 验证反引号字段的数学运算
expectedFahrenheit := syncResult["temperature"].(float64)*1.8 + 32
assert.InDelta(t, expectedFahrenheit, syncResult["temp_fahrenheit"].(float64), 0.01, "华氏温度转换应该正确")
// 验证结果包含所有预期字段
assert.Contains(t, syncResult, "temperature", "应该包含temperature字段")
assert.Contains(t, syncResult, "humidity", "应该包含humidity字段")
assert.Contains(t, syncResult, "temp_fahrenheit", "应该包含temp_fahrenheit字段")
assert.Contains(t, syncResult, "status", "应该包含status字段")
assert.Contains(t, syncResult, "data_type", "应该包含data_type字段")
}
}
@@ -212,6 +227,39 @@ func TestEmitSyncWithAddSink(t *testing.T) {
// 验证AddSink没有被触发
assert.Equal(t, int32(0), atomic.LoadInt32(&sinkCallCount), "过滤掉的数据不应触发AddSink")
})
// 新增测试:字符串常量与反引号字段的复杂混合用法
t.Run("字符串常量与反引号字段混合用法", func(t *testing.T) {
ssql := New()
defer ssql.Stop()
// 测试包含多种字符串常量的SQL查询
sql := "SELECT `temperature` as temp, 'celsius' as unit, 'high' as level, `humidity`, 'percent' as humidity_unit FROM stream WHERE temperature > 20"
err := ssql.Execute(sql)
require.NoError(t, err)
// 测试数据
testData := map[string]interface{}{
"temperature": 25.5,
"humidity": 65.0,
}
// 同步处理
result, err := ssql.EmitSync(testData)
require.NoError(t, err)
require.NotNil(t, result)
if syncResult, ok := result.(map[string]interface{}); ok {
// 验证反引号字段
assert.Equal(t, 25.5, syncResult["temp"], "温度字段应该正确")
assert.Equal(t, 65.0, syncResult["humidity"], "湿度字段应该正确")
// 验证字符串常量字段
assert.Equal(t, "celsius", syncResult["unit"], "单位应该是celsius")
assert.Equal(t, "high", syncResult["level"], "级别应该是high")
assert.Equal(t, "percent", syncResult["humidity_unit"], "湿度单位应该是percent")
}
})
}
// TestEmitSyncPerformance 测试EmitSync性能包括AddSink触发
+55
View File
@@ -0,0 +1,55 @@
package streamsql
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestPrintTable 测试PrintTable方法的基本功能
func TestPrintTable(t *testing.T) {
// 创建StreamSQL实例并测试PrintTable
ssql := New()
err := ssql.Execute("SELECT device, AVG(temperature) as avg_temp FROM stream GROUP BY device, TumblingWindow('2s')")
assert.NoError(t, err)
// 使用PrintTable方法不验证输出内容只确保不会panic
assert.NotPanics(t, func() {
ssql.PrintTable()
}, "PrintTable方法不应该panic")
// 发送测试数据
testData := []map[string]interface{}{
{"device": "sensor1", "temperature": 25.0},
{"device": "sensor2", "temperature": 30.0},
}
for _, data := range testData {
ssql.Emit(data)
}
// 等待窗口触发
time.Sleep(3 * time.Second)
}
// TestPrintTableFormat 测试printTableFormat方法处理不同数据类型
func TestPrintTableFormat(t *testing.T) {
ssql := New()
// 测试不同类型的数据确保不会panic
assert.NotPanics(t, func() {
// 测试空切片
ssql.printTableFormat([]map[string]interface{}{})
}, "空切片不应该panic")
assert.NotPanics(t, func() {
// 测试单个map
ssql.printTableFormat(map[string]interface{}{"key": "value"})
}, "单个map不应该panic")
assert.NotPanics(t, func() {
// 测试其他类型
ssql.printTableFormat("string data")
}, "字符串数据不应该panic")
}
+2 -1
View File
@@ -15,8 +15,9 @@ type Config struct {
FieldAlias map[string]string `json:"fieldAlias"`
SimpleFields []string `json:"simpleFields"`
FieldExpressions map[string]FieldExpression `json:"fieldExpressions"`
FieldOrder []string `json:"fieldOrder"` // SELECT语句中字段的原始顺序
Where string `json:"where"`
Having string `json:"having"`
Having string `json:"having"`
// 功能开关
NeedWindow bool `json:"needWindow"`
+150
View File
@@ -0,0 +1,150 @@
/*
* Copyright 2025 The RuleGo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package table
import (
"fmt"
)
// PrintTableFromSlice 从切片数据打印表格
// 支持自定义字段顺序如果fieldOrder为空则使用字母排序
func PrintTableFromSlice(data []map[string]interface{}, fieldOrder []string) {
if len(data) == 0 {
return
}
// 收集所有列名
columnSet := make(map[string]bool)
for _, row := range data {
for col := range row {
columnSet[col] = true
}
}
// 根据字段顺序排列列名
var columns []string
if len(fieldOrder) > 0 {
// 使用指定的字段顺序
for _, field := range fieldOrder {
if columnSet[field] {
columns = append(columns, field)
delete(columnSet, field) // 标记已处理
}
}
// 添加剩余的列(如果有的话)
for col := range columnSet {
columns = append(columns, col)
}
} else {
// 如果没有指定字段顺序,使用简单排序
columns = make([]string, 0, len(columnSet))
for col := range columnSet {
columns = append(columns, col)
}
// 简单排序,确保输出一致性
for i := 0; i < len(columns)-1; i++ {
for j := i + 1; j < len(columns); j++ {
if columns[i] > columns[j] {
columns[i], columns[j] = columns[j], columns[i]
}
}
}
}
// 计算每列的最大宽度
colWidths := make([]int, len(columns))
for i, col := range columns {
colWidths[i] = len(col) // 列名长度
for _, row := range data {
if val, exists := row[col]; exists {
valStr := fmt.Sprintf("%v", val)
if len(valStr) > colWidths[i] {
colWidths[i] = len(valStr)
}
}
}
// 最小宽度为4
if colWidths[i] < 4 {
colWidths[i] = 4
}
}
// 打印顶部边框
PrintTableBorder(colWidths)
// 打印列名
fmt.Print("|")
for i, col := range columns {
fmt.Printf(" %-*s |", colWidths[i], col)
}
fmt.Println()
// 打印分隔线
PrintTableBorder(colWidths)
// 打印数据行
for _, row := range data {
fmt.Print("|")
for i, col := range columns {
val := ""
if v, exists := row[col]; exists {
val = fmt.Sprintf("%v", v)
}
fmt.Printf(" %-*s |", colWidths[i], val)
}
fmt.Println()
}
// 打印底部边框
PrintTableBorder(colWidths)
// 打印行数统计
fmt.Printf("(%d rows)\n", len(data))
}
// PrintTableBorder 打印表格边框
func PrintTableBorder(columnWidths []int) {
fmt.Print("+")
for _, width := range columnWidths {
for i := 0; i < width+2; i++ {
fmt.Print("-")
}
fmt.Print("+")
}
fmt.Println()
}
// FormatTableData 格式化表格数据,支持多种数据类型
func FormatTableData(result interface{}, fieldOrder []string) {
switch v := result.(type) {
case []map[string]interface{}:
if len(v) == 0 {
fmt.Println("(0 rows)")
return
}
PrintTableFromSlice(v, fieldOrder)
case map[string]interface{}:
if len(v) == 0 {
fmt.Println("(0 rows)")
return
}
PrintTableFromSlice([]map[string]interface{}{v}, fieldOrder)
default:
// 对于非表格数据,直接打印
fmt.Printf("Result: %v\n", result)
}
}
+91
View File
@@ -0,0 +1,91 @@
/*
* Copyright 2025 The RuleGo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package table
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestPrintTableFromSlice 测试表格打印功能
func TestPrintTableFromSlice(t *testing.T) {
// 测试空数据
assert.NotPanics(t, func() {
PrintTableFromSlice([]map[string]interface{}{}, nil)
}, "空数据不应该panic")
// 测试正常数据
data := []map[string]interface{}{
{"name": "Alice", "age": 30, "city": "New York"},
{"name": "Bob", "age": 25, "city": "Los Angeles"},
}
assert.NotPanics(t, func() {
PrintTableFromSlice(data, nil)
}, "正常数据不应该panic")
// 测试带字段顺序的数据
fieldOrder := []string{"name", "city", "age"}
assert.NotPanics(t, func() {
PrintTableFromSlice(data, fieldOrder)
}, "带字段顺序的数据不应该panic")
}
// TestPrintTableBorder 测试边框打印功能
func TestPrintTableBorder(t *testing.T) {
// 测试正常宽度
assert.NotPanics(t, func() {
colWidths := []int{5, 8, 6}
PrintTableBorder(colWidths)
}, "PrintTableBorder不应该panic")
// 测试空宽度
assert.NotPanics(t, func() {
PrintTableBorder([]int{})
}, "空宽度数组不应该panic")
}
// TestFormatTableData 测试数据格式化功能
func TestFormatTableData(t *testing.T) {
// 测试切片数据
sliceData := []map[string]interface{}{
{"device": "sensor1", "temp": 25.5},
}
assert.NotPanics(t, func() {
FormatTableData(sliceData, nil)
}, "切片数据不应该panic")
// 测试单个map数据
mapData := map[string]interface{}{"device": "sensor1", "temp": 25.5}
assert.NotPanics(t, func() {
FormatTableData(mapData, nil)
}, "map数据不应该panic")
// 测试其他类型数据
assert.NotPanics(t, func() {
FormatTableData("string data", nil)
}, "字符串数据不应该panic")
// 测试空数据
assert.NotPanics(t, func() {
FormatTableData([]map[string]interface{}{}, nil)
}, "空切片数据不应该panic")
assert.NotPanics(t, func() {
FormatTableData(map[string]interface{}{}, nil)
}, "空map数据不应该panic")
}
+11 -1
View File
@@ -135,12 +135,22 @@ func (sw *SessionWindow) Add(data interface{}) {
}
// Start 启动会话窗口的定时检查机制
// Start 启动会话窗口,开始定期检查过期会话
// 采用延迟初始化模式,避免在没有数据时无限等待,同时确保后续数据能正常处理
func (sw *SessionWindow) Start() {
go func() {
<-sw.initChan
// 在函数结束时关闭输出通道
defer close(sw.outputChan)
// 等待初始化完成或上下文取消
select {
case <-sw.initChan:
// 正常初始化完成,继续处理
case <-sw.ctx.Done():
// 上下文被取消,直接退出
return
}
// 定期检查过期会话
sw.tickerMu.Lock()
sw.ticker = time.NewTicker(sw.timeout / 2)
+10 -2
View File
@@ -115,13 +115,21 @@ func (sw *SlidingWindow) Add(data interface{}) {
}
// Start 启动滑动窗口,开始定时触发窗口
// 采用延迟初始化模式,避免在没有数据时无限等待,同时确保后续数据能正常处理
func (sw *SlidingWindow) Start() {
go func() {
// 等待初始化信号
<-sw.initChan
// 在函数结束时关闭输出通道。
defer close(sw.outputChan)
// 等待初始化完成或上下文取消
select {
case <-sw.initChan:
// 正常初始化完成,继续处理
case <-sw.ctx.Done():
// 上下文被取消,直接退出
return
}
for {
// 在每次循环中安全地获取timer
sw.timerMu.Lock()
+10 -1
View File
@@ -130,12 +130,21 @@ func (tw *TumblingWindow) Stop() {
}
// Start 启动滚动窗口的定时触发机制。
// 采用延迟初始化模式,避免在没有数据时无限等待,同时确保后续数据能正常处理
func (tw *TumblingWindow) Start() {
go func() {
<-tw.initChan
// 在函数结束时关闭输出通道。
defer close(tw.outputChan)
// 等待初始化完成或上下文取消
select {
case <-tw.initChan:
// 正常初始化完成,继续处理
case <-tw.ctx.Done():
// 上下文被取消,直接退出
return
}
for {
// 在每次循环中安全地获取timer
tw.timerMu.Lock()