mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-05-18 16:00:30 +00:00
feat: 增强 WITH 子句功能
- 支持通过 TIMESTAMP 指定时间戳属性名 - 支持通过 TIMEUNIT 指定时间单位(ms/ss/mm/hh/dd)
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
aggregator2 "github.com/rulego/streamsql/aggregator"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WindowConfig WindowConfig
|
||||
GroupFields []string
|
||||
SelectFields map[string]aggregator2.AggregateType
|
||||
}
|
||||
type WindowConfig struct {
|
||||
Type string
|
||||
Params map[string]interface{}
|
||||
TsProp string
|
||||
TimeUnit time.Duration
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type RowEvent interface {
|
||||
GetTimestamp() time.Time
|
||||
}
|
||||
|
||||
type Row struct {
|
||||
Timestamp time.Time
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// GetTimestamp 获取时间戳
|
||||
func (r *Row) GetTimestamp() time.Time {
|
||||
return r.Timestamp
|
||||
}
|
||||
+14
-9
@@ -2,12 +2,13 @@ package rsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/rulego/streamsql/window"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/rulego/streamsql/window"
|
||||
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/stream"
|
||||
)
|
||||
|
||||
type SelectStatement struct {
|
||||
@@ -24,12 +25,14 @@ type Field struct {
|
||||
}
|
||||
|
||||
type WindowDefinition struct {
|
||||
Type string
|
||||
Params []interface{}
|
||||
Type string
|
||||
Params []interface{}
|
||||
TsProp string
|
||||
TimeUnit time.Duration
|
||||
}
|
||||
|
||||
// ToStreamConfig 将AST转换为Stream配置
|
||||
func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) {
|
||||
func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) {
|
||||
if s.Source == "" {
|
||||
return nil, "", fmt.Errorf("missing FROM clause")
|
||||
}
|
||||
@@ -51,10 +54,12 @@ func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) {
|
||||
}
|
||||
|
||||
// 构建Stream配置
|
||||
config := stream.Config{
|
||||
WindowConfig: stream.WindowConfig{
|
||||
Type: windowType,
|
||||
Params: params,
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: windowType,
|
||||
Params: params,
|
||||
TsProp: s.Window.TsProp,
|
||||
TimeUnit: s.Window.TimeUnit,
|
||||
},
|
||||
GroupFields: extractGroupFields(s),
|
||||
SelectFields: buildSelectFields(s.Fields),
|
||||
|
||||
@@ -34,6 +34,9 @@ const (
|
||||
TokenSliding
|
||||
TokenCounting
|
||||
TokenSession
|
||||
TokenWITH
|
||||
TokenTimestamp
|
||||
TokenTimeUnit
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
@@ -204,6 +207,12 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenCounting, Value: ident}
|
||||
case "SESSIONWINDOW":
|
||||
return Token{Type: TokenSession, Value: ident}
|
||||
case "WITH":
|
||||
return Token{Type: TokenWITH, Value: ident}
|
||||
case "TIMESTAMP":
|
||||
return Token{Type: TokenTimestamp, Value: ident}
|
||||
case "TIMEUNIT":
|
||||
return Token{Type: TokenTimeUnit, Value: ident}
|
||||
default:
|
||||
return Token{Type: TokenIdent, Value: ident}
|
||||
}
|
||||
|
||||
+78
-3
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
@@ -39,6 +40,10 @@ func (p *Parser) Parse() (*SelectStatement, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.parseWith(stmt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
}
|
||||
func (p *Parser) parseSelect(stmt *SelectStatement) error {
|
||||
@@ -125,9 +130,14 @@ func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) erro
|
||||
params = append(params, convertValue(valTok.Value))
|
||||
}
|
||||
|
||||
stmt.Window = WindowDefinition{
|
||||
Type: winType,
|
||||
Params: params,
|
||||
if &stmt.Window != nil {
|
||||
stmt.Window.Params = params
|
||||
stmt.Window.Type = winType
|
||||
} else {
|
||||
stmt.Window = WindowDefinition{
|
||||
Type: winType,
|
||||
Params: params,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -185,3 +195,68 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseWith(stmt *SelectStatement) error {
|
||||
p.lexer.NextToken() // 跳过(
|
||||
for p.lexer.peekChar() != ')' {
|
||||
valTok := p.lexer.NextToken()
|
||||
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
|
||||
break
|
||||
}
|
||||
if valTok.Type == TokenComma {
|
||||
continue
|
||||
}
|
||||
|
||||
if valTok.Type == TokenTimestamp {
|
||||
next := p.lexer.NextToken()
|
||||
if next.Type == TokenEQ {
|
||||
next = p.lexer.NextToken()
|
||||
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
|
||||
next.Value = strings.Trim(next.Value, "'")
|
||||
}
|
||||
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
|
||||
if stmt.Window.Type == "" {
|
||||
stmt.Window = WindowDefinition{
|
||||
TsProp: next.Value,
|
||||
}
|
||||
} else {
|
||||
stmt.Window.TsProp = next.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
if valTok.Type == TokenTimeUnit {
|
||||
timeUnit := time.Minute
|
||||
next := p.lexer.NextToken()
|
||||
if next.Type == TokenEQ {
|
||||
next = p.lexer.NextToken()
|
||||
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
|
||||
next.Value = strings.Trim(next.Value, "'")
|
||||
}
|
||||
switch next.Value {
|
||||
case "dd":
|
||||
timeUnit = 24 * time.Hour
|
||||
case "hh":
|
||||
timeUnit = time.Hour
|
||||
case "mi":
|
||||
timeUnit = time.Minute
|
||||
case "ss":
|
||||
timeUnit = time.Second
|
||||
case "ms":
|
||||
timeUnit = time.Millisecond
|
||||
default:
|
||||
|
||||
}
|
||||
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
|
||||
if stmt.Window.Type == "" {
|
||||
stmt.Window = WindowDefinition{
|
||||
TimeUnit: timeUnit,
|
||||
}
|
||||
} else {
|
||||
stmt.Window.TimeUnit = timeUnit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+28
-7
@@ -1,24 +1,25 @@
|
||||
package rsql
|
||||
|
||||
import (
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/stream"
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/model"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected *stream.Config
|
||||
expected *model.Config
|
||||
condition string
|
||||
}{
|
||||
{
|
||||
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s')",
|
||||
expected: &stream.Config{
|
||||
WindowConfig: stream.WindowConfig{
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{
|
||||
"size": 10 * time.Second,
|
||||
@@ -33,8 +34,8 @@ func TestParseSQL(t *testing.T) {
|
||||
},
|
||||
{
|
||||
sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow('20s', '5s')",
|
||||
expected: &stream.Config{
|
||||
WindowConfig: stream.WindowConfig{
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "sliding",
|
||||
Params: map[string]interface{}{
|
||||
"size": 20 * time.Second,
|
||||
@@ -49,6 +50,23 @@ func TestParseSQL(t *testing.T) {
|
||||
},
|
||||
condition: "",
|
||||
},
|
||||
{
|
||||
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ",
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{
|
||||
"size": 10 * time.Second,
|
||||
},
|
||||
TsProp: "ts",
|
||||
},
|
||||
GroupFields: []string{"deviceId"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"aa": "avg",
|
||||
},
|
||||
},
|
||||
condition: "deviceId == 'aa'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -64,6 +82,9 @@ func TestParseSQL(t *testing.T) {
|
||||
assert.Equal(t, tt.expected.GroupFields, config.GroupFields)
|
||||
assert.Equal(t, tt.expected.SelectFields, config.SelectFields)
|
||||
assert.Equal(t, tt.condition, cond)
|
||||
if tt.expected.WindowConfig.TsProp != "" {
|
||||
assert.Equal(t, tt.expected.WindowConfig.TsProp, config.WindowConfig.TsProp)
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestWindowParamParsing(t *testing.T) {
|
||||
|
||||
+5
-15
@@ -5,33 +5,23 @@ import (
|
||||
"time"
|
||||
|
||||
aggregator2 "github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/rulego/streamsql/parser"
|
||||
"github.com/rulego/streamsql/window"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WindowConfig WindowConfig
|
||||
GroupFields []string
|
||||
SelectFields map[string]aggregator2.AggregateType
|
||||
}
|
||||
|
||||
type WindowConfig struct {
|
||||
Type string
|
||||
Params map[string]interface{}
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
dataChan chan interface{}
|
||||
filter parser.Condition
|
||||
Window window.Window
|
||||
aggregator aggregator2.Aggregator
|
||||
config Config
|
||||
config model.Config
|
||||
sinks []func(interface{})
|
||||
resultChan chan interface{} // 结果通道
|
||||
}
|
||||
|
||||
func NewStream(config Config) (*Stream, error) {
|
||||
win, err := window.CreateWindow(config.WindowConfig.Type, config.WindowConfig.Params)
|
||||
func NewStream(config model.Config) (*Stream, error) {
|
||||
win, err := window.CreateWindow(config.WindowConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -108,5 +98,5 @@ func (s *Stream) GetResultsChan() <-chan interface{} {
|
||||
}
|
||||
|
||||
func NewStreamProcessor() (*Stream, error) {
|
||||
return NewStream(Config{})
|
||||
return NewStream(model.Config{})
|
||||
}
|
||||
|
||||
@@ -7,13 +7,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamProcess(t *testing.T) {
|
||||
config := Config{
|
||||
WindowConfig: WindowConfig{
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{"size": time.Second},
|
||||
},
|
||||
@@ -77,8 +78,8 @@ func TestStreamProcess(t *testing.T) {
|
||||
|
||||
// 不设置过滤器
|
||||
func TestStreamWithoutFilter(t *testing.T) {
|
||||
config := Config{
|
||||
WindowConfig: WindowConfig{
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "sliding",
|
||||
Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second},
|
||||
},
|
||||
@@ -158,8 +159,8 @@ func TestStreamWithoutFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIncompleteStreamProcess(t *testing.T) {
|
||||
config := Config{
|
||||
WindowConfig: WindowConfig{
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{"size": time.Second},
|
||||
},
|
||||
|
||||
@@ -2,13 +2,18 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var _ Window = (*CountingWindow)(nil)
|
||||
|
||||
type CountingWindow struct {
|
||||
config model.WindowConfig
|
||||
threshold int
|
||||
count int
|
||||
mu sync.Mutex
|
||||
@@ -21,17 +26,26 @@ type CountingWindow struct {
|
||||
triggerChan chan struct{}
|
||||
}
|
||||
|
||||
func NewCountingWindow(threshold int, callback func([]interface{})) *CountingWindow {
|
||||
func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &CountingWindow{
|
||||
threshold := cast.ToInt(config.Params["count"])
|
||||
if threshold <= 0 {
|
||||
return nil, fmt.Errorf("threshold must be a positive integer")
|
||||
}
|
||||
|
||||
cw := &CountingWindow{
|
||||
threshold: threshold,
|
||||
dataBuffer: make([]interface{}, 0, threshold),
|
||||
outputChan: make(chan []interface{}, 10),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
callback: callback,
|
||||
triggerChan: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
if callback, ok := config.Params["callback"].(func([]interface{})); ok {
|
||||
cw.SetCallback(callback)
|
||||
}
|
||||
return cw, nil
|
||||
}
|
||||
|
||||
func (cw *CountingWindow) Add(data interface{}) {
|
||||
|
||||
@@ -2,10 +2,12 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -14,8 +16,13 @@ func TestCountingWindow(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
// Test case 1: Normal operation
|
||||
cw := NewCountingWindow(3, func(results []interface{}) {
|
||||
t.Logf("Received results: %v", results)
|
||||
cw, _ := NewCountingWindow(model.WindowConfig{
|
||||
Params: map[string]interface{}{
|
||||
"count": 3,
|
||||
"callback": func(results []interface{}) {
|
||||
t.Logf("Received results: %v", results)
|
||||
},
|
||||
},
|
||||
})
|
||||
go cw.Start()
|
||||
|
||||
@@ -50,8 +57,11 @@ func TestCountingWindow(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCountingWindowBadThreshold(t *testing.T) {
|
||||
_, err := CreateWindow("counting", map[string]interface{}{
|
||||
"count": 0,
|
||||
_, err := CreateWindow(model.WindowConfig{
|
||||
Type: "counting",
|
||||
Params: map[string]interface{}{
|
||||
"count": 0,
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
+47
-27
@@ -2,7 +2,10 @@ package window
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/cast"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,39 +25,56 @@ type Window interface {
|
||||
Trigger()
|
||||
}
|
||||
|
||||
func CreateWindow(windowType string, params map[string]interface{}) (Window, error) {
|
||||
switch windowType {
|
||||
func CreateWindow(config model.WindowConfig) (Window, error) {
|
||||
switch config.Type {
|
||||
case TypeTumbling:
|
||||
size, err := cast.ToDurationE(params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for tumbling window: %v", err)
|
||||
}
|
||||
return NewTumblingWindow(size), nil
|
||||
return NewTumblingWindow(config)
|
||||
case TypeSliding:
|
||||
size, err := cast.ToDurationE(params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for sliding window: %v", err)
|
||||
}
|
||||
slide, err := cast.ToDurationE(params["slide"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid slide for sliding window: %v", err)
|
||||
}
|
||||
return NewSlidingWindow(size, slide), nil
|
||||
return NewSlidingWindow(config)
|
||||
case TypeCounting:
|
||||
count := cast.ToInt(params["count"])
|
||||
if count <= 0 {
|
||||
return nil, fmt.Errorf("count must be a positive integer")
|
||||
}
|
||||
cw := NewCountingWindow(count, nil)
|
||||
if callback, ok := params["callback"].(func([]interface{})); ok {
|
||||
cw.SetCallback(callback)
|
||||
}
|
||||
return cw, nil
|
||||
return NewCountingWindow(config)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported window type: %s", windowType)
|
||||
return nil, fmt.Errorf("unsupported window type: %s", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *CountingWindow) SetCallback(callback func([]interface{})) {
|
||||
cw.callback = callback
|
||||
}
|
||||
|
||||
// GetTimestamp 从数据中获取时间戳。
|
||||
func GetTimestamp(data interface{}, tsProp string) time.Time {
|
||||
if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok {
|
||||
return ts.GetTimestamp()
|
||||
} else if tsProp != "" {
|
||||
v := reflect.ValueOf(data)
|
||||
|
||||
// 处理不同类型
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
// 如果是结构体,使用反射获取字段值
|
||||
if f := v.FieldByName(tsProp); f.IsValid() {
|
||||
if t, ok := f.Interface().(time.Time); ok {
|
||||
return t
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
// 如果是map,直接通过key获取值
|
||||
if v.Type().Key().Kind() == reflect.String {
|
||||
if value := v.MapIndex(reflect.ValueOf(tsProp)); value.IsValid() {
|
||||
return value.Interface().(time.Time)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。
|
||||
func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time {
|
||||
trunc := t.Truncate(timeUnit)
|
||||
if !roundUp {
|
||||
return trunc.Add(timeUnit)
|
||||
}
|
||||
return trunc
|
||||
}
|
||||
|
||||
+24
-17
@@ -2,8 +2,12 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// 确保 SlidingWindow 结构体实现了 Window 接口
|
||||
@@ -11,12 +15,14 @@ var _ Window = (*SlidingWindow)(nil)
|
||||
|
||||
// TimedData 用于包装数据和时间戳
|
||||
type TimedData struct {
|
||||
Data interface{}
|
||||
Timestamp time.Time
|
||||
Data interface{}
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// SlidingWindow 表示一个滑动窗口,用于按时间范围处理数据
|
||||
type SlidingWindow struct {
|
||||
// config 窗口的配置信息
|
||||
config model.WindowConfig
|
||||
// 窗口的总大小,即窗口覆盖的时间范围
|
||||
size time.Duration
|
||||
// 窗口每次滑动的时间间隔
|
||||
@@ -39,17 +45,26 @@ type SlidingWindow struct {
|
||||
|
||||
// NewSlidingWindow 创建一个新的滑动窗口实例
|
||||
// 参数 size 表示窗口的总大小,slide 表示窗口每次滑动的时间间隔
|
||||
func NewSlidingWindow(size, slide time.Duration) *SlidingWindow {
|
||||
func NewSlidingWindow(config model.WindowConfig) (*SlidingWindow, error) {
|
||||
// 创建一个可取消的上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
size, err := cast.ToDurationE(config.Params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for sliding window: %v", err)
|
||||
}
|
||||
slide, err := cast.ToDurationE(config.Params["slide"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid slide for sliding window: %v", err)
|
||||
}
|
||||
return &SlidingWindow{
|
||||
config: config,
|
||||
size: size,
|
||||
slide: slide,
|
||||
outputChan: make(chan []interface{}, 10),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
data: make([]TimedData, 0),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add 向滑动窗口中添加数据
|
||||
@@ -58,19 +73,11 @@ func (sw *SlidingWindow) Add(data interface{}) {
|
||||
// 加锁以保证数据的并发安全
|
||||
sw.mu.Lock()
|
||||
defer sw.mu.Unlock()
|
||||
|
||||
var timestamp time.Time
|
||||
if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok {
|
||||
timestamp = ts.GetTimestamp()
|
||||
} else {
|
||||
timestamp = time.Now()
|
||||
}
|
||||
|
||||
// 将数据添加到窗口的数据列表中
|
||||
sw.data = append(sw.data, TimedData{
|
||||
Data: data,
|
||||
Timestamp: timestamp,
|
||||
})
|
||||
// 将数据添加到窗口的数据列表中
|
||||
sw.data = append(sw.data, TimedData{
|
||||
Data: data,
|
||||
Timestamp: GetTimestamp(data, sw.config.TsProp),
|
||||
})
|
||||
}
|
||||
|
||||
// Start 启动滑动窗口,开始定时触发窗口
|
||||
|
||||
@@ -2,16 +2,24 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSlidingWindow(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sw := NewSlidingWindow(2*time.Second, 1*time.Second)
|
||||
sw, _ := NewSlidingWindow(model.WindowConfig{
|
||||
Params: map[string]interface{}{
|
||||
"size": "2s",
|
||||
"slide": "1s",
|
||||
},
|
||||
TsProp: "Ts",
|
||||
})
|
||||
sw.SetCallback(func(results []interface{}) {
|
||||
t.Logf("Received results: %v", results)
|
||||
})
|
||||
@@ -19,10 +27,15 @@ func TestSlidingWindow(t *testing.T) {
|
||||
|
||||
// 添加数据
|
||||
now := time.Now()
|
||||
sw.Add(now.Add(-3 * time.Second))
|
||||
sw.Add(now.Add(-2 * time.Second))
|
||||
sw.Add(now.Add(-1 * time.Second))
|
||||
sw.Add(now)
|
||||
t_3 := TestDate{Ts: now.Add(-3 * time.Second)}
|
||||
t_2 := TestDate{Ts: now.Add(-2 * time.Second)}
|
||||
t_1 := TestDate{Ts: now.Add(-1 * time.Second)}
|
||||
t_0 := TestDate{Ts: now}
|
||||
|
||||
sw.Add(t_3)
|
||||
sw.Add(t_2)
|
||||
sw.Add(t_1)
|
||||
sw.Add(t_0)
|
||||
|
||||
// 等待一段时间,触发窗口
|
||||
time.Sleep(3 * time.Second)
|
||||
@@ -32,12 +45,40 @@ func TestSlidingWindow(t *testing.T) {
|
||||
var results []interface{}
|
||||
select {
|
||||
case results = <-resultsChan:
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-time.After(100 * time.Second):
|
||||
t.Fatal("No results received within timeout")
|
||||
}
|
||||
|
||||
// 预期结果:保留最近 2 秒内的数据
|
||||
assert.Len(t, results, 2)
|
||||
assert.Contains(t, results, now.Add(-1*time.Second))
|
||||
assert.Contains(t, results, now)
|
||||
assert.Contains(t, results, t_1)
|
||||
assert.Contains(t, results, t_0)
|
||||
}
|
||||
|
||||
type TestDate struct {
|
||||
Ts time.Time
|
||||
}
|
||||
|
||||
type TestDate2 struct {
|
||||
ts time.Time
|
||||
}
|
||||
|
||||
func (d TestDate2) GetTimestamp() time.Time {
|
||||
return d.ts
|
||||
}
|
||||
|
||||
func TestGetTimestamp(t *testing.T) {
|
||||
t_0 := time.Now()
|
||||
data := map[string]interface{}{"device": "aa", "age": 15.0, "score": 100, "ts": t_0}
|
||||
t_1 := GetTimestamp(data, "ts")
|
||||
|
||||
data_1 := TestDate{Ts: t_0}
|
||||
t_2 := GetTimestamp(data_1, "Ts")
|
||||
|
||||
data_2 := TestDate2{ts: t_0}
|
||||
t_3 := GetTimestamp(data_2, "")
|
||||
|
||||
assert.Equal(t, t_0, t_1)
|
||||
assert.Equal(t, t_0, t_2)
|
||||
assert.Equal(t, t_0, t_3)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,12 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// 确保 TumblingWindow 结构体实现了 Window 接口。
|
||||
@@ -12,6 +16,8 @@ var _ Window = (*TumblingWindow)(nil)
|
||||
|
||||
// TumblingWindow 表示一个滚动窗口,用于在固定时间间隔内收集数据并触发处理。
|
||||
type TumblingWindow struct {
|
||||
// config 是窗口的配置信息。
|
||||
config model.WindowConfig
|
||||
// size 是滚动窗口的时间大小,即窗口的持续时间。
|
||||
size time.Duration
|
||||
// mu 用于保护对窗口数据的并发访问。
|
||||
@@ -32,15 +38,20 @@ type TumblingWindow struct {
|
||||
|
||||
// NewTumblingWindow 创建一个新的滚动窗口实例。
|
||||
// 参数 size 是窗口的时间大小。
|
||||
func NewTumblingWindow(size time.Duration) *TumblingWindow {
|
||||
func NewTumblingWindow(config model.WindowConfig) (*TumblingWindow, error) {
|
||||
// 创建一个可取消的上下文。
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
size, err := cast.ToDurationE(config.Params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for tumbling window: %v", err)
|
||||
}
|
||||
return &TumblingWindow{
|
||||
config: config,
|
||||
size: size,
|
||||
outputChan: make(chan []interface{}, 10),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add 向滚动窗口添加数据。
|
||||
|
||||
@@ -2,16 +2,21 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTumblingWindow(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tw := NewTumblingWindow(2 * time.Second)
|
||||
tw, _ := NewTumblingWindow(model.WindowConfig{
|
||||
Type: "TumblingWindow",
|
||||
Params: map[string]interface{}{"size": "2s"},
|
||||
})
|
||||
tw.SetCallback(func(results []interface{}) {
|
||||
// Process results
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user