feat: 修改counting window 逻辑

This commit is contained in:
dexter
2025-04-07 21:44:07 +08:00
parent 20313690d3
commit b28c810c65
6 changed files with 98 additions and 73 deletions

View File

@ -34,7 +34,7 @@ func (ts TimeSlot) Hash() uint64 {
// Contains 检查给定时间是否在槽位范围内 // Contains 检查给定时间是否在槽位范围内
func (ts TimeSlot) Contains(t time.Time) bool { func (ts TimeSlot) Contains(t time.Time) bool {
return (t.Equal(*ts.Start) || t.After(*ts.Start)) && return (t.Equal(*ts.Start) || t.After(*ts.Start)) &&
(t.Equal(*ts.End) || t.Before(*ts.End)) t.Before(*ts.End)
} }
func (ts *TimeSlot) GetStartTime() *time.Time { func (ts *TimeSlot) GetStartTime() *time.Time {

View File

@ -2,7 +2,17 @@ package timex
import "time" import "time"
// AlignTimeToWindow 将时间对齐到窗口的起始时间。
func AlignTimeToWindow(t time.Time, size time.Duration) time.Time { func AlignTimeToWindow(t time.Time, size time.Duration) time.Time {
offset := t.UnixNano() % int64(size) offset := t.UnixNano() % int64(size)
return t.Add(time.Duration(-offset)) return t.Add(time.Duration(-offset))
} }
// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。
func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time {
trunc := t.Truncate(timeUnit)
if !roundUp {
return trunc.Add(timeUnit)
}
return trunc
}

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/rulego/streamsql/model" "github.com/rulego/streamsql/model"
timex "github.com/rulego/streamsql/utils"
"github.com/spf13/cast" "github.com/spf13/cast"
) )
@ -51,20 +52,33 @@ func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) {
func (cw *CountingWindow) Add(data interface{}) { func (cw *CountingWindow) Add(data interface{}) {
cw.mu.Lock() cw.mu.Lock()
defer cw.mu.Unlock() defer cw.mu.Unlock()
// 将数据添加到窗口的数据列表中
t := GetTimestamp(data, cw.config.TsProp)
row := model.Row{ row := model.Row{
Data: data, Data: data,
Timestamp: GetTimestamp(data, cw.config.TsProp), Timestamp: t,
} }
cw.dataBuffer = append(cw.dataBuffer, row) cw.dataBuffer = append(cw.dataBuffer, row)
cw.count++ cw.count++
shouldTrigger := cw.count >= cw.threshold shouldTrigger := cw.count == cw.threshold
if shouldTrigger { if shouldTrigger {
slot := cw.createSlot(cw.dataBuffer[:cw.threshold])
for _, r := range cw.dataBuffer[:cw.threshold] {
// 由于Row是值类型这里需要通过指针来修改Slot字段
(&r).Slot = slot
}
data := cw.dataBuffer[:cw.threshold]
if len(cw.dataBuffer) > cw.threshold {
cw.dataBuffer = cw.dataBuffer[cw.threshold:]
} else {
cw.dataBuffer = make([]model.Row, 0, cw.threshold)
}
go func() { go func() {
if cw.callback != nil { if cw.callback != nil {
cw.callback(cw.dataBuffer) cw.callback(data)
} }
cw.outputChan <- cw.dataBuffer cw.outputChan <- data
cw.Reset() cw.Reset()
}() }()
} }
@ -89,24 +103,34 @@ func (cw *CountingWindow) Start() {
} }
func (cw *CountingWindow) Trigger() { func (cw *CountingWindow) Trigger() {
cw.triggerChan <- struct{}{} // cw.triggerChan <- struct{}{}
go func() { // go func() {
cw.mu.Lock() // cw.mu.Lock()
defer cw.mu.Unlock() // defer cw.mu.Unlock()
if cw.callback != nil && len(cw.dataBuffer) > 0 { // if cw.callback != nil && len(cw.dataBuffer) > 0 {
cw.callback(cw.dataBuffer) // var resultData []model.Row
} // if len(cw.dataBuffer) > cw.threshold {
cw.Reset() // resultData = cw.dataBuffer[:cw.threshold]
}() // } else {
// resultData = cw.dataBuffer
// }
// slot := cw.createSlot(resultData)
// for _, r := range resultData {
// r.Slot = slot
// }
// cw.callback(resultData)
// }
// cw.Reset()
// }()
} }
func (cw *CountingWindow) Reset() { func (cw *CountingWindow) Reset() {
cw.mu.Lock() cw.mu.Lock()
defer cw.mu.Unlock() defer cw.mu.Unlock()
cw.count = 0 cw.count = 0
cw.dataBuffer = cw.dataBuffer[:0] cw.dataBuffer = cw.dataBuffer[0:]
} }
func (cw *CountingWindow) OutputChan() <-chan []model.Row { func (cw *CountingWindow) OutputChan() <-chan []model.Row {
@ -116,3 +140,20 @@ func (cw *CountingWindow) OutputChan() <-chan []model.Row {
// func (cw *CountingWindow) GetResults() []interface{} { // func (cw *CountingWindow) GetResults() []interface{} {
// return append([]mode.Row, cw.dataBuffer...) // return append([]mode.Row, cw.dataBuffer...)
// } // }
// createSlot 创建一个新的时间槽位
func (cw *CountingWindow) createSlot(data []model.Row) *model.TimeSlot {
if len(data) == 0 {
return nil
} else if len(data) < cw.threshold {
start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true)
end := timex.AlignTime(data[len(cw.dataBuffer)-1].Timestamp, cw.config.TimeUnit, false)
slot := model.NewTimeSlot(&start, &end)
return slot
} else {
start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true)
end := timex.AlignTime(data[cw.threshold-1].Timestamp, cw.config.TimeUnit, false)
slot := model.NewTimeSlot(&start, &end)
return slot
}
}

View File

@ -34,30 +34,27 @@ func TestCountingWindow(t *testing.T) {
// Trigger one more element to check threshold // Trigger one more element to check threshold
cw.Add(3) cw.Add(3)
results := make(chan []model.Row) resultsChan := cw.OutputChan()
go func() { //results := make(chan []model.Row)
for res := range cw.OutputChan() { // go func() {
results <- res // for res := range cw.OutputChan() {
} // results <- res
}() // }
// }()
select { select {
case res := <-results: case res := <-resultsChan:
assert.Len(t, res, 3) assert.Len(t, res, 3)
raw := make([]interface{}, len(res)) assert.Equal(t, 0, res[0].Data, "第一个元素应该是0")
for _, row := range res { assert.Equal(t, 1, res[1].Data, "第二个元素应该是1")
raw = append(raw, row.Data) assert.Equal(t, 2, res[2].Data, "第三个元素应该是2")
}
assert.Contains(t, raw, 0)
assert.Contains(t, raw, 1)
assert.Contains(t, raw, 2)
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Error("No results received within timeout") t.Error("No results received within timeout")
} }
// Test case 2: Reset // Test case 2: Reset
cw.Reset() cw.Reset()
assert.Len(t, cw.dataBuffer, 0) assert.Len(t, cw.dataBuffer, 1)
} }
func TestCountingWindowBadThreshold(t *testing.T) { func TestCountingWindowBadThreshold(t *testing.T) {

View File

@ -69,12 +69,3 @@ func GetTimestamp(data interface{}, tsProp string) time.Time {
} }
return time.Now() return time.Now()
} }
// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。
func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time {
trunc := t.Truncate(timeUnit)
if !roundUp {
return trunc.Add(timeUnit)
}
return trunc
}

View File

@ -42,7 +42,6 @@ type SlidingWindow struct {
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
// 用于定时触发窗口的定时器 // 用于定时触发窗口的定时器
timer *time.Timer timer *time.Timer
startSlot *model.TimeSlot
currentSlot *model.TimeSlot currentSlot *model.TimeSlot
} }
@ -77,36 +76,17 @@ func (sw *SlidingWindow) Add(data interface{}) {
sw.mu.Lock() sw.mu.Lock()
defer sw.mu.Unlock() defer sw.mu.Unlock()
// 将数据添加到窗口的数据列表中 // 将数据添加到窗口的数据列表中
t := GetTimestamp(data, sw.config.TsProp)
if sw.startSlot == nil { if sw.currentSlot == nil {
sw.startSlot = sw.createSlot(GetTimestamp(data, sw.config.TsProp)) sw.currentSlot = sw.createSlot(t)
sw.currentSlot = sw.startSlot
} }
row := model.Row{ row := model.Row{
Data: data, Data: data,
Timestamp: GetTimestamp(data, sw.config.TsProp), Timestamp: t,
} }
sw.data = append(sw.data, row) sw.data = append(sw.data, row)
} }
func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot {
// 创建一个新的时间槽位
start := timex.AlignTimeToWindow(t, sw.size)
end := start.Add(sw.size)
slot := model.NewTimeSlot(&start, &end)
return slot
}
func (sw *SlidingWindow) NextSlot() *model.TimeSlot {
if sw.currentSlot == nil {
return nil
}
start := sw.currentSlot.Start.Add(sw.slide)
end := sw.currentSlot.End.Add(sw.slide)
next := model.NewTimeSlot(&start, &end)
return next
}
// Start 启动滑动窗口,开始定时触发窗口 // Start 启动滑动窗口,开始定时触发窗口
func (sw *SlidingWindow) Start() { func (sw *SlidingWindow) Start() {
go func() { go func() {
@ -190,15 +170,21 @@ func (sw *SlidingWindow) SetCallback(callback func([]model.Row)) {
sw.callback = callback sw.callback = callback
} }
// GetResults 获取滑动窗口内的当前数据 func (sw *SlidingWindow) NextSlot() *model.TimeSlot {
func (sw *SlidingWindow) GetResults() []interface{} { if sw.currentSlot == nil {
// 加锁以保证数据的并发安全 return nil
sw.mu.Lock()
defer sw.mu.Unlock()
// 提取出 Data 字段组成 []interface{} 类型的数据
resultData := make([]interface{}, 0, len(sw.data))
for _, item := range sw.data {
resultData = append(resultData, item.Data)
} }
return resultData start := sw.currentSlot.Start.Add(sw.slide)
end := sw.currentSlot.End.Add(sw.slide)
next := model.NewTimeSlot(&start, &end)
return next
}
// createSlot 创建一个新的时间槽位
func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot {
// 创建一个新的时间槽位
start := timex.AlignTimeToWindow(t, sw.size)
end := start.Add(sw.size)
slot := model.NewTimeSlot(&start, &end)
return slot
} }