From 6fc3e345277306480030f550b676e6f9a0f42039 Mon Sep 17 00:00:00 2001 From: dimon Date: Thu, 3 Apr 2025 17:33:02 +0800 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=20WITH=20?= =?UTF-8?q?=E5=AD=90=E5=8F=A5=E5=8A=9F=E8=83=BD=20-=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E9=80=9A=E8=BF=87=20TIMESTAMP=20=E6=8C=87=E5=AE=9A=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E6=88=B3=E5=B1=9E=E6=80=A7=E5=90=8D=20-=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E9=80=9A=E8=BF=87=20TIMEUNIT=20=E6=8C=87=E5=AE=9A?= =?UTF-8?q?=E6=97=B6=E9=97=B4=E5=8D=95=E4=BD=8D=EF=BC=88ms/ss/mm/hh/dd?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/model.go | 19 ++++++++ model/row.go | 19 ++++++++ rsql/ast.go | 23 ++++++---- rsql/lexer.go | 9 ++++ rsql/parser.go | 81 ++++++++++++++++++++++++++++++++-- rsql/parser_test.go | 35 ++++++++++++--- stream/stream.go | 20 +++------ stream/stream_test.go | 13 +++--- window/counting_window.go | 20 +++++++-- window/counting_window_test.go | 20 ++++++--- window/factory.go | 74 +++++++++++++++++++------------ window/sliding_window.go | 41 ++++++++++------- window/sliding_window_test.go | 59 +++++++++++++++++++++---- window/tumbling_window.go | 15 ++++++- window/tumbling_window_test.go | 9 +++- 15 files changed, 352 insertions(+), 105 deletions(-) create mode 100644 model/model.go create mode 100644 model/row.go diff --git a/model/model.go b/model/model.go new file mode 100644 index 0000000..869674b --- /dev/null +++ b/model/model.go @@ -0,0 +1,19 @@ +package model + +import ( + "time" + + aggregator2 "github.com/rulego/streamsql/aggregator" +) + +type Config struct { + WindowConfig WindowConfig + GroupFields []string + SelectFields map[string]aggregator2.AggregateType +} +type WindowConfig struct { + Type string + Params map[string]interface{} + TsProp string + TimeUnit time.Duration +} diff --git a/model/row.go b/model/row.go new file mode 100644 index 0000000..f724296 --- /dev/null +++ b/model/row.go @@ -0,0 +1,19 @@ +package model + +import ( + "time" +) + +type RowEvent interface { + GetTimestamp() time.Time +} + +type Row struct { + Timestamp time.Time + Data interface{} +} + +// GetTimestamp 获取时间戳 +func (r *Row) GetTimestamp() time.Time { + return r.Timestamp +} diff --git a/rsql/ast.go b/rsql/ast.go index 99cffa4..c028afb 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -2,12 +2,13 @@ package rsql import ( "fmt" - "github.com/rulego/streamsql/window" "strings" "time" + "github.com/rulego/streamsql/model" + "github.com/rulego/streamsql/window" + "github.com/rulego/streamsql/aggregator" - "github.com/rulego/streamsql/stream" ) type SelectStatement struct { @@ -24,12 +25,14 @@ type Field struct { } type WindowDefinition struct { - Type string - Params []interface{} + Type string + Params []interface{} + TsProp string + TimeUnit time.Duration } // ToStreamConfig 将AST转换为Stream配置 -func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) { +func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) { if s.Source == "" { return nil, "", fmt.Errorf("missing FROM clause") } @@ -51,10 +54,12 @@ func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) { } // 构建Stream配置 - config := stream.Config{ - WindowConfig: stream.WindowConfig{ - Type: windowType, - Params: params, + config := model.Config{ + WindowConfig: model.WindowConfig{ + Type: windowType, + Params: params, + TsProp: s.Window.TsProp, + TimeUnit: s.Window.TimeUnit, }, GroupFields: extractGroupFields(s), SelectFields: buildSelectFields(s.Fields), diff --git a/rsql/lexer.go b/rsql/lexer.go index 7ade707..299b6aa 100644 --- a/rsql/lexer.go +++ b/rsql/lexer.go @@ -34,6 +34,9 @@ const ( TokenSliding TokenCounting TokenSession + TokenWITH + TokenTimestamp + TokenTimeUnit ) type Token struct { @@ -204,6 +207,12 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenCounting, Value: ident} case "SESSIONWINDOW": return Token{Type: TokenSession, Value: ident} + case "WITH": + return Token{Type: TokenWITH, Value: ident} + case "TIMESTAMP": + return Token{Type: TokenTimestamp, Value: ident} + case "TIMEUNIT": + return Token{Type: TokenTimeUnit, Value: ident} default: return Token{Type: TokenIdent, Value: ident} } diff --git a/rsql/parser.go b/rsql/parser.go index c3696e3..01a451c 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -4,6 +4,7 @@ import ( "errors" "strconv" "strings" + "time" ) type Parser struct { @@ -39,6 +40,10 @@ func (p *Parser) Parse() (*SelectStatement, error) { return nil, err } + if err := p.parseWith(stmt); err != nil { + return nil, err + } + return stmt, nil } func (p *Parser) parseSelect(stmt *SelectStatement) error { @@ -125,9 +130,14 @@ func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) erro params = append(params, convertValue(valTok.Value)) } - stmt.Window = WindowDefinition{ - Type: winType, - Params: params, + if &stmt.Window != nil { + stmt.Window.Params = params + stmt.Window.Type = winType + } else { + stmt.Window = WindowDefinition{ + Type: winType, + Params: params, + } } return nil } @@ -185,3 +195,68 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error { } return nil } + +func (p *Parser) parseWith(stmt *SelectStatement) error { + p.lexer.NextToken() // 跳过( + for p.lexer.peekChar() != ')' { + valTok := p.lexer.NextToken() + if valTok.Type == TokenRParen || valTok.Type == TokenEOF { + break + } + if valTok.Type == TokenComma { + continue + } + + if valTok.Type == TokenTimestamp { + next := p.lexer.NextToken() + if next.Type == TokenEQ { + next = p.lexer.NextToken() + if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") { + next.Value = strings.Trim(next.Value, "'") + } + // 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition + if stmt.Window.Type == "" { + stmt.Window = WindowDefinition{ + TsProp: next.Value, + } + } else { + stmt.Window.TsProp = next.Value + } + } + } + if valTok.Type == TokenTimeUnit { + timeUnit := time.Minute + next := p.lexer.NextToken() + if next.Type == TokenEQ { + next = p.lexer.NextToken() + if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") { + next.Value = strings.Trim(next.Value, "'") + } + switch next.Value { + case "dd": + timeUnit = 24 * time.Hour + case "hh": + timeUnit = time.Hour + case "mi": + timeUnit = time.Minute + case "ss": + timeUnit = time.Second + case "ms": + timeUnit = time.Millisecond + default: + + } + // 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition + if stmt.Window.Type == "" { + stmt.Window = WindowDefinition{ + TimeUnit: timeUnit, + } + } else { + stmt.Window.TimeUnit = timeUnit + } + } + } + } + + return nil +} diff --git a/rsql/parser_test.go b/rsql/parser_test.go index 2027545..7fa96c1 100644 --- a/rsql/parser_test.go +++ b/rsql/parser_test.go @@ -1,24 +1,25 @@ package rsql import ( - "github.com/rulego/streamsql/aggregator" "testing" "time" - "github.com/rulego/streamsql/stream" + "github.com/rulego/streamsql/aggregator" + "github.com/rulego/streamsql/model" + "github.com/stretchr/testify/assert" ) func TestParseSQL(t *testing.T) { tests := []struct { sql string - expected *stream.Config + expected *model.Config condition string }{ { sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s')", - expected: &stream.Config{ - WindowConfig: stream.WindowConfig{ + expected: &model.Config{ + WindowConfig: model.WindowConfig{ Type: "tumbling", Params: map[string]interface{}{ "size": 10 * time.Second, @@ -33,8 +34,8 @@ func TestParseSQL(t *testing.T) { }, { sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow('20s', '5s')", - expected: &stream.Config{ - WindowConfig: stream.WindowConfig{ + expected: &model.Config{ + WindowConfig: model.WindowConfig{ Type: "sliding", Params: map[string]interface{}{ "size": 20 * time.Second, @@ -49,6 +50,23 @@ func TestParseSQL(t *testing.T) { }, condition: "", }, + { + sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ", + expected: &model.Config{ + WindowConfig: model.WindowConfig{ + Type: "tumbling", + Params: map[string]interface{}{ + "size": 10 * time.Second, + }, + TsProp: "ts", + }, + GroupFields: []string{"deviceId"}, + SelectFields: map[string]aggregator.AggregateType{ + "aa": "avg", + }, + }, + condition: "deviceId == 'aa'", + }, } for _, tt := range tests { @@ -64,6 +82,9 @@ func TestParseSQL(t *testing.T) { assert.Equal(t, tt.expected.GroupFields, config.GroupFields) assert.Equal(t, tt.expected.SelectFields, config.SelectFields) assert.Equal(t, tt.condition, cond) + if tt.expected.WindowConfig.TsProp != "" { + assert.Equal(t, tt.expected.WindowConfig.TsProp, config.WindowConfig.TsProp) + } } } func TestWindowParamParsing(t *testing.T) { diff --git a/stream/stream.go b/stream/stream.go index 5cf2dcc..c5d6faa 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -5,33 +5,23 @@ import ( "time" aggregator2 "github.com/rulego/streamsql/aggregator" + "github.com/rulego/streamsql/model" "github.com/rulego/streamsql/parser" "github.com/rulego/streamsql/window" ) -type Config struct { - WindowConfig WindowConfig - GroupFields []string - SelectFields map[string]aggregator2.AggregateType -} - -type WindowConfig struct { - Type string - Params map[string]interface{} -} - type Stream struct { dataChan chan interface{} filter parser.Condition Window window.Window aggregator aggregator2.Aggregator - config Config + config model.Config sinks []func(interface{}) resultChan chan interface{} // 结果通道 } -func NewStream(config Config) (*Stream, error) { - win, err := window.CreateWindow(config.WindowConfig.Type, config.WindowConfig.Params) +func NewStream(config model.Config) (*Stream, error) { + win, err := window.CreateWindow(config.WindowConfig) if err != nil { return nil, err } @@ -108,5 +98,5 @@ func (s *Stream) GetResultsChan() <-chan interface{} { } func NewStreamProcessor() (*Stream, error) { - return NewStream(Config{}) + return NewStream(model.Config{}) } diff --git a/stream/stream_test.go b/stream/stream_test.go index 8655778..b6e0d94 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -7,13 +7,14 @@ import ( "time" "github.com/rulego/streamsql/aggregator" + "github.com/rulego/streamsql/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestStreamProcess(t *testing.T) { - config := Config{ - WindowConfig: WindowConfig{ + config := model.Config{ + WindowConfig: model.WindowConfig{ Type: "tumbling", Params: map[string]interface{}{"size": time.Second}, }, @@ -77,8 +78,8 @@ func TestStreamProcess(t *testing.T) { // 不设置过滤器 func TestStreamWithoutFilter(t *testing.T) { - config := Config{ - WindowConfig: WindowConfig{ + config := model.Config{ + WindowConfig: model.WindowConfig{ Type: "sliding", Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second}, }, @@ -158,8 +159,8 @@ func TestStreamWithoutFilter(t *testing.T) { } func TestIncompleteStreamProcess(t *testing.T) { - config := Config{ - WindowConfig: WindowConfig{ + config := model.Config{ + WindowConfig: model.WindowConfig{ Type: "tumbling", Params: map[string]interface{}{"size": time.Second}, }, diff --git a/window/counting_window.go b/window/counting_window.go index b6f2837..424f29f 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -2,13 +2,18 @@ package window import ( "context" + "fmt" "sync" "time" + + "github.com/rulego/streamsql/model" + "github.com/spf13/cast" ) var _ Window = (*CountingWindow)(nil) type CountingWindow struct { + config model.WindowConfig threshold int count int mu sync.Mutex @@ -21,17 +26,26 @@ type CountingWindow struct { triggerChan chan struct{} } -func NewCountingWindow(threshold int, callback func([]interface{})) *CountingWindow { +func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) { ctx, cancel := context.WithCancel(context.Background()) - return &CountingWindow{ + threshold := cast.ToInt(config.Params["count"]) + if threshold <= 0 { + return nil, fmt.Errorf("threshold must be a positive integer") + } + + cw := &CountingWindow{ threshold: threshold, dataBuffer: make([]interface{}, 0, threshold), outputChan: make(chan []interface{}, 10), ctx: ctx, cancelFunc: cancel, - callback: callback, triggerChan: make(chan struct{}, 1), } + + if callback, ok := config.Params["callback"].(func([]interface{})); ok { + cw.SetCallback(callback) + } + return cw, nil } func (cw *CountingWindow) Add(data interface{}) { diff --git a/window/counting_window_test.go b/window/counting_window_test.go index e1b121a..50ef2bb 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -2,10 +2,12 @@ package window import ( "context" - "github.com/stretchr/testify/require" "testing" "time" + "github.com/rulego/streamsql/model" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) @@ -14,8 +16,13 @@ func TestCountingWindow(t *testing.T) { defer cancel() // Test case 1: Normal operation - cw := NewCountingWindow(3, func(results []interface{}) { - t.Logf("Received results: %v", results) + cw, _ := NewCountingWindow(model.WindowConfig{ + Params: map[string]interface{}{ + "count": 3, + "callback": func(results []interface{}) { + t.Logf("Received results: %v", results) + }, + }, }) go cw.Start() @@ -50,8 +57,11 @@ func TestCountingWindow(t *testing.T) { } func TestCountingWindowBadThreshold(t *testing.T) { - _, err := CreateWindow("counting", map[string]interface{}{ - "count": 0, + _, err := CreateWindow(model.WindowConfig{ + Type: "counting", + Params: map[string]interface{}{ + "count": 0, + }, }) require.Error(t, err) } diff --git a/window/factory.go b/window/factory.go index d58cce9..d3cccce 100644 --- a/window/factory.go +++ b/window/factory.go @@ -2,7 +2,10 @@ package window import ( "fmt" - "github.com/spf13/cast" + "reflect" + "time" + + "github.com/rulego/streamsql/model" ) const ( @@ -22,39 +25,56 @@ type Window interface { Trigger() } -func CreateWindow(windowType string, params map[string]interface{}) (Window, error) { - switch windowType { +func CreateWindow(config model.WindowConfig) (Window, error) { + switch config.Type { case TypeTumbling: - size, err := cast.ToDurationE(params["size"]) - if err != nil { - return nil, fmt.Errorf("invalid size for tumbling window: %v", err) - } - return NewTumblingWindow(size), nil + return NewTumblingWindow(config) case TypeSliding: - size, err := cast.ToDurationE(params["size"]) - if err != nil { - return nil, fmt.Errorf("invalid size for sliding window: %v", err) - } - slide, err := cast.ToDurationE(params["slide"]) - if err != nil { - return nil, fmt.Errorf("invalid slide for sliding window: %v", err) - } - return NewSlidingWindow(size, slide), nil + return NewSlidingWindow(config) case TypeCounting: - count := cast.ToInt(params["count"]) - if count <= 0 { - return nil, fmt.Errorf("count must be a positive integer") - } - cw := NewCountingWindow(count, nil) - if callback, ok := params["callback"].(func([]interface{})); ok { - cw.SetCallback(callback) - } - return cw, nil + return NewCountingWindow(config) default: - return nil, fmt.Errorf("unsupported window type: %s", windowType) + return nil, fmt.Errorf("unsupported window type: %s", config.Type) } } func (cw *CountingWindow) SetCallback(callback func([]interface{})) { cw.callback = callback } + +// GetTimestamp 从数据中获取时间戳。 +func GetTimestamp(data interface{}, tsProp string) time.Time { + if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok { + return ts.GetTimestamp() + } else if tsProp != "" { + v := reflect.ValueOf(data) + + // 处理不同类型 + switch v.Kind() { + case reflect.Struct: + // 如果是结构体,使用反射获取字段值 + if f := v.FieldByName(tsProp); f.IsValid() { + if t, ok := f.Interface().(time.Time); ok { + return t + } + } + case reflect.Map: + // 如果是map,直接通过key获取值 + if v.Type().Key().Kind() == reflect.String { + if value := v.MapIndex(reflect.ValueOf(tsProp)); value.IsValid() { + return value.Interface().(time.Time) + } + } + } + } + return time.Now() +} + +// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。 +func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time { + trunc := t.Truncate(timeUnit) + if !roundUp { + return trunc.Add(timeUnit) + } + return trunc +} diff --git a/window/sliding_window.go b/window/sliding_window.go index a807816..39f549e 100644 --- a/window/sliding_window.go +++ b/window/sliding_window.go @@ -2,8 +2,12 @@ package window import ( "context" + "fmt" "sync" "time" + + "github.com/rulego/streamsql/model" + "github.com/spf13/cast" ) // 确保 SlidingWindow 结构体实现了 Window 接口 @@ -11,12 +15,14 @@ var _ Window = (*SlidingWindow)(nil) // TimedData 用于包装数据和时间戳 type TimedData struct { - Data interface{} - Timestamp time.Time + Data interface{} + Timestamp time.Time } // SlidingWindow 表示一个滑动窗口,用于按时间范围处理数据 type SlidingWindow struct { + // config 窗口的配置信息 + config model.WindowConfig // 窗口的总大小,即窗口覆盖的时间范围 size time.Duration // 窗口每次滑动的时间间隔 @@ -39,17 +45,26 @@ type SlidingWindow struct { // NewSlidingWindow 创建一个新的滑动窗口实例 // 参数 size 表示窗口的总大小,slide 表示窗口每次滑动的时间间隔 -func NewSlidingWindow(size, slide time.Duration) *SlidingWindow { +func NewSlidingWindow(config model.WindowConfig) (*SlidingWindow, error) { // 创建一个可取消的上下文 ctx, cancel := context.WithCancel(context.Background()) + size, err := cast.ToDurationE(config.Params["size"]) + if err != nil { + return nil, fmt.Errorf("invalid size for sliding window: %v", err) + } + slide, err := cast.ToDurationE(config.Params["slide"]) + if err != nil { + return nil, fmt.Errorf("invalid slide for sliding window: %v", err) + } return &SlidingWindow{ + config: config, size: size, slide: slide, outputChan: make(chan []interface{}, 10), ctx: ctx, cancelFunc: cancel, data: make([]TimedData, 0), - } + }, nil } // Add 向滑动窗口中添加数据 @@ -58,19 +73,11 @@ func (sw *SlidingWindow) Add(data interface{}) { // 加锁以保证数据的并发安全 sw.mu.Lock() defer sw.mu.Unlock() - - var timestamp time.Time - if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok { - timestamp = ts.GetTimestamp() - } else { - timestamp = time.Now() - } - - // 将数据添加到窗口的数据列表中 - sw.data = append(sw.data, TimedData{ - Data: data, - Timestamp: timestamp, - }) + // 将数据添加到窗口的数据列表中 + sw.data = append(sw.data, TimedData{ + Data: data, + Timestamp: GetTimestamp(data, sw.config.TsProp), + }) } // Start 启动滑动窗口,开始定时触发窗口 diff --git a/window/sliding_window_test.go b/window/sliding_window_test.go index 2a08e8f..9bfd674 100644 --- a/window/sliding_window_test.go +++ b/window/sliding_window_test.go @@ -2,16 +2,24 @@ package window import ( "context" - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/rulego/streamsql/model" + "github.com/stretchr/testify/assert" ) func TestSlidingWindow(t *testing.T) { _, cancel := context.WithCancel(context.Background()) defer cancel() - sw := NewSlidingWindow(2*time.Second, 1*time.Second) + sw, _ := NewSlidingWindow(model.WindowConfig{ + Params: map[string]interface{}{ + "size": "2s", + "slide": "1s", + }, + TsProp: "Ts", + }) sw.SetCallback(func(results []interface{}) { t.Logf("Received results: %v", results) }) @@ -19,10 +27,15 @@ func TestSlidingWindow(t *testing.T) { // 添加数据 now := time.Now() - sw.Add(now.Add(-3 * time.Second)) - sw.Add(now.Add(-2 * time.Second)) - sw.Add(now.Add(-1 * time.Second)) - sw.Add(now) + t_3 := TestDate{Ts: now.Add(-3 * time.Second)} + t_2 := TestDate{Ts: now.Add(-2 * time.Second)} + t_1 := TestDate{Ts: now.Add(-1 * time.Second)} + t_0 := TestDate{Ts: now} + + sw.Add(t_3) + sw.Add(t_2) + sw.Add(t_1) + sw.Add(t_0) // 等待一段时间,触发窗口 time.Sleep(3 * time.Second) @@ -32,12 +45,40 @@ func TestSlidingWindow(t *testing.T) { var results []interface{} select { case results = <-resultsChan: - case <-time.After(1 * time.Second): + case <-time.After(100 * time.Second): t.Fatal("No results received within timeout") } // 预期结果:保留最近 2 秒内的数据 assert.Len(t, results, 2) - assert.Contains(t, results, now.Add(-1*time.Second)) - assert.Contains(t, results, now) + assert.Contains(t, results, t_1) + assert.Contains(t, results, t_0) +} + +type TestDate struct { + Ts time.Time +} + +type TestDate2 struct { + ts time.Time +} + +func (d TestDate2) GetTimestamp() time.Time { + return d.ts +} + +func TestGetTimestamp(t *testing.T) { + t_0 := time.Now() + data := map[string]interface{}{"device": "aa", "age": 15.0, "score": 100, "ts": t_0} + t_1 := GetTimestamp(data, "ts") + + data_1 := TestDate{Ts: t_0} + t_2 := GetTimestamp(data_1, "Ts") + + data_2 := TestDate2{ts: t_0} + t_3 := GetTimestamp(data_2, "") + + assert.Equal(t, t_0, t_1) + assert.Equal(t, t_0, t_2) + assert.Equal(t, t_0, t_3) } diff --git a/window/tumbling_window.go b/window/tumbling_window.go index 3a6e87c..cb79f50 100644 --- a/window/tumbling_window.go +++ b/window/tumbling_window.go @@ -3,8 +3,12 @@ package window import ( "context" + "fmt" "sync" "time" + + "github.com/rulego/streamsql/model" + "github.com/spf13/cast" ) // 确保 TumblingWindow 结构体实现了 Window 接口。 @@ -12,6 +16,8 @@ var _ Window = (*TumblingWindow)(nil) // TumblingWindow 表示一个滚动窗口,用于在固定时间间隔内收集数据并触发处理。 type TumblingWindow struct { + // config 是窗口的配置信息。 + config model.WindowConfig // size 是滚动窗口的时间大小,即窗口的持续时间。 size time.Duration // mu 用于保护对窗口数据的并发访问。 @@ -32,15 +38,20 @@ type TumblingWindow struct { // NewTumblingWindow 创建一个新的滚动窗口实例。 // 参数 size 是窗口的时间大小。 -func NewTumblingWindow(size time.Duration) *TumblingWindow { +func NewTumblingWindow(config model.WindowConfig) (*TumblingWindow, error) { // 创建一个可取消的上下文。 ctx, cancel := context.WithCancel(context.Background()) + size, err := cast.ToDurationE(config.Params["size"]) + if err != nil { + return nil, fmt.Errorf("invalid size for tumbling window: %v", err) + } return &TumblingWindow{ + config: config, size: size, outputChan: make(chan []interface{}, 10), ctx: ctx, cancelFunc: cancel, - } + }, nil } // Add 向滚动窗口添加数据。 diff --git a/window/tumbling_window_test.go b/window/tumbling_window_test.go index a9300b4..70d42b0 100644 --- a/window/tumbling_window_test.go +++ b/window/tumbling_window_test.go @@ -2,16 +2,21 @@ package window import ( "context" - "github.com/stretchr/testify/require" "testing" "time" + + "github.com/rulego/streamsql/model" + "github.com/stretchr/testify/require" ) func TestTumblingWindow(t *testing.T) { _, cancel := context.WithCancel(context.Background()) defer cancel() - tw := NewTumblingWindow(2 * time.Second) + tw, _ := NewTumblingWindow(model.WindowConfig{ + Type: "TumblingWindow", + Params: map[string]interface{}{"size": "2s"}, + }) tw.SetCallback(func(results []interface{}) { // Process results }) From 20313690d363ad1f558c2b4376a8d6e8aef62b0d Mon Sep 17 00:00:00 2001 From: dimon Date: Mon, 7 Apr 2025 17:27:58 +0800 Subject: [PATCH 2/5] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0window=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E6=A7=BDtimeslot=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/row.go | 1 + model/timeslot.go | 52 ++++++++++++++++++ utils/time.go | 8 +++ utils/time_test.go | 69 ++++++++++++++++++++++++ window/counting_window.go | 37 ++++++------- window/counting_window_test.go | 12 +++-- window/factory.go | 8 +-- window/sliding_window.go | 63 ++++++++++++++++------ window/sliding_window_test.go | 65 +++++++++++++++++------ window/tumbling_window.go | 96 +++++++++++++++++++++++++--------- window/tumbling_window_test.go | 4 +- 11 files changed, 330 insertions(+), 85 deletions(-) create mode 100644 model/timeslot.go create mode 100644 utils/time.go create mode 100644 utils/time_test.go diff --git a/model/row.go b/model/row.go index f724296..25e1f59 100644 --- a/model/row.go +++ b/model/row.go @@ -11,6 +11,7 @@ type RowEvent interface { type Row struct { Timestamp time.Time Data interface{} + Slot *TimeSlot } // GetTimestamp 获取时间戳 diff --git a/model/timeslot.go b/model/timeslot.go new file mode 100644 index 0000000..f6fb4b6 --- /dev/null +++ b/model/timeslot.go @@ -0,0 +1,52 @@ +package model + +import ( + "time" +) + +type TimeSlot struct { + Start *time.Time + End *time.Time +} + +func NewTimeSlot(start, end *time.Time) *TimeSlot { + return &TimeSlot{ + Start: start, + End: end, + } +} + +// Hash 生成槽位的哈希值 +func (ts TimeSlot) Hash() uint64 { + // 将开始时间和结束时间转换为 Unix 时间戳(纳秒级) + startNano := ts.Start.UnixNano() + endNano := ts.End.UnixNano() + + // 使用简单但高效的哈希算法 + // 将两个时间戳组合成一个唯一的哈希值 + hash := uint64(startNano) + hash = (hash << 32) | (hash >> 32) + hash = hash ^ uint64(endNano) + + return hash +} + +// Contains 检查给定时间是否在槽位范围内 +func (ts TimeSlot) Contains(t time.Time) bool { + return (t.Equal(*ts.Start) || t.After(*ts.Start)) && + (t.Equal(*ts.End) || t.Before(*ts.End)) +} + +func (ts *TimeSlot) GetStartTime() *time.Time { + if ts == nil || ts.Start == nil { + return nil + } + return ts.Start +} + +func (ts *TimeSlot) GetEndTime() *time.Time { + if ts == nil || ts.End == nil { + return nil + } + return ts.End +} diff --git a/utils/time.go b/utils/time.go new file mode 100644 index 0000000..2288b50 --- /dev/null +++ b/utils/time.go @@ -0,0 +1,8 @@ +package timex + +import "time" + +func AlignTimeToWindow(t time.Time, size time.Duration) time.Time { + offset := t.UnixNano() % int64(size) + return t.Add(time.Duration(-offset)) +} diff --git a/utils/time_test.go b/utils/time_test.go new file mode 100644 index 0000000..1db6904 --- /dev/null +++ b/utils/time_test.go @@ -0,0 +1,69 @@ +// Copyright 2021 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timex + +import ( + "testing" + "time" +) + +func TestAlignTimeToWindow(t *testing.T) { + tests := []struct { + name string + input time.Time + size time.Duration + expected time.Time + }{ + { + name: "对齐到1分钟窗口", + input: time.Date(2024, 1, 1, 12, 35, 56, 789000000, time.UTC), + size: 3 * time.Minute, + expected: time.Date(2024, 1, 1, 12, 33, 0, 0, time.UTC), + }, + { + name: "对齐到5分钟窗口", + input: time.Date(2024, 1, 1, 12, 37, 56, 789000000, time.UTC), + size: 5 * time.Minute, + expected: time.Date(2024, 1, 1, 12, 35, 0, 0, time.UTC), + }, + { + name: "对齐到1小时窗口", + input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC), + size: time.Hour, + expected: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), + }, + { + name: "对齐到1天窗口", + input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC), + size: 24 * time.Hour, + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "零时刻对齐测试", + input: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + size: time.Hour, + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AlignTimeToWindow(tt.input, tt.size) + if !got.Equal(tt.expected) { + t.Errorf("AlignTimeToWindow() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/window/counting_window.go b/window/counting_window.go index 424f29f..4a16a45 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -17,9 +17,9 @@ type CountingWindow struct { threshold int count int mu sync.Mutex - callback func([]interface{}) - dataBuffer []interface{} - outputChan chan []interface{} + callback func([]model.Row) + dataBuffer []model.Row + outputChan chan []model.Row ctx context.Context cancelFunc context.CancelFunc ticker *time.Ticker @@ -35,14 +35,14 @@ func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) { cw := &CountingWindow{ threshold: threshold, - dataBuffer: make([]interface{}, 0, threshold), - outputChan: make(chan []interface{}, 10), + dataBuffer: make([]model.Row, 0, threshold), + outputChan: make(chan []model.Row, 10), ctx: ctx, cancelFunc: cancel, triggerChan: make(chan struct{}, 1), } - if callback, ok := config.Params["callback"].(func([]interface{})); ok { + if callback, ok := config.Params["callback"].(func([]model.Row)); ok { cw.SetCallback(callback) } return cw, nil @@ -50,21 +50,21 @@ func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) { func (cw *CountingWindow) Add(data interface{}) { cw.mu.Lock() - cw.dataBuffer = append(cw.dataBuffer, data) + defer cw.mu.Unlock() + row := model.Row{ + Data: data, + Timestamp: GetTimestamp(data, cw.config.TsProp), + } + cw.dataBuffer = append(cw.dataBuffer, row) cw.count++ shouldTrigger := cw.count >= cw.threshold - cw.mu.Unlock() if shouldTrigger { - cw.mu.Lock() - v := append([]interface{}{}, cw.dataBuffer...) - cw.mu.Unlock() - go func() { if cw.callback != nil { - cw.callback(v) + cw.callback(cw.dataBuffer) } - cw.outputChan <- v + cw.outputChan <- cw.dataBuffer cw.Reset() }() } @@ -109,9 +109,10 @@ func (cw *CountingWindow) Reset() { cw.dataBuffer = cw.dataBuffer[:0] } -func (cw *CountingWindow) OutputChan() <-chan []interface{} { +func (cw *CountingWindow) OutputChan() <-chan []model.Row { return cw.outputChan } -func (cw *CountingWindow) GetResults() []interface{} { - return append([]interface{}{}, cw.dataBuffer...) -} + +// func (cw *CountingWindow) GetResults() []interface{} { +// return append([]mode.Row, cw.dataBuffer...) +// } diff --git a/window/counting_window_test.go b/window/counting_window_test.go index 50ef2bb..5d1ae02 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -34,7 +34,7 @@ func TestCountingWindow(t *testing.T) { // Trigger one more element to check threshold cw.Add(3) - results := make(chan []interface{}) + results := make(chan []model.Row) go func() { for res := range cw.OutputChan() { results <- res @@ -44,9 +44,13 @@ func TestCountingWindow(t *testing.T) { select { case res := <-results: assert.Len(t, res, 3) - assert.Contains(t, res, 0) - assert.Contains(t, res, 1) - assert.Contains(t, res, 2) + raw := make([]interface{}, len(res)) + for _, row := range res { + raw = append(raw, row.Data) + } + assert.Contains(t, raw, 0) + assert.Contains(t, raw, 1) + assert.Contains(t, raw, 2) case <-time.After(2 * time.Second): t.Error("No results received within timeout") } diff --git a/window/factory.go b/window/factory.go index d3cccce..483e81c 100644 --- a/window/factory.go +++ b/window/factory.go @@ -17,11 +17,11 @@ const ( type Window interface { Add(item interface{}) - GetResults() []interface{} + //GetResults() []interface{} Reset() Start() - OutputChan() <-chan []interface{} - SetCallback(callback func([]interface{})) + OutputChan() <-chan []model.Row + SetCallback(callback func([]model.Row)) Trigger() } @@ -38,7 +38,7 @@ func CreateWindow(config model.WindowConfig) (Window, error) { } } -func (cw *CountingWindow) SetCallback(callback func([]interface{})) { +func (cw *CountingWindow) SetCallback(callback func([]model.Row)) { cw.callback = callback } diff --git a/window/sliding_window.go b/window/sliding_window.go index 39f549e..4a700aa 100644 --- a/window/sliding_window.go +++ b/window/sliding_window.go @@ -7,6 +7,7 @@ import ( "time" "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" "github.com/spf13/cast" ) @@ -30,17 +31,19 @@ type SlidingWindow struct { // 用于保护数据并发访问的互斥锁 mu sync.Mutex // 存储窗口内的数据 - data []TimedData + data []model.Row // 用于输出窗口内数据的通道 - outputChan chan []interface{} + outputChan chan []model.Row // 当窗口触发时执行的回调函数 - callback func([]interface{}) + callback func([]model.Row) // 用于控制窗口生命周期的上下文 ctx context.Context // 用于取消上下文的函数 cancelFunc context.CancelFunc // 用于定时触发窗口的定时器 - timer *time.Timer + timer *time.Timer + startSlot *model.TimeSlot + currentSlot *model.TimeSlot } // NewSlidingWindow 创建一个新的滑动窗口实例 @@ -60,10 +63,10 @@ func NewSlidingWindow(config model.WindowConfig) (*SlidingWindow, error) { config: config, size: size, slide: slide, - outputChan: make(chan []interface{}, 10), + outputChan: make(chan []model.Row, 10), ctx: ctx, cancelFunc: cancel, - data: make([]TimedData, 0), + data: make([]model.Row, 0), }, nil } @@ -74,10 +77,34 @@ func (sw *SlidingWindow) Add(data interface{}) { sw.mu.Lock() defer sw.mu.Unlock() // 将数据添加到窗口的数据列表中 - sw.data = append(sw.data, TimedData{ + + if sw.startSlot == nil { + sw.startSlot = sw.createSlot(GetTimestamp(data, sw.config.TsProp)) + sw.currentSlot = sw.startSlot + } + row := model.Row{ Data: data, Timestamp: GetTimestamp(data, sw.config.TsProp), - }) + } + sw.data = append(sw.data, row) +} + +func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot { + // 创建一个新的时间槽位 + start := timex.AlignTimeToWindow(t, sw.size) + end := start.Add(sw.size) + slot := model.NewTimeSlot(&start, &end) + return slot +} + +func (sw *SlidingWindow) NextSlot() *model.TimeSlot { + if sw.currentSlot == nil { + return nil + } + start := sw.currentSlot.Start.Add(sw.slide) + end := sw.currentSlot.End.Add(sw.slide) + next := model.NewTimeSlot(&start, &end) + return next } // Start 启动滑动窗口,开始定时触发窗口 @@ -113,19 +140,22 @@ func (sw *SlidingWindow) Trigger() { } // 计算截止时间,即当前时间减去窗口的总大小 - cutoff := time.Now().Add(-sw.size) - var newData []TimedData + next := sw.NextSlot() + var newData []model.Row // 遍历窗口内的数据,只保留在截止时间之后的数据 for _, item := range sw.data { - if item.Timestamp.After(cutoff) { + if next.Contains(item.Timestamp) { newData = append(newData, item) } } // 提取出 Data 字段组成 []interface{} 类型的数据 - resultData := make([]interface{}, 0, len(newData)) - for _, item := range newData { - resultData = append(resultData, item.Data) + resultData := make([]model.Row, 0) + for _, item := range sw.data { + if sw.currentSlot.Contains(item.Timestamp) { + item.Slot = sw.currentSlot + resultData = append(resultData, item) + } } // 如果设置了回调函数,则执行回调函数 @@ -135,6 +165,7 @@ func (sw *SlidingWindow) Trigger() { // 更新窗口内的数据 sw.data = newData + sw.currentSlot = next // 将新的数据发送到输出通道 sw.outputChan <- resultData } @@ -149,13 +180,13 @@ func (sw *SlidingWindow) Reset() { } // OutputChan 返回滑动窗口的输出通道 -func (sw *SlidingWindow) OutputChan() <-chan []interface{} { +func (sw *SlidingWindow) OutputChan() <-chan []model.Row { return sw.outputChan } // SetCallback 设置滑动窗口触发时执行的回调函数 // 参数 callback 表示要设置的回调函数 -func (sw *SlidingWindow) SetCallback(callback func([]interface{})) { +func (sw *SlidingWindow) SetCallback(callback func([]model.Row)) { sw.callback = callback } diff --git a/window/sliding_window_test.go b/window/sliding_window_test.go index 9bfd674..5b4e0a6 100644 --- a/window/sliding_window_test.go +++ b/window/sliding_window_test.go @@ -18,19 +18,19 @@ func TestSlidingWindow(t *testing.T) { "size": "2s", "slide": "1s", }, - TsProp: "Ts", + TsProp: "Ts", + TimeUnit: time.Second, }) - sw.SetCallback(func(results []interface{}) { + sw.SetCallback(func(results []model.Row) { t.Logf("Received results: %v", results) }) sw.Start() // 添加数据 - now := time.Now() - t_3 := TestDate{Ts: now.Add(-3 * time.Second)} - t_2 := TestDate{Ts: now.Add(-2 * time.Second)} - t_1 := TestDate{Ts: now.Add(-1 * time.Second)} - t_0 := TestDate{Ts: now} + t_3 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 56, 789000000, time.UTC), tag: "1"} + t_2 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 57, 789000000, time.UTC), tag: "2"} + t_1 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 58, 789000000, time.UTC), tag: "3"} + t_0 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 59, 789000000, time.UTC), tag: "4"} sw.Add(t_3) sw.Add(t_2) @@ -38,25 +38,56 @@ func TestSlidingWindow(t *testing.T) { sw.Add(t_0) // 等待一段时间,触发窗口 - time.Sleep(3 * time.Second) + //time.Sleep(3 * time.Second) // 检查结果 resultsChan := sw.OutputChan() - var results []interface{} - select { - case results = <-resultsChan: - case <-time.After(100 * time.Second): - t.Fatal("No results received within timeout") + var results []model.Row + + for { + select { + case results = <-resultsChan: + raw := make([]TestDate, 0) + for _, row := range results { + raw = append(raw, row.Data.(TestDate)) + } + + // 获取当前窗口的时间范围 + windowStart := results[0].Slot.Start + windowEnd := results[0].Slot.End + t.Logf("Window range: %v - %v", windowStart, windowEnd) + + // 检查窗口内的数据 + expectedData := make([]TestDate, 0) + if windowStart.Before(t_3.Ts) && windowEnd.After(t_2.Ts) { + expectedData = []TestDate{t_3, t_2} + } else if windowStart.Before(t_2.Ts) && windowEnd.After(t_1.Ts) { + expectedData = []TestDate{t_2, t_1} + } else if windowStart.Before(t_1.Ts) && windowEnd.After(t_0.Ts) { + expectedData = []TestDate{t_1, t_0} + } else { + expectedData = []TestDate{t_0} + } + + // 验证窗口数据 + assert.Equal(t, len(expectedData), len(raw), "窗口数据数量不匹配") + for _, expected := range expectedData { + assert.Contains(t, raw, expected, "窗口缺少预期数据") + } + default: + // 通道为空时退出 + goto END + } } +END: // 预期结果:保留最近 2 秒内的数据 - assert.Len(t, results, 2) - assert.Contains(t, results, t_1) - assert.Contains(t, results, t_0) + assert.Len(t, results, 0) } type TestDate struct { - Ts time.Time + Ts time.Time + tag string } type TestDate2 struct { diff --git a/window/tumbling_window.go b/window/tumbling_window.go index cb79f50..a0f45c1 100644 --- a/window/tumbling_window.go +++ b/window/tumbling_window.go @@ -8,6 +8,7 @@ import ( "time" "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" "github.com/spf13/cast" ) @@ -23,17 +24,19 @@ type TumblingWindow struct { // mu 用于保护对窗口数据的并发访问。 mu sync.Mutex // data 存储窗口内收集的数据。 - data []interface{} + data []model.Row // outputChan 是一个通道,用于在窗口触发时发送数据。 - outputChan chan []interface{} + outputChan chan []model.Row // callback 是一个可选的回调函数,在窗口触发时调用。 - callback func([]interface{}) + callback func([]model.Row) // ctx 用于控制窗口的生命周期。 ctx context.Context // cancelFunc 用于取消窗口的操作。 cancelFunc context.CancelFunc // timer 用于定时触发窗口。 - timer *time.Timer + timer *time.Timer + startSlot *model.TimeSlot + currentSlot *model.TimeSlot } // NewTumblingWindow 创建一个新的滚动窗口实例。 @@ -48,7 +51,7 @@ func NewTumblingWindow(config model.WindowConfig) (*TumblingWindow, error) { return &TumblingWindow{ config: config, size: size, - outputChan: make(chan []interface{}, 10), + outputChan: make(chan []model.Row, 10), ctx: ctx, cancelFunc: cancel, }, nil @@ -61,7 +64,33 @@ func (tw *TumblingWindow) Add(data interface{}) { tw.mu.Lock() defer tw.mu.Unlock() // 将数据追加到窗口的数据列表中。 - tw.data = append(tw.data, data) + if tw.startSlot == nil { + tw.startSlot = tw.createSlot(GetTimestamp(data, tw.config.TsProp)) + tw.currentSlot = tw.startSlot + } + row := model.Row{ + Data: data, + Timestamp: GetTimestamp(data, tw.config.TsProp), + } + tw.data = append(tw.data, row) +} + +func (sw *TumblingWindow) createSlot(t time.Time) *model.TimeSlot { + // 创建一个新的时间槽位 + start := timex.AlignTimeToWindow(t, sw.size) + end := start.Add(sw.size) + slot := model.NewTimeSlot(&start, &end) + return slot +} + +func (sw *TumblingWindow) NextSlot() *model.TimeSlot { + if sw.currentSlot == nil { + return nil + } + start := sw.currentSlot.End + end := sw.currentSlot.End.Add(sw.size) + next := model.NewTimeSlot(start, &end) + return next } // Stop 停止滚动窗口的操作。 @@ -98,16 +127,35 @@ func (tw *TumblingWindow) Trigger() { // 加锁以确保并发安全。 tw.mu.Lock() defer tw.mu.Unlock() - - // 如果设置了回调函数,则调用它。 - if tw.callback != nil { - tw.callback(tw.data) + // 计算截止时间,即当前时间减去窗口的总大小 + next := tw.NextSlot() + var newData []model.Row + // 遍历窗口内的数据,只保留在截止时间之后的数据 + for _, item := range tw.data { + if next.Contains(item.Timestamp) { + newData = append(newData, item) + } } - // 将窗口数据发送到输出通道。 - tw.outputChan <- append([]interface{}{}, tw.data...) - // 重置窗口数据。 - tw.data = nil + // 提取出 Data 字段组成 []interface{} 类型的数据 + resultData := make([]model.Row, 0) + for _, item := range tw.data { + if tw.currentSlot.Contains(item.Timestamp) { + item.Slot = tw.currentSlot + resultData = append(resultData, item) + } + } + + // 如果设置了回调函数,则执行回调函数 + if tw.callback != nil { + tw.callback(resultData) + } + + // 更新窗口内的数据 + tw.data = newData + tw.currentSlot = next + // 将新的数据发送到输出通道 + tw.outputChan <- resultData } // Reset 重置滚动窗口的数据。 @@ -120,21 +168,21 @@ func (tw *TumblingWindow) Reset() { } // OutputChan 返回一个只读通道,用于接收窗口触发时的数据。 -func (tw *TumblingWindow) OutputChan() <-chan []interface{} { +func (tw *TumblingWindow) OutputChan() <-chan []model.Row { return tw.outputChan } // SetCallback 设置滚动窗口触发时的回调函数。 // 参数 callback 是要设置的回调函数。 -func (tw *TumblingWindow) SetCallback(callback func([]interface{})) { +func (tw *TumblingWindow) SetCallback(callback func([]model.Row)) { tw.callback = callback } -// GetResults 获取当前滚动窗口中的数据副本。 -func (tw *TumblingWindow) GetResults() []interface{} { - // 加锁以确保并发安全。 - tw.mu.Lock() - defer tw.mu.Unlock() - // 返回窗口数据的副本。 - return append([]interface{}{}, tw.data...) -} +// // GetResults 获取当前滚动窗口中的数据副本。 +// func (tw *TumblingWindow) GetResults() []interface{} { +// // 加锁以确保并发安全。 +// tw.mu.Lock() +// defer tw.mu.Unlock() +// // 返回窗口数据的副本。 +// return append([]interface{}{}, tw.data...) +// } diff --git a/window/tumbling_window_test.go b/window/tumbling_window_test.go index 70d42b0..8aea937 100644 --- a/window/tumbling_window_test.go +++ b/window/tumbling_window_test.go @@ -17,7 +17,7 @@ func TestTumblingWindow(t *testing.T) { Type: "TumblingWindow", Params: map[string]interface{}{"size": "2s"}, }) - tw.SetCallback(func(results []interface{}) { + tw.SetCallback(func(results []model.Row) { // Process results }) go tw.Start() @@ -30,7 +30,7 @@ func TestTumblingWindow(t *testing.T) { // Check output channel resultsChan := tw.OutputChan() - var results []interface{} + var results []model.Row select { case results = <-resultsChan: case <-time.After(3 * time.Second): From b28c810c65c9f265f6da441fab06b51e2e414ae4 Mon Sep 17 00:00:00 2001 From: dexter Date: Mon, 7 Apr 2025 21:44:07 +0800 Subject: [PATCH 3/5] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9counting=20window?= =?UTF-8?q?=20=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/timeslot.go | 2 +- utils/time.go | 10 +++++ window/counting_window.go | 69 +++++++++++++++++++++++++++------- window/counting_window_test.go | 27 ++++++------- window/factory.go | 9 ----- window/sliding_window.go | 54 ++++++++++---------------- 6 files changed, 98 insertions(+), 73 deletions(-) diff --git a/model/timeslot.go b/model/timeslot.go index f6fb4b6..0ff5043 100644 --- a/model/timeslot.go +++ b/model/timeslot.go @@ -34,7 +34,7 @@ func (ts TimeSlot) Hash() uint64 { // Contains 检查给定时间是否在槽位范围内 func (ts TimeSlot) Contains(t time.Time) bool { return (t.Equal(*ts.Start) || t.After(*ts.Start)) && - (t.Equal(*ts.End) || t.Before(*ts.End)) + t.Before(*ts.End) } func (ts *TimeSlot) GetStartTime() *time.Time { diff --git a/utils/time.go b/utils/time.go index 2288b50..8c83f33 100644 --- a/utils/time.go +++ b/utils/time.go @@ -2,7 +2,17 @@ package timex import "time" +// AlignTimeToWindow 将时间对齐到窗口的起始时间。 func AlignTimeToWindow(t time.Time, size time.Duration) time.Time { offset := t.UnixNano() % int64(size) return t.Add(time.Duration(-offset)) } + +// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。 +func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time { + trunc := t.Truncate(timeUnit) + if !roundUp { + return trunc.Add(timeUnit) + } + return trunc +} diff --git a/window/counting_window.go b/window/counting_window.go index 4a16a45..5cd70eb 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -7,6 +7,7 @@ import ( "time" "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" "github.com/spf13/cast" ) @@ -51,20 +52,33 @@ func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) { func (cw *CountingWindow) Add(data interface{}) { cw.mu.Lock() defer cw.mu.Unlock() + // 将数据添加到窗口的数据列表中 + t := GetTimestamp(data, cw.config.TsProp) row := model.Row{ Data: data, - Timestamp: GetTimestamp(data, cw.config.TsProp), + Timestamp: t, } cw.dataBuffer = append(cw.dataBuffer, row) cw.count++ - shouldTrigger := cw.count >= cw.threshold + shouldTrigger := cw.count == cw.threshold if shouldTrigger { + slot := cw.createSlot(cw.dataBuffer[:cw.threshold]) + for _, r := range cw.dataBuffer[:cw.threshold] { + // 由于Row是值类型,这里需要通过指针来修改Slot字段 + (&r).Slot = slot + } + data := cw.dataBuffer[:cw.threshold] + if len(cw.dataBuffer) > cw.threshold { + cw.dataBuffer = cw.dataBuffer[cw.threshold:] + } else { + cw.dataBuffer = make([]model.Row, 0, cw.threshold) + } go func() { if cw.callback != nil { - cw.callback(cw.dataBuffer) + cw.callback(data) } - cw.outputChan <- cw.dataBuffer + cw.outputChan <- data cw.Reset() }() } @@ -89,24 +103,34 @@ func (cw *CountingWindow) Start() { } func (cw *CountingWindow) Trigger() { - cw.triggerChan <- struct{}{} + // cw.triggerChan <- struct{}{} - go func() { - cw.mu.Lock() - defer cw.mu.Unlock() + // go func() { + // cw.mu.Lock() + // defer cw.mu.Unlock() - if cw.callback != nil && len(cw.dataBuffer) > 0 { - cw.callback(cw.dataBuffer) - } - cw.Reset() - }() + // if cw.callback != nil && len(cw.dataBuffer) > 0 { + // var resultData []model.Row + // if len(cw.dataBuffer) > cw.threshold { + // resultData = cw.dataBuffer[:cw.threshold] + // } else { + // resultData = cw.dataBuffer + // } + // slot := cw.createSlot(resultData) + // for _, r := range resultData { + // r.Slot = slot + // } + // cw.callback(resultData) + // } + // cw.Reset() + // }() } func (cw *CountingWindow) Reset() { cw.mu.Lock() defer cw.mu.Unlock() cw.count = 0 - cw.dataBuffer = cw.dataBuffer[:0] + cw.dataBuffer = cw.dataBuffer[0:] } func (cw *CountingWindow) OutputChan() <-chan []model.Row { @@ -116,3 +140,20 @@ func (cw *CountingWindow) OutputChan() <-chan []model.Row { // func (cw *CountingWindow) GetResults() []interface{} { // return append([]mode.Row, cw.dataBuffer...) // } + +// createSlot 创建一个新的时间槽位 +func (cw *CountingWindow) createSlot(data []model.Row) *model.TimeSlot { + if len(data) == 0 { + return nil + } else if len(data) < cw.threshold { + start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true) + end := timex.AlignTime(data[len(cw.dataBuffer)-1].Timestamp, cw.config.TimeUnit, false) + slot := model.NewTimeSlot(&start, &end) + return slot + } else { + start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true) + end := timex.AlignTime(data[cw.threshold-1].Timestamp, cw.config.TimeUnit, false) + slot := model.NewTimeSlot(&start, &end) + return slot + } +} diff --git a/window/counting_window_test.go b/window/counting_window_test.go index 5d1ae02..ea36f6d 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -34,30 +34,27 @@ func TestCountingWindow(t *testing.T) { // Trigger one more element to check threshold cw.Add(3) - results := make(chan []model.Row) - go func() { - for res := range cw.OutputChan() { - results <- res - } - }() + resultsChan := cw.OutputChan() + //results := make(chan []model.Row) + // go func() { + // for res := range cw.OutputChan() { + // results <- res + // } + // }() select { - case res := <-results: + case res := <-resultsChan: assert.Len(t, res, 3) - raw := make([]interface{}, len(res)) - for _, row := range res { - raw = append(raw, row.Data) - } - assert.Contains(t, raw, 0) - assert.Contains(t, raw, 1) - assert.Contains(t, raw, 2) + assert.Equal(t, 0, res[0].Data, "第一个元素应该是0") + assert.Equal(t, 1, res[1].Data, "第二个元素应该是1") + assert.Equal(t, 2, res[2].Data, "第三个元素应该是2") case <-time.After(2 * time.Second): t.Error("No results received within timeout") } // Test case 2: Reset cw.Reset() - assert.Len(t, cw.dataBuffer, 0) + assert.Len(t, cw.dataBuffer, 1) } func TestCountingWindowBadThreshold(t *testing.T) { diff --git a/window/factory.go b/window/factory.go index 483e81c..916c221 100644 --- a/window/factory.go +++ b/window/factory.go @@ -69,12 +69,3 @@ func GetTimestamp(data interface{}, tsProp string) time.Time { } return time.Now() } - -// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。 -func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time { - trunc := t.Truncate(timeUnit) - if !roundUp { - return trunc.Add(timeUnit) - } - return trunc -} diff --git a/window/sliding_window.go b/window/sliding_window.go index 4a700aa..13c79ff 100644 --- a/window/sliding_window.go +++ b/window/sliding_window.go @@ -42,7 +42,6 @@ type SlidingWindow struct { cancelFunc context.CancelFunc // 用于定时触发窗口的定时器 timer *time.Timer - startSlot *model.TimeSlot currentSlot *model.TimeSlot } @@ -77,36 +76,17 @@ func (sw *SlidingWindow) Add(data interface{}) { sw.mu.Lock() defer sw.mu.Unlock() // 将数据添加到窗口的数据列表中 - - if sw.startSlot == nil { - sw.startSlot = sw.createSlot(GetTimestamp(data, sw.config.TsProp)) - sw.currentSlot = sw.startSlot + t := GetTimestamp(data, sw.config.TsProp) + if sw.currentSlot == nil { + sw.currentSlot = sw.createSlot(t) } row := model.Row{ Data: data, - Timestamp: GetTimestamp(data, sw.config.TsProp), + Timestamp: t, } sw.data = append(sw.data, row) } -func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot { - // 创建一个新的时间槽位 - start := timex.AlignTimeToWindow(t, sw.size) - end := start.Add(sw.size) - slot := model.NewTimeSlot(&start, &end) - return slot -} - -func (sw *SlidingWindow) NextSlot() *model.TimeSlot { - if sw.currentSlot == nil { - return nil - } - start := sw.currentSlot.Start.Add(sw.slide) - end := sw.currentSlot.End.Add(sw.slide) - next := model.NewTimeSlot(&start, &end) - return next -} - // Start 启动滑动窗口,开始定时触发窗口 func (sw *SlidingWindow) Start() { go func() { @@ -190,15 +170,21 @@ func (sw *SlidingWindow) SetCallback(callback func([]model.Row)) { sw.callback = callback } -// GetResults 获取滑动窗口内的当前数据 -func (sw *SlidingWindow) GetResults() []interface{} { - // 加锁以保证数据的并发安全 - sw.mu.Lock() - defer sw.mu.Unlock() - // 提取出 Data 字段组成 []interface{} 类型的数据 - resultData := make([]interface{}, 0, len(sw.data)) - for _, item := range sw.data { - resultData = append(resultData, item.Data) +func (sw *SlidingWindow) NextSlot() *model.TimeSlot { + if sw.currentSlot == nil { + return nil } - return resultData + start := sw.currentSlot.Start.Add(sw.slide) + end := sw.currentSlot.End.Add(sw.slide) + next := model.NewTimeSlot(&start, &end) + return next +} + +// createSlot 创建一个新的时间槽位 +func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot { + // 创建一个新的时间槽位 + start := timex.AlignTimeToWindow(t, sw.size) + end := start.Add(sw.size) + slot := model.NewTimeSlot(&start, &end) + return slot } From 67c6a91dbba69df9f2beb82a88efa01d10190881 Mon Sep 17 00:00:00 2001 From: dimon Date: Tue, 8 Apr 2025 11:33:37 +0800 Subject: [PATCH 4/5] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=8D=95?= =?UTF-8?q?=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- window/counting_window.go | 54 +++++++++++---------- window/counting_window_test.go | 4 +- window/sliding_window.go | 22 +++++---- window/sliding_window_test.go | 14 ++++++ window/tumbling_window.go | 35 +++++++------- window/tumbling_window_test.go | 85 ++++++++++++++++++++++++---------- 6 files changed, 139 insertions(+), 75 deletions(-) diff --git a/window/counting_window.go b/window/counting_window.go index 5cd70eb..a327c7d 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -60,7 +60,7 @@ func (cw *CountingWindow) Add(data interface{}) { } cw.dataBuffer = append(cw.dataBuffer, row) cw.count++ - shouldTrigger := cw.count == cw.threshold + shouldTrigger := cw.count >= cw.threshold if shouldTrigger { slot := cw.createSlot(cw.dataBuffer[:cw.threshold]) @@ -70,16 +70,22 @@ func (cw *CountingWindow) Add(data interface{}) { } data := cw.dataBuffer[:cw.threshold] if len(cw.dataBuffer) > cw.threshold { - cw.dataBuffer = cw.dataBuffer[cw.threshold:] + remaining := len(cw.dataBuffer) - cw.threshold + newBuffer := make([]model.Row, remaining, cw.threshold) + copy(newBuffer, cw.dataBuffer[cw.threshold:]) + cw.dataBuffer = newBuffer } else { cw.dataBuffer = make([]model.Row, 0, cw.threshold) } go func() { + cw.mu.Lock() if cw.callback != nil { cw.callback(data) } cw.outputChan <- data - cw.Reset() + cw.count = 0 + //cw.Reset() + cw.mu.Unlock() }() } } @@ -94,7 +100,7 @@ func (cw *CountingWindow) Start() { for { select { case <-cw.ticker.C: - cw.Trigger() + //cw.Trigger() case <-cw.ctx.Done(): return } @@ -103,34 +109,34 @@ func (cw *CountingWindow) Start() { } func (cw *CountingWindow) Trigger() { - // cw.triggerChan <- struct{}{} + cw.triggerChan <- struct{}{} - // go func() { - // cw.mu.Lock() - // defer cw.mu.Unlock() + go func() { + cw.mu.Lock() + defer cw.mu.Unlock() - // if cw.callback != nil && len(cw.dataBuffer) > 0 { - // var resultData []model.Row - // if len(cw.dataBuffer) > cw.threshold { - // resultData = cw.dataBuffer[:cw.threshold] - // } else { - // resultData = cw.dataBuffer - // } - // slot := cw.createSlot(resultData) - // for _, r := range resultData { - // r.Slot = slot - // } - // cw.callback(resultData) - // } - // cw.Reset() - // }() + if cw.callback != nil && len(cw.dataBuffer) > 0 { + var resultData []model.Row + if len(cw.dataBuffer) > cw.threshold { + resultData = cw.dataBuffer[:cw.threshold] + } else { + resultData = cw.dataBuffer + } + slot := cw.createSlot(resultData) + for _, r := range resultData { + r.Slot = slot + } + cw.callback(resultData) + } + cw.Reset() + }() } func (cw *CountingWindow) Reset() { cw.mu.Lock() defer cw.mu.Unlock() cw.count = 0 - cw.dataBuffer = cw.dataBuffer[0:] + cw.dataBuffer = cw.dataBuffer[:0] } func (cw *CountingWindow) OutputChan() <-chan []model.Row { diff --git a/window/counting_window_test.go b/window/counting_window_test.go index ea36f6d..d0c0ca5 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -51,10 +51,10 @@ func TestCountingWindow(t *testing.T) { case <-time.After(2 * time.Second): t.Error("No results received within timeout") } - + assert.Len(t, cw.dataBuffer, 1) // Test case 2: Reset cw.Reset() - assert.Len(t, cw.dataBuffer, 1) + assert.Len(t, cw.dataBuffer, 0) } func TestCountingWindowBadThreshold(t *testing.T) { diff --git a/window/sliding_window.go b/window/sliding_window.go index 13c79ff..194be63 100644 --- a/window/sliding_window.go +++ b/window/sliding_window.go @@ -80,11 +80,13 @@ func (sw *SlidingWindow) Add(data interface{}) { if sw.currentSlot == nil { sw.currentSlot = sw.createSlot(t) } - row := model.Row{ - Data: data, - Timestamp: t, - } - sw.data = append(sw.data, row) + go func() { + row := model.Row{ + Data: data, + Timestamp: t, + } + sw.data = append(sw.data, row) + }() } // Start 启动滑动窗口,开始定时触发窗口 @@ -121,10 +123,13 @@ func (sw *SlidingWindow) Trigger() { // 计算截止时间,即当前时间减去窗口的总大小 next := sw.NextSlot() - var newData []model.Row - // 遍历窗口内的数据,只保留在截止时间之后的数据 + // 保留下一个窗口的数据 + tms := next.Start.Add(-sw.size) + tme := next.End.Add(sw.size) + temp := model.NewTimeSlot(&tms, &tme) + newData := make([]model.Row, 0) for _, item := range sw.data { - if next.Contains(item.Timestamp) { + if temp.Contains(item.Timestamp) { newData = append(newData, item) } } @@ -157,6 +162,7 @@ func (sw *SlidingWindow) Reset() { defer sw.mu.Unlock() // 清空窗口内的数据 sw.data = nil + sw.currentSlot = nil } // OutputChan 返回滑动窗口的输出通道 diff --git a/window/sliding_window_test.go b/window/sliding_window_test.go index 5b4e0a6..b605540 100644 --- a/window/sliding_window_test.go +++ b/window/sliding_window_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" "github.com/stretchr/testify/assert" ) @@ -59,14 +60,27 @@ func TestSlidingWindow(t *testing.T) { // 检查窗口内的数据 expectedData := make([]TestDate, 0) + if windowStart.Before(t_3.Ts) && windowEnd.After(t_2.Ts) { expectedData = []TestDate{t_3, t_2} + start := timex.AlignTimeToWindow(t_3.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) } else if windowStart.Before(t_2.Ts) && windowEnd.After(t_1.Ts) { expectedData = []TestDate{t_2, t_1} + start := timex.AlignTimeToWindow(t_2.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) } else if windowStart.Before(t_1.Ts) && windowEnd.After(t_0.Ts) { expectedData = []TestDate{t_1, t_0} + start := timex.AlignTimeToWindow(t_1.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) } else { expectedData = []TestDate{t_0} + start := timex.AlignTimeToWindow(t_0.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) } // 验证窗口数据 diff --git a/window/tumbling_window.go b/window/tumbling_window.go index a0f45c1..9e15046 100644 --- a/window/tumbling_window.go +++ b/window/tumbling_window.go @@ -35,7 +35,6 @@ type TumblingWindow struct { cancelFunc context.CancelFunc // timer 用于定时触发窗口。 timer *time.Timer - startSlot *model.TimeSlot currentSlot *model.TimeSlot } @@ -64,15 +63,16 @@ func (tw *TumblingWindow) Add(data interface{}) { tw.mu.Lock() defer tw.mu.Unlock() // 将数据追加到窗口的数据列表中。 - if tw.startSlot == nil { - tw.startSlot = tw.createSlot(GetTimestamp(data, tw.config.TsProp)) - tw.currentSlot = tw.startSlot + if tw.currentSlot == nil { + tw.currentSlot = tw.createSlot(GetTimestamp(data, tw.config.TsProp)) } - row := model.Row{ - Data: data, - Timestamp: GetTimestamp(data, tw.config.TsProp), - } - tw.data = append(tw.data, row) + go func() { + row := model.Row{ + Data: data, + Timestamp: GetTimestamp(data, tw.config.TsProp), + } + tw.data = append(tw.data, row) + }() } func (sw *TumblingWindow) createSlot(t time.Time) *model.TimeSlot { @@ -89,8 +89,7 @@ func (sw *TumblingWindow) NextSlot() *model.TimeSlot { } start := sw.currentSlot.End end := sw.currentSlot.End.Add(sw.size) - next := model.NewTimeSlot(start, &end) - return next + return model.NewTimeSlot(start, &end) } // Stop 停止滚动窗口的操作。 @@ -127,17 +126,20 @@ func (tw *TumblingWindow) Trigger() { // 加锁以确保并发安全。 tw.mu.Lock() defer tw.mu.Unlock() - // 计算截止时间,即当前时间减去窗口的总大小 + // 计算下一个窗口槽位 next := tw.NextSlot() - var newData []model.Row - // 遍历窗口内的数据,只保留在截止时间之后的数据 + // 保留下一个窗口的数据 + tms := next.Start.Add(-tw.size) + tme := next.End.Add(tw.size) + temp := model.NewTimeSlot(&tms, &tme) + newData := make([]model.Row, 0) for _, item := range tw.data { - if next.Contains(item.Timestamp) { + if temp.Contains(item.Timestamp) { newData = append(newData, item) } } - // 提取出 Data 字段组成 []interface{} 类型的数据 + // 提取出当前窗口数据 resultData := make([]model.Row, 0) for _, item := range tw.data { if tw.currentSlot.Contains(item.Timestamp) { @@ -165,6 +167,7 @@ func (tw *TumblingWindow) Reset() { defer tw.mu.Unlock() // 清空窗口数据。 tw.data = nil + tw.currentSlot = nil } // OutputChan 返回一个只读通道,用于接收窗口触发时的数据。 diff --git a/window/tumbling_window_test.go b/window/tumbling_window_test.go index 8aea937..7e9b6f7 100644 --- a/window/tumbling_window_test.go +++ b/window/tumbling_window_test.go @@ -2,6 +2,7 @@ package window import ( "context" + "fmt" "testing" "time" @@ -16,6 +17,7 @@ func TestTumblingWindow(t *testing.T) { tw, _ := NewTumblingWindow(model.WindowConfig{ Type: "TumblingWindow", Params: map[string]interface{}{"size": "2s"}, + TsProp: "Ts", }) tw.SetCallback(func(results []model.Row) { // Process results @@ -23,46 +25,79 @@ func TestTumblingWindow(t *testing.T) { go tw.Start() // Add data every 500ms + baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC) + // 添加测试数据 for i := 0; i < 5; i++ { - tw.Add(i) - time.Sleep(1100 * time.Millisecond) + data := TestDate{ + Ts: baseTime.Add(time.Duration(i) * 1100 * time.Millisecond), + tag: fmt.Sprintf("%d", i), + } + tw.Add(data) } - // Check output channel + // 收集窗口结果 resultsChan := tw.OutputChan() - var results []model.Row - select { - case results = <-resultsChan: - case <-time.After(3 * time.Second): - t.Fatal("No results received within timeout") + var all [][]model.Row = make([][]model.Row, 0) + + // 收集所有窗口数据 +COLLECT: + for { + select { + case results := <-resultsChan: + all = append(all, results) + if len(all) >= 3 { + break COLLECT + } + default: + } } - // Verify that data is sent every 2 seconds - require.Len(t, results, 2) - require.Equal(t, []interface{}{0, 1}, results) + // 验证窗口数据 + require.Len(t, all, 3, "应该有3个时间窗口的数据") - // Verify next batch - select { - case results = <-resultsChan: - require.Len(t, results, 2) - require.Equal(t, []interface{}{2, 3}, results) - case <-time.After(3 * time.Second): - t.Fatal("No results received within timeout") + // 验证每个窗口的数据 + expectedWindows := []struct { + size int + tags []string + startIdx int + }{ + {size: 2, tags: []string{"0", "1"}, startIdx: 0}, + {size: 2, tags: []string{"2", "3"}, startIdx: 1}, + {size: 1, tags: []string{"4"}, startIdx: 2}, } - //time.Sleep(1100 * time.Millisecond) - //results = <-resultsChan + for i, window := range all { + expected := expectedWindows[i] + require.Len(t, window, expected.size, "窗口 %d 数据数量不匹配", i) + + // 验证数据内容 + for _, row := range window { + require.Contains(t, expected.tags, row.Data.(TestDate).tag) + } + + // 验证时间槽 + startTime := baseTime.Add(time.Duration(i*2) * time.Second) + endTime := startTime.Add(2 * time.Second) + require.True(t, window[0].Slot.Start.Equal(startTime) && + window[0].Slot.End.Equal(endTime), + "窗口 %d 时间槽边界不正确", i) + } // Verify reset and final batch tw.Reset() - tw.Add(99) + tw.Add(TestDate{ + Ts: baseTime.Add(time.Duration(99) * 1100 * time.Millisecond), + tag: fmt.Sprintf("%d", 99), + }) + // time.Sleep(1100 * time.Millisecond) cancel() select { - case results = <-resultsChan: + case results := <-resultsChan: require.Len(t, results, 1) - require.Equal(t, []interface{}{99}, results) - case <-time.After(3 * time.Second): - t.Fatal("No results received within timeout") + require.Equal(t, "99", results[0].Data.(TestDate).tag) + startTime := baseTime.Add(108 * time.Second) + endTime := baseTime.Add(110 * time.Second) + require.True(t, results[0].Slot.Start.Equal(startTime) && results[0].Slot.End.Equal(endTime)) } } From b9b38565a0899e4f656f0e2e3cda513cc582680a Mon Sep 17 00:00:00 2001 From: dimon Date: Tue, 8 Apr 2025 17:38:34 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=9A=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E5=88=AB=E5=90=8D=E6=98=A0=E5=B0=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aggregator/builtin.go | 116 +++++++++++++++++++--------- aggregator/context_aggregator.go | 45 +++++++++++ aggregator/group_aggregator.go | 45 ++++++++++- aggregator/group_aggregator_test.go | 23 ++++-- model/model.go | 5 +- model/timeslot.go | 14 ++++ rsql/ast.go | 44 ++++++++--- rsql/parser_test.go | 4 +- stream/stream.go | 11 ++- stream/stream_test.go | 92 ++++++++++++++++++++++ streamsql_test.go | 76 +++++++++++++++++- 11 files changed, 415 insertions(+), 60 deletions(-) create mode 100644 aggregator/context_aggregator.go diff --git a/aggregator/builtin.go b/aggregator/builtin.go index 69445ae..2790524 100644 --- a/aggregator/builtin.go +++ b/aggregator/builtin.go @@ -3,26 +3,29 @@ package aggregator import ( "math" "sort" + "strconv" "sync" ) type AggregateType string const ( - Sum AggregateType = "sum" - Count AggregateType = "count" - Avg AggregateType = "avg" - Max AggregateType = "max" - Min AggregateType = "min" - StdDev AggregateType = "stddev" - Median AggregateType = "median" - Percentile AggregateType = "percentile" + Sum AggregateType = "sum" + Count AggregateType = "count" + Avg AggregateType = "avg" + Max AggregateType = "max" + Min AggregateType = "min" + StdDev AggregateType = "stddev" + Median AggregateType = "median" + Percentile AggregateType = "percentile" + WindowStart AggregateType = "window_start" + WindowEnd AggregateType = "window_end" ) type AggregatorFunction interface { New() AggregatorFunction - Add(value float64) - Result() float64 + Add(value interface{}) + Result() interface{} } type SumAggregator struct { @@ -33,11 +36,12 @@ func (s *SumAggregator) New() AggregatorFunction { return &SumAggregator{} } -func (s *SumAggregator) Add(v float64) { - s.value += v +func (s *SumAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + s.value += vv } -func (s *SumAggregator) Result() float64 { +func (s *SumAggregator) Result() interface{} { return s.value } @@ -49,11 +53,11 @@ func (s *CountAggregator) New() AggregatorFunction { return &CountAggregator{} } -func (c *CountAggregator) Add(_ float64) { +func (c *CountAggregator) Add(_ interface{}) { c.count++ } -func (c *CountAggregator) Result() float64 { +func (c *CountAggregator) Result() interface{} { return float64(c.count) } @@ -66,12 +70,13 @@ func (a *AvgAggregator) New() AggregatorFunction { return &AvgAggregator{} } -func (a *AvgAggregator) Add(v float64) { - a.sum += v +func (a *AvgAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + a.sum += vv a.count++ } -func (a *AvgAggregator) Result() float64 { +func (a *AvgAggregator) Result() interface{} { if a.count == 0 { return 0 } @@ -117,6 +122,10 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction { return &MedianAggregator{} case Percentile: return &PercentileAggregator{p: 0.95} + case WindowStart: + return &WindowStartAggregator{} + case WindowEnd: + return &WindowEndAggregator{} default: panic("unsupported aggregator type: " + aggType) } @@ -150,11 +159,12 @@ func (m *MedianAggregator) New() AggregatorFunction { return &MedianAggregator{} } -func (m *MedianAggregator) Add(val float64) { - m.values = append(m.values, val) +func (m *MedianAggregator) Add(val interface{}) { + var vv float64 = ConvertToFloat64(val, 0) + m.values = append(m.values, vv) } -func (m *MedianAggregator) Result() float64 { +func (m *MedianAggregator) Result() interface{} { sort.Float64s(m.values) return m.values[len(m.values)/2] } @@ -168,8 +178,9 @@ func (p *PercentileAggregator) New() AggregatorFunction { return &PercentileAggregator{} } -func (p *PercentileAggregator) Add(v float64) { - p.values = append(p.values, v) +func (p *PercentileAggregator) Add(v interface{}) { + vv := ConvertToFloat64(v, 0) + p.values = append(p.values, vv) } type MinAggregator struct { @@ -183,14 +194,15 @@ func (s *MinAggregator) New() AggregatorFunction { } } -func (m *MinAggregator) Add(v float64) { - if m.first || v < m.value { - m.value = v +func (m *MinAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, math.MaxFloat64) + if m.first || vv < m.value { + m.value = vv m.first = false } } -func (m *MinAggregator) Result() float64 { +func (m *MinAggregator) Result() interface{} { return m.value } @@ -203,22 +215,24 @@ func (m *MaxAggregator) New() AggregatorFunction { return &MaxAggregator{} } -func (m *MaxAggregator) Add(v float64) { - if m.first || v > m.value { - m.value = v +func (m *MaxAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + if m.first || vv > m.value { + m.value = vv m.first = false } } -func (m *MaxAggregator) Result() float64 { +func (m *MaxAggregator) Result() interface{} { return m.value } -func (s *StdDevAggregator) Add(v float64) { - s.values = append(s.values, v) +func (s *StdDevAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + s.values = append(s.values, vv) } -func (s *StdDevAggregator) Result() float64 { +func (s *StdDevAggregator) Result() interface{} { if len(s.values) < 2 { return 0 } @@ -230,7 +244,7 @@ func (s *StdDevAggregator) Result() float64 { return math.Sqrt(sum / float64(len(s.values)-1)) } -func (p *PercentileAggregator) Result() float64 { +func (p *PercentileAggregator) Result() interface{} { if len(p.values) == 0 { return 0 } @@ -246,3 +260,35 @@ func calculateAverage(values []float64) float64 { } return sum / float64(len(values)) } + +func ConvertToFloat64(v interface{}, defaultVal float64) float64 { + var vv float64 = defaultVal + switch val := v.(type) { + case float64: + vv = val + case float32: + vv = float64(val) + case int: + vv = float64(val) + case int32: + vv = float64(val) + case int64: + vv = float64(val) + case uint: + vv = float64(val) + case uint32: + vv = float64(val) + case uint64: + vv = float64(val) + case string: + // 处理字符串类型的转换 + if floatValue, err := strconv.ParseFloat(val, 64); err == nil { + vv = floatValue + } else { + panic("unsupported type for sum aggregator") + } + default: + panic("unsupported type for sum aggregator") + } + return vv +} diff --git a/aggregator/context_aggregator.go b/aggregator/context_aggregator.go new file mode 100644 index 0000000..e0f4780 --- /dev/null +++ b/aggregator/context_aggregator.go @@ -0,0 +1,45 @@ +package aggregator + +type ContextAggregator interface { + GetContextKey() string +} + +type WindowStartAggregator struct { + val interface{} +} + +func (w *WindowStartAggregator) New() AggregatorFunction { + return &WindowStartAggregator{} +} + +func (w *WindowStartAggregator) Add(val interface{}) { + w.val = val +} + +func (w *WindowStartAggregator) Result() interface{} { + return w.val +} + +func (w *WindowStartAggregator) GetContextKey() string { + return "window_start" +} + +type WindowEndAggregator struct { + val interface{} +} + +func (w *WindowEndAggregator) New() AggregatorFunction { + return &WindowEndAggregator{} +} + +func (w *WindowEndAggregator) Add(val interface{}) { + w.val = val +} + +func (w *WindowEndAggregator) Result() interface{} { + return w.val +} + +func (w *WindowEndAggregator) GetContextKey() string { + return "window_end" +} diff --git a/aggregator/group_aggregator.go b/aggregator/group_aggregator.go index 211ffea..0402541 100644 --- a/aggregator/group_aggregator.go +++ b/aggregator/group_aggregator.go @@ -9,6 +9,7 @@ import ( type Aggregator interface { Add(data interface{}) error + Put(key string, val interface{}) error GetResults() ([]map[string]interface{}, error) Reset() } @@ -19,9 +20,11 @@ type GroupAggregator struct { aggregators map[string]AggregatorFunction groups map[string]map[string]AggregatorFunction mu sync.RWMutex + context map[string]interface{} + fieldAlias map[string]string } -func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator { +func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator { aggregators := make(map[string]AggregatorFunction) for field, aggType := range fieldMap { @@ -33,8 +36,20 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) groupFields: groupFields, aggregators: aggregators, groups: make(map[string]map[string]AggregatorFunction), + fieldAlias: fieldAlias, } } + +func (ga *GroupAggregator) Put(key string, val interface{}) error { + ga.mu.Lock() // 获取写锁 + defer ga.mu.Unlock() // 确保函数返回时释放锁 + if ga.context == nil { + ga.context = make(map[string]interface{}) + } + ga.context[key] = val + return nil +} + func (ga *GroupAggregator) Add(data interface{}) error { ga.mu.Lock() // 获取写锁 defer ga.mu.Unlock() // 确保函数返回时释放锁 @@ -114,6 +129,19 @@ func (ga *GroupAggregator) Add(data interface{}) error { if !f.IsValid() { //return fmt.Errorf("field %s not found", field) //fmt.Printf("field %s not found in %v \n ", field, data) + + // 尝试从context中获取 + if ga.context != nil { + if groupAgg, exists := ga.groups[key][field]; exists { + if _, ok := groupAgg.(ContextAggregator); ok { + key := groupAgg.(ContextAggregator).GetContextKey() + if val, exists := ga.context[key]; exists { + groupAgg.Add(val) + //fmt.Printf("add agg group by %s:%s , %v \n", key, field, value) + } + } + } + } continue } @@ -151,7 +179,20 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) { group[field] = fields[i] } for field, agg := range aggregators { - group[field+"_"+string(ga.fieldMap[field])] = agg.Result() + if _, ok := agg.(ContextAggregator); ok { + if alias, ok := ga.fieldAlias[field]; ok { + group[alias] = agg.Result() + } else { + group[field] = agg.Result() + } + } else { + if alias, ok := ga.fieldAlias[field]; ok { + group[alias] = agg.Result() + } else { + group[field+"_"+string(ga.fieldMap[field])] = agg.Result() + } + } + } result = append(result, group) } diff --git a/aggregator/group_aggregator_test.go b/aggregator/group_aggregator_test.go index e7279fa..4c8c395 100644 --- a/aggregator/group_aggregator_test.go +++ b/aggregator/group_aggregator_test.go @@ -19,13 +19,17 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) { "Data1": Sum, "Data2": Sum, }, + map[string]string{ + "Data1": "Data1_sum", + "Data2": "Data2_sum", + }, ) testData := []map[string]interface{}{ - {"device": "aa", "data1": 20, "data2": 30}, - {"device": "aa", "data1": 21, "data2": 0}, - {"device": "bb", "data1": 15, "data2": 20}, - {"device": "bb", "data1": 16, "data2": 20}, + {"Device": "aa", "Data1": 20, "Data2": 30}, + {"Device": "aa", "Data1": 21, "Data2": 0}, + {"Device": "bb", "Data1": 15, "Data2": 20}, + {"Device": "bb", "Data1": 16, "Data2": 20}, } for _, d := range testData { @@ -47,6 +51,9 @@ func TestGroupAggregator_SingleField(t *testing.T) { map[string]AggregateType{ "Data1": Sum, }, + map[string]string{ + "Data1": "Data1_sum", + }, ) testData := []map[string]interface{}{ @@ -75,6 +82,12 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) { "Data3": Max, "Data4": Min, }, + map[string]string{ + "Data1": "Data1_sum", + "Data2": "Data2_avg", + "Data3": "Data3_max", + "Data4": "Data4_min", + }, ) testData := []map[string]interface{}{ @@ -88,7 +101,7 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) { expected := []map[string]interface{}{ { - "Device": "cc", + "Device": "cc", "Data1_sum": 30.0, "Data2_avg": 5.0, "Data3_max": 12.0, diff --git a/model/model.go b/model/model.go index 869674b..fceab5a 100644 --- a/model/model.go +++ b/model/model.go @@ -3,13 +3,14 @@ package model import ( "time" - aggregator2 "github.com/rulego/streamsql/aggregator" + "github.com/rulego/streamsql/aggregator" ) type Config struct { WindowConfig WindowConfig GroupFields []string - SelectFields map[string]aggregator2.AggregateType + SelectFields map[string]aggregator.AggregateType + FieldAlias map[string]string } type WindowConfig struct { Type string diff --git a/model/timeslot.go b/model/timeslot.go index 0ff5043..951490f 100644 --- a/model/timeslot.go +++ b/model/timeslot.go @@ -50,3 +50,17 @@ func (ts *TimeSlot) GetEndTime() *time.Time { } return ts.End } + +func (ts *TimeSlot) WindowStart() int64 { + if ts == nil || ts.Start == nil { + return 0 + } + return ts.Start.UnixNano() +} + +func (ts *TimeSlot) WindowEnd() int64 { + if ts == nil || ts.End == nil { + return 0 + } + return ts.End.UnixNano() +} diff --git a/rsql/ast.go b/rsql/ast.go index c028afb..f9da831 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -22,6 +22,7 @@ type SelectStatement struct { type Field struct { Expression string Alias string + AggType string } type WindowDefinition struct { @@ -52,7 +53,7 @@ func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) { if err != nil { return nil, "", fmt.Errorf("解析窗口参数失败: %w", err) } - + aggs, fields := buildSelectFields(s.Fields) // 构建Stream配置 config := model.Config{ WindowConfig: model.WindowConfig{ @@ -62,7 +63,8 @@ func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) { TimeUnit: s.Window.TimeUnit, }, GroupFields: extractGroupFields(s), - SelectFields: buildSelectFields(s.Fields), + SelectFields: aggs, + FieldAlias: fields, } return &config, s.Condition, nil @@ -78,28 +80,50 @@ func extractGroupFields(s *SelectStatement) []string { return fields } -func buildSelectFields(fields []Field) map[string]aggregator.AggregateType { +func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string) { selectFields := make(map[string]aggregator.AggregateType) + fieldMap = make(map[string]string) for _, f := range fields { if alias := f.Alias; alias != "" { - selectFields[alias] = parseAggregateType(f.Expression) + t, n := parseAggregateType(f.Expression) + if n != "" { + selectFields[n] = t + fieldMap[n] = alias + } else { + selectFields[alias] = t + } } } - return selectFields + return selectFields, fieldMap } -func parseAggregateType(expr string) aggregator.AggregateType { +func parseAggregateType(expr string) (aggType aggregator.AggregateType, name string) { if strings.Contains(expr, "avg(") { - return "avg" + return "avg", extractAggField(expr) } if strings.Contains(expr, "sum(") { - return "sum" + return "sum", extractAggField(expr) } if strings.Contains(expr, "max(") { - return "max" + return "max", extractAggField(expr) } if strings.Contains(expr, "min(") { - return "min" + return "min", extractAggField(expr) + } + if strings.Contains(expr, "window_start(") { + return "window_start", "window_start" + } + if strings.Contains(expr, "window_end(") { + return "window_end", "window_end" + } + return "", "" +} + +func extractAggField(expr string) string { + start := strings.Index(expr, "(") + end := strings.LastIndex(expr, ")") + if start >= 0 && end > start { + return strings.TrimSpace(expr[start+1 : end]) } return "" } diff --git a/rsql/parser_test.go b/rsql/parser_test.go index 7fa96c1..b5f9de2 100644 --- a/rsql/parser_test.go +++ b/rsql/parser_test.go @@ -44,8 +44,8 @@ func TestParseSQL(t *testing.T) { }, GroupFields: []string{"type"}, SelectFields: map[string]aggregator.AggregateType{ - "max_score": "max", - "min_age": "min", + "score": "max", + "age": "min", }, }, condition: "", diff --git a/stream/stream.go b/stream/stream.go index c5d6faa..c8ba6ee 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -2,6 +2,7 @@ package stream import ( "fmt" + "strings" "time" aggregator2 "github.com/rulego/streamsql/aggregator" @@ -34,6 +35,9 @@ func NewStream(config model.Config) (*Stream, error) { } func (s *Stream) RegisterFilter(condition string) error { + if strings.TrimSpace(condition) == "" { + return nil + } filter, err := parser.NewExprCondition(condition) if err != nil { return fmt.Errorf("compile filter error: %w", err) @@ -47,7 +51,7 @@ func (s *Stream) Start() { } func (s *Stream) process() { - s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields) + s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias) // 启动窗口处理协程 s.Window.Start() @@ -66,10 +70,13 @@ func (s *Stream) process() { case batch := <-s.Window.OutputChan(): // 处理窗口批数据 for _, item := range batch { - if err := s.aggregator.Add(item); err != nil { + s.aggregator.Put("window_start", item.Slot.WindowStart()) + s.aggregator.Put("window_end", item.Slot.WindowEnd()) + if err := s.aggregator.Add(item.Data); err != nil { fmt.Printf("aggregate error: %v\n", err) } } + // 获取并发送聚合结果 if results, err := s.aggregator.GetResults(); err == nil { // 发送结果到结果通道和 Sink 函数 diff --git a/stream/stream_test.go b/stream/stream_test.go index b6e0d94..2c85945 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -223,3 +223,95 @@ func TestIncompleteStreamProcess(t *testing.T) { assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001) assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001) } + +func TestWindowSlotAgg(t *testing.T) { + config := model.Config{ + WindowConfig: model.WindowConfig{ + Type: "sliding", + Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second}, + TsProp: "ts", + }, + GroupFields: []string{"device"}, + SelectFields: map[string]aggregator.AggregateType{ + "age": aggregator.Max, + "score": aggregator.Min, + "start": aggregator.WindowStart, + "end": aggregator.WindowEnd, + }, + } + + strm, err := NewStream(config) + require.NoError(t, err) + + strm.Start() + // Add data every 500ms + baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC) + + testData := []interface{}{ + map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "ts": baseTime}, + map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "ts": baseTime.Add(1 * time.Second)}, + map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "ts": baseTime}, + } + + for _, data := range testData { + strm.AddData(data) + } + + // 捕获结果 + resultChan := make(chan interface{}) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + // 等待 3 秒触发窗口 + time.Sleep(3 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var actual interface{} + select { + case actual = <-resultChan: + cancel() + case <-ctx.Done(): + t.Fatal("Timeout waiting for results") + } + + expected := []map[string]interface{}{ + { + "device": "aa", + "age_max": 10.0, + "score_min": 100.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + { + "device": "bb", + "age_max": 3.0, + "score_min": 300.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + } + + assert.IsType(t, []map[string]interface{}{}, actual) + resultSlice, ok := actual.([]map[string]interface{}) + require.True(t, ok) + + assert.Len(t, resultSlice, 2) + for _, expectedResult := range expected { + found := false + for _, resultMap := range resultSlice { + //if resultMap, ok := result.(map[string]interface{}); ok { + if resultMap["device"] == expectedResult["device"] { + assert.InEpsilon(t, expectedResult["age_max"].(float64), resultMap["age_max"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["score_min"].(float64), resultMap["score_min"].(float64), 0.0001) + assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64)) + assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64)) + found = true + break + } + //} + } + assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"])) + } +} diff --git a/streamsql_test.go b/streamsql_test.go index 899ed31..47a7295 100644 --- a/streamsql_test.go +++ b/streamsql_test.go @@ -1,13 +1,85 @@ package streamsql import ( - "github.com/stretchr/testify/assert" + "context" + "fmt" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStreamsql(t *testing.T) { streamsql := New() - var rsql = "" + var rsql = "SELECT device,max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream group by device,SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')" err := streamsql.Execute(rsql) assert.Nil(t, err) + strm := streamsql.stream + baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC) + testData := []interface{}{ + map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime}, + map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)}, + map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime}, + } + + for _, data := range testData { + strm.AddData(data) + } + // 捕获结果 + resultChan := make(chan interface{}) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + // 等待 3 秒触发窗口 + time.Sleep(3 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var actual interface{} + select { + case actual = <-resultChan: + cancel() + case <-ctx.Done(): + t.Fatal("Timeout waiting for results") + } + + expected := []map[string]interface{}{ + { + "device": "aa", + "max_age": 10.0, + "min_score": 100.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + { + "device": "bb", + "max_age": 3.0, + "min_score": 300.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + } + + assert.IsType(t, []map[string]interface{}{}, actual) + resultSlice, ok := actual.([]map[string]interface{}) + require.True(t, ok) + assert.Len(t, resultSlice, 2) + for _, expectedResult := range expected { + found := false + for _, resultMap := range resultSlice { + //if resultMap, ok := result.(map[string]interface{}); ok { + if resultMap["device"] == expectedResult["device"] { + assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001) + assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64)) + assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64)) + found = true + break + } + //} + } + assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"])) + } }