From b481b5a675a24332739c69b0b67d576d0ea1b4c7 Mon Sep 17 00:00:00 2001 From: rulego-team Date: Fri, 13 Jun 2025 21:38:27 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aggregator/group_aggregator.go | 249 +++++++++++++++++------------- functions/custom_example.go | 6 +- functions/functions_conversion.go | 2 - stream/persistence.go | 2 +- stream/stream.go | 4 +- utils/cast/cast.go | 89 +++++++---- 6 files changed, 206 insertions(+), 146 deletions(-) diff --git a/aggregator/group_aggregator.go b/aggregator/group_aggregator.go index 23e9d68..e9f641e 100644 --- a/aggregator/group_aggregator.go +++ b/aggregator/group_aggregator.go @@ -42,17 +42,16 @@ type ExpressionEvaluator struct { func NewGroupAggregator(groupFields []string, fieldMap map[string]AggregateType, fieldAlias map[string]string) *GroupAggregator { aggregators := make(map[string]AggregatorFunction) - // 重新组织 fieldMap 和 fieldAlias - // 测试中:fieldMap: {"temperature": Sum}, fieldAlias: {"temperature": "temperature_sum"} - // 这意味着:输入字段"temperature",聚合类型Sum,输出别名"temperature_sum" + // 重新组织映射关系 + // fieldMap: 输入字段名 -> 聚合类型 + // fieldAlias: 输入字段名 -> 输出别名 + // 需要转换为:输出别名 -> 聚合类型,输出别名 -> 输入字段名 - // 创建新的映射:输出字段名 -> 聚合类型 - newFieldMap := make(map[string]AggregateType) - // 创建新的别名映射:输出字段名 -> 输入字段名 - newFieldAlias := make(map[string]string) + newFieldMap := make(map[string]AggregateType) // 输出字段名 -> 聚合类型 + newFieldAlias := make(map[string]string) // 输出字段名 -> 输入字段名 for inputField, aggType := range fieldMap { - outputField := inputField // 默认输出字段名等于输入字段名 + outputField := inputField // 默认输出字段名 = 输入字段名 if alias, exists := fieldAlias[inputField]; exists { outputField = alias // 如果有别名,使用别名作为输出字段名 } @@ -143,49 +142,61 @@ func (ga *GroupAggregator) isNumericAggregator(aggType AggregateType) bool { return false } -func (ga *GroupAggregator) Add(data interface{}) error { - ga.mu.Lock() - defer ga.mu.Unlock() - var v reflect.Value - +// prepareDataValue 准备数据的反射值 +func (ga *GroupAggregator) prepareDataValue(data interface{}) reflect.Value { switch data.(type) { case map[string]interface{}: - dataMap := data.(map[string]interface{}) - v = reflect.ValueOf(dataMap) + return reflect.ValueOf(data.(map[string]interface{})) default: - v = reflect.ValueOf(data) + v := reflect.ValueOf(data) if v.Kind() == reflect.Ptr { v = v.Elem() } + return v } +} +// buildGroupKey 构建分组键 +func (ga *GroupAggregator) buildGroupKey(v reflect.Value) (string, error) { key := "" for _, field := range ga.groupFields { - var f reflect.Value - - if v.Kind() == reflect.Map { - keyVal := reflect.ValueOf(field) - f = v.MapIndex(keyVal) - } else { - f = v.FieldByName(field) + fieldValue, err := ga.getFieldValue(v, field) + if err != nil { + return "", err } - if !f.IsValid() { - return fmt.Errorf("field %s not found", field) + if fieldValue == nil { + return "", fmt.Errorf("field %s has nil value", field) } - keyVal := f.Interface() - if keyVal == nil { - return fmt.Errorf("field %s has nil value", field) - } - - if str, ok := keyVal.(string); ok { + if str, ok := fieldValue.(string); ok { key += fmt.Sprintf("%s|", str) } else { - key += fmt.Sprintf("%v|", keyVal) + key += fmt.Sprintf("%v|", fieldValue) } } + return key, nil +} +// getFieldValue 获取字段值 +func (ga *GroupAggregator) getFieldValue(v reflect.Value, fieldName string) (interface{}, error) { + var f reflect.Value + if v.Kind() == reflect.Map { + keyVal := reflect.ValueOf(fieldName) + f = v.MapIndex(keyVal) + } else { + f = v.FieldByName(fieldName) + } + + if !f.IsValid() { + return nil, fmt.Errorf("field %s not found", fieldName) + } + + return f.Interface(), nil +} + +// ensureAggregators 确保聚合器实例存在 +func (ga *GroupAggregator) ensureAggregators(key string) { if _, exists := ga.groups[key]; !exists { ga.groups[key] = make(map[string]AggregatorFunction) } @@ -196,80 +207,110 @@ func (ga *GroupAggregator) Add(data interface{}) error { ga.groups[key][field] = agg.New() } } +} + +// processFieldAggregation 处理字段聚合 +func (ga *GroupAggregator) processFieldAggregation(key, field string, data interface{}, v reflect.Value) error { + // 检查是否有表达式计算器 + if expr, hasExpr := ga.expressions[field]; hasExpr { + return ga.processExpressionField(key, field, expr, data) + } + + // 获取实际的输入字段名 + inputFieldName := ga.getInputFieldName(field) + + // 特殊处理count(*)的情况 + if inputFieldName == "*" { + return ga.addValueToAggregator(key, field, 1) + } + + // 获取字段值并处理 + return ga.processRegularField(key, field, inputFieldName, v) +} + +// processExpressionField 处理表达式字段 +func (ga *GroupAggregator) processExpressionField(key, field string, expr *ExpressionEvaluator, data interface{}) error { + result, err := expr.evaluateFunc(data) + if err != nil { + return nil // 继续处理其他字段 + } + return ga.addValueToAggregator(key, field, result) +} + +// getInputFieldName 获取输入字段名 +func (ga *GroupAggregator) getInputFieldName(field string) string { + if mappedField, exists := ga.fieldAlias[field]; exists { + return mappedField + } + return field +} + +// processRegularField 处理常规字段 +func (ga *GroupAggregator) processRegularField(key, field, inputFieldName string, v reflect.Value) error { + fieldVal, err := ga.getFieldValue(v, inputFieldName) + if err != nil { + // 尝试从context中获取 + return ga.tryContextAggregation(key, field) + } + + aggType := ga.fieldMap[field] + if ga.isNumericAggregator(aggType) { + return ga.processNumericField(key, field, inputFieldName, fieldVal, aggType) + } + + return ga.addValueToAggregator(key, field, fieldVal) +} + +// tryContextAggregation 尝试从context中获取值进行聚合 +func (ga *GroupAggregator) tryContextAggregation(key, field string) error { + if ga.context == nil { + return nil + } + + if groupAgg, exists := ga.groups[key][field]; exists { + if contextAgg, ok := groupAgg.(ContextAggregator); ok { + contextKey := contextAgg.GetContextKey() + if val, exists := ga.context[contextKey]; exists { + groupAgg.Add(val) + } + } + } + return nil +} + +// processNumericField 处理数值字段 +func (ga *GroupAggregator) processNumericField(key, field, inputFieldName string, fieldVal interface{}, aggType AggregateType) error { + numVal, err := cast.ToFloat64E(fieldVal) + if err != nil { + return fmt.Errorf("cannot convert field %s value %v to numeric type for aggregator %s", inputFieldName, fieldVal, aggType) + } + return ga.addValueToAggregator(key, field, numVal) +} + +// addValueToAggregator 向聚合器添加值 +func (ga *GroupAggregator) addValueToAggregator(key, field string, value interface{}) error { + if groupAgg, exists := ga.groups[key][field]; exists { + groupAgg.Add(value) + } + return nil +} + +func (ga *GroupAggregator) Add(data interface{}) error { + ga.mu.Lock() + defer ga.mu.Unlock() + + v := ga.prepareDataValue(data) + + key, err := ga.buildGroupKey(v) + if err != nil { + return err + } + + ga.ensureAggregators(key) for field := range ga.fieldMap { - // 检查是否有表达式计算器 - if expr, hasExpr := ga.expressions[field]; hasExpr { - result, err := expr.evaluateFunc(data) - if err != nil { - continue - } - - if groupAgg, exists := ga.groups[key][field]; exists { - groupAgg.Add(result) - } - continue - } - - // 获取实际的输入字段名 - // field现在是输出字段名(可能是别名),需要找到对应的输入字段名 - inputFieldName := field - if mappedField, exists := ga.fieldAlias[field]; exists { - // 如果field是别名,获取实际输入字段名 - inputFieldName = mappedField - } - - // 特殊处理count(*)的情况 - if inputFieldName == "*" { - // 对于count(*),直接添加1,不需要获取具体字段值 - if groupAgg, exists := ga.groups[key][field]; exists { - groupAgg.Add(1) - } - continue - } - - // 获取字段值 - var f reflect.Value - if v.Kind() == reflect.Map { - keyVal := reflect.ValueOf(inputFieldName) - f = v.MapIndex(keyVal) - } else { - f = v.FieldByName(inputFieldName) - } - - if !f.IsValid() { - // 尝试从context中获取 - if ga.context != nil { - if groupAgg, exists := ga.groups[key][field]; exists { - if contextAgg, ok := groupAgg.(ContextAggregator); ok { - contextKey := contextAgg.GetContextKey() - if val, exists := ga.context[contextKey]; exists { - groupAgg.Add(val) - } - } - } - } - continue - } - - fieldVal := f.Interface() - aggType := ga.fieldMap[field] - - // 动态检查是否需要数值转换 - if ga.isNumericAggregator(aggType) { - // 对于数值聚合函数,尝试转换为数值类型 - if numVal, err := cast.ToFloat64E(fieldVal); err == nil { - if groupAgg, exists := ga.groups[key][field]; exists { - groupAgg.Add(numVal) - } - } else { - return fmt.Errorf("cannot convert field %s value %v to numeric type for aggregator %s", inputFieldName, fieldVal, aggType) - } - } else { - // 对于非数值聚合函数,直接传递原始值 - if groupAgg, exists := ga.groups[key][field]; exists { - groupAgg.Add(fieldVal) - } + if err := ga.processFieldAggregation(key, field, data, v); err != nil { + return err } } diff --git a/functions/custom_example.go b/functions/custom_example.go index 3880952..d8bc6fc 100644 --- a/functions/custom_example.go +++ b/functions/custom_example.go @@ -241,11 +241,11 @@ func (f *CustomGeometricMeanFunction) Clone() AggregatorFunction { // RegisterCustomFunctions 注册自定义函数的示例 func RegisterCustomFunctions() { // 注册自定义聚合函数 - Register(NewCustomProductFunction()) - Register(NewCustomGeometricMeanFunction()) + _ = Register(NewCustomProductFunction()) + _ = Register(NewCustomGeometricMeanFunction()) // 注册自定义分析函数 - Register(NewCustomMovingAverageFunction(5)) // 5个值的移动平均 + _ = Register(NewCustomMovingAverageFunction(5)) // 5个值的移动平均 // 注册适配器 RegisterAggregatorAdapter("product") diff --git a/functions/functions_conversion.go b/functions/functions_conversion.go index ae74673..1068daa 100644 --- a/functions/functions_conversion.go +++ b/functions/functions_conversion.go @@ -353,8 +353,6 @@ func (f *ChrFunction) Execute(ctx *FunctionContext, args []interface{}) (interfa return string(rune(code)), nil } - - // UrlEncodeFunction URL编码函数 type UrlEncodeFunction struct { *BaseFunction diff --git a/stream/persistence.go b/stream/persistence.go index 7977f13..5301c4c 100644 --- a/stream/persistence.go +++ b/stream/persistence.go @@ -279,7 +279,7 @@ func (pm *PersistenceManager) flushPendingData() { // 同步到磁盘 if pm.currentFile != nil { - pm.currentFile.Sync() + _ = pm.currentFile.Sync() } logger.Info("Flushed %d pending data records to disk", len(dataToWrite)) diff --git a/stream/stream.go b/stream/stream.go index e5d7462..c78f886 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -345,8 +345,8 @@ func (s *Stream) process() { for batch := range s.Window.OutputChan() { // 处理窗口批数据 for _, item := range batch { - s.aggregator.Put("window_start", item.Slot.WindowStart()) - s.aggregator.Put("window_end", item.Slot.WindowEnd()) + _ = s.aggregator.Put("window_start", item.Slot.WindowStart()) + _ = s.aggregator.Put("window_end", item.Slot.WindowEnd()) if err := s.aggregator.Add(item.Data); err != nil { logger.Error("aggregate error: %v", err) } diff --git a/utils/cast/cast.go b/utils/cast/cast.go index f1b1359..650dcb1 100644 --- a/utils/cast/cast.go +++ b/utils/cast/cast.go @@ -237,44 +237,41 @@ func ToString(input interface{}) string { return v } -// ToStringE converts an interface{} to string with error handling. -// Returns the converted string value and nil error if successful. -// Returns empty string and an error if conversion fails. -func ToStringE(input interface{}) (string, error) { - if input == nil { - return "", nil - } +// convertNumericToString 将数字类型转换为字符串 +func convertNumericToString(input interface{}) (string, bool) { switch v := input.(type) { - case string: - return v, nil - case bool: - return strconv.FormatBool(v), nil case float64: - ft := input.(float64) - return strconv.FormatFloat(ft, 'f', -1, 64), nil + return strconv.FormatFloat(v, 'f', -1, 64), true case float32: - ft := input.(float32) - return strconv.FormatFloat(float64(ft), 'f', -1, 32), nil + return strconv.FormatFloat(float64(v), 'f', -1, 32), true case int: - return strconv.Itoa(v), nil + return strconv.Itoa(v), true case uint: - return strconv.Itoa(int(v)), nil + return strconv.Itoa(int(v)), true case int8: - return strconv.Itoa(int(v)), nil + return strconv.Itoa(int(v)), true case uint8: - return strconv.Itoa(int(v)), nil + return strconv.Itoa(int(v)), true case int16: - return strconv.Itoa(int(v)), nil + return strconv.Itoa(int(v)), true case uint16: - return strconv.Itoa(int(v)), nil + return strconv.Itoa(int(v)), true case int32: - return strconv.Itoa(int(v)), nil + return strconv.Itoa(int(v)), true case uint32: - return strconv.Itoa(int(v)), nil + return strconv.Itoa(int(v)), true case int64: - return strconv.FormatInt(v, 10), nil + return strconv.FormatInt(v, 10), true case uint64: - return strconv.FormatUint(v, 10), nil + return strconv.FormatUint(v, 10), true + default: + return "", false + } +} + +// convertComplexToString 将复杂类型转换为字符串 +func convertComplexToString(input interface{}) (string, error) { + switch v := input.(type) { case []byte: return string(v), nil case fmt.Stringer: @@ -287,17 +284,41 @@ func ToStringE(input interface{}) (string, error) { for k, value := range v { convertedInput[fmt.Sprintf("%v", k)] = value } - if newValue, err := json.Marshal(convertedInput); err == nil { - return string(newValue), nil - } else { - return "", err - } + return marshalToString(convertedInput) default: - if newValue, err := json.Marshal(input); err == nil { - return string(newValue), nil - } else { - return "", err + return marshalToString(input) + } +} + +// marshalToString 通过JSON序列化转换为字符串 +func marshalToString(input interface{}) (string, error) { + if newValue, err := json.Marshal(input); err == nil { + return string(newValue), nil + } else { + return "", err + } +} + +// ToStringE converts an interface{} to string with error handling. +// Returns the converted string value and nil error if successful. +// Returns empty string and an error if conversion fails. +func ToStringE(input interface{}) (string, error) { + if input == nil { + return "", nil + } + + switch v := input.(type) { + case string: + return v, nil + case bool: + return strconv.FormatBool(v), nil + default: + // 尝试数字类型转换 + if str, ok := convertNumericToString(input); ok { + return str, nil } + // 尝试复杂类型转换 + return convertComplexToString(input) } }