From 040cb6ba87616858332d21794c4658d9dd97046b Mon Sep 17 00:00:00 2001 From: rulego-team Date: Sat, 14 Jun 2025 21:47:19 +0800 Subject: [PATCH] fix:window data race --- window/counting_window.go | 8 ++++---- window/counting_window_test.go | 29 +++++++++++++++++++++++++++-- window/tumbling_window.go | 1 + 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/window/counting_window.go b/window/counting_window.go index 3704177..9690a71 100644 --- a/window/counting_window.go +++ b/window/counting_window.go @@ -77,12 +77,12 @@ func (cw *CountingWindow) Start() { if shouldTrigger { // 在持有锁的情况下立即处理 slot := cw.createSlot(cw.dataBuffer[:cw.threshold]) - for i := range cw.dataBuffer[:cw.threshold] { - // 由于Row是值类型,这里需要通过指针来修改Slot字段 - cw.dataBuffer[i].Slot = slot - } data := make([]types.Row, cw.threshold) copy(data, cw.dataBuffer[:cw.threshold]) + // 设置Slot字段到复制的数据中,避免修改原始dataBuffer + for i := range data { + data[i].Slot = slot + } if len(cw.dataBuffer) > cw.threshold { remaining := len(cw.dataBuffer) - cw.threshold diff --git a/window/counting_window_test.go b/window/counting_window_test.go index 297a529..7128aa7 100644 --- a/window/counting_window_test.go +++ b/window/counting_window_test.go @@ -51,10 +51,35 @@ func TestCountingWindow(t *testing.T) { case <-time.After(2 * time.Second): t.Error("No results received within timeout") } - assert.Len(t, cw.dataBuffer, 1) + // 验证窗口状态:添加第4个数据后,第一个窗口已触发,剩余1个数据(值为3) + // 继续添加2个数据,应该再次触发 + cw.Add(4) // 添加第5个数据 + cw.Add(5) // 添加第6个数据,应该再次触发(3,4,5) + + // 等待第二次触发 + select { + case res := <-resultsChan: + assert.Len(t, res, 3) + assert.Equal(t, 3, res[0].Data, "第二批第一个元素应该是3") + assert.Equal(t, 4, res[1].Data, "第二批第二个元素应该是4") + assert.Equal(t, 5, res[2].Data, "第二批第三个元素应该是5") + case <-time.After(2 * time.Second): + t.Error("No second results received within timeout") + } + // Test case 2: Reset cw.Reset() - assert.Len(t, cw.dataBuffer, 0) + // Reset后添加数据验证重置是否成功 + cw.Add(100) + cw.Add(101) + cw.Add(102) + select { + case res := <-resultsChan: + assert.Len(t, res, 3) + assert.Equal(t, 100, res[0].Data, "重置后第一个元素应该是100") + case <-time.After(2 * time.Second): + t.Error("No results after reset received within timeout") + } } func TestCountingWindowBadThreshold(t *testing.T) { diff --git a/window/tumbling_window.go b/window/tumbling_window.go index eb50256..7cbbf3a 100644 --- a/window/tumbling_window.go +++ b/window/tumbling_window.go @@ -115,6 +115,7 @@ func (tw *TumblingWindow) Start() { select { // 当定时器到期时,触发窗口。 case <-tw.timer.C: + // 在调用Trigger前不需要额外加锁,因为Trigger方法内部已经有锁保护 tw.Trigger() // 当上下文被取消时,停止定时器并退出循环。 case <-tw.ctx.Done():