Merge pull request #2 from dimon-83/main

增加window_start()、window_end()函数支持
This commit is contained in:
Whki
2025-04-08 21:26:24 +08:00
committed by GitHub
23 changed files with 1216 additions and 280 deletions
+81 -35
View File
@@ -3,26 +3,29 @@ package aggregator
import (
"math"
"sort"
"strconv"
"sync"
)
type AggregateType string
const (
Sum AggregateType = "sum"
Count AggregateType = "count"
Avg AggregateType = "avg"
Max AggregateType = "max"
Min AggregateType = "min"
StdDev AggregateType = "stddev"
Median AggregateType = "median"
Percentile AggregateType = "percentile"
Sum AggregateType = "sum"
Count AggregateType = "count"
Avg AggregateType = "avg"
Max AggregateType = "max"
Min AggregateType = "min"
StdDev AggregateType = "stddev"
Median AggregateType = "median"
Percentile AggregateType = "percentile"
WindowStart AggregateType = "window_start"
WindowEnd AggregateType = "window_end"
)
type AggregatorFunction interface {
New() AggregatorFunction
Add(value float64)
Result() float64
Add(value interface{})
Result() interface{}
}
type SumAggregator struct {
@@ -33,11 +36,12 @@ func (s *SumAggregator) New() AggregatorFunction {
return &SumAggregator{}
}
func (s *SumAggregator) Add(v float64) {
s.value += v
func (s *SumAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
s.value += vv
}
func (s *SumAggregator) Result() float64 {
func (s *SumAggregator) Result() interface{} {
return s.value
}
@@ -49,11 +53,11 @@ func (s *CountAggregator) New() AggregatorFunction {
return &CountAggregator{}
}
func (c *CountAggregator) Add(_ float64) {
func (c *CountAggregator) Add(_ interface{}) {
c.count++
}
func (c *CountAggregator) Result() float64 {
func (c *CountAggregator) Result() interface{} {
return float64(c.count)
}
@@ -66,12 +70,13 @@ func (a *AvgAggregator) New() AggregatorFunction {
return &AvgAggregator{}
}
func (a *AvgAggregator) Add(v float64) {
a.sum += v
func (a *AvgAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
a.sum += vv
a.count++
}
func (a *AvgAggregator) Result() float64 {
func (a *AvgAggregator) Result() interface{} {
if a.count == 0 {
return 0
}
@@ -117,6 +122,10 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction {
return &MedianAggregator{}
case Percentile:
return &PercentileAggregator{p: 0.95}
case WindowStart:
return &WindowStartAggregator{}
case WindowEnd:
return &WindowEndAggregator{}
default:
panic("unsupported aggregator type: " + aggType)
}
@@ -150,11 +159,12 @@ func (m *MedianAggregator) New() AggregatorFunction {
return &MedianAggregator{}
}
func (m *MedianAggregator) Add(val float64) {
m.values = append(m.values, val)
func (m *MedianAggregator) Add(val interface{}) {
var vv float64 = ConvertToFloat64(val, 0)
m.values = append(m.values, vv)
}
func (m *MedianAggregator) Result() float64 {
func (m *MedianAggregator) Result() interface{} {
sort.Float64s(m.values)
return m.values[len(m.values)/2]
}
@@ -168,8 +178,9 @@ func (p *PercentileAggregator) New() AggregatorFunction {
return &PercentileAggregator{}
}
func (p *PercentileAggregator) Add(v float64) {
p.values = append(p.values, v)
func (p *PercentileAggregator) Add(v interface{}) {
vv := ConvertToFloat64(v, 0)
p.values = append(p.values, vv)
}
type MinAggregator struct {
@@ -183,14 +194,15 @@ func (s *MinAggregator) New() AggregatorFunction {
}
}
func (m *MinAggregator) Add(v float64) {
if m.first || v < m.value {
m.value = v
func (m *MinAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, math.MaxFloat64)
if m.first || vv < m.value {
m.value = vv
m.first = false
}
}
func (m *MinAggregator) Result() float64 {
func (m *MinAggregator) Result() interface{} {
return m.value
}
@@ -203,22 +215,24 @@ func (m *MaxAggregator) New() AggregatorFunction {
return &MaxAggregator{}
}
func (m *MaxAggregator) Add(v float64) {
if m.first || v > m.value {
m.value = v
func (m *MaxAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
if m.first || vv > m.value {
m.value = vv
m.first = false
}
}
func (m *MaxAggregator) Result() float64 {
func (m *MaxAggregator) Result() interface{} {
return m.value
}
func (s *StdDevAggregator) Add(v float64) {
s.values = append(s.values, v)
func (s *StdDevAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
s.values = append(s.values, vv)
}
func (s *StdDevAggregator) Result() float64 {
func (s *StdDevAggregator) Result() interface{} {
if len(s.values) < 2 {
return 0
}
@@ -230,7 +244,7 @@ func (s *StdDevAggregator) Result() float64 {
return math.Sqrt(sum / float64(len(s.values)-1))
}
func (p *PercentileAggregator) Result() float64 {
func (p *PercentileAggregator) Result() interface{} {
if len(p.values) == 0 {
return 0
}
@@ -246,3 +260,35 @@ func calculateAverage(values []float64) float64 {
}
return sum / float64(len(values))
}
func ConvertToFloat64(v interface{}, defaultVal float64) float64 {
var vv float64 = defaultVal
switch val := v.(type) {
case float64:
vv = val
case float32:
vv = float64(val)
case int:
vv = float64(val)
case int32:
vv = float64(val)
case int64:
vv = float64(val)
case uint:
vv = float64(val)
case uint32:
vv = float64(val)
case uint64:
vv = float64(val)
case string:
// 处理字符串类型的转换
if floatValue, err := strconv.ParseFloat(val, 64); err == nil {
vv = floatValue
} else {
panic("unsupported type for sum aggregator")
}
default:
panic("unsupported type for sum aggregator")
}
return vv
}
+45
View File
@@ -0,0 +1,45 @@
package aggregator
type ContextAggregator interface {
GetContextKey() string
}
type WindowStartAggregator struct {
val interface{}
}
func (w *WindowStartAggregator) New() AggregatorFunction {
return &WindowStartAggregator{}
}
func (w *WindowStartAggregator) Add(val interface{}) {
w.val = val
}
func (w *WindowStartAggregator) Result() interface{} {
return w.val
}
func (w *WindowStartAggregator) GetContextKey() string {
return "window_start"
}
type WindowEndAggregator struct {
val interface{}
}
func (w *WindowEndAggregator) New() AggregatorFunction {
return &WindowEndAggregator{}
}
func (w *WindowEndAggregator) Add(val interface{}) {
w.val = val
}
func (w *WindowEndAggregator) Result() interface{} {
return w.val
}
func (w *WindowEndAggregator) GetContextKey() string {
return "window_end"
}
+43 -2
View File
@@ -9,6 +9,7 @@ import (
type Aggregator interface {
Add(data interface{}) error
Put(key string, val interface{}) error
GetResults() ([]map[string]interface{}, error)
Reset()
}
@@ -19,9 +20,11 @@ type GroupAggregator struct {
aggregators map[string]AggregatorFunction
groups map[string]map[string]AggregatorFunction
mu sync.RWMutex
context map[string]interface{}
fieldAlias map[string]string
}
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator {
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator {
aggregators := make(map[string]AggregatorFunction)
for field, aggType := range fieldMap {
@@ -33,8 +36,20 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType)
groupFields: groupFields,
aggregators: aggregators,
groups: make(map[string]map[string]AggregatorFunction),
fieldAlias: fieldAlias,
}
}
func (ga *GroupAggregator) Put(key string, val interface{}) error {
ga.mu.Lock() // 获取写锁
defer ga.mu.Unlock() // 确保函数返回时释放锁
if ga.context == nil {
ga.context = make(map[string]interface{})
}
ga.context[key] = val
return nil
}
func (ga *GroupAggregator) Add(data interface{}) error {
ga.mu.Lock() // 获取写锁
defer ga.mu.Unlock() // 确保函数返回时释放锁
@@ -114,6 +129,19 @@ func (ga *GroupAggregator) Add(data interface{}) error {
if !f.IsValid() {
//return fmt.Errorf("field %s not found", field)
//fmt.Printf("field %s not found in %v \n ", field, data)
// 尝试从context中获取
if ga.context != nil {
if groupAgg, exists := ga.groups[key][field]; exists {
if _, ok := groupAgg.(ContextAggregator); ok {
key := groupAgg.(ContextAggregator).GetContextKey()
if val, exists := ga.context[key]; exists {
groupAgg.Add(val)
//fmt.Printf("add agg group by %s:%s , %v \n", key, field, value)
}
}
}
}
continue
}
@@ -151,7 +179,20 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
group[field] = fields[i]
}
for field, agg := range aggregators {
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
if _, ok := agg.(ContextAggregator); ok {
if alias, ok := ga.fieldAlias[field]; ok {
group[alias] = agg.Result()
} else {
group[field] = agg.Result()
}
} else {
if alias, ok := ga.fieldAlias[field]; ok {
group[alias] = agg.Result()
} else {
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
}
}
}
result = append(result, group)
}
+18 -5
View File
@@ -19,13 +19,17 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) {
"Data1": Sum,
"Data2": Sum,
},
map[string]string{
"Data1": "Data1_sum",
"Data2": "Data2_sum",
},
)
testData := []map[string]interface{}{
{"device": "aa", "data1": 20, "data2": 30},
{"device": "aa", "data1": 21, "data2": 0},
{"device": "bb", "data1": 15, "data2": 20},
{"device": "bb", "data1": 16, "data2": 20},
{"Device": "aa", "Data1": 20, "Data2": 30},
{"Device": "aa", "Data1": 21, "Data2": 0},
{"Device": "bb", "Data1": 15, "Data2": 20},
{"Device": "bb", "Data1": 16, "Data2": 20},
}
for _, d := range testData {
@@ -47,6 +51,9 @@ func TestGroupAggregator_SingleField(t *testing.T) {
map[string]AggregateType{
"Data1": Sum,
},
map[string]string{
"Data1": "Data1_sum",
},
)
testData := []map[string]interface{}{
@@ -75,6 +82,12 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
"Data3": Max,
"Data4": Min,
},
map[string]string{
"Data1": "Data1_sum",
"Data2": "Data2_avg",
"Data3": "Data3_max",
"Data4": "Data4_min",
},
)
testData := []map[string]interface{}{
@@ -88,7 +101,7 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
expected := []map[string]interface{}{
{
"Device": "cc",
"Device": "cc",
"Data1_sum": 30.0,
"Data2_avg": 5.0,
"Data3_max": 12.0,
+20
View File
@@ -0,0 +1,20 @@
package model
import (
"time"
"github.com/rulego/streamsql/aggregator"
)
type Config struct {
WindowConfig WindowConfig
GroupFields []string
SelectFields map[string]aggregator.AggregateType
FieldAlias map[string]string
}
type WindowConfig struct {
Type string
Params map[string]interface{}
TsProp string
TimeUnit time.Duration
}
+20
View File
@@ -0,0 +1,20 @@
package model
import (
"time"
)
type RowEvent interface {
GetTimestamp() time.Time
}
type Row struct {
Timestamp time.Time
Data interface{}
Slot *TimeSlot
}
// GetTimestamp 获取时间戳
func (r *Row) GetTimestamp() time.Time {
return r.Timestamp
}
+66
View File
@@ -0,0 +1,66 @@
package model
import (
"time"
)
type TimeSlot struct {
Start *time.Time
End *time.Time
}
func NewTimeSlot(start, end *time.Time) *TimeSlot {
return &TimeSlot{
Start: start,
End: end,
}
}
// Hash 生成槽位的哈希值
func (ts TimeSlot) Hash() uint64 {
// 将开始时间和结束时间转换为 Unix 时间戳(纳秒级)
startNano := ts.Start.UnixNano()
endNano := ts.End.UnixNano()
// 使用简单但高效的哈希算法
// 将两个时间戳组合成一个唯一的哈希值
hash := uint64(startNano)
hash = (hash << 32) | (hash >> 32)
hash = hash ^ uint64(endNano)
return hash
}
// Contains 检查给定时间是否在槽位范围内
func (ts TimeSlot) Contains(t time.Time) bool {
return (t.Equal(*ts.Start) || t.After(*ts.Start)) &&
t.Before(*ts.End)
}
func (ts *TimeSlot) GetStartTime() *time.Time {
if ts == nil || ts.Start == nil {
return nil
}
return ts.Start
}
func (ts *TimeSlot) GetEndTime() *time.Time {
if ts == nil || ts.End == nil {
return nil
}
return ts.End
}
func (ts *TimeSlot) WindowStart() int64 {
if ts == nil || ts.Start == nil {
return 0
}
return ts.Start.UnixNano()
}
func (ts *TimeSlot) WindowEnd() int64 {
if ts == nil || ts.End == nil {
return 0
}
return ts.End.UnixNano()
}
+48 -19
View File
@@ -2,12 +2,13 @@ package rsql
import (
"fmt"
"github.com/rulego/streamsql/window"
"strings"
"time"
"github.com/rulego/streamsql/model"
"github.com/rulego/streamsql/window"
"github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/stream"
)
type SelectStatement struct {
@@ -21,15 +22,18 @@ type SelectStatement struct {
type Field struct {
Expression string
Alias string
AggType string
}
type WindowDefinition struct {
Type string
Params []interface{}
Type string
Params []interface{}
TsProp string
TimeUnit time.Duration
}
// ToStreamConfig 将AST转换为Stream配置
func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) {
func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) {
if s.Source == "" {
return nil, "", fmt.Errorf("missing FROM clause")
}
@@ -49,15 +53,18 @@ func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) {
if err != nil {
return nil, "", fmt.Errorf("解析窗口参数失败: %w", err)
}
aggs, fields := buildSelectFields(s.Fields)
// 构建Stream配置
config := stream.Config{
WindowConfig: stream.WindowConfig{
Type: windowType,
Params: params,
config := model.Config{
WindowConfig: model.WindowConfig{
Type: windowType,
Params: params,
TsProp: s.Window.TsProp,
TimeUnit: s.Window.TimeUnit,
},
GroupFields: extractGroupFields(s),
SelectFields: buildSelectFields(s.Fields),
SelectFields: aggs,
FieldAlias: fields,
}
return &config, s.Condition, nil
@@ -73,28 +80,50 @@ func extractGroupFields(s *SelectStatement) []string {
return fields
}
func buildSelectFields(fields []Field) map[string]aggregator.AggregateType {
func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string) {
selectFields := make(map[string]aggregator.AggregateType)
fieldMap = make(map[string]string)
for _, f := range fields {
if alias := f.Alias; alias != "" {
selectFields[alias] = parseAggregateType(f.Expression)
t, n := parseAggregateType(f.Expression)
if n != "" {
selectFields[n] = t
fieldMap[n] = alias
} else {
selectFields[alias] = t
}
}
}
return selectFields
return selectFields, fieldMap
}
func parseAggregateType(expr string) aggregator.AggregateType {
func parseAggregateType(expr string) (aggType aggregator.AggregateType, name string) {
if strings.Contains(expr, "avg(") {
return "avg"
return "avg", extractAggField(expr)
}
if strings.Contains(expr, "sum(") {
return "sum"
return "sum", extractAggField(expr)
}
if strings.Contains(expr, "max(") {
return "max"
return "max", extractAggField(expr)
}
if strings.Contains(expr, "min(") {
return "min"
return "min", extractAggField(expr)
}
if strings.Contains(expr, "window_start(") {
return "window_start", "window_start"
}
if strings.Contains(expr, "window_end(") {
return "window_end", "window_end"
}
return "", ""
}
func extractAggField(expr string) string {
start := strings.Index(expr, "(")
end := strings.LastIndex(expr, ")")
if start >= 0 && end > start {
return strings.TrimSpace(expr[start+1 : end])
}
return ""
}
+9
View File
@@ -34,6 +34,9 @@ const (
TokenSliding
TokenCounting
TokenSession
TokenWITH
TokenTimestamp
TokenTimeUnit
)
type Token struct {
@@ -204,6 +207,12 @@ func (l *Lexer) lookupIdent(ident string) Token {
return Token{Type: TokenCounting, Value: ident}
case "SESSIONWINDOW":
return Token{Type: TokenSession, Value: ident}
case "WITH":
return Token{Type: TokenWITH, Value: ident}
case "TIMESTAMP":
return Token{Type: TokenTimestamp, Value: ident}
case "TIMEUNIT":
return Token{Type: TokenTimeUnit, Value: ident}
default:
return Token{Type: TokenIdent, Value: ident}
}
+78 -3
View File
@@ -4,6 +4,7 @@ import (
"errors"
"strconv"
"strings"
"time"
)
type Parser struct {
@@ -39,6 +40,10 @@ func (p *Parser) Parse() (*SelectStatement, error) {
return nil, err
}
if err := p.parseWith(stmt); err != nil {
return nil, err
}
return stmt, nil
}
func (p *Parser) parseSelect(stmt *SelectStatement) error {
@@ -125,9 +130,14 @@ func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) erro
params = append(params, convertValue(valTok.Value))
}
stmt.Window = WindowDefinition{
Type: winType,
Params: params,
if &stmt.Window != nil {
stmt.Window.Params = params
stmt.Window.Type = winType
} else {
stmt.Window = WindowDefinition{
Type: winType,
Params: params,
}
}
return nil
}
@@ -185,3 +195,68 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
}
return nil
}
func (p *Parser) parseWith(stmt *SelectStatement) error {
p.lexer.NextToken() // 跳过(
for p.lexer.peekChar() != ')' {
valTok := p.lexer.NextToken()
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
break
}
if valTok.Type == TokenComma {
continue
}
if valTok.Type == TokenTimestamp {
next := p.lexer.NextToken()
if next.Type == TokenEQ {
next = p.lexer.NextToken()
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
next.Value = strings.Trim(next.Value, "'")
}
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
if stmt.Window.Type == "" {
stmt.Window = WindowDefinition{
TsProp: next.Value,
}
} else {
stmt.Window.TsProp = next.Value
}
}
}
if valTok.Type == TokenTimeUnit {
timeUnit := time.Minute
next := p.lexer.NextToken()
if next.Type == TokenEQ {
next = p.lexer.NextToken()
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
next.Value = strings.Trim(next.Value, "'")
}
switch next.Value {
case "dd":
timeUnit = 24 * time.Hour
case "hh":
timeUnit = time.Hour
case "mi":
timeUnit = time.Minute
case "ss":
timeUnit = time.Second
case "ms":
timeUnit = time.Millisecond
default:
}
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
if stmt.Window.Type == "" {
stmt.Window = WindowDefinition{
TimeUnit: timeUnit,
}
} else {
stmt.Window.TimeUnit = timeUnit
}
}
}
}
return nil
}
+30 -9
View File
@@ -1,24 +1,25 @@
package rsql
import (
"github.com/rulego/streamsql/aggregator"
"testing"
"time"
"github.com/rulego/streamsql/stream"
"github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/model"
"github.com/stretchr/testify/assert"
)
func TestParseSQL(t *testing.T) {
tests := []struct {
sql string
expected *stream.Config
expected *model.Config
condition string
}{
{
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s')",
expected: &stream.Config{
WindowConfig: stream.WindowConfig{
expected: &model.Config{
WindowConfig: model.WindowConfig{
Type: "tumbling",
Params: map[string]interface{}{
"size": 10 * time.Second,
@@ -33,8 +34,8 @@ func TestParseSQL(t *testing.T) {
},
{
sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow('20s', '5s')",
expected: &stream.Config{
WindowConfig: stream.WindowConfig{
expected: &model.Config{
WindowConfig: model.WindowConfig{
Type: "sliding",
Params: map[string]interface{}{
"size": 20 * time.Second,
@@ -43,12 +44,29 @@ func TestParseSQL(t *testing.T) {
},
GroupFields: []string{"type"},
SelectFields: map[string]aggregator.AggregateType{
"max_score": "max",
"min_age": "min",
"score": "max",
"age": "min",
},
},
condition: "",
},
{
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ",
expected: &model.Config{
WindowConfig: model.WindowConfig{
Type: "tumbling",
Params: map[string]interface{}{
"size": 10 * time.Second,
},
TsProp: "ts",
},
GroupFields: []string{"deviceId"},
SelectFields: map[string]aggregator.AggregateType{
"aa": "avg",
},
},
condition: "deviceId == 'aa'",
},
}
for _, tt := range tests {
@@ -64,6 +82,9 @@ func TestParseSQL(t *testing.T) {
assert.Equal(t, tt.expected.GroupFields, config.GroupFields)
assert.Equal(t, tt.expected.SelectFields, config.SelectFields)
assert.Equal(t, tt.condition, cond)
if tt.expected.WindowConfig.TsProp != "" {
assert.Equal(t, tt.expected.WindowConfig.TsProp, config.WindowConfig.TsProp)
}
}
}
func TestWindowParamParsing(t *testing.T) {
+14 -17
View File
@@ -2,36 +2,27 @@ package stream
import (
"fmt"
"strings"
"time"
aggregator2 "github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/model"
"github.com/rulego/streamsql/parser"
"github.com/rulego/streamsql/window"
)
type Config struct {
WindowConfig WindowConfig
GroupFields []string
SelectFields map[string]aggregator2.AggregateType
}
type WindowConfig struct {
Type string
Params map[string]interface{}
}
type Stream struct {
dataChan chan interface{}
filter parser.Condition
Window window.Window
aggregator aggregator2.Aggregator
config Config
config model.Config
sinks []func(interface{})
resultChan chan interface{} // 结果通道
}
func NewStream(config Config) (*Stream, error) {
win, err := window.CreateWindow(config.WindowConfig.Type, config.WindowConfig.Params)
func NewStream(config model.Config) (*Stream, error) {
win, err := window.CreateWindow(config.WindowConfig)
if err != nil {
return nil, err
}
@@ -44,6 +35,9 @@ func NewStream(config Config) (*Stream, error) {
}
func (s *Stream) RegisterFilter(condition string) error {
if strings.TrimSpace(condition) == "" {
return nil
}
filter, err := parser.NewExprCondition(condition)
if err != nil {
return fmt.Errorf("compile filter error: %w", err)
@@ -57,7 +51,7 @@ func (s *Stream) Start() {
}
func (s *Stream) process() {
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields)
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias)
// 启动窗口处理协程
s.Window.Start()
@@ -76,10 +70,13 @@ func (s *Stream) process() {
case batch := <-s.Window.OutputChan():
// 处理窗口批数据
for _, item := range batch {
if err := s.aggregator.Add(item); err != nil {
s.aggregator.Put("window_start", item.Slot.WindowStart())
s.aggregator.Put("window_end", item.Slot.WindowEnd())
if err := s.aggregator.Add(item.Data); err != nil {
fmt.Printf("aggregate error: %v\n", err)
}
}
// 获取并发送聚合结果
if results, err := s.aggregator.GetResults(); err == nil {
// 发送结果到结果通道和 Sink 函数
@@ -108,5 +105,5 @@ func (s *Stream) GetResultsChan() <-chan interface{} {
}
func NewStreamProcessor() (*Stream, error) {
return NewStream(Config{})
return NewStream(model.Config{})
}
+99 -6
View File
@@ -7,13 +7,14 @@ import (
"time"
"github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStreamProcess(t *testing.T) {
config := Config{
WindowConfig: WindowConfig{
config := model.Config{
WindowConfig: model.WindowConfig{
Type: "tumbling",
Params: map[string]interface{}{"size": time.Second},
},
@@ -77,8 +78,8 @@ func TestStreamProcess(t *testing.T) {
// 不设置过滤器
func TestStreamWithoutFilter(t *testing.T) {
config := Config{
WindowConfig: WindowConfig{
config := model.Config{
WindowConfig: model.WindowConfig{
Type: "sliding",
Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second},
},
@@ -158,8 +159,8 @@ func TestStreamWithoutFilter(t *testing.T) {
}
func TestIncompleteStreamProcess(t *testing.T) {
config := Config{
WindowConfig: WindowConfig{
config := model.Config{
WindowConfig: model.WindowConfig{
Type: "tumbling",
Params: map[string]interface{}{"size": time.Second},
},
@@ -222,3 +223,95 @@ func TestIncompleteStreamProcess(t *testing.T) {
assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001)
assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001)
}
func TestWindowSlotAgg(t *testing.T) {
config := model.Config{
WindowConfig: model.WindowConfig{
Type: "sliding",
Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second},
TsProp: "ts",
},
GroupFields: []string{"device"},
SelectFields: map[string]aggregator.AggregateType{
"age": aggregator.Max,
"score": aggregator.Min,
"start": aggregator.WindowStart,
"end": aggregator.WindowEnd,
},
}
strm, err := NewStream(config)
require.NoError(t, err)
strm.Start()
// Add data every 500ms
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
testData := []interface{}{
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "ts": baseTime},
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "ts": baseTime.Add(1 * time.Second)},
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "ts": baseTime},
}
for _, data := range testData {
strm.AddData(data)
}
// 捕获结果
resultChan := make(chan interface{})
strm.AddSink(func(result interface{}) {
resultChan <- result
})
// 等待 3 秒触发窗口
time.Sleep(3 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var actual interface{}
select {
case actual = <-resultChan:
cancel()
case <-ctx.Done():
t.Fatal("Timeout waiting for results")
}
expected := []map[string]interface{}{
{
"device": "aa",
"age_max": 10.0,
"score_min": 100.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
{
"device": "bb",
"age_max": 3.0,
"score_min": 300.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
}
assert.IsType(t, []map[string]interface{}{}, actual)
resultSlice, ok := actual.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 2)
for _, expectedResult := range expected {
found := false
for _, resultMap := range resultSlice {
//if resultMap, ok := result.(map[string]interface{}); ok {
if resultMap["device"] == expectedResult["device"] {
assert.InEpsilon(t, expectedResult["age_max"].(float64), resultMap["age_max"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["score_min"].(float64), resultMap["score_min"].(float64), 0.0001)
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
found = true
break
}
//}
}
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
}
}
+74 -2
View File
@@ -1,13 +1,85 @@
package streamsql
import (
"github.com/stretchr/testify/assert"
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStreamsql(t *testing.T) {
streamsql := New()
var rsql = ""
var rsql = "SELECT device,max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream group by device,SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
testData := []interface{}{
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime},
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)},
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime},
}
for _, data := range testData {
strm.AddData(data)
}
// 捕获结果
resultChan := make(chan interface{})
strm.AddSink(func(result interface{}) {
resultChan <- result
})
// 等待 3 秒触发窗口
time.Sleep(3 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var actual interface{}
select {
case actual = <-resultChan:
cancel()
case <-ctx.Done():
t.Fatal("Timeout waiting for results")
}
expected := []map[string]interface{}{
{
"device": "aa",
"max_age": 10.0,
"min_score": 100.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
{
"device": "bb",
"max_age": 3.0,
"min_score": 300.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
}
assert.IsType(t, []map[string]interface{}{}, actual)
resultSlice, ok := actual.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 2)
for _, expectedResult := range expected {
found := false
for _, resultMap := range resultSlice {
//if resultMap, ok := result.(map[string]interface{}); ok {
if resultMap["device"] == expectedResult["device"] {
assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001)
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
found = true
break
}
//}
}
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
}
}
+18
View File
@@ -0,0 +1,18 @@
package timex
import "time"
// AlignTimeToWindow 将时间对齐到窗口的起始时间。
func AlignTimeToWindow(t time.Time, size time.Duration) time.Time {
offset := t.UnixNano() % int64(size)
return t.Add(time.Duration(-offset))
}
// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。
func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time {
trunc := t.Truncate(timeUnit)
if !roundUp {
return trunc.Add(timeUnit)
}
return trunc
}
+69
View File
@@ -0,0 +1,69 @@
// Copyright 2021 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package timex
import (
"testing"
"time"
)
func TestAlignTimeToWindow(t *testing.T) {
tests := []struct {
name string
input time.Time
size time.Duration
expected time.Time
}{
{
name: "对齐到1分钟窗口",
input: time.Date(2024, 1, 1, 12, 35, 56, 789000000, time.UTC),
size: 3 * time.Minute,
expected: time.Date(2024, 1, 1, 12, 33, 0, 0, time.UTC),
},
{
name: "对齐到5分钟窗口",
input: time.Date(2024, 1, 1, 12, 37, 56, 789000000, time.UTC),
size: 5 * time.Minute,
expected: time.Date(2024, 1, 1, 12, 35, 0, 0, time.UTC),
},
{
name: "对齐到1小时窗口",
input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC),
size: time.Hour,
expected: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
},
{
name: "对齐到1天窗口",
input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC),
size: 24 * time.Hour,
expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "零时刻对齐测试",
input: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
size: time.Hour,
expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := AlignTimeToWindow(tt.input, tt.size)
if !got.Equal(tt.expected) {
t.Errorf("AlignTimeToWindow() = %v, want %v", got, tt.expected)
}
})
}
}
+84 -22
View File
@@ -2,56 +2,90 @@ package window
import (
"context"
"fmt"
"sync"
"time"
"github.com/rulego/streamsql/model"
timex "github.com/rulego/streamsql/utils"
"github.com/spf13/cast"
)
var _ Window = (*CountingWindow)(nil)
type CountingWindow struct {
config model.WindowConfig
threshold int
count int
mu sync.Mutex
callback func([]interface{})
dataBuffer []interface{}
outputChan chan []interface{}
callback func([]model.Row)
dataBuffer []model.Row
outputChan chan []model.Row
ctx context.Context
cancelFunc context.CancelFunc
ticker *time.Ticker
triggerChan chan struct{}
}
func NewCountingWindow(threshold int, callback func([]interface{})) *CountingWindow {
func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) {
ctx, cancel := context.WithCancel(context.Background())
return &CountingWindow{
threshold := cast.ToInt(config.Params["count"])
if threshold <= 0 {
return nil, fmt.Errorf("threshold must be a positive integer")
}
cw := &CountingWindow{
threshold: threshold,
dataBuffer: make([]interface{}, 0, threshold),
outputChan: make(chan []interface{}, 10),
dataBuffer: make([]model.Row, 0, threshold),
outputChan: make(chan []model.Row, 10),
ctx: ctx,
cancelFunc: cancel,
callback: callback,
triggerChan: make(chan struct{}, 1),
}
if callback, ok := config.Params["callback"].(func([]model.Row)); ok {
cw.SetCallback(callback)
}
return cw, nil
}
func (cw *CountingWindow) Add(data interface{}) {
cw.mu.Lock()
cw.dataBuffer = append(cw.dataBuffer, data)
defer cw.mu.Unlock()
// 将数据添加到窗口的数据列表中
t := GetTimestamp(data, cw.config.TsProp)
row := model.Row{
Data: data,
Timestamp: t,
}
cw.dataBuffer = append(cw.dataBuffer, row)
cw.count++
shouldTrigger := cw.count >= cw.threshold
cw.mu.Unlock()
if shouldTrigger {
cw.mu.Lock()
v := append([]interface{}{}, cw.dataBuffer...)
cw.mu.Unlock()
slot := cw.createSlot(cw.dataBuffer[:cw.threshold])
for _, r := range cw.dataBuffer[:cw.threshold] {
// 由于Row是值类型,这里需要通过指针来修改Slot字段
(&r).Slot = slot
}
data := cw.dataBuffer[:cw.threshold]
if len(cw.dataBuffer) > cw.threshold {
remaining := len(cw.dataBuffer) - cw.threshold
newBuffer := make([]model.Row, remaining, cw.threshold)
copy(newBuffer, cw.dataBuffer[cw.threshold:])
cw.dataBuffer = newBuffer
} else {
cw.dataBuffer = make([]model.Row, 0, cw.threshold)
}
go func() {
cw.mu.Lock()
if cw.callback != nil {
cw.callback(v)
cw.callback(data)
}
cw.outputChan <- v
cw.Reset()
cw.outputChan <- data
cw.count = 0
//cw.Reset()
cw.mu.Unlock()
}()
}
}
@@ -66,7 +100,7 @@ func (cw *CountingWindow) Start() {
for {
select {
case <-cw.ticker.C:
cw.Trigger()
//cw.Trigger()
case <-cw.ctx.Done():
return
}
@@ -82,7 +116,17 @@ func (cw *CountingWindow) Trigger() {
defer cw.mu.Unlock()
if cw.callback != nil && len(cw.dataBuffer) > 0 {
cw.callback(cw.dataBuffer)
var resultData []model.Row
if len(cw.dataBuffer) > cw.threshold {
resultData = cw.dataBuffer[:cw.threshold]
} else {
resultData = cw.dataBuffer
}
slot := cw.createSlot(resultData)
for _, r := range resultData {
r.Slot = slot
}
cw.callback(resultData)
}
cw.Reset()
}()
@@ -95,9 +139,27 @@ func (cw *CountingWindow) Reset() {
cw.dataBuffer = cw.dataBuffer[:0]
}
func (cw *CountingWindow) OutputChan() <-chan []interface{} {
func (cw *CountingWindow) OutputChan() <-chan []model.Row {
return cw.outputChan
}
func (cw *CountingWindow) GetResults() []interface{} {
return append([]interface{}{}, cw.dataBuffer...)
// func (cw *CountingWindow) GetResults() []interface{} {
// return append([]mode.Row, cw.dataBuffer...)
// }
// createSlot 创建一个新的时间槽位
func (cw *CountingWindow) createSlot(data []model.Row) *model.TimeSlot {
if len(data) == 0 {
return nil
} else if len(data) < cw.threshold {
start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true)
end := timex.AlignTime(data[len(cw.dataBuffer)-1].Timestamp, cw.config.TimeUnit, false)
slot := model.NewTimeSlot(&start, &end)
return slot
} else {
start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true)
end := timex.AlignTime(data[cw.threshold-1].Timestamp, cw.config.TimeUnit, false)
slot := model.NewTimeSlot(&start, &end)
return slot
}
}
+27 -16
View File
@@ -2,10 +2,12 @@ package window
import (
"context"
"github.com/stretchr/testify/require"
"testing"
"time"
"github.com/rulego/streamsql/model"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)
@@ -14,8 +16,13 @@ func TestCountingWindow(t *testing.T) {
defer cancel()
// Test case 1: Normal operation
cw := NewCountingWindow(3, func(results []interface{}) {
t.Logf("Received results: %v", results)
cw, _ := NewCountingWindow(model.WindowConfig{
Params: map[string]interface{}{
"count": 3,
"callback": func(results []interface{}) {
t.Logf("Received results: %v", results)
},
},
})
go cw.Start()
@@ -27,31 +34,35 @@ func TestCountingWindow(t *testing.T) {
// Trigger one more element to check threshold
cw.Add(3)
results := make(chan []interface{})
go func() {
for res := range cw.OutputChan() {
results <- res
}
}()
resultsChan := cw.OutputChan()
//results := make(chan []model.Row)
// go func() {
// for res := range cw.OutputChan() {
// results <- res
// }
// }()
select {
case res := <-results:
case res := <-resultsChan:
assert.Len(t, res, 3)
assert.Contains(t, res, 0)
assert.Contains(t, res, 1)
assert.Contains(t, res, 2)
assert.Equal(t, 0, res[0].Data, "第一个元素应该是0")
assert.Equal(t, 1, res[1].Data, "第二个元素应该是1")
assert.Equal(t, 2, res[2].Data, "第三个元素应该是2")
case <-time.After(2 * time.Second):
t.Error("No results received within timeout")
}
assert.Len(t, cw.dataBuffer, 1)
// Test case 2: Reset
cw.Reset()
assert.Len(t, cw.dataBuffer, 0)
}
func TestCountingWindowBadThreshold(t *testing.T) {
_, err := CreateWindow("counting", map[string]interface{}{
"count": 0,
_, err := CreateWindow(model.WindowConfig{
Type: "counting",
Params: map[string]interface{}{
"count": 0,
},
})
require.Error(t, err)
}
+42 -31
View File
@@ -2,7 +2,10 @@ package window
import (
"fmt"
"github.com/spf13/cast"
"reflect"
"time"
"github.com/rulego/streamsql/model"
)
const (
@@ -14,47 +17,55 @@ const (
type Window interface {
Add(item interface{})
GetResults() []interface{}
//GetResults() []interface{}
Reset()
Start()
OutputChan() <-chan []interface{}
SetCallback(callback func([]interface{}))
OutputChan() <-chan []model.Row
SetCallback(callback func([]model.Row))
Trigger()
}
func CreateWindow(windowType string, params map[string]interface{}) (Window, error) {
switch windowType {
func CreateWindow(config model.WindowConfig) (Window, error) {
switch config.Type {
case TypeTumbling:
size, err := cast.ToDurationE(params["size"])
if err != nil {
return nil, fmt.Errorf("invalid size for tumbling window: %v", err)
}
return NewTumblingWindow(size), nil
return NewTumblingWindow(config)
case TypeSliding:
size, err := cast.ToDurationE(params["size"])
if err != nil {
return nil, fmt.Errorf("invalid size for sliding window: %v", err)
}
slide, err := cast.ToDurationE(params["slide"])
if err != nil {
return nil, fmt.Errorf("invalid slide for sliding window: %v", err)
}
return NewSlidingWindow(size, slide), nil
return NewSlidingWindow(config)
case TypeCounting:
count := cast.ToInt(params["count"])
if count <= 0 {
return nil, fmt.Errorf("count must be a positive integer")
}
cw := NewCountingWindow(count, nil)
if callback, ok := params["callback"].(func([]interface{})); ok {
cw.SetCallback(callback)
}
return cw, nil
return NewCountingWindow(config)
default:
return nil, fmt.Errorf("unsupported window type: %s", windowType)
return nil, fmt.Errorf("unsupported window type: %s", config.Type)
}
}
func (cw *CountingWindow) SetCallback(callback func([]interface{})) {
func (cw *CountingWindow) SetCallback(callback func([]model.Row)) {
cw.callback = callback
}
// GetTimestamp 从数据中获取时间戳。
func GetTimestamp(data interface{}, tsProp string) time.Time {
if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok {
return ts.GetTimestamp()
} else if tsProp != "" {
v := reflect.ValueOf(data)
// 处理不同类型
switch v.Kind() {
case reflect.Struct:
// 如果是结构体,使用反射获取字段值
if f := v.FieldByName(tsProp); f.IsValid() {
if t, ok := f.Interface().(time.Time); ok {
return t
}
}
case reflect.Map:
// 如果是map,直接通过key获取值
if v.Type().Key().Kind() == reflect.String {
if value := v.MapIndex(reflect.ValueOf(tsProp)); value.IsValid() {
return value.Interface().(time.Time)
}
}
}
}
return time.Now()
}
+72 -42
View File
@@ -2,8 +2,13 @@ package window
import (
"context"
"fmt"
"sync"
"time"
"github.com/rulego/streamsql/model"
timex "github.com/rulego/streamsql/utils"
"github.com/spf13/cast"
)
// 确保 SlidingWindow 结构体实现了 Window 接口
@@ -11,12 +16,14 @@ var _ Window = (*SlidingWindow)(nil)
// TimedData 用于包装数据和时间戳
type TimedData struct {
Data interface{}
Timestamp time.Time
Data interface{}
Timestamp time.Time
}
// SlidingWindow 表示一个滑动窗口,用于按时间范围处理数据
type SlidingWindow struct {
// config 窗口的配置信息
config model.WindowConfig
// 窗口的总大小,即窗口覆盖的时间范围
size time.Duration
// 窗口每次滑动的时间间隔
@@ -24,32 +31,42 @@ type SlidingWindow struct {
// 用于保护数据并发访问的互斥锁
mu sync.Mutex
// 存储窗口内的数据
data []TimedData
data []model.Row
// 用于输出窗口内数据的通道
outputChan chan []interface{}
outputChan chan []model.Row
// 当窗口触发时执行的回调函数
callback func([]interface{})
callback func([]model.Row)
// 用于控制窗口生命周期的上下文
ctx context.Context
// 用于取消上下文的函数
cancelFunc context.CancelFunc
// 用于定时触发窗口的定时器
timer *time.Timer
timer *time.Timer
currentSlot *model.TimeSlot
}
// NewSlidingWindow 创建一个新的滑动窗口实例
// 参数 size 表示窗口的总大小,slide 表示窗口每次滑动的时间间隔
func NewSlidingWindow(size, slide time.Duration) *SlidingWindow {
func NewSlidingWindow(config model.WindowConfig) (*SlidingWindow, error) {
// 创建一个可取消的上下文
ctx, cancel := context.WithCancel(context.Background())
size, err := cast.ToDurationE(config.Params["size"])
if err != nil {
return nil, fmt.Errorf("invalid size for sliding window: %v", err)
}
slide, err := cast.ToDurationE(config.Params["slide"])
if err != nil {
return nil, fmt.Errorf("invalid slide for sliding window: %v", err)
}
return &SlidingWindow{
config: config,
size: size,
slide: slide,
outputChan: make(chan []interface{}, 10),
outputChan: make(chan []model.Row, 10),
ctx: ctx,
cancelFunc: cancel,
data: make([]TimedData, 0),
}
data: make([]model.Row, 0),
}, nil
}
// Add 向滑动窗口中添加数据
@@ -58,19 +75,18 @@ func (sw *SlidingWindow) Add(data interface{}) {
// 加锁以保证数据的并发安全
sw.mu.Lock()
defer sw.mu.Unlock()
var timestamp time.Time
if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok {
timestamp = ts.GetTimestamp()
} else {
timestamp = time.Now()
}
// 将数据添加到窗口的数据列表中
sw.data = append(sw.data, TimedData{
Data: data,
Timestamp: timestamp,
})
// 将数据添加到窗口的数据列表中
t := GetTimestamp(data, sw.config.TsProp)
if sw.currentSlot == nil {
sw.currentSlot = sw.createSlot(t)
}
go func() {
row := model.Row{
Data: data,
Timestamp: t,
}
sw.data = append(sw.data, row)
}()
}
// Start 启动滑动窗口,开始定时触发窗口
@@ -106,19 +122,25 @@ func (sw *SlidingWindow) Trigger() {
}
// 计算截止时间,即当前时间减去窗口的总大小
cutoff := time.Now().Add(-sw.size)
var newData []TimedData
// 遍历窗口内的数据,只保留在截止时间之后的数据
next := sw.NextSlot()
// 保留下一个窗口的数据
tms := next.Start.Add(-sw.size)
tme := next.End.Add(sw.size)
temp := model.NewTimeSlot(&tms, &tme)
newData := make([]model.Row, 0)
for _, item := range sw.data {
if item.Timestamp.After(cutoff) {
if temp.Contains(item.Timestamp) {
newData = append(newData, item)
}
}
// 提取出 Data 字段组成 []interface{} 类型的数据
resultData := make([]interface{}, 0, len(newData))
for _, item := range newData {
resultData = append(resultData, item.Data)
resultData := make([]model.Row, 0)
for _, item := range sw.data {
if sw.currentSlot.Contains(item.Timestamp) {
item.Slot = sw.currentSlot
resultData = append(resultData, item)
}
}
// 如果设置了回调函数,则执行回调函数
@@ -128,6 +150,7 @@ func (sw *SlidingWindow) Trigger() {
// 更新窗口内的数据
sw.data = newData
sw.currentSlot = next
// 将新的数据发送到输出通道
sw.outputChan <- resultData
}
@@ -139,28 +162,35 @@ func (sw *SlidingWindow) Reset() {
defer sw.mu.Unlock()
// 清空窗口内的数据
sw.data = nil
sw.currentSlot = nil
}
// OutputChan 返回滑动窗口的输出通道
func (sw *SlidingWindow) OutputChan() <-chan []interface{} {
func (sw *SlidingWindow) OutputChan() <-chan []model.Row {
return sw.outputChan
}
// SetCallback 设置滑动窗口触发时执行的回调函数
// 参数 callback 表示要设置的回调函数
func (sw *SlidingWindow) SetCallback(callback func([]interface{})) {
func (sw *SlidingWindow) SetCallback(callback func([]model.Row)) {
sw.callback = callback
}
// GetResults 获取滑动窗口内的当前数据
func (sw *SlidingWindow) GetResults() []interface{} {
// 加锁以保证数据的并发安全
sw.mu.Lock()
defer sw.mu.Unlock()
// 提取出 Data 字段组成 []interface{} 类型的数据
resultData := make([]interface{}, 0, len(sw.data))
for _, item := range sw.data {
resultData = append(resultData, item.Data)
func (sw *SlidingWindow) NextSlot() *model.TimeSlot {
if sw.currentSlot == nil {
return nil
}
return resultData
start := sw.currentSlot.Start.Add(sw.slide)
end := sw.currentSlot.End.Add(sw.slide)
next := model.NewTimeSlot(&start, &end)
return next
}
// createSlot 创建一个新的时间槽位
func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot {
// 创建一个新的时间槽位
start := timex.AlignTimeToWindow(t, sw.size)
end := start.Add(sw.size)
slot := model.NewTimeSlot(&start, &end)
return slot
}
+103 -17
View File
@@ -2,42 +2,128 @@ package window
import (
"context"
"github.com/stretchr/testify/assert"
"testing"
"time"
"github.com/rulego/streamsql/model"
timex "github.com/rulego/streamsql/utils"
"github.com/stretchr/testify/assert"
)
func TestSlidingWindow(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
sw := NewSlidingWindow(2*time.Second, 1*time.Second)
sw.SetCallback(func(results []interface{}) {
sw, _ := NewSlidingWindow(model.WindowConfig{
Params: map[string]interface{}{
"size": "2s",
"slide": "1s",
},
TsProp: "Ts",
TimeUnit: time.Second,
})
sw.SetCallback(func(results []model.Row) {
t.Logf("Received results: %v", results)
})
sw.Start()
// 添加数据
now := time.Now()
sw.Add(now.Add(-3 * time.Second))
sw.Add(now.Add(-2 * time.Second))
sw.Add(now.Add(-1 * time.Second))
sw.Add(now)
t_3 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 56, 789000000, time.UTC), tag: "1"}
t_2 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 57, 789000000, time.UTC), tag: "2"}
t_1 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 58, 789000000, time.UTC), tag: "3"}
t_0 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 59, 789000000, time.UTC), tag: "4"}
sw.Add(t_3)
sw.Add(t_2)
sw.Add(t_1)
sw.Add(t_0)
// 等待一段时间,触发窗口
time.Sleep(3 * time.Second)
//time.Sleep(3 * time.Second)
// 检查结果
resultsChan := sw.OutputChan()
var results []interface{}
select {
case results = <-resultsChan:
case <-time.After(1 * time.Second):
t.Fatal("No results received within timeout")
var results []model.Row
for {
select {
case results = <-resultsChan:
raw := make([]TestDate, 0)
for _, row := range results {
raw = append(raw, row.Data.(TestDate))
}
// 获取当前窗口的时间范围
windowStart := results[0].Slot.Start
windowEnd := results[0].Slot.End
t.Logf("Window range: %v - %v", windowStart, windowEnd)
// 检查窗口内的数据
expectedData := make([]TestDate, 0)
if windowStart.Before(t_3.Ts) && windowEnd.After(t_2.Ts) {
expectedData = []TestDate{t_3, t_2}
start := timex.AlignTimeToWindow(t_3.Ts, sw.size)
assert.Equal(t, start, windowStart)
assert.Equal(t, start.Add(sw.size), windowEnd)
} else if windowStart.Before(t_2.Ts) && windowEnd.After(t_1.Ts) {
expectedData = []TestDate{t_2, t_1}
start := timex.AlignTimeToWindow(t_2.Ts, sw.size)
assert.Equal(t, start, windowStart)
assert.Equal(t, start.Add(sw.size), windowEnd)
} else if windowStart.Before(t_1.Ts) && windowEnd.After(t_0.Ts) {
expectedData = []TestDate{t_1, t_0}
start := timex.AlignTimeToWindow(t_1.Ts, sw.size)
assert.Equal(t, start, windowStart)
assert.Equal(t, start.Add(sw.size), windowEnd)
} else {
expectedData = []TestDate{t_0}
start := timex.AlignTimeToWindow(t_0.Ts, sw.size)
assert.Equal(t, start, windowStart)
assert.Equal(t, start.Add(sw.size), windowEnd)
}
// 验证窗口数据
assert.Equal(t, len(expectedData), len(raw), "窗口数据数量不匹配")
for _, expected := range expectedData {
assert.Contains(t, raw, expected, "窗口缺少预期数据")
}
default:
// 通道为空时退出
goto END
}
}
END:
// 预期结果:保留最近 2 秒内的数据
assert.Len(t, results, 2)
assert.Contains(t, results, now.Add(-1*time.Second))
assert.Contains(t, results, now)
assert.Len(t, results, 0)
}
type TestDate struct {
Ts time.Time
tag string
}
type TestDate2 struct {
ts time.Time
}
func (d TestDate2) GetTimestamp() time.Time {
return d.ts
}
func TestGetTimestamp(t *testing.T) {
t_0 := time.Now()
data := map[string]interface{}{"device": "aa", "age": 15.0, "score": 100, "ts": t_0}
t_1 := GetTimestamp(data, "ts")
data_1 := TestDate{Ts: t_0}
t_2 := GetTimestamp(data_1, "Ts")
data_2 := TestDate2{ts: t_0}
t_3 := GetTimestamp(data_2, "")
assert.Equal(t, t_0, t_1)
assert.Equal(t, t_0, t_2)
assert.Equal(t, t_0, t_3)
}
+88 -26
View File
@@ -3,8 +3,13 @@ package window
import (
"context"
"fmt"
"sync"
"time"
"github.com/rulego/streamsql/model"
timex "github.com/rulego/streamsql/utils"
"github.com/spf13/cast"
)
// 确保 TumblingWindow 结构体实现了 Window 接口。
@@ -12,35 +17,43 @@ var _ Window = (*TumblingWindow)(nil)
// TumblingWindow 表示一个滚动窗口,用于在固定时间间隔内收集数据并触发处理。
type TumblingWindow struct {
// config 是窗口的配置信息。
config model.WindowConfig
// size 是滚动窗口的时间大小,即窗口的持续时间。
size time.Duration
// mu 用于保护对窗口数据的并发访问。
mu sync.Mutex
// data 存储窗口内收集的数据。
data []interface{}
data []model.Row
// outputChan 是一个通道,用于在窗口触发时发送数据。
outputChan chan []interface{}
outputChan chan []model.Row
// callback 是一个可选的回调函数,在窗口触发时调用。
callback func([]interface{})
callback func([]model.Row)
// ctx 用于控制窗口的生命周期。
ctx context.Context
// cancelFunc 用于取消窗口的操作。
cancelFunc context.CancelFunc
// timer 用于定时触发窗口。
timer *time.Timer
timer *time.Timer
currentSlot *model.TimeSlot
}
// NewTumblingWindow 创建一个新的滚动窗口实例。
// 参数 size 是窗口的时间大小。
func NewTumblingWindow(size time.Duration) *TumblingWindow {
func NewTumblingWindow(config model.WindowConfig) (*TumblingWindow, error) {
// 创建一个可取消的上下文。
ctx, cancel := context.WithCancel(context.Background())
size, err := cast.ToDurationE(config.Params["size"])
if err != nil {
return nil, fmt.Errorf("invalid size for tumbling window: %v", err)
}
return &TumblingWindow{
config: config,
size: size,
outputChan: make(chan []interface{}, 10),
outputChan: make(chan []model.Row, 10),
ctx: ctx,
cancelFunc: cancel,
}
}, nil
}
// Add 向滚动窗口添加数据。
@@ -50,7 +63,33 @@ func (tw *TumblingWindow) Add(data interface{}) {
tw.mu.Lock()
defer tw.mu.Unlock()
// 将数据追加到窗口的数据列表中。
tw.data = append(tw.data, data)
if tw.currentSlot == nil {
tw.currentSlot = tw.createSlot(GetTimestamp(data, tw.config.TsProp))
}
go func() {
row := model.Row{
Data: data,
Timestamp: GetTimestamp(data, tw.config.TsProp),
}
tw.data = append(tw.data, row)
}()
}
func (sw *TumblingWindow) createSlot(t time.Time) *model.TimeSlot {
// 创建一个新的时间槽位
start := timex.AlignTimeToWindow(t, sw.size)
end := start.Add(sw.size)
slot := model.NewTimeSlot(&start, &end)
return slot
}
func (sw *TumblingWindow) NextSlot() *model.TimeSlot {
if sw.currentSlot == nil {
return nil
}
start := sw.currentSlot.End
end := sw.currentSlot.End.Add(sw.size)
return model.NewTimeSlot(start, &end)
}
// Stop 停止滚动窗口的操作。
@@ -87,16 +126,38 @@ func (tw *TumblingWindow) Trigger() {
// 加锁以确保并发安全。
tw.mu.Lock()
defer tw.mu.Unlock()
// 如果设置了回调函数,则调用它。
if tw.callback != nil {
tw.callback(tw.data)
// 计算下一个窗口槽位
next := tw.NextSlot()
// 保留下一个窗口的数据
tms := next.Start.Add(-tw.size)
tme := next.End.Add(tw.size)
temp := model.NewTimeSlot(&tms, &tme)
newData := make([]model.Row, 0)
for _, item := range tw.data {
if temp.Contains(item.Timestamp) {
newData = append(newData, item)
}
}
// 将窗口数据发送到输出通道。
tw.outputChan <- append([]interface{}{}, tw.data...)
// 重置窗口数据。
tw.data = nil
// 提取出当前窗口数据
resultData := make([]model.Row, 0)
for _, item := range tw.data {
if tw.currentSlot.Contains(item.Timestamp) {
item.Slot = tw.currentSlot
resultData = append(resultData, item)
}
}
// 如果设置了回调函数,则执行回调函数
if tw.callback != nil {
tw.callback(resultData)
}
// 更新窗口内的数据
tw.data = newData
tw.currentSlot = next
// 将新的数据发送到输出通道
tw.outputChan <- resultData
}
// Reset 重置滚动窗口的数据。
@@ -106,24 +167,25 @@ func (tw *TumblingWindow) Reset() {
defer tw.mu.Unlock()
// 清空窗口数据。
tw.data = nil
tw.currentSlot = nil
}
// OutputChan 返回一个只读通道,用于接收窗口触发时的数据。
func (tw *TumblingWindow) OutputChan() <-chan []interface{} {
func (tw *TumblingWindow) OutputChan() <-chan []model.Row {
return tw.outputChan
}
// SetCallback 设置滚动窗口触发时的回调函数。
// 参数 callback 是要设置的回调函数。
func (tw *TumblingWindow) SetCallback(callback func([]interface{})) {
func (tw *TumblingWindow) SetCallback(callback func([]model.Row)) {
tw.callback = callback
}
// GetResults 获取当前滚动窗口中的数据副本。
func (tw *TumblingWindow) GetResults() []interface{} {
// 加锁以确保并发安全。
tw.mu.Lock()
defer tw.mu.Unlock()
// 返回窗口数据的副本。
return append([]interface{}{}, tw.data...)
}
// // GetResults 获取当前滚动窗口中的数据副本。
// func (tw *TumblingWindow) GetResults() []interface{} {
// // 加锁以确保并发安全。
// tw.mu.Lock()
// defer tw.mu.Unlock()
// // 返回窗口数据的副本。
// return append([]interface{}{}, tw.data...)
// }
+68 -28
View File
@@ -2,62 +2,102 @@ package window
import (
"context"
"github.com/stretchr/testify/require"
"fmt"
"testing"
"time"
"github.com/rulego/streamsql/model"
"github.com/stretchr/testify/require"
)
func TestTumblingWindow(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
tw := NewTumblingWindow(2 * time.Second)
tw.SetCallback(func(results []interface{}) {
tw, _ := NewTumblingWindow(model.WindowConfig{
Type: "TumblingWindow",
Params: map[string]interface{}{"size": "2s"},
TsProp: "Ts",
})
tw.SetCallback(func(results []model.Row) {
// Process results
})
go tw.Start()
// Add data every 500ms
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
// 添加测试数据
for i := 0; i < 5; i++ {
tw.Add(i)
time.Sleep(1100 * time.Millisecond)
data := TestDate{
Ts: baseTime.Add(time.Duration(i) * 1100 * time.Millisecond),
tag: fmt.Sprintf("%d", i),
}
tw.Add(data)
}
// Check output channel
// 收集窗口结果
resultsChan := tw.OutputChan()
var results []interface{}
select {
case results = <-resultsChan:
case <-time.After(3 * time.Second):
t.Fatal("No results received within timeout")
var all [][]model.Row = make([][]model.Row, 0)
// 收集所有窗口数据
COLLECT:
for {
select {
case results := <-resultsChan:
all = append(all, results)
if len(all) >= 3 {
break COLLECT
}
default:
}
}
// Verify that data is sent every 2 seconds
require.Len(t, results, 2)
require.Equal(t, []interface{}{0, 1}, results)
// 验证窗口数据
require.Len(t, all, 3, "应该有3个时间窗口的数据")
// Verify next batch
select {
case results = <-resultsChan:
require.Len(t, results, 2)
require.Equal(t, []interface{}{2, 3}, results)
case <-time.After(3 * time.Second):
t.Fatal("No results received within timeout")
// 验证每个窗口的数据
expectedWindows := []struct {
size int
tags []string
startIdx int
}{
{size: 2, tags: []string{"0", "1"}, startIdx: 0},
{size: 2, tags: []string{"2", "3"}, startIdx: 1},
{size: 1, tags: []string{"4"}, startIdx: 2},
}
//time.Sleep(1100 * time.Millisecond)
//results = <-resultsChan
for i, window := range all {
expected := expectedWindows[i]
require.Len(t, window, expected.size, "窗口 %d 数据数量不匹配", i)
// 验证数据内容
for _, row := range window {
require.Contains(t, expected.tags, row.Data.(TestDate).tag)
}
// 验证时间槽
startTime := baseTime.Add(time.Duration(i*2) * time.Second)
endTime := startTime.Add(2 * time.Second)
require.True(t, window[0].Slot.Start.Equal(startTime) &&
window[0].Slot.End.Equal(endTime),
"窗口 %d 时间槽边界不正确", i)
}
// Verify reset and final batch
tw.Reset()
tw.Add(99)
tw.Add(TestDate{
Ts: baseTime.Add(time.Duration(99) * 1100 * time.Millisecond),
tag: fmt.Sprintf("%d", 99),
})
// time.Sleep(1100 * time.Millisecond)
cancel()
select {
case results = <-resultsChan:
case results := <-resultsChan:
require.Len(t, results, 1)
require.Equal(t, []interface{}{99}, results)
case <-time.After(3 * time.Second):
t.Fatal("No results received within timeout")
require.Equal(t, "99", results[0].Data.(TestDate).tag)
startTime := baseTime.Add(108 * time.Second)
endTime := baseTime.Add(110 * time.Second)
require.True(t, results[0].Slot.Start.Equal(startTime) && results[0].Slot.End.Equal(endTime))
}
}