mirror of
https://gitee.com/rulego/streamsql.git
synced 2026-05-11 13:07:03 +00:00
367 lines
11 KiB
Go
367 lines
11 KiB
Go
package aggregator
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/rulego/streamsql/functions"
|
|
"github.com/rulego/streamsql/utils/cast"
|
|
"github.com/rulego/streamsql/utils/fieldpath"
|
|
)
|
|
|
|
// Aggregator aggregator interface
|
|
type Aggregator interface {
|
|
Add(data interface{}) error
|
|
Put(key string, val interface{}) error
|
|
GetResults() ([]map[string]interface{}, error)
|
|
Reset()
|
|
// RegisterExpression registers expression evaluator
|
|
RegisterExpression(field, expression string, fields []string, evaluator func(data interface{}) (interface{}, error))
|
|
}
|
|
|
|
// AggregationField defines configuration for a single aggregation field
|
|
type AggregationField struct {
|
|
InputField string // Input field name (e.g., "temperature")
|
|
AggregateType AggregateType // Aggregation type (e.g., Sum, Avg)
|
|
OutputAlias string // Output alias (e.g., "temp_sum")
|
|
}
|
|
|
|
type GroupAggregator struct {
|
|
aggregationFields []AggregationField
|
|
groupFields []string
|
|
aggregators map[string]AggregatorFunction
|
|
groups map[string]map[string]AggregatorFunction
|
|
mu sync.RWMutex
|
|
context map[string]interface{}
|
|
// Expression evaluators
|
|
expressions map[string]*ExpressionEvaluator
|
|
}
|
|
|
|
// ExpressionEvaluator wraps expression evaluation functionality
|
|
type ExpressionEvaluator struct {
|
|
Expression string // Complete expression
|
|
Field string // Primary field name
|
|
Fields []string // All fields referenced in expression
|
|
evaluateFunc func(data interface{}) (interface{}, error)
|
|
}
|
|
|
|
// NewGroupAggregator creates a new group aggregator
|
|
func NewGroupAggregator(groupFields []string, aggregationFields []AggregationField) *GroupAggregator {
|
|
aggregators := make(map[string]AggregatorFunction)
|
|
|
|
// Create aggregator for each aggregation field
|
|
for i := range aggregationFields {
|
|
if aggregationFields[i].OutputAlias == "" {
|
|
// If no alias specified, use input field name
|
|
aggregationFields[i].OutputAlias = aggregationFields[i].InputField
|
|
}
|
|
aggregators[aggregationFields[i].OutputAlias] = CreateBuiltinAggregator(aggregationFields[i].AggregateType)
|
|
}
|
|
|
|
return &GroupAggregator{
|
|
aggregationFields: aggregationFields,
|
|
groupFields: groupFields,
|
|
aggregators: aggregators,
|
|
groups: make(map[string]map[string]AggregatorFunction),
|
|
expressions: make(map[string]*ExpressionEvaluator),
|
|
}
|
|
}
|
|
|
|
// RegisterExpression registers expression evaluator
|
|
func (ga *GroupAggregator) RegisterExpression(field, expression string, fields []string, evaluator func(data interface{}) (interface{}, error)) {
|
|
ga.mu.Lock()
|
|
defer ga.mu.Unlock()
|
|
|
|
ga.expressions[field] = &ExpressionEvaluator{
|
|
Expression: expression,
|
|
Field: field,
|
|
Fields: fields,
|
|
evaluateFunc: evaluator,
|
|
}
|
|
}
|
|
|
|
func (ga *GroupAggregator) Put(key string, val interface{}) error {
|
|
ga.mu.Lock()
|
|
defer ga.mu.Unlock()
|
|
if ga.context == nil {
|
|
ga.context = make(map[string]interface{})
|
|
}
|
|
ga.context[key] = val
|
|
return nil
|
|
}
|
|
|
|
// isNumericAggregator checks if aggregator requires numeric type input
|
|
func (ga *GroupAggregator) isNumericAggregator(aggType AggregateType) bool {
|
|
// Dynamically check function type through functions module
|
|
if fn, exists := functions.Get(string(aggType)); exists {
|
|
switch fn.GetType() {
|
|
case functions.TypeMath:
|
|
// Math functions usually require numeric input
|
|
return true
|
|
case functions.TypeAggregation:
|
|
// Check if it's a numeric aggregation function
|
|
switch string(aggType) {
|
|
case functions.SumStr, functions.AvgStr, functions.MinStr, functions.MaxStr, functions.CountStr,
|
|
functions.StdDevStr, functions.MedianStr, functions.PercentileStr,
|
|
functions.VarStr, functions.VarSStr, functions.StdDevSStr:
|
|
return true
|
|
case functions.CollectStr, functions.MergeAggStr, functions.DeduplicateStr, functions.LastValueStr:
|
|
// These functions can handle any type
|
|
return false
|
|
default:
|
|
// For unknown aggregation functions, try to check function name patterns
|
|
funcName := string(aggType)
|
|
if strings.Contains(funcName, functions.SumStr) || strings.Contains(funcName, functions.AvgStr) ||
|
|
strings.Contains(funcName, functions.MinStr) || strings.Contains(funcName, functions.MaxStr) ||
|
|
strings.Contains(funcName, functions.StdStr) || strings.Contains(funcName, functions.VarStr) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
case functions.TypeAnalytical:
|
|
// Analytical functions can usually handle any type
|
|
return false
|
|
default:
|
|
// For other types of functions, conservatively assume no numeric conversion needed
|
|
return false
|
|
}
|
|
}
|
|
|
|
// If function doesn't exist, judge by name pattern
|
|
funcName := string(aggType)
|
|
if strings.Contains(funcName, functions.SumStr) || strings.Contains(funcName, functions.AvgStr) ||
|
|
strings.Contains(funcName, functions.MinStr) || strings.Contains(funcName, functions.MaxStr) ||
|
|
strings.Contains(funcName, functions.CountStr) || strings.Contains(funcName, functions.StdStr) ||
|
|
strings.Contains(funcName, functions.VarStr) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// shouldAllowNullValues 判断聚合函数是否应该允许NULL值
|
|
func (ga *GroupAggregator) shouldAllowNullValues(aggType AggregateType) bool {
|
|
// FIRST_VALUE和LAST_VALUE函数应该允许NULL值,因为它们需要记录第一个/最后一个值,即使是NULL
|
|
return aggType == FirstValue || aggType == LastValue
|
|
}
|
|
|
|
func (ga *GroupAggregator) Add(data interface{}) error {
|
|
ga.mu.Lock()
|
|
defer ga.mu.Unlock()
|
|
|
|
// 检查数据是否为nil
|
|
if data == nil {
|
|
return fmt.Errorf("data cannot be nil")
|
|
}
|
|
|
|
var v reflect.Value
|
|
|
|
switch data.(type) {
|
|
case map[string]interface{}:
|
|
dataMap := data.(map[string]interface{})
|
|
v = reflect.ValueOf(dataMap)
|
|
default:
|
|
v = reflect.ValueOf(data)
|
|
if v.Kind() == reflect.Ptr {
|
|
v = v.Elem()
|
|
}
|
|
// 检查是否为支持的数据类型
|
|
if v.Kind() != reflect.Struct && v.Kind() != reflect.Map {
|
|
return fmt.Errorf("unsupported data type: %T, expected struct or map", data)
|
|
}
|
|
}
|
|
|
|
key := ""
|
|
for _, field := range ga.groupFields {
|
|
var fieldVal interface{}
|
|
var found bool
|
|
|
|
// Check if it's a nested field
|
|
if fieldpath.IsNestedField(field) {
|
|
fieldVal, found = fieldpath.GetNestedField(data, field)
|
|
} else {
|
|
// Original field access logic
|
|
var f reflect.Value
|
|
if v.Kind() == reflect.Map {
|
|
keyVal := reflect.ValueOf(field)
|
|
f = v.MapIndex(keyVal)
|
|
} else {
|
|
f = v.FieldByName(field)
|
|
}
|
|
|
|
if f.IsValid() {
|
|
fieldVal = f.Interface()
|
|
found = true
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
return fmt.Errorf("field %s not found", field)
|
|
}
|
|
|
|
if fieldVal == nil {
|
|
return fmt.Errorf("field %s has nil value", field)
|
|
}
|
|
|
|
if str, ok := fieldVal.(string); ok {
|
|
key += fmt.Sprintf("%s|", str)
|
|
} else {
|
|
key += fmt.Sprintf("%v|", fieldVal)
|
|
}
|
|
}
|
|
|
|
if _, exists := ga.groups[key]; !exists {
|
|
ga.groups[key] = make(map[string]AggregatorFunction)
|
|
}
|
|
|
|
// Create aggregator instances for each field
|
|
for outputAlias, agg := range ga.aggregators {
|
|
if _, exists := ga.groups[key][outputAlias]; !exists {
|
|
ga.groups[key][outputAlias] = agg.New()
|
|
}
|
|
}
|
|
|
|
// Process each aggregation field
|
|
for _, aggField := range ga.aggregationFields {
|
|
outputAlias := aggField.OutputAlias
|
|
if outputAlias == "" {
|
|
outputAlias = aggField.InputField
|
|
}
|
|
|
|
// Check if there's an expression evaluator
|
|
if expr, hasExpr := ga.expressions[outputAlias]; hasExpr {
|
|
result, err := expr.evaluateFunc(data)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
|
groupAgg.Add(result)
|
|
}
|
|
continue
|
|
}
|
|
|
|
inputField := aggField.InputField
|
|
|
|
// Special handling for count(*) case
|
|
if inputField == "*" {
|
|
// For count(*), directly add 1 without getting specific field value
|
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
|
groupAgg.Add(1)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Get field value - supports nested fields
|
|
var fieldVal interface{}
|
|
var found bool
|
|
|
|
if fieldpath.IsNestedField(inputField) {
|
|
fieldVal, found = fieldpath.GetNestedField(data, inputField)
|
|
} else {
|
|
// Original field access logic
|
|
var f reflect.Value
|
|
if v.Kind() == reflect.Map {
|
|
keyVal := reflect.ValueOf(inputField)
|
|
f = v.MapIndex(keyVal)
|
|
} else {
|
|
f = v.FieldByName(inputField)
|
|
}
|
|
|
|
if f.IsValid() {
|
|
fieldVal = f.Interface()
|
|
found = true
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
// Try to get from context
|
|
if ga.context != nil {
|
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
|
if contextAgg, ok := groupAgg.(ContextAggregator); ok {
|
|
contextKey := contextAgg.GetContextKey()
|
|
if val, exists := ga.context[contextKey]; exists {
|
|
groupAgg.Add(val)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
aggType := aggField.AggregateType
|
|
|
|
// Skip nil values for most aggregation functions, but allow FIRST_VALUE and LAST_VALUE to handle them
|
|
if fieldVal == nil && !ga.shouldAllowNullValues(aggType) {
|
|
continue
|
|
}
|
|
|
|
// Special handling for Count aggregator - it can handle any type
|
|
if aggType == Count {
|
|
// Count can handle any non-null value
|
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
|
groupAgg.Add(fieldVal)
|
|
}
|
|
} else if ga.isNumericAggregator(aggType) {
|
|
// For numeric aggregation functions, try to convert to numeric type
|
|
if numVal, err := cast.ToFloat64E(fieldVal); err == nil {
|
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
|
|
|
groupAgg.Add(numVal)
|
|
}
|
|
} else {
|
|
return fmt.Errorf("cannot convert field %s value %v to numeric type for aggregator %s", inputField, fieldVal, aggType)
|
|
}
|
|
} else {
|
|
// For non-numeric aggregation functions, pass original value directly
|
|
if groupAgg, exists := ga.groups[key][outputAlias]; exists {
|
|
|
|
groupAgg.Add(fieldVal)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) {
|
|
ga.mu.RLock()
|
|
defer ga.mu.RUnlock()
|
|
|
|
// 如果既没有分组字段又没有聚合字段,但有数据被添加过,返回一个空的结果行
|
|
if len(ga.aggregationFields) == 0 && len(ga.groupFields) == 0 {
|
|
if len(ga.groups) > 0 {
|
|
return []map[string]interface{}{{}}, nil
|
|
}
|
|
return []map[string]interface{}{}, nil
|
|
}
|
|
|
|
result := make([]map[string]interface{}, 0, len(ga.groups))
|
|
for key, aggregators := range ga.groups {
|
|
group := make(map[string]interface{})
|
|
fields := strings.Split(key, "|")
|
|
for i, field := range ga.groupFields {
|
|
if i < len(fields) {
|
|
group[field] = fields[i]
|
|
}
|
|
}
|
|
for field, agg := range aggregators {
|
|
result := agg.Result()
|
|
group[field] = result
|
|
// Debug: log aggregator results (can be removed in production)
|
|
// if strings.HasPrefix(field, "__") {
|
|
// fmt.Printf("Aggregator %s result: %v (%T)\n", field, result, result)
|
|
// }
|
|
}
|
|
result = append(result, group)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (ga *GroupAggregator) Reset() {
|
|
ga.mu.Lock()
|
|
defer ga.mu.Unlock()
|
|
ga.groups = make(map[string]map[string]AggregatorFunction)
|
|
}
|