mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-07-04 07:09:21 +00:00
重构:修改别名映射
This commit is contained in:
@ -3,26 +3,29 @@ package aggregator
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type AggregateType string
|
||||
|
||||
const (
|
||||
Sum AggregateType = "sum"
|
||||
Count AggregateType = "count"
|
||||
Avg AggregateType = "avg"
|
||||
Max AggregateType = "max"
|
||||
Min AggregateType = "min"
|
||||
StdDev AggregateType = "stddev"
|
||||
Median AggregateType = "median"
|
||||
Percentile AggregateType = "percentile"
|
||||
Sum AggregateType = "sum"
|
||||
Count AggregateType = "count"
|
||||
Avg AggregateType = "avg"
|
||||
Max AggregateType = "max"
|
||||
Min AggregateType = "min"
|
||||
StdDev AggregateType = "stddev"
|
||||
Median AggregateType = "median"
|
||||
Percentile AggregateType = "percentile"
|
||||
WindowStart AggregateType = "window_start"
|
||||
WindowEnd AggregateType = "window_end"
|
||||
)
|
||||
|
||||
type AggregatorFunction interface {
|
||||
New() AggregatorFunction
|
||||
Add(value float64)
|
||||
Result() float64
|
||||
Add(value interface{})
|
||||
Result() interface{}
|
||||
}
|
||||
|
||||
type SumAggregator struct {
|
||||
@ -33,11 +36,12 @@ func (s *SumAggregator) New() AggregatorFunction {
|
||||
return &SumAggregator{}
|
||||
}
|
||||
|
||||
func (s *SumAggregator) Add(v float64) {
|
||||
s.value += v
|
||||
func (s *SumAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
s.value += vv
|
||||
}
|
||||
|
||||
func (s *SumAggregator) Result() float64 {
|
||||
func (s *SumAggregator) Result() interface{} {
|
||||
return s.value
|
||||
}
|
||||
|
||||
@ -49,11 +53,11 @@ func (s *CountAggregator) New() AggregatorFunction {
|
||||
return &CountAggregator{}
|
||||
}
|
||||
|
||||
func (c *CountAggregator) Add(_ float64) {
|
||||
func (c *CountAggregator) Add(_ interface{}) {
|
||||
c.count++
|
||||
}
|
||||
|
||||
func (c *CountAggregator) Result() float64 {
|
||||
func (c *CountAggregator) Result() interface{} {
|
||||
return float64(c.count)
|
||||
}
|
||||
|
||||
@ -66,12 +70,13 @@ func (a *AvgAggregator) New() AggregatorFunction {
|
||||
return &AvgAggregator{}
|
||||
}
|
||||
|
||||
func (a *AvgAggregator) Add(v float64) {
|
||||
a.sum += v
|
||||
func (a *AvgAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
a.sum += vv
|
||||
a.count++
|
||||
}
|
||||
|
||||
func (a *AvgAggregator) Result() float64 {
|
||||
func (a *AvgAggregator) Result() interface{} {
|
||||
if a.count == 0 {
|
||||
return 0
|
||||
}
|
||||
@ -117,6 +122,10 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction {
|
||||
return &MedianAggregator{}
|
||||
case Percentile:
|
||||
return &PercentileAggregator{p: 0.95}
|
||||
case WindowStart:
|
||||
return &WindowStartAggregator{}
|
||||
case WindowEnd:
|
||||
return &WindowEndAggregator{}
|
||||
default:
|
||||
panic("unsupported aggregator type: " + aggType)
|
||||
}
|
||||
@ -150,11 +159,12 @@ func (m *MedianAggregator) New() AggregatorFunction {
|
||||
return &MedianAggregator{}
|
||||
}
|
||||
|
||||
func (m *MedianAggregator) Add(val float64) {
|
||||
m.values = append(m.values, val)
|
||||
func (m *MedianAggregator) Add(val interface{}) {
|
||||
var vv float64 = ConvertToFloat64(val, 0)
|
||||
m.values = append(m.values, vv)
|
||||
}
|
||||
|
||||
func (m *MedianAggregator) Result() float64 {
|
||||
func (m *MedianAggregator) Result() interface{} {
|
||||
sort.Float64s(m.values)
|
||||
return m.values[len(m.values)/2]
|
||||
}
|
||||
@ -168,8 +178,9 @@ func (p *PercentileAggregator) New() AggregatorFunction {
|
||||
return &PercentileAggregator{}
|
||||
}
|
||||
|
||||
func (p *PercentileAggregator) Add(v float64) {
|
||||
p.values = append(p.values, v)
|
||||
func (p *PercentileAggregator) Add(v interface{}) {
|
||||
vv := ConvertToFloat64(v, 0)
|
||||
p.values = append(p.values, vv)
|
||||
}
|
||||
|
||||
type MinAggregator struct {
|
||||
@ -183,14 +194,15 @@ func (s *MinAggregator) New() AggregatorFunction {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MinAggregator) Add(v float64) {
|
||||
if m.first || v < m.value {
|
||||
m.value = v
|
||||
func (m *MinAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, math.MaxFloat64)
|
||||
if m.first || vv < m.value {
|
||||
m.value = vv
|
||||
m.first = false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MinAggregator) Result() float64 {
|
||||
func (m *MinAggregator) Result() interface{} {
|
||||
return m.value
|
||||
}
|
||||
|
||||
@ -203,22 +215,24 @@ func (m *MaxAggregator) New() AggregatorFunction {
|
||||
return &MaxAggregator{}
|
||||
}
|
||||
|
||||
func (m *MaxAggregator) Add(v float64) {
|
||||
if m.first || v > m.value {
|
||||
m.value = v
|
||||
func (m *MaxAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
if m.first || vv > m.value {
|
||||
m.value = vv
|
||||
m.first = false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MaxAggregator) Result() float64 {
|
||||
func (m *MaxAggregator) Result() interface{} {
|
||||
return m.value
|
||||
}
|
||||
|
||||
func (s *StdDevAggregator) Add(v float64) {
|
||||
s.values = append(s.values, v)
|
||||
func (s *StdDevAggregator) Add(v interface{}) {
|
||||
var vv float64 = ConvertToFloat64(v, 0)
|
||||
s.values = append(s.values, vv)
|
||||
}
|
||||
|
||||
func (s *StdDevAggregator) Result() float64 {
|
||||
func (s *StdDevAggregator) Result() interface{} {
|
||||
if len(s.values) < 2 {
|
||||
return 0
|
||||
}
|
||||
@ -230,7 +244,7 @@ func (s *StdDevAggregator) Result() float64 {
|
||||
return math.Sqrt(sum / float64(len(s.values)-1))
|
||||
}
|
||||
|
||||
func (p *PercentileAggregator) Result() float64 {
|
||||
func (p *PercentileAggregator) Result() interface{} {
|
||||
if len(p.values) == 0 {
|
||||
return 0
|
||||
}
|
||||
@ -246,3 +260,35 @@ func calculateAverage(values []float64) float64 {
|
||||
}
|
||||
return sum / float64(len(values))
|
||||
}
|
||||
|
||||
func ConvertToFloat64(v interface{}, defaultVal float64) float64 {
|
||||
var vv float64 = defaultVal
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
vv = val
|
||||
case float32:
|
||||
vv = float64(val)
|
||||
case int:
|
||||
vv = float64(val)
|
||||
case int32:
|
||||
vv = float64(val)
|
||||
case int64:
|
||||
vv = float64(val)
|
||||
case uint:
|
||||
vv = float64(val)
|
||||
case uint32:
|
||||
vv = float64(val)
|
||||
case uint64:
|
||||
vv = float64(val)
|
||||
case string:
|
||||
// 处理字符串类型的转换
|
||||
if floatValue, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
vv = floatValue
|
||||
} else {
|
||||
panic("unsupported type for sum aggregator")
|
||||
}
|
||||
default:
|
||||
panic("unsupported type for sum aggregator")
|
||||
}
|
||||
return vv
|
||||
}
|
||||
|
45
aggregator/context_aggregator.go
Normal file
45
aggregator/context_aggregator.go
Normal file
@ -0,0 +1,45 @@
|
||||
package aggregator
|
||||
|
||||
type ContextAggregator interface {
|
||||
GetContextKey() string
|
||||
}
|
||||
|
||||
type WindowStartAggregator struct {
|
||||
val interface{}
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) New() AggregatorFunction {
|
||||
return &WindowStartAggregator{}
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) Add(val interface{}) {
|
||||
w.val = val
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) Result() interface{} {
|
||||
return w.val
|
||||
}
|
||||
|
||||
func (w *WindowStartAggregator) GetContextKey() string {
|
||||
return "window_start"
|
||||
}
|
||||
|
||||
type WindowEndAggregator struct {
|
||||
val interface{}
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) New() AggregatorFunction {
|
||||
return &WindowEndAggregator{}
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) Add(val interface{}) {
|
||||
w.val = val
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) Result() interface{} {
|
||||
return w.val
|
||||
}
|
||||
|
||||
func (w *WindowEndAggregator) GetContextKey() string {
|
||||
return "window_end"
|
||||
}
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
type Aggregator interface {
|
||||
Add(data interface{}) error
|
||||
Put(key string, val interface{}) error
|
||||
GetResults() ([]map[string]interface{}, error)
|
||||
Reset()
|
||||
}
|
||||
@ -19,9 +20,11 @@ type GroupAggregator struct {
|
||||
aggregators map[string]AggregatorFunction
|
||||
groups map[string]map[string]AggregatorFunction
|
||||
mu sync.RWMutex
|
||||
context map[string]interface{}
|
||||
fieldAlias map[string]string
|
||||
}
|
||||
|
||||
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator {
|
||||
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator {
|
||||
aggregators := make(map[string]AggregatorFunction)
|
||||
|
||||
for field, aggType := range fieldMap {
|
||||
@ -33,8 +36,20 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType)
|
||||
groupFields: groupFields,
|
||||
aggregators: aggregators,
|
||||
groups: make(map[string]map[string]AggregatorFunction),
|
||||
fieldAlias: fieldAlias,
|
||||
}
|
||||
}
|
||||
|
||||
func (ga *GroupAggregator) Put(key string, val interface{}) error {
|
||||
ga.mu.Lock() // 获取写锁
|
||||
defer ga.mu.Unlock() // 确保函数返回时释放锁
|
||||
if ga.context == nil {
|
||||
ga.context = make(map[string]interface{})
|
||||
}
|
||||
ga.context[key] = val
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
ga.mu.Lock() // 获取写锁
|
||||
defer ga.mu.Unlock() // 确保函数返回时释放锁
|
||||
@ -114,6 +129,19 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
||||
if !f.IsValid() {
|
||||
//return fmt.Errorf("field %s not found", field)
|
||||
//fmt.Printf("field %s not found in %v \n ", field, data)
|
||||
|
||||
// 尝试从context中获取
|
||||
if ga.context != nil {
|
||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
||||
if _, ok := groupAgg.(ContextAggregator); ok {
|
||||
key := groupAgg.(ContextAggregator).GetContextKey()
|
||||
if val, exists := ga.context[key]; exists {
|
||||
groupAgg.Add(val)
|
||||
//fmt.Printf("add agg group by %s:%s , %v \n", key, field, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@ -151,7 +179,20 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
|
||||
group[field] = fields[i]
|
||||
}
|
||||
for field, agg := range aggregators {
|
||||
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
|
||||
if _, ok := agg.(ContextAggregator); ok {
|
||||
if alias, ok := ga.fieldAlias[field]; ok {
|
||||
group[alias] = agg.Result()
|
||||
} else {
|
||||
group[field] = agg.Result()
|
||||
}
|
||||
} else {
|
||||
if alias, ok := ga.fieldAlias[field]; ok {
|
||||
group[alias] = agg.Result()
|
||||
} else {
|
||||
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
result = append(result, group)
|
||||
}
|
||||
|
@ -19,13 +19,17 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) {
|
||||
"Data1": Sum,
|
||||
"Data2": Sum,
|
||||
},
|
||||
map[string]string{
|
||||
"Data1": "Data1_sum",
|
||||
"Data2": "Data2_sum",
|
||||
},
|
||||
)
|
||||
|
||||
testData := []map[string]interface{}{
|
||||
{"device": "aa", "data1": 20, "data2": 30},
|
||||
{"device": "aa", "data1": 21, "data2": 0},
|
||||
{"device": "bb", "data1": 15, "data2": 20},
|
||||
{"device": "bb", "data1": 16, "data2": 20},
|
||||
{"Device": "aa", "Data1": 20, "Data2": 30},
|
||||
{"Device": "aa", "Data1": 21, "Data2": 0},
|
||||
{"Device": "bb", "Data1": 15, "Data2": 20},
|
||||
{"Device": "bb", "Data1": 16, "Data2": 20},
|
||||
}
|
||||
|
||||
for _, d := range testData {
|
||||
@ -47,6 +51,9 @@ func TestGroupAggregator_SingleField(t *testing.T) {
|
||||
map[string]AggregateType{
|
||||
"Data1": Sum,
|
||||
},
|
||||
map[string]string{
|
||||
"Data1": "Data1_sum",
|
||||
},
|
||||
)
|
||||
|
||||
testData := []map[string]interface{}{
|
||||
@ -75,6 +82,12 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
|
||||
"Data3": Max,
|
||||
"Data4": Min,
|
||||
},
|
||||
map[string]string{
|
||||
"Data1": "Data1_sum",
|
||||
"Data2": "Data2_avg",
|
||||
"Data3": "Data3_max",
|
||||
"Data4": "Data4_min",
|
||||
},
|
||||
)
|
||||
|
||||
testData := []map[string]interface{}{
|
||||
@ -88,7 +101,7 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
|
||||
|
||||
expected := []map[string]interface{}{
|
||||
{
|
||||
"Device": "cc",
|
||||
"Device": "cc",
|
||||
"Data1_sum": 30.0,
|
||||
"Data2_avg": 5.0,
|
||||
"Data3_max": 12.0,
|
||||
|
@ -3,13 +3,14 @@ package model
|
||||
import (
|
||||
"time"
|
||||
|
||||
aggregator2 "github.com/rulego/streamsql/aggregator"
|
||||
"github.com/rulego/streamsql/aggregator"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WindowConfig WindowConfig
|
||||
GroupFields []string
|
||||
SelectFields map[string]aggregator2.AggregateType
|
||||
SelectFields map[string]aggregator.AggregateType
|
||||
FieldAlias map[string]string
|
||||
}
|
||||
type WindowConfig struct {
|
||||
Type string
|
||||
|
@ -50,3 +50,17 @@ func (ts *TimeSlot) GetEndTime() *time.Time {
|
||||
}
|
||||
return ts.End
|
||||
}
|
||||
|
||||
func (ts *TimeSlot) WindowStart() int64 {
|
||||
if ts == nil || ts.Start == nil {
|
||||
return 0
|
||||
}
|
||||
return ts.Start.UnixNano()
|
||||
}
|
||||
|
||||
func (ts *TimeSlot) WindowEnd() int64 {
|
||||
if ts == nil || ts.End == nil {
|
||||
return 0
|
||||
}
|
||||
return ts.End.UnixNano()
|
||||
}
|
||||
|
44
rsql/ast.go
44
rsql/ast.go
@ -22,6 +22,7 @@ type SelectStatement struct {
|
||||
type Field struct {
|
||||
Expression string
|
||||
Alias string
|
||||
AggType string
|
||||
}
|
||||
|
||||
type WindowDefinition struct {
|
||||
@ -52,7 +53,7 @@ func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) {
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("解析窗口参数失败: %w", err)
|
||||
}
|
||||
|
||||
aggs, fields := buildSelectFields(s.Fields)
|
||||
// 构建Stream配置
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
@ -62,7 +63,8 @@ func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) {
|
||||
TimeUnit: s.Window.TimeUnit,
|
||||
},
|
||||
GroupFields: extractGroupFields(s),
|
||||
SelectFields: buildSelectFields(s.Fields),
|
||||
SelectFields: aggs,
|
||||
FieldAlias: fields,
|
||||
}
|
||||
|
||||
return &config, s.Condition, nil
|
||||
@ -78,28 +80,50 @@ func extractGroupFields(s *SelectStatement) []string {
|
||||
return fields
|
||||
}
|
||||
|
||||
func buildSelectFields(fields []Field) map[string]aggregator.AggregateType {
|
||||
func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string) {
|
||||
selectFields := make(map[string]aggregator.AggregateType)
|
||||
fieldMap = make(map[string]string)
|
||||
for _, f := range fields {
|
||||
if alias := f.Alias; alias != "" {
|
||||
selectFields[alias] = parseAggregateType(f.Expression)
|
||||
t, n := parseAggregateType(f.Expression)
|
||||
if n != "" {
|
||||
selectFields[n] = t
|
||||
fieldMap[n] = alias
|
||||
} else {
|
||||
selectFields[alias] = t
|
||||
}
|
||||
}
|
||||
}
|
||||
return selectFields
|
||||
return selectFields, fieldMap
|
||||
}
|
||||
|
||||
func parseAggregateType(expr string) aggregator.AggregateType {
|
||||
func parseAggregateType(expr string) (aggType aggregator.AggregateType, name string) {
|
||||
if strings.Contains(expr, "avg(") {
|
||||
return "avg"
|
||||
return "avg", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "sum(") {
|
||||
return "sum"
|
||||
return "sum", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "max(") {
|
||||
return "max"
|
||||
return "max", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "min(") {
|
||||
return "min"
|
||||
return "min", extractAggField(expr)
|
||||
}
|
||||
if strings.Contains(expr, "window_start(") {
|
||||
return "window_start", "window_start"
|
||||
}
|
||||
if strings.Contains(expr, "window_end(") {
|
||||
return "window_end", "window_end"
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func extractAggField(expr string) string {
|
||||
start := strings.Index(expr, "(")
|
||||
end := strings.LastIndex(expr, ")")
|
||||
if start >= 0 && end > start {
|
||||
return strings.TrimSpace(expr[start+1 : end])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@ -44,8 +44,8 @@ func TestParseSQL(t *testing.T) {
|
||||
},
|
||||
GroupFields: []string{"type"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"max_score": "max",
|
||||
"min_age": "min",
|
||||
"score": "max",
|
||||
"age": "min",
|
||||
},
|
||||
},
|
||||
condition: "",
|
||||
|
@ -2,6 +2,7 @@ package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
aggregator2 "github.com/rulego/streamsql/aggregator"
|
||||
@ -34,6 +35,9 @@ func NewStream(config model.Config) (*Stream, error) {
|
||||
}
|
||||
|
||||
func (s *Stream) RegisterFilter(condition string) error {
|
||||
if strings.TrimSpace(condition) == "" {
|
||||
return nil
|
||||
}
|
||||
filter, err := parser.NewExprCondition(condition)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compile filter error: %w", err)
|
||||
@ -47,7 +51,7 @@ func (s *Stream) Start() {
|
||||
}
|
||||
|
||||
func (s *Stream) process() {
|
||||
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields)
|
||||
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias)
|
||||
|
||||
// 启动窗口处理协程
|
||||
s.Window.Start()
|
||||
@ -66,10 +70,13 @@ func (s *Stream) process() {
|
||||
case batch := <-s.Window.OutputChan():
|
||||
// 处理窗口批数据
|
||||
for _, item := range batch {
|
||||
if err := s.aggregator.Add(item); err != nil {
|
||||
s.aggregator.Put("window_start", item.Slot.WindowStart())
|
||||
s.aggregator.Put("window_end", item.Slot.WindowEnd())
|
||||
if err := s.aggregator.Add(item.Data); err != nil {
|
||||
fmt.Printf("aggregate error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取并发送聚合结果
|
||||
if results, err := s.aggregator.GetResults(); err == nil {
|
||||
// 发送结果到结果通道和 Sink 函数
|
||||
|
@ -223,3 +223,95 @@ func TestIncompleteStreamProcess(t *testing.T) {
|
||||
assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001)
|
||||
assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001)
|
||||
}
|
||||
|
||||
func TestWindowSlotAgg(t *testing.T) {
|
||||
config := model.Config{
|
||||
WindowConfig: model.WindowConfig{
|
||||
Type: "sliding",
|
||||
Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second},
|
||||
TsProp: "ts",
|
||||
},
|
||||
GroupFields: []string{"device"},
|
||||
SelectFields: map[string]aggregator.AggregateType{
|
||||
"age": aggregator.Max,
|
||||
"score": aggregator.Min,
|
||||
"start": aggregator.WindowStart,
|
||||
"end": aggregator.WindowEnd,
|
||||
},
|
||||
}
|
||||
|
||||
strm, err := NewStream(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
strm.Start()
|
||||
// Add data every 500ms
|
||||
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
|
||||
|
||||
testData := []interface{}{
|
||||
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "ts": baseTime},
|
||||
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "ts": baseTime.Add(1 * time.Second)},
|
||||
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "ts": baseTime},
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
}
|
||||
|
||||
// 捕获结果
|
||||
resultChan := make(chan interface{})
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
// 等待 3 秒触发窗口
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var actual interface{}
|
||||
select {
|
||||
case actual = <-resultChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Timeout waiting for results")
|
||||
}
|
||||
|
||||
expected := []map[string]interface{}{
|
||||
{
|
||||
"device": "aa",
|
||||
"age_max": 10.0,
|
||||
"score_min": 100.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
{
|
||||
"device": "bb",
|
||||
"age_max": 3.0,
|
||||
"score_min": 300.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
}
|
||||
|
||||
assert.IsType(t, []map[string]interface{}{}, actual)
|
||||
resultSlice, ok := actual.([]map[string]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Len(t, resultSlice, 2)
|
||||
for _, expectedResult := range expected {
|
||||
found := false
|
||||
for _, resultMap := range resultSlice {
|
||||
//if resultMap, ok := result.(map[string]interface{}); ok {
|
||||
if resultMap["device"] == expectedResult["device"] {
|
||||
assert.InEpsilon(t, expectedResult["age_max"].(float64), resultMap["age_max"].(float64), 0.0001)
|
||||
assert.InEpsilon(t, expectedResult["score_min"].(float64), resultMap["score_min"].(float64), 0.0001)
|
||||
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
|
||||
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
|
||||
found = true
|
||||
break
|
||||
}
|
||||
//}
|
||||
}
|
||||
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,85 @@
|
||||
package streamsql
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamsql(t *testing.T) {
|
||||
streamsql := New()
|
||||
var rsql = ""
|
||||
var rsql = "SELECT device,max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream group by device,SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')"
|
||||
err := streamsql.Execute(rsql)
|
||||
assert.Nil(t, err)
|
||||
strm := streamsql.stream
|
||||
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
|
||||
testData := []interface{}{
|
||||
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime},
|
||||
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)},
|
||||
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime},
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
strm.AddData(data)
|
||||
}
|
||||
// 捕获结果
|
||||
resultChan := make(chan interface{})
|
||||
strm.AddSink(func(result interface{}) {
|
||||
resultChan <- result
|
||||
})
|
||||
// 等待 3 秒触发窗口
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var actual interface{}
|
||||
select {
|
||||
case actual = <-resultChan:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Timeout waiting for results")
|
||||
}
|
||||
|
||||
expected := []map[string]interface{}{
|
||||
{
|
||||
"device": "aa",
|
||||
"max_age": 10.0,
|
||||
"min_score": 100.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
{
|
||||
"device": "bb",
|
||||
"max_age": 3.0,
|
||||
"min_score": 300.0,
|
||||
"start": baseTime.UnixNano(),
|
||||
"end": baseTime.Add(2 * time.Second).UnixNano(),
|
||||
},
|
||||
}
|
||||
|
||||
assert.IsType(t, []map[string]interface{}{}, actual)
|
||||
resultSlice, ok := actual.([]map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, resultSlice, 2)
|
||||
for _, expectedResult := range expected {
|
||||
found := false
|
||||
for _, resultMap := range resultSlice {
|
||||
//if resultMap, ok := result.(map[string]interface{}); ok {
|
||||
if resultMap["device"] == expectedResult["device"] {
|
||||
assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001)
|
||||
assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001)
|
||||
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
|
||||
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
|
||||
found = true
|
||||
break
|
||||
}
|
||||
//}
|
||||
}
|
||||
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user