Merge pull request #1 from dimon-83/main

feat: add RWMutex to GroupAggregator
This commit is contained in:
Whki
2025-03-13 11:53:59 +08:00
committed by GitHub
3 changed files with 86 additions and 3 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
)
type Aggregator interface {
@ -17,6 +18,7 @@ type GroupAggregator struct {
groupFields []string
aggregators map[string]AggregatorFunction
groups map[string]map[string]AggregatorFunction
mu sync.RWMutex
}
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator {
@ -34,6 +36,8 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType)
}
}
func (ga *GroupAggregator) Add(data interface{}) error {
ga.mu.Lock() // 获取写锁
defer ga.mu.Unlock() // 确保函数返回时释放锁
var v reflect.Value
switch data.(type) {
@ -85,9 +89,13 @@ func (ga *GroupAggregator) Add(data interface{}) error {
if _, exists := ga.groups[key]; !exists {
ga.groups[key] = make(map[string]AggregatorFunction)
for field, agg := range ga.aggregators {
}
// field级别的聚合可以分批创建
for field, agg := range ga.aggregators {
if _, exists := ga.groups[key][field]; !exists {
// 创建新的聚合器实例
ga.groups[key][field] = agg.New()
//fmt.Printf("groups by %s : %v \n", key, ga.groups[key])
}
}
@ -104,14 +112,15 @@ func (ga *GroupAggregator) Add(data interface{}) error {
}
if !f.IsValid() {
return fmt.Errorf("field %s not found", field)
//return fmt.Errorf("field %s not found", field)
//fmt.Printf("field %s not found in %v \n ", field, data)
continue
}
fieldVal := f.Interface()
var value float64
switch vType := fieldVal.(type) {
case float64:
value = vType
case int, int32, int64:
value = float64(vType.(int))
@ -122,6 +131,9 @@ func (ga *GroupAggregator) Add(data interface{}) error {
}
if groupAgg, exists := ga.groups[key][field]; exists {
groupAgg.Add(value)
//fmt.Printf("add agg group by %s:%s , %v \n", key, field, value)
} else {
}
}
@ -129,6 +141,8 @@ func (ga *GroupAggregator) Add(data interface{}) error {
}
func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
ga.mu.RLock() // 获取读锁,允许并发读取
defer ga.mu.RUnlock() // 确保函数返回时释放锁
result := make([]map[string]interface{}, 0, len(ga.groups))
for key, aggregators := range ga.groups {
group := make(map[string]interface{})
@ -145,5 +159,7 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
}
func (ga *GroupAggregator) Reset() {
ga.mu.Lock() // 获取写锁
defer ga.mu.Unlock() // 确保函数返回时释放锁
ga.groups = make(map[string]map[string]AggregatorFunction)
}

View File

@ -71,6 +71,7 @@ func (s *Stream) process() {
case data := <-s.dataChan:
if s.filter == nil || s.filter.Evaluate(data) {
s.Window.Add(data)
// fmt.Printf("add data to win : %v \n", data)
}
case batch := <-s.Window.OutputChan():
// 处理窗口批数据

View File

@ -156,3 +156,69 @@ func TestStreamWithoutFilter(t *testing.T) {
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
}
}
func TestIncompleteStreamProcess(t *testing.T) {
config := Config{
WindowConfig: WindowConfig{
Type: "tumbling",
Params: map[string]interface{}{"size": time.Second},
},
GroupFields: []string{"device"},
SelectFields: map[string]aggregator.AggregateType{
"age": aggregator.Avg,
"score": aggregator.Sum,
},
}
strm, err := NewStream(config)
require.NoError(t, err)
err = strm.RegisterFilter("device == 'aa' && age > 10")
require.NoError(t, err)
// 添加 Sink 函数来捕获结果
resultChan := make(chan interface{})
strm.AddSink(func(result interface{}) {
resultChan <- result
})
strm.Start()
// 准备测试数据
testData := []interface{}{
map[string]interface{}{"device": "aa", "age": 15.0},
map[string]interface{}{"device": "aa", "score": 100},
map[string]interface{}{"device": "aa", "age": 20.0},
map[string]interface{}{"device": "aa", "score": 200},
map[string]interface{}{"device": "bb", "age": 25.0, "score": 300},
}
for _, data := range testData {
strm.AddData(data)
}
// 等待结果,并设置超时
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var actual interface{}
select {
case actual = <-strm.GetResultsChan():
cancel()
case <-ctx.Done():
t.Fatal("No results received within 5 seconds")
}
// 预期结果:只有 device='aa' 且 age>10 的数据会被聚合
expected := map[string]interface{}{
"device": "aa",
"age_avg": 17.5, // (15+20)/2
"score_sum": 300.0, // 100+200
}
// 验证结果
assert.IsType(t, []map[string]interface{}{}, actual)
resultMap := actual.([]map[string]interface{})
assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001)
assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001)
}