diff --git a/aggregator/builtin.go b/aggregator/builtin.go index 69445ae..2790524 100644 --- a/aggregator/builtin.go +++ b/aggregator/builtin.go @@ -3,26 +3,29 @@ package aggregator import ( "math" "sort" + "strconv" "sync" ) type AggregateType string const ( - Sum AggregateType = "sum" - Count AggregateType = "count" - Avg AggregateType = "avg" - Max AggregateType = "max" - Min AggregateType = "min" - StdDev AggregateType = "stddev" - Median AggregateType = "median" - Percentile AggregateType = "percentile" + Sum AggregateType = "sum" + Count AggregateType = "count" + Avg AggregateType = "avg" + Max AggregateType = "max" + Min AggregateType = "min" + StdDev AggregateType = "stddev" + Median AggregateType = "median" + Percentile AggregateType = "percentile" + WindowStart AggregateType = "window_start" + WindowEnd AggregateType = "window_end" ) type AggregatorFunction interface { New() AggregatorFunction - Add(value float64) - Result() float64 + Add(value interface{}) + Result() interface{} } type SumAggregator struct { @@ -33,11 +36,12 @@ func (s *SumAggregator) New() AggregatorFunction { return &SumAggregator{} } -func (s *SumAggregator) Add(v float64) { - s.value += v +func (s *SumAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + s.value += vv } -func (s *SumAggregator) Result() float64 { +func (s *SumAggregator) Result() interface{} { return s.value } @@ -49,11 +53,11 @@ func (s *CountAggregator) New() AggregatorFunction { return &CountAggregator{} } -func (c *CountAggregator) Add(_ float64) { +func (c *CountAggregator) Add(_ interface{}) { c.count++ } -func (c *CountAggregator) Result() float64 { +func (c *CountAggregator) Result() interface{} { return float64(c.count) } @@ -66,12 +70,13 @@ func (a *AvgAggregator) New() AggregatorFunction { return &AvgAggregator{} } -func (a *AvgAggregator) Add(v float64) { - a.sum += v +func (a *AvgAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + a.sum += vv a.count++ } -func (a *AvgAggregator) Result() float64 { +func (a *AvgAggregator) Result() interface{} { if a.count == 0 { return 0 } @@ -117,6 +122,10 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction { return &MedianAggregator{} case Percentile: return &PercentileAggregator{p: 0.95} + case WindowStart: + return &WindowStartAggregator{} + case WindowEnd: + return &WindowEndAggregator{} default: panic("unsupported aggregator type: " + aggType) } @@ -150,11 +159,12 @@ func (m *MedianAggregator) New() AggregatorFunction { return &MedianAggregator{} } -func (m *MedianAggregator) Add(val float64) { - m.values = append(m.values, val) +func (m *MedianAggregator) Add(val interface{}) { + var vv float64 = ConvertToFloat64(val, 0) + m.values = append(m.values, vv) } -func (m *MedianAggregator) Result() float64 { +func (m *MedianAggregator) Result() interface{} { sort.Float64s(m.values) return m.values[len(m.values)/2] } @@ -168,8 +178,9 @@ func (p *PercentileAggregator) New() AggregatorFunction { return &PercentileAggregator{} } -func (p *PercentileAggregator) Add(v float64) { - p.values = append(p.values, v) +func (p *PercentileAggregator) Add(v interface{}) { + vv := ConvertToFloat64(v, 0) + p.values = append(p.values, vv) } type MinAggregator struct { @@ -183,14 +194,15 @@ func (s *MinAggregator) New() AggregatorFunction { } } -func (m *MinAggregator) Add(v float64) { - if m.first || v < m.value { - m.value = v +func (m *MinAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, math.MaxFloat64) + if m.first || vv < m.value { + m.value = vv m.first = false } } -func (m *MinAggregator) Result() float64 { +func (m *MinAggregator) Result() interface{} { return m.value } @@ -203,22 +215,24 @@ func (m *MaxAggregator) New() AggregatorFunction { return &MaxAggregator{} } -func (m *MaxAggregator) Add(v float64) { - if m.first || v > m.value { - m.value = v +func (m *MaxAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + if m.first || vv > m.value { + m.value = vv m.first = false } } -func (m *MaxAggregator) Result() float64 { +func (m *MaxAggregator) Result() interface{} { return m.value } -func (s *StdDevAggregator) Add(v float64) { - s.values = append(s.values, v) +func (s *StdDevAggregator) Add(v interface{}) { + var vv float64 = ConvertToFloat64(v, 0) + s.values = append(s.values, vv) } -func (s *StdDevAggregator) Result() float64 { +func (s *StdDevAggregator) Result() interface{} { if len(s.values) < 2 { return 0 } @@ -230,7 +244,7 @@ func (s *StdDevAggregator) Result() float64 { return math.Sqrt(sum / float64(len(s.values)-1)) } -func (p *PercentileAggregator) Result() float64 { +func (p *PercentileAggregator) Result() interface{} { if len(p.values) == 0 { return 0 } @@ -246,3 +260,35 @@ func calculateAverage(values []float64) float64 { } return sum / float64(len(values)) } + +func ConvertToFloat64(v interface{}, defaultVal float64) float64 { + var vv float64 = defaultVal + switch val := v.(type) { + case float64: + vv = val + case float32: + vv = float64(val) + case int: + vv = float64(val) + case int32: + vv = float64(val) + case int64: + vv = float64(val) + case uint: + vv = float64(val) + case uint32: + vv = float64(val) + case uint64: + vv = float64(val) + case string: + // 处理字符串类型的转换 + if floatValue, err := strconv.ParseFloat(val, 64); err == nil { + vv = floatValue + } else { + panic("unsupported type for sum aggregator") + } + default: + panic("unsupported type for sum aggregator") + } + return vv +} diff --git a/aggregator/context_aggregator.go b/aggregator/context_aggregator.go new file mode 100644 index 0000000..e0f4780 --- /dev/null +++ b/aggregator/context_aggregator.go @@ -0,0 +1,45 @@ +package aggregator + +type ContextAggregator interface { + GetContextKey() string +} + +type WindowStartAggregator struct { + val interface{} +} + +func (w *WindowStartAggregator) New() AggregatorFunction { + return &WindowStartAggregator{} +} + +func (w *WindowStartAggregator) Add(val interface{}) { + w.val = val +} + +func (w *WindowStartAggregator) Result() interface{} { + return w.val +} + +func (w *WindowStartAggregator) GetContextKey() string { + return "window_start" +} + +type WindowEndAggregator struct { + val interface{} +} + +func (w *WindowEndAggregator) New() AggregatorFunction { + return &WindowEndAggregator{} +} + +func (w *WindowEndAggregator) Add(val interface{}) { + w.val = val +} + +func (w *WindowEndAggregator) Result() interface{} { + return w.val +} + +func (w *WindowEndAggregator) GetContextKey() string { + return "window_end" +} diff --git a/aggregator/group_aggregator.go b/aggregator/group_aggregator.go index 211ffea..0402541 100644 --- a/aggregator/group_aggregator.go +++ b/aggregator/group_aggregator.go @@ -9,6 +9,7 @@ import ( type Aggregator interface { Add(data interface{}) error + Put(key string, val interface{}) error GetResults() ([]map[string]interface{}, error) Reset() } @@ -19,9 +20,11 @@ type GroupAggregator struct { aggregators map[string]AggregatorFunction groups map[string]map[string]AggregatorFunction mu sync.RWMutex + context map[string]interface{} + fieldAlias map[string]string } -func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator { +func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator { aggregators := make(map[string]AggregatorFunction) for field, aggType := range fieldMap { @@ -33,8 +36,20 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) groupFields: groupFields, aggregators: aggregators, groups: make(map[string]map[string]AggregatorFunction), + fieldAlias: fieldAlias, } } + +func (ga *GroupAggregator) Put(key string, val interface{}) error { + ga.mu.Lock() // 获取写锁 + defer ga.mu.Unlock() // 确保函数返回时释放锁 + if ga.context == nil { + ga.context = make(map[string]interface{}) + } + ga.context[key] = val + return nil +} + func (ga *GroupAggregator) Add(data interface{}) error { ga.mu.Lock() // 获取写锁 defer ga.mu.Unlock() // 确保函数返回时释放锁 @@ -114,6 +129,19 @@ func (ga *GroupAggregator) Add(data interface{}) error { if !f.IsValid() { //return fmt.Errorf("field %s not found", field) //fmt.Printf("field %s not found in %v \n ", field, data) + + // 尝试从context中获取 + if ga.context != nil { + if groupAgg, exists := ga.groups[key][field]; exists { + if _, ok := groupAgg.(ContextAggregator); ok { + key := groupAgg.(ContextAggregator).GetContextKey() + if val, exists := ga.context[key]; exists { + groupAgg.Add(val) + //fmt.Printf("add agg group by %s:%s , %v \n", key, field, value) + } + } + } + } continue } @@ -151,7 +179,20 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) { group[field] = fields[i] } for field, agg := range aggregators { - group[field+"_"+string(ga.fieldMap[field])] = agg.Result() + if _, ok := agg.(ContextAggregator); ok { + if alias, ok := ga.fieldAlias[field]; ok { + group[alias] = agg.Result() + } else { + group[field] = agg.Result() + } + } else { + if alias, ok := ga.fieldAlias[field]; ok { + group[alias] = agg.Result() + } else { + group[field+"_"+string(ga.fieldMap[field])] = agg.Result() + } + } + } result = append(result, group) } diff --git a/aggregator/group_aggregator_test.go b/aggregator/group_aggregator_test.go index e7279fa..4c8c395 100644 --- a/aggregator/group_aggregator_test.go +++ b/aggregator/group_aggregator_test.go @@ -19,13 +19,17 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) { "Data1": Sum, "Data2": Sum, }, + map[string]string{ + "Data1": "Data1_sum", + "Data2": "Data2_sum", + }, ) testData := []map[string]interface{}{ - {"device": "aa", "data1": 20, "data2": 30}, - {"device": "aa", "data1": 21, "data2": 0}, - {"device": "bb", "data1": 15, "data2": 20}, - {"device": "bb", "data1": 16, "data2": 20}, + {"Device": "aa", "Data1": 20, "Data2": 30}, + {"Device": "aa", "Data1": 21, "Data2": 0}, + {"Device": "bb", "Data1": 15, "Data2": 20}, + {"Device": "bb", "Data1": 16, "Data2": 20}, } for _, d := range testData { @@ -47,6 +51,9 @@ func TestGroupAggregator_SingleField(t *testing.T) { map[string]AggregateType{ "Data1": Sum, }, + map[string]string{ + "Data1": "Data1_sum", + }, ) testData := []map[string]interface{}{ @@ -75,6 +82,12 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) { "Data3": Max, "Data4": Min, }, + map[string]string{ + "Data1": "Data1_sum", + "Data2": "Data2_avg", + "Data3": "Data3_max", + "Data4": "Data4_min", + }, ) testData := []map[string]interface{}{ @@ -88,7 +101,7 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) { expected := []map[string]interface{}{ { - "Device": "cc", + "Device": "cc", "Data1_sum": 30.0, "Data2_avg": 5.0, "Data3_max": 12.0, diff --git a/model/model.go b/model/model.go new file mode 100644 index 0000000..fceab5a --- /dev/null +++ b/model/model.go @@ -0,0 +1,20 @@ +package model + +import ( + "time" + + "github.com/rulego/streamsql/aggregator" +) + +type Config struct { + WindowConfig WindowConfig + GroupFields []string + SelectFields map[string]aggregator.AggregateType + FieldAlias map[string]string +} +type WindowConfig struct { + Type string + Params map[string]interface{} + TsProp string + TimeUnit time.Duration +} diff --git a/model/row.go b/model/row.go new file mode 100644 index 0000000..25e1f59 --- /dev/null +++ b/model/row.go @@ -0,0 +1,20 @@ +package model + +import ( + "time" +) + +type RowEvent interface { + GetTimestamp() time.Time +} + +type Row struct { + Timestamp time.Time + Data interface{} + Slot *TimeSlot +} + +// GetTimestamp 获取时间戳 +func (r *Row) GetTimestamp() time.Time { + return r.Timestamp +} diff --git a/model/timeslot.go b/model/timeslot.go new file mode 100644 index 0000000..951490f --- /dev/null +++ b/model/timeslot.go @@ -0,0 +1,66 @@ +package model + +import ( + "time" +) + +type TimeSlot struct { + Start *time.Time + End *time.Time +} + +func NewTimeSlot(start, end *time.Time) *TimeSlot { + return &TimeSlot{ + Start: start, + End: end, + } +} + +// Hash 生成槽位的哈希值 +func (ts TimeSlot) Hash() uint64 { + // 将开始时间和结束时间转换为 Unix 时间戳(纳秒级) + startNano := ts.Start.UnixNano() + endNano := ts.End.UnixNano() + + // 使用简单但高效的哈希算法 + // 将两个时间戳组合成一个唯一的哈希值 + hash := uint64(startNano) + hash = (hash << 32) | (hash >> 32) + hash = hash ^ uint64(endNano) + + return hash +} + +// Contains 检查给定时间是否在槽位范围内 +func (ts TimeSlot) Contains(t time.Time) bool { + return (t.Equal(*ts.Start) || t.After(*ts.Start)) && + t.Before(*ts.End) +} + +func (ts *TimeSlot) GetStartTime() *time.Time { + if ts == nil || ts.Start == nil { + return nil + } + return ts.Start +} + +func (ts *TimeSlot) GetEndTime() *time.Time { + if ts == nil || ts.End == nil { + return nil + } + return ts.End +} + +func (ts *TimeSlot) WindowStart() int64 { + if ts == nil || ts.Start == nil { + return 0 + } + return ts.Start.UnixNano() +} + +func (ts *TimeSlot) WindowEnd() int64 { + if ts == nil || ts.End == nil { + return 0 + } + return ts.End.UnixNano() +} diff --git a/rsql/ast.go b/rsql/ast.go index 99cffa4..f9da831 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -2,12 +2,13 @@ package rsql import ( "fmt" - "github.com/rulego/streamsql/window" "strings" "time" + "github.com/rulego/streamsql/model" + "github.com/rulego/streamsql/window" + "github.com/rulego/streamsql/aggregator" - "github.com/rulego/streamsql/stream" ) type SelectStatement struct { @@ -21,15 +22,18 @@ type SelectStatement struct { type Field struct { Expression string Alias string + AggType string } type WindowDefinition struct { - Type string - Params []interface{} + Type string + Params []interface{} + TsProp string + TimeUnit time.Duration } // ToStreamConfig 将AST转换为Stream配置 -func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) { +func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) { if s.Source == "" { return nil, "", fmt.Errorf("missing FROM clause") } @@ -49,15 +53,18 @@ func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) { if err != nil { return nil, "", fmt.Errorf("解析窗口参数失败: %w", err) } - + aggs, fields := buildSelectFields(s.Fields) // 构建Stream配置 - config := stream.Config{ - WindowConfig: stream.WindowConfig{ - Type: windowType, - Params: params, + config := model.Config{ + WindowConfig: model.WindowConfig{ + Type: windowType, + Params: params, + TsProp: s.Window.TsProp, + TimeUnit: s.Window.TimeUnit, }, GroupFields: extractGroupFields(s), - SelectFields: buildSelectFields(s.Fields), + SelectFields: aggs, + FieldAlias: fields, } return &config, s.Condition, nil @@ -73,28 +80,50 @@ func extractGroupFields(s *SelectStatement) []string { return fields } -func buildSelectFields(fields []Field) map[string]aggregator.AggregateType { +func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string) { selectFields := make(map[string]aggregator.AggregateType) + fieldMap = make(map[string]string) for _, f := range fields { if alias := f.Alias; alias != "" { - selectFields[alias] = parseAggregateType(f.Expression) + t, n := parseAggregateType(f.Expression) + if n != "" { + selectFields[n] = t + fieldMap[n] = alias + } else { + selectFields[alias] = t + } } } - return selectFields + return selectFields, fieldMap } -func parseAggregateType(expr string) aggregator.AggregateType { +func parseAggregateType(expr string) (aggType aggregator.AggregateType, name string) { if strings.Contains(expr, "avg(") { - return "avg" + return "avg", extractAggField(expr) } if strings.Contains(expr, "sum(") { - return "sum" + return "sum", extractAggField(expr) } if strings.Contains(expr, "max(") { - return "max" + return "max", extractAggField(expr) } if strings.Contains(expr, "min(") { - return "min" + return "min", extractAggField(expr) + } + if strings.Contains(expr, "window_start(") { + return "window_start", "window_start" + } + if strings.Contains(expr, "window_end(") { + return "window_end", "window_end" + } + return "", "" +} + +func extractAggField(expr string) string { + start := strings.Index(expr, "(") + end := strings.LastIndex(expr, ")") + if start >= 0 && end > start { + return strings.TrimSpace(expr[start+1 : end]) } return "" } diff --git a/rsql/lexer.go b/rsql/lexer.go index 7ade707..299b6aa 100644 --- a/rsql/lexer.go +++ b/rsql/lexer.go @@ -34,6 +34,9 @@ const ( TokenSliding TokenCounting TokenSession + TokenWITH + TokenTimestamp + TokenTimeUnit ) type Token struct { @@ -204,6 +207,12 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenCounting, Value: ident} case "SESSIONWINDOW": return Token{Type: TokenSession, Value: ident} + case "WITH": + return Token{Type: TokenWITH, Value: ident} + case "TIMESTAMP": + return Token{Type: TokenTimestamp, Value: ident} + case "TIMEUNIT": + return Token{Type: TokenTimeUnit, Value: ident} default: return Token{Type: TokenIdent, Value: ident} } diff --git a/rsql/parser.go b/rsql/parser.go index c3696e3..01a451c 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -4,6 +4,7 @@ import ( "errors" "strconv" "strings" + "time" ) type Parser struct { @@ -39,6 +40,10 @@ func (p *Parser) Parse() (*SelectStatement, error) { return nil, err } + if err := p.parseWith(stmt); err != nil { + return nil, err + } + return stmt, nil } func (p *Parser) parseSelect(stmt *SelectStatement) error { @@ -125,9 +130,14 @@ func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) erro params = append(params, convertValue(valTok.Value)) } - stmt.Window = WindowDefinition{ - Type: winType, - Params: params, + if &stmt.Window != nil { + stmt.Window.Params = params + stmt.Window.Type = winType + } else { + stmt.Window = WindowDefinition{ + Type: winType, + Params: params, + } } return nil } @@ -185,3 +195,68 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error { } return nil } + +func (p *Parser) parseWith(stmt *SelectStatement) error { + p.lexer.NextToken() // 跳过( + for p.lexer.peekChar() != ')' { + valTok := p.lexer.NextToken() + if valTok.Type == TokenRParen || valTok.Type == TokenEOF { + break + } + if valTok.Type == TokenComma { + continue + } + + if valTok.Type == TokenTimestamp { + next := p.lexer.NextToken() + if next.Type == TokenEQ { + next = p.lexer.NextToken() + if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") { + next.Value = strings.Trim(next.Value, "'") + } + // 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition + if stmt.Window.Type == "" { + stmt.Window = WindowDefinition{ + TsProp: next.Value, + } + } else { + stmt.Window.TsProp = next.Value + } + } + } + if valTok.Type == TokenTimeUnit { + timeUnit := time.Minute + next := p.lexer.NextToken() + if next.Type == TokenEQ { + next = p.lexer.NextToken() + if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") { + next.Value = strings.Trim(next.Value, "'") + } + switch next.Value { + case "dd": + timeUnit = 24 * time.Hour + case "hh": + timeUnit = time.Hour + case "mi": + timeUnit = time.Minute + case "ss": + timeUnit = time.Second + case "ms": + timeUnit = time.Millisecond + default: + + } + // 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition + if stmt.Window.Type == "" { + stmt.Window = WindowDefinition{ + TimeUnit: timeUnit, + } + } else { + stmt.Window.TimeUnit = timeUnit + } + } + } + } + + return nil +} diff --git a/rsql/parser_test.go b/rsql/parser_test.go index 2027545..b5f9de2 100644 --- a/rsql/parser_test.go +++ b/rsql/parser_test.go @@ -1,24 +1,25 @@ package rsql import ( - "github.com/rulego/streamsql/aggregator" "testing" "time" - "github.com/rulego/streamsql/stream" + "github.com/rulego/streamsql/aggregator" + "github.com/rulego/streamsql/model" + "github.com/stretchr/testify/assert" ) func TestParseSQL(t *testing.T) { tests := []struct { sql string - expected *stream.Config + expected *model.Config condition string }{ { sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s')", - expected: &stream.Config{ - WindowConfig: stream.WindowConfig{ + expected: &model.Config{ + WindowConfig: model.WindowConfig{ Type: "tumbling", Params: map[string]interface{}{ "size": 10 * time.Second, @@ -33,8 +34,8 @@ func TestParseSQL(t *testing.T) { }, { sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow('20s', '5s')", - expected: &stream.Config{ - WindowConfig: stream.WindowConfig{ + expected: &model.Config{ + WindowConfig: model.WindowConfig{ Type: "sliding", Params: map[string]interface{}{ "size": 20 * time.Second, @@ -43,12 +44,29 @@ func TestParseSQL(t *testing.T) { }, GroupFields: []string{"type"}, SelectFields: map[string]aggregator.AggregateType{ - "max_score": "max", - "min_age": "min", + "score": "max", + "age": "min", }, }, condition: "", }, + { + sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ", + expected: &model.Config{ + WindowConfig: model.WindowConfig{ + Type: "tumbling", + Params: map[string]interface{}{ + "size": 10 * time.Second, + }, + TsProp: "ts", + }, + GroupFields: []string{"deviceId"}, + SelectFields: map[string]aggregator.AggregateType{ + "aa": "avg", + }, + }, + condition: "deviceId == 'aa'", + }, } for _, tt := range tests { @@ -64,6 +82,9 @@ func TestParseSQL(t *testing.T) { assert.Equal(t, tt.expected.GroupFields, config.GroupFields) assert.Equal(t, tt.expected.SelectFields, config.SelectFields) assert.Equal(t, tt.condition, cond) + if tt.expected.WindowConfig.TsProp != "" { + assert.Equal(t, tt.expected.WindowConfig.TsProp, config.WindowConfig.TsProp) + } } } func TestWindowParamParsing(t *testing.T) { diff --git a/stream/stream.go b/stream/stream.go index 5cf2dcc..c8ba6ee 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -2,36 +2,27 @@ package stream import ( "fmt" + "strings" "time" aggregator2 "github.com/rulego/streamsql/aggregator" + "github.com/rulego/streamsql/model" "github.com/rulego/streamsql/parser" "github.com/rulego/streamsql/window" ) -type Config struct { - WindowConfig WindowConfig - GroupFields []string - SelectFields map[string]aggregator2.AggregateType -} - -type WindowConfig struct { - Type string - Params map[string]interface{} -} - type Stream struct { dataChan chan interface{} filter parser.Condition Window window.Window aggregator aggregator2.Aggregator - config Config + config model.Config sinks []func(interface{}) resultChan chan interface{} // 结果通道 } -func NewStream(config Config) (*Stream, error) { - win, err := window.CreateWindow(config.WindowConfig.Type, config.WindowConfig.Params) +func NewStream(config model.Config) (*Stream, error) { + win, err := window.CreateWindow(config.WindowConfig) if err != nil { return nil, err } @@ -44,6 +35,9 @@ func NewStream(config Config) (*Stream, error) { } func (s *Stream) RegisterFilter(condition string) error { + if strings.TrimSpace(condition) == "" { + return nil + } filter, err := parser.NewExprCondition(condition) if err != nil { return fmt.Errorf("compile filter error: %w", err) @@ -57,7 +51,7 @@ func (s *Stream) Start() { } func (s *Stream) process() { - s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields) + s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias) // 启动窗口处理协程 s.Window.Start() @@ -76,10 +70,13 @@ func (s *Stream) process() { case batch := <-s.Window.OutputChan(): // 处理窗口批数据 for _, item := range batch { - if err := s.aggregator.Add(item); err != nil { + s.aggregator.Put("window_start", item.Slot.WindowStart()) + s.aggregator.Put("window_end", item.Slot.WindowEnd()) + if err := s.aggregator.Add(item.Data); err != nil { fmt.Printf("aggregate error: %v\n", err) } } + // 获取并发送聚合结果 if results, err := s.aggregator.GetResults(); err == nil { // 发送结果到结果通道和 Sink 函数 @@ -108,5 +105,5 @@ func (s *Stream) GetResultsChan() <-chan interface{} { } func NewStreamProcessor() (*Stream, error) { - return NewStream(Config{}) + return NewStream(model.Config{}) } diff --git a/stream/stream_test.go b/stream/stream_test.go index 8655778..2c85945 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -7,13 +7,14 @@ import ( "time" "github.com/rulego/streamsql/aggregator" + "github.com/rulego/streamsql/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestStreamProcess(t *testing.T) { - config := Config{ - WindowConfig: WindowConfig{ + config := model.Config{ + WindowConfig: model.WindowConfig{ Type: "tumbling", Params: map[string]interface{}{"size": time.Second}, }, @@ -77,8 +78,8 @@ func TestStreamProcess(t *testing.T) { // 不设置过滤器 func TestStreamWithoutFilter(t *testing.T) { - config := Config{ - WindowConfig: WindowConfig{ + config := model.Config{ + WindowConfig: model.WindowConfig{ Type: "sliding", Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second}, }, @@ -158,8 +159,8 @@ func TestStreamWithoutFilter(t *testing.T) { } func TestIncompleteStreamProcess(t *testing.T) { - config := Config{ - WindowConfig: WindowConfig{ + config := model.Config{ + WindowConfig: model.WindowConfig{ Type: "tumbling", Params: map[string]interface{}{"size": time.Second}, }, @@ -222,3 +223,95 @@ func TestIncompleteStreamProcess(t *testing.T) { assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001) assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001) } + +func TestWindowSlotAgg(t *testing.T) { + config := model.Config{ + WindowConfig: model.WindowConfig{ + Type: "sliding", + Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second}, + TsProp: "ts", + }, + GroupFields: []string{"device"}, + SelectFields: map[string]aggregator.AggregateType{ + "age": aggregator.Max, + "score": aggregator.Min, + "start": aggregator.WindowStart, + "end": aggregator.WindowEnd, + }, + } + + strm, err := NewStream(config) + require.NoError(t, err) + + strm.Start() + // Add data every 500ms + baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC) + + testData := []interface{}{ + map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "ts": baseTime}, + map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "ts": baseTime.Add(1 * time.Second)}, + map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "ts": baseTime}, + } + + for _, data := range testData { + strm.AddData(data) + } + + // 捕获结果 + resultChan := make(chan interface{}) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + // 等待 3 秒触发窗口 + time.Sleep(3 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var actual interface{} + select { + case actual = <-resultChan: + cancel() + case <-ctx.Done(): + t.Fatal("Timeout waiting for results") + } + + expected := []map[string]interface{}{ + { + "device": "aa", + "age_max": 10.0, + "score_min": 100.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + { + "device": "bb", + "age_max": 3.0, + "score_min": 300.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + } + + assert.IsType(t, []map[string]interface{}{}, actual) + resultSlice, ok := actual.([]map[string]interface{}) + require.True(t, ok) + + assert.Len(t, resultSlice, 2) + for _, expectedResult := range expected { + found := false + for _, resultMap := range resultSlice { + //if resultMap, ok := result.(map[string]interface{}); ok { + if resultMap["device"] == expectedResult["device"] { + assert.InEpsilon(t, expectedResult["age_max"].(float64), resultMap["age_max"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["score_min"].(float64), resultMap["score_min"].(float64), 0.0001) + assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64)) + assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64)) + found = true + break + } + //} + } + assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"])) + } +} diff --git a/streamsql_test.go b/streamsql_test.go index 899ed31..47a7295 100644 --- a/streamsql_test.go +++ b/streamsql_test.go @@ -1,13 +1,85 @@ package streamsql import ( - "github.com/stretchr/testify/assert" + "context" + "fmt" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStreamsql(t *testing.T) { streamsql := New() - var rsql = "" + var rsql = "SELECT device,max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream group by device,SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')" err := streamsql.Execute(rsql) assert.Nil(t, err) + strm := streamsql.stream + baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC) + testData := []interface{}{ + map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime}, + map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)}, + map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime}, + } + + for _, data := range testData { + strm.AddData(data) + } + // 捕获结果 + resultChan := make(chan interface{}) + strm.AddSink(func(result interface{}) { + resultChan <- result + }) + // 等待 3 秒触发窗口 + time.Sleep(3 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var actual interface{} + select { + case actual = <-resultChan: + cancel() + case <-ctx.Done(): + t.Fatal("Timeout waiting for results") + } + + expected := []map[string]interface{}{ + { + "device": "aa", + "max_age": 10.0, + "min_score": 100.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + { + "device": "bb", + "max_age": 3.0, + "min_score": 300.0, + "start": baseTime.UnixNano(), + "end": baseTime.Add(2 * time.Second).UnixNano(), + }, + } + + assert.IsType(t, []map[string]interface{}{}, actual) + resultSlice, ok := actual.([]map[string]interface{}) + require.True(t, ok) + assert.Len(t, resultSlice, 2) + for _, expectedResult := range expected { + found := false + for _, resultMap := range resultSlice { + //if resultMap, ok := result.(map[string]interface{}); ok { + if resultMap["device"] == expectedResult["device"] { + assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001) + assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001) + assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64)) + assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64)) + found = true + break + } + //} + } + assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"])) + } } diff --git a/utils/time.go b/utils/time.go new file mode 100644 index 0000000..8c83f33 --- /dev/null +++ b/utils/time.go @@ -0,0 +1,18 @@ +package timex + +import "time" + +// AlignTimeToWindow 将时间对齐到窗口的起始时间。 +func AlignTimeToWindow(t time.Time, size time.Duration) time.Time { + offset := t.UnixNano() % int64(size) + return t.Add(time.Duration(-offset)) +} + +// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。 +func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time { + trunc := t.Truncate(timeUnit) + if !roundUp { + return trunc.Add(timeUnit) + } + return trunc +} diff --git a/utils/time_test.go b/utils/time_test.go new file mode 100644 index 0000000..1db6904 --- /dev/null +++ b/utils/time_test.go @@ -0,0 +1,69 @@ +// Copyright 2021 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timex + +import ( + "testing" + "time" +) + +func TestAlignTimeToWindow(t *testing.T) { + tests := []struct { + name string + input time.Time + size time.Duration + expected time.Time + }{ + { + name: "对齐到1分钟窗口", + input: time.Date(2024, 1, 1, 12, 35, 56, 789000000, time.UTC), + size: 3 * time.Minute, + expected: time.Date(2024, 1, 1, 12, 33, 0, 0, time.UTC), + }, + { + name: "对齐到5分钟窗口", + input: time.Date(2024, 1, 1, 12, 37, 56, 789000000, time.UTC), + size: 5 * time.Minute, + expected: time.Date(2024, 1, 1, 12, 35, 0, 0, time.UTC), + }, + { + name: "对齐到1小时窗口", + input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC), + size: time.Hour, + expected: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), + }, + { + name: "对齐到1天窗口", + input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC), + size: 24 * time.Hour, + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "零时刻对齐测试", + input: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + size: time.Hour, + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AlignTimeToWindow(tt.input, tt.size) + if !got.Equal(tt.expected) { + t.Errorf("AlignTimeToWindow() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/window/counting_window.go b/window/counting_window.go index b6f2837..a327c7d 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -2,56 +2,90 @@ package window import ( "context" + "fmt" "sync" "time" + + "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" + "github.com/spf13/cast" ) var _ Window = (*CountingWindow)(nil) type CountingWindow struct { + config model.WindowConfig threshold int count int mu sync.Mutex - callback func([]interface{}) - dataBuffer []interface{} - outputChan chan []interface{} + callback func([]model.Row) + dataBuffer []model.Row + outputChan chan []model.Row ctx context.Context cancelFunc context.CancelFunc ticker *time.Ticker triggerChan chan struct{} } -func NewCountingWindow(threshold int, callback func([]interface{})) *CountingWindow { +func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) { ctx, cancel := context.WithCancel(context.Background()) - return &CountingWindow{ + threshold := cast.ToInt(config.Params["count"]) + if threshold <= 0 { + return nil, fmt.Errorf("threshold must be a positive integer") + } + + cw := &CountingWindow{ threshold: threshold, - dataBuffer: make([]interface{}, 0, threshold), - outputChan: make(chan []interface{}, 10), + dataBuffer: make([]model.Row, 0, threshold), + outputChan: make(chan []model.Row, 10), ctx: ctx, cancelFunc: cancel, - callback: callback, triggerChan: make(chan struct{}, 1), } + + if callback, ok := config.Params["callback"].(func([]model.Row)); ok { + cw.SetCallback(callback) + } + return cw, nil } func (cw *CountingWindow) Add(data interface{}) { cw.mu.Lock() - cw.dataBuffer = append(cw.dataBuffer, data) + defer cw.mu.Unlock() + // 将数据添加到窗口的数据列表中 + t := GetTimestamp(data, cw.config.TsProp) + row := model.Row{ + Data: data, + Timestamp: t, + } + cw.dataBuffer = append(cw.dataBuffer, row) cw.count++ shouldTrigger := cw.count >= cw.threshold - cw.mu.Unlock() if shouldTrigger { - cw.mu.Lock() - v := append([]interface{}{}, cw.dataBuffer...) - cw.mu.Unlock() - + slot := cw.createSlot(cw.dataBuffer[:cw.threshold]) + for _, r := range cw.dataBuffer[:cw.threshold] { + // 由于Row是值类型,这里需要通过指针来修改Slot字段 + (&r).Slot = slot + } + data := cw.dataBuffer[:cw.threshold] + if len(cw.dataBuffer) > cw.threshold { + remaining := len(cw.dataBuffer) - cw.threshold + newBuffer := make([]model.Row, remaining, cw.threshold) + copy(newBuffer, cw.dataBuffer[cw.threshold:]) + cw.dataBuffer = newBuffer + } else { + cw.dataBuffer = make([]model.Row, 0, cw.threshold) + } go func() { + cw.mu.Lock() if cw.callback != nil { - cw.callback(v) + cw.callback(data) } - cw.outputChan <- v - cw.Reset() + cw.outputChan <- data + cw.count = 0 + //cw.Reset() + cw.mu.Unlock() }() } } @@ -66,7 +100,7 @@ func (cw *CountingWindow) Start() { for { select { case <-cw.ticker.C: - cw.Trigger() + //cw.Trigger() case <-cw.ctx.Done(): return } @@ -82,7 +116,17 @@ func (cw *CountingWindow) Trigger() { defer cw.mu.Unlock() if cw.callback != nil && len(cw.dataBuffer) > 0 { - cw.callback(cw.dataBuffer) + var resultData []model.Row + if len(cw.dataBuffer) > cw.threshold { + resultData = cw.dataBuffer[:cw.threshold] + } else { + resultData = cw.dataBuffer + } + slot := cw.createSlot(resultData) + for _, r := range resultData { + r.Slot = slot + } + cw.callback(resultData) } cw.Reset() }() @@ -95,9 +139,27 @@ func (cw *CountingWindow) Reset() { cw.dataBuffer = cw.dataBuffer[:0] } -func (cw *CountingWindow) OutputChan() <-chan []interface{} { +func (cw *CountingWindow) OutputChan() <-chan []model.Row { return cw.outputChan } -func (cw *CountingWindow) GetResults() []interface{} { - return append([]interface{}{}, cw.dataBuffer...) + +// func (cw *CountingWindow) GetResults() []interface{} { +// return append([]mode.Row, cw.dataBuffer...) +// } + +// createSlot 创建一个新的时间槽位 +func (cw *CountingWindow) createSlot(data []model.Row) *model.TimeSlot { + if len(data) == 0 { + return nil + } else if len(data) < cw.threshold { + start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true) + end := timex.AlignTime(data[len(cw.dataBuffer)-1].Timestamp, cw.config.TimeUnit, false) + slot := model.NewTimeSlot(&start, &end) + return slot + } else { + start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true) + end := timex.AlignTime(data[cw.threshold-1].Timestamp, cw.config.TimeUnit, false) + slot := model.NewTimeSlot(&start, &end) + return slot + } } diff --git a/window/counting_window_test.go b/window/counting_window_test.go index e1b121a..d0c0ca5 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -2,10 +2,12 @@ package window import ( "context" - "github.com/stretchr/testify/require" "testing" "time" + "github.com/rulego/streamsql/model" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) @@ -14,8 +16,13 @@ func TestCountingWindow(t *testing.T) { defer cancel() // Test case 1: Normal operation - cw := NewCountingWindow(3, func(results []interface{}) { - t.Logf("Received results: %v", results) + cw, _ := NewCountingWindow(model.WindowConfig{ + Params: map[string]interface{}{ + "count": 3, + "callback": func(results []interface{}) { + t.Logf("Received results: %v", results) + }, + }, }) go cw.Start() @@ -27,31 +34,35 @@ func TestCountingWindow(t *testing.T) { // Trigger one more element to check threshold cw.Add(3) - results := make(chan []interface{}) - go func() { - for res := range cw.OutputChan() { - results <- res - } - }() + resultsChan := cw.OutputChan() + //results := make(chan []model.Row) + // go func() { + // for res := range cw.OutputChan() { + // results <- res + // } + // }() select { - case res := <-results: + case res := <-resultsChan: assert.Len(t, res, 3) - assert.Contains(t, res, 0) - assert.Contains(t, res, 1) - assert.Contains(t, res, 2) + assert.Equal(t, 0, res[0].Data, "第一个元素应该是0") + assert.Equal(t, 1, res[1].Data, "第二个元素应该是1") + assert.Equal(t, 2, res[2].Data, "第三个元素应该是2") case <-time.After(2 * time.Second): t.Error("No results received within timeout") } - + assert.Len(t, cw.dataBuffer, 1) // Test case 2: Reset cw.Reset() assert.Len(t, cw.dataBuffer, 0) } func TestCountingWindowBadThreshold(t *testing.T) { - _, err := CreateWindow("counting", map[string]interface{}{ - "count": 0, + _, err := CreateWindow(model.WindowConfig{ + Type: "counting", + Params: map[string]interface{}{ + "count": 0, + }, }) require.Error(t, err) } diff --git a/window/factory.go b/window/factory.go index d58cce9..916c221 100644 --- a/window/factory.go +++ b/window/factory.go @@ -2,7 +2,10 @@ package window import ( "fmt" - "github.com/spf13/cast" + "reflect" + "time" + + "github.com/rulego/streamsql/model" ) const ( @@ -14,47 +17,55 @@ const ( type Window interface { Add(item interface{}) - GetResults() []interface{} + //GetResults() []interface{} Reset() Start() - OutputChan() <-chan []interface{} - SetCallback(callback func([]interface{})) + OutputChan() <-chan []model.Row + SetCallback(callback func([]model.Row)) Trigger() } -func CreateWindow(windowType string, params map[string]interface{}) (Window, error) { - switch windowType { +func CreateWindow(config model.WindowConfig) (Window, error) { + switch config.Type { case TypeTumbling: - size, err := cast.ToDurationE(params["size"]) - if err != nil { - return nil, fmt.Errorf("invalid size for tumbling window: %v", err) - } - return NewTumblingWindow(size), nil + return NewTumblingWindow(config) case TypeSliding: - size, err := cast.ToDurationE(params["size"]) - if err != nil { - return nil, fmt.Errorf("invalid size for sliding window: %v", err) - } - slide, err := cast.ToDurationE(params["slide"]) - if err != nil { - return nil, fmt.Errorf("invalid slide for sliding window: %v", err) - } - return NewSlidingWindow(size, slide), nil + return NewSlidingWindow(config) case TypeCounting: - count := cast.ToInt(params["count"]) - if count <= 0 { - return nil, fmt.Errorf("count must be a positive integer") - } - cw := NewCountingWindow(count, nil) - if callback, ok := params["callback"].(func([]interface{})); ok { - cw.SetCallback(callback) - } - return cw, nil + return NewCountingWindow(config) default: - return nil, fmt.Errorf("unsupported window type: %s", windowType) + return nil, fmt.Errorf("unsupported window type: %s", config.Type) } } -func (cw *CountingWindow) SetCallback(callback func([]interface{})) { +func (cw *CountingWindow) SetCallback(callback func([]model.Row)) { cw.callback = callback } + +// GetTimestamp 从数据中获取时间戳。 +func GetTimestamp(data interface{}, tsProp string) time.Time { + if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok { + return ts.GetTimestamp() + } else if tsProp != "" { + v := reflect.ValueOf(data) + + // 处理不同类型 + switch v.Kind() { + case reflect.Struct: + // 如果是结构体,使用反射获取字段值 + if f := v.FieldByName(tsProp); f.IsValid() { + if t, ok := f.Interface().(time.Time); ok { + return t + } + } + case reflect.Map: + // 如果是map,直接通过key获取值 + if v.Type().Key().Kind() == reflect.String { + if value := v.MapIndex(reflect.ValueOf(tsProp)); value.IsValid() { + return value.Interface().(time.Time) + } + } + } + } + return time.Now() +} diff --git a/window/sliding_window.go b/window/sliding_window.go index a807816..194be63 100644 --- a/window/sliding_window.go +++ b/window/sliding_window.go @@ -2,8 +2,13 @@ package window import ( "context" + "fmt" "sync" "time" + + "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" + "github.com/spf13/cast" ) // 确保 SlidingWindow 结构体实现了 Window 接口 @@ -11,12 +16,14 @@ var _ Window = (*SlidingWindow)(nil) // TimedData 用于包装数据和时间戳 type TimedData struct { - Data interface{} - Timestamp time.Time + Data interface{} + Timestamp time.Time } // SlidingWindow 表示一个滑动窗口,用于按时间范围处理数据 type SlidingWindow struct { + // config 窗口的配置信息 + config model.WindowConfig // 窗口的总大小,即窗口覆盖的时间范围 size time.Duration // 窗口每次滑动的时间间隔 @@ -24,32 +31,42 @@ type SlidingWindow struct { // 用于保护数据并发访问的互斥锁 mu sync.Mutex // 存储窗口内的数据 - data []TimedData + data []model.Row // 用于输出窗口内数据的通道 - outputChan chan []interface{} + outputChan chan []model.Row // 当窗口触发时执行的回调函数 - callback func([]interface{}) + callback func([]model.Row) // 用于控制窗口生命周期的上下文 ctx context.Context // 用于取消上下文的函数 cancelFunc context.CancelFunc // 用于定时触发窗口的定时器 - timer *time.Timer + timer *time.Timer + currentSlot *model.TimeSlot } // NewSlidingWindow 创建一个新的滑动窗口实例 // 参数 size 表示窗口的总大小,slide 表示窗口每次滑动的时间间隔 -func NewSlidingWindow(size, slide time.Duration) *SlidingWindow { +func NewSlidingWindow(config model.WindowConfig) (*SlidingWindow, error) { // 创建一个可取消的上下文 ctx, cancel := context.WithCancel(context.Background()) + size, err := cast.ToDurationE(config.Params["size"]) + if err != nil { + return nil, fmt.Errorf("invalid size for sliding window: %v", err) + } + slide, err := cast.ToDurationE(config.Params["slide"]) + if err != nil { + return nil, fmt.Errorf("invalid slide for sliding window: %v", err) + } return &SlidingWindow{ + config: config, size: size, slide: slide, - outputChan: make(chan []interface{}, 10), + outputChan: make(chan []model.Row, 10), ctx: ctx, cancelFunc: cancel, - data: make([]TimedData, 0), - } + data: make([]model.Row, 0), + }, nil } // Add 向滑动窗口中添加数据 @@ -58,19 +75,18 @@ func (sw *SlidingWindow) Add(data interface{}) { // 加锁以保证数据的并发安全 sw.mu.Lock() defer sw.mu.Unlock() - - var timestamp time.Time - if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok { - timestamp = ts.GetTimestamp() - } else { - timestamp = time.Now() - } - - // 将数据添加到窗口的数据列表中 - sw.data = append(sw.data, TimedData{ - Data: data, - Timestamp: timestamp, - }) + // 将数据添加到窗口的数据列表中 + t := GetTimestamp(data, sw.config.TsProp) + if sw.currentSlot == nil { + sw.currentSlot = sw.createSlot(t) + } + go func() { + row := model.Row{ + Data: data, + Timestamp: t, + } + sw.data = append(sw.data, row) + }() } // Start 启动滑动窗口,开始定时触发窗口 @@ -106,19 +122,25 @@ func (sw *SlidingWindow) Trigger() { } // 计算截止时间,即当前时间减去窗口的总大小 - cutoff := time.Now().Add(-sw.size) - var newData []TimedData - // 遍历窗口内的数据,只保留在截止时间之后的数据 + next := sw.NextSlot() + // 保留下一个窗口的数据 + tms := next.Start.Add(-sw.size) + tme := next.End.Add(sw.size) + temp := model.NewTimeSlot(&tms, &tme) + newData := make([]model.Row, 0) for _, item := range sw.data { - if item.Timestamp.After(cutoff) { + if temp.Contains(item.Timestamp) { newData = append(newData, item) } } // 提取出 Data 字段组成 []interface{} 类型的数据 - resultData := make([]interface{}, 0, len(newData)) - for _, item := range newData { - resultData = append(resultData, item.Data) + resultData := make([]model.Row, 0) + for _, item := range sw.data { + if sw.currentSlot.Contains(item.Timestamp) { + item.Slot = sw.currentSlot + resultData = append(resultData, item) + } } // 如果设置了回调函数,则执行回调函数 @@ -128,6 +150,7 @@ func (sw *SlidingWindow) Trigger() { // 更新窗口内的数据 sw.data = newData + sw.currentSlot = next // 将新的数据发送到输出通道 sw.outputChan <- resultData } @@ -139,28 +162,35 @@ func (sw *SlidingWindow) Reset() { defer sw.mu.Unlock() // 清空窗口内的数据 sw.data = nil + sw.currentSlot = nil } // OutputChan 返回滑动窗口的输出通道 -func (sw *SlidingWindow) OutputChan() <-chan []interface{} { +func (sw *SlidingWindow) OutputChan() <-chan []model.Row { return sw.outputChan } // SetCallback 设置滑动窗口触发时执行的回调函数 // 参数 callback 表示要设置的回调函数 -func (sw *SlidingWindow) SetCallback(callback func([]interface{})) { +func (sw *SlidingWindow) SetCallback(callback func([]model.Row)) { sw.callback = callback } -// GetResults 获取滑动窗口内的当前数据 -func (sw *SlidingWindow) GetResults() []interface{} { - // 加锁以保证数据的并发安全 - sw.mu.Lock() - defer sw.mu.Unlock() - // 提取出 Data 字段组成 []interface{} 类型的数据 - resultData := make([]interface{}, 0, len(sw.data)) - for _, item := range sw.data { - resultData = append(resultData, item.Data) +func (sw *SlidingWindow) NextSlot() *model.TimeSlot { + if sw.currentSlot == nil { + return nil } - return resultData + start := sw.currentSlot.Start.Add(sw.slide) + end := sw.currentSlot.End.Add(sw.slide) + next := model.NewTimeSlot(&start, &end) + return next +} + +// createSlot 创建一个新的时间槽位 +func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot { + // 创建一个新的时间槽位 + start := timex.AlignTimeToWindow(t, sw.size) + end := start.Add(sw.size) + slot := model.NewTimeSlot(&start, &end) + return slot } diff --git a/window/sliding_window_test.go b/window/sliding_window_test.go index 2a08e8f..b605540 100644 --- a/window/sliding_window_test.go +++ b/window/sliding_window_test.go @@ -2,42 +2,128 @@ package window import ( "context" - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" + "github.com/stretchr/testify/assert" ) func TestSlidingWindow(t *testing.T) { _, cancel := context.WithCancel(context.Background()) defer cancel() - sw := NewSlidingWindow(2*time.Second, 1*time.Second) - sw.SetCallback(func(results []interface{}) { + sw, _ := NewSlidingWindow(model.WindowConfig{ + Params: map[string]interface{}{ + "size": "2s", + "slide": "1s", + }, + TsProp: "Ts", + TimeUnit: time.Second, + }) + sw.SetCallback(func(results []model.Row) { t.Logf("Received results: %v", results) }) sw.Start() // 添加数据 - now := time.Now() - sw.Add(now.Add(-3 * time.Second)) - sw.Add(now.Add(-2 * time.Second)) - sw.Add(now.Add(-1 * time.Second)) - sw.Add(now) + t_3 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 56, 789000000, time.UTC), tag: "1"} + t_2 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 57, 789000000, time.UTC), tag: "2"} + t_1 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 58, 789000000, time.UTC), tag: "3"} + t_0 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 59, 789000000, time.UTC), tag: "4"} + + sw.Add(t_3) + sw.Add(t_2) + sw.Add(t_1) + sw.Add(t_0) // 等待一段时间,触发窗口 - time.Sleep(3 * time.Second) + //time.Sleep(3 * time.Second) // 检查结果 resultsChan := sw.OutputChan() - var results []interface{} - select { - case results = <-resultsChan: - case <-time.After(1 * time.Second): - t.Fatal("No results received within timeout") + var results []model.Row + + for { + select { + case results = <-resultsChan: + raw := make([]TestDate, 0) + for _, row := range results { + raw = append(raw, row.Data.(TestDate)) + } + + // 获取当前窗口的时间范围 + windowStart := results[0].Slot.Start + windowEnd := results[0].Slot.End + t.Logf("Window range: %v - %v", windowStart, windowEnd) + + // 检查窗口内的数据 + expectedData := make([]TestDate, 0) + + if windowStart.Before(t_3.Ts) && windowEnd.After(t_2.Ts) { + expectedData = []TestDate{t_3, t_2} + start := timex.AlignTimeToWindow(t_3.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) + } else if windowStart.Before(t_2.Ts) && windowEnd.After(t_1.Ts) { + expectedData = []TestDate{t_2, t_1} + start := timex.AlignTimeToWindow(t_2.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) + } else if windowStart.Before(t_1.Ts) && windowEnd.After(t_0.Ts) { + expectedData = []TestDate{t_1, t_0} + start := timex.AlignTimeToWindow(t_1.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) + } else { + expectedData = []TestDate{t_0} + start := timex.AlignTimeToWindow(t_0.Ts, sw.size) + assert.Equal(t, start, windowStart) + assert.Equal(t, start.Add(sw.size), windowEnd) + } + + // 验证窗口数据 + assert.Equal(t, len(expectedData), len(raw), "窗口数据数量不匹配") + for _, expected := range expectedData { + assert.Contains(t, raw, expected, "窗口缺少预期数据") + } + default: + // 通道为空时退出 + goto END + } } +END: // 预期结果:保留最近 2 秒内的数据 - assert.Len(t, results, 2) - assert.Contains(t, results, now.Add(-1*time.Second)) - assert.Contains(t, results, now) + assert.Len(t, results, 0) +} + +type TestDate struct { + Ts time.Time + tag string +} + +type TestDate2 struct { + ts time.Time +} + +func (d TestDate2) GetTimestamp() time.Time { + return d.ts +} + +func TestGetTimestamp(t *testing.T) { + t_0 := time.Now() + data := map[string]interface{}{"device": "aa", "age": 15.0, "score": 100, "ts": t_0} + t_1 := GetTimestamp(data, "ts") + + data_1 := TestDate{Ts: t_0} + t_2 := GetTimestamp(data_1, "Ts") + + data_2 := TestDate2{ts: t_0} + t_3 := GetTimestamp(data_2, "") + + assert.Equal(t, t_0, t_1) + assert.Equal(t, t_0, t_2) + assert.Equal(t, t_0, t_3) } diff --git a/window/tumbling_window.go b/window/tumbling_window.go index 3a6e87c..9e15046 100644 --- a/window/tumbling_window.go +++ b/window/tumbling_window.go @@ -3,8 +3,13 @@ package window import ( "context" + "fmt" "sync" "time" + + "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" + "github.com/spf13/cast" ) // 确保 TumblingWindow 结构体实现了 Window 接口。 @@ -12,35 +17,43 @@ var _ Window = (*TumblingWindow)(nil) // TumblingWindow 表示一个滚动窗口,用于在固定时间间隔内收集数据并触发处理。 type TumblingWindow struct { + // config 是窗口的配置信息。 + config model.WindowConfig // size 是滚动窗口的时间大小,即窗口的持续时间。 size time.Duration // mu 用于保护对窗口数据的并发访问。 mu sync.Mutex // data 存储窗口内收集的数据。 - data []interface{} + data []model.Row // outputChan 是一个通道,用于在窗口触发时发送数据。 - outputChan chan []interface{} + outputChan chan []model.Row // callback 是一个可选的回调函数,在窗口触发时调用。 - callback func([]interface{}) + callback func([]model.Row) // ctx 用于控制窗口的生命周期。 ctx context.Context // cancelFunc 用于取消窗口的操作。 cancelFunc context.CancelFunc // timer 用于定时触发窗口。 - timer *time.Timer + timer *time.Timer + currentSlot *model.TimeSlot } // NewTumblingWindow 创建一个新的滚动窗口实例。 // 参数 size 是窗口的时间大小。 -func NewTumblingWindow(size time.Duration) *TumblingWindow { +func NewTumblingWindow(config model.WindowConfig) (*TumblingWindow, error) { // 创建一个可取消的上下文。 ctx, cancel := context.WithCancel(context.Background()) + size, err := cast.ToDurationE(config.Params["size"]) + if err != nil { + return nil, fmt.Errorf("invalid size for tumbling window: %v", err) + } return &TumblingWindow{ + config: config, size: size, - outputChan: make(chan []interface{}, 10), + outputChan: make(chan []model.Row, 10), ctx: ctx, cancelFunc: cancel, - } + }, nil } // Add 向滚动窗口添加数据。 @@ -50,7 +63,33 @@ func (tw *TumblingWindow) Add(data interface{}) { tw.mu.Lock() defer tw.mu.Unlock() // 将数据追加到窗口的数据列表中。 - tw.data = append(tw.data, data) + if tw.currentSlot == nil { + tw.currentSlot = tw.createSlot(GetTimestamp(data, tw.config.TsProp)) + } + go func() { + row := model.Row{ + Data: data, + Timestamp: GetTimestamp(data, tw.config.TsProp), + } + tw.data = append(tw.data, row) + }() +} + +func (sw *TumblingWindow) createSlot(t time.Time) *model.TimeSlot { + // 创建一个新的时间槽位 + start := timex.AlignTimeToWindow(t, sw.size) + end := start.Add(sw.size) + slot := model.NewTimeSlot(&start, &end) + return slot +} + +func (sw *TumblingWindow) NextSlot() *model.TimeSlot { + if sw.currentSlot == nil { + return nil + } + start := sw.currentSlot.End + end := sw.currentSlot.End.Add(sw.size) + return model.NewTimeSlot(start, &end) } // Stop 停止滚动窗口的操作。 @@ -87,16 +126,38 @@ func (tw *TumblingWindow) Trigger() { // 加锁以确保并发安全。 tw.mu.Lock() defer tw.mu.Unlock() - - // 如果设置了回调函数,则调用它。 - if tw.callback != nil { - tw.callback(tw.data) + // 计算下一个窗口槽位 + next := tw.NextSlot() + // 保留下一个窗口的数据 + tms := next.Start.Add(-tw.size) + tme := next.End.Add(tw.size) + temp := model.NewTimeSlot(&tms, &tme) + newData := make([]model.Row, 0) + for _, item := range tw.data { + if temp.Contains(item.Timestamp) { + newData = append(newData, item) + } } - // 将窗口数据发送到输出通道。 - tw.outputChan <- append([]interface{}{}, tw.data...) - // 重置窗口数据。 - tw.data = nil + // 提取出当前窗口数据 + resultData := make([]model.Row, 0) + for _, item := range tw.data { + if tw.currentSlot.Contains(item.Timestamp) { + item.Slot = tw.currentSlot + resultData = append(resultData, item) + } + } + + // 如果设置了回调函数,则执行回调函数 + if tw.callback != nil { + tw.callback(resultData) + } + + // 更新窗口内的数据 + tw.data = newData + tw.currentSlot = next + // 将新的数据发送到输出通道 + tw.outputChan <- resultData } // Reset 重置滚动窗口的数据。 @@ -106,24 +167,25 @@ func (tw *TumblingWindow) Reset() { defer tw.mu.Unlock() // 清空窗口数据。 tw.data = nil + tw.currentSlot = nil } // OutputChan 返回一个只读通道,用于接收窗口触发时的数据。 -func (tw *TumblingWindow) OutputChan() <-chan []interface{} { +func (tw *TumblingWindow) OutputChan() <-chan []model.Row { return tw.outputChan } // SetCallback 设置滚动窗口触发时的回调函数。 // 参数 callback 是要设置的回调函数。 -func (tw *TumblingWindow) SetCallback(callback func([]interface{})) { +func (tw *TumblingWindow) SetCallback(callback func([]model.Row)) { tw.callback = callback } -// GetResults 获取当前滚动窗口中的数据副本。 -func (tw *TumblingWindow) GetResults() []interface{} { - // 加锁以确保并发安全。 - tw.mu.Lock() - defer tw.mu.Unlock() - // 返回窗口数据的副本。 - return append([]interface{}{}, tw.data...) -} +// // GetResults 获取当前滚动窗口中的数据副本。 +// func (tw *TumblingWindow) GetResults() []interface{} { +// // 加锁以确保并发安全。 +// tw.mu.Lock() +// defer tw.mu.Unlock() +// // 返回窗口数据的副本。 +// return append([]interface{}{}, tw.data...) +// } diff --git a/window/tumbling_window_test.go b/window/tumbling_window_test.go index a9300b4..7e9b6f7 100644 --- a/window/tumbling_window_test.go +++ b/window/tumbling_window_test.go @@ -2,62 +2,102 @@ package window import ( "context" - "github.com/stretchr/testify/require" + "fmt" "testing" "time" + + "github.com/rulego/streamsql/model" + "github.com/stretchr/testify/require" ) func TestTumblingWindow(t *testing.T) { _, cancel := context.WithCancel(context.Background()) defer cancel() - tw := NewTumblingWindow(2 * time.Second) - tw.SetCallback(func(results []interface{}) { + tw, _ := NewTumblingWindow(model.WindowConfig{ + Type: "TumblingWindow", + Params: map[string]interface{}{"size": "2s"}, + TsProp: "Ts", + }) + tw.SetCallback(func(results []model.Row) { // Process results }) go tw.Start() // Add data every 500ms + baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC) + // 添加测试数据 for i := 0; i < 5; i++ { - tw.Add(i) - time.Sleep(1100 * time.Millisecond) + data := TestDate{ + Ts: baseTime.Add(time.Duration(i) * 1100 * time.Millisecond), + tag: fmt.Sprintf("%d", i), + } + tw.Add(data) } - // Check output channel + // 收集窗口结果 resultsChan := tw.OutputChan() - var results []interface{} - select { - case results = <-resultsChan: - case <-time.After(3 * time.Second): - t.Fatal("No results received within timeout") + var all [][]model.Row = make([][]model.Row, 0) + + // 收集所有窗口数据 +COLLECT: + for { + select { + case results := <-resultsChan: + all = append(all, results) + if len(all) >= 3 { + break COLLECT + } + default: + } } - // Verify that data is sent every 2 seconds - require.Len(t, results, 2) - require.Equal(t, []interface{}{0, 1}, results) + // 验证窗口数据 + require.Len(t, all, 3, "应该有3个时间窗口的数据") - // Verify next batch - select { - case results = <-resultsChan: - require.Len(t, results, 2) - require.Equal(t, []interface{}{2, 3}, results) - case <-time.After(3 * time.Second): - t.Fatal("No results received within timeout") + // 验证每个窗口的数据 + expectedWindows := []struct { + size int + tags []string + startIdx int + }{ + {size: 2, tags: []string{"0", "1"}, startIdx: 0}, + {size: 2, tags: []string{"2", "3"}, startIdx: 1}, + {size: 1, tags: []string{"4"}, startIdx: 2}, } - //time.Sleep(1100 * time.Millisecond) - //results = <-resultsChan + for i, window := range all { + expected := expectedWindows[i] + require.Len(t, window, expected.size, "窗口 %d 数据数量不匹配", i) + + // 验证数据内容 + for _, row := range window { + require.Contains(t, expected.tags, row.Data.(TestDate).tag) + } + + // 验证时间槽 + startTime := baseTime.Add(time.Duration(i*2) * time.Second) + endTime := startTime.Add(2 * time.Second) + require.True(t, window[0].Slot.Start.Equal(startTime) && + window[0].Slot.End.Equal(endTime), + "窗口 %d 时间槽边界不正确", i) + } // Verify reset and final batch tw.Reset() - tw.Add(99) + tw.Add(TestDate{ + Ts: baseTime.Add(time.Duration(99) * 1100 * time.Millisecond), + tag: fmt.Sprintf("%d", 99), + }) + // time.Sleep(1100 * time.Millisecond) cancel() select { - case results = <-resultsChan: + case results := <-resultsChan: require.Len(t, results, 1) - require.Equal(t, []interface{}{99}, results) - case <-time.After(3 * time.Second): - t.Fatal("No results received within timeout") + require.Equal(t, "99", results[0].Data.(TestDate).tag) + startTime := baseTime.Add(108 * time.Second) + endTime := baseTime.Add(110 * time.Second) + require.True(t, results[0].Slot.Start.Equal(startTime) && results[0].Slot.End.Equal(endTime)) } }