mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-06-30 05:19:58 +00:00
feat: Add query syntax for all fields (select *)
This commit is contained in:
20
rsql/ast.go
20
rsql/ast.go
@ -17,6 +17,7 @@ import (
|
||||
type SelectStatement struct {
|
||||
Fields []Field
|
||||
Distinct bool
|
||||
SelectAll bool // 新增:标识是否是SELECT *查询
|
||||
Source string
|
||||
Condition string
|
||||
Window WindowDefinition
|
||||
@ -92,13 +93,18 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) {
|
||||
|
||||
// 如果没有聚合函数,收集简单字段
|
||||
if !hasAggregation {
|
||||
for _, field := range s.Fields {
|
||||
fieldName := field.Expression
|
||||
if field.Alias != "" {
|
||||
// 如果有别名,用别名作为字段名
|
||||
simpleFields = append(simpleFields, fieldName+":"+field.Alias)
|
||||
} else {
|
||||
simpleFields = append(simpleFields, fieldName)
|
||||
// 如果是SELECT *查询,设置特殊标记
|
||||
if s.SelectAll {
|
||||
simpleFields = append(simpleFields, "*")
|
||||
} else {
|
||||
for _, field := range s.Fields {
|
||||
fieldName := field.Expression
|
||||
if field.Alias != "" {
|
||||
// 如果有别名,用别名作为字段名
|
||||
simpleFields = append(simpleFields, fieldName+":"+field.Alias)
|
||||
} else {
|
||||
simpleFields = append(simpleFields, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Debug("收集简单字段: %v", simpleFields)
|
||||
|
@ -212,6 +212,23 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error {
|
||||
currentToken = p.lexer.NextToken() // 消费 DISTINCT,移动到下一个 token
|
||||
}
|
||||
|
||||
// 检查是否是SELECT *查询
|
||||
if currentToken.Type == TokenIdent && currentToken.Value == "*" {
|
||||
stmt.SelectAll = true
|
||||
// 添加一个特殊的字段标记SELECT *
|
||||
stmt.Fields = append(stmt.Fields, Field{Expression: "*"})
|
||||
|
||||
// 消费*token并检查下一个token
|
||||
currentToken = p.lexer.NextToken()
|
||||
|
||||
// 如果下一个token是FROM或EOF,则完成SELECT *解析
|
||||
if currentToken.Type == TokenFROM || currentToken.Type == TokenEOF {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果不是FROM/EOF,继续正常的字段解析流程
|
||||
}
|
||||
|
||||
// 设置最大字段数量限制,防止无限循环
|
||||
maxFields := 100
|
||||
fieldCount := 0
|
||||
|
@ -573,6 +573,18 @@ func (s *Stream) processDirectData(data interface{}) {
|
||||
// 如果指定了字段,只保留这些字段
|
||||
if len(s.config.SimpleFields) > 0 {
|
||||
for _, fieldSpec := range s.config.SimpleFields {
|
||||
// 处理SELECT *的特殊情况
|
||||
if fieldSpec == "*" {
|
||||
// SELECT *:返回所有字段,但跳过已经通过表达式字段处理的字段
|
||||
for k, v := range dataMap {
|
||||
// 如果该字段已经通过表达式字段处理,则跳过,保持表达式计算结果
|
||||
if _, isExpression := s.config.FieldExpressions[k]; !isExpression {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理别名
|
||||
parts := strings.Split(fieldSpec, ":")
|
||||
fieldName := parts[0]
|
||||
|
@ -754,3 +754,219 @@ func TestStreamsqlPersistenceConfigPassing(t *testing.T) {
|
||||
|
||||
t.Logf("持久化配置验证通过: %+v", stats)
|
||||
}
|
||||
|
||||
func TestSelectStarWithExpressionFields(t *testing.T) {
|
||||
config := types.Config{
|
||||
NeedWindow: false,
|
||||
SimpleFields: []string{"*"}, // SELECT *
|
||||
FieldExpressions: map[string]types.FieldExpression{
|
||||
"name": {
|
||||
Expression: "UPPER(name)",
|
||||
Fields: []string{"name"},
|
||||
},
|
||||
"full_info": {
|
||||
Expression: "CONCAT(name, ' - ', status)",
|
||||
Fields: []string{"name", "status"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := NewStream(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create stream: %v", err)
|
||||
}
|
||||
defer stream.Stop()
|
||||
|
||||
// 收集结果 - 使用sync.Mutex防止数据竞争
|
||||
var mu sync.Mutex
|
||||
var results []interface{}
|
||||
stream.AddSink(func(result interface{}) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
results = append(results, result)
|
||||
})
|
||||
|
||||
stream.Start()
|
||||
|
||||
// 添加测试数据
|
||||
testData := map[string]interface{}{
|
||||
"name": "john",
|
||||
"status": "active",
|
||||
"age": 25,
|
||||
}
|
||||
|
||||
stream.AddData(testData)
|
||||
|
||||
// 等待处理完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证结果 - 使用互斥锁保护读取
|
||||
mu.Lock()
|
||||
resultsLen := len(results)
|
||||
var resultData map[string]interface{}
|
||||
if resultsLen > 0 {
|
||||
resultData = results[0].([]map[string]interface{})[0]
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if resultsLen != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", resultsLen)
|
||||
}
|
||||
|
||||
// 验证表达式字段的结果没有被覆盖
|
||||
if resultData["name"] != "JOHN" {
|
||||
t.Errorf("Expected name to be 'JOHN' (uppercase), got %v", resultData["name"])
|
||||
}
|
||||
|
||||
if resultData["full_info"] != "john - active" {
|
||||
t.Errorf("Expected full_info to be 'john - active', got %v", resultData["full_info"])
|
||||
}
|
||||
|
||||
// 验证原始字段仍然存在
|
||||
if resultData["status"] != "active" {
|
||||
t.Errorf("Expected status to be 'active', got %v", resultData["status"])
|
||||
}
|
||||
|
||||
if resultData["age"] != 25 {
|
||||
t.Errorf("Expected age to be 25, got %v", resultData["age"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectStarWithExpressionFieldsOverride(t *testing.T) {
|
||||
// 测试表达式字段名与原始字段名相同的情况
|
||||
config := types.Config{
|
||||
NeedWindow: false,
|
||||
SimpleFields: []string{"*"}, // SELECT *
|
||||
FieldExpressions: map[string]types.FieldExpression{
|
||||
"name": {
|
||||
Expression: "UPPER(name)",
|
||||
Fields: []string{"name"},
|
||||
},
|
||||
"age": {
|
||||
Expression: "age * 2",
|
||||
Fields: []string{"age"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := NewStream(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create stream: %v", err)
|
||||
}
|
||||
defer stream.Stop()
|
||||
|
||||
// 收集结果 - 使用sync.Mutex防止数据竞争
|
||||
var mu sync.Mutex
|
||||
var results []interface{}
|
||||
stream.AddSink(func(result interface{}) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
results = append(results, result)
|
||||
})
|
||||
|
||||
stream.Start()
|
||||
|
||||
// 添加测试数据
|
||||
testData := map[string]interface{}{
|
||||
"name": "alice",
|
||||
"age": 30,
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
stream.AddData(testData)
|
||||
|
||||
// 等待处理完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证结果 - 使用互斥锁保护读取
|
||||
mu.Lock()
|
||||
resultsLen := len(results)
|
||||
var resultData map[string]interface{}
|
||||
if resultsLen > 0 {
|
||||
resultData = results[0].([]map[string]interface{})[0]
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if resultsLen != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", resultsLen)
|
||||
}
|
||||
|
||||
// 验证表达式字段的结果覆盖了原始字段
|
||||
if resultData["name"] != "ALICE" {
|
||||
t.Errorf("Expected name to be 'ALICE' (expression result), got %v", resultData["name"])
|
||||
}
|
||||
|
||||
// 检查age表达式的结果(可能是int或float64类型)
|
||||
ageResult := resultData["age"]
|
||||
if ageResult != 60 && ageResult != 60.0 {
|
||||
t.Errorf("Expected age to be 60 (expression result), got %v (type: %T)", resultData["age"], resultData["age"])
|
||||
}
|
||||
|
||||
// 验证没有表达式的字段保持原值
|
||||
if resultData["status"] != "active" {
|
||||
t.Errorf("Expected status to be 'active', got %v", resultData["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectStarWithoutExpressionFields(t *testing.T) {
|
||||
// 测试没有表达式字段时SELECT *的行为
|
||||
config := types.Config{
|
||||
NeedWindow: false,
|
||||
SimpleFields: []string{"*"}, // SELECT *
|
||||
}
|
||||
|
||||
stream, err := NewStream(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create stream: %v", err)
|
||||
}
|
||||
defer stream.Stop()
|
||||
|
||||
// 收集结果 - 使用sync.Mutex防止数据竞争
|
||||
var mu sync.Mutex
|
||||
var results []interface{}
|
||||
stream.AddSink(func(result interface{}) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
results = append(results, result)
|
||||
})
|
||||
|
||||
stream.Start()
|
||||
|
||||
// 添加测试数据
|
||||
testData := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"age": 35,
|
||||
"status": "inactive",
|
||||
}
|
||||
|
||||
stream.AddData(testData)
|
||||
|
||||
// 等待处理完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证结果 - 使用互斥锁保护读取
|
||||
mu.Lock()
|
||||
resultsLen := len(results)
|
||||
var resultData map[string]interface{}
|
||||
if resultsLen > 0 {
|
||||
resultData = results[0].([]map[string]interface{})[0]
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if resultsLen != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", resultsLen)
|
||||
}
|
||||
|
||||
// 验证所有原始字段都被保留
|
||||
if resultData["name"] != "bob" {
|
||||
t.Errorf("Expected name to be 'bob', got %v", resultData["name"])
|
||||
}
|
||||
|
||||
if resultData["age"] != 35 {
|
||||
t.Errorf("Expected age to be 35, got %v", resultData["age"])
|
||||
}
|
||||
|
||||
if resultData["status"] != "inactive" {
|
||||
t.Errorf("Expected status to be 'inactive', got %v", resultData["status"])
|
||||
}
|
||||
}
|
||||
|
@ -350,7 +350,7 @@ func TestStreamsqlLimit(t *testing.T) {
|
||||
streamsql := New()
|
||||
defer streamsql.Stop()
|
||||
|
||||
var rsql = "SELECT device, temperature FROM stream LIMIT 2"
|
||||
var rsql = "SELECT * FROM stream LIMIT 2"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
@ -2591,3 +2591,286 @@ func TestExprFunctionsWithStreamSQLFunctions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectAllFeature 专门测试SELECT *功能
|
||||
func TestSelectAllFeature(t *testing.T) {
|
||||
// 测试场景1:基本SELECT *查询
|
||||
t.Run("基本SELECT *查询", func(t *testing.T) {
|
||||
streamsql := New()
|
||||
defer streamsql.Stop()
|
||||
|
||||
var rsql = "SELECT * FROM stream"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
|
||||
// 创建结果接收通道
|
||||
resultChan := make(chan interface{}, 10)
|
||||
|
||||
// 添加结果接收器
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
// 添加测试数据
|
||||
testData := map[string]interface{}{
|
||||
"device": "sensor001",
|
||||
"temperature": 25.5,
|
||||
"humidity": 60,
|
||||
"location": "room1",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
// 发送数据
|
||||
strm.AddData(testData)
|
||||
|
||||
// 等待结果
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
// 验证结果
|
||||
resultSlice, ok := result.([]map[string]interface{})
|
||||
require.True(t, ok, "结果应该是[]map[string]interface{}类型")
|
||||
require.Len(t, resultSlice, 1, "应该只有一条结果")
|
||||
|
||||
item := resultSlice[0]
|
||||
// 验证所有原始字段都存在
|
||||
assert.Equal(t, "sensor001", item["device"], "device字段应该正确")
|
||||
assert.Equal(t, 25.5, item["temperature"], "temperature字段应该正确")
|
||||
assert.Equal(t, 60, item["humidity"], "humidity字段应该正确")
|
||||
assert.Equal(t, "room1", item["location"], "location字段应该正确")
|
||||
assert.Equal(t, "active", item["status"], "status字段应该正确")
|
||||
|
||||
// 验证字段数量
|
||||
assert.Len(t, item, 5, "应该包含所有5个字段")
|
||||
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("测试超时,未收到结果")
|
||||
}
|
||||
})
|
||||
|
||||
// 测试场景2:SELECT * + WHERE条件
|
||||
t.Run("SELECT * + WHERE条件", func(t *testing.T) {
|
||||
streamsql := New()
|
||||
defer streamsql.Stop()
|
||||
|
||||
var rsql = "SELECT * FROM stream WHERE temperature > 20"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
|
||||
// 创建结果接收通道
|
||||
resultChan := make(chan interface{}, 10)
|
||||
|
||||
// 添加结果接收器
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
// 添加测试数据
|
||||
testData := []map[string]interface{}{
|
||||
{"device": "sensor1", "temperature": 25.0, "humidity": 60}, // 应该被包含
|
||||
{"device": "sensor2", "temperature": 15.0, "humidity": 70}, // 应该被过滤掉
|
||||
{"device": "sensor3", "temperature": 30.0, "humidity": 50}, // 应该被包含
|
||||
}
|
||||
|
||||
var results []interface{}
|
||||
var resultsMutex sync.Mutex
|
||||
|
||||
// 发送数据
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
|
||||
// 立即检查结果
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
resultsMutex.Lock()
|
||||
results = append(results, result)
|
||||
resultsMutex.Unlock()
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
// 对于不满足条件的数据,超时是正常的
|
||||
}
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
resultsMutex.Lock()
|
||||
finalResultCount := len(results)
|
||||
resultsCopy := make([]interface{}, len(results))
|
||||
copy(resultsCopy, results)
|
||||
resultsMutex.Unlock()
|
||||
|
||||
assert.Equal(t, 2, finalResultCount, "应该有2条记录满足条件")
|
||||
|
||||
// 验证结果内容
|
||||
deviceFound := make(map[string]bool)
|
||||
for _, result := range resultsCopy {
|
||||
resultSlice, ok := result.([]map[string]interface{})
|
||||
require.True(t, ok, "结果应该是[]map[string]interface{}类型")
|
||||
require.Len(t, resultSlice, 1, "每个结果应该只有一条记录")
|
||||
|
||||
item := resultSlice[0]
|
||||
device, _ := item["device"].(string)
|
||||
temp, _ := item["temperature"].(float64)
|
||||
|
||||
// 验证温度条件
|
||||
assert.Greater(t, temp, 20.0, "温度应该大于20")
|
||||
|
||||
// 记录找到的设备
|
||||
deviceFound[device] = true
|
||||
|
||||
// 验证所有字段都存在
|
||||
assert.Contains(t, item, "device", "应该包含device字段")
|
||||
assert.Contains(t, item, "temperature", "应该包含temperature字段")
|
||||
assert.Contains(t, item, "humidity", "应该包含humidity字段")
|
||||
}
|
||||
|
||||
// 验证正确的设备被包含
|
||||
assert.True(t, deviceFound["sensor1"], "sensor1应该被包含")
|
||||
assert.True(t, deviceFound["sensor3"], "sensor3应该被包含")
|
||||
assert.False(t, deviceFound["sensor2"], "sensor2不应该被包含")
|
||||
})
|
||||
|
||||
// 测试场景3:SELECT * + LIMIT
|
||||
t.Run("SELECT * + LIMIT", func(t *testing.T) {
|
||||
streamsql := New()
|
||||
defer streamsql.Stop()
|
||||
|
||||
var rsql = "SELECT * FROM stream LIMIT 2"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
|
||||
// 创建结果接收通道
|
||||
resultChan := make(chan interface{}, 10)
|
||||
|
||||
// 添加结果接收器
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
// 添加测试数据
|
||||
testData := []map[string]interface{}{
|
||||
{"device": "sensor1", "temperature": 25.0},
|
||||
{"device": "sensor2", "temperature": 26.0},
|
||||
{"device": "sensor3", "temperature": 27.0},
|
||||
{"device": "sensor4", "temperature": 28.0},
|
||||
}
|
||||
|
||||
var results []interface{}
|
||||
var resultsMutex sync.Mutex
|
||||
|
||||
// 发送数据
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
|
||||
// 立即检查结果
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
resultsMutex.Lock()
|
||||
results = append(results, result)
|
||||
resultsMutex.Unlock()
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
resultsMutex.Lock()
|
||||
finalResultCount := len(results)
|
||||
resultsCopy := make([]interface{}, len(results))
|
||||
copy(resultsCopy, results)
|
||||
resultsMutex.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, finalResultCount, 2, "应该至少有2条结果")
|
||||
|
||||
// 验证结果内容
|
||||
for _, result := range resultsCopy {
|
||||
resultSlice, ok := result.([]map[string]interface{})
|
||||
require.True(t, ok, "结果应该是[]map[string]interface{}类型")
|
||||
|
||||
// 验证LIMIT限制:每个batch最多2条记录
|
||||
assert.LessOrEqual(t, len(resultSlice), 2, "每个batch最多2条记录")
|
||||
assert.Greater(t, len(resultSlice), 0, "应该有结果")
|
||||
|
||||
// 验证字段
|
||||
for _, item := range resultSlice {
|
||||
assert.Contains(t, item, "device", "结果应包含device字段")
|
||||
assert.Contains(t, item, "temperature", "结果应包含temperature字段")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 测试场景4:SELECT * with嵌套字段
|
||||
t.Run("SELECT * with嵌套字段", func(t *testing.T) {
|
||||
streamsql := New()
|
||||
defer streamsql.Stop()
|
||||
|
||||
var rsql = "SELECT * FROM stream"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
|
||||
// 创建结果接收通道
|
||||
resultChan := make(chan interface{}, 10)
|
||||
|
||||
// 添加结果接收器
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
// 添加带嵌套字段的测试数据
|
||||
testData := map[string]interface{}{
|
||||
"device": "sensor001",
|
||||
"metrics": map[string]interface{}{
|
||||
"temperature": 25.5,
|
||||
"humidity": 60,
|
||||
},
|
||||
"location": map[string]interface{}{
|
||||
"building": "A",
|
||||
"room": "101",
|
||||
},
|
||||
}
|
||||
|
||||
// 发送数据
|
||||
strm.AddData(testData)
|
||||
|
||||
// 等待结果
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
// 验证结果
|
||||
resultSlice, ok := result.([]map[string]interface{})
|
||||
require.True(t, ok, "结果应该是[]map[string]interface{}类型")
|
||||
require.Len(t, resultSlice, 1, "应该只有一条结果")
|
||||
|
||||
item := resultSlice[0]
|
||||
// 验证顶级字段
|
||||
assert.Equal(t, "sensor001", item["device"], "device字段应该正确")
|
||||
|
||||
// 验证嵌套字段结构被保留
|
||||
metrics, ok := item["metrics"].(map[string]interface{})
|
||||
assert.True(t, ok, "metrics应该是map类型")
|
||||
assert.Equal(t, 25.5, metrics["temperature"], "嵌套temperature字段应该正确")
|
||||
assert.Equal(t, 60, metrics["humidity"], "嵌套humidity字段应该正确")
|
||||
|
||||
location, ok := item["location"].(map[string]interface{})
|
||||
assert.True(t, ok, "location应该是map类型")
|
||||
assert.Equal(t, "A", location["building"], "嵌套building字段应该正确")
|
||||
assert.Equal(t, "101", location["room"], "嵌套room字段应该正确")
|
||||
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("测试超时,未收到结果")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user