package aggregator import ( "fmt" "regexp" "strconv" "strings" "sync" "github.com/rulego/streamsql/functions" ) // Configuration constants for post-aggregation processing const ( // PlaceholderPrefix defines the prefix for aggregation field placeholders PlaceholderPrefix = "__" // PlaceholderSuffix defines the suffix for aggregation field placeholders PlaceholderSuffix = "__" // HashMultiplier is used for generating unique hash values for function calls HashMultiplier = 31 // MaxFunctionNameLength defines the maximum allowed length for function names MaxFunctionNameLength = 100 // MaxExpressionDepth defines the maximum nesting depth for expression parsing MaxExpressionDepth = 50 ) var ( // funcCallRegex is a compiled regex for function calls, cached for performance funcCallRegex = regexp.MustCompile(`(?i)([a-z_]+)\s*\(`) // placeholderRegex is a compiled regex for placeholder detection placeholderRegex = regexp.MustCompile(`^` + regexp.QuoteMeta(PlaceholderPrefix) + `.*` + regexp.QuoteMeta(PlaceholderSuffix) + `$`) ) // PostAggregationExpression represents an expression that needs to be evaluated after aggregation type PostAggregationExpression struct { OutputField string // 输出字段名 Expression string // 表达式模板,如 "__first_value_0__ - __last_value_1__" RequiredAggFields []string // 依赖的聚合字段,如 ["__first_value_0__", "__last_value_1__"] OriginalExpr string // 原始表达式,用于调试 processor *PostAggregationProcessor // 处理器引用 } // Evaluate 评估后聚合表达式 func (pae *PostAggregationExpression) Evaluate(data map[string]interface{}) (interface{}, error) { if pae == nil { return nil, fmt.Errorf("post-aggregation expression is nil") } if pae.processor == nil { return nil, fmt.Errorf("post-aggregation processor not initialized") } if strings.TrimSpace(pae.Expression) == "" { return nil, fmt.Errorf("expression cannot be empty") } if data == nil { return nil, fmt.Errorf("evaluation data cannot be nil") } return pae.processor.evaluateExpression(pae.Expression, data) } // PostAggregationProcessor handles expressions that contain aggregation functions type PostAggregationProcessor struct { expressions []PostAggregationExpression mu sync.RWMutex exprBridge *functions.ExprBridge fieldsCache map[string][]string } // NewPostAggregationProcessor creates a new post-aggregation processor func NewPostAggregationProcessor() *PostAggregationProcessor { return &PostAggregationProcessor{ expressions: make([]PostAggregationExpression, 0), exprBridge: functions.GetExprBridge(), fieldsCache: make(map[string][]string), } } // AddExpression adds a post-aggregation expression func (p *PostAggregationProcessor) AddExpression(outputField, originalExpr string, aggFields []string, exprTemplate string) { p.mu.Lock() defer p.mu.Unlock() expr := PostAggregationExpression{ OutputField: outputField, Expression: exprTemplate, RequiredAggFields: aggFields, OriginalExpr: originalExpr, processor: p, } p.expressions = append(p.expressions, expr) p.fieldsCache[outputField] = aggFields } // ProcessResults processes aggregation results and evaluates post-aggregation expressions func (p *PostAggregationProcessor) ProcessResults(results []map[string]interface{}) ([]map[string]interface{}, error) { p.mu.RLock() defer p.mu.RUnlock() if len(p.expressions) == 0 { return results, nil } // Pre-allocate cleanup fields map to avoid repeated allocations fieldsToCleanup := make(map[string]bool, len(p.expressions)*2) // Collect all placeholder fields that need cleanup for j := range p.expressions { expr := &p.expressions[j] p.markPlaceholderFields(expr.RequiredAggFields, fieldsToCleanup) } // Process each result row for i := range results { result := results[i] for j := range p.expressions { expr := &p.expressions[j] // Fast path: check required fields presence allPresent := p.checkRequiredFields(result, expr.RequiredAggFields) if !allPresent { result[expr.OutputField] = nil continue } // Evaluate expression exprResult, err := p.evaluateExpressionFast(expr.Expression, result) if err != nil { result[expr.OutputField] = nil } else { result[expr.OutputField] = exprResult } } // Batch cleanup of placeholder fields for field := range fieldsToCleanup { delete(result, field) } } return results, nil } // checkRequiredFields checks if all required fields are present in the result func (p *PostAggregationProcessor) checkRequiredFields(result map[string]interface{}, requiredFields []string) bool { for _, field := range requiredFields { if _, exists := result[field]; !exists { return false } } return true } // markPlaceholderFields marks placeholder fields for cleanup func (p *PostAggregationProcessor) markPlaceholderFields(requiredFields []string, fieldsToCleanup map[string]bool) { for _, field := range requiredFields { if placeholderRegex.MatchString(field) { fieldsToCleanup[field] = true } } } // evaluateExpressionFast evaluates an expression using cached bridge func (p *PostAggregationProcessor) evaluateExpressionFast(expression string, data map[string]interface{}) (interface{}, error) { result, err := p.exprBridge.EvaluateExpression(expression, data) if err != nil { return nil, err } return p.unwrapNestedSlices(result), nil } // evaluateExpression evaluates an expression using aggregated values func (p *PostAggregationProcessor) evaluateExpression(expression string, data map[string]interface{}) (interface{}, error) { return p.evaluateExpressionFast(expression, data) } // unwrapNestedSlices recursively unwraps nested empty slices to get the actual value func (p *PostAggregationProcessor) unwrapNestedSlices(value interface{}) interface{} { if value == nil { return nil } // Check if it's a slice if slice, ok := value.([]interface{}); ok { // If it's an empty slice or contains only nil, return nil if len(slice) == 0 { return nil } // If it contains only one element, recursively unwrap it if len(slice) == 1 { return p.unwrapNestedSlices(slice[0]) } // If it contains multiple elements, return as is return slice } // For non-slice values, return as is return value } // ParseComplexAggregationExpression parses expressions containing multiple aggregation functions // Returns the list of required aggregation fields and the expression template // 该函数将包含聚合函数的复杂表达式分解为: // 1. 后聚合表达式模板(聚合函数被占位符替换) // 2. 需要预先计算的聚合字段信息列表 // 3. 错误信息(如果解析失败) // // 示例: // // 输入: "SUM(price) + AVG(quantity) * 2" // 输出: 表达式模板 "__SUM_123__ + __AVG_456__ * 2" // 聚合字段 [{FieldName: "__SUM_123__", FunctionName: "SUM", Arguments: ["price"]}, ...] func ParseComplexAggregationExpression(expr string) (aggFields []AggregationFieldInfo, exprTemplate string, err error) { exprTemplate = expr // 使用递归方法解析嵌套函数调用 aggFields, exprTemplate = parseNestedFunctions(expr, make([]AggregationFieldInfo, 0)) return aggFields, exprTemplate, nil } // parseNestedFunctions 递归解析嵌套函数调用 func parseNestedFunctions(expr string, aggFields []AggregationFieldInfo) ([]AggregationFieldInfo, string) { return parseNestedFunctionsWithDepth(expr, aggFields, 0) } // findFunctionCalls 查找表达式中的所有函数调用 func findFunctionCalls(expr string) [][]int { return funcCallRegex.FindAllStringSubmatchIndex(expr, -1) } // generatePlaceholder 为函数调用生成唯一占位符 func generatePlaceholder(funcName, fullFuncCall string) string { callHash := uint32(0) for i := 0; i < len(fullFuncCall); i++ { callHash = callHash*HashMultiplier + uint32(fullFuncCall[i]) } return PlaceholderPrefix + funcName + "_" + strconv.FormatUint(uint64(callHash), 10) + PlaceholderSuffix } // parseNestedFunctionsWithDepth 递归解析嵌套函数调用,支持深度控制 func parseNestedFunctionsWithDepth(expr string, aggFields []AggregationFieldInfo, depth int) ([]AggregationFieldInfo, string) { if depth > MaxExpressionDepth { return aggFields, expr } isTopLevelSingleAggregation := (depth == 0 && isTopLevelAggregationFunction(expr)) matches := findFunctionCalls(expr) if len(matches) == 0 { return aggFields, expr } // 从右到左处理,避免索引偏移问题 for i := len(matches) - 1; i >= 0; i-- { match := matches[i] funcStart := match[0] funcName := strings.ToLower(expr[match[2]:match[3]]) parenStart := match[3] parenEnd := findMatchingParen(expr, parenStart) if parenEnd == -1 { continue } fullFuncCall := expr[funcStart : parenEnd+1] funcParam := expr[parenStart+1 : parenEnd] if fn, exists := functions.Get(funcName); exists { switch fn.GetType() { case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: if isTopLevelSingleAggregation && i == 0 { innerAggFields, processedParam := parseNestedFunctionsWithDepth(funcParam, aggFields, depth+1) aggFields = innerAggFields expr = expr[:parenStart+1] + processedParam + expr[parenEnd:] continue } placeholder := generatePlaceholder(funcName, fullFuncCall) inputField := funcParam if strings.Contains(funcParam, ",") && fn.GetMinArgs() > 1 { if commaIdx := strings.Index(funcParam, ","); commaIdx > 0 { inputField = strings.TrimSpace(funcParam[:commaIdx]) } } aggFields = append(aggFields, AggregationFieldInfo{ FuncName: funcName, InputField: inputField, Placeholder: placeholder, AggType: AggregateType(funcName), FullCall: fullFuncCall, }) expr = expr[:funcStart] + placeholder + expr[parenEnd+1:] default: innerAggFields, processedParam := parseNestedFunctionsWithDepth(funcParam, aggFields, depth+1) aggFields = innerAggFields expr = expr[:parenStart+1] + processedParam + expr[parenEnd:] } } } return aggFields, expr } // isTopLevelAggregationFunction 检查表达式是否是顶层的单一聚合函数调用 func isTopLevelAggregationFunction(expr string) bool { // 提取最外层的函数名 funcName := extractOutermostFunctionName(expr) if funcName == "" { return false } // 检查是否是聚合函数 if fn, exists := functions.Get(funcName); exists { switch fn.GetType() { case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: return true } } return false } // extractOutermostFunctionName 提取最外层的函数名 func extractOutermostFunctionName(expr string) string { expr = strings.TrimSpace(expr) // 查找第一个左括号 parenIndex := strings.Index(expr, "(") if parenIndex == -1 { return "" } // 提取函数名 funcName := strings.TrimSpace(expr[:parenIndex]) // 检查函数名是否有效(只包含字母、数字、下划线) for _, char := range funcName { if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9') || char == '_') { return "" } } return funcName } // findMatchingParen 找到匹配的右括号 func findMatchingParen(s string, start int) int { if start >= len(s) || s[start] != '(' { return -1 } count := 1 for i := start + 1; i < len(s); i++ { switch s[i] { case '(': count++ case ')': count-- if count == 0 { return i } } } return -1 // 未找到匹配的右括号 } // AggregationFieldInfo holds information about an aggregation function in an expression type AggregationFieldInfo struct { FuncName string // 函数名,如 "first_value" InputField string // 输入字段,如 "displayNum" Placeholder string // 占位符,如 "__first_value_0__" AggType AggregateType // 聚合类型 FullCall string // 完整函数调用,如 "NTH_VALUE(value, 2)" } // Enhanced GroupAggregator with post-aggregation support type EnhancedGroupAggregator struct { *GroupAggregator postProcessor *PostAggregationProcessor } // NewEnhancedGroupAggregator creates a new enhanced group aggregator with post-aggregation support func NewEnhancedGroupAggregator(groupFields []string, aggregationFields []AggregationField) *EnhancedGroupAggregator { baseAggregator := NewGroupAggregator(groupFields, aggregationFields) return &EnhancedGroupAggregator{ GroupAggregator: baseAggregator, postProcessor: NewPostAggregationProcessor(), } } // AddPostAggregationExpression adds an expression that needs post-aggregation processing func (ega *EnhancedGroupAggregator) AddPostAggregationExpression(outputField, originalExpr string, requiredFields []AggregationFieldInfo) error { // Validate input parameters if strings.TrimSpace(originalExpr) == "" { return fmt.Errorf("expression cannot be empty") } // Check for malformed expressions (basic validation) if strings.Count(originalExpr, "(") != strings.Count(originalExpr, ")") { return fmt.Errorf("malformed expression: mismatched parentheses") } // Validate required fields contain valid function names for _, field := range requiredFields { if field.FuncName != "" { if _, exists := functions.Get(field.FuncName); !exists { return fmt.Errorf("invalid function name: %s", field.FuncName) } } } // Add individual aggregation fields to the base aggregator (only if not already exists) for _, field := range requiredFields { // For parameterized functions, always recreate the aggregator with correct parameters // even if it already exists (it was created with default parameters) // A function is considered parameterized if it needs multiple parameters to configure its behavior isParameterized := false if fn, exists := functions.Get(field.FuncName); exists { minArgs := fn.GetMinArgs() maxArgs := fn.GetMaxArgs() // Function is parameterized if: // 1. It requires more than 1 parameter (minArgs > 1), OR // 2. It has optional parameters that can configure its behavior (maxArgs > minArgs && minArgs >= 1) isParameterized = minArgs > 1 || (maxArgs > minArgs && minArgs >= 1) } // Check if field already exists in aggregationFields to avoid duplicates fieldExistsInAggFields := false for _, existingField := range ega.GroupAggregator.aggregationFields { if existingField.OutputAlias == field.Placeholder { fieldExistsInAggFields = true break } } // Check if input field is an expression (contains function calls) isInputExpression := strings.Contains(field.InputField, "(") && strings.Contains(field.InputField, ")") // If input expression itself contains aggregation calls, skip creating an aggregator for this field // Use dynamic function registry instead of hardcoded list containsAggCall := func(s string) bool { lower := strings.ToLower(s) // Extract potential function names from the expression for i := 0; i < len(lower); i++ { if lower[i] >= 'a' && lower[i] <= 'z' { // Find the end of the function name j := i for j < len(lower) && (lower[j] >= 'a' && lower[j] <= 'z' || lower[j] == '_') { j++ } // Check if it's followed by '(' and is an aggregator function if j < len(lower) && lower[j] == '(' { funcName := lower[i:j] if functions.IsAggregatorFunction(funcName) { return true } } i = j } else { i++ } } return false } // Check if expression is already registered hasExpressionRegistered := false if ega.GroupAggregator.expressions != nil { _, hasExpressionRegistered = ega.GroupAggregator.expressions[field.Placeholder] } // For parameterized functions, always recreate the aggregator with correct parameters // For non-parameterized functions, only add if field doesn't exist // For expression fields, always ensure expression is registered shouldProcess := (!fieldExistsInAggFields && !isParameterized) || isParameterized || (isInputExpression && !hasExpressionRegistered) if isInputExpression && containsAggCall(field.InputField) { shouldProcess = false } if shouldProcess { // Debug: log field creation (can be removed in production) // fmt.Printf("Creating aggregator for field: %s (%s) -> %s\n", field.FuncName, field.InputField, field.Placeholder) // Create aggregation field aggField := AggregationField{ InputField: field.InputField, AggregateType: field.AggType, OutputAlias: field.Placeholder, } // Add to aggregation fields (only if not duplicate) if !fieldExistsInAggFields { ega.GroupAggregator.aggregationFields = append(ega.GroupAggregator.aggregationFields, aggField) } // If input field is an expression, register expression evaluator (only if it does not depend on aggregation) if isInputExpression && !containsAggCall(field.InputField) { bridge := functions.GetExprBridge() ega.GroupAggregator.RegisterExpression( field.Placeholder, field.InputField, []string{}, // Will be populated by expression parsing func(data interface{}) (interface{}, error) { if dataMap, ok := data.(map[string]interface{}); ok { result, err := bridge.EvaluateExpression(field.InputField, dataMap) return result, err } return nil, fmt.Errorf("unsupported data type: %T", data) }, ) } // Create aggregator instance // For parameterized functions, create with parameters only when multiple top-level args are present if isParameterized && hasMultipleTopLevelArgs(field.FullCall) { aggregator := ega.createParameterizedAggregator(field) if aggregator != nil { ega.GroupAggregator.aggregators[field.Placeholder] = aggregator } else { // Fallback to simple aggregator ega.GroupAggregator.aggregators[field.Placeholder] = CreateBuiltinAggregator(field.AggType) } } else { ega.GroupAggregator.aggregators[field.Placeholder] = CreateBuiltinAggregator(field.AggType) } } } // Extract required field names var requiredFieldNames []string for _, field := range requiredFields { requiredFieldNames = append(requiredFieldNames, field.Placeholder) } // Build expression template by replacing each full function call with its placeholder // This preserves any outer non-aggregation functions (e.g., CEIL(__avg__)) and ensures // placeholders exactly match the ones created earlier for requiredFields. exprTemplate := originalExpr for _, field := range requiredFields { exprTemplate = strings.ReplaceAll(exprTemplate, field.FullCall, field.Placeholder) } // Detect aggregators whose input expressions themselves contain aggregation calls // Use dynamic function registry instead of hardcoded list containsAggCall := func(s string) bool { lower := strings.ToLower(s) // Extract potential function names from the expression for i := 0; i < len(lower); i++ { if lower[i] >= 'a' && lower[i] <= 'z' { // Find the end of the function name j := i for j < len(lower) && (lower[j] >= 'a' && lower[j] <= 'z' || lower[j] == '_') { j++ } // Check if it's followed by '(' and is an aggregator function if j < len(lower) && lower[j] == '(' { funcName := lower[i:j] if functions.IsAggregatorFunction(funcName) { return true } } i = j } else { i++ } } return false } // Adjust template and required fields: drop outer aggregators that wrap other aggregations adjustedTemplate := exprTemplate var adjustedRequired []AggregationFieldInfo for _, field := range requiredFields { if containsAggCall(field.InputField) { // Transform the input by replacing inner full calls with placeholders transformed := field.InputField for _, inner := range requiredFields { if inner.FullCall != field.FullCall { transformed = strings.ReplaceAll(transformed, inner.FullCall, inner.Placeholder) } } // Replace the placeholder of this outer aggregator back to the transformed expression adjustedTemplate = strings.ReplaceAll(adjustedTemplate, field.Placeholder, transformed) // Do NOT keep this outer aggregator in required list (it will not be created) continue } adjustedRequired = append(adjustedRequired, field) } requiredFields = adjustedRequired // Add to post-processor ega.postProcessor.AddExpression(outputField, originalExpr, requiredFieldNames, adjustedTemplate) return nil } // GetResults returns results with post-aggregation expressions evaluated func (ega *EnhancedGroupAggregator) GetResults() ([]map[string]interface{}, error) { // Get base aggregation results results, err := ega.GroupAggregator.GetResults() if err != nil { return nil, err } // Process post-aggregation expressions return ega.postProcessor.ProcessResults(results) } // createParameterizedAggregator creates aggregator with parameters for complex functions // 使用新的接口方法替代硬编码实现 func (ega *EnhancedGroupAggregator) createParameterizedAggregator(field AggregationFieldInfo) AggregatorFunction { // Parse function call to extract parameters args, err := ega.parseFunctionCall(field.FullCall) if err != nil { return nil } // Use the new interface method to create parameterized aggregator aggFunc, err := functions.CreateParameterizedAggregator(field.FuncName, args) if err != nil { return nil } // Wrap with WindowFunctionWrapper for compatibility return &WindowFunctionWrapper{aggFunc: aggFunc} } // hasMultipleTopLevelArgs returns true if the function call has more than one top-level argument func hasMultipleTopLevelArgs(funcCall string) bool { // Check if this is a function call with parentheses (starts with identifier followed by parentheses) start := strings.Index(funcCall, "(") end := strings.LastIndex(funcCall, ")") var params string var isDirectArgList bool // Only treat as function call if it starts with an identifier and has matching parentheses if start > 0 && end != -1 && end > start && end == len(funcCall)-1 { // Check if everything before the first '(' is a valid identifier (function name) funcName := strings.TrimSpace(funcCall[:start]) if isValidIdentifier(funcName) { // Function call format: func(args) - extract only the arguments inside parentheses params = funcCall[start+1 : end] isDirectArgList = false } else { // Direct argument list format: arg1, arg2 params = strings.TrimSpace(funcCall) isDirectArgList = true } } else { // Direct argument list format: arg1, arg2 params = strings.TrimSpace(funcCall) if params == "" { return false } isDirectArgList = true } params = strings.TrimSpace(params) if params == "" { return false } // For direct argument lists, special case: if the entire params is wrapped in parentheses // and has no top-level commas, it's a single complex argument if isDirectArgList && strings.HasPrefix(params, "(") && strings.HasSuffix(params, ")") { // Check if this is a complete parenthesized expression level := 0 for i, ch := range params { if ch == '(' { level++ } else if ch == ')' { level-- if level == 0 && i == len(params)-1 { // This is a single complete parenthesized expression return false } } } } level := 0 count := 0 inString := false stringChar := byte(0) for i := 0; i < len(params); i++ { ch := params[i] // Handle string literals if !inString && (ch == '\'' || ch == '"') { inString = true stringChar = ch continue } if inString && ch == stringChar { inString = false stringChar = 0 continue } // Skip processing if inside string if inString { continue } switch ch { case '(': level++ case ')': if level > 0 { level-- } case ',': if level == 0 { count++ } } } // If we found any commas at top level, we have multiple arguments return count > 0 } // isValidIdentifier checks if a string is a valid identifier (function name) func isValidIdentifier(s string) bool { if len(s) == 0 || len(s) > MaxFunctionNameLength { return false } // First character must be letter or underscore if !isValidIdentifierStart(s[0]) { return false } // Remaining characters must be letters, digits, or underscores for i := 1; i < len(s); i++ { if !isValidIdentifierChar(s[i]) { return false } } return true } // isValidIdentifierStart checks if a character can be used as the start of an identifier func isValidIdentifierStart(c byte) bool { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_' } // isValidIdentifierChar checks if a character can be used in an identifier func isValidIdentifierChar(c byte) bool { return isValidIdentifierStart(c) || (c >= '0' && c <= '9') } // parseFunctionCall parses a function call string and returns the arguments func (ega *EnhancedGroupAggregator) parseFunctionCall(funcCall string) ([]interface{}, error) { // Find the parentheses start := strings.Index(funcCall, "(") end := strings.LastIndex(funcCall, ")") if start == -1 || end == -1 { return nil, fmt.Errorf("invalid function call format: %s", funcCall) } // Extract parameters string paramsStr := strings.TrimSpace(funcCall[start+1 : end]) if paramsStr == "" { return []interface{}{}, nil } // Split parameters by comma paramStrs := strings.Split(paramsStr, ",") args := make([]interface{}, len(paramStrs)) for i, paramStr := range paramStrs { paramStr = strings.TrimSpace(paramStr) // Try to parse as number first if val, err := strconv.Atoi(paramStr); err == nil { args[i] = val } else if val, err := strconv.ParseFloat(paramStr, 64); err == nil { args[i] = val } else { // Treat as string (remove quotes if present) if (strings.HasPrefix(paramStr, "'") && strings.HasSuffix(paramStr, "'")) || (strings.HasPrefix(paramStr, "\"") && strings.HasSuffix(paramStr, "\"")) { args[i] = paramStr[1 : len(paramStr)-1] } else { args[i] = paramStr } } } return args, nil } // WindowFunctionWrapper wraps window functions to make them compatible with LegacyAggregatorFunction type WindowFunctionWrapper struct { aggFunc functions.AggregatorFunction } func (w *WindowFunctionWrapper) New() AggregatorFunction { return &WindowFunctionWrapper{aggFunc: w.aggFunc.New()} } func (w *WindowFunctionWrapper) Add(value interface{}) { w.aggFunc.Add(value) } func (w *WindowFunctionWrapper) Result() interface{} { return w.aggFunc.Result() } func (w *WindowFunctionWrapper) Reset() { w.aggFunc.Reset() } func (w *WindowFunctionWrapper) Clone() AggregatorFunction { return &WindowFunctionWrapper{aggFunc: w.aggFunc.Clone()} } // Interface compliance check var _ Aggregator = (*EnhancedGroupAggregator)(nil)