重构:修改别名映射

This commit is contained in:
dimon
2025-04-08 17:38:34 +08:00
parent 67c6a91dbb
commit b9b38565a0
11 changed files with 415 additions and 60 deletions

View File

@ -3,26 +3,29 @@ package aggregator
import (
"math"
"sort"
"strconv"
"sync"
)
type AggregateType string
const (
Sum AggregateType = "sum"
Count AggregateType = "count"
Avg AggregateType = "avg"
Max AggregateType = "max"
Min AggregateType = "min"
StdDev AggregateType = "stddev"
Median AggregateType = "median"
Percentile AggregateType = "percentile"
Sum AggregateType = "sum"
Count AggregateType = "count"
Avg AggregateType = "avg"
Max AggregateType = "max"
Min AggregateType = "min"
StdDev AggregateType = "stddev"
Median AggregateType = "median"
Percentile AggregateType = "percentile"
WindowStart AggregateType = "window_start"
WindowEnd AggregateType = "window_end"
)
type AggregatorFunction interface {
New() AggregatorFunction
Add(value float64)
Result() float64
Add(value interface{})
Result() interface{}
}
type SumAggregator struct {
@ -33,11 +36,12 @@ func (s *SumAggregator) New() AggregatorFunction {
return &SumAggregator{}
}
func (s *SumAggregator) Add(v float64) {
s.value += v
func (s *SumAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
s.value += vv
}
func (s *SumAggregator) Result() float64 {
func (s *SumAggregator) Result() interface{} {
return s.value
}
@ -49,11 +53,11 @@ func (s *CountAggregator) New() AggregatorFunction {
return &CountAggregator{}
}
func (c *CountAggregator) Add(_ float64) {
func (c *CountAggregator) Add(_ interface{}) {
c.count++
}
func (c *CountAggregator) Result() float64 {
func (c *CountAggregator) Result() interface{} {
return float64(c.count)
}
@ -66,12 +70,13 @@ func (a *AvgAggregator) New() AggregatorFunction {
return &AvgAggregator{}
}
func (a *AvgAggregator) Add(v float64) {
a.sum += v
func (a *AvgAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
a.sum += vv
a.count++
}
func (a *AvgAggregator) Result() float64 {
func (a *AvgAggregator) Result() interface{} {
if a.count == 0 {
return 0
}
@ -117,6 +122,10 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction {
return &MedianAggregator{}
case Percentile:
return &PercentileAggregator{p: 0.95}
case WindowStart:
return &WindowStartAggregator{}
case WindowEnd:
return &WindowEndAggregator{}
default:
panic("unsupported aggregator type: " + aggType)
}
@ -150,11 +159,12 @@ func (m *MedianAggregator) New() AggregatorFunction {
return &MedianAggregator{}
}
func (m *MedianAggregator) Add(val float64) {
m.values = append(m.values, val)
func (m *MedianAggregator) Add(val interface{}) {
var vv float64 = ConvertToFloat64(val, 0)
m.values = append(m.values, vv)
}
func (m *MedianAggregator) Result() float64 {
func (m *MedianAggregator) Result() interface{} {
sort.Float64s(m.values)
return m.values[len(m.values)/2]
}
@ -168,8 +178,9 @@ func (p *PercentileAggregator) New() AggregatorFunction {
return &PercentileAggregator{}
}
func (p *PercentileAggregator) Add(v float64) {
p.values = append(p.values, v)
func (p *PercentileAggregator) Add(v interface{}) {
vv := ConvertToFloat64(v, 0)
p.values = append(p.values, vv)
}
type MinAggregator struct {
@ -183,14 +194,15 @@ func (s *MinAggregator) New() AggregatorFunction {
}
}
func (m *MinAggregator) Add(v float64) {
if m.first || v < m.value {
m.value = v
func (m *MinAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, math.MaxFloat64)
if m.first || vv < m.value {
m.value = vv
m.first = false
}
}
func (m *MinAggregator) Result() float64 {
func (m *MinAggregator) Result() interface{} {
return m.value
}
@ -203,22 +215,24 @@ func (m *MaxAggregator) New() AggregatorFunction {
return &MaxAggregator{}
}
func (m *MaxAggregator) Add(v float64) {
if m.first || v > m.value {
m.value = v
func (m *MaxAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
if m.first || vv > m.value {
m.value = vv
m.first = false
}
}
func (m *MaxAggregator) Result() float64 {
func (m *MaxAggregator) Result() interface{} {
return m.value
}
func (s *StdDevAggregator) Add(v float64) {
s.values = append(s.values, v)
func (s *StdDevAggregator) Add(v interface{}) {
var vv float64 = ConvertToFloat64(v, 0)
s.values = append(s.values, vv)
}
func (s *StdDevAggregator) Result() float64 {
func (s *StdDevAggregator) Result() interface{} {
if len(s.values) < 2 {
return 0
}
@ -230,7 +244,7 @@ func (s *StdDevAggregator) Result() float64 {
return math.Sqrt(sum / float64(len(s.values)-1))
}
func (p *PercentileAggregator) Result() float64 {
func (p *PercentileAggregator) Result() interface{} {
if len(p.values) == 0 {
return 0
}
@ -246,3 +260,35 @@ func calculateAverage(values []float64) float64 {
}
return sum / float64(len(values))
}
func ConvertToFloat64(v interface{}, defaultVal float64) float64 {
var vv float64 = defaultVal
switch val := v.(type) {
case float64:
vv = val
case float32:
vv = float64(val)
case int:
vv = float64(val)
case int32:
vv = float64(val)
case int64:
vv = float64(val)
case uint:
vv = float64(val)
case uint32:
vv = float64(val)
case uint64:
vv = float64(val)
case string:
// 处理字符串类型的转换
if floatValue, err := strconv.ParseFloat(val, 64); err == nil {
vv = floatValue
} else {
panic("unsupported type for sum aggregator")
}
default:
panic("unsupported type for sum aggregator")
}
return vv
}

View File

@ -0,0 +1,45 @@
package aggregator
type ContextAggregator interface {
GetContextKey() string
}
type WindowStartAggregator struct {
val interface{}
}
func (w *WindowStartAggregator) New() AggregatorFunction {
return &WindowStartAggregator{}
}
func (w *WindowStartAggregator) Add(val interface{}) {
w.val = val
}
func (w *WindowStartAggregator) Result() interface{} {
return w.val
}
func (w *WindowStartAggregator) GetContextKey() string {
return "window_start"
}
type WindowEndAggregator struct {
val interface{}
}
func (w *WindowEndAggregator) New() AggregatorFunction {
return &WindowEndAggregator{}
}
func (w *WindowEndAggregator) Add(val interface{}) {
w.val = val
}
func (w *WindowEndAggregator) Result() interface{} {
return w.val
}
func (w *WindowEndAggregator) GetContextKey() string {
return "window_end"
}

View File

@ -9,6 +9,7 @@ import (
type Aggregator interface {
Add(data interface{}) error
Put(key string, val interface{}) error
GetResults() ([]map[string]interface{}, error)
Reset()
}
@ -19,9 +20,11 @@ type GroupAggregator struct {
aggregators map[string]AggregatorFunction
groups map[string]map[string]AggregatorFunction
mu sync.RWMutex
context map[string]interface{}
fieldAlias map[string]string
}
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType) *GroupAggregator {
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator {
aggregators := make(map[string]AggregatorFunction)
for field, aggType := range fieldMap {
@ -33,8 +36,20 @@ func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType)
groupFields: groupFields,
aggregators: aggregators,
groups: make(map[string]map[string]AggregatorFunction),
fieldAlias: fieldAlias,
}
}
func (ga *GroupAggregator) Put(key string, val interface{}) error {
ga.mu.Lock() // 获取写锁
defer ga.mu.Unlock() // 确保函数返回时释放锁
if ga.context == nil {
ga.context = make(map[string]interface{})
}
ga.context[key] = val
return nil
}
func (ga *GroupAggregator) Add(data interface{}) error {
ga.mu.Lock() // 获取写锁
defer ga.mu.Unlock() // 确保函数返回时释放锁
@ -114,6 +129,19 @@ func (ga *GroupAggregator) Add(data interface{}) error {
if !f.IsValid() {
//return fmt.Errorf("field %s not found", field)
//fmt.Printf("field %s not found in %v \n ", field, data)
// 尝试从context中获取
if ga.context != nil {
if groupAgg, exists := ga.groups[key][field]; exists {
if _, ok := groupAgg.(ContextAggregator); ok {
key := groupAgg.(ContextAggregator).GetContextKey()
if val, exists := ga.context[key]; exists {
groupAgg.Add(val)
//fmt.Printf("add agg group by %s:%s , %v \n", key, field, value)
}
}
}
}
continue
}
@ -151,7 +179,20 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
group[field] = fields[i]
}
for field, agg := range aggregators {
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
if _, ok := agg.(ContextAggregator); ok {
if alias, ok := ga.fieldAlias[field]; ok {
group[alias] = agg.Result()
} else {
group[field] = agg.Result()
}
} else {
if alias, ok := ga.fieldAlias[field]; ok {
group[alias] = agg.Result()
} else {
group[field+"_"+string(ga.fieldMap[field])] = agg.Result()
}
}
}
result = append(result, group)
}

View File

@ -19,13 +19,17 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) {
"Data1": Sum,
"Data2": Sum,
},
map[string]string{
"Data1": "Data1_sum",
"Data2": "Data2_sum",
},
)
testData := []map[string]interface{}{
{"device": "aa", "data1": 20, "data2": 30},
{"device": "aa", "data1": 21, "data2": 0},
{"device": "bb", "data1": 15, "data2": 20},
{"device": "bb", "data1": 16, "data2": 20},
{"Device": "aa", "Data1": 20, "Data2": 30},
{"Device": "aa", "Data1": 21, "Data2": 0},
{"Device": "bb", "Data1": 15, "Data2": 20},
{"Device": "bb", "Data1": 16, "Data2": 20},
}
for _, d := range testData {
@ -47,6 +51,9 @@ func TestGroupAggregator_SingleField(t *testing.T) {
map[string]AggregateType{
"Data1": Sum,
},
map[string]string{
"Data1": "Data1_sum",
},
)
testData := []map[string]interface{}{
@ -75,6 +82,12 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
"Data3": Max,
"Data4": Min,
},
map[string]string{
"Data1": "Data1_sum",
"Data2": "Data2_avg",
"Data3": "Data3_max",
"Data4": "Data4_min",
},
)
testData := []map[string]interface{}{
@ -88,7 +101,7 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
expected := []map[string]interface{}{
{
"Device": "cc",
"Device": "cc",
"Data1_sum": 30.0,
"Data2_avg": 5.0,
"Data3_max": 12.0,

View File

@ -3,13 +3,14 @@ package model
import (
"time"
aggregator2 "github.com/rulego/streamsql/aggregator"
"github.com/rulego/streamsql/aggregator"
)
type Config struct {
WindowConfig WindowConfig
GroupFields []string
SelectFields map[string]aggregator2.AggregateType
SelectFields map[string]aggregator.AggregateType
FieldAlias map[string]string
}
type WindowConfig struct {
Type string

View File

@ -50,3 +50,17 @@ func (ts *TimeSlot) GetEndTime() *time.Time {
}
return ts.End
}
func (ts *TimeSlot) WindowStart() int64 {
if ts == nil || ts.Start == nil {
return 0
}
return ts.Start.UnixNano()
}
func (ts *TimeSlot) WindowEnd() int64 {
if ts == nil || ts.End == nil {
return 0
}
return ts.End.UnixNano()
}

View File

@ -22,6 +22,7 @@ type SelectStatement struct {
type Field struct {
Expression string
Alias string
AggType string
}
type WindowDefinition struct {
@ -52,7 +53,7 @@ func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) {
if err != nil {
return nil, "", fmt.Errorf("解析窗口参数失败: %w", err)
}
aggs, fields := buildSelectFields(s.Fields)
// 构建Stream配置
config := model.Config{
WindowConfig: model.WindowConfig{
@ -62,7 +63,8 @@ func (s *SelectStatement) ToStreamConfig() (*model.Config, string, error) {
TimeUnit: s.Window.TimeUnit,
},
GroupFields: extractGroupFields(s),
SelectFields: buildSelectFields(s.Fields),
SelectFields: aggs,
FieldAlias: fields,
}
return &config, s.Condition, nil
@ -78,28 +80,50 @@ func extractGroupFields(s *SelectStatement) []string {
return fields
}
func buildSelectFields(fields []Field) map[string]aggregator.AggregateType {
func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string) {
selectFields := make(map[string]aggregator.AggregateType)
fieldMap = make(map[string]string)
for _, f := range fields {
if alias := f.Alias; alias != "" {
selectFields[alias] = parseAggregateType(f.Expression)
t, n := parseAggregateType(f.Expression)
if n != "" {
selectFields[n] = t
fieldMap[n] = alias
} else {
selectFields[alias] = t
}
}
}
return selectFields
return selectFields, fieldMap
}
func parseAggregateType(expr string) aggregator.AggregateType {
func parseAggregateType(expr string) (aggType aggregator.AggregateType, name string) {
if strings.Contains(expr, "avg(") {
return "avg"
return "avg", extractAggField(expr)
}
if strings.Contains(expr, "sum(") {
return "sum"
return "sum", extractAggField(expr)
}
if strings.Contains(expr, "max(") {
return "max"
return "max", extractAggField(expr)
}
if strings.Contains(expr, "min(") {
return "min"
return "min", extractAggField(expr)
}
if strings.Contains(expr, "window_start(") {
return "window_start", "window_start"
}
if strings.Contains(expr, "window_end(") {
return "window_end", "window_end"
}
return "", ""
}
func extractAggField(expr string) string {
start := strings.Index(expr, "(")
end := strings.LastIndex(expr, ")")
if start >= 0 && end > start {
return strings.TrimSpace(expr[start+1 : end])
}
return ""
}

View File

@ -44,8 +44,8 @@ func TestParseSQL(t *testing.T) {
},
GroupFields: []string{"type"},
SelectFields: map[string]aggregator.AggregateType{
"max_score": "max",
"min_age": "min",
"score": "max",
"age": "min",
},
},
condition: "",

View File

@ -2,6 +2,7 @@ package stream
import (
"fmt"
"strings"
"time"
aggregator2 "github.com/rulego/streamsql/aggregator"
@ -34,6 +35,9 @@ func NewStream(config model.Config) (*Stream, error) {
}
func (s *Stream) RegisterFilter(condition string) error {
if strings.TrimSpace(condition) == "" {
return nil
}
filter, err := parser.NewExprCondition(condition)
if err != nil {
return fmt.Errorf("compile filter error: %w", err)
@ -47,7 +51,7 @@ func (s *Stream) Start() {
}
func (s *Stream) process() {
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields)
s.aggregator = aggregator2.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias)
// 启动窗口处理协程
s.Window.Start()
@ -66,10 +70,13 @@ func (s *Stream) process() {
case batch := <-s.Window.OutputChan():
// 处理窗口批数据
for _, item := range batch {
if err := s.aggregator.Add(item); err != nil {
s.aggregator.Put("window_start", item.Slot.WindowStart())
s.aggregator.Put("window_end", item.Slot.WindowEnd())
if err := s.aggregator.Add(item.Data); err != nil {
fmt.Printf("aggregate error: %v\n", err)
}
}
// 获取并发送聚合结果
if results, err := s.aggregator.GetResults(); err == nil {
// 发送结果到结果通道和 Sink 函数

View File

@ -223,3 +223,95 @@ func TestIncompleteStreamProcess(t *testing.T) {
assert.InEpsilon(t, expected["age_avg"].(float64), resultMap[0]["age_avg"].(float64), 0.0001)
assert.InDelta(t, expected["score_sum"].(float64), resultMap[0]["score_sum"].(float64), 0.0001)
}
func TestWindowSlotAgg(t *testing.T) {
config := model.Config{
WindowConfig: model.WindowConfig{
Type: "sliding",
Params: map[string]interface{}{"size": 2 * time.Second, "slide": 1 * time.Second},
TsProp: "ts",
},
GroupFields: []string{"device"},
SelectFields: map[string]aggregator.AggregateType{
"age": aggregator.Max,
"score": aggregator.Min,
"start": aggregator.WindowStart,
"end": aggregator.WindowEnd,
},
}
strm, err := NewStream(config)
require.NoError(t, err)
strm.Start()
// Add data every 500ms
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
testData := []interface{}{
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "ts": baseTime},
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "ts": baseTime.Add(1 * time.Second)},
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "ts": baseTime},
}
for _, data := range testData {
strm.AddData(data)
}
// 捕获结果
resultChan := make(chan interface{})
strm.AddSink(func(result interface{}) {
resultChan <- result
})
// 等待 3 秒触发窗口
time.Sleep(3 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var actual interface{}
select {
case actual = <-resultChan:
cancel()
case <-ctx.Done():
t.Fatal("Timeout waiting for results")
}
expected := []map[string]interface{}{
{
"device": "aa",
"age_max": 10.0,
"score_min": 100.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
{
"device": "bb",
"age_max": 3.0,
"score_min": 300.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
}
assert.IsType(t, []map[string]interface{}{}, actual)
resultSlice, ok := actual.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 2)
for _, expectedResult := range expected {
found := false
for _, resultMap := range resultSlice {
//if resultMap, ok := result.(map[string]interface{}); ok {
if resultMap["device"] == expectedResult["device"] {
assert.InEpsilon(t, expectedResult["age_max"].(float64), resultMap["age_max"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["score_min"].(float64), resultMap["score_min"].(float64), 0.0001)
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
found = true
break
}
//}
}
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
}
}

View File

@ -1,13 +1,85 @@
package streamsql
import (
"github.com/stretchr/testify/assert"
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStreamsql(t *testing.T) {
streamsql := New()
var rsql = ""
var rsql = "SELECT device,max(age) as max_age,min(score) as min_score,window_start() as start,window_end() as end FROM stream group by device,SlidingWindow('2s','1s') with (TIMESTAMP='Ts',TIMEUNIT='ss')"
err := streamsql.Execute(rsql)
assert.Nil(t, err)
strm := streamsql.stream
baseTime := time.Date(2025, 4, 7, 16, 46, 0, 0, time.UTC)
testData := []interface{}{
map[string]interface{}{"device": "aa", "age": 5.0, "score": 100, "Ts": baseTime},
map[string]interface{}{"device": "aa", "age": 10.0, "score": 200, "Ts": baseTime.Add(1 * time.Second)},
map[string]interface{}{"device": "bb", "age": 3.0, "score": 300, "Ts": baseTime},
}
for _, data := range testData {
strm.AddData(data)
}
// 捕获结果
resultChan := make(chan interface{})
strm.AddSink(func(result interface{}) {
resultChan <- result
})
// 等待 3 秒触发窗口
time.Sleep(3 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var actual interface{}
select {
case actual = <-resultChan:
cancel()
case <-ctx.Done():
t.Fatal("Timeout waiting for results")
}
expected := []map[string]interface{}{
{
"device": "aa",
"max_age": 10.0,
"min_score": 100.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
{
"device": "bb",
"max_age": 3.0,
"min_score": 300.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
}
assert.IsType(t, []map[string]interface{}{}, actual)
resultSlice, ok := actual.([]map[string]interface{})
require.True(t, ok)
assert.Len(t, resultSlice, 2)
for _, expectedResult := range expected {
found := false
for _, resultMap := range resultSlice {
//if resultMap, ok := result.(map[string]interface{}); ok {
if resultMap["device"] == expectedResult["device"] {
assert.InEpsilon(t, expectedResult["max_age"].(float64), resultMap["max_age"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["min_score"].(float64), resultMap["min_score"].(float64), 0.0001)
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
found = true
break
}
//}
}
assert.True(t, found, fmt.Sprintf("Expected result for device %v not found", expectedResult["device"]))
}
}