mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-07-10 09:41:37 +00:00
Merge pull request #1 from dimon-83/main
feat: add RWMutex to GroupAggregator
This commit is contained in:
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Aggregator interface {
|
||||
@ -17,6 +18,7 @@ type GroupAggregator struct {
|
||||
groupFields []string
|
||||
aggregators map[string]AggregatorFunction
|
||||
groups map[string]map[string]AggregatorFunction
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator {
|
||||
@ -34,6 +36,8 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType)
|
||||
}
|
||||
}
|
||||
func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
ga.mu.Lock() // 获取写锁
|
||||
defer ga.mu.Unlock() // 确保函数返回时释放锁
|
||||
var v reflect.Value
|
||||
|
||||
switch data.(type) {
|
||||
@ -85,9 +89,13 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
|
||||
if _, exists := ga.groups[key]; !exists {
|
||||
ga.groups[key] = make(map[string]AggregatorFunction)
|
||||
for field, agg := range ga.aggregators {
|
||||
}
|
||||
// field级别的聚合可以分批创建
|
||||
for field, agg := range ga.aggregators {
|
||||
if _, exists := ga.groups[key][field]; !exists {
|
||||
// 创建新的聚合器实例
|
||||
ga.groups[key][field] = agg.New()
|
||||
//fmt.Printf("groups by %s : %v \n", key, ga.groups[key])
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,14 +112,15 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
}
|
||||
|
||||
if !f.IsValid() {
|
||||
return fmt.Errorf("field %s not found", field)
|
||||
//return fmt.Errorf("field %s not found", field)
|
||||
//fmt.Printf("field %s not found in %v \n ", field, data)
|
||||
continue
|
||||
}
|
||||
|
||||
fieldVal := f.Interface()
|
||||
var value float64
|
||||
switch vType := fieldVal.(type) {
|
||||
case float64:
|
||||
|
||||
value = vType
|
||||
case int, int32, int64:
|
||||
value = float64(vType.(int))
|
||||
@ -122,6 +131,9 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
}
|
||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
||||
groupAgg.Add(value)
|
||||
//fmt.Printf("add agg group by %s:%s , %v \n", key, field, value)
|
||||
} else {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -129,6 +141,8 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
}
|
||||
|
||||
func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
|
||||
ga.mu.RLock() // 获取读锁,允许并发读取
|
||||
defer ga.mu.RUnlock() // 确保函数返回时释放锁
|
||||
result := make([]map[string]interface{}, 0, len(ga.groups))
|
||||
for key, aggregators := range ga.groups {
|
||||
group := make(map[string]interface{})
|
||||
@ -145,5 +159,7 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
|
||||
}
|
||||
|
||||
func (ga *GroupAggregator) Reset() {
|
||||
ga.mu.Lock() // 获取写锁
|
||||
defer ga.mu.Unlock() // 确保函数返回时释放锁
|
||||
ga.groups = make(map[string]map[string]AggregatorFunction)
|
||||
}
|
||||
|
@ -71,6 +71,7 @@ func (s *Stream) process() {
|
||||
case data := <-s.dataChan:
|
||||
if s.filter == nil || s.filter.Evaluate(data) {
|
||||
s.Window.Add(data)
|
||||
// fmt.Printf("add data to win : %v \n", data)
|
||||
}
|
||||
case batch := <-s.Window.OutputChan():
|
||||
// 处理窗口批数据
|
||||
|
@ -156,3 +156,69 @@ func TestStreamWithoutFilter(t *testing.T) {
|
||||
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncompleteStreamProcess(t *testing.T) {
|
||||
config := Config{
|
||||
WindowConfig: WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{"size": time.Second},
|
||||
},
|
||||
GroupFields: []string{"device"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"age": aggregator.Avg,
|
||||
"score": aggregator.Sum,
|
||||
},
|
||||
}
|
||||
|
||||
strm, err := NewStream(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = strm.RegisterFilter("device == 'aa' && age > 10")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 添加 Sink 函数来捕获结果
|
||||
resultChan := make(chan interface{})
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
|
||||
strm.Start()
|
||||
|
||||
// 准备测试数据
|
||||
testData := []interface{}{
|
||||
map[string]interface{}{"device": "aa", "age": 15.0},
|
||||
map[string]interface{}{"device": "aa", "score": 100},
|
||||
map[string]interface{}{"device": "aa", "age": 20.0},
|
||||
map[string]interface{}{"device": "aa", "score": 200},
|
||||
map[string]interface{}{"device": "bb", "age": 25.0, "score": 300},
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
}
|
||||
|
||||
// 等待结果,并设置超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var actual interface{}
|
||||
select {
|
||||
case actual = <-strm.GetResultsChan():
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("No results received within 5 seconds")
|
||||
}
|
||||
|
||||
// 预期结果:只有 device='aa' 且 age>10 的数据会被聚合
|
||||
expected := map[string]interface{}{
|
||||
"device": "aa",
|
||||
"age_avg": 17.5, // (15+20)/2
|
||||
"score_sum": 300.0, // 100+200
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
assert.IsType(t, []map[string]interface{}{}, actual)
|
||||
resultMap := actual.([]map[string]interface{})
|
||||
assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001)
|
||||
assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001)
|
||||
}
|
||||
|
Reference in New Issue
Block a user