mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-05-06 10:55:47 +00:00
Merge pull request #2 from dimon-83/main
增加window_start()、window_end()函数支持
This commit is contained in:
+81
-35
@@ -3,26 +3,29 @@ package aggregator
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type AggregateType string
|
||||
|
||||
const (
|
||||
Sum AggregateType = "sum"
|
||||
Count AggregateType = "count"
|
||||
Avg AggregateType = "avg"
|
||||
Max AggregateType = "max"
|
||||
Min AggregateType = "min"
|
||||
StdDev AggregateType = "stddev"
|
||||
Median AggregateType = "median"
|
||||
Percentile AggregateType = "percentile"
|
||||
Sum AggregateType = "sum"
|
||||
Count AggregateType = "count"
|
||||
Avg AggregateType = "avg"
|
||||
Max AggregateType = "max"
|
||||
Min AggregateType = "min"
|
||||
StdDev AggregateType = "stddev"
|
||||
Median AggregateType = "median"
|
||||
Percentile AggregateType = "percentile"
|
||||
WindowStart AggregateType = "window_start"
|
||||
WindowEnd AggregateType = "window_end"
|
||||
)
|
||||
|
||||
type AggregatorFunction interface {
|
||||
New() AggregatorFunction
|
||||
Add(value float64)
|
||||
Result() float64
|
||||
Add(value interface{})
|
||||
Result() interface{}
|
||||
}
|
||||
|
||||
type SumAggregator struct {
|
||||
@@ -33,11 +36,12 @@ func (s *SumAggregator) New() AggregatorFunction {
|
||||
return &SumAggregator{}
|
||||
}
|
||||
|
||||
func (s *SumAggregator) Add(v float64) {
|
||||
s.value += v
|
||||
func (s *SumAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
s.value += vv
|
||||
}
|
||||
|
||||
func (s *SumAggregator) Result() float64 {
|
||||
func (s *SumAggregator) Result() interface{} {
|
||||
return s.value
|
||||
}
|
||||
|
||||
@@ -49,11 +53,11 @@ func (s *CountAggregator) New() AggregatorFunction {
|
||||
return &CountAggregator{}
|
||||
}
|
||||
|
||||
func (c *CountAggregator) Add(_ float64) {
|
||||
func (c *CountAggregator) Add(_ interface{}) {
|
||||
c.count++
|
||||
}
|
||||
|
||||
func (c *CountAggregator) Result() float64 {
|
||||
func (c *CountAggregator) Result() interface{} {
|
||||
return float64(c.count)
|
||||
}
|
||||
|
||||
@@ -66,12 +70,13 @@ func (a *AvgAggregator) New() AggregatorFunction {
|
||||
return &AvgAggregator{}
|
||||
}
|
||||
|
||||
func (a *AvgAggregator) Add(v float64) {
|
||||
a.sum += v
|
||||
func (a *AvgAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
a.sum += vv
|
||||
a.count++
|
||||
}
|
||||
|
||||
func (a *AvgAggregator) Result() float64 {
|
||||
func (a *AvgAggregator) Result() interface{} {
|
||||
if a.count == 0 {
|
||||
return 0
|
||||
}
|
||||
@@ -117,6 +122,10 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction {
|
||||
return &MedianAggregator{}
|
||||
case Percentile:
|
||||
return &PercentileAggregator{p: 0.95}
|
||||
case WindowStart:
|
||||
return &WindowStartAggregator{}
|
||||
case WindowEnd:
|
||||
return &WindowEndAggregator{}
|
||||
default:
|
||||
panic("unsupported aggregator type: " + aggType)
|
||||
}
|
||||
@@ -150,11 +159,12 @@ func (m *MedianAggregator) New() AggregatorFunction {
|
||||
return &MedianAggregator{}
|
||||
}
|
||||
|
||||
func (m *MedianAggregator) Add(val float64) {
|
||||
m.values = append(m.values, val)
|
||||
func (m *MedianAggregator) Add(val interface{}) {
|
||||
var vv float64 = ConvertToFloat64(val, 0)
|
||||
m.values = append(m.values, vv)
|
||||
}
|
||||
|
||||
func (m *MedianAggregator) Result() float64 {
|
||||
func (m *MedianAggregator) Result() interface{} {
|
||||
sort.Float64s(m.values)
|
||||
return m.values[len(m.values)/2]
|
||||
}
|
||||
@@ -168,8 +178,9 @@ func (p *PercentileAggregator) New() AggregatorFunction {
|
||||
return &PercentileAggregator{}
|
||||
}
|
||||
|
||||
func (p *PercentileAggregator) Add(v float64) {
|
||||
p.values = append(p.values, v)
|
||||
func (p *PercentileAggregator) Add(v interface{}) {
|
||||
vv := ConvertToFloat64(v, 0)
|
||||
p.values = append(p.values, vv)
|
||||
}
|
||||
|
||||
type MinAggregator struct {
|
||||
@@ -183,14 +194,15 @@ func (s *MinAggregator) New() AggregatorFunction {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MinAggregator) Add(v float64) {
|
||||
if m.first || v < m.value {
|
||||
m.value = v
|
||||
func (m *MinAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, math.MaxFloat64)
|
||||
if m.first || vv < m.value {
|
||||
m.value = vv
|
||||
m.first = false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MinAggregator) Result() float64 {
|
||||
func (m *MinAggregator) Result() interface{} {
|
||||
return m.value
|
||||
}
|
||||
|
||||
@@ -203,22 +215,24 @@ func (m *MaxAggregator) New() AggregatorFunction {
|
||||
return &MaxAggregator{}
|
||||
}
|
||||
|
||||
func (m *MaxAggregator) Add(v float64) {
|
||||
if m.first || v > m.value {
|
||||
m.value = v
|
||||
func (m *MaxAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
if m.first || vv > m.value {
|
||||
m.value = vv
|
||||
m.first = false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MaxAggregator) Result() float64 {
|
||||
func (m *MaxAggregator) Result() interface{} {
|
||||
return m.value
|
||||
}
|
||||
|
||||
func (s *StdDevAggregator) Add(v float64) {
|
||||
s.values = append(s.values, v)
|
||||
func (s *StdDevAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
s.values = append(s.values, vv)
|
||||
}
|
||||
|
||||
func (s *StdDevAggregator) Result() float64 {
|
||||
func (s *StdDevAggregator) Result() interface{} {
|
||||
if len(s.values) < 2 {
|
||||
return 0
|
||||
}
|
||||
@@ -230,7 +244,7 @@ func (s *StdDevAggregator) Result() float64 {
|
||||
return math.Sqrt(sum / float64(len(s.values)-1))
|
||||
}
|
||||
|
||||
func (p *PercentileAggregator) Result() float64 {
|
||||
func (p *PercentileAggregator) Result() interface{} {
|
||||
if len(p.values) == 0 {
|
||||
return 0
|
||||
}
|
||||
@@ -246,3 +260,35 @@ func calculateAverage(values []float64) float64 {
|
||||
}
|
||||
return sum / float64(len(values))
|
||||
}
|
||||
|
||||
func ConvertToFloat64(v interface{}, defaultVal float64) float64 {
|
||||
var vv float64 = defaultVal
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
vv = val
|
||||
case float32:
|
||||
vv = float64(val)
|
||||
case int:
|
||||
vv = float64(val)
|
||||
case int32:
|
||||
vv = float64(val)
|
||||
case int64:
|
||||
vv = float64(val)
|
||||
case uint:
|
||||
vv = float64(val)
|
||||
case uint32:
|
||||
vv = float64(val)
|
||||
case uint64:
|
||||
vv = float64(val)
|
||||
case string:
|
||||
// 处理字符串类型的转换
|
||||
if floatValue, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
vv = floatValue
|
||||
} else {
|
||||
panic("unsupported type for sum aggregator")
|
||||
}
|
||||
default:
|
||||
panic("unsupported type for sum aggregator")
|
||||
}
|
||||
return vv
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package aggregator
|
||||
|
||||
type ContextAggregator interface {
|
||||
GetContextKey() string
|
||||
}
|
||||
|
||||
type WindowStartAggregator struct {
|
||||
val interface{}
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) New() AggregatorFunction {
|
||||
return &WindowStartAggregator{}
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) Add(val interface{}) {
|
||||
w.val = val
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) Result() interface{} {
|
||||
return w.val
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) GetContextKey() string {
|
||||
return "window_start"
|
||||
}
|
||||
|
||||
type WindowEndAggregator struct {
|
||||
val interface{}
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) New() AggregatorFunction {
|
||||
return &WindowEndAggregator{}
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) Add(val interface{}) {
|
||||
w.val = val
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) Result() interface{} {
|
||||
return w.val
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) GetContextKey() string {
|
||||
return "window_end"
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
type Aggregator interface {
|
||||
Add(data interface{}) error
|
||||
Put(key string, val interface{}) error
|
||||
GetResults() ([]map[string]interface{}, error)
|
||||
Reset()
|
||||
}
|
||||
@@ -19,9 +20,11 @@ type GroupAggregator struct {
|
||||
aggregators map[string]AggregatorFunction
|
||||
groups map[string]map[string]AggregatorFunction
|
||||
mu sync.RWMutex
|
||||
context map[string]interface{}
|
||||
fieldAlias map[string]string
|
||||
}
|
||||
|
||||
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator {
|
||||
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator {
|
||||
aggregators := make(map[string]AggregatorFunction)
|
||||
|
||||
for field, aggType := range fieldMap {
|
||||
@@ -33,8 +36,20 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType)
|
||||
groupFields: groupFields,
|
||||
aggregators: aggregators,
|
||||
groups: make(map[string]map[string]AggregatorFunction),
|
||||
fieldAlias: fieldAlias,
|
||||
}
|
||||
}
|
||||
|
||||
func (ga *GroupAggregator) Put(key string, val interface{}) error {
|
||||
ga.mu.Lock() // 获取写锁
|
||||
defer ga.mu.Unlock() // 确保函数返回时释放锁
|
||||
if ga.context == nil {
|
||||
ga.context = make(map[string]interface{})
|
||||
}
|
||||
ga.context[key] = val
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
ga.mu.Lock() // 获取写锁
|
||||
defer ga.mu.Unlock() // 确保函数返回时释放锁
|
||||
@@ -114,6 +129,19 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
if !f.IsValid() {
|
||||
//return fmt.Errorf("field %s not found", field)
|
||||
//fmt.Printf("field %s not found in %v \n ", field, data)
|
||||
|
||||
// 尝试从context中获取
|
||||
if ga.context != nil {
|
||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
||||
if _, ok := groupAgg.(ContextAggregator); ok {
|
||||
key := groupAgg.(ContextAggregator).GetContextKey()
|
||||
if val, exists := ga.context[key]; exists {
|
||||
groupAgg.Add(val)
|
||||
//fmt.Printf("add agg group by %s:%s , %v \n", key, field, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -151,7 +179,20 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
|
||||
group[field] = fields[i]
|
||||
}
|
||||
for field, agg := range aggregators {
|
||||
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
|
||||
if _, ok := agg.(ContextAggregator); ok {
|
||||
if alias, ok := ga.fieldAlias[field]; ok {
|
||||
group[alias] = agg.Result()
|
||||
} else {
|
||||
group[field] = agg.Result()
|
||||
}
|
||||
} else {
|
||||
if alias, ok := ga.fieldAlias[field]; ok {
|
||||
group[alias] = agg.Result()
|
||||
} else {
|
||||
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
result = append(result, group)
|
||||
}
|
||||
|
||||
@@ -19,13 +19,17 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) {
|
||||
"Data1": Sum,
|
||||
"Data2": Sum,
|
||||
},
|
||||
map[string]string{
|
||||
"Data1": "Data1_sum",
|
||||
"Data2": "Data2_sum",
|
||||
},
|
||||
)
|
||||
|
||||
testData := []map[string]interface{}{
|
||||
{"device": "aa", "data1": 20, "data2": 30},
|
||||
{"device": "aa", "data1": 21, "data2": 0},
|
||||
{"device": "bb", "data1": 15, "data2": 20},
|
||||
{"device": "bb", "data1": 16, "data2": 20},
|
||||
{"Device": "aa", "Data1": 20, "Data2": 30},
|
||||
{"Device": "aa", "Data1": 21, "Data2": 0},
|
||||
{"Device": "bb", "Data1": 15, "Data2": 20},
|
||||
{"Device": "bb", "Data1": 16, "Data2": 20},
|
||||
}
|
||||
|
||||
for _, d := range testData {
|
||||
@@ -47,6 +51,9 @@ func TestGroupAggregator_SingleField(t *testing.T) {
|
||||
map[string]AggregateType{
|
||||
"Data1": Sum,
|
||||
},
|
||||
map[string]string{
|
||||
"Data1": "Data1_sum",
|
||||
},
|
||||
)
|
||||
|
||||
testData := []map[string]interface{}{
|
||||
@@ -75,6 +82,12 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
|
||||
"Data3": Max,
|
||||
"Data4": Min,
|
||||
},
|
||||
map[string]string{
|
||||
"Data1": "Data1_sum",
|
||||
"Data2": "Data2_avg",
|
||||
"Data3": "Data3_max",
|
||||
"Data4": "Data4_min",
|
||||
},
|
||||
)
|
||||
|
||||
testData := []map[string]interface{}{
|
||||
@@ -88,7 +101,7 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
|
||||
|
||||
expected := []map[string]interface{}{
|
||||
{
|
||||
"Device": "cc",
|
||||
"Device": "cc",
|
||||
"Data1_sum": 30.0,
|
||||
"Data2_avg": 5.0,
|
||||
"Data3_max": 12.0,
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WindowConfig WindowConfig
|
||||
GroupFields []string
|
||||
SelectFields map[string]aggregator.AggregateType
|
||||
FieldAlias map[string]string
|
||||
}
|
||||
type WindowConfig struct {
|
||||
Type string
|
||||
Params map[string]interface{}
|
||||
TsProp string
|
||||
TimeUnit time.Duration
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type RowEvent interface {
|
||||
GetTimestamp() time.Time
|
||||
}
|
||||
|
||||
type Row struct {
|
||||
Timestamp time.Time
|
||||
Data interface{}
|
||||
Slot *TimeSlot
|
||||
}
|
||||
|
||||
// GetTimestamp 获取时间戳
|
||||
func (r *Row) GetTimestamp() time.Time {
|
||||
return r.Timestamp
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type TimeSlot struct {
|
||||
Start *time.Time
|
||||
End *time.Time
|
||||
}
|
||||
|
||||
func NewTimeSlot(start, end *time.Time) *TimeSlot {
|
||||
return &TimeSlot{
|
||||
Start: start,
|
||||
End: end,
|
||||
}
|
||||
}
|
||||
|
||||
// Hash 生成槽位的哈希值
|
||||
func (ts TimeSlot) Hash() uint64 {
|
||||
// 将开始时间和结束时间转换为 Unix 时间戳(纳秒级)
|
||||
startNano := ts.Start.UnixNano()
|
||||
endNano := ts.End.UnixNano()
|
||||
|
||||
// 使用简单但高效的哈希算法
|
||||
// 将两个时间戳组合成一个唯一的哈希值
|
||||
hash := uint64(startNano)
|
||||
hash = (hash << 32) | (hash >> 32)
|
||||
hash = hash ^ uint64(endNano)
|
||||
|
||||
return hash
|
||||
}
|
||||
|
||||
// Contains 检查给定时间是否在槽位范围内
|
||||
func (ts TimeSlot) Contains(t time.Time) bool {
|
||||
return (t.Equal(*ts.Start) || t.After(*ts.Start)) &&
|
||||
t.Before(*ts.End)
|
||||
}
|
||||
|
||||
func (ts *TimeSlot) GetStartTime() *time.Time {
|
||||
if ts == nil || ts.Start == nil {
|
||||
return nil
|
||||
}
|
||||
return ts.Start
|
||||
}
|
||||
|
||||
func (ts *TimeSlot) GetEndTime() *time.Time {
|
||||
if ts == nil || ts.End == nil {
|
||||
return nil
|
||||
}
|
||||
return ts.End
|
||||
}
|
||||
|
||||
func (ts *TimeSlot) WindowStart() int64 {
|
||||
if ts == nil || ts.Start == nil {
|
||||
return 0
|
||||
}
|
||||
return ts.Start.UnixNano()
|
||||
}
|
||||
|
||||
func (ts *TimeSlot) WindowEnd() int64 {
|
||||
if ts == nil || ts.End == nil {
|
||||
return 0
|
||||
}
|
||||
return ts.End.UnixNano()
|
||||
}
|
||||
+48
-19
@@ -2,12 +2,13 @@ package rsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/rulego/streamsql/window"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/rulego/streamsql/window"
|
||||
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/stream"
|
||||
)
|
||||
|
||||
type SelectStatement struct {
|
||||
@@ -21,15 +22,18 @@ type SelectStatement struct {
|
||||
type Field struct {
|
||||
Expression string
|
||||
Alias string
|
||||
AggType string
|
||||
}
|
||||
|
||||
type WindowDefinition struct {
|
||||
Type string
|
||||
Params []interface{}
|
||||
Type string
|
||||
Params []interface{}
|
||||
TsProp string
|
||||
TimeUnit time.Duration
|
||||
}
|
||||
|
||||
// ToStreamConfig 将AST转换为Stream配置
|
||||
func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) {
|
||||
func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) {
|
||||
if s.Source == "" {
|
||||
return nil, "", fmt.Errorf("missing FROM clause")
|
||||
}
|
||||
@@ -49,15 +53,18 @@ func (s *SelectStatement) ToStreamConfig() (*stream.Config, string, error) {
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("解析窗口参数失败: %w", err)
|
||||
}
|
||||
|
||||
aggs, fields := buildSelectFields(s.Fields)
|
||||
// 构建Stream配置
|
||||
config := stream.Config{
|
||||
WindowConfig: stream.WindowConfig{
|
||||
Type: windowType,
|
||||
Params: params,
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: windowType,
|
||||
Params: params,
|
||||
TsProp: s.Window.TsProp,
|
||||
TimeUnit: s.Window.TimeUnit,
|
||||
},
|
||||
GroupFields: extractGroupFields(s),
|
||||
SelectFields: buildSelectFields(s.Fields),
|
||||
SelectFields: aggs,
|
||||
FieldAlias: fields,
|
||||
}
|
||||
|
||||
return &config, s.Condition, nil
|
||||
@@ -73,28 +80,50 @@ func extractGroupFields(s *SelectStatement) []string {
|
||||
return fields
|
||||
}
|
||||
|
||||
func buildSelectFields(fields []Field) map[string]aggregator.AggregateType {
|
||||
func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string) {
|
||||
selectFields := make(map[string]aggregator.AggregateType)
|
||||
fieldMap = make(map[string]string)
|
||||
for _, f := range fields {
|
||||
if alias := f.Alias; alias != "" {
|
||||
selectFields[alias] = parseAggregateType(f.Expression)
|
||||
t, n := parseAggregateType(f.Expression)
|
||||
if n != "" {
|
||||
selectFields[n] = t
|
||||
fieldMap[n] = alias
|
||||
} else {
|
||||
selectFields[alias] = t
|
||||
}
|
||||
}
|
||||
}
|
||||
return selectFields
|
||||
return selectFields, fieldMap
|
||||
}
|
||||
|
||||
func parseAggregateType(expr string) aggregator.AggregateType {
|
||||
func parseAggregateType(expr string) (aggType aggregator.AggregateType, name string) {
|
||||
if strings.Contains(expr, "avg(") {
|
||||
return "avg"
|
||||
return "avg", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "sum(") {
|
||||
return "sum"
|
||||
return "sum", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "max(") {
|
||||
return "max"
|
||||
return "max", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "min(") {
|
||||
return "min"
|
||||
return "min", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "window_start(") {
|
||||
return "window_start", "window_start"
|
||||
}
|
||||
if strings.Contains(expr, "window_end(") {
|
||||
return "window_end", "window_end"
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func extractAggField(expr string) string {
|
||||
start := strings.Index(expr, "(")
|
||||
end := strings.LastIndex(expr, ")")
|
||||
if start >= 0 && end > start {
|
||||
return strings.TrimSpace(expr[start+1 : end])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ const (
|
||||
TokenSliding
|
||||
TokenCounting
|
||||
TokenSession
|
||||
TokenWITH
|
||||
TokenTimestamp
|
||||
TokenTimeUnit
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
@@ -204,6 +207,12 @@ func (l *Lexer) lookupIdent(ident string) Token {
|
||||
return Token{Type: TokenCounting, Value: ident}
|
||||
case "SESSIONWINDOW":
|
||||
return Token{Type: TokenSession, Value: ident}
|
||||
case "WITH":
|
||||
return Token{Type: TokenWITH, Value: ident}
|
||||
case "TIMESTAMP":
|
||||
return Token{Type: TokenTimestamp, Value: ident}
|
||||
case "TIMEUNIT":
|
||||
return Token{Type: TokenTimeUnit, Value: ident}
|
||||
default:
|
||||
return Token{Type: TokenIdent, Value: ident}
|
||||
}
|
||||
|
||||
+78
-3
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
@@ -39,6 +40,10 @@ func (p *Parser) Parse() (*SelectStatement, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.parseWith(stmt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stmt, nil
|
||||
}
|
||||
func (p *Parser) parseSelect(stmt *SelectStatement) error {
|
||||
@@ -125,9 +130,14 @@ func (p *Parser) parseWindowFunction(stmt *SelectStatement, winType string) erro
|
||||
params = append(params, convertValue(valTok.Value))
|
||||
}
|
||||
|
||||
stmt.Window = WindowDefinition{
|
||||
Type: winType,
|
||||
Params: params,
|
||||
if &stmt.Window != nil {
|
||||
stmt.Window.Params = params
|
||||
stmt.Window.Type = winType
|
||||
} else {
|
||||
stmt.Window = WindowDefinition{
|
||||
Type: winType,
|
||||
Params: params,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -185,3 +195,68 @@ func (p *Parser) parseGroupBy(stmt *SelectStatement) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseWith(stmt *SelectStatement) error {
|
||||
p.lexer.NextToken() // 跳过(
|
||||
for p.lexer.peekChar() != ')' {
|
||||
valTok := p.lexer.NextToken()
|
||||
if valTok.Type == TokenRParen || valTok.Type == TokenEOF {
|
||||
break
|
||||
}
|
||||
if valTok.Type == TokenComma {
|
||||
continue
|
||||
}
|
||||
|
||||
if valTok.Type == TokenTimestamp {
|
||||
next := p.lexer.NextToken()
|
||||
if next.Type == TokenEQ {
|
||||
next = p.lexer.NextToken()
|
||||
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
|
||||
next.Value = strings.Trim(next.Value, "'")
|
||||
}
|
||||
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
|
||||
if stmt.Window.Type == "" {
|
||||
stmt.Window = WindowDefinition{
|
||||
TsProp: next.Value,
|
||||
}
|
||||
} else {
|
||||
stmt.Window.TsProp = next.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
if valTok.Type == TokenTimeUnit {
|
||||
timeUnit := time.Minute
|
||||
next := p.lexer.NextToken()
|
||||
if next.Type == TokenEQ {
|
||||
next = p.lexer.NextToken()
|
||||
if strings.HasPrefix(next.Value, "'") && strings.HasSuffix(next.Value, "'") {
|
||||
next.Value = strings.Trim(next.Value, "'")
|
||||
}
|
||||
switch next.Value {
|
||||
case "dd":
|
||||
timeUnit = 24 * time.Hour
|
||||
case "hh":
|
||||
timeUnit = time.Hour
|
||||
case "mi":
|
||||
timeUnit = time.Minute
|
||||
case "ss":
|
||||
timeUnit = time.Second
|
||||
case "ms":
|
||||
timeUnit = time.Millisecond
|
||||
default:
|
||||
|
||||
}
|
||||
// 检查Window是否已初始化,如果未初始化则创建新的WindowDefinition
|
||||
if stmt.Window.Type == "" {
|
||||
stmt.Window = WindowDefinition{
|
||||
TimeUnit: timeUnit,
|
||||
}
|
||||
} else {
|
||||
stmt.Window.TimeUnit = timeUnit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+30
-9
@@ -1,24 +1,25 @@
|
||||
package rsql
|
||||
|
||||
import (
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/stream"
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/model"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected *stream.Config
|
||||
expected *model.Config
|
||||
condition string
|
||||
}{
|
||||
{
|
||||
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s')",
|
||||
expected: &stream.Config{
|
||||
WindowConfig: stream.WindowConfig{
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{
|
||||
"size": 10 * time.Second,
|
||||
@@ -33,8 +34,8 @@ func TestParseSQL(t *testing.T) {
|
||||
},
|
||||
{
|
||||
sql: "select max(score) as max_score, min(age) as min_age from Sensor group by type, SlidingWindow('20s', '5s')",
|
||||
expected: &stream.Config{
|
||||
WindowConfig: stream.WindowConfig{
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "sliding",
|
||||
Params: map[string]interface{}{
|
||||
"size": 20 * time.Second,
|
||||
@@ -43,12 +44,29 @@ func TestParseSQL(t *testing.T) {
|
||||
},
|
||||
GroupFields: []string{"type"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"max_score": "max",
|
||||
"min_age": "min",
|
||||
"score": "max",
|
||||
"age": "min",
|
||||
},
|
||||
},
|
||||
condition: "",
|
||||
},
|
||||
{
|
||||
sql: "select deviceId, avg(temperature/10) as aa from Input where deviceId='aa' group by deviceId, TumblingWindow('10s') with (TIMESTAMP='ts') ",
|
||||
expected: &model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{
|
||||
"size": 10 * time.Second,
|
||||
},
|
||||
TsProp: "ts",
|
||||
},
|
||||
GroupFields: []string{"deviceId"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"aa": "avg",
|
||||
},
|
||||
},
|
||||
condition: "deviceId == 'aa'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -64,6 +82,9 @@ func TestParseSQL(t *testing.T) {
|
||||
assert.Equal(t, tt.expected.GroupFields, config.GroupFields)
|
||||
assert.Equal(t, tt.expected.SelectFields, config.SelectFields)
|
||||
assert.Equal(t, tt.condition, cond)
|
||||
if tt.expected.WindowConfig.TsProp != "" {
|
||||
assert.Equal(t, tt.expected.WindowConfig.TsProp, config.WindowConfig.TsProp)
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestWindowParamParsing(t *testing.T) {
|
||||
|
||||
+14
-17
@@ -2,36 +2,27 @@ package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
aggregator2 "github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/rulego/streamsql/parser"
|
||||
"github.com/rulego/streamsql/window"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WindowConfig WindowConfig
|
||||
GroupFields []string
|
||||
SelectFields map[string]aggregator2.AggregateType
|
||||
}
|
||||
|
||||
type WindowConfig struct {
|
||||
Type string
|
||||
Params map[string]interface{}
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
dataChan chan interface{}
|
||||
filter parser.Condition
|
||||
Window window.Window
|
||||
aggregator aggregator2.Aggregator
|
||||
config Config
|
||||
config model.Config
|
||||
sinks []func(interface{})
|
||||
resultChan chan interface{} // 结果通道
|
||||
}
|
||||
|
||||
func NewStream(config Config) (*Stream, error) {
|
||||
win, err := window.CreateWindow(config.WindowConfig.Type, config.WindowConfig.Params)
|
||||
func NewStream(config model.Config) (*Stream, error) {
|
||||
win, err := window.CreateWindow(config.WindowConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -44,6 +35,9 @@ func NewStream(config Config) (*Stream, error) {
|
||||
}
|
||||
|
||||
func (s *Stream) RegisterFilter(condition string) error {
|
||||
if strings.TrimSpace(condition) == "" {
|
||||
return nil
|
||||
}
|
||||
filter, err := parser.NewExprCondition(condition)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compile filter error: %w", err)
|
||||
@@ -57,7 +51,7 @@ func (s *Stream) Start() {
|
||||
}
|
||||
|
||||
func (s *Stream) process() {
|
||||
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields)
|
||||
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias)
|
||||
|
||||
// 启动窗口处理协程
|
||||
s.Window.Start()
|
||||
@@ -76,10 +70,13 @@ func (s *Stream) process() {
|
||||
case batch := <-s.Window.OutputChan():
|
||||
// 处理窗口批数据
|
||||
for _, item := range batch {
|
||||
if err := s.aggregator.Add(item); err != nil {
|
||||
s.aggregator.Put("window_start", item.Slot.WindowStart())
|
||||
s.aggregator.Put("window_end", item.Slot.WindowEnd())
|
||||
if err := s.aggregator.Add(item.Data); err != nil {
|
||||
fmt.Printf("aggregate error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取并发送聚合结果
|
||||
if results, err := s.aggregator.GetResults(); err == nil {
|
||||
// 发送结果到结果通道和 Sink 函数
|
||||
@@ -108,5 +105,5 @@ func (s *Stream) GetResultsChan() <-chan interface{} {
|
||||
}
|
||||
|
||||
func NewStreamProcessor() (*Stream, error) {
|
||||
return NewStream(Config{})
|
||||
return NewStream(model.Config{})
|
||||
}
|
||||
|
||||
+99
-6
@@ -7,13 +7,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamProcess(t *testing.T) {
|
||||
config := Config{
|
||||
WindowConfig: WindowConfig{
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{"size": time.Second},
|
||||
},
|
||||
@@ -77,8 +78,8 @@ func TestStreamProcess(t *testing.T) {
|
||||
|
||||
// 不设置过滤器
|
||||
func TestStreamWithoutFilter(t *testing.T) {
|
||||
config := Config{
|
||||
WindowConfig: WindowConfig{
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "sliding",
|
||||
Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second},
|
||||
},
|
||||
@@ -158,8 +159,8 @@ func TestStreamWithoutFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIncompleteStreamProcess(t *testing.T) {
|
||||
config := Config{
|
||||
WindowConfig: WindowConfig{
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "tumbling",
|
||||
Params: map[string]interface{}{"size": time.Second},
|
||||
},
|
||||
@@ -222,3 +223,95 @@ func TestIncompleteStreamProcess(t *testing.T) {
|
||||
assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001)
|
||||
assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001)
|
||||
}
|
||||
|
||||
func TestWindowSlotAgg(t *testing.T) {
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "sliding",
|
||||
Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second},
|
||||
TsProp: "ts",
|
||||
},
|
||||
GroupFields: []string{"device"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"age": aggregator.Max,
|
||||
"score": aggregator.Min,
|
||||
"start": aggregator.WindowStart,
|
||||
"end": aggregator.WindowEnd,
|
||||
},
|
||||
}
|
||||
|
||||
strm, err := NewStream(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
strm.Start()
|
||||
// Add data every 500ms
|
||||
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
|
||||
|
||||
testData := []interface{}{
|
||||
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "ts": baseTime},
|
||||
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "ts": baseTime.Add(1 * time.Second)},
|
||||
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "ts": baseTime},
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
}
|
||||
|
||||
// 捕获结果
|
||||
resultChan := make(chan interface{})
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
// 等待 3 秒触发窗口
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var actual interface{}
|
||||
select {
|
||||
case actual = <-resultChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Timeout waiting for results")
|
||||
}
|
||||
|
||||
expected := []map[string]interface{}{
|
||||
{
|
||||
"device": "aa",
|
||||
"age_max": 10.0,
|
||||
"score_min": 100.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
{
|
||||
"device": "bb",
|
||||
"age_max": 3.0,
|
||||
"score_min": 300.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
}
|
||||
|
||||
assert.IsType(t, []map[string]interface{}{}, actual)
|
||||
resultSlice, ok := actual.([]map[string]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Len(t, resultSlice, 2)
|
||||
for _, expectedResult := range expected {
|
||||
found := false
|
||||
for _, resultMap := range resultSlice {
|
||||
//if resultMap, ok := result.(map[string]interface{}); ok {
|
||||
if resultMap["device"] == expectedResult["device"] {
|
||||
assert.InEpsilon(t, expectedResult["age_max"].(float64), resultMap["age_max"].(float64), 0.0001)
|
||||
assert.InEpsilon(t, expectedResult["score_min"].(float64), resultMap["score_min"].(float64), 0.0001)
|
||||
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
|
||||
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
|
||||
found = true
|
||||
break
|
||||
}
|
||||
//}
|
||||
}
|
||||
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
|
||||
}
|
||||
}
|
||||
|
||||
+74
-2
@@ -1,13 +1,85 @@
|
||||
package streamsql
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamsql(t *testing.T) {
|
||||
streamsql := New()
|
||||
var rsql = ""
|
||||
var rsql = "SELECT device,max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream group by device,SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
|
||||
testData := []interface{}{
|
||||
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime},
|
||||
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)},
|
||||
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime},
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
}
|
||||
// 捕获结果
|
||||
resultChan := make(chan interface{})
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
// 等待 3 秒触发窗口
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var actual interface{}
|
||||
select {
|
||||
case actual = <-resultChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Timeout waiting for results")
|
||||
}
|
||||
|
||||
expected := []map[string]interface{}{
|
||||
{
|
||||
"device": "aa",
|
||||
"max_age": 10.0,
|
||||
"min_score": 100.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
{
|
||||
"device": "bb",
|
||||
"max_age": 3.0,
|
||||
"min_score": 300.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
}
|
||||
|
||||
assert.IsType(t, []map[string]interface{}{}, actual)
|
||||
resultSlice, ok := actual.([]map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, resultSlice, 2)
|
||||
for _, expectedResult := range expected {
|
||||
found := false
|
||||
for _, resultMap := range resultSlice {
|
||||
//if resultMap, ok := result.(map[string]interface{}); ok {
|
||||
if resultMap["device"] == expectedResult["device"] {
|
||||
assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001)
|
||||
assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001)
|
||||
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
|
||||
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
|
||||
found = true
|
||||
break
|
||||
}
|
||||
//}
|
||||
}
|
||||
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package timex
|
||||
|
||||
import "time"
|
||||
|
||||
// AlignTimeToWindow 将时间对齐到窗口的起始时间。
|
||||
func AlignTimeToWindow(t time.Time, size time.Duration) time.Time {
|
||||
offset := t.UnixNano() % int64(size)
|
||||
return t.Add(time.Duration(-offset))
|
||||
}
|
||||
|
||||
// AlignTime 将时间对齐到指定的时间单位。 roundUp 为 true 时向上截断,为 false 时向下截断。
|
||||
func AlignTime(t time.Time, timeUnit time.Duration, roundUp bool) time.Time {
|
||||
trunc := t.Truncate(timeUnit)
|
||||
if !roundUp {
|
||||
return trunc.Add(timeUnit)
|
||||
}
|
||||
return trunc
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright 2021 EMQ Technologies Co., Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package timex
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAlignTimeToWindow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input time.Time
|
||||
size time.Duration
|
||||
expected time.Time
|
||||
}{
|
||||
{
|
||||
name: "对齐到1分钟窗口",
|
||||
input: time.Date(2024, 1, 1, 12, 35, 56, 789000000, time.UTC),
|
||||
size: 3 * time.Minute,
|
||||
expected: time.Date(2024, 1, 1, 12, 33, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "对齐到5分钟窗口",
|
||||
input: time.Date(2024, 1, 1, 12, 37, 56, 789000000, time.UTC),
|
||||
size: 5 * time.Minute,
|
||||
expected: time.Date(2024, 1, 1, 12, 35, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "对齐到1小时窗口",
|
||||
input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC),
|
||||
size: time.Hour,
|
||||
expected: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "对齐到1天窗口",
|
||||
input: time.Date(2024, 1, 1, 12, 34, 56, 789000000, time.UTC),
|
||||
size: 24 * time.Hour,
|
||||
expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "零时刻对齐测试",
|
||||
input: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
size: time.Hour,
|
||||
expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := AlignTimeToWindow(tt.input, tt.size)
|
||||
if !got.Equal(tt.expected) {
|
||||
t.Errorf("AlignTimeToWindow() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+84
-22
@@ -2,56 +2,90 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
timex "github.com/rulego/streamsql/utils"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var _ Window = (*CountingWindow)(nil)
|
||||
|
||||
type CountingWindow struct {
|
||||
config model.WindowConfig
|
||||
threshold int
|
||||
count int
|
||||
mu sync.Mutex
|
||||
callback func([]interface{})
|
||||
dataBuffer []interface{}
|
||||
outputChan chan []interface{}
|
||||
callback func([]model.Row)
|
||||
dataBuffer []model.Row
|
||||
outputChan chan []model.Row
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
ticker *time.Ticker
|
||||
triggerChan chan struct{}
|
||||
}
|
||||
|
||||
func NewCountingWindow(threshold int, callback func([]interface{})) *CountingWindow {
|
||||
func NewCountingWindow(config model.WindowConfig) (*CountingWindow, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &CountingWindow{
|
||||
threshold := cast.ToInt(config.Params["count"])
|
||||
if threshold <= 0 {
|
||||
return nil, fmt.Errorf("threshold must be a positive integer")
|
||||
}
|
||||
|
||||
cw := &CountingWindow{
|
||||
threshold: threshold,
|
||||
dataBuffer: make([]interface{}, 0, threshold),
|
||||
outputChan: make(chan []interface{}, 10),
|
||||
dataBuffer: make([]model.Row, 0, threshold),
|
||||
outputChan: make(chan []model.Row, 10),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
callback: callback,
|
||||
triggerChan: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
if callback, ok := config.Params["callback"].(func([]model.Row)); ok {
|
||||
cw.SetCallback(callback)
|
||||
}
|
||||
return cw, nil
|
||||
}
|
||||
|
||||
func (cw *CountingWindow) Add(data interface{}) {
|
||||
cw.mu.Lock()
|
||||
cw.dataBuffer = append(cw.dataBuffer, data)
|
||||
defer cw.mu.Unlock()
|
||||
// 将数据添加到窗口的数据列表中
|
||||
t := GetTimestamp(data, cw.config.TsProp)
|
||||
row := model.Row{
|
||||
Data: data,
|
||||
Timestamp: t,
|
||||
}
|
||||
cw.dataBuffer = append(cw.dataBuffer, row)
|
||||
cw.count++
|
||||
shouldTrigger := cw.count >= cw.threshold
|
||||
cw.mu.Unlock()
|
||||
|
||||
if shouldTrigger {
|
||||
cw.mu.Lock()
|
||||
v := append([]interface{}{}, cw.dataBuffer...)
|
||||
cw.mu.Unlock()
|
||||
|
||||
slot := cw.createSlot(cw.dataBuffer[:cw.threshold])
|
||||
for _, r := range cw.dataBuffer[:cw.threshold] {
|
||||
// 由于Row是值类型,这里需要通过指针来修改Slot字段
|
||||
(&r).Slot = slot
|
||||
}
|
||||
data := cw.dataBuffer[:cw.threshold]
|
||||
if len(cw.dataBuffer) > cw.threshold {
|
||||
remaining := len(cw.dataBuffer) - cw.threshold
|
||||
newBuffer := make([]model.Row, remaining, cw.threshold)
|
||||
copy(newBuffer, cw.dataBuffer[cw.threshold:])
|
||||
cw.dataBuffer = newBuffer
|
||||
} else {
|
||||
cw.dataBuffer = make([]model.Row, 0, cw.threshold)
|
||||
}
|
||||
go func() {
|
||||
cw.mu.Lock()
|
||||
if cw.callback != nil {
|
||||
cw.callback(v)
|
||||
cw.callback(data)
|
||||
}
|
||||
cw.outputChan <- v
|
||||
cw.Reset()
|
||||
cw.outputChan <- data
|
||||
cw.count = 0
|
||||
//cw.Reset()
|
||||
cw.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -66,7 +100,7 @@ func (cw *CountingWindow) Start() {
|
||||
for {
|
||||
select {
|
||||
case <-cw.ticker.C:
|
||||
cw.Trigger()
|
||||
//cw.Trigger()
|
||||
case <-cw.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -82,7 +116,17 @@ func (cw *CountingWindow) Trigger() {
|
||||
defer cw.mu.Unlock()
|
||||
|
||||
if cw.callback != nil && len(cw.dataBuffer) > 0 {
|
||||
cw.callback(cw.dataBuffer)
|
||||
var resultData []model.Row
|
||||
if len(cw.dataBuffer) > cw.threshold {
|
||||
resultData = cw.dataBuffer[:cw.threshold]
|
||||
} else {
|
||||
resultData = cw.dataBuffer
|
||||
}
|
||||
slot := cw.createSlot(resultData)
|
||||
for _, r := range resultData {
|
||||
r.Slot = slot
|
||||
}
|
||||
cw.callback(resultData)
|
||||
}
|
||||
cw.Reset()
|
||||
}()
|
||||
@@ -95,9 +139,27 @@ func (cw *CountingWindow) Reset() {
|
||||
cw.dataBuffer = cw.dataBuffer[:0]
|
||||
}
|
||||
|
||||
func (cw *CountingWindow) OutputChan() <-chan []interface{} {
|
||||
func (cw *CountingWindow) OutputChan() <-chan []model.Row {
|
||||
return cw.outputChan
|
||||
}
|
||||
func (cw *CountingWindow) GetResults() []interface{} {
|
||||
return append([]interface{}{}, cw.dataBuffer...)
|
||||
|
||||
// func (cw *CountingWindow) GetResults() []interface{} {
|
||||
// return append([]mode.Row, cw.dataBuffer...)
|
||||
// }
|
||||
|
||||
// createSlot 创建一个新的时间槽位
|
||||
func (cw *CountingWindow) createSlot(data []model.Row) *model.TimeSlot {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
} else if len(data) < cw.threshold {
|
||||
start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true)
|
||||
end := timex.AlignTime(data[len(cw.dataBuffer)-1].Timestamp, cw.config.TimeUnit, false)
|
||||
slot := model.NewTimeSlot(&start, &end)
|
||||
return slot
|
||||
} else {
|
||||
start := timex.AlignTime(data[0].Timestamp, cw.config.TimeUnit, true)
|
||||
end := timex.AlignTime(data[cw.threshold-1].Timestamp, cw.config.TimeUnit, false)
|
||||
slot := model.NewTimeSlot(&start, &end)
|
||||
return slot
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -14,8 +16,13 @@ func TestCountingWindow(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
// Test case 1: Normal operation
|
||||
cw := NewCountingWindow(3, func(results []interface{}) {
|
||||
t.Logf("Received results: %v", results)
|
||||
cw, _ := NewCountingWindow(model.WindowConfig{
|
||||
Params: map[string]interface{}{
|
||||
"count": 3,
|
||||
"callback": func(results []interface{}) {
|
||||
t.Logf("Received results: %v", results)
|
||||
},
|
||||
},
|
||||
})
|
||||
go cw.Start()
|
||||
|
||||
@@ -27,31 +34,35 @@ func TestCountingWindow(t *testing.T) {
|
||||
// Trigger one more element to check threshold
|
||||
cw.Add(3)
|
||||
|
||||
results := make(chan []interface{})
|
||||
go func() {
|
||||
for res := range cw.OutputChan() {
|
||||
results <- res
|
||||
}
|
||||
}()
|
||||
resultsChan := cw.OutputChan()
|
||||
//results := make(chan []model.Row)
|
||||
// go func() {
|
||||
// for res := range cw.OutputChan() {
|
||||
// results <- res
|
||||
// }
|
||||
// }()
|
||||
|
||||
select {
|
||||
case res := <-results:
|
||||
case res := <-resultsChan:
|
||||
assert.Len(t, res, 3)
|
||||
assert.Contains(t, res, 0)
|
||||
assert.Contains(t, res, 1)
|
||||
assert.Contains(t, res, 2)
|
||||
assert.Equal(t, 0, res[0].Data, "第一个元素应该是0")
|
||||
assert.Equal(t, 1, res[1].Data, "第二个元素应该是1")
|
||||
assert.Equal(t, 2, res[2].Data, "第三个元素应该是2")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("No results received within timeout")
|
||||
}
|
||||
|
||||
assert.Len(t, cw.dataBuffer, 1)
|
||||
// Test case 2: Reset
|
||||
cw.Reset()
|
||||
assert.Len(t, cw.dataBuffer, 0)
|
||||
}
|
||||
|
||||
func TestCountingWindowBadThreshold(t *testing.T) {
|
||||
_, err := CreateWindow("counting", map[string]interface{}{
|
||||
"count": 0,
|
||||
_, err := CreateWindow(model.WindowConfig{
|
||||
Type: "counting",
|
||||
Params: map[string]interface{}{
|
||||
"count": 0,
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
+42
-31
@@ -2,7 +2,10 @@ package window
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/cast"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -14,47 +17,55 @@ const (
|
||||
|
||||
type Window interface {
|
||||
Add(item interface{})
|
||||
GetResults() []interface{}
|
||||
//GetResults() []interface{}
|
||||
Reset()
|
||||
Start()
|
||||
OutputChan() <-chan []interface{}
|
||||
SetCallback(callback func([]interface{}))
|
||||
OutputChan() <-chan []model.Row
|
||||
SetCallback(callback func([]model.Row))
|
||||
Trigger()
|
||||
}
|
||||
|
||||
func CreateWindow(windowType string, params map[string]interface{}) (Window, error) {
|
||||
switch windowType {
|
||||
func CreateWindow(config model.WindowConfig) (Window, error) {
|
||||
switch config.Type {
|
||||
case TypeTumbling:
|
||||
size, err := cast.ToDurationE(params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for tumbling window: %v", err)
|
||||
}
|
||||
return NewTumblingWindow(size), nil
|
||||
return NewTumblingWindow(config)
|
||||
case TypeSliding:
|
||||
size, err := cast.ToDurationE(params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for sliding window: %v", err)
|
||||
}
|
||||
slide, err := cast.ToDurationE(params["slide"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid slide for sliding window: %v", err)
|
||||
}
|
||||
return NewSlidingWindow(size, slide), nil
|
||||
return NewSlidingWindow(config)
|
||||
case TypeCounting:
|
||||
count := cast.ToInt(params["count"])
|
||||
if count <= 0 {
|
||||
return nil, fmt.Errorf("count must be a positive integer")
|
||||
}
|
||||
cw := NewCountingWindow(count, nil)
|
||||
if callback, ok := params["callback"].(func([]interface{})); ok {
|
||||
cw.SetCallback(callback)
|
||||
}
|
||||
return cw, nil
|
||||
return NewCountingWindow(config)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported window type: %s", windowType)
|
||||
return nil, fmt.Errorf("unsupported window type: %s", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *CountingWindow) SetCallback(callback func([]interface{})) {
|
||||
func (cw *CountingWindow) SetCallback(callback func([]model.Row)) {
|
||||
cw.callback = callback
|
||||
}
|
||||
|
||||
// GetTimestamp 从数据中获取时间戳。
|
||||
func GetTimestamp(data interface{}, tsProp string) time.Time {
|
||||
if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok {
|
||||
return ts.GetTimestamp()
|
||||
} else if tsProp != "" {
|
||||
v := reflect.ValueOf(data)
|
||||
|
||||
// 处理不同类型
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
// 如果是结构体,使用反射获取字段值
|
||||
if f := v.FieldByName(tsProp); f.IsValid() {
|
||||
if t, ok := f.Interface().(time.Time); ok {
|
||||
return t
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
// 如果是map,直接通过key获取值
|
||||
if v.Type().Key().Kind() == reflect.String {
|
||||
if value := v.MapIndex(reflect.ValueOf(tsProp)); value.IsValid() {
|
||||
return value.Interface().(time.Time)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
+72
-42
@@ -2,8 +2,13 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
timex "github.com/rulego/streamsql/utils"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// 确保 SlidingWindow 结构体实现了 Window 接口
|
||||
@@ -11,12 +16,14 @@ var _ Window = (*SlidingWindow)(nil)
|
||||
|
||||
// TimedData 用于包装数据和时间戳
|
||||
type TimedData struct {
|
||||
Data interface{}
|
||||
Timestamp time.Time
|
||||
Data interface{}
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// SlidingWindow 表示一个滑动窗口,用于按时间范围处理数据
|
||||
type SlidingWindow struct {
|
||||
// config 窗口的配置信息
|
||||
config model.WindowConfig
|
||||
// 窗口的总大小,即窗口覆盖的时间范围
|
||||
size time.Duration
|
||||
// 窗口每次滑动的时间间隔
|
||||
@@ -24,32 +31,42 @@ type SlidingWindow struct {
|
||||
// 用于保护数据并发访问的互斥锁
|
||||
mu sync.Mutex
|
||||
// 存储窗口内的数据
|
||||
data []TimedData
|
||||
data []model.Row
|
||||
// 用于输出窗口内数据的通道
|
||||
outputChan chan []interface{}
|
||||
outputChan chan []model.Row
|
||||
// 当窗口触发时执行的回调函数
|
||||
callback func([]interface{})
|
||||
callback func([]model.Row)
|
||||
// 用于控制窗口生命周期的上下文
|
||||
ctx context.Context
|
||||
// 用于取消上下文的函数
|
||||
cancelFunc context.CancelFunc
|
||||
// 用于定时触发窗口的定时器
|
||||
timer *time.Timer
|
||||
timer *time.Timer
|
||||
currentSlot *model.TimeSlot
|
||||
}
|
||||
|
||||
// NewSlidingWindow 创建一个新的滑动窗口实例
|
||||
// 参数 size 表示窗口的总大小,slide 表示窗口每次滑动的时间间隔
|
||||
func NewSlidingWindow(size, slide time.Duration) *SlidingWindow {
|
||||
func NewSlidingWindow(config model.WindowConfig) (*SlidingWindow, error) {
|
||||
// 创建一个可取消的上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
size, err := cast.ToDurationE(config.Params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for sliding window: %v", err)
|
||||
}
|
||||
slide, err := cast.ToDurationE(config.Params["slide"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid slide for sliding window: %v", err)
|
||||
}
|
||||
return &SlidingWindow{
|
||||
config: config,
|
||||
size: size,
|
||||
slide: slide,
|
||||
outputChan: make(chan []interface{}, 10),
|
||||
outputChan: make(chan []model.Row, 10),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
data: make([]TimedData, 0),
|
||||
}
|
||||
data: make([]model.Row, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add 向滑动窗口中添加数据
|
||||
@@ -58,19 +75,18 @@ func (sw *SlidingWindow) Add(data interface{}) {
|
||||
// 加锁以保证数据的并发安全
|
||||
sw.mu.Lock()
|
||||
defer sw.mu.Unlock()
|
||||
|
||||
var timestamp time.Time
|
||||
if ts, ok := data.(interface{ GetTimestamp() time.Time }); ok {
|
||||
timestamp = ts.GetTimestamp()
|
||||
} else {
|
||||
timestamp = time.Now()
|
||||
}
|
||||
|
||||
// 将数据添加到窗口的数据列表中
|
||||
sw.data = append(sw.data, TimedData{
|
||||
Data: data,
|
||||
Timestamp: timestamp,
|
||||
})
|
||||
// 将数据添加到窗口的数据列表中
|
||||
t := GetTimestamp(data, sw.config.TsProp)
|
||||
if sw.currentSlot == nil {
|
||||
sw.currentSlot = sw.createSlot(t)
|
||||
}
|
||||
go func() {
|
||||
row := model.Row{
|
||||
Data: data,
|
||||
Timestamp: t,
|
||||
}
|
||||
sw.data = append(sw.data, row)
|
||||
}()
|
||||
}
|
||||
|
||||
// Start 启动滑动窗口,开始定时触发窗口
|
||||
@@ -106,19 +122,25 @@ func (sw *SlidingWindow) Trigger() {
|
||||
}
|
||||
|
||||
// 计算截止时间,即当前时间减去窗口的总大小
|
||||
cutoff := time.Now().Add(-sw.size)
|
||||
var newData []TimedData
|
||||
// 遍历窗口内的数据,只保留在截止时间之后的数据
|
||||
next := sw.NextSlot()
|
||||
// 保留下一个窗口的数据
|
||||
tms := next.Start.Add(-sw.size)
|
||||
tme := next.End.Add(sw.size)
|
||||
temp := model.NewTimeSlot(&tms, &tme)
|
||||
newData := make([]model.Row, 0)
|
||||
for _, item := range sw.data {
|
||||
if item.Timestamp.After(cutoff) {
|
||||
if temp.Contains(item.Timestamp) {
|
||||
newData = append(newData, item)
|
||||
}
|
||||
}
|
||||
|
||||
// 提取出 Data 字段组成 []interface{} 类型的数据
|
||||
resultData := make([]interface{}, 0, len(newData))
|
||||
for _, item := range newData {
|
||||
resultData = append(resultData, item.Data)
|
||||
resultData := make([]model.Row, 0)
|
||||
for _, item := range sw.data {
|
||||
if sw.currentSlot.Contains(item.Timestamp) {
|
||||
item.Slot = sw.currentSlot
|
||||
resultData = append(resultData, item)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果设置了回调函数,则执行回调函数
|
||||
@@ -128,6 +150,7 @@ func (sw *SlidingWindow) Trigger() {
|
||||
|
||||
// 更新窗口内的数据
|
||||
sw.data = newData
|
||||
sw.currentSlot = next
|
||||
// 将新的数据发送到输出通道
|
||||
sw.outputChan <- resultData
|
||||
}
|
||||
@@ -139,28 +162,35 @@ func (sw *SlidingWindow) Reset() {
|
||||
defer sw.mu.Unlock()
|
||||
// 清空窗口内的数据
|
||||
sw.data = nil
|
||||
sw.currentSlot = nil
|
||||
}
|
||||
|
||||
// OutputChan 返回滑动窗口的输出通道
|
||||
func (sw *SlidingWindow) OutputChan() <-chan []interface{} {
|
||||
func (sw *SlidingWindow) OutputChan() <-chan []model.Row {
|
||||
return sw.outputChan
|
||||
}
|
||||
|
||||
// SetCallback 设置滑动窗口触发时执行的回调函数
|
||||
// 参数 callback 表示要设置的回调函数
|
||||
func (sw *SlidingWindow) SetCallback(callback func([]interface{})) {
|
||||
func (sw *SlidingWindow) SetCallback(callback func([]model.Row)) {
|
||||
sw.callback = callback
|
||||
}
|
||||
|
||||
// GetResults 获取滑动窗口内的当前数据
|
||||
func (sw *SlidingWindow) GetResults() []interface{} {
|
||||
// 加锁以保证数据的并发安全
|
||||
sw.mu.Lock()
|
||||
defer sw.mu.Unlock()
|
||||
// 提取出 Data 字段组成 []interface{} 类型的数据
|
||||
resultData := make([]interface{}, 0, len(sw.data))
|
||||
for _, item := range sw.data {
|
||||
resultData = append(resultData, item.Data)
|
||||
func (sw *SlidingWindow) NextSlot() *model.TimeSlot {
|
||||
if sw.currentSlot == nil {
|
||||
return nil
|
||||
}
|
||||
return resultData
|
||||
start := sw.currentSlot.Start.Add(sw.slide)
|
||||
end := sw.currentSlot.End.Add(sw.slide)
|
||||
next := model.NewTimeSlot(&start, &end)
|
||||
return next
|
||||
}
|
||||
|
||||
// createSlot 创建一个新的时间槽位
|
||||
func (sw *SlidingWindow) createSlot(t time.Time) *model.TimeSlot {
|
||||
// 创建一个新的时间槽位
|
||||
start := timex.AlignTimeToWindow(t, sw.size)
|
||||
end := start.Add(sw.size)
|
||||
slot := model.NewTimeSlot(&start, &end)
|
||||
return slot
|
||||
}
|
||||
|
||||
+103
-17
@@ -2,42 +2,128 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
timex "github.com/rulego/streamsql/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSlidingWindow(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sw := NewSlidingWindow(2*time.Second, 1*time.Second)
|
||||
sw.SetCallback(func(results []interface{}) {
|
||||
sw, _ := NewSlidingWindow(model.WindowConfig{
|
||||
Params: map[string]interface{}{
|
||||
"size": "2s",
|
||||
"slide": "1s",
|
||||
},
|
||||
TsProp: "Ts",
|
||||
TimeUnit: time.Second,
|
||||
})
|
||||
sw.SetCallback(func(results []model.Row) {
|
||||
t.Logf("Received results: %v", results)
|
||||
})
|
||||
sw.Start()
|
||||
|
||||
// 添加数据
|
||||
now := time.Now()
|
||||
sw.Add(now.Add(-3 * time.Second))
|
||||
sw.Add(now.Add(-2 * time.Second))
|
||||
sw.Add(now.Add(-1 * time.Second))
|
||||
sw.Add(now)
|
||||
t_3 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 56, 789000000, time.UTC), tag: "1"}
|
||||
t_2 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 57, 789000000, time.UTC), tag: "2"}
|
||||
t_1 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 58, 789000000, time.UTC), tag: "3"}
|
||||
t_0 := TestDate{Ts: time.Date(2025, 4, 7, 16, 46, 59, 789000000, time.UTC), tag: "4"}
|
||||
|
||||
sw.Add(t_3)
|
||||
sw.Add(t_2)
|
||||
sw.Add(t_1)
|
||||
sw.Add(t_0)
|
||||
|
||||
// 等待一段时间,触发窗口
|
||||
time.Sleep(3 * time.Second)
|
||||
//time.Sleep(3 * time.Second)
|
||||
|
||||
// 检查结果
|
||||
resultsChan := sw.OutputChan()
|
||||
var results []interface{}
|
||||
select {
|
||||
case results = <-resultsChan:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("No results received within timeout")
|
||||
var results []model.Row
|
||||
|
||||
for {
|
||||
select {
|
||||
case results = <-resultsChan:
|
||||
raw := make([]TestDate, 0)
|
||||
for _, row := range results {
|
||||
raw = append(raw, row.Data.(TestDate))
|
||||
}
|
||||
|
||||
// 获取当前窗口的时间范围
|
||||
windowStart := results[0].Slot.Start
|
||||
windowEnd := results[0].Slot.End
|
||||
t.Logf("Window range: %v - %v", windowStart, windowEnd)
|
||||
|
||||
// 检查窗口内的数据
|
||||
expectedData := make([]TestDate, 0)
|
||||
|
||||
if windowStart.Before(t_3.Ts) && windowEnd.After(t_2.Ts) {
|
||||
expectedData = []TestDate{t_3, t_2}
|
||||
start := timex.AlignTimeToWindow(t_3.Ts, sw.size)
|
||||
assert.Equal(t, start, windowStart)
|
||||
assert.Equal(t, start.Add(sw.size), windowEnd)
|
||||
} else if windowStart.Before(t_2.Ts) && windowEnd.After(t_1.Ts) {
|
||||
expectedData = []TestDate{t_2, t_1}
|
||||
start := timex.AlignTimeToWindow(t_2.Ts, sw.size)
|
||||
assert.Equal(t, start, windowStart)
|
||||
assert.Equal(t, start.Add(sw.size), windowEnd)
|
||||
} else if windowStart.Before(t_1.Ts) && windowEnd.After(t_0.Ts) {
|
||||
expectedData = []TestDate{t_1, t_0}
|
||||
start := timex.AlignTimeToWindow(t_1.Ts, sw.size)
|
||||
assert.Equal(t, start, windowStart)
|
||||
assert.Equal(t, start.Add(sw.size), windowEnd)
|
||||
} else {
|
||||
expectedData = []TestDate{t_0}
|
||||
start := timex.AlignTimeToWindow(t_0.Ts, sw.size)
|
||||
assert.Equal(t, start, windowStart)
|
||||
assert.Equal(t, start.Add(sw.size), windowEnd)
|
||||
}
|
||||
|
||||
// 验证窗口数据
|
||||
assert.Equal(t, len(expectedData), len(raw), "窗口数据数量不匹配")
|
||||
for _, expected := range expectedData {
|
||||
assert.Contains(t, raw, expected, "窗口缺少预期数据")
|
||||
}
|
||||
default:
|
||||
// 通道为空时退出
|
||||
goto END
|
||||
}
|
||||
}
|
||||
|
||||
END:
|
||||
// 预期结果:保留最近 2 秒内的数据
|
||||
assert.Len(t, results, 2)
|
||||
assert.Contains(t, results, now.Add(-1*time.Second))
|
||||
assert.Contains(t, results, now)
|
||||
assert.Len(t, results, 0)
|
||||
}
|
||||
|
||||
type TestDate struct {
|
||||
Ts time.Time
|
||||
tag string
|
||||
}
|
||||
|
||||
type TestDate2 struct {
|
||||
ts time.Time
|
||||
}
|
||||
|
||||
func (d TestDate2) GetTimestamp() time.Time {
|
||||
return d.ts
|
||||
}
|
||||
|
||||
func TestGetTimestamp(t *testing.T) {
|
||||
t_0 := time.Now()
|
||||
data := map[string]interface{}{"device": "aa", "age": 15.0, "score": 100, "ts": t_0}
|
||||
t_1 := GetTimestamp(data, "ts")
|
||||
|
||||
data_1 := TestDate{Ts: t_0}
|
||||
t_2 := GetTimestamp(data_1, "Ts")
|
||||
|
||||
data_2 := TestDate2{ts: t_0}
|
||||
t_3 := GetTimestamp(data_2, "")
|
||||
|
||||
assert.Equal(t, t_0, t_1)
|
||||
assert.Equal(t, t_0, t_2)
|
||||
assert.Equal(t, t_0, t_3)
|
||||
}
|
||||
|
||||
+88
-26
@@ -3,8 +3,13 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
timex "github.com/rulego/streamsql/utils"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// 确保 TumblingWindow 结构体实现了 Window 接口。
|
||||
@@ -12,35 +17,43 @@ var _ Window = (*TumblingWindow)(nil)
|
||||
|
||||
// TumblingWindow 表示一个滚动窗口,用于在固定时间间隔内收集数据并触发处理。
|
||||
type TumblingWindow struct {
|
||||
// config 是窗口的配置信息。
|
||||
config model.WindowConfig
|
||||
// size 是滚动窗口的时间大小,即窗口的持续时间。
|
||||
size time.Duration
|
||||
// mu 用于保护对窗口数据的并发访问。
|
||||
mu sync.Mutex
|
||||
// data 存储窗口内收集的数据。
|
||||
data []interface{}
|
||||
data []model.Row
|
||||
// outputChan 是一个通道,用于在窗口触发时发送数据。
|
||||
outputChan chan []interface{}
|
||||
outputChan chan []model.Row
|
||||
// callback 是一个可选的回调函数,在窗口触发时调用。
|
||||
callback func([]interface{})
|
||||
callback func([]model.Row)
|
||||
// ctx 用于控制窗口的生命周期。
|
||||
ctx context.Context
|
||||
// cancelFunc 用于取消窗口的操作。
|
||||
cancelFunc context.CancelFunc
|
||||
// timer 用于定时触发窗口。
|
||||
timer *time.Timer
|
||||
timer *time.Timer
|
||||
currentSlot *model.TimeSlot
|
||||
}
|
||||
|
||||
// NewTumblingWindow 创建一个新的滚动窗口实例。
|
||||
// 参数 size 是窗口的时间大小。
|
||||
func NewTumblingWindow(size time.Duration) *TumblingWindow {
|
||||
func NewTumblingWindow(config model.WindowConfig) (*TumblingWindow, error) {
|
||||
// 创建一个可取消的上下文。
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
size, err := cast.ToDurationE(config.Params["size"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid size for tumbling window: %v", err)
|
||||
}
|
||||
return &TumblingWindow{
|
||||
config: config,
|
||||
size: size,
|
||||
outputChan: make(chan []interface{}, 10),
|
||||
outputChan: make(chan []model.Row, 10),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add 向滚动窗口添加数据。
|
||||
@@ -50,7 +63,33 @@ func (tw *TumblingWindow) Add(data interface{}) {
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
// 将数据追加到窗口的数据列表中。
|
||||
tw.data = append(tw.data, data)
|
||||
if tw.currentSlot == nil {
|
||||
tw.currentSlot = tw.createSlot(GetTimestamp(data, tw.config.TsProp))
|
||||
}
|
||||
go func() {
|
||||
row := model.Row{
|
||||
Data: data,
|
||||
Timestamp: GetTimestamp(data, tw.config.TsProp),
|
||||
}
|
||||
tw.data = append(tw.data, row)
|
||||
}()
|
||||
}
|
||||
|
||||
func (sw *TumblingWindow) createSlot(t time.Time) *model.TimeSlot {
|
||||
// 创建一个新的时间槽位
|
||||
start := timex.AlignTimeToWindow(t, sw.size)
|
||||
end := start.Add(sw.size)
|
||||
slot := model.NewTimeSlot(&start, &end)
|
||||
return slot
|
||||
}
|
||||
|
||||
func (sw *TumblingWindow) NextSlot() *model.TimeSlot {
|
||||
if sw.currentSlot == nil {
|
||||
return nil
|
||||
}
|
||||
start := sw.currentSlot.End
|
||||
end := sw.currentSlot.End.Add(sw.size)
|
||||
return model.NewTimeSlot(start, &end)
|
||||
}
|
||||
|
||||
// Stop 停止滚动窗口的操作。
|
||||
@@ -87,16 +126,38 @@ func (tw *TumblingWindow) Trigger() {
|
||||
// 加锁以确保并发安全。
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
|
||||
// 如果设置了回调函数,则调用它。
|
||||
if tw.callback != nil {
|
||||
tw.callback(tw.data)
|
||||
// 计算下一个窗口槽位
|
||||
next := tw.NextSlot()
|
||||
// 保留下一个窗口的数据
|
||||
tms := next.Start.Add(-tw.size)
|
||||
tme := next.End.Add(tw.size)
|
||||
temp := model.NewTimeSlot(&tms, &tme)
|
||||
newData := make([]model.Row, 0)
|
||||
for _, item := range tw.data {
|
||||
if temp.Contains(item.Timestamp) {
|
||||
newData = append(newData, item)
|
||||
}
|
||||
}
|
||||
|
||||
// 将窗口数据发送到输出通道。
|
||||
tw.outputChan <- append([]interface{}{}, tw.data...)
|
||||
// 重置窗口数据。
|
||||
tw.data = nil
|
||||
// 提取出当前窗口数据
|
||||
resultData := make([]model.Row, 0)
|
||||
for _, item := range tw.data {
|
||||
if tw.currentSlot.Contains(item.Timestamp) {
|
||||
item.Slot = tw.currentSlot
|
||||
resultData = append(resultData, item)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果设置了回调函数,则执行回调函数
|
||||
if tw.callback != nil {
|
||||
tw.callback(resultData)
|
||||
}
|
||||
|
||||
// 更新窗口内的数据
|
||||
tw.data = newData
|
||||
tw.currentSlot = next
|
||||
// 将新的数据发送到输出通道
|
||||
tw.outputChan <- resultData
|
||||
}
|
||||
|
||||
// Reset 重置滚动窗口的数据。
|
||||
@@ -106,24 +167,25 @@ func (tw *TumblingWindow) Reset() {
|
||||
defer tw.mu.Unlock()
|
||||
// 清空窗口数据。
|
||||
tw.data = nil
|
||||
tw.currentSlot = nil
|
||||
}
|
||||
|
||||
// OutputChan 返回一个只读通道,用于接收窗口触发时的数据。
|
||||
func (tw *TumblingWindow) OutputChan() <-chan []interface{} {
|
||||
func (tw *TumblingWindow) OutputChan() <-chan []model.Row {
|
||||
return tw.outputChan
|
||||
}
|
||||
|
||||
// SetCallback 设置滚动窗口触发时的回调函数。
|
||||
// 参数 callback 是要设置的回调函数。
|
||||
func (tw *TumblingWindow) SetCallback(callback func([]interface{})) {
|
||||
func (tw *TumblingWindow) SetCallback(callback func([]model.Row)) {
|
||||
tw.callback = callback
|
||||
}
|
||||
|
||||
// GetResults 获取当前滚动窗口中的数据副本。
|
||||
func (tw *TumblingWindow) GetResults() []interface{} {
|
||||
// 加锁以确保并发安全。
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
// 返回窗口数据的副本。
|
||||
return append([]interface{}{}, tw.data...)
|
||||
}
|
||||
// // GetResults 获取当前滚动窗口中的数据副本。
|
||||
// func (tw *TumblingWindow) GetResults() []interface{} {
|
||||
// // 加锁以确保并发安全。
|
||||
// tw.mu.Lock()
|
||||
// defer tw.mu.Unlock()
|
||||
// // 返回窗口数据的副本。
|
||||
// return append([]interface{}{}, tw.data...)
|
||||
// }
|
||||
|
||||
@@ -2,62 +2,102 @@ package window
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/require"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rulego/streamsql/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTumblingWindow(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tw := NewTumblingWindow(2 * time.Second)
|
||||
tw.SetCallback(func(results []interface{}) {
|
||||
tw, _ := NewTumblingWindow(model.WindowConfig{
|
||||
Type: "TumblingWindow",
|
||||
Params: map[string]interface{}{"size": "2s"},
|
||||
TsProp: "Ts",
|
||||
})
|
||||
tw.SetCallback(func(results []model.Row) {
|
||||
// Process results
|
||||
})
|
||||
go tw.Start()
|
||||
|
||||
// Add data every 500ms
|
||||
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
|
||||
// 添加测试数据
|
||||
for i := 0; i < 5; i++ {
|
||||
tw.Add(i)
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
data := TestDate{
|
||||
Ts: baseTime.Add(time.Duration(i) * 1100 * time.Millisecond),
|
||||
tag: fmt.Sprintf("%d", i),
|
||||
}
|
||||
tw.Add(data)
|
||||
}
|
||||
|
||||
// Check output channel
|
||||
// 收集窗口结果
|
||||
resultsChan := tw.OutputChan()
|
||||
var results []interface{}
|
||||
select {
|
||||
case results = <-resultsChan:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("No results received within timeout")
|
||||
var all [][]model.Row = make([][]model.Row, 0)
|
||||
|
||||
// 收集所有窗口数据
|
||||
COLLECT:
|
||||
for {
|
||||
select {
|
||||
case results := <-resultsChan:
|
||||
all = append(all, results)
|
||||
if len(all) >= 3 {
|
||||
break COLLECT
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that data is sent every 2 seconds
|
||||
require.Len(t, results, 2)
|
||||
require.Equal(t, []interface{}{0, 1}, results)
|
||||
// 验证窗口数据
|
||||
require.Len(t, all, 3, "应该有3个时间窗口的数据")
|
||||
|
||||
// Verify next batch
|
||||
select {
|
||||
case results = <-resultsChan:
|
||||
require.Len(t, results, 2)
|
||||
require.Equal(t, []interface{}{2, 3}, results)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("No results received within timeout")
|
||||
// 验证每个窗口的数据
|
||||
expectedWindows := []struct {
|
||||
size int
|
||||
tags []string
|
||||
startIdx int
|
||||
}{
|
||||
{size: 2, tags: []string{"0", "1"}, startIdx: 0},
|
||||
{size: 2, tags: []string{"2", "3"}, startIdx: 1},
|
||||
{size: 1, tags: []string{"4"}, startIdx: 2},
|
||||
}
|
||||
|
||||
//time.Sleep(1100 * time.Millisecond)
|
||||
//results = <-resultsChan
|
||||
for i, window := range all {
|
||||
expected := expectedWindows[i]
|
||||
require.Len(t, window, expected.size, "窗口 %d 数据数量不匹配", i)
|
||||
|
||||
// 验证数据内容
|
||||
for _, row := range window {
|
||||
require.Contains(t, expected.tags, row.Data.(TestDate).tag)
|
||||
}
|
||||
|
||||
// 验证时间槽
|
||||
startTime := baseTime.Add(time.Duration(i*2) * time.Second)
|
||||
endTime := startTime.Add(2 * time.Second)
|
||||
require.True(t, window[0].Slot.Start.Equal(startTime) &&
|
||||
window[0].Slot.End.Equal(endTime),
|
||||
"窗口 %d 时间槽边界不正确", i)
|
||||
}
|
||||
|
||||
// Verify reset and final batch
|
||||
tw.Reset()
|
||||
tw.Add(99)
|
||||
tw.Add(TestDate{
|
||||
Ts: baseTime.Add(time.Duration(99) * 1100 * time.Millisecond),
|
||||
tag: fmt.Sprintf("%d", 99),
|
||||
})
|
||||
// time.Sleep(1100 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case results = <-resultsChan:
|
||||
case results := <-resultsChan:
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, []interface{}{99}, results)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("No results received within timeout")
|
||||
require.Equal(t, "99", results[0].Data.(TestDate).tag)
|
||||
startTime := baseTime.Add(108 * time.Second)
|
||||
endTime := baseTime.Add(110 * time.Second)
|
||||
require.True(t, results[0].Slot.Start.Equal(startTime) && results[0].Slot.End.Equal(endTime))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user