mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-07-13 11:03:50 +00:00
refactor:重构聚合器API
This commit is contained in:
@ -19,14 +19,20 @@ type Aggregator interface {
|
|||||||
RegisterExpression(field, expression string, fields []string, evaluator func(data interface{}) (interface{}, error))
|
RegisterExpression(field, expression string, fields []string, evaluator func(data interface{}) (interface{}, error))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AggregationField 定义单个聚合字段的配置
|
||||||
|
type AggregationField struct {
|
||||||
|
InputField string // 输入字段名(如 "temperature")
|
||||||
|
AggregateType AggregateType // 聚合类型(如 Sum, Avg)
|
||||||
|
OutputAlias string // 输出别名(如 "temp_sum")
|
||||||
|
}
|
||||||
|
|
||||||
type GroupAggregator struct {
|
type GroupAggregator struct {
|
||||||
fieldMap map[string]AggregateType
|
aggregationFields []AggregationField
|
||||||
groupFields []string
|
groupFields []string
|
||||||
aggregators map[string]AggregatorFunction
|
aggregators map[string]AggregatorFunction
|
||||||
groups map[string]map[string]AggregatorFunction
|
groups map[string]map[string]AggregatorFunction
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
context map[string]interface{}
|
context map[string]interface{}
|
||||||
fieldAlias map[string]string
|
|
||||||
// 表达式计算器
|
// 表达式计算器
|
||||||
expressions map[string]*ExpressionEvaluator
|
expressions map[string]*ExpressionEvaluator
|
||||||
}
|
}
|
||||||
@ -39,69 +45,25 @@ type ExpressionEvaluator struct {
|
|||||||
evaluateFunc func(data interface{}) (interface{}, error)
|
evaluateFunc func(data interface{}) (interface{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator {
|
// NewGroupAggregator 创建新的分组聚合器
|
||||||
|
func NewGroupAggregator(groupFields []string, aggregationFields []AggregationField) *GroupAggregator {
|
||||||
aggregators := make(map[string]AggregatorFunction)
|
aggregators := make(map[string]AggregatorFunction)
|
||||||
|
|
||||||
// 处理两种可能的调用模式:
|
// 为每个聚合字段创建聚合器
|
||||||
// 1. SQL解析模式:fieldMap是输出字段名->聚合类型,fieldAlias是输出字段名->输入字段名
|
for _, field := range aggregationFields {
|
||||||
// 2. 直接测试模式:fieldMap是输入字段名->聚合类型,fieldAlias是输入字段名->输出字段名
|
if field.OutputAlias == "" {
|
||||||
|
// 如果没有指定别名,使用输入字段名
|
||||||
// 创建最终的映射
|
field.OutputAlias = field.InputField
|
||||||
finalFieldMap := make(map[string]AggregateType)
|
|
||||||
finalFieldAlias := make(map[string]string)
|
|
||||||
|
|
||||||
// 简化的检测逻辑:
|
|
||||||
// 在直接测试模式中,fieldAlias 的值通常包含 "_sum", "_avg" 等后缀
|
|
||||||
// 在SQL解析模式中,fieldAlias 的值是实际的数据字段名(如 "temperature")
|
|
||||||
|
|
||||||
isSQLMode := false
|
|
||||||
if len(fieldAlias) > 0 {
|
|
||||||
// 检查是否有任何 fieldAlias 的值看起来像 SQL 解析模式(不包含聚合后缀)
|
|
||||||
for _, aliasValue := range fieldAlias {
|
|
||||||
// 如果值不包含典型的聚合后缀,可能是SQL模式
|
|
||||||
if !strings.Contains(aliasValue, "_sum") &&
|
|
||||||
!strings.Contains(aliasValue, "_avg") &&
|
|
||||||
!strings.Contains(aliasValue, "_min") &&
|
|
||||||
!strings.Contains(aliasValue, "_max") &&
|
|
||||||
!strings.Contains(aliasValue, "_count") {
|
|
||||||
isSQLMode = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
aggregators[field.OutputAlias] = CreateBuiltinAggregator(field.AggregateType)
|
||||||
|
|
||||||
if isSQLMode {
|
|
||||||
// SQL解析模式:fieldMap是输出字段名->聚合类型,fieldAlias是输出字段名->输入字段名
|
|
||||||
finalFieldMap = fieldMap
|
|
||||||
finalFieldAlias = fieldAlias
|
|
||||||
} else {
|
|
||||||
// 直接测试模式:fieldMap是输入字段名->聚合类型,fieldAlias是输入字段名->输出字段名
|
|
||||||
for inputField, aggType := range fieldMap {
|
|
||||||
outputField := inputField // 默认输出字段名等于输入字段名
|
|
||||||
|
|
||||||
// fieldAlias提供了:输入字段名 -> 输出别名的映射
|
|
||||||
if alias, exists := fieldAlias[inputField]; exists {
|
|
||||||
outputField = alias
|
|
||||||
}
|
|
||||||
|
|
||||||
finalFieldMap[outputField] = aggType
|
|
||||||
finalFieldAlias[outputField] = inputField
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建聚合器
|
|
||||||
for outputField := range finalFieldMap {
|
|
||||||
aggregators[outputField] = CreateBuiltinAggregator(finalFieldMap[outputField])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &GroupAggregator{
|
return &GroupAggregator{
|
||||||
fieldMap: finalFieldMap, // 输出字段名 -> 聚合类型
|
aggregationFields: aggregationFields,
|
||||||
groupFields: groupFields,
|
groupFields: groupFields,
|
||||||
aggregators: aggregators,
|
aggregators: aggregators,
|
||||||
groups: make(map[string]map[string]AggregatorFunction),
|
groups: make(map[string]map[string]AggregatorFunction),
|
||||||
fieldAlias: finalFieldAlias, // 输出字段名 -> 输入字段名
|
expressions: make(map[string]*ExpressionEvaluator),
|
||||||
expressions: make(map[string]*ExpressionEvaluator),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,38 +186,38 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 为每个字段创建聚合器实例
|
// 为每个字段创建聚合器实例
|
||||||
for field, agg := range ga.aggregators {
|
for outputAlias, agg := range ga.aggregators {
|
||||||
if _, exists := ga.groups[key][field]; !exists {
|
if _, exists := ga.groups[key][outputAlias]; !exists {
|
||||||
ga.groups[key][field] = agg.New()
|
ga.groups[key][outputAlias] = agg.New()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for field := range ga.fieldMap {
|
// 处理每个聚合字段
|
||||||
|
for _, aggField := range ga.aggregationFields {
|
||||||
|
outputAlias := aggField.OutputAlias
|
||||||
|
if outputAlias == "" {
|
||||||
|
outputAlias = aggField.InputField
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否有表达式计算器
|
// 检查是否有表达式计算器
|
||||||
if expr, hasExpr := ga.expressions[field]; hasExpr {
|
if expr, hasExpr := ga.expressions[outputAlias]; hasExpr {
|
||||||
result, err := expr.evaluateFunc(data)
|
result, err := expr.evaluateFunc(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
||||||
groupAgg.Add(result)
|
groupAgg.Add(result)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取实际的输入字段名
|
inputField := aggField.InputField
|
||||||
// field现在是输出字段名(可能是别名),需要找到对应的输入字段名
|
|
||||||
inputFieldName := field
|
|
||||||
// 在聚合器内部,fieldAlias的映射方向是:输出字段名 -> 输入字段名
|
|
||||||
if mappedField, exists := ga.fieldAlias[field]; exists {
|
|
||||||
inputFieldName = mappedField
|
|
||||||
}
|
|
||||||
|
|
||||||
// 特殊处理count(*)的情况
|
// 特殊处理count(*)的情况
|
||||||
if inputFieldName == "*" {
|
if inputField == "*" {
|
||||||
// 对于count(*),直接添加1,不需要获取具体字段值
|
// 对于count(*),直接添加1,不需要获取具体字段值
|
||||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
||||||
groupAgg.Add(1)
|
groupAgg.Add(1)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@ -264,16 +226,16 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
|||||||
// 获取字段值
|
// 获取字段值
|
||||||
var f reflect.Value
|
var f reflect.Value
|
||||||
if v.Kind() == reflect.Map {
|
if v.Kind() == reflect.Map {
|
||||||
keyVal := reflect.ValueOf(inputFieldName)
|
keyVal := reflect.ValueOf(inputField)
|
||||||
f = v.MapIndex(keyVal)
|
f = v.MapIndex(keyVal)
|
||||||
} else {
|
} else {
|
||||||
f = v.FieldByName(inputFieldName)
|
f = v.FieldByName(inputField)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !f.IsValid() {
|
if !f.IsValid() {
|
||||||
// 尝试从context中获取
|
// 尝试从context中获取
|
||||||
if ga.context != nil {
|
if ga.context != nil {
|
||||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
||||||
if contextAgg, ok := groupAgg.(ContextAggregator); ok {
|
if contextAgg, ok := groupAgg.(ContextAggregator); ok {
|
||||||
contextKey := contextAgg.GetContextKey()
|
contextKey := contextAgg.GetContextKey()
|
||||||
if val, exists := ga.context[contextKey]; exists {
|
if val, exists := ga.context[contextKey]; exists {
|
||||||
@ -286,21 +248,21 @@ func (ga *GroupAggregator) Add(data interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fieldVal := f.Interface()
|
fieldVal := f.Interface()
|
||||||
aggType := ga.fieldMap[field]
|
aggType := aggField.AggregateType
|
||||||
|
|
||||||
// 动态检查是否需要数值转换
|
// 动态检查是否需要数值转换
|
||||||
if ga.isNumericAggregator(aggType) {
|
if ga.isNumericAggregator(aggType) {
|
||||||
// 对于数值聚合函数,尝试转换为数值类型
|
// 对于数值聚合函数,尝试转换为数值类型
|
||||||
if numVal, err := cast.ToFloat64E(fieldVal); err == nil {
|
if numVal, err := cast.ToFloat64E(fieldVal); err == nil {
|
||||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
||||||
groupAgg.Add(numVal)
|
groupAgg.Add(numVal)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("cannot convert field %s value %v to numeric type for aggregator %s", inputFieldName, fieldVal, aggType)
|
return fmt.Errorf("cannot convert field %s value %v to numeric type for aggregator %s", inputField, fieldVal, aggType)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 对于非数值聚合函数,直接传递原始值
|
// 对于非数值聚合函数,直接传递原始值
|
||||||
if groupAgg, exists := ga.groups[key][field]; exists {
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
||||||
groupAgg.Add(fieldVal)
|
groupAgg.Add(fieldVal)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,13 +15,17 @@ type testData struct {
|
|||||||
func TestGroupAggregator_MultiFieldSum(t *testing.T) {
|
func TestGroupAggregator_MultiFieldSum(t *testing.T) {
|
||||||
agg := NewGroupAggregator(
|
agg := NewGroupAggregator(
|
||||||
[]string{"Device"},
|
[]string{"Device"},
|
||||||
map[string]AggregateType{
|
[]AggregationField{
|
||||||
"temperature": Sum,
|
{
|
||||||
"humidity": Sum,
|
InputField: "temperature",
|
||||||
},
|
AggregateType: Sum,
|
||||||
map[string]string{
|
OutputAlias: "temperature_sum",
|
||||||
"temperature": "temperature_sum",
|
},
|
||||||
"humidity": "humidity_sum",
|
{
|
||||||
|
InputField: "humidity",
|
||||||
|
AggregateType: Sum,
|
||||||
|
OutputAlias: "humidity_sum",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,11 +52,12 @@ func TestGroupAggregator_MultiFieldSum(t *testing.T) {
|
|||||||
func TestGroupAggregator_SingleField(t *testing.T) {
|
func TestGroupAggregator_SingleField(t *testing.T) {
|
||||||
agg := NewGroupAggregator(
|
agg := NewGroupAggregator(
|
||||||
[]string{"Device"},
|
[]string{"Device"},
|
||||||
map[string]AggregateType{
|
[]AggregationField{
|
||||||
"temperature": Sum,
|
{
|
||||||
},
|
InputField: "temperature",
|
||||||
map[string]string{
|
AggregateType: Sum,
|
||||||
"temperature": "temperature_sum",
|
OutputAlias: "temperature_sum",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,17 +81,27 @@ func TestGroupAggregator_SingleField(t *testing.T) {
|
|||||||
func TestGroupAggregator_MultipleAggregators(t *testing.T) {
|
func TestGroupAggregator_MultipleAggregators(t *testing.T) {
|
||||||
agg := NewGroupAggregator(
|
agg := NewGroupAggregator(
|
||||||
[]string{"Device"},
|
[]string{"Device"},
|
||||||
map[string]AggregateType{
|
[]AggregationField{
|
||||||
"temperature": Sum,
|
{
|
||||||
"humidity": Avg,
|
InputField: "temperature",
|
||||||
"presure": Max,
|
AggregateType: Sum,
|
||||||
"PM10": Min,
|
OutputAlias: "temperature_sum",
|
||||||
},
|
},
|
||||||
map[string]string{
|
{
|
||||||
"temperature": "temperature_sum",
|
InputField: "humidity",
|
||||||
"humidity": "humidity_avg",
|
AggregateType: Avg,
|
||||||
"presure": "presure_max",
|
OutputAlias: "humidity_avg",
|
||||||
"PM10": "PM10_min",
|
},
|
||||||
|
{
|
||||||
|
InputField: "presure",
|
||||||
|
AggregateType: Max,
|
||||||
|
OutputAlias: "presure_max",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
InputField: "PM10",
|
||||||
|
AggregateType: Min,
|
||||||
|
OutputAlias: "PM10_min",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,3 +127,33 @@ func TestGroupAggregator_MultipleAggregators(t *testing.T) {
|
|||||||
results, _ := agg.GetResults()
|
results, _ := agg.GetResults()
|
||||||
assert.ElementsMatch(t, expected, results)
|
assert.ElementsMatch(t, expected, results)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGroupAggregator_NoAlias(t *testing.T) {
|
||||||
|
// 测试没有指定别名的情况,应该使用输入字段名作为输出字段名
|
||||||
|
agg := NewGroupAggregator(
|
||||||
|
[]string{"Device"},
|
||||||
|
[]AggregationField{
|
||||||
|
{
|
||||||
|
InputField: "temperature",
|
||||||
|
AggregateType: Sum,
|
||||||
|
// OutputAlias 留空,应该使用 InputField
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
testData := []map[string]interface{}{
|
||||||
|
{"Device": "dd", "temperature": 10.0},
|
||||||
|
{"Device": "dd", "temperature": 15.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range testData {
|
||||||
|
agg.Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []map[string]interface{}{
|
||||||
|
{"Device": "dd", "temperature": 25.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
results, _ := agg.GetResults()
|
||||||
|
assert.ElementsMatch(t, expected, results)
|
||||||
|
}
|
||||||
|
@ -268,6 +268,30 @@ func (s *Stream) RegisterFilter(conditionStr string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convertToAggregationFields 将旧格式的配置转换为新的AggregationField格式
|
||||||
|
func convertToAggregationFields(selectFields map[string]aggregator.AggregateType, fieldAlias map[string]string) []aggregator.AggregationField {
|
||||||
|
var fields []aggregator.AggregationField
|
||||||
|
|
||||||
|
for outputAlias, aggType := range selectFields {
|
||||||
|
field := aggregator.AggregationField{
|
||||||
|
AggregateType: aggType,
|
||||||
|
OutputAlias: outputAlias,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找对应的输入字段名
|
||||||
|
if inputField, exists := fieldAlias[outputAlias]; exists {
|
||||||
|
field.InputField = inputField
|
||||||
|
} else {
|
||||||
|
// 如果没有别名映射,假设输入字段名等于输出别名
|
||||||
|
field.InputField = outputAlias
|
||||||
|
}
|
||||||
|
|
||||||
|
fields = append(fields, field)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Stream) Start() {
|
func (s *Stream) Start() {
|
||||||
// 启动处理协程
|
// 启动处理协程
|
||||||
go s.process()
|
go s.process()
|
||||||
@ -276,7 +300,9 @@ func (s *Stream) Start() {
|
|||||||
func (s *Stream) process() {
|
func (s *Stream) process() {
|
||||||
// 初始化聚合器,用于窗口模式
|
// 初始化聚合器,用于窗口模式
|
||||||
if s.config.NeedWindow {
|
if s.config.NeedWindow {
|
||||||
s.aggregator = aggregator.NewGroupAggregator(s.config.GroupFields, s.config.SelectFields, s.config.FieldAlias)
|
// 转换为新的AggregationField格式
|
||||||
|
aggregationFields := convertToAggregationFields(s.config.SelectFields, s.config.FieldAlias)
|
||||||
|
s.aggregator = aggregator.NewGroupAggregator(s.config.GroupFields, aggregationFields)
|
||||||
|
|
||||||
// 为表达式字段创建计算器
|
// 为表达式字段创建计算器
|
||||||
for field, fieldExpr := range s.config.FieldExpressions {
|
for field, fieldExpr := range s.config.FieldExpressions {
|
||||||
|
Reference in New Issue
Block a user