diff --git a/model/timeslot.go b/model/timeslot.go index f6fb4b6..0ff5043 100644 --- a/model/timeslot.go +++ b/model/timeslot.go @@ -34,7 +34,7 @@ func (ts TimeSlot) Hash() uint64 { // Contains 检查给定时间是否在槽位范围内 func (ts TimeSlot) Contains(t time.Time) bool { return (t.Equal(*ts.Start) || t.After(*ts.Start)) && - (t.Equal(*ts.End) || t.Before(*ts.End)) + t.Before(*ts.End) } func (ts *TimeSlot) GetStartTime() *time.Time { diff --git a/utils/time.go b/utils/time.go index 2288b50..8c83f33 100644 --- a/utils/time.go +++ b/utils/time.go @@ -2,7 +2,17 @@ package timex import "time" +// AlignTimeToWindow 将时间对齐到窗口的起始时间。 func AlignTimeToWindow(t time.Time, size time.Duration) time.Time { offset := t.UnixNano() % int64(size) return t.Add(time.Duration(-offset)) } + +// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。 +func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time { + trunc := t.Truncate(timeUnit) + if !roundUp { + return trunc.Add(timeUnit) + } + return trunc +} diff --git a/window/counting_window.go b/window/counting_window.go index 4a16a45..5cd70eb 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -7,6 +7,7 @@ import ( "time" "github.com/rulego/streamsql/model" + timex "github.com/rulego/streamsql/utils" "github.com/spf13/cast" ) @@ -51,20 +52,33 @@ func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) { func (cw *CountingWindow) Add(data interface{}) { cw.mu.Lock() defer cw.mu.Unlock() + // 将数据添加到窗口的数据列表中 + t := GetTimestamp(data, cw.config.TsProp) row := model.Row{ Data: data, - Timestamp: GetTimestamp(data, cw.config.TsProp), + Timestamp: t, } cw.dataBuffer = append(cw.dataBuffer, row) cw.count++ - shouldTrigger := cw.count >= cw.threshold + shouldTrigger := cw.count == cw.threshold if shouldTrigger { + slot := cw.createSlot(cw.dataBuffer[:cw.threshold]) + for _, r := range cw.dataBuffer[:cw.threshold] { + // 由于Row是值类型,这里需要通过指针来修改Slot字段 + (&r).Slot = slot + } + data := cw.dataBuffer[:cw.threshold] + if len(cw.dataBuffer) > cw.threshold { + cw.dataBuffer = cw.dataBuffer[cw.threshold:] + } else { + cw.dataBuffer = make([]model.Row, 0, cw.threshold) + } go func() { if cw.callback != nil { - cw.callback(cw.dataBuffer) + cw.callback(data) } - cw.outputChan <- cw.dataBuffer + cw.outputChan <- data cw.Reset() }() } @@ -89,24 +103,34 @@ func (cw *CountingWindow) Start() { } func (cw *CountingWindow) Trigger() { - cw.triggerChan <- struct{}{} + // cw.triggerChan <- struct{}{} - go func() { - cw.mu.Lock() - defer cw.mu.Unlock() + // go func() { + // cw.mu.Lock() + // defer cw.mu.Unlock() - if cw.callback != nil && len(cw.dataBuffer) > 0 { - cw.callback(cw.dataBuffer) - } - cw.Reset() - }() + // if cw.callback != nil && len(cw.dataBuffer) > 0 { + // var resultData []model.Row + // if len(cw.dataBuffer) > cw.threshold { + // resultData = cw.dataBuffer[:cw.threshold] + // } else { + // resultData = cw.dataBuffer + // } + // slot := cw.createSlot(resultData) + // for _, r := range resultData { + // r.Slot = slot + // } + // cw.callback(resultData) + // } + // cw.Reset() + // }() } func (cw *CountingWindow) Reset() { cw.mu.Lock() defer cw.mu.Unlock() cw.count = 0 - cw.dataBuffer = cw.dataBuffer[:0] + cw.dataBuffer = cw.dataBuffer[0:] } func (cw *CountingWindow) OutputChan() <-chan []model.Row { @@ -116,3 +140,20 @@ func (cw *CountingWindow) OutputChan() <-chan []model.Row { // func (cw *CountingWindow) GetResults() []interface{} { // return append([]mode.Row, cw.dataBuffer...) // } + +// createSlot 创建一个新的时间槽位 +func (cw *CountingWindow) createSlot(data []model.Row) *model.TimeSlot { + if len(data) == 0 { + return nil + } else if len(data) < cw.threshold { + start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true) + end := timex.AlignTime(data[len(cw.dataBuffer)-1].Timestamp, cw.config.TimeUnit, false) + slot := model.NewTimeSlot(&start, &end) + return slot + } else { + start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true) + end := timex.AlignTime(data[cw.threshold-1].Timestamp, cw.config.TimeUnit, false) + slot := model.NewTimeSlot(&start, &end) + return slot + } +} diff --git a/window/counting_window_test.go b/window/counting_window_test.go index 5d1ae02..ea36f6d 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -34,30 +34,27 @@ func TestCountingWindow(t *testing.T) { // Trigger one more element to check threshold cw.Add(3) - results := make(chan []model.Row) - go func() { - for res := range cw.OutputChan() { - results <- res - } - }() + resultsChan := cw.OutputChan() + //results := make(chan []model.Row) + // go func() { + // for res := range cw.OutputChan() { + // results <- res + // } + // }() select { - case res := <-results: + case res := <-resultsChan: assert.Len(t, res, 3) - raw := make([]interface{}, len(res)) - for _, row := range res { - raw = append(raw, row.Data) - } - assert.Contains(t, raw, 0) - assert.Contains(t, raw, 1) - assert.Contains(t, raw, 2) + assert.Equal(t, 0, res[0].Data, "第一个元素应该是0") + assert.Equal(t, 1, res[1].Data, "第二个元素应该是1") + assert.Equal(t, 2, res[2].Data, "第三个元素应该是2") case <-time.After(2 * time.Second): t.Error("No results received within timeout") } // Test case 2: Reset cw.Reset() - assert.Len(t, cw.dataBuffer, 0) + assert.Len(t, cw.dataBuffer, 1) } func TestCountingWindowBadThreshold(t *testing.T) { diff --git a/window/factory.go b/window/factory.go index 483e81c..916c221 100644 --- a/window/factory.go +++ b/window/factory.go @@ -69,12 +69,3 @@ func GetTimestamp(data interface{}, tsProp string) time.Time { } return time.Now() } - -// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。 -func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time { - trunc := t.Truncate(timeUnit) - if !roundUp { - return trunc.Add(timeUnit) - } - return trunc -} diff --git a/window/sliding_window.go b/window/sliding_window.go index 4a700aa..13c79ff 100644 --- a/window/sliding_window.go +++ b/window/sliding_window.go @@ -42,7 +42,6 @@ type SlidingWindow struct { cancelFunc context.CancelFunc // 用于定时触发窗口的定时器 timer *time.Timer - startSlot *model.TimeSlot currentSlot *model.TimeSlot } @@ -77,36 +76,17 @@ func (sw *SlidingWindow) Add(data interface{}) { sw.mu.Lock() defer sw.mu.Unlock() // 将数据添加到窗口的数据列表中 - - if sw.startSlot == nil { - sw.startSlot = sw.createSlot(GetTimestamp(data, sw.config.TsProp)) - sw.currentSlot = sw.startSlot + t := GetTimestamp(data, sw.config.TsProp) + if sw.currentSlot == nil { + sw.currentSlot = sw.createSlot(t) } row := model.Row{ Data: data, - Timestamp: GetTimestamp(data, sw.config.TsProp), + Timestamp: t, } sw.data = append(sw.data, row) } -func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot { - // 创建一个新的时间槽位 - start := timex.AlignTimeToWindow(t, sw.size) - end := start.Add(sw.size) - slot := model.NewTimeSlot(&start, &end) - return slot -} - -func (sw *SlidingWindow) NextSlot() *model.TimeSlot { - if sw.currentSlot == nil { - return nil - } - start := sw.currentSlot.Start.Add(sw.slide) - end := sw.currentSlot.End.Add(sw.slide) - next := model.NewTimeSlot(&start, &end) - return next -} - // Start 启动滑动窗口,开始定时触发窗口 func (sw *SlidingWindow) Start() { go func() { @@ -190,15 +170,21 @@ func (sw *SlidingWindow) SetCallback(callback func([]model.Row)) { sw.callback = callback } -// GetResults 获取滑动窗口内的当前数据 -func (sw *SlidingWindow) GetResults() []interface{} { - // 加锁以保证数据的并发安全 - sw.mu.Lock() - defer sw.mu.Unlock() - // 提取出 Data 字段组成 []interface{} 类型的数据 - resultData := make([]interface{}, 0, len(sw.data)) - for _, item := range sw.data { - resultData = append(resultData, item.Data) +func (sw *SlidingWindow) NextSlot() *model.TimeSlot { + if sw.currentSlot == nil { + return nil } - return resultData + start := sw.currentSlot.Start.Add(sw.slide) + end := sw.currentSlot.End.Add(sw.slide) + next := model.NewTimeSlot(&start, &end) + return next +} + +// createSlot 创建一个新的时间槽位 +func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot { + // 创建一个新的时间槽位 + start := timex.AlignTimeToWindow(t, sw.size) + end := start.Add(sw.size) + slot := model.NewTimeSlot(&start, &end) + return slot }