refactor:重构聚合器API

This commit is contained in:
rulego-team
2025-06-15 21:52:11 +08:00
parent 615935e667
commit e362565e4b
3 changed files with 143 additions and 110 deletions

View File

@ -19,14 +19,20 @@ type Aggregator interface {
RegisterExpression(field, expression string, fields []string, evaluator func(data interface{}) (interface{}, error)) RegisterExpression(field, expression string, fields []string, evaluator func(data interface{}) (interface{}, error))
} }
// AggregationField 定义单个聚合字段的配置
type AggregationField struct {
InputField string // 输入字段名(如 "temperature"
AggregateType AggregateType // 聚合类型(如 Sum, Avg
OutputAlias string // 输出别名(如 "temp_sum"
}
type GroupAggregator struct { type GroupAggregator struct {
fieldMap map[string]AggregateType aggregationFields []AggregationField
groupFields []string groupFields []string
aggregators map[string]AggregatorFunction aggregators map[string]AggregatorFunction
groups map[string]map[string]AggregatorFunction groups map[string]map[string]AggregatorFunction
mu sync.RWMutex mu sync.RWMutex
context map[string]interface{} context map[string]interface{}
fieldAlias map[string]string
// 表达式计算器 // 表达式计算器
expressions map[string]*ExpressionEvaluator expressions map[string]*ExpressionEvaluator
} }
@ -39,69 +45,25 @@ type ExpressionEvaluator struct {
evaluateFunc func(data interface{}) (interface{}, error) evaluateFunc func(data interface{}) (interface{}, error)
} }
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator { // NewGroupAggregator 创建新的分组聚合器
func NewGroupAggregator(groupFields []string, aggregationFields []AggregationField) *GroupAggregator {
aggregators := make(map[string]AggregatorFunction) aggregators := make(map[string]AggregatorFunction)
// 处理两种可能的调用模式: // 为每个聚合字段创建聚合器
// 1. SQL解析模式fieldMap是输出字段名->聚合类型fieldAlias是输出字段名->输入字段名 for _, field := range aggregationFields {
// 2. 直接测试模式fieldMap是输入字段名->聚合类型fieldAlias是输入字段名->输出字段名 if field.OutputAlias == "" {
// 如果没有指定别名,使用输入字段名
// 创建最终的映射 field.OutputAlias = field.InputField
finalFieldMap := make(map[string]AggregateType)
finalFieldAlias := make(map[string]string)
// 简化的检测逻辑:
// 在直接测试模式中fieldAlias 的值通常包含 "_sum", "_avg" 等后缀
// 在SQL解析模式中fieldAlias 的值是实际的数据字段名(如 "temperature"
isSQLMode := false
if len(fieldAlias) > 0 {
// 检查是否有任何 fieldAlias 的值看起来像 SQL 解析模式(不包含聚合后缀)
for _, aliasValue := range fieldAlias {
// 如果值不包含典型的聚合后缀可能是SQL模式
if !strings.Contains(aliasValue, "_sum") &&
!strings.Contains(aliasValue, "_avg") &&
!strings.Contains(aliasValue, "_min") &&
!strings.Contains(aliasValue, "_max") &&
!strings.Contains(aliasValue, "_count") {
isSQLMode = true
break
}
} }
} aggregators[field.OutputAlias] = CreateBuiltinAggregator(field.AggregateType)
if isSQLMode {
// SQL解析模式fieldMap是输出字段名->聚合类型fieldAlias是输出字段名->输入字段名
finalFieldMap = fieldMap
finalFieldAlias = fieldAlias
} else {
// 直接测试模式fieldMap是输入字段名->聚合类型fieldAlias是输入字段名->输出字段名
for inputField, aggType := range fieldMap {
outputField := inputField // 默认输出字段名等于输入字段名
// fieldAlias提供了输入字段名 -> 输出别名的映射
if alias, exists := fieldAlias[inputField]; exists {
outputField = alias
}
finalFieldMap[outputField] = aggType
finalFieldAlias[outputField] = inputField
}
}
// 创建聚合器
for outputField := range finalFieldMap {
aggregators[outputField] = CreateBuiltinAggregator(finalFieldMap[outputField])
} }
return &GroupAggregator{ return &GroupAggregator{
fieldMap: finalFieldMap, // 输出字段名 -> 聚合类型 aggregationFields: aggregationFields,
groupFields: groupFields, groupFields: groupFields,
aggregators: aggregators, aggregators: aggregators,
groups: make(map[string]map[string]AggregatorFunction), groups: make(map[string]map[string]AggregatorFunction),
fieldAlias: finalFieldAlias, // 输出字段名 -> 输入字段名 expressions: make(map[string]*ExpressionEvaluator),
expressions: make(map[string]*ExpressionEvaluator),
} }
} }
@ -224,38 +186,38 @@ func (ga *GroupAggregator) Add(data interface{}) error {
} }
// 为每个字段创建聚合器实例 // 为每个字段创建聚合器实例
for field, agg := range ga.aggregators { for outputAlias, agg := range ga.aggregators {
if _, exists := ga.groups[key][field]; !exists { if _, exists := ga.groups[key][outputAlias]; !exists {
ga.groups[key][field] = agg.New() ga.groups[key][outputAlias] = agg.New()
} }
} }
for field := range ga.fieldMap { // 处理每个聚合字段
for _, aggField := range ga.aggregationFields {
outputAlias := aggField.OutputAlias
if outputAlias == "" {
outputAlias = aggField.InputField
}
// 检查是否有表达式计算器 // 检查是否有表达式计算器
if expr, hasExpr := ga.expressions[field]; hasExpr { if expr, hasExpr := ga.expressions[outputAlias]; hasExpr {
result, err := expr.evaluateFunc(data) result, err := expr.evaluateFunc(data)
if err != nil { if err != nil {
continue continue
} }
if groupAgg, exists := ga.groups[key][field]; exists { if groupAgg, exists := ga.groups[key][outputAlias]; exists {
groupAgg.Add(result) groupAgg.Add(result)
} }
continue continue
} }
// 获取实际的输入字段名 inputField := aggField.InputField
// field现在是输出字段名可能是别名需要找到对应的输入字段名
inputFieldName := field
// 在聚合器内部fieldAlias的映射方向是输出字段名 -> 输入字段名
if mappedField, exists := ga.fieldAlias[field]; exists {
inputFieldName = mappedField
}
// 特殊处理count(*)的情况 // 特殊处理count(*)的情况
if inputFieldName == "*" { if inputField == "*" {
// 对于count(*)直接添加1不需要获取具体字段值 // 对于count(*)直接添加1不需要获取具体字段值
if groupAgg, exists := ga.groups[key][field]; exists { if groupAgg, exists := ga.groups[key][outputAlias]; exists {
groupAgg.Add(1) groupAgg.Add(1)
} }
continue continue
@ -264,16 +226,16 @@ func (ga *GroupAggregator) Add(data interface{}) error {
// 获取字段值 // 获取字段值
var f reflect.Value var f reflect.Value
if v.Kind() == reflect.Map { if v.Kind() == reflect.Map {
keyVal := reflect.ValueOf(inputFieldName) keyVal := reflect.ValueOf(inputField)
f = v.MapIndex(keyVal) f = v.MapIndex(keyVal)
} else { } else {
f = v.FieldByName(inputFieldName) f = v.FieldByName(inputField)
} }
if !f.IsValid() { if !f.IsValid() {
// 尝试从context中获取 // 尝试从context中获取
if ga.context != nil { if ga.context != nil {
if groupAgg, exists := ga.groups[key][field]; exists { if groupAgg, exists := ga.groups[key][outputAlias]; exists {
if contextAgg, ok := groupAgg.(ContextAggregator); ok { if contextAgg, ok := groupAgg.(ContextAggregator); ok {
contextKey := contextAgg.GetContextKey() contextKey := contextAgg.GetContextKey()
if val, exists := ga.context[contextKey]; exists { if val, exists := ga.context[contextKey]; exists {
@ -286,21 +248,21 @@ func (ga *GroupAggregator) Add(data interface{}) error {
} }
fieldVal := f.Interface() fieldVal := f.Interface()
aggType := ga.fieldMap[field] aggType := aggField.AggregateType
// 动态检查是否需要数值转换 // 动态检查是否需要数值转换
if ga.isNumericAggregator(aggType) { if ga.isNumericAggregator(aggType) {
// 对于数值聚合函数,尝试转换为数值类型 // 对于数值聚合函数,尝试转换为数值类型
if numVal, err := cast.ToFloat64E(fieldVal); err == nil { if numVal, err := cast.ToFloat64E(fieldVal); err == nil {
if groupAgg, exists := ga.groups[key][field]; exists { if groupAgg, exists := ga.groups[key][outputAlias]; exists {
groupAgg.Add(numVal) groupAgg.Add(numVal)
} }
} else { } else {
return fmt.Errorf("cannot convert field %s value %v to numeric type for aggregator %s", inputFieldName, fieldVal, aggType) return fmt.Errorf("cannot convert field %s value %v to numeric type for aggregator %s", inputField, fieldVal, aggType)
} }
} else { } else {
// 对于非数值聚合函数,直接传递原始值 // 对于非数值聚合函数,直接传递原始值
if groupAgg, exists := ga.groups[key][field]; exists { if groupAgg, exists := ga.groups[key][outputAlias]; exists {
groupAgg.Add(fieldVal) groupAgg.Add(fieldVal)
} }
} }

View File

@ -15,13 +15,17 @@ type testData struct {
func TestGroupAggregator_MultiFieldSum(t *testing.T) { func TestGroupAggregator_MultiFieldSum(t *testing.T) {
agg := NewGroupAggregator( agg := NewGroupAggregator(
[]string{"Device"}, []string{"Device"},
map[string]AggregateType{ []AggregationField{
"temperature": Sum, {
"humidity": Sum, InputField: "temperature",
}, AggregateType: Sum,
map[string]string{ OutputAlias: "temperature_sum",
"temperature": "temperature_sum", },
"humidity": "humidity_sum", {
InputField: "humidity",
AggregateType: Sum,
OutputAlias: "humidity_sum",
},
}, },
) )
@ -48,11 +52,12 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) {
func TestGroupAggregator_SingleField(t *testing.T) { func TestGroupAggregator_SingleField(t *testing.T) {
agg := NewGroupAggregator( agg := NewGroupAggregator(
[]string{"Device"}, []string{"Device"},
map[string]AggregateType{ []AggregationField{
"temperature": Sum, {
}, InputField: "temperature",
map[string]string{ AggregateType: Sum,
"temperature": "temperature_sum", OutputAlias: "temperature_sum",
},
}, },
) )
@ -76,17 +81,27 @@ func TestGroupAggregator_SingleField(t *testing.T) {
func TestGroupAggregator_MultipleAggregators(t *testing.T) { func TestGroupAggregator_MultipleAggregators(t *testing.T) {
agg := NewGroupAggregator( agg := NewGroupAggregator(
[]string{"Device"}, []string{"Device"},
map[string]AggregateType{ []AggregationField{
"temperature": Sum, {
"humidity": Avg, InputField: "temperature",
"presure": Max, AggregateType: Sum,
"PM10": Min, OutputAlias: "temperature_sum",
}, },
map[string]string{ {
"temperature": "temperature_sum", InputField: "humidity",
"humidity": "humidity_avg", AggregateType: Avg,
"presure": "presure_max", OutputAlias: "humidity_avg",
"PM10": "PM10_min", },
{
InputField: "presure",
AggregateType: Max,
OutputAlias: "presure_max",
},
{
InputField: "PM10",
AggregateType: Min,
OutputAlias: "PM10_min",
},
}, },
) )
@ -112,3 +127,33 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
results, _ := agg.GetResults() results, _ := agg.GetResults()
assert.ElementsMatch(t, expected, results) assert.ElementsMatch(t, expected, results)
} }
func TestGroupAggregator_NoAlias(t *testing.T) {
// 测试没有指定别名的情况,应该使用输入字段名作为输出字段名
agg := NewGroupAggregator(
[]string{"Device"},
[]AggregationField{
{
InputField: "temperature",
AggregateType: Sum,
// OutputAlias 留空,应该使用 InputField
},
},
)
testData := []map[string]interface{}{
{"Device": "dd", "temperature": 10.0},
{"Device": "dd", "temperature": 15.0},
}
for _, d := range testData {
agg.Add(d)
}
expected := []map[string]interface{}{
{"Device": "dd", "temperature": 25.0},
}
results, _ := agg.GetResults()
assert.ElementsMatch(t, expected, results)
}

View File

@ -268,6 +268,30 @@ func (s *Stream) RegisterFilter(conditionStr string) error {
return nil return nil
} }
// convertToAggregationFields 将旧格式的配置转换为新的AggregationField格式
func convertToAggregationFields(selectFields map[string]aggregator.AggregateType, fieldAlias map[string]string) []aggregator.AggregationField {
var fields []aggregator.AggregationField
for outputAlias, aggType := range selectFields {
field := aggregator.AggregationField{
AggregateType: aggType,
OutputAlias: outputAlias,
}
// 查找对应的输入字段名
if inputField, exists := fieldAlias[outputAlias]; exists {
field.InputField = inputField
} else {
// 如果没有别名映射,假设输入字段名等于输出别名
field.InputField = outputAlias
}
fields = append(fields, field)
}
return fields
}
func (s *Stream) Start() { func (s *Stream) Start() {
// 启动处理协程 // 启动处理协程
go s.process() go s.process()
@ -276,7 +300,9 @@ func (s *Stream) Start() {
func (s *Stream) process() { func (s *Stream) process() {
// 初始化聚合器,用于窗口模式 // 初始化聚合器,用于窗口模式
if s.config.NeedWindow { if s.config.NeedWindow {
s.aggregator = aggregator.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias) // 转换为新的AggregationField格式
aggregationFields := convertToAggregationFields(s.config.SelectFields, s.config.FieldAlias)
s.aggregator = aggregator.NewGroupAggregator(s.config.GroupFields, aggregationFields)
// 为表达式字段创建计算器 // 为表达式字段创建计算器
for field, fieldExpr := range s.config.FieldExpressions { for field, fieldExpr := range s.config.FieldExpressions {