diff --git a/rsql/ast.go b/rsql/ast.go index b71bd27..c0e9f88 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -17,6 +17,7 @@ import ( type SelectStatement struct { Fields []Field Distinct bool + SelectAll bool // 新增:标识是否是SELECT *查询 Source string Condition string Window WindowDefinition @@ -92,13 +93,18 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { // 如果没有聚合函数,收集简单字段 if !hasAggregation { - for _, field := range s.Fields { - fieldName := field.Expression - if field.Alias != "" { - // 如果有别名,用别名作为字段名 - simpleFields = append(simpleFields, fieldName+":"+field.Alias) - } else { - simpleFields = append(simpleFields, fieldName) + // 如果是SELECT *查询,设置特殊标记 + if s.SelectAll { + simpleFields = append(simpleFields, "*") + } else { + for _, field := range s.Fields { + fieldName := field.Expression + if field.Alias != "" { + // 如果有别名,用别名作为字段名 + simpleFields = append(simpleFields, fieldName+":"+field.Alias) + } else { + simpleFields = append(simpleFields, fieldName) + } } } logger.Debug("收集简单字段: %v", simpleFields) diff --git a/rsql/parser.go b/rsql/parser.go index 34810df..7aaed56 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -212,6 +212,23 @@ func (p *Parser) parseSelect(stmt *SelectStatement) error { currentToken = p.lexer.NextToken() // 消费 DISTINCT,移动到下一个 token } + // 检查是否是SELECT *查询 + if currentToken.Type == TokenIdent && currentToken.Value == "*" { + stmt.SelectAll = true + // 添加一个特殊的字段标记SELECT * + stmt.Fields = append(stmt.Fields, Field{Expression: "*"}) + + // 消费*token并检查下一个token + currentToken = p.lexer.NextToken() + + // 如果下一个token是FROM或EOF,则完成SELECT *解析 + if currentToken.Type == TokenFROM || currentToken.Type == TokenEOF { + return nil + } + + // 如果不是FROM/EOF,继续正常的字段解析流程 + } + // 设置最大字段数量限制,防止无限循环 maxFields := 100 fieldCount := 0 diff --git a/stream/stream.go b/stream/stream.go index 5b8ce2e..3894ac8 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -573,6 +573,18 @@ func (s *Stream) processDirectData(data interface{}) { // 如果指定了字段,只保留这些字段 if len(s.config.SimpleFields) > 0 { for _, fieldSpec := range s.config.SimpleFields { + // 处理SELECT *的特殊情况 + if fieldSpec == "*" { + // SELECT *:返回所有字段,但跳过已经通过表达式字段处理的字段 + for k, v := range dataMap { + // 如果该字段已经通过表达式字段处理,则跳过,保持表达式计算结果 + if _, isExpression := s.config.FieldExpressions[k]; !isExpression { + result[k] = v + } + } + continue + } + // 处理别名 parts := strings.Split(fieldSpec, ":") fieldName := parts[0] diff --git a/stream/stream_test.go b/stream/stream_test.go index 51e2e4a..cda3049 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -754,3 +754,219 @@ func TestStreamsqlPersistenceConfigPassing(t *testing.T) { t.Logf("持久化配置验证通过: %+v", stats) } + +func TestSelectStarWithExpressionFields(t *testing.T) { + config := types.Config{ + NeedWindow: false, + SimpleFields: []string{"*"}, // SELECT * + FieldExpressions: map[string]types.FieldExpression{ + "name": { + Expression: "UPPER(name)", + Fields: []string{"name"}, + }, + "full_info": { + Expression: "CONCAT(name, ' - ', status)", + Fields: []string{"name", "status"}, + }, + }, + } + + stream, err := NewStream(config) + if err != nil { + t.Fatalf("Failed to create stream: %v", err) + } + defer stream.Stop() + + // 收集结果 - 使用sync.Mutex防止数据竞争 + var mu sync.Mutex + var results []interface{} + stream.AddSink(func(result interface{}) { + mu.Lock() + defer mu.Unlock() + results = append(results, result) + }) + + stream.Start() + + // 添加测试数据 + testData := map[string]interface{}{ + "name": "john", + "status": "active", + "age": 25, + } + + stream.AddData(testData) + + // 等待处理完成 + time.Sleep(100 * time.Millisecond) + + // 验证结果 - 使用互斥锁保护读取 + mu.Lock() + resultsLen := len(results) + var resultData map[string]interface{} + if resultsLen > 0 { + resultData = results[0].([]map[string]interface{})[0] + } + mu.Unlock() + + if resultsLen != 1 { + t.Fatalf("Expected 1 result, got %d", resultsLen) + } + + // 验证表达式字段的结果没有被覆盖 + if resultData["name"] != "JOHN" { + t.Errorf("Expected name to be 'JOHN' (uppercase), got %v", resultData["name"]) + } + + if resultData["full_info"] != "john - active" { + t.Errorf("Expected full_info to be 'john - active', got %v", resultData["full_info"]) + } + + // 验证原始字段仍然存在 + if resultData["status"] != "active" { + t.Errorf("Expected status to be 'active', got %v", resultData["status"]) + } + + if resultData["age"] != 25 { + t.Errorf("Expected age to be 25, got %v", resultData["age"]) + } +} + +func TestSelectStarWithExpressionFieldsOverride(t *testing.T) { + // 测试表达式字段名与原始字段名相同的情况 + config := types.Config{ + NeedWindow: false, + SimpleFields: []string{"*"}, // SELECT * + FieldExpressions: map[string]types.FieldExpression{ + "name": { + Expression: "UPPER(name)", + Fields: []string{"name"}, + }, + "age": { + Expression: "age * 2", + Fields: []string{"age"}, + }, + }, + } + + stream, err := NewStream(config) + if err != nil { + t.Fatalf("Failed to create stream: %v", err) + } + defer stream.Stop() + + // 收集结果 - 使用sync.Mutex防止数据竞争 + var mu sync.Mutex + var results []interface{} + stream.AddSink(func(result interface{}) { + mu.Lock() + defer mu.Unlock() + results = append(results, result) + }) + + stream.Start() + + // 添加测试数据 + testData := map[string]interface{}{ + "name": "alice", + "age": 30, + "status": "active", + } + + stream.AddData(testData) + + // 等待处理完成 + time.Sleep(100 * time.Millisecond) + + // 验证结果 - 使用互斥锁保护读取 + mu.Lock() + resultsLen := len(results) + var resultData map[string]interface{} + if resultsLen > 0 { + resultData = results[0].([]map[string]interface{})[0] + } + mu.Unlock() + + if resultsLen != 1 { + t.Fatalf("Expected 1 result, got %d", resultsLen) + } + + // 验证表达式字段的结果覆盖了原始字段 + if resultData["name"] != "ALICE" { + t.Errorf("Expected name to be 'ALICE' (expression result), got %v", resultData["name"]) + } + + // 检查age表达式的结果(可能是int或float64类型) + ageResult := resultData["age"] + if ageResult != 60 && ageResult != 60.0 { + t.Errorf("Expected age to be 60 (expression result), got %v (type: %T)", resultData["age"], resultData["age"]) + } + + // 验证没有表达式的字段保持原值 + if resultData["status"] != "active" { + t.Errorf("Expected status to be 'active', got %v", resultData["status"]) + } +} + +func TestSelectStarWithoutExpressionFields(t *testing.T) { + // 测试没有表达式字段时SELECT *的行为 + config := types.Config{ + NeedWindow: false, + SimpleFields: []string{"*"}, // SELECT * + } + + stream, err := NewStream(config) + if err != nil { + t.Fatalf("Failed to create stream: %v", err) + } + defer stream.Stop() + + // 收集结果 - 使用sync.Mutex防止数据竞争 + var mu sync.Mutex + var results []interface{} + stream.AddSink(func(result interface{}) { + mu.Lock() + defer mu.Unlock() + results = append(results, result) + }) + + stream.Start() + + // 添加测试数据 + testData := map[string]interface{}{ + "name": "bob", + "age": 35, + "status": "inactive", + } + + stream.AddData(testData) + + // 等待处理完成 + time.Sleep(100 * time.Millisecond) + + // 验证结果 - 使用互斥锁保护读取 + mu.Lock() + resultsLen := len(results) + var resultData map[string]interface{} + if resultsLen > 0 { + resultData = results[0].([]map[string]interface{})[0] + } + mu.Unlock() + + if resultsLen != 1 { + t.Fatalf("Expected 1 result, got %d", resultsLen) + } + + // 验证所有原始字段都被保留 + if resultData["name"] != "bob" { + t.Errorf("Expected name to be 'bob', got %v", resultData["name"]) + } + + if resultData["age"] != 35 { + t.Errorf("Expected age to be 35, got %v", resultData["age"]) + } + + if resultData["status"] != "inactive" { + t.Errorf("Expected status to be 'inactive', got %v", resultData["status"]) + } +} diff --git a/streamsql_test.go b/streamsql_test.go index f2c0aa3..94eec1a 100644 --- a/streamsql_test.go +++ b/streamsql_test.go @@ -350,7 +350,7 @@ func TestStreamsqlLimit(t *testing.T) { streamsql := New() defer streamsql.Stop() - var rsql = "SELECT device, temperature FROM stream LIMIT 2" + var rsql = "SELECT * FROM stream LIMIT 2" err := streamsql.Execute(rsql) assert.Nil(t, err) strm := streamsql.stream @@ -2591,3 +2591,286 @@ func TestExprFunctionsWithStreamSQLFunctions(t *testing.T) { } } } + +// TestSelectAllFeature 专门测试SELECT *功能 +func TestSelectAllFeature(t *testing.T) { + // 测试场景1:基本SELECT *查询 + t.Run("基本SELECT *查询", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := map[string]interface{}{ + "device": "sensor001", + "temperature": 25.5, + "humidity": 60, + "location": "room1", + "status": "active", + } + + // 发送数据 + strm.AddData(testData) + + // 等待结果 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + select { + case result := <-resultChan: + // 验证结果 + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + require.Len(t, resultSlice, 1, "应该只有一条结果") + + item := resultSlice[0] + // 验证所有原始字段都存在 + assert.Equal(t, "sensor001", item["device"], "device字段应该正确") + assert.Equal(t, 25.5, item["temperature"], "temperature字段应该正确") + assert.Equal(t, 60, item["humidity"], "humidity字段应该正确") + assert.Equal(t, "room1", item["location"], "location字段应该正确") + assert.Equal(t, "active", item["status"], "status字段应该正确") + + // 验证字段数量 + assert.Len(t, item, 5, "应该包含所有5个字段") + + cancel() + case <-ctx.Done(): + t.Fatal("测试超时,未收到结果") + } + }) + + // 测试场景2:SELECT * + WHERE条件 + t.Run("SELECT * + WHERE条件", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream WHERE temperature > 20" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"device": "sensor1", "temperature": 25.0, "humidity": 60}, // 应该被包含 + {"device": "sensor2", "temperature": 15.0, "humidity": 70}, // 应该被过滤掉 + {"device": "sensor3", "temperature": 30.0, "humidity": 50}, // 应该被包含 + } + + var results []interface{} + var resultsMutex sync.Mutex + + // 发送数据 + for _, data := range testData { + strm.AddData(data) + + // 立即检查结果 + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + select { + case result := <-resultChan: + resultsMutex.Lock() + results = append(results, result) + resultsMutex.Unlock() + cancel() + case <-ctx.Done(): + cancel() + // 对于不满足条件的数据,超时是正常的 + } + } + + // 验证结果 + resultsMutex.Lock() + finalResultCount := len(results) + resultsCopy := make([]interface{}, len(results)) + copy(resultsCopy, results) + resultsMutex.Unlock() + + assert.Equal(t, 2, finalResultCount, "应该有2条记录满足条件") + + // 验证结果内容 + deviceFound := make(map[string]bool) + for _, result := range resultsCopy { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + require.Len(t, resultSlice, 1, "每个结果应该只有一条记录") + + item := resultSlice[0] + device, _ := item["device"].(string) + temp, _ := item["temperature"].(float64) + + // 验证温度条件 + assert.Greater(t, temp, 20.0, "温度应该大于20") + + // 记录找到的设备 + deviceFound[device] = true + + // 验证所有字段都存在 + assert.Contains(t, item, "device", "应该包含device字段") + assert.Contains(t, item, "temperature", "应该包含temperature字段") + assert.Contains(t, item, "humidity", "应该包含humidity字段") + } + + // 验证正确的设备被包含 + assert.True(t, deviceFound["sensor1"], "sensor1应该被包含") + assert.True(t, deviceFound["sensor3"], "sensor3应该被包含") + assert.False(t, deviceFound["sensor2"], "sensor2不应该被包含") + }) + + // 测试场景3:SELECT * + LIMIT + t.Run("SELECT * + LIMIT", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream LIMIT 2" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"device": "sensor1", "temperature": 25.0}, + {"device": "sensor2", "temperature": 26.0}, + {"device": "sensor3", "temperature": 27.0}, + {"device": "sensor4", "temperature": 28.0}, + } + + var results []interface{} + var resultsMutex sync.Mutex + + // 发送数据 + for _, data := range testData { + strm.AddData(data) + + // 立即检查结果 + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + select { + case result := <-resultChan: + resultsMutex.Lock() + results = append(results, result) + resultsMutex.Unlock() + cancel() + case <-ctx.Done(): + cancel() + } + } + + // 验证结果 + resultsMutex.Lock() + finalResultCount := len(results) + resultsCopy := make([]interface{}, len(results)) + copy(resultsCopy, results) + resultsMutex.Unlock() + + assert.GreaterOrEqual(t, finalResultCount, 2, "应该至少有2条结果") + + // 验证结果内容 + for _, result := range resultsCopy { + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + + // 验证LIMIT限制:每个batch最多2条记录 + assert.LessOrEqual(t, len(resultSlice), 2, "每个batch最多2条记录") + assert.Greater(t, len(resultSlice), 0, "应该有结果") + + // 验证字段 + for _, item := range resultSlice { + assert.Contains(t, item, "device", "结果应包含device字段") + assert.Contains(t, item, "temperature", "结果应包含temperature字段") + } + } + }) + + // 测试场景4:SELECT * with嵌套字段 + t.Run("SELECT * with嵌套字段", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() + + var rsql = "SELECT * FROM stream" + err := streamsql.Execute(rsql) + assert.Nil(t, err) + strm := streamsql.stream + + // 创建结果接收通道 + resultChan := make(chan interface{}, 10) + + // 添加结果接收器 + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + + // 添加带嵌套字段的测试数据 + testData := map[string]interface{}{ + "device": "sensor001", + "metrics": map[string]interface{}{ + "temperature": 25.5, + "humidity": 60, + }, + "location": map[string]interface{}{ + "building": "A", + "room": "101", + }, + } + + // 发送数据 + strm.AddData(testData) + + // 等待结果 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + select { + case result := <-resultChan: + // 验证结果 + resultSlice, ok := result.([]map[string]interface{}) + require.True(t, ok, "结果应该是[]map[string]interface{}类型") + require.Len(t, resultSlice, 1, "应该只有一条结果") + + item := resultSlice[0] + // 验证顶级字段 + assert.Equal(t, "sensor001", item["device"], "device字段应该正确") + + // 验证嵌套字段结构被保留 + metrics, ok := item["metrics"].(map[string]interface{}) + assert.True(t, ok, "metrics应该是map类型") + assert.Equal(t, 25.5, metrics["temperature"], "嵌套temperature字段应该正确") + assert.Equal(t, 60, metrics["humidity"], "嵌套humidity字段应该正确") + + location, ok := item["location"].(map[string]interface{}) + assert.True(t, ok, "location应该是map类型") + assert.Equal(t, "A", location["building"], "嵌套building字段应该正确") + assert.Equal(t, "101", location["room"], "嵌套room字段应该正确") + + cancel() + case <-ctx.Done(): + t.Fatal("测试超时,未收到结果") + } + }) +}