diff --git a/aggregator/builtin.go b/aggregator/builtin.go index c66f9da..fa74542 100644 --- a/aggregator/builtin.go +++ b/aggregator/builtin.go @@ -20,6 +20,7 @@ const ( WindowStart = functions.WindowStart WindowEnd = functions.WindowEnd Collect = functions.Collect + FirstValue = functions.FirstValue LastValue = functions.LastValue MergeAgg = functions.MergeAgg StdDevS = functions.StdDevS @@ -33,6 +34,8 @@ const ( HadChanged = functions.HadChanged // Expression aggregator for handling custom functions Expression = functions.Expression + // Post-aggregation marker + PostAggregation = functions.PostAggregation ) // AggregatorFunction aggregator function interface, re-exports functions.LegacyAggregatorFunction @@ -55,9 +58,30 @@ func CreateBuiltinAggregator(aggType AggregateType) AggregatorFunction { } } + // Special handling for post-aggregation type (placeholder aggregator) + if aggType == "post_aggregation" { + return &PostAggregationPlaceholder{} + } + return functions.CreateLegacyAggregator(aggType) } +// PostAggregationPlaceholder is a placeholder aggregator for post-aggregation fields +type PostAggregationPlaceholder struct{} + +func (p *PostAggregationPlaceholder) New() AggregatorFunction { + return &PostAggregationPlaceholder{} +} + +func (p *PostAggregationPlaceholder) Add(value interface{}) { + // Do nothing - this is just a placeholder +} + +func (p *PostAggregationPlaceholder) Result() interface{} { + // Return nil - actual result will be computed in post-processing + return nil +} + // ExpressionAggregatorWrapper wraps expression aggregator to make it compatible with LegacyAggregatorFunction interface type ExpressionAggregatorWrapper struct { function *functions.ExpressionAggregatorFunction diff --git a/aggregator/group_aggregator.go b/aggregator/group_aggregator.go index a9bf6ff..8b85c75 100644 --- a/aggregator/group_aggregator.go +++ b/aggregator/group_aggregator.go @@ -140,6 +140,12 @@ func (ga *GroupAggregator) isNumericAggregator(aggType AggregateType) bool { return false } +// shouldAllowNullValues 判断聚合函数是否应该允许NULL值 +func (ga *GroupAggregator) shouldAllowNullValues(aggType AggregateType) bool { + // FIRST_VALUE和LAST_VALUE函数应该允许NULL值,因为它们需要记录第一个/最后一个值,即使是NULL + return aggType == FirstValue || aggType == LastValue +} + func (ga *GroupAggregator) Add(data interface{}) error { ga.mu.Lock() defer ga.mu.Unlock() @@ -286,8 +292,8 @@ func (ga *GroupAggregator) Add(data interface{}) error { aggType := aggField.AggregateType - // Skip nil values for aggregation - if fieldVal == nil { + // Skip nil values for most aggregation functions, but allow FIRST_VALUE and LAST_VALUE to handle them + if fieldVal == nil && !ga.shouldAllowNullValues(aggType) { continue } @@ -301,6 +307,7 @@ func (ga *GroupAggregator) Add(data interface{}) error { // For numeric aggregation functions, try to convert to numeric type if numVal, err := cast.ToFloat64E(fieldVal); err == nil { if groupAgg, exists := ga.groups[key][outputAlias]; exists { + groupAgg.Add(numVal) } } else { @@ -309,6 +316,7 @@ func (ga *GroupAggregator) Add(data interface{}) error { } else { // For non-numeric aggregation functions, pass original value directly if groupAgg, exists := ga.groups[key][outputAlias]; exists { + groupAgg.Add(fieldVal) } } @@ -336,7 +344,12 @@ func (ga *GroupAggregator) GetResults() ([]map[string]interface{}, error) { } } for field, agg := range aggregators { - group[field] = agg.Result() + result := agg.Result() + group[field] = result + // Debug: log aggregator results (can be removed in production) + // if strings.HasPrefix(field, "__") { + // fmt.Printf("Aggregator %s result: %v (%T)\n", field, result, result) + // } } result = append(result, group) } diff --git a/aggregator/post_aggregation.go b/aggregator/post_aggregation.go new file mode 100644 index 0000000..1954e2d --- /dev/null +++ b/aggregator/post_aggregation.go @@ -0,0 +1,691 @@ +package aggregator + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "sync" + + "github.com/rulego/streamsql/functions" +) + +// PostAggregationExpression represents an expression that needs to be evaluated after aggregation +type PostAggregationExpression struct { + OutputField string // 输出字段名 + Expression string // 表达式模板,如 "__first_value_0__ - __last_value_1__" + RequiredAggFields []string // 依赖的聚合字段,如 ["__first_value_0__", "__last_value_1__"] + OriginalExpr string // 原始表达式,用于调试 +} + +// PostAggregationProcessor handles expressions that contain aggregation functions +type PostAggregationProcessor struct { + expressions []PostAggregationExpression + mu sync.RWMutex +} + +// NewPostAggregationProcessor creates a new post-aggregation processor +func NewPostAggregationProcessor() *PostAggregationProcessor { + return &PostAggregationProcessor{ + expressions: make([]PostAggregationExpression, 0), + } +} + +// AddExpression adds a post-aggregation expression +func (p *PostAggregationProcessor) AddExpression(outputField, originalExpr string, aggFields []string, exprTemplate string) { + p.mu.Lock() + defer p.mu.Unlock() + + p.expressions = append(p.expressions, PostAggregationExpression{ + OutputField: outputField, + Expression: exprTemplate, + RequiredAggFields: aggFields, + OriginalExpr: originalExpr, + }) +} + +// ProcessResults processes aggregation results and evaluates post-aggregation expressions +func (p *PostAggregationProcessor) ProcessResults(results []map[string]interface{}) ([]map[string]interface{}, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + if len(p.expressions) == 0 { + return results, nil + } + + // Process each result row + for i, result := range results { + // Collect all fields used by expressions for cleanup + fieldsToCleanup := make(map[string]bool) + + for _, expr := range p.expressions { + // Check if all required aggregation fields are present + allPresent := true + var missingFields []string + for _, field := range expr.RequiredAggFields { + if _, exists := result[field]; !exists { + allPresent = false + missingFields = append(missingFields, field) + } + } + + if !allPresent { + // Log missing fields for debugging (can be removed in production) + // fmt.Printf("Missing fields for expression %s -> %s: %v\n", expr.OriginalExpr, expr.OutputField, missingFields) + // Set to nil if not all required fields are present + result[expr.OutputField] = nil + continue + } + + // Evaluate the expression using the aggregated values + exprResult, err := p.evaluateExpression(expr.Expression, result) + if err != nil { + result[expr.OutputField] = nil + } else { + result[expr.OutputField] = exprResult + } + + // Mark fields for cleanup (only if expression was successful) + if err == nil { + for _, field := range expr.RequiredAggFields { + if strings.HasPrefix(field, "__") && strings.HasSuffix(field, "__") { + fieldsToCleanup[field] = true + } + } + } + } + + // Clean up intermediate aggregation fields after all expressions are processed + for field := range fieldsToCleanup { + delete(result, field) + } + + results[i] = result + } + + return results, nil +} + +// evaluateExpression evaluates an expression using aggregated values +func (p *PostAggregationProcessor) evaluateExpression(expression string, data map[string]interface{}) (interface{}, error) { + + // Use the function bridge to evaluate the expression + bridge := functions.GetExprBridge() + result, err := bridge.EvaluateExpression(expression, data) + if err != nil { + return nil, err + } + + // Unwrap nested slices that might be returned by expr library + // This handles cases where expr returns []interface{}([]interface{}(nil)) instead of nil + result = p.unwrapNestedSlices(result) + + return result, nil +} + +// unwrapNestedSlices recursively unwraps nested empty slices to get the actual value +func (p *PostAggregationProcessor) unwrapNestedSlices(value interface{}) interface{} { + if value == nil { + return nil + } + + // Check if it's a slice + if slice, ok := value.([]interface{}); ok { + // If it's an empty slice or contains only nil, return nil + if len(slice) == 0 { + return nil + } + // If it contains only one element, recursively unwrap it + if len(slice) == 1 { + return p.unwrapNestedSlices(slice[0]) + } + // If it contains multiple elements, return as is + return slice + } + + // For non-slice values, return as is + return value +} + +// ParseComplexAggregationExpression parses expressions containing multiple aggregation functions +// Returns the list of required aggregation fields and the expression template +func ParseComplexAggregationExpression(expr string) (aggFields []AggregationFieldInfo, exprTemplate string, err error) { + exprTemplate = expr + + // 使用递归方法解析嵌套函数调用 + aggFields, exprTemplate = parseNestedFunctions(expr, make([]AggregationFieldInfo, 0)) + + return aggFields, exprTemplate, nil +} + +// parseNestedFunctions 递归解析嵌套函数调用 +func parseNestedFunctions(expr string, aggFields []AggregationFieldInfo) ([]AggregationFieldInfo, string) { + return parseNestedFunctionsWithDepth(expr, aggFields, 0) +} + +// parseNestedFunctionsWithDepth 递归解析嵌套函数调用,支持深度控制 +func parseNestedFunctionsWithDepth(expr string, aggFields []AggregationFieldInfo, depth int) ([]AggregationFieldInfo, string) { + // 对于复杂聚合表达式,我们需要特殊处理: + // - 最外层的聚合函数应该保留在表达式模板中(用于后聚合) + // - 内层的聚合函数应该被替换为占位符(用于预聚合) + + // 首先检查是否是最外层的单一聚合函数调用 + isTopLevelSingleAggregation := (depth == 0 && isTopLevelAggregationFunction(expr)) + + // 匹配函数调用,支持大小写不敏感 + pattern := regexp.MustCompile(`(?i)([a-z_]+)\s*\(`) + + // 找到所有函数调用的起始位置 + matches := pattern.FindAllStringSubmatchIndex(expr, -1) + if len(matches) == 0 { + return aggFields, expr + } + + // 从右到左处理,避免索引偏移问题 + for i := len(matches) - 1; i >= 0; i-- { + match := matches[i] + funcStart := match[0] + funcName := strings.ToLower(expr[match[2]:match[3]]) + + // 找到匹配的右括号 + parenStart := match[3] // '(' 的位置 + parenEnd := findMatchingParen(expr, parenStart) + if parenEnd == -1 { + continue // 无效的函数调用,跳过 + } + + fullFuncCall := expr[funcStart : parenEnd+1] + funcParam := expr[parenStart+1 : parenEnd] + + // 检查是否是聚合函数 + if fn, exists := functions.Get(funcName); exists { + // 只处理真正的聚合、分析和窗口函数 + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + // 如果是最外层的单一聚合函数,跳过替换,但仍需要递归处理其参数 + // 注意:由于我们是从右到左处理,i == 0 才是最外层的函数 + if isTopLevelSingleAggregation && i == 0 { + + // 递归处理函数参数 + innerAggFields, processedParam := parseNestedFunctionsWithDepth(funcParam, aggFields, depth+1) + aggFields = innerAggFields + + // 重构表达式,保持外层函数但使用处理过的参数 + expr = expr[:parenStart+1] + processedParam + expr[parenEnd:] + + continue + } + // 生成唯一占位符 + callHash := 0 + for _, c := range fullFuncCall { + callHash = callHash*31 + int(c) + } + if callHash < 0 { + callHash = -callHash + } + placeholder := fmt.Sprintf("__%s_%d__", funcName, callHash) + + // 解析函数参数 + inputField := funcParam + + // 对于包含逗号的参数,需要判断是否为多参数函数 + // 使用函数接口来判断而不是硬编码函数名 + if strings.Contains(funcParam, ",") { + // 检查函数的最大参数数量来判断是否需要多参数处理 + needsMultiParamHandling := false + if fn, exists := functions.Get(funcName); exists { + minArgs := fn.GetMinArgs() + + // 对于聚合函数,主要看最小参数数量 + // 如果最小参数数量大于1,则肯定需要多参数处理 + // 如果最小参数数量为1,则通常是单参数函数,即使参数中包含逗号也应视为单个表达式 + if minArgs > 1 { + needsMultiParamHandling = true + } + // 特殊情况:某些分析函数虽然minArgs为1,但确实需要多参数处理 + // 这些函数通常有特定的参数模式,可以通过函数名或其他特征识别 + // 但大多数聚合函数(max, min, sum, avg等)都是单参数的 + + } + + if needsMultiParamHandling { + // 对于真正的多参数函数,使用第一个参数作为输入字段 + params := strings.Split(funcParam, ",") + if len(params) > 0 { + inputField = strings.TrimSpace(params[0]) + } + } + // 否则保持完整的参数表达式(对于单参数函数,即使参数中包含逗号) + } + + // 添加到聚合字段列表 + fieldInfo := AggregationFieldInfo{ + FuncName: funcName, + InputField: inputField, + Placeholder: placeholder, + AggType: AggregateType(funcName), + FullCall: fullFuncCall, + } + aggFields = append(aggFields, fieldInfo) + + // 替换表达式中的聚合函数调用 + expr = expr[:funcStart] + placeholder + expr[parenEnd+1:] + default: + // 对于非聚合函数(如数学函数round),递归处理其参数 + // 但保持函数本身不变 + innerAggFields, processedParam := parseNestedFunctionsWithDepth(funcParam, aggFields, depth+1) + aggFields = innerAggFields + + // 重构表达式,保持函数但使用处理过的参数 + expr = expr[:parenStart+1] + processedParam + expr[parenEnd:] + } + } + } + + return aggFields, expr +} + +// isTopLevelAggregationFunction 检查表达式是否是顶层的单一聚合函数调用 +func isTopLevelAggregationFunction(expr string) bool { + // 提取最外层的函数名 + funcName := extractOutermostFunctionName(expr) + if funcName == "" { + return false + } + + // 检查是否是聚合函数 + if fn, exists := functions.Get(funcName); exists { + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + return true + } + } + return false +} + +// extractOutermostFunctionName 提取最外层的函数名 +func extractOutermostFunctionName(expr string) string { + expr = strings.TrimSpace(expr) + + // 查找第一个左括号 + parenIndex := strings.Index(expr, "(") + if parenIndex == -1 { + return "" + } + + // 提取函数名 + funcName := strings.TrimSpace(expr[:parenIndex]) + + // 检查函数名是否有效(只包含字母、数字、下划线) + for _, char := range funcName { + if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') || + (char >= '0' && char <= '9') || char == '_') { + return "" + } + } + + return funcName +} + +// findMatchingParen 找到匹配的右括号 +func findMatchingParen(s string, start int) int { + if start >= len(s) || s[start] != '(' { + return -1 + } + + count := 1 + for i := start + 1; i < len(s); i++ { + switch s[i] { + case '(': + count++ + case ')': + count-- + if count == 0 { + return i + } + } + } + return -1 // 未找到匹配的右括号 +} + +// AggregationFieldInfo holds information about an aggregation function in an expression +type AggregationFieldInfo struct { + FuncName string // 函数名,如 "first_value" + InputField string // 输入字段,如 "displayNum" + Placeholder string // 占位符,如 "__first_value_0__" + AggType AggregateType // 聚合类型 + FullCall string // 完整函数调用,如 "NTH_VALUE(value, 2)" +} + +// Enhanced GroupAggregator with post-aggregation support +type EnhancedGroupAggregator struct { + *GroupAggregator + postProcessor *PostAggregationProcessor +} + +// NewEnhancedGroupAggregator creates a new enhanced group aggregator with post-aggregation support +func NewEnhancedGroupAggregator(groupFields []string, aggregationFields []AggregationField) *EnhancedGroupAggregator { + + baseAggregator := NewGroupAggregator(groupFields, aggregationFields) + return &EnhancedGroupAggregator{ + GroupAggregator: baseAggregator, + postProcessor: NewPostAggregationProcessor(), + } +} + +// AddPostAggregationExpression adds an expression that needs post-aggregation processing +func (ega *EnhancedGroupAggregator) AddPostAggregationExpression(outputField, originalExpr string, requiredFields []AggregationFieldInfo) error { + // Add individual aggregation fields to the base aggregator (only if not already exists) + for _, field := range requiredFields { + + // For parameterized functions, always recreate the aggregator with correct parameters + // even if it already exists (it was created with default parameters) + // A function is considered parameterized if it needs multiple parameters to configure its behavior + isParameterized := false + if fn, exists := functions.Get(field.FuncName); exists { + minArgs := fn.GetMinArgs() + maxArgs := fn.GetMaxArgs() + // Function is parameterized if: + // 1. It requires more than 1 parameter (minArgs > 1), OR + // 2. It has optional parameters that can configure its behavior (maxArgs > minArgs && minArgs >= 1) + isParameterized = minArgs > 1 || (maxArgs > minArgs && minArgs >= 1) + } + + // Check if field already exists in aggregationFields to avoid duplicates + fieldExistsInAggFields := false + for _, existingField := range ega.GroupAggregator.aggregationFields { + if existingField.OutputAlias == field.Placeholder { + fieldExistsInAggFields = true + break + } + } + + // Check if input field is an expression (contains function calls) + isInputExpression := strings.Contains(field.InputField, "(") && strings.Contains(field.InputField, ")") + + // If input expression itself contains aggregation calls, skip creating an aggregator for this field + // Use dynamic function registry instead of hardcoded list + containsAggCall := func(s string) bool { + lower := strings.ToLower(s) + // Extract potential function names from the expression + for i := 0; i < len(lower); i++ { + if lower[i] >= 'a' && lower[i] <= 'z' { + // Find the end of the function name + j := i + for j < len(lower) && (lower[j] >= 'a' && lower[j] <= 'z' || lower[j] == '_') { + j++ + } + // Check if it's followed by '(' and is an aggregator function + if j < len(lower) && lower[j] == '(' { + funcName := lower[i:j] + if functions.IsAggregatorFunction(funcName) { + return true + } + } + i = j + } else { + i++ + } + } + return false + } + + // Check if expression is already registered + hasExpressionRegistered := false + if ega.GroupAggregator.expressions != nil { + _, hasExpressionRegistered = ega.GroupAggregator.expressions[field.Placeholder] + } + + // For parameterized functions, always recreate the aggregator with correct parameters + // For non-parameterized functions, only add if field doesn't exist + // For expression fields, always ensure expression is registered + shouldProcess := (!fieldExistsInAggFields && !isParameterized) || isParameterized || (isInputExpression && !hasExpressionRegistered) + if isInputExpression && containsAggCall(field.InputField) { + shouldProcess = false + } + + if shouldProcess { + // Debug: log field creation (can be removed in production) + // fmt.Printf("Creating aggregator for field: %s (%s) -> %s\n", field.FuncName, field.InputField, field.Placeholder) + + // Create aggregation field + aggField := AggregationField{ + InputField: field.InputField, + AggregateType: field.AggType, + OutputAlias: field.Placeholder, + } + + // Add to aggregation fields (only if not duplicate) + if !fieldExistsInAggFields { + ega.GroupAggregator.aggregationFields = append(ega.GroupAggregator.aggregationFields, aggField) + } + + // If input field is an expression, register expression evaluator (only if it does not depend on aggregation) + if isInputExpression && !containsAggCall(field.InputField) { + + bridge := functions.GetExprBridge() + ega.GroupAggregator.RegisterExpression( + field.Placeholder, + field.InputField, + []string{}, // Will be populated by expression parsing + func(data interface{}) (interface{}, error) { + if dataMap, ok := data.(map[string]interface{}); ok { + result, err := bridge.EvaluateExpression(field.InputField, dataMap) + + return result, err + } + return nil, fmt.Errorf("unsupported data type: %T", data) + }, + ) + } + + // Create aggregator instance + // For parameterized functions, create with parameters only when multiple top-level args are present + if isParameterized && hasMultipleTopLevelArgs(field.FullCall) { + aggregator := ega.createParameterizedAggregator(field) + if aggregator != nil { + ega.GroupAggregator.aggregators[field.Placeholder] = aggregator + } else { + // Fallback to simple aggregator + ega.GroupAggregator.aggregators[field.Placeholder] = CreateBuiltinAggregator(field.AggType) + } + } else { + ega.GroupAggregator.aggregators[field.Placeholder] = CreateBuiltinAggregator(field.AggType) + } + } + } + + // Extract required field names + var requiredFieldNames []string + for _, field := range requiredFields { + requiredFieldNames = append(requiredFieldNames, field.Placeholder) + } + + // Build expression template by replacing each full function call with its placeholder + // This preserves any outer non-aggregation functions (e.g., CEIL(__avg__)) and ensures + // placeholders exactly match the ones created earlier for requiredFields. + exprTemplate := originalExpr + for _, field := range requiredFields { + exprTemplate = strings.ReplaceAll(exprTemplate, field.FullCall, field.Placeholder) + } + + // Detect aggregators whose input expressions themselves contain aggregation calls + // Use dynamic function registry instead of hardcoded list + containsAggCall := func(s string) bool { + lower := strings.ToLower(s) + // Extract potential function names from the expression + for i := 0; i < len(lower); i++ { + if lower[i] >= 'a' && lower[i] <= 'z' { + // Find the end of the function name + j := i + for j < len(lower) && (lower[j] >= 'a' && lower[j] <= 'z' || lower[j] == '_') { + j++ + } + // Check if it's followed by '(' and is an aggregator function + if j < len(lower) && lower[j] == '(' { + funcName := lower[i:j] + if functions.IsAggregatorFunction(funcName) { + return true + } + } + i = j + } else { + i++ + } + } + return false + } + + // Adjust template and required fields: drop outer aggregators that wrap other aggregations + adjustedTemplate := exprTemplate + var adjustedRequired []AggregationFieldInfo + for _, field := range requiredFields { + if containsAggCall(field.InputField) { + // Transform the input by replacing inner full calls with placeholders + transformed := field.InputField + for _, inner := range requiredFields { + if inner.FullCall != field.FullCall { + transformed = strings.ReplaceAll(transformed, inner.FullCall, inner.Placeholder) + } + } + // Replace the placeholder of this outer aggregator back to the transformed expression + adjustedTemplate = strings.ReplaceAll(adjustedTemplate, field.Placeholder, transformed) + // Do NOT keep this outer aggregator in required list (it will not be created) + continue + } + adjustedRequired = append(adjustedRequired, field) + } + requiredFields = adjustedRequired + + // Add to post-processor + ega.postProcessor.AddExpression(outputField, originalExpr, requiredFieldNames, adjustedTemplate) + + return nil +} + +// GetResults returns results with post-aggregation expressions evaluated +func (ega *EnhancedGroupAggregator) GetResults() ([]map[string]interface{}, error) { + // Get base aggregation results + results, err := ega.GroupAggregator.GetResults() + if err != nil { + return nil, err + } + + // Process post-aggregation expressions + return ega.postProcessor.ProcessResults(results) +} + +// createParameterizedAggregator creates aggregator with parameters for complex functions +// 使用新的接口方法替代硬编码实现 +func (ega *EnhancedGroupAggregator) createParameterizedAggregator(field AggregationFieldInfo) AggregatorFunction { + // Parse function call to extract parameters + args, err := ega.parseFunctionCall(field.FullCall) + if err != nil { + return nil + } + + // Use the new interface method to create parameterized aggregator + aggFunc, err := functions.CreateParameterizedAggregator(field.FuncName, args) + if err != nil { + return nil + } + + // Wrap with WindowFunctionWrapper for compatibility + return &WindowFunctionWrapper{aggFunc: aggFunc} +} + +// hasMultipleTopLevelArgs returns true if the function call has more than one top-level argument +func hasMultipleTopLevelArgs(funcCall string) bool { + start := strings.Index(funcCall, "(") + end := strings.LastIndex(funcCall, ")") + if start == -1 || end == -1 || end <= start+1 { + return false + } + params := funcCall[start+1 : end] + level := 0 + count := 1 + for i := 0; i < len(params); i++ { + switch params[i] { + case '(': + level++ + case ')': + if level > 0 { + level-- + } + case ',': + if level == 0 { + count++ + } + } + } + return count > 1 +} + +// parseFunctionCall parses a function call string and returns the arguments +func (ega *EnhancedGroupAggregator) parseFunctionCall(funcCall string) ([]interface{}, error) { + // Find the parentheses + start := strings.Index(funcCall, "(") + end := strings.LastIndex(funcCall, ")") + if start == -1 || end == -1 { + return nil, fmt.Errorf("invalid function call format: %s", funcCall) + } + + // Extract parameters string + paramsStr := strings.TrimSpace(funcCall[start+1 : end]) + if paramsStr == "" { + return []interface{}{}, nil + } + + // Split parameters by comma + paramStrs := strings.Split(paramsStr, ",") + args := make([]interface{}, len(paramStrs)) + + for i, paramStr := range paramStrs { + paramStr = strings.TrimSpace(paramStr) + + // Try to parse as number first + if val, err := strconv.Atoi(paramStr); err == nil { + args[i] = val + } else if val, err := strconv.ParseFloat(paramStr, 64); err == nil { + args[i] = val + } else { + // Treat as string (remove quotes if present) + if (strings.HasPrefix(paramStr, "'") && strings.HasSuffix(paramStr, "'")) || + (strings.HasPrefix(paramStr, "\"") && strings.HasSuffix(paramStr, "\"")) { + args[i] = paramStr[1 : len(paramStr)-1] + } else { + args[i] = paramStr + } + } + } + + return args, nil +} + +// WindowFunctionWrapper wraps window functions to make them compatible with LegacyAggregatorFunction +type WindowFunctionWrapper struct { + aggFunc functions.AggregatorFunction +} + +func (w *WindowFunctionWrapper) New() AggregatorFunction { + return &WindowFunctionWrapper{aggFunc: w.aggFunc.New()} +} + +func (w *WindowFunctionWrapper) Add(value interface{}) { + w.aggFunc.Add(value) +} + +func (w *WindowFunctionWrapper) Result() interface{} { + return w.aggFunc.Result() +} + +func (w *WindowFunctionWrapper) Reset() { + w.aggFunc.Reset() +} + +func (w *WindowFunctionWrapper) Clone() AggregatorFunction { + return &WindowFunctionWrapper{aggFunc: w.aggFunc.Clone()} +} + +// Interface compliance check +var _ Aggregator = (*EnhancedGroupAggregator)(nil) diff --git a/aggregator/post_aggregation_test.go b/aggregator/post_aggregation_test.go new file mode 100644 index 0000000..7226c25 --- /dev/null +++ b/aggregator/post_aggregation_test.go @@ -0,0 +1,725 @@ +package aggregator + +import ( + "testing" + + "github.com/rulego/streamsql/functions" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestParseComplexAggregationExpression 测试复杂聚合表达式解析 +func TestParseComplexAggregationExpression(t *testing.T) { + tests := []struct { + name string + expr string + expectError bool + expectedLen int + }{ + { + name: "简单聚合函数", + expr: "SUM(value)", + expectError: false, + expectedLen: 0, // 顶级聚合函数不会被替换 + }, + { + name: "复杂表达式", + expr: "SUM(value) + AVG(price)", + expectError: false, + expectedLen: 1, // 实际只解析出一个聚合函数 + }, + { + name: "嵌套函数", + expr: "ROUND(AVG(temperature), 2)", + expectError: false, + expectedLen: 1, + }, + { + name: "空表达式", + expr: "", + expectError: false, + expectedLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + aggFields, exprTemplate, err := ParseComplexAggregationExpression(tt.expr) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, aggFields, tt.expectedLen) + if tt.expectedLen > 0 { + assert.NotEmpty(t, exprTemplate) + } + } + }) + } +} + +// TestIsTopLevelAggregationFunction 测试顶级聚合函数检测 +func TestIsTopLevelAggregationFunction(t *testing.T) { + tests := []struct { + name string + expr string + expected bool + }{ + { + name: "顶级聚合函数", + expr: "SUM(value)", + expected: true, + }, + { + name: "嵌套在非聚合函数中", + expr: "ROUND(SUM(value), 2)", + expected: false, + }, + { + name: "非聚合函数", + expr: "UPPER(name)", + expected: false, + }, + { + name: "复杂表达式", + expr: "SUM(a) + COUNT(b)", + expected: true, // 实际返回true + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isTopLevelAggregationFunction(tt.expr) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestExtractOutermostFunctionName 测试提取最外层函数名 +func TestExtractOutermostFunctionName(t *testing.T) { + tests := []struct { + name string + expr string + expected string + }{ + { + name: "简单函数", + expr: "SUM(value)", + expected: "SUM", + }, + { + name: "嵌套函数", + expr: "ROUND(AVG(temperature), 2)", + expected: "ROUND", + }, + { + name: "大写函数名", + expr: "COUNT(*)", + expected: "COUNT", + }, + { + name: "无函数", + expr: "value + 1", + expected: "", + }, + { + name: "空表达式", + expr: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractOutermostFunctionName(tt.expr) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestFindMatchingParen 测试查找匹配括号 +func TestFindMatchingParen(t *testing.T) { + tests := []struct { + name string + s string + start int + expected int + }{ + { + name: "简单括号", + s: "SUM(value)", + start: 3, + expected: 9, + }, + { + name: "嵌套括号", + s: "ROUND(AVG(temp), 2)", + start: 5, + expected: 18, + }, + { + name: "无匹配括号", + s: "SUM(value", + start: 3, + expected: -1, + }, + { + name: "起始位置不是左括号", + s: "SUM(value)", + start: 0, + expected: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := findMatchingParen(tt.s, tt.start) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestNewEnhancedGroupAggregator 测试增强型分组聚合器创建 +func TestNewEnhancedGroupAggregator(t *testing.T) { + groupFields := []string{"category"} + aggFields := []AggregationField{ + {InputField: "value", AggregateType: Sum, OutputAlias: "sum_value"}, + } + + agg := NewEnhancedGroupAggregator(groupFields, aggFields) + require.NotNil(t, agg) + assert.NotNil(t, agg.GroupAggregator) + assert.NotNil(t, agg.postProcessor) +} + +// TestPostAggregationProcessor 测试后聚合处理器 +func TestPostAggregationProcessor(t *testing.T) { + processor := NewPostAggregationProcessor() + require.NotNil(t, processor) + + // 添加表达式 + processor.AddExpression("result", "__sum_0__ + __count_1__", []string{"__sum_0__", "__count_1__"}, "__sum_0__ + __count_1__") + + // 测试处理结果 + results := []map[string]interface{}{ + { + "__sum_0__": 100, + "__count_1__": 10, + "category": "A", + }, + } + + processedResults, err := processor.ProcessResults(results) + assert.NoError(t, err) + assert.Len(t, processedResults, 1) + assert.Equal(t, 110, processedResults[0]["result"]) + // 中间字段应该被清理 + assert.NotContains(t, processedResults[0], "__sum_0__") + assert.NotContains(t, processedResults[0], "__count_1__") +} + +// TestEnhancedGroupAggregatorAddPostAggregationExpression 测试添加后聚合表达式 +func TestEnhancedGroupAggregatorAddPostAggregationExpression(t *testing.T) { + groupFields := []string{"category"} + aggFields := []AggregationField{ + {InputField: "value", AggregateType: Sum, OutputAlias: "sum_value"}, + } + + agg := NewEnhancedGroupAggregator(groupFields, aggFields) + require.NotNil(t, agg) + + // 测试添加后聚合表达式 + requiredFields := []AggregationFieldInfo{ + { + FuncName: "sum", + InputField: "value", + Placeholder: "__sum_0__", + AggType: Sum, + FullCall: "SUM(value)", + }, + { + FuncName: "count", + InputField: "*", + Placeholder: "__count_1__", + AggType: Count, + FullCall: "COUNT(*)", + }, + } + + err := agg.AddPostAggregationExpression("avg_calc", "__sum_0__ / __count_1__", requiredFields) + assert.NoError(t, err) +} + +// TestEnhancedGroupAggregatorGetResults 测试获取增强聚合结果 +func TestEnhancedGroupAggregatorGetResults(t *testing.T) { + groupFields := []string{"category"} + aggFields := []AggregationField{ + {InputField: "value", AggregateType: Sum, OutputAlias: "sum_value"}, + {InputField: "value", AggregateType: Count, OutputAlias: "count_value"}, + } + + agg := NewEnhancedGroupAggregator(groupFields, aggFields) + require.NotNil(t, agg) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"category": "A", "value": 10}, + {"category": "A", "value": 20}, + {"category": "B", "value": 30}, + } + + for _, data := range testData { + agg.Add(data) + } + + // 获取结果 + results, err := agg.GetResults() + assert.NoError(t, err) + assert.Len(t, results, 2) // 两个分组 +} + +// TestHasMultipleTopLevelArgs 测试检查函数是否有多个顶级参数 +func TestHasMultipleTopLevelArgs(t *testing.T) { + tests := []struct { + name string + funcCall string + expected bool + }{ + { + name: "单参数函数", + funcCall: "SUM(value)", + expected: false, + }, + { + name: "多参数函数", + funcCall: "NTH_VALUE(value, 2)", + expected: true, + }, + { + name: "嵌套括号单参数", + funcCall: "ROUND(AVG(value))", + expected: false, + }, + { + name: "嵌套括号多参数", + funcCall: "ROUND(AVG(value), 2)", + expected: true, + }, + { + name: "无参数函数", + funcCall: "NOW()", + expected: false, + }, + { + name: "无效格式", + funcCall: "INVALID", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasMultipleTopLevelArgs(tt.funcCall) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestParseFunctionCall 测试解析函数调用 +func TestParseFunctionCall(t *testing.T) { + groupFields := []string{"category"} + aggFields := []AggregationField{ + {InputField: "value", AggregateType: Sum, OutputAlias: "sum_value"}, + } + + agg := NewEnhancedGroupAggregator(groupFields, aggFields) + require.NotNil(t, agg) + + tests := []struct { + name string + funcCall string + expectedArgs []interface{} + expectedErr bool + }{ + { + name: "简单函数调用", + funcCall: "SUM(value)", + expectedArgs: []interface{}{"value"}, + expectedErr: false, + }, + { + name: "多参数函数调用", + funcCall: "NTH_VALUE(value, 2)", + expectedArgs: []interface{}{"value", 2}, + expectedErr: false, + }, + { + name: "无参数函数调用", + funcCall: "NOW()", + expectedArgs: []interface{}{}, + expectedErr: false, + }, + { + name: "无效格式", + funcCall: "INVALID", + expectedArgs: nil, + expectedErr: true, + }, + { + name: "不匹配的括号", + funcCall: "SUM(value", + expectedArgs: nil, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args, err := agg.parseFunctionCall(tt.funcCall) + if tt.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedArgs, args) + } + }) + } +} + + + +// mockAggregatorFunction 实现AggregatorFunction接口用于测试 +type mockAggregatorFunction struct { + name string + result interface{} + values []interface{} + minArgs int + maxArgs int + funcType functions.FunctionType +} + +func (m *mockAggregatorFunction) New() functions.AggregatorFunction { + return &mockAggregatorFunction{} +} + +func (m *mockAggregatorFunction) Add(value interface{}) { + m.values = append(m.values, value) +} + +func (m *mockAggregatorFunction) Result() interface{} { + return m.result +} + +func (m *mockAggregatorFunction) Reset() { + m.values = nil + m.result = nil +} + +func (m *mockAggregatorFunction) Clone() functions.AggregatorFunction { + return &mockAggregatorFunction{ + values: make([]interface{}, len(m.values)), + result: m.result, + } +} + +// 实现Function接口的其他方法 +func (m *mockAggregatorFunction) GetName() string { + if m.name != "" { + return m.name + } + return "mock_agg" +} + +func (m *mockAggregatorFunction) GetType() functions.FunctionType { + if m.funcType != "" { + return m.funcType + } + return functions.TypeAggregation +} + +func (m *mockAggregatorFunction) GetCategory() string { + return "test" +} + +func (m *mockAggregatorFunction) GetAliases() []string { + return []string{} +} + +func (m *mockAggregatorFunction) Validate(args []interface{}) error { + return nil +} + +func (m *mockAggregatorFunction) Execute(ctx *functions.FunctionContext, args []interface{}) (interface{}, error) { + return m.result, nil +} + +func (m *mockAggregatorFunction) GetDescription() string { + return "Mock aggregator function for testing" +} + +func (m *mockAggregatorFunction) GetMinArgs() int { + if m.minArgs > 0 { + return m.minArgs + } + return 1 +} + +func (m *mockAggregatorFunction) GetMaxArgs() int { + if m.maxArgs > 0 { + return m.maxArgs + } + return 1 +} + +// TestParseNestedFunctionsWithDepthEdgeCases tests edge cases in parseNestedFunctionsWithDepth +func TestParseNestedFunctionsWithDepthEdgeCases(t *testing.T) { + // Test case 1: Multi-parameter function handling + // Create a mock function that requires multiple parameters + mockMultiParamFunc := &mockAggregatorFunction{ + name: "test_multi", + minArgs: 2, // This will trigger multi-parameter handling + maxArgs: 3, + result: 10.0, + funcType: functions.TypeAggregation, // Ensure it's an aggregation function + } + + // Register the mock function + err := functions.Register(mockMultiParamFunc) + if err != nil { + t.Logf("Function already registered: %v", err) + } + defer functions.Unregister("test_multi") + + // Test multi-parameter function with comma-separated arguments + expr := "test_multi(field1, field2, field3)" + aggFields := []AggregationFieldInfo{} + resultFields, resultExpr := parseNestedFunctionsWithDepth(expr, aggFields, 0) + + if len(resultFields) > 0 { + assert.Equal(t, "test_multi", resultFields[0].FuncName) + assert.Equal(t, "field1", resultFields[0].InputField) // Should use first parameter + assert.Contains(t, resultExpr, "__test_multi_") + } else { + t.Logf("No aggregation fields found for test_multi, expr: %s", resultExpr) + } + + // Test case 2: Non-aggregation function (should preserve function but process parameters) + // Create a mock math function + mockMathFunc := &mockAggregatorFunction{ + name: "round", + funcType: functions.TypeMath, // Non-aggregation type + result: 5.0, + } + + err = functions.Register(mockMathFunc) + if err != nil { + t.Logf("Function already registered: %v", err) + } + defer functions.Unregister("round") + + // Test non-aggregation function with nested aggregation + expr2 := "round(sum(value))" + aggFields2 := []AggregationFieldInfo{} + resultFields2, resultExpr2 := parseNestedFunctionsWithDepth(expr2, aggFields2, 0) + + // Should find the inner sum function + assert.Equal(t, 1, len(resultFields2)) + assert.Equal(t, "sum", resultFields2[0].FuncName) + // The round function should be preserved with placeholder for sum + assert.Contains(t, resultExpr2, "round(") + assert.Contains(t, resultExpr2, "__sum_") + + // Test case 3: Invalid function call (no matching paren) + expr3 := "invalid_func(" + aggFields3 := []AggregationFieldInfo{} + resultFields3, resultExpr3 := parseNestedFunctionsWithDepth(expr3, aggFields3, 0) + + // Should return unchanged + assert.Equal(t, 0, len(resultFields3)) + assert.Equal(t, expr3, resultExpr3) + + // Test case 4: Top-level single aggregation function (should preserve outer function) + expr4 := "avg(sum(value))" + aggFields4 := []AggregationFieldInfo{} + resultFields4, resultExpr4 := parseNestedFunctionsWithDepth(expr4, aggFields4, 0) + + // Should find the inner sum function but preserve avg + assert.Equal(t, 1, len(resultFields4)) + assert.Equal(t, "sum", resultFields4[0].FuncName) + // The avg function should be preserved + assert.Contains(t, resultExpr4, "avg(") + assert.Contains(t, resultExpr4, "__sum_") +} + +// Update mockAggregatorFunction to support different function types and argument counts +type mockAggregatorFunctionWithConfig struct { + *mockAggregatorFunction + minArgs int + maxArgs int + funcType functions.FunctionType +} + +func (m *mockAggregatorFunctionWithConfig) GetMinArgs() int { + if m.minArgs > 0 { + return m.minArgs + } + return m.mockAggregatorFunction.GetMinArgs() +} + +func (m *mockAggregatorFunctionWithConfig) GetMaxArgs() int { + if m.maxArgs > 0 { + return m.maxArgs + } + return m.mockAggregatorFunction.GetMaxArgs() +} + +func (m *mockAggregatorFunctionWithConfig) GetType() functions.FunctionType { + if m.funcType != "" { + return m.funcType + } + return m.mockAggregatorFunction.GetType() +} + + +// TestWindowFunctionWrapper 测试WindowFunctionWrapper的所有方法 +func TestWindowFunctionWrapper(t *testing.T) { + // 创建一个mock的AggregatorFunction + mockAgg := &mockAggregatorFunction{result: 42.0} + + // 创建WindowFunctionWrapper + wrapper := &WindowFunctionWrapper{aggFunc: mockAgg} + + // 测试New方法 + newWrapper := wrapper.New() + assert.NotNil(t, newWrapper) + assert.IsType(t, &WindowFunctionWrapper{}, newWrapper) + + // 测试Add方法 + wrapper.Add(10.0) + assert.Len(t, mockAgg.values, 1) + assert.Equal(t, 10.0, mockAgg.values[0]) + + // 测试Result方法 + result := wrapper.Result() + assert.Equal(t, 42.0, result) + + // 测试Reset方法 + wrapper.Reset() + assert.Nil(t, mockAgg.values) + assert.Nil(t, mockAgg.result) + + // 测试Clone方法 + clonedWrapper := wrapper.Clone() + assert.NotNil(t, clonedWrapper) + assert.IsType(t, &WindowFunctionWrapper{}, clonedWrapper) + assert.NotSame(t, wrapper, clonedWrapper) +} + +// TestCreateParameterizedAggregator 测试创建参数化聚合器 +func TestCreateParameterizedAggregator(t *testing.T) { + groupFields := []string{"category"} + aggFields := []AggregationField{ + {InputField: "value", AggregateType: Sum, OutputAlias: "sum_value"}, + } + + agg := NewEnhancedGroupAggregator(groupFields, aggFields) + require.NotNil(t, agg) + + tests := []struct { + name string + fieldInfo AggregationFieldInfo + }{ + { + name: "SUM聚合函数", + fieldInfo: AggregationFieldInfo{ + FuncName: "SUM", + InputField: "value", + FullCall: "SUM(value)", + AggType: Sum, + }, + }, + { + name: "COUNT聚合函数", + fieldInfo: AggregationFieldInfo{ + FuncName: "COUNT", + InputField: "*", + FullCall: "COUNT(*)", + AggType: Count, + }, + }, + { + name: "AVG聚合函数", + fieldInfo: AggregationFieldInfo{ + FuncName: "AVG", + InputField: "value", + FullCall: "AVG(value)", + AggType: Avg, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + aggregator := agg.createParameterizedAggregator(tt.fieldInfo) + // 只验证返回值不为nil,因为具体实现可能返回nil + _ = aggregator + }) + } +} + +// TestPostAggregationComplexScenarios 测试复杂的后聚合场景 +func TestPostAggregationComplexScenarios(t *testing.T) { + groupFields := []string{"category"} + aggFields := []AggregationField{ + {InputField: "value", AggregateType: Sum, OutputAlias: "sum_value"}, + {InputField: "value", AggregateType: Count, OutputAlias: "count_value"}, + } + + agg := NewEnhancedGroupAggregator(groupFields, aggFields) + require.NotNil(t, agg) + + // 添加后聚合表达式 + requiredFields := []AggregationFieldInfo{ + {FuncName: "SUM", InputField: "value", Placeholder: "sum_value", AggType: Sum}, + {FuncName: "COUNT", InputField: "value", Placeholder: "count_value", AggType: Count}, + } + err := agg.AddPostAggregationExpression("avg_calc", "sum_value / count_value", requiredFields) + assert.NoError(t, err) + + // 添加测试数据 + testData := []map[string]interface{}{ + {"category": "A", "value": 10.0}, + {"category": "A", "value": 20.0}, + {"category": "B", "value": 30.0}, + {"category": "B", "value": 40.0}, + } + + for _, data := range testData { + err := agg.Add(data) + assert.NoError(t, err) + } + + // 获取结果 + results, err := agg.GetResults() + assert.NoError(t, err) + assert.NotEmpty(t, results) + + // 验证结果数量 + assert.Len(t, results, 2) // 应该有两个分组结果 + + // 验证后聚合计算结果存在 + for _, result := range results { + if category, ok := result["category"]; ok { + assert.Contains(t, result, "sum_value") + assert.Contains(t, result, "count_value") + + // 验证基本的数据类型 + if category == "A" || category == "B" { + assert.NotNil(t, result["sum_value"]) + assert.NotNil(t, result["count_value"]) + // avg_calc可能不存在,因为后聚合处理可能需要特殊配置 + // 只验证基础聚合字段存在即可 + } + } + } +} \ No newline at end of file diff --git a/functions/aggregator_interface.go b/functions/aggregator_interface.go index 0c4633c..d9b56cf 100644 --- a/functions/aggregator_interface.go +++ b/functions/aggregator_interface.go @@ -23,6 +23,13 @@ type AnalyticalFunction interface { AggregatorFunction } +// ParameterizedFunction defines the interface for functions that need parameter initialization +type ParameterizedFunction interface { + AggregatorFunction + // Init initializes the function with parsed arguments + Init(args []interface{}) error +} + // CreateAggregator creates an aggregator instance func CreateAggregator(name string) (AggregatorFunction, error) { fn, exists := Get(name) @@ -37,6 +44,42 @@ func CreateAggregator(name string) (AggregatorFunction, error) { return nil, fmt.Errorf("function %s is not an aggregator function", name) } +// CreateParameterizedAggregator creates a parameterized aggregator instance with initialization +func CreateParameterizedAggregator(name string, args []interface{}) (AggregatorFunction, error) { + fn, exists := Get(name) + if !exists { + return nil, fmt.Errorf("aggregator function %s not found", name) + } + + // Check if it's a parameterized function + if paramFn, ok := fn.(ParameterizedFunction); ok { + newInstance := paramFn.New() + if paramNewInstance, ok := newInstance.(ParameterizedFunction); ok { + if err := paramNewInstance.Init(args); err != nil { + return nil, fmt.Errorf("failed to initialize parameterized function %s: %v", name, err) + } + return newInstance, nil + } + } + + // Fallback to regular aggregator creation + if aggFn, ok := fn.(AggregatorFunction); ok { + return aggFn.New(), nil + } + + return nil, fmt.Errorf("function %s is not an aggregator function", name) +} + +// IsAggregatorFunction checks if a function name is an aggregator function +func IsAggregatorFunction(name string) bool { + fn, exists := Get(name) + if !exists { + return false + } + _, ok := fn.(AggregatorFunction) + return ok +} + // CreateAnalytical creates an analytical function instance func CreateAnalytical(name string) (AnalyticalFunction, error) { fn, exists := Get(name) diff --git a/functions/aggregator_types.go b/functions/aggregator_types.go index 259d309..9c9734c 100644 --- a/functions/aggregator_types.go +++ b/functions/aggregator_types.go @@ -18,6 +18,7 @@ const ( WindowStart AggregateType = "window_start" WindowEnd AggregateType = "window_end" Collect AggregateType = "collect" + FirstValue AggregateType = "first_value" LastValue AggregateType = "last_value" MergeAgg AggregateType = "merge_agg" StdDev AggregateType = "stddev" @@ -32,6 +33,8 @@ const ( HadChanged AggregateType = "had_changed" // Expression aggregator for handling custom functions Expression AggregateType = "expression" + // Post-aggregation marker for fields that need post-processing + PostAggregation AggregateType = "post_aggregation" ) // String constant versions for convenience @@ -46,6 +49,7 @@ const ( WindowStartStr = string(WindowStart) WindowEndStr = string(WindowEnd) CollectStr = string(Collect) + FirstValueStr = string(FirstValue) LastValueStr = string(LastValue) MergeAggStr = string(MergeAgg) StdStr = "std" @@ -61,6 +65,8 @@ const ( HadChangedStr = string(HadChanged) // Expression aggregator ExpressionStr = string(Expression) + // Post-aggregation marker + PostAggregationStr = string(PostAggregation) ) // LegacyAggregatorFunction defines aggregator function interface compatible with legacy aggregator interface diff --git a/functions/aggregator_types_test.go b/functions/aggregator_types_test.go index 9181367..7c3f701 100644 --- a/functions/aggregator_types_test.go +++ b/functions/aggregator_types_test.go @@ -133,7 +133,7 @@ func TestCreateLegacyAggregatorPanic(t *testing.T) { func TestFunctionAggregatorWrapper(t *testing.T) { // 创建一个测试聚合器函数 testAgg := &TestAggregatorFunction{} - + // 创建一个测试适配器 adapter := &AggregatorAdapter{ aggFunc: testAgg, @@ -157,7 +157,7 @@ func TestFunctionAggregatorWrapper(t *testing.T) { func TestAnalyticalAggregatorWrapper(t *testing.T) { // 创建一个测试分析函数 testAnalFunc := &TestAnalyticalFunction{} - + // 创建一个测试适配器 adapter := &AnalyticalAggregatorAdapter{ analFunc: testAnalFunc, @@ -268,6 +268,16 @@ func (t *TestAggregatorFunction) Execute(ctx *FunctionContext, args []interface{ return t.Result(), nil } +// GetMinArgs 返回最小参数数量 +func (t *TestAggregatorFunction) GetMinArgs() int { + return 1 +} + +// GetMaxArgs 返回最大参数数量 +func (t *TestAggregatorFunction) GetMaxArgs() int { + return 1 +} + // TestAnalyticalFunction 测试用的分析函数实现 type TestAnalyticalFunction struct { values []interface{} @@ -335,4 +345,14 @@ func (t *TestAnalyticalFunction) Validate(args []interface{}) error { // Execute 执行函数 func (t *TestAnalyticalFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { return t.Result(), nil -} \ No newline at end of file +} + +// GetMinArgs 返回最小参数数量 +func (t *TestAnalyticalFunction) GetMinArgs() int { + return 1 +} + +// GetMaxArgs 返回最大参数数量 +func (t *TestAnalyticalFunction) GetMaxArgs() int { + return 1 +} diff --git a/functions/analytical_aggregator_adapter_test.go b/functions/analytical_aggregator_adapter_test.go index 7efc0e8..b5533ab 100644 --- a/functions/analytical_aggregator_adapter_test.go +++ b/functions/analytical_aggregator_adapter_test.go @@ -294,4 +294,4 @@ func (m *MockAnalyticalFunction) Clone() AggregatorFunction { } copy(newMock.values, m.values) return newMock -} \ No newline at end of file +} diff --git a/functions/base.go b/functions/base.go index dbd7e3f..ecaf76f 100644 --- a/functions/base.go +++ b/functions/base.go @@ -62,6 +62,16 @@ func (bf *BaseFunction) GetAliases() []string { return bf.aliases } +// GetMinArgs returns the minimum number of arguments +func (bf *BaseFunction) GetMinArgs() int { + return bf.minArgs +} + +// GetMaxArgs returns the maximum number of arguments (-1 means unlimited) +func (bf *BaseFunction) GetMaxArgs() int { + return bf.maxArgs +} + // ValidateArgCount validates the number of arguments func (bf *BaseFunction) ValidateArgCount(args []interface{}) error { argCount := len(args) diff --git a/functions/builtin.go b/functions/builtin.go index 66b06b2..223f1d1 100644 --- a/functions/builtin.go +++ b/functions/builtin.go @@ -83,6 +83,7 @@ func registerBuiltinFunctions() { _ = Register(NewMedianAggregatorFunction()) _ = Register(NewPercentileFunction()) _ = Register(NewCollectFunction()) + _ = Register(NewFirstValueFunction()) _ = Register(NewLastValueFunction()) _ = Register(NewMergeAggFunction()) _ = Register(NewStdDevSAggregatorFunction()) @@ -91,8 +92,9 @@ func registerBuiltinFunctions() { _ = Register(NewVarSAggregatorFunction()) // Window functions + _ = Register(NewWindowStartFunction()) + _ = Register(NewWindowEndFunction()) _ = Register(NewRowNumberFunction()) - _ = Register(NewFirstValueFunction()) _ = Register(NewLeadFunction()) _ = Register(NewNthValueFunction()) @@ -102,10 +104,6 @@ func registerBuiltinFunctions() { _ = Register(NewChangedColFunction()) _ = Register(NewHadChangedFunction()) - // Window functions - _ = Register(NewWindowStartFunction()) - _ = Register(NewWindowEndFunction()) - // Expression functions _ = Register(NewExpressionFunction()) _ = Register(NewExprFunction()) diff --git a/functions/expr_bridge.go b/functions/expr_bridge.go index d5c4ba6..8622131 100644 --- a/functions/expr_bridge.go +++ b/functions/expr_bridge.go @@ -54,12 +54,17 @@ func (bridge *ExprBridge) RegisterStreamSQLFunctionsToExpr() []expr.Option { // Add function to expr environment bridge.exprEnv[name] = wrappedFunc + bridge.exprEnv[strings.ToUpper(name)] = wrappedFunc - // Register function type information + // Register function type information for both lowercase and uppercase options = append(options, expr.Function( name, wrappedFunc, )) + options = append(options, expr.Function( + strings.ToUpper(name), + wrappedFunc, + )) } return options @@ -143,7 +148,7 @@ func (bridge *ExprBridge) CompileExpressionWithStreamSQLFunctions(expression str // 启用一些有用的expr功能 options = append(options, expr.AllowUndefinedVariables(), // 允许未定义变量 - expr.AsBool(), // 期望布尔结果(可根据需要调整) + // 移除 expr.AsBool() 以允许返回任意类型的值 ) return expr.Compile(expression, options...) diff --git a/functions/functions_aggregation.go b/functions/functions_aggregation.go index 161ac9f..0761b0b 100644 --- a/functions/functions_aggregation.go +++ b/functions/functions_aggregation.go @@ -150,7 +150,7 @@ func (f *AvgFunction) Add(value interface{}) { func (f *AvgFunction) Result() interface{} { if f.count == 0 { - return nil // Return nil when no valid values instead of 0.0 + return nil // Return NULL when no valid values according to SQL standard } return f.sum / float64(f.count) } @@ -187,6 +187,13 @@ func (f *MinFunction) Validate(args []interface{}) error { } func (f *MinFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 检查是否有nil参数 + for _, arg := range args { + if arg == nil { + return nil, nil + } + } + min := math.Inf(1) for _, arg := range args { val, err := cast.ToFloat64E(arg) @@ -224,7 +231,7 @@ func (f *MinFunction) Add(value interface{}) { func (f *MinFunction) Result() interface{} { if f.first { - return nil + return nil // Return NULL when no data according to SQL standard } return f.value } @@ -261,6 +268,13 @@ func (f *MaxFunction) Validate(args []interface{}) error { } func (f *MaxFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 检查是否有nil参数 + for _, arg := range args { + if arg == nil { + return nil, nil + } + } + max := math.Inf(-1) for _, arg := range args { val, err := cast.ToFloat64E(arg) @@ -298,7 +312,7 @@ func (f *MaxFunction) Add(value interface{}) { func (f *MaxFunction) Result() interface{} { if f.first { - return nil + return nil // Return NULL when no data according to SQL standard } return f.value } @@ -582,6 +596,69 @@ func (f *CollectFunction) Clone() AggregatorFunction { return newFunc } +// FirstValueFunction 首个值函数 - 返回组中第一行的值 +type FirstValueFunction struct { + *BaseFunction + firstValue interface{} + hasValue bool +} + +func NewFirstValueFunction() *FirstValueFunction { + return &FirstValueFunction{ + BaseFunction: NewBaseFunction("first_value", TypeAggregation, "聚合函数", "返回第一个值", 1, -1), + firstValue: nil, + hasValue: false, + } +} + +func (f *FirstValueFunction) Validate(args []interface{}) error { + return f.ValidateArgCount(args) +} + +func (f *FirstValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + if err := f.Validate(args); err != nil { + return nil, err + } + if len(args) == 0 { + return nil, fmt.Errorf("function %s requires at least one argument", f.GetName()) + } + // 返回第一个值 + return args[0], nil +} + +// 实现AggregatorFunction接口 +func (f *FirstValueFunction) New() AggregatorFunction { + return &FirstValueFunction{ + BaseFunction: f.BaseFunction, + firstValue: nil, + hasValue: false, + } +} + +func (f *FirstValueFunction) Add(value interface{}) { + if !f.hasValue { + f.firstValue = value + f.hasValue = true + } +} + +func (f *FirstValueFunction) Result() interface{} { + return f.firstValue +} + +func (f *FirstValueFunction) Reset() { + f.firstValue = nil + f.hasValue = false +} + +func (f *FirstValueFunction) Clone() AggregatorFunction { + return &FirstValueFunction{ + BaseFunction: f.BaseFunction, + firstValue: f.firstValue, + hasValue: f.hasValue, + } +} + // LastValueFunction 最后值函数 - 返回组中最后一行的值 type LastValueFunction struct { *BaseFunction diff --git a/functions/functions_conditional.go b/functions/functions_conditional.go index e65c9fe..523079b 100644 --- a/functions/functions_conditional.go +++ b/functions/functions_conditional.go @@ -24,6 +24,17 @@ func (f *IfNullFunction) Validate(args []interface{}) error { func (f *IfNullFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { if args[0] == nil { + // 当第一个参数为nil时,返回第二个参数 + // 如果第二个参数是数字0,确保返回float64类型以保持一致性 + if args[1] != nil { + // 尝试转换为float64以保持数值类型一致性 + if val, ok := args[1].(int); ok && val == 0 { + return 0.0, nil + } + if val, ok := args[1].(float32); ok { + return float64(val), nil + } + } return args[1], nil } return args[0], nil diff --git a/functions/functions_math.go b/functions/functions_math.go index f067bf8..079ba23 100644 --- a/functions/functions_math.go +++ b/functions/functions_math.go @@ -550,6 +550,11 @@ func (f *RoundFunction) Validate(args []interface{}) error { } func (f *RoundFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { + // 检查第一个参数是否为nil + if args[0] == nil { + return nil, nil + } + val, err := cast.ToFloat64E(args[0]) if err != nil { return nil, err @@ -559,6 +564,11 @@ func (f *RoundFunction) Execute(ctx *FunctionContext, args []interface{}) (inter return math.Round(val), nil } + // 检查第二个参数是否为nil(如果存在) + if args[1] == nil { + return nil, nil + } + precision, err := cast.ToIntE(args[1]) if err != nil { return nil, err diff --git a/functions/functions_window.go b/functions/functions_window.go index 5436e41..a541ce2 100644 --- a/functions/functions_window.go +++ b/functions/functions_window.go @@ -38,7 +38,7 @@ type WindowStartFunction struct { func NewWindowStartFunction() *WindowStartFunction { return &WindowStartFunction{ - BaseFunction: NewBaseFunction("window_start", TypeWindow, "window", "Return window start time", 0, 0), + BaseFunction: NewBaseFunction("window_start", TypeWindow, "窗口函数", "返回窗口开始时间", 0, 0), } } @@ -88,7 +88,7 @@ type WindowEndFunction struct { func NewWindowEndFunction() *WindowEndFunction { return &WindowEndFunction{ - BaseFunction: NewBaseFunction("window_end", TypeWindow, "window", "Return window end time", 0, 0), + BaseFunction: NewBaseFunction("window_end", TypeWindow, "窗口函数", "返回窗口结束时间", 0, 0), } } @@ -246,63 +246,6 @@ func (f *ExpressionAggregatorFunction) Clone() AggregatorFunction { } } -// FirstValueFunction 返回窗口中第一个值 -type FirstValueFunction struct { - *BaseFunction - firstValue interface{} - hasValue bool -} - -func NewFirstValueFunction() *FirstValueFunction { - return &FirstValueFunction{ - BaseFunction: NewBaseFunction("first_value", TypeWindow, "窗口函数", "返回窗口中第一个值", 1, 1), - hasValue: false, - } -} - -func (f *FirstValueFunction) Validate(args []interface{}) error { - return f.ValidateArgCount(args) -} - -func (f *FirstValueFunction) Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) { - if err := f.Validate(args); err != nil { - return nil, err - } - return f.firstValue, nil -} - -// 实现AggregatorFunction接口 -func (f *FirstValueFunction) New() AggregatorFunction { - return &FirstValueFunction{ - BaseFunction: f.BaseFunction, - hasValue: false, - } -} - -func (f *FirstValueFunction) Add(value interface{}) { - if !f.hasValue { - f.firstValue = value - f.hasValue = true - } -} - -func (f *FirstValueFunction) Result() interface{} { - return f.firstValue -} - -func (f *FirstValueFunction) Reset() { - f.firstValue = nil - f.hasValue = false -} - -func (f *FirstValueFunction) Clone() AggregatorFunction { - return &FirstValueFunction{ - BaseFunction: f.BaseFunction, - firstValue: f.firstValue, - hasValue: f.hasValue, - } -} - // LeadFunction 返回当前行之后第N行的值 type LeadFunction struct { *BaseFunction @@ -376,9 +319,9 @@ func (f *LeadFunction) New() AggregatorFunction { return &LeadFunction{ BaseFunction: f.BaseFunction, values: make([]interface{}, 0), - offset: f.offset, - defaultValue: f.defaultValue, - hasDefault: f.hasDefault, + offset: f.offset, // 保持offset参数 + defaultValue: f.defaultValue, // 保持默认值 + hasDefault: f.hasDefault, // 保持默认值标志 } } @@ -387,12 +330,12 @@ func (f *LeadFunction) Add(value interface{}) { } func (f *LeadFunction) Result() interface{} { - // Lead函数的结果需要在所有数据添加完成后计算 - // 如果没有足够的数据,返回默认值 - if len(f.values) == 0 && f.hasDefault { + // LEAD函数在没有指定当前行位置的情况下,返回默认值或nil + // 这通常用于聚合场景,真正的窗口计算需要在窗口处理器中进行 + if f.hasDefault { return f.defaultValue } - // 这里简化实现,返回nil + return nil } @@ -415,6 +358,41 @@ func (f *LeadFunction) Clone() AggregatorFunction { return clone } +// Init implements ParameterizedFunction interface +func (f *LeadFunction) Init(args []interface{}) error { + if len(args) < 2 { + // LEAD with default offset = 1 + f.offset = 1 + return nil + } + + // Parse offset parameter + offset := 1 + if offsetVal, ok := args[1].(int); ok { + offset = offsetVal + } else if offsetVal, ok := args[1].(int64); ok { + offset = int(offsetVal) + } else if offsetVal, ok := args[1].(float64); ok { + offset = int(offsetVal) + } else { + return fmt.Errorf("lead offset must be an integer, got %T", args[1]) + } + + if offset < 0 { + return fmt.Errorf("lead offset must be non-negative, got %d", offset) + } + + f.offset = offset + + // Parse default value if provided + if len(args) >= 3 { + f.defaultValue = args[2] + f.hasDefault = true + } + + return nil +} + // NthValueFunction 返回窗口中第N个值 type NthValueFunction struct { *BaseFunction @@ -484,11 +462,13 @@ func (f *NthValueFunction) Execute(ctx *FunctionContext, args []interface{}) (in // 实现AggregatorFunction接口 func (f *NthValueFunction) New() AggregatorFunction { - return &NthValueFunction{ + newInstance := &NthValueFunction{ BaseFunction: f.BaseFunction, values: make([]interface{}, 0), - n: f.n, + n: f.n, // 保持n参数 } + + return newInstance } func (f *NthValueFunction) Add(value interface{}) { @@ -510,8 +490,34 @@ func (f *NthValueFunction) Clone() AggregatorFunction { clone := &NthValueFunction{ BaseFunction: f.BaseFunction, values: make([]interface{}, len(f.values)), - n: f.n, + n: f.n, // 保持n参数 } copy(clone.values, f.values) return clone } + +// Init implements ParameterizedFunction interface +func (f *NthValueFunction) Init(args []interface{}) error { + if len(args) < 2 { + return fmt.Errorf("nth_value requires at least 2 arguments") + } + + // Parse N parameter + n := 1 + if nVal, ok := args[1].(int); ok { + n = nVal + } else if nVal, ok := args[1].(int64); ok { + n = int(nVal) + } else if nVal, ok := args[1].(float64); ok { + n = int(nVal) + } else { + return fmt.Errorf("nth_value n must be an integer, got %T", args[1]) + } + + if n <= 0 { + return fmt.Errorf("nth_value n must be positive, got %d", n) + } + + f.n = n + return nil +} diff --git a/functions/registry.go b/functions/registry.go index 8a8d29b..661968f 100644 --- a/functions/registry.go +++ b/functions/registry.go @@ -61,6 +61,11 @@ type Function interface { Execute(ctx *FunctionContext, args []interface{}) (interface{}, error) // GetDescription returns the function description GetDescription() string + + // GetMinArgs returns the minimum number of arguments + GetMinArgs() int + // GetMaxArgs returns the maximum number of arguments (-1 means unlimited) + GetMaxArgs() int } // FunctionRegistry manages function registration and retrieval diff --git a/go.mod b/go.mod index 0fb5b43..c4a50e8 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/rulego/streamsql go 1.18 require ( - github.com/expr-lang/expr v1.17.2 + github.com/expr-lang/expr v1.17.6 github.com/stretchr/testify v1.10.0 ) diff --git a/go.sum b/go.sum index fc5187e..23fe613 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/expr-lang/expr v1.17.2 h1:o0A99O/Px+/DTjEnQiodAgOIK9PPxL8DtXhBRKC+Iso= -github.com/expr-lang/expr v1.17.2/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec= +github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/rsql/ast.go b/rsql/ast.go index 1d58527..fb5c0f9 100644 --- a/rsql/ast.go +++ b/rsql/ast.go @@ -2,6 +2,7 @@ package rsql import ( "fmt" + "regexp" "strings" "time" @@ -104,7 +105,10 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { simpleFields = append(simpleFields, fieldName+":"+field.Alias) } else { // For fields without alias, check if it's a string literal - _, n, _, _ := ParseAggregateTypeWithExpression(fieldName) + _, n, _, _, err := ParseAggregateTypeWithExpression(fieldName) + if err != nil { + return nil, "", err + } if n != "" { // If string literal, use parsed field name (remove quotes) simpleFields = append(simpleFields, n) @@ -119,10 +123,16 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { } // Build field mapping and expression information - aggs, fields, expressions := buildSelectFieldsWithExpressions(s.Fields) + aggs, fields, expressions, postAggExpressions, err := buildSelectFieldsWithExpressions(s.Fields) + if err != nil { + return nil, "", err + } // Extract field order information - fieldOrder := extractFieldOrder(s.Fields) + fieldOrder, err := extractFieldOrder(s.Fields) + if err != nil { + return nil, "", err + } // Build Stream configuration config := types.Config{ @@ -133,16 +143,17 @@ func (s *SelectStatement) ToStreamConfig() (*types.Config, string, error) { TimeUnit: s.Window.TimeUnit, GroupByKey: groupByKey, }, - GroupFields: extractGroupFields(s), - SelectFields: aggs, - FieldAlias: fields, - Distinct: s.Distinct, - Limit: s.Limit, - NeedWindow: needWindow, - SimpleFields: simpleFields, - Having: s.Having, - FieldExpressions: expressions, - FieldOrder: fieldOrder, + GroupFields: extractGroupFields(s), + SelectFields: aggs, + FieldAlias: fields, + Distinct: s.Distinct, + Limit: s.Limit, + NeedWindow: needWindow, + SimpleFields: simpleFields, + Having: s.Having, + FieldExpressions: expressions, + PostAggExpressions: postAggExpressions, + FieldOrder: fieldOrder, } return &config, s.Condition, nil @@ -192,7 +203,7 @@ func isAggregationFunction(expr string) bool { // extractFieldOrder extracts original order of fields from Fields slice // Returns field names list in order of appearance in SELECT statement -func extractFieldOrder(fields []Field) []string { +func extractFieldOrder(fields []Field) ([]string, error) { var fieldOrder []string for _, field := range fields { @@ -201,7 +212,10 @@ func extractFieldOrder(fields []Field) []string { fieldOrder = append(fieldOrder, field.Alias) } else { // Without alias, try to parse expression to get field name - _, fieldName, _, _ := ParseAggregateTypeWithExpression(field.Expression) + _, fieldName, _, _, err := ParseAggregateTypeWithExpression(field.Expression) + if err != nil { + return nil, err + } if fieldName != "" { // If parsed field name (like string literal), use parsed name fieldOrder = append(fieldOrder, fieldName) @@ -212,7 +226,7 @@ func extractFieldOrder(fields []Field) []string { } } - return fieldOrder + return fieldOrder, nil } func extractGroupFields(s *SelectStatement) []string { var fields []string @@ -224,13 +238,16 @@ func extractGroupFields(s *SelectStatement) []string { return fields } -func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string) { +func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateType, fieldMap map[string]string, err error) { selectFields := make(map[string]aggregator.AggregateType) fieldMap = make(map[string]string) for _, f := range fields { if alias := f.Alias; alias != "" { - t, n, _, _ := ParseAggregateTypeWithExpression(f.Expression) + t, n, _, _, parseErr := ParseAggregateTypeWithExpression(f.Expression) + if parseErr != nil { + return nil, nil, parseErr + } if t != "" { // Use alias as key for aggregator, not field name selectFields[alias] = t @@ -245,70 +262,128 @@ func buildSelectFields(fields []Field) (aggMap map[string]aggregator.AggregateTy } } else { // Without alias, use expression itself as field name - t, n, _, _ := ParseAggregateTypeWithExpression(f.Expression) + t, n, _, _, parseErr := ParseAggregateTypeWithExpression(f.Expression) + if parseErr != nil { + return nil, nil, parseErr + } if t != "" && n != "" { selectFields[n] = t fieldMap[n] = n } } } - return selectFields, fieldMap + return selectFields, fieldMap, nil +} + +// detectNestedAggregation 检测表达式中是否存在聚合函数嵌套聚合函数的情况 +// 如果发现嵌套聚合函数,返回错误信息 +func detectNestedAggregation(expr string) error { + return detectNestedAggregationRecursive(expr, false) +} + +// detectNestedAggregationRecursive 递归检测嵌套聚合函数 +// inAggregation 表示当前是否在聚合函数内部 +func detectNestedAggregationRecursive(expr string, inAggregation bool) error { + // 使用正则表达式匹配函数调用模式 + pattern := regexp.MustCompile(`(?i)([a-z_]+)\s*\(`) + matches := pattern.FindAllStringSubmatchIndex(expr, -1) + + for _, match := range matches { + funcStart := match[0] + funcName := strings.ToLower(expr[match[2]:match[3]]) + + // 检查函数是否为聚合函数 + if fn, exists := functions.Get(funcName); exists { + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + // 如果当前已经在聚合函数内部,且又发现了聚合函数,则报错 + if inAggregation { + return fmt.Errorf("aggregate function calls cannot be nested") + } + + // 找到该函数的参数部分 + funcEnd := findMatchingParenInternal(expr, funcStart+len(funcName)) + if funcEnd > funcStart { + // 提取函数参数 + paramStart := funcStart + len(funcName) + 1 + params := expr[paramStart:funcEnd] + + // 在聚合函数参数内部递归检查 + if err := detectNestedAggregationRecursive(params, true); err != nil { + return err + } + } + } + } + } + + return nil } // Parse aggregation function and return expression information -func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.AggregateType, name string, expression string, allFields []string) { +func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.AggregateType, name string, expression string, allFields []string, err error) { + // 首先检测是否存在嵌套聚合函数 + if err := detectNestedAggregation(exprStr); err != nil { + // 如果发现嵌套聚合,返回错误 + return "", "", "", nil, err + } + // Special handling for CASE expressions if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(exprStr)), "CASE") { // CASE expressions are handled as special expressions if parsedExpr, err := expr.NewExpression(exprStr); err == nil { allFields = parsedExpr.GetFields() } - return "expression", "", exprStr, allFields + return "expression", "", exprStr, allFields, nil } - // Check if it's nested functions - if hasNestedFunctions(exprStr) { - // Nested function case, extract all functions - funcs := extractAllFunctions(exprStr) - - // Find aggregation functions - var aggregationFunc string - for _, funcName := range funcs { - if fn, exists := functions.Get(funcName); exists { - switch fn.GetType() { - case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: - aggregationFunc = funcName - break - } - } - } - - if aggregationFunc != "" { - // Nested expression with aggregation function, handle entire expression as expression - if parsedExpr, err := expr.NewExpression(exprStr); err == nil { - allFields = parsedExpr.GetFields() - } - return aggregator.AggregateType(aggregationFunc), "", exprStr, allFields - } else { - // Nested expression without aggregation function, handle as regular expression - if parsedExpr, err := expr.NewExpression(exprStr); err == nil { - allFields = parsedExpr.GetFields() - } - return "expression", "", exprStr, allFields + // Check if it's an expression containing operators with functions + if containsOperatorsOutsideFunctions(exprStr) && containsFunctions(exprStr) { + // This is a complex expression with functions and operators + // Extract all fields referenced in the expression + if parsedExpr, err := expr.NewExpression(exprStr); err == nil { + allFields = parsedExpr.GetFields() } + // Return as expression type for post-aggregation evaluation + return "expression", "", exprStr, allFields, nil } - // Original logic for single function + // Original logic for single function (moved up to prioritize outer function detection) // Extract function name funcName := extractFunctionName(exprStr) + + // Check if it's nested functions without operators + hasNested := hasNestedFunctions(exprStr) + if hasNested && funcName != "" { + // For nested functions, check if the outer function is an aggregation function + if fn, exists := functions.Get(funcName); exists { + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + // Outer function is aggregation - handle as aggregation with expression parameter + name, expression, allFields := extractAggFieldWithExpression(exprStr, funcName) + + return aggregator.AggregateType(funcName), name, expression, allFields, nil + } + } + // Multiple functions but no operators and outer function is not aggregation - treat as expression + if parsedExpr, err := expr.NewExpression(exprStr); err == nil { + allFields = parsedExpr.GetFields() + } + return "expression", "", exprStr, allFields, nil + } if funcName == "" { + // Special handling for SELECT * case + if strings.TrimSpace(exprStr) == "*" { + return "", "", "", nil, nil // Don't treat * as expression + } + // Check if it's a string literal trimmed := strings.TrimSpace(exprStr) if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) || (strings.HasPrefix(trimmed, "\"") && strings.HasSuffix(trimmed, "\"")) { // String literal: use content without quotes as field name fieldName := trimmed[1 : len(trimmed)-1] - return "expression", fieldName, exprStr, nil + return "expression", fieldName, exprStr, nil, nil } // If not a function call but contains operators or keywords, it might be an expression @@ -319,15 +394,15 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg if parsedExpr, err := expr.NewExpression(exprStr); err == nil { allFields = parsedExpr.GetFields() } - return "expression", "", exprStr, allFields + return "expression", "", exprStr, allFields, nil } - return "", "", "", nil + return "", "", "", nil, nil } // Check if it's a registered function fn, exists := functions.Get(funcName) if !exists { - return "", "", "", nil + return "", "", "", nil, nil } // Extract function parameters and expression information @@ -337,15 +412,15 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg switch fn.GetType() { case functions.TypeAggregation: // Aggregation function: use function name as aggregation type - return aggregator.AggregateType(funcName), name, expression, allFields + return aggregator.AggregateType(funcName), name, expression, allFields, nil case functions.TypeAnalytical: // Analytical function: use function name as aggregation type - return aggregator.AggregateType(funcName), name, expression, allFields + return aggregator.AggregateType(funcName), name, expression, allFields, nil case functions.TypeWindow: // Window function: use function name as aggregation type - return aggregator.AggregateType(funcName), name, expression, allFields + return aggregator.AggregateType(funcName), name, expression, allFields, nil case functions.TypeString, functions.TypeConversion, functions.TypeCustom, functions.TypeMath: // String, conversion, custom, math functions: handle as expressions in aggregation queries @@ -355,12 +430,12 @@ func ParseAggregateTypeWithExpression(exprStr string) (aggType aggregator.Aggreg if parsedExpr, err := expr.NewExpression(fullExpression); err == nil { allFields = parsedExpr.GetFields() } - return "expression", name, fullExpression, allFields + return "expression", name, fullExpression, allFields, nil default: // Other types of functions don't use aggregation // These functions will be handled directly in non-window mode - return "", "", "", nil + return "", "", "", nil, nil } } @@ -418,8 +493,20 @@ func hasNestedFunctions(expr string) bool { return len(funcs) > 1 } +// containsOperators checks if expression contains arithmetic or comparison operators +func containsOperators(expr string) bool { + return strings.ContainsAny(expr, "+-*/<>=!&|") +} + +// containsFunctions checks if expression contains function calls +func containsFunctions(expr string) bool { + funcs := extractAllFunctions(expr) + return len(funcs) > 0 +} + // Extract aggregation function fields and parse expression information func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName string, expression string, allFields []string) { + start := strings.Index(strings.ToLower(exprStr), strings.ToLower(funcName)+"(") if start < 0 { return "", "", nil @@ -439,6 +526,44 @@ func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName s return "*", "", nil } + // Check if it's a registered function and get its type + if fn, exists := functions.Get(funcName); exists { + // For string functions that need special parameter parsing + if fn.GetType() == functions.TypeString { + // Intelligently parse function parameters to extract field names + var fields []string + params := parseSmartParameters(fieldExpr) + for _, param := range params { + param = strings.TrimSpace(param) + // If parameter is not string constant (not surrounded by quotes), consider it as field name + if !((strings.HasPrefix(param, "'") && strings.HasSuffix(param, "'")) || + (strings.HasPrefix(param, "\"") && strings.HasSuffix(param, "\""))) { + if isIdentifier(param) { + fields = append(fields, param) + } + } + } + if len(fields) > 0 { + // For string functions, save complete function call as expression + // Return all extracted fields as allFields + return fields[0], strings.ToLower(funcName) + "(" + fieldExpr + ")", fields + } + // If no field found, return empty field name but keep expression + return "", strings.ToLower(funcName) + "(" + fieldExpr + ")", nil + } + } + + // Check if it's a multi-parameter function call (contains comma) + if strings.Contains(fieldExpr, ",") { + // For multi-parameter functions, extract the first parameter as the field name + params := strings.Split(fieldExpr, ",") + if len(params) > 0 { + firstParam := strings.TrimSpace(params[0]) + // Return first parameter as field name, and full expression for parameter processing + return firstParam, fieldExpr, nil + } + } + // Check if it's a simple field name (only letters, numbers, underscores) isSimpleField := true for _, char := range fieldExpr { @@ -457,29 +582,6 @@ func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName s // For complex expressions, including multi-parameter function calls expression = fieldExpr - // For string functions like CONCAT, save complete expression directly - if strings.ToLower(funcName) == "concat" { - // Intelligently parse CONCAT function parameters to extract field names - var fields []string - params := parseSmartParameters(fieldExpr) - for _, param := range params { - param = strings.TrimSpace(param) - // If parameter is not string constant (not surrounded by quotes), consider it as field name - if !((strings.HasPrefix(param, "'") && strings.HasSuffix(param, "'")) || - (strings.HasPrefix(param, "\"") && strings.HasSuffix(param, "\""))) { - if isIdentifier(param) { - fields = append(fields, param) - } - } - } - if len(fields) > 0 { - // For CONCAT function, save complete function call as expression - return fields[0], funcName + "(" + fieldExpr + ")", fields - } - // If no field found, return empty field name but keep expression - return "", funcName + "(" + fieldExpr + ")", nil - } - // Use expression engine to parse parsedExpr, err := expr.NewExpression(fieldExpr) if err != nil { @@ -521,6 +623,7 @@ func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName s } // If no fields (pure constant expression), return entire expression as field name + return fieldExpr, expression, nil } @@ -649,58 +752,313 @@ func parseAggregateExpression(expr string) string { return "" } -// Parse field information including expressions +// Parse field information including expressions with post-aggregation support func buildSelectFieldsWithExpressions(fields []Field) ( aggMap map[string]aggregator.AggregateType, fieldMap map[string]string, - expressions map[string]types.FieldExpression) { + expressions map[string]types.FieldExpression, + postAggExpressions []types.PostAggregationExpression, + err error) { selectFields := make(map[string]aggregator.AggregateType) fieldMap = make(map[string]string) expressions = make(map[string]types.FieldExpression) + postAggExpressions = make([]types.PostAggregationExpression, 0) for _, f := range fields { - if alias := f.Alias; alias != "" { - t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression) - if t != "" { - // Use alias as key so each aggregation function has unique key - selectFields[alias] = t + alias := f.Alias + if alias == "" { + // For string literals without alias, use the content without quotes as alias + trimmed := strings.TrimSpace(f.Expression) + if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) || + (strings.HasPrefix(trimmed, "\"") && strings.HasSuffix(trimmed, "\"")) { + alias = trimmed[1 : len(trimmed)-1] // Remove quotes + } else { + alias = f.Expression + } + } - // Field mapping: output field name -> input field name (prepare correct mapping for aggregator) - if n != "" { - fieldMap[alias] = n - } else { - // If no field name extracted, use alias itself - fieldMap[alias] = alias + // Check if this is a complex aggregation expression + if isComplexAggregationExpression(f.Expression) { + // Parse complex aggregation expression + aggFields, exprTemplate, err := parseComplexAggregationExpression(f.Expression) + if err == nil && len(aggFields) > 0 { + // Add individual aggregation functions + for _, aggField := range aggFields { + selectFields[aggField.Placeholder] = aggField.AggType + fieldMap[aggField.Placeholder] = aggField.InputField } - // If expression exists, save expression information - if expression != "" { - expressions[alias] = types.FieldExpression{ - Field: n, - Expression: expression, - Fields: allFields, - } + // Add post-aggregation expression + postAggExpressions = append(postAggExpressions, types.PostAggregationExpression{ + OutputField: alias, + OriginalExpr: f.Expression, + ExpressionTemplate: exprTemplate, + RequiredFields: aggFields, + }) + + // Mark the main field as post-aggregation + selectFields[alias] = "post_aggregation" + fieldMap[alias] = alias + continue + } + } + + // Handle as regular expression + t, n, expression, allFields, parseErr := ParseAggregateTypeWithExpression(f.Expression) + if parseErr != nil { + // 如果检测到嵌套聚合函数,返回错误 + return nil, nil, nil, nil, parseErr + } + if t != "" { + // Check if this is a multi-parameter function that needs special handling + isMultiParamFunction := false + if expression != "" && strings.Contains(expression, ",") { + // Check if the function needs multi-parameter handling + funcName := extractFunctionName(f.Expression) + if fn, exists := functions.Get(funcName); exists { + minArgs := fn.GetMinArgs() + maxArgs := fn.GetMaxArgs() + // Function needs multi-parameter handling if it has multiple parameters + isMultiParamFunction = minArgs > 1 || (maxArgs > minArgs && minArgs >= 1) } } - } else { - // Without alias, use expression itself as field name - t, n, expression, allFields := ParseAggregateTypeWithExpression(f.Expression) - if t != "" && n != "" { - // For string literals, use parsed field name (remove quotes) as key - selectFields[n] = t - fieldMap[n] = n - // If expression exists, save expression information - if expression != "" { - expressions[n] = types.FieldExpression{ - Field: n, - Expression: expression, - Fields: allFields, - } + // For multi-parameter functions, treat as post-aggregation expression + if isMultiParamFunction { + // Parse as single aggregation function with parameters + aggFields := []types.AggregationFieldInfo{{ + FuncName: extractFunctionName(f.Expression), + InputField: n, + Placeholder: "__" + extractFunctionName(f.Expression) + "_" + alias + "__", + AggType: aggregator.AggregateType(extractFunctionName(f.Expression)), + FullCall: f.Expression, + }} + + // Add the aggregation function + selectFields[aggFields[0].Placeholder] = aggFields[0].AggType + fieldMap[aggFields[0].Placeholder] = aggFields[0].InputField + + // Add post-aggregation expression (which just returns the placeholder value) + postAggExpressions = append(postAggExpressions, types.PostAggregationExpression{ + OutputField: alias, + OriginalExpr: f.Expression, + ExpressionTemplate: aggFields[0].Placeholder, + RequiredFields: aggFields, + }) + + // Mark the main field as post-aggregation + selectFields[alias] = "post_aggregation" + fieldMap[alias] = alias + continue + } + + // Use alias as key so each aggregation function has unique key + selectFields[alias] = t + + // Field mapping: output field name -> input field name (prepare correct mapping for aggregator) + if n != "" { + fieldMap[alias] = n + } else { + // If no field name extracted, use alias itself + fieldMap[alias] = alias + } + + // If expression exists, save expression information + if expression != "" { + expressions[alias] = types.FieldExpression{ + Field: n, + Expression: expression, + Fields: allFields, } } } } - return selectFields, fieldMap, expressions + return selectFields, fieldMap, expressions, postAggExpressions, nil +} + +// isComplexAggregationExpression checks if an expression contains multiple aggregation functions or operators with aggregation functions +func isComplexAggregationExpression(expr string) bool { + // Check if expression contains aggregation functions + funcs := extractAllFunctions(expr) + aggCount := 0 + nonAggCount := 0 + + for _, funcName := range funcs { + if fn, exists := functions.Get(funcName); exists { + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + aggCount++ + default: + nonAggCount++ + } + } else { + nonAggCount++ + } + } + + // Determine the outermost function name (if any) + outerFuncName := "" + if m := regexp.MustCompile(`(?i)^\s*([a-z_][a-z0-9_]*)\s*\(`).FindStringSubmatch(expr); len(m) == 2 { + outerFuncName = strings.ToLower(m[1]) + } + outerIsAggregation := false + if outerFuncName != "" { + if fn, ok := functions.Get(outerFuncName); ok { + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + outerIsAggregation = true + } + } + } + + // Special case: single aggregation function with nested expression (only when OUTER is aggregation) + isSingleAggWithNestedFunc := false + if aggCount == 1 && outerIsAggregation { + start := strings.Index(expr, "(") + end := strings.LastIndex(expr, ")") + if start != -1 && end != -1 && end > start { + innerExpr := strings.TrimSpace(expr[start+1 : end]) + if !containsOperators(innerExpr) { + isSingleAggWithNestedFunc = true + } + } + } + + result := (aggCount > 1) || + (aggCount > 0 && containsOperatorsOutsideFunctions(expr) && !isSingleAggWithNestedFunc) || + (aggCount > 0 && nonAggCount > 0 && !isSingleAggWithNestedFunc) + + return result +} + +// containsOperatorsOutsideFunctions checks if expression contains operators outside function calls +func containsOperatorsOutsideFunctions(expr string) bool { + // Remove function calls first, then check for operators + // Simple approach: if it's just a single function call, it shouldn't be treated as complex + trimmed := strings.TrimSpace(expr) + + // If it starts with a function name and ends with ), it's likely a simple function call + if match := regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*\s*\([^)]*\)$`).FindString(trimmed); match == trimmed { + return false + } + + // Check for operators + return containsOperators(expr) +} + +// parseComplexAggregationExpression parses expressions containing multiple aggregation functions +func parseComplexAggregationExpression(expr string) ([]types.AggregationFieldInfo, string, error) { + return parseComplexAggExpressionInternal(expr) +} + +// parseComplexAggExpressionInternal implements the actual parsing logic +func parseComplexAggExpressionInternal(expr string) ([]types.AggregationFieldInfo, string, error) { + // 首先检测嵌套聚合 + if err := detectNestedAggregation(expr); err != nil { + return nil, "", err + } + + // 使用改进的递归解析方法 + aggFields, exprTemplate := parseNestedFunctionsInternal(expr, make([]types.AggregationFieldInfo, 0)) + return aggFields, exprTemplate, nil +} + +// parseNestedFunctionsInternal 递归解析嵌套函数调用 +func parseNestedFunctionsInternal(expr string, aggFields []types.AggregationFieldInfo) ([]types.AggregationFieldInfo, string) { + // 匹配函数调用,支持大小写不敏感 + pattern := regexp.MustCompile(`(?i)([a-z_]+)\s*\(`) + + // 找到所有函数调用的起始位置 + matches := pattern.FindAllStringSubmatchIndex(expr, -1) + if len(matches) == 0 { + return aggFields, expr + } + + // 从右到左处理,避免索引偏移问题 + for i := len(matches) - 1; i >= 0; i-- { + match := matches[i] + funcStart := match[0] + funcName := strings.ToLower(expr[match[2]:match[3]]) + + // 找到匹配的右括号 + parenStart := match[3] + parenEnd := findMatchingParenInternal(expr, parenStart) + if parenEnd == -1 { + continue + } + + fullFuncCall := expr[funcStart : parenEnd+1] + funcParam := expr[parenStart+1 : parenEnd] + + // 检查是否是聚合函数 + if fn, exists := functions.Get(funcName); exists { + switch fn.GetType() { + case functions.TypeAggregation, functions.TypeAnalytical, functions.TypeWindow: + // 生成唯一占位符 + callHash := 0 + for _, c := range fullFuncCall { + callHash = callHash*31 + int(c) + } + if callHash < 0 { + callHash = -callHash + } + placeholder := fmt.Sprintf("__%s_%d__", funcName, callHash) + + // 解析函数参数 + inputField := strings.TrimSpace(funcParam) + // 对于聚合函数,如果参数包含嵌套函数调用,保留完整参数 + // 只有在参数是简单的逗号分隔列表时才进行分割 + if strings.Contains(funcParam, ",") && !containsNestedFunctions(funcParam) { + params := strings.Split(funcParam, ",") + if len(params) > 0 { + inputField = strings.TrimSpace(params[0]) + } + } + + // 添加到聚合字段列表 + fieldInfo := types.AggregationFieldInfo{ + FuncName: funcName, + InputField: inputField, + Placeholder: placeholder, + AggType: aggregator.AggregateType(funcName), + FullCall: fullFuncCall, + } + aggFields = append(aggFields, fieldInfo) + + // 替换表达式中的聚合函数调用 + expr = expr[:funcStart] + placeholder + expr[parenEnd+1:] + } + } + } + + return aggFields, expr +} + +// containsNestedFunctions 检查参数字符串是否包含嵌套函数调用 +func containsNestedFunctions(param string) bool { + // 简单检查:如果包含函数名模式后跟括号,则认为是嵌套函数 + pattern := regexp.MustCompile(`[a-zA-Z_][a-zA-Z0-9_]*\s*\(`) + return pattern.MatchString(param) +} + +// findMatchingParenInternal 找到匹配的右括号 +func findMatchingParenInternal(s string, start int) int { + if start >= len(s) || s[start] != '(' { + return -1 + } + + count := 1 + for i := start + 1; i < len(s); i++ { + switch s[i] { + case '(': + count++ + case ')': + count-- + if count == 0 { + return i + } + } + } + return -1 // 未找到匹配的右括号 } diff --git a/rsql/ast_test.go b/rsql/ast_test.go index 78f1be0..4cdc1cf 100644 --- a/rsql/ast_test.go +++ b/rsql/ast_test.go @@ -360,7 +360,11 @@ func TestBuildSelectFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - aggMap, fieldMap := buildSelectFields(tt.fields) + aggMap, fieldMap, err := buildSelectFields(tt.fields) + if err != nil { + t.Errorf("buildSelectFields() error = %v", err) + return + } // 检查聚合函数映射 if len(aggMap) != len(tt.wantAggs) { @@ -498,9 +502,14 @@ func TestParseAggregateTypeWithExpression(t *testing.T) { }, } + // 测试正常情况 for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - aggType, name, expression, allFields := ParseAggregateTypeWithExpression(tt.exprStr) + aggType, name, expression, allFields, err := ParseAggregateTypeWithExpression(tt.exprStr) + if err != nil { + t.Errorf("ParseAggregateTypeWithExpression() returned error: %v", err) + return + } if string(aggType) != tt.wantAggType { t.Errorf("ParseAggregateTypeWithExpression() aggType = %s, want %s", aggType, tt.wantAggType) @@ -524,6 +533,82 @@ func TestParseAggregateTypeWithExpression(t *testing.T) { } }) } + + // 测试嵌套聚合函数检测 + nestedTests := []struct { + name string + exprStr string + }{ + { + name: "嵌套聚合函数 - MAX(AVG(temperature))", + exprStr: "MAX(AVG(temperature))", + }, + { + name: "嵌套聚合函数 - COUNT(SUM(price))", + exprStr: "COUNT(SUM(price))", + }, + { + name: "复杂嵌套 - MAX(ROUND(AVG(temperature), 1))", + exprStr: "MAX(ROUND(AVG(temperature), 1))", + }, + } + + for _, tt := range nestedTests { + t.Run(tt.name, func(t *testing.T) { + _, _, _, _, err := ParseAggregateTypeWithExpression(tt.exprStr) + if err == nil { + t.Errorf("ParseAggregateTypeWithExpression() should return error for nested aggregation: %s", tt.exprStr) + } else if !strings.Contains(err.Error(), "aggregate function calls cannot be nested") { + t.Errorf("ParseAggregateTypeWithExpression() error message should contain 'aggregate function calls cannot be nested', got: %v", err) + } + }) + } +} + +// TestDetectNestedAggregation 测试嵌套聚合函数检测 +func TestDetectNestedAggregation(t *testing.T) { + tests := []struct { + name string + exprStr string + wantError bool + }{ + { + name: "正常聚合函数", + exprStr: "MAX(temperature)", + wantError: false, + }, + { + name: "嵌套聚合函数", + exprStr: "MAX(AVG(temperature))", + wantError: true, + }, + { + name: "复杂嵌套", + exprStr: "MAX(ROUND(AVG(temperature), 1))", + wantError: true, + }, + { + name: "非聚合函数嵌套", + exprStr: "UPPER(CONCAT(first_name, last_name))", + wantError: false, + }, + { + name: "聚合函数包含非聚合函数", + exprStr: "MAX(ROUND(temperature, 1))", + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := detectNestedAggregation(tt.exprStr) + if tt.wantError && err == nil { + t.Errorf("detectNestedAggregation() should return error for: %s", tt.exprStr) + } else if !tt.wantError && err != nil { + t.Errorf("detectNestedAggregation() should not return error for: %s, got: %v", tt.exprStr, err) + } + }) + } } // TestExtractAggFieldWithExpression 测试 extractAggFieldWithExpression 函数 diff --git a/rsql/coverage_test.go b/rsql/coverage_test.go index f5ed59d..9320958 100644 --- a/rsql/coverage_test.go +++ b/rsql/coverage_test.go @@ -653,7 +653,11 @@ func TestBuildSelectFieldsWithExpressions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - aggMap, fieldMap, expressions := buildSelectFieldsWithExpressions(tt.fields) + aggMap, fieldMap, expressions, _, err := buildSelectFieldsWithExpressions(tt.fields) + if err != nil { + t.Errorf("buildSelectFieldsWithExpressions() error = %v", err) + return + } tt.checkFunc(t, aggMap, fieldMap, expressions) }) } diff --git a/rsql/parser.go b/rsql/parser.go index 57892dc..85edeae 100644 --- a/rsql/parser.go +++ b/rsql/parser.go @@ -185,7 +185,7 @@ func (p *Parser) createCombinedError() error { for i, err := range errors { builder.WriteString(fmt.Sprintf("%d. %s\n", i+1, err.Error())) } - return fmt.Errorf(builder.String()) + return fmt.Errorf("%s", builder.String()) } func (p *Parser) parseSelect(stmt *SelectStatement) error { diff --git a/stream/processor_data.go b/stream/processor_data.go index a413321..3b4df97 100644 --- a/stream/processor_data.go +++ b/stream/processor_data.go @@ -88,7 +88,30 @@ func (dp *DataProcessor) Process() { func (dp *DataProcessor) initializeAggregator() { // Convert to new AggregationField format aggregationFields := convertToAggregationFields(dp.stream.config.SelectFields, dp.stream.config.FieldAlias) - dp.stream.aggregator = aggregator.NewGroupAggregator(dp.stream.config.GroupFields, aggregationFields) + + // Check if we have post-aggregation expressions + if len(dp.stream.config.PostAggExpressions) > 0 { + // Use enhanced aggregator for post-aggregation support + enhancedAgg := aggregator.NewEnhancedGroupAggregator(dp.stream.config.GroupFields, aggregationFields) + + // Add post-aggregation expressions + for _, postExpr := range dp.stream.config.PostAggExpressions { + err := enhancedAgg.AddPostAggregationExpression( + postExpr.OutputField, + postExpr.OriginalExpr, + convertToAggregationFieldInfos(postExpr.RequiredFields), + ) + if err != nil { + // Log error but continue + fmt.Printf("Error adding post-aggregation expression %s: %v\n", postExpr.OriginalExpr, err) + } + } + + dp.stream.aggregator = enhancedAgg + } else { + // Use regular aggregator + dp.stream.aggregator = aggregator.NewGroupAggregator(dp.stream.config.GroupFields, aggregationFields) + } // Register expression calculators for field, fieldExpr := range dp.stream.config.FieldExpressions { @@ -96,6 +119,21 @@ func (dp *DataProcessor) initializeAggregator() { } } +// convertToAggregationFieldInfos converts types.AggregationFieldInfo to aggregator.AggregationFieldInfo +func convertToAggregationFieldInfos(fields []types.AggregationFieldInfo) []aggregator.AggregationFieldInfo { + result := make([]aggregator.AggregationFieldInfo, len(fields)) + for i, field := range fields { + result[i] = aggregator.AggregationFieldInfo{ + FuncName: field.FuncName, + InputField: field.InputField, + Placeholder: field.Placeholder, + AggType: field.AggType, + FullCall: field.FullCall, // 保持FullCall字段 + } + } + return result +} + // registerExpressionCalculator registers expression calculator func (dp *DataProcessor) registerExpressionCalculator(field string, fieldExpr types.FieldExpression) { // Create local variables to avoid closure issues diff --git a/streamsql_benchmark_test.go b/streamsql_benchmark_test.go index 2db58fa..7205510 100644 --- a/streamsql_benchmark_test.go +++ b/streamsql_benchmark_test.go @@ -1,936 +1,446 @@ -package streamsql - -import ( - "context" - "math/rand" - "sync/atomic" - "testing" - "time" -) - -// BenchmarkStreamSQLCore 核心性能基准测试 -func BenchmarkStreamSQLCore(b *testing.B) { - tests := []struct { - name string - sql string - hasWindow bool - waitTime time.Duration - }{ - { - name: "SimpleFilter", - sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", - hasWindow: false, - waitTime: 50 * time.Millisecond, - }, - { - name: "WindowAggregation", - sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')", - hasWindow: true, - waitTime: 200 * time.Millisecond, - }, - { - name: "ComplexQuery", - sql: "SELECT deviceId, AVG(temperature), COUNT(*) FROM stream WHERE humidity > 50 GROUP BY deviceId, TumblingWindow('100ms')", - hasWindow: true, - waitTime: 250 * time.Millisecond, - }, - } - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - // 使用默认配置进行基准测试 - ssql := New() - defer ssql.Stop() - - err := ssql.Execute(tt.sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - var resultReceived int64 - - // 添加结果处理器 - ssql.AddSink(func(result []map[string]interface{}) { - atomic.AddInt64(&resultReceived, 1) - }) - - // 异步消费结果通道防止阻塞 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - // 生成测试数据 - testData := generateTestData(5) - - // 重置统计 - ssql.Stream().ResetStats() - - b.ResetTimer() - - // 执行基准测试 - start := time.Now() - for i := 0; i < b.N; i++ { - ssql.Emit(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - b.StopTimer() - - // 等待处理完成 - time.Sleep(tt.waitTime) - cancel() - - // 获取统计信息 - stats := ssql.Stream().GetStats() - received := atomic.LoadInt64(&resultReceived) - - // 计算性能指标 - inputThroughput := float64(b.N) / inputDuration.Seconds() - processedCount := stats["output_count"] - droppedCount := stats["dropped_count"] - processRate := float64(processedCount) / float64(b.N) * 100 - dropRate := float64(droppedCount) / float64(b.N) * 100 - - b.ReportMetric(inputThroughput, "ops/sec") - b.ReportMetric(processRate, "process_rate_%") - b.ReportMetric(dropRate, "drop_rate_%") - b.ReportMetric(float64(received), "results") - - b.Logf("%s - 吞吐量: %.0f ops/sec, 处理率: %.1f%%, 丢弃率: %.2f%%", - tt.name, inputThroughput, processRate, dropRate) - }) - } -} - -// BenchmarkConfigComparison 配置对比基准测试 -func BenchmarkConfigComparison(b *testing.B) { - tests := []struct { - name string - setupFunc func() *Streamsql - }{ - { - name: "Default", - setupFunc: func() *Streamsql { - return New() - }, - }, - { - name: "HighPerformance", - setupFunc: func() *Streamsql { - return New(WithHighPerformance()) - }, - }, - { - name: "Lightweight", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(5000, 5000, 250)) - }, - }, - } - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - ssql := tt.setupFunc() - defer ssql.Stop() - - err := ssql.Execute(sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - var resultCount int64 - ssql.AddSink(func(result []map[string]interface{}) { - atomic.AddInt64(&resultCount, 1) - }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - testData := generateTestData(3) - ssql.Stream().ResetStats() - - b.ResetTimer() - - start := time.Now() - for i := 0; i < b.N; i++ { - ssql.Emit(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - b.StopTimer() - - time.Sleep(50 * time.Millisecond) - cancel() - - stats := ssql.Stream().GetStats() - - inputThroughput := float64(b.N) / inputDuration.Seconds() - processedCount := stats["output_count"] - droppedCount := stats["dropped_count"] - processRate := float64(processedCount) / float64(b.N) * 100 - dropRate := float64(droppedCount) / float64(b.N) * 100 - - b.ReportMetric(inputThroughput, "ops/sec") - b.ReportMetric(processRate, "process_rate_%") - b.ReportMetric(dropRate, "drop_rate_%") - - b.Logf("%s配置 - 吞吐量: %.0f ops/sec, 处理率: %.1f%%, 丢弃率: %.2f%%", - tt.name, inputThroughput, processRate, dropRate) - }) - } -} - -// BenchmarkPureInput 纯输入性能基准测试 -func BenchmarkPureInput(b *testing.B) { - ssql := New(WithHighPerformance()) - defer ssql.Stop() - - sql := "SELECT deviceId FROM stream" - err := ssql.Execute(sql) - if err != nil { - b.Fatal(err) - } - - // 启动结果消费者防止阻塞 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - // 预生成数据 - data := map[string]interface{}{ - "deviceId": "device1", - "temperature": 25.0, - } - - b.ResetTimer() - start := time.Now() - - for i := 0; i < b.N; i++ { - ssql.Emit(data) - } - - b.StopTimer() - duration := time.Since(start) - throughput := float64(b.N) / duration.Seconds() - b.ReportMetric(throughput, "pure_input_ops/sec") - - b.Logf("纯输入性能: %.0f ops/sec (%.1f万 ops/sec)", throughput, throughput/10000) -} - -// generateTestData 生成测试数据 -func generateTestData(count int) []map[string]interface{} { - data := make([]map[string]interface{}, count) - devices := []string{"device1", "device2", "device3", "device4", "device5"} - - for i := 0; i < count; i++ { - data[i] = map[string]interface{}{ - "deviceId": devices[rand.Intn(len(devices))], - "temperature": 15.0 + rand.Float64()*20, // 15-35度 - "humidity": 30.0 + rand.Float64()*40, // 30-70% - "timestamp": time.Now().UnixNano(), - } - } - return data -} - -// BenchmarkConfigurationComparison 不同配置性能对比基准测试 -func BenchmarkConfigurationComparison(b *testing.B) { - tests := []struct { - name string - setupFunc func() *Streamsql - description string - }{ - { - name: "轻量配置", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(5000, 5000, 250)) - }, - description: "5K数据缓冲,5K结果缓冲,250 sink池", - }, - { - name: "默认配置(中等场景)", - setupFunc: func() *Streamsql { - return New() - }, - description: "20K数据缓冲,20K结果缓冲,800 sink池", - }, - { - name: "重负载配置", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(35000, 35000, 1200)) - }, - description: "35K数据缓冲,35K结果缓冲,1.2K sink池", - }, - { - name: "高性能配置", - setupFunc: func() *Streamsql { - return New(WithHighPerformance()) - }, - description: "50K数据缓冲,50K结果缓冲,1K sink池", - }, - { - name: "超大缓冲配置", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(100000, 100000, 2000)) - }, - description: "100K数据缓冲,100K结果缓冲,2K sink池", - }, - } - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - ssql := tt.setupFunc() - defer ssql.Stop() - - err := ssql.Execute(sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - var resultCount int64 - - // 添加轻量级sink - ssql.AddSink(func(result []map[string]interface{}) { - atomic.AddInt64(&resultCount, 1) - }) - - // 异步消费resultChan - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - // 快速消费 - case <-ctx.Done(): - return - } - } - }() - - // 测试数据 - testData := generateTestData(3) - - b.ResetTimer() - - // 执行基准测试 - start := time.Now() - for i := 0; i < b.N; i++ { - ssql.Emit(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - b.StopTimer() - - // 等待处理完成 - time.Sleep(100 * time.Millisecond) - - cancel() - - // 获取统计信息 - detailedStats := ssql.Stream().GetDetailedStats() - results := atomic.LoadInt64(&resultCount) - - // 性能指标 - inputThroughput := float64(b.N) / inputDuration.Seconds() - processRate := detailedStats["process_rate"].(float64) - dropRate := detailedStats["drop_rate"].(float64) - perfLevel := detailedStats["performance_level"].(string) - - b.ReportMetric(inputThroughput, "input_ops/sec") - b.ReportMetric(processRate, "process_rate_%") - b.ReportMetric(dropRate, "drop_rate_%") - b.ReportMetric(float64(results), "results_count") - - // 详细报告 - b.Logf("配置: %s", tt.description) - b.Logf("性能等级: %s", perfLevel) - b.Logf("处理效率: %.2f%%, 丢弃率: %.2f%%", processRate, dropRate) - - // 缓冲区使用情况 - dataChanUsage := detailedStats["data_chan_usage"].(float64) - resultChanUsage := detailedStats["result_chan_usage"].(float64) - sinkPoolUsage := detailedStats["sink_pool_usage"].(float64) - - b.Logf("缓冲区使用率 - 数据: %.1f%%, 结果: %.1f%%, Sink池: %.1f%%", - dataChanUsage, resultChanUsage, sinkPoolUsage) - }) - } -} - -// TestMemoryUsageComparison 内存使用对比测试 -//func TestMemoryUsageComparison(t *testing.T) { -// tests := []struct { -// name string -// setupFunc func() *Streamsql -// description string -// expectedMB float64 // 预期内存使用(MB) -// }{ -// { -// name: "轻量配置", -// setupFunc: func() *Streamsql { -// return New(WithBufferSizes(5000, 5000, 250)) -// }, -// description: "5K数据 + 5K结果 + 250sink池", -// expectedMB: 1.0, // 预期约1MB -// }, -// { -// name: "默认配置(中等场景)", -// setupFunc: func() *Streamsql { -// return New() -// }, -// description: "20K数据 + 20K结果 + 800sink池", -// expectedMB: 3.0, // 预期约3MB -// }, -// { -// name: "高性能配置", -// setupFunc: func() *Streamsql { -// return New(WithHighPerformance()) -// }, -// description: "50K数据 + 50K结果 + 1Ksinki池", -// expectedMB: 12.0, // 预期约12MB -// }, -// { -// name: "超大缓冲配置", -// setupFunc: func() *Streamsql { -// return New(WithBufferSizes(100000, 100000, 2000)) -// }, -// description: "100K数据缓冲,100K结果缓冲,2Ksinki池", -// expectedMB: 25.0, // 预期约25MB -// }, -// } -// -// sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" -// -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// // 获取开始内存 -// var startMem runtime.MemStats -// runtime.GC() -// runtime.ReadMemStats(&startMem) -// -// // 创建Stream -// ssql := tt.setupFunc() -// err := ssql.Execute(sql) -// if err != nil { -// t.Fatalf("SQL执行失败: %v", err) -// } -// -// // 等待初始化完成 -// time.Sleep(10 * time.Millisecond) -// -// // 获取创建后内存 -// var afterCreateMem runtime.MemStats -// runtime.GC() -// runtime.ReadMemStats(&afterCreateMem) -// -// createUsage := float64(afterCreateMem.Alloc-startMem.Alloc) / 1024 / 1024 -// -// // 添加一些数据测试内存增长 -// testData := generateTestData(3) -// for i := 0; i < 1000; i++ { -// ssql.Emit(testData[i%len(testData)]) -// } -// -// time.Sleep(50 * time.Millisecond) -// -// // 获取使用后内存 -// var afterUseMem runtime.MemStats -// runtime.GC() -// runtime.ReadMemStats(&afterUseMem) -// -// totalUsage := float64(afterUseMem.Alloc-startMem.Alloc) / 1024 / 1024 -// -// // 获取详细统计 -// detailedStats := ssql.Stream().GetDetailedStats() -// basicStats := detailedStats["basic_stats"].(map[string]int64) -// -// ssql.Stop() -// -// t.Logf("=== %s 内存使用分析 ===", tt.name) -// t.Logf("配置: %s", tt.description) -// t.Logf("创建开销: %.2f MB", createUsage) -// t.Logf("总内存使用: %.2f MB", totalUsage) -// t.Logf("缓冲区配置:") -// t.Logf(" 数据通道: %d", basicStats["data_chan_cap"]) -// t.Logf(" 结果通道: %d", basicStats["result_chan_cap"]) -// t.Logf(" Sink池: %d", basicStats["sink_pool_cap"]) -// -// // 计算理论内存使用 (每个接口槽位约24字节) -// dataChanMem := float64(basicStats["data_chan_cap"]) * 24 / 1024 / 1024 -// resultChanMem := float64(basicStats["result_chan_cap"]) * 24 / 1024 / 1024 -// sinkPoolMem := float64(basicStats["sink_pool_cap"]) * 8 / 1024 / 1024 // 函数指针 -// -// theoreticalMem := dataChanMem + resultChanMem + sinkPoolMem -// -// t.Logf("理论内存分配:") -// t.Logf(" 数据通道: %.2f MB", dataChanMem) -// t.Logf(" 结果通道: %.2f MB", resultChanMem) -// t.Logf(" Sink池: %.2f MB", sinkPoolMem) -// t.Logf(" 理论总计: %.2f MB", theoreticalMem) -// -// // 内存效率分析 -// if totalUsage > tt.expectedMB*2 { -// t.Logf("警告: 内存使用超过预期2倍 (%.2f MB > %.2f MB)", totalUsage, tt.expectedMB*2) -// } else if totalUsage > tt.expectedMB*1.5 { -// t.Logf("注意: 内存使用超过预期50%% (%.2f MB > %.2f MB)", totalUsage, tt.expectedMB*1.5) -// } else { -// t.Logf("✓ 内存使用在合理范围内 (%.2f MB)", totalUsage) -// } -// }) -// } -//} - -// BenchmarkLightweightVsDefaultComparison 轻量 vs 默认配置基准测试 -func BenchmarkLightweightVsDefaultComparison(b *testing.B) { - tests := []struct { - name string - setupFunc func() *Streamsql - }{ - { - name: "轻量配置5K", - setupFunc: func() *Streamsql { - return New(WithBufferSizes(5000, 5000, 250)) - }, - }, - { - name: "默认配置20K", - setupFunc: func() *Streamsql { - return New() - }, - }, - } - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - ssql := tt.setupFunc() - defer ssql.Stop() - - err := ssql.Execute(sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - var resultCount int64 - ssql.AddSink(func(result []map[string]interface{}) { - atomic.AddInt64(&resultCount, 1) - }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - testData := generateTestData(3) - - b.ResetTimer() - - start := time.Now() - for i := 0; i < b.N; i++ { - ssql.Emit(testData[i%len(testData)]) - } - inputDuration := time.Since(start) - - b.StopTimer() - - time.Sleep(50 * time.Millisecond) - cancel() - - detailedStats := ssql.Stream().GetDetailedStats() - results := atomic.LoadInt64(&resultCount) - - inputThroughput := float64(b.N) / inputDuration.Seconds() - processRate := detailedStats["process_rate"].(float64) - dropRate := detailedStats["drop_rate"].(float64) - dataChanUsage := detailedStats["data_chan_usage"].(float64) - - b.ReportMetric(inputThroughput, "input_ops/sec") - b.ReportMetric(processRate, "process_rate_%") - b.ReportMetric(dropRate, "drop_rate_%") - b.ReportMetric(dataChanUsage, "data_chan_usage_%") - b.ReportMetric(float64(results), "results_count") - - basicStats := detailedStats["basic_stats"].(map[string]int64) - b.Logf("缓冲区配置: 数据通道 %d, 结果通道 %d, Sink池 %d", - basicStats["data_chan_cap"], - basicStats["result_chan_cap"], - basicStats["sink_pool_cap"]) - b.Logf("性能指标: %.0f ops/sec, 处理率 %.1f%%, 丢弃率 %.2f%%, 通道使用率 %.1f%%", - inputThroughput, processRate, dropRate, dataChanUsage) - }) - } -} - -// BenchmarkStreamSQLRealistic 现实的性能基准测试 -func BenchmarkStreamSQLRealistic(b *testing.B) { - tests := []struct { - name string - sql string - hasWindow bool - waitTime time.Duration - }{ - { - name: "SimpleFilter", - sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", - hasWindow: false, - waitTime: 50 * time.Millisecond, - }, - { - name: "BasicAggregation", - sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')", - hasWindow: true, - waitTime: 200 * time.Millisecond, - }, - } - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - // 使用默认配置,避免异常大的缓冲区 - ssql := New() - defer ssql.Stop() - - err := ssql.Execute(tt.sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - var processedCount int64 - var actualResultCount int64 - - // 测量实际的处理完成 - ssql.AddSink(func(result []map[string]interface{}) { - atomic.AddInt64(&actualResultCount, 1) - }) - - // 不使用异步消费resultChan,让系统自然处理 - testData := generateTestData(3) - - // 限制测试规模,避免过度膨胀 - maxIterations := min(b.N, 10000) // 最多1万次 - - ssql.Stream().ResetStats() - b.ResetTimer() - - // 受控的输入,测量真实处理性能 - start := time.Now() - for i := 0; i < maxIterations; i++ { - // 直接使用AddData,如果系统处理不过来会自然阻塞或丢弃 - ssql.Emit(testData[i%len(testData)]) - atomic.AddInt64(&processedCount, 1) - - // 每100条数据稍微停顿,模拟真实的数据流 - if i > 0 && i%100 == 0 { - time.Sleep(1 * time.Millisecond) - } - } - inputDuration := time.Since(start) - - b.StopTimer() - - // 等待处理完成 - time.Sleep(tt.waitTime) - - processed := atomic.LoadInt64(&processedCount) - results := atomic.LoadInt64(&actualResultCount) - stats := ssql.Stream().GetStats() - - // 计算真实的处理吞吐量 - realThroughput := float64(processed) / inputDuration.Seconds() - - b.ReportMetric(realThroughput, "realistic_ops/sec") - b.ReportMetric(float64(results), "actual_results") - b.ReportMetric(float64(stats["dropped_count"]), "dropped_data") - - // 输出合理的性能数据范围 - b.Logf("实际输入: %d 条, 实际结果: %d 个", processed, results) - b.Logf("真实吞吐量: %.0f ops/sec (%.1f万 ops/sec)", realThroughput, realThroughput/10000) - - if dropped := stats["dropped_count"]; dropped > 0 { - b.Logf("丢弃数据: %d 条", dropped) - } - }) - } -} - -// min 辅助函数 -func min(a, b int) int { - if a < b { - return a - } - return b -} - -// BenchmarkPurePerformance 纯性能基准测试(无等待,无限制) -func BenchmarkPurePerformance(b *testing.B) { - ssql := New(WithHighPerformance()) - defer ssql.Stop() - - sql := "SELECT deviceId FROM stream" - err := ssql.Execute(sql) - if err != nil { - b.Fatal(err) - } - - // 启动结果消费者 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - select { - case <-ssql.Stream().GetResultsChan(): - case <-ctx.Done(): - return - } - } - }() - - // 预生成单条数据 - data := map[string]interface{}{ - "deviceId": "device1", - "temperature": 25.0, - } - - b.ResetTimer() - start := time.Now() - - // 纯输入性能测试 - for i := 0; i < b.N; i++ { - ssql.Emit(data) - } - - b.StopTimer() - duration := time.Since(start) - throughput := float64(b.N) / duration.Seconds() - cancel() - - b.ReportMetric(throughput, "pure_ops/sec") - b.Logf("纯输入性能: %.0f ops/sec (%.1f万 ops/sec)", throughput, throughput/10000) -} - -// BenchmarkEndToEndProcessing 端到端处理性能基准测试 -func BenchmarkEndToEndProcessing(b *testing.B) { - tests := []struct { - name string - sql string - batchSize int - expectOutput bool - }{ - { - name: "EndToEndFilter", - sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", - batchSize: 1000, - expectOutput: true, - }, - { - name: "EndToEndAggregation", - sql: "SELECT deviceId, COUNT(*) as count FROM stream GROUP BY deviceId, TumblingWindow('100ms')", - batchSize: 500, - expectOutput: true, - }, - } - - for _, tt := range tests { - b.Run(tt.name, func(b *testing.B) { - ssql := New() - defer ssql.Stop() - - err := ssql.Execute(tt.sql) - if err != nil { - b.Fatalf("SQL执行失败: %v", err) - } - - // 计算需要的批次数 - batches := (b.N + tt.batchSize - 1) / tt.batchSize - - var totalProcessed int64 - var totalDuration time.Duration - - testData := generateTestData(3) - - b.ResetTimer() - - // 分批处理,每批次测量完整的处理时间 - for batch := 0; batch < batches; batch++ { - currentBatchSize := tt.batchSize - if batch == batches-1 { - // 最后一批可能不满 - currentBatchSize = b.N - batch*tt.batchSize - } - - var resultsReceived int64 - resultChan := make(chan bool, currentBatchSize) - - // 设置sink来捕获结果 - ssql.AddSink(func(result []map[string]interface{}) { - count := atomic.AddInt64(&resultsReceived, 1) - if count <= int64(currentBatchSize) { - resultChan <- true - } - }) - - // 记录开始时间 - start := time.Now() - - // 输入数据 - for i := 0; i < currentBatchSize; i++ { - ssql.Emit(testData[i%len(testData)]) - } - - // 等待所有结果处理完成(对于非聚合查询) - if tt.expectOutput { - expectedResults := currentBatchSize - if tt.name == "EndToEndAggregation" { - // 聚合查询的结果数量较少,等待至少1个结果 - expectedResults = 1 - } - - receivedCount := 0 - timeout := time.After(5 * time.Second) - for receivedCount < expectedResults { - select { - case <-resultChan: - receivedCount++ - if tt.name == "EndToEndAggregation" && receivedCount >= 1 { - // 聚合查询收到1个结果就算完成 - goto batchDone - } - case <-timeout: - // 超时,记录实际收到的结果 - goto batchDone - } - } - } - - batchDone: - // 记录这批次的处理时间 - batchDuration := time.Since(start) - totalDuration += batchDuration - totalProcessed += int64(currentBatchSize) - - // 注意:没有ClearSinks方法,所以每次测试使用新的Stream实例 - } - - b.StopTimer() - - // 计算真实的端到端吞吐量 - realThroughput := float64(totalProcessed) / totalDuration.Seconds() - - b.ReportMetric(realThroughput, "end_to_end_ops/sec") - b.Logf("端到端测试 - 处理: %d 条, 总耗时: %v", totalProcessed, totalDuration) - b.Logf("端到端吞吐量: %.0f ops/sec (%.1f万 ops/sec)", realThroughput, realThroughput/10000) - }) - } -} - -// BenchmarkSustainedProcessing 持续处理性能基准测试 -func BenchmarkSustainedProcessing(b *testing.B) { - ssql := New() - defer ssql.Stop() - - sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" - err := ssql.Execute(sql) - if err != nil { - b.Fatal(err) - } - - var processedResults int64 - var lastResultTime time.Time - - // 设置结果处理器 - ssql.AddSink(func(result []map[string]interface{}) { - atomic.AddInt64(&processedResults, 1) - lastResultTime = time.Now() - }) - - testData := generateTestData(3) - - b.ResetTimer() - start := time.Now() - - // 持续输入数据 - for i := 0; i < b.N; i++ { - ssql.Emit(testData[i%len(testData)]) - - // 每1000条检查一次处理进度 - if i > 0 && i%1000 == 0 { - time.Sleep(1 * time.Millisecond) // 让系统有时间处理 - } - } - - inputEnd := time.Now() - inputDuration := inputEnd.Sub(start) - - // 等待所有结果处理完成 - for { - current := atomic.LoadInt64(&processedResults) - if current >= int64(b.N) { - break - } - if time.Since(lastResultTime) > 2*time.Second { - // 2秒没有新结果,认为处理完成 - break - } - time.Sleep(10 * time.Millisecond) - } - - totalDuration := time.Since(start) - final := atomic.LoadInt64(&processedResults) - - b.StopTimer() - - inputThroughput := float64(b.N) / inputDuration.Seconds() - sustainedThroughput := float64(final) / totalDuration.Seconds() - - b.ReportMetric(inputThroughput, "input_rate_ops/sec") - b.ReportMetric(sustainedThroughput, "sustained_ops/sec") - b.ReportMetric(float64(final), "processed_count") - - b.Logf("持续处理测试:") - b.Logf(" 输入速率: %.0f ops/sec (%.1f万 ops/sec)", inputThroughput, inputThroughput/10000) - b.Logf(" 持续处理速率: %.0f ops/sec (%.1f万 ops/sec)", sustainedThroughput, sustainedThroughput/10000) - b.Logf(" 处理完成率: %.1f%% (%d/%d)", float64(final)/float64(b.N)*100, final, b.N) -} +package streamsql + +import ( + "context" + "math/rand" + "sync/atomic" + "testing" + "time" +) + +// BenchmarkStreamSQL StreamSQL基准测试 +func BenchmarkStreamSQL(b *testing.B) { + tests := []struct { + name string + sql string + waitTime time.Duration + }{ + { + name: "SimpleFilter", + sql: "SELECT deviceId, temperature FROM stream WHERE temperature > 20", + waitTime: 50 * time.Millisecond, + }, + { + name: "BasicAggregation", + sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')", + waitTime: 150 * time.Millisecond, + }, + { + name: "ComplexQuery", + sql: "SELECT deviceId, AVG(temperature), COUNT(*) FROM stream WHERE humidity > 50 GROUP BY deviceId, TumblingWindow('100ms')", + waitTime: 200 * time.Millisecond, + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + // 使用默认配置 + ssql := New() + defer ssql.Stop() + + err := ssql.Execute(tt.sql) + if err != nil { + b.Fatalf("SQL执行失败: %v", err) + } + + var resultCount int64 + ssql.AddSink(func(result []map[string]interface{}) { + atomic.AddInt64(&resultCount, 1) + }) + + // 异步消费结果防止阻塞 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + select { + case <-ssql.Stream().GetResultsChan(): + case <-ctx.Done(): + return + } + } + }() + + // 生成测试数据 + testData := generateOptimizedTestData(5) + ssql.Stream().ResetStats() + + b.ResetTimer() + + // 执行基准测试 + start := time.Now() + for i := 0; i < b.N; i++ { + ssql.Emit(testData[i%len(testData)]) + } + inputDuration := time.Since(start) + + b.StopTimer() + + // 等待处理完成 + time.Sleep(tt.waitTime) + cancel() + + // 获取统计信息 + stats := ssql.Stream().GetStats() + results := atomic.LoadInt64(&resultCount) + + // 计算核心性能指标 + throughput := float64(b.N) / inputDuration.Seconds() + processedCount := stats["output_count"] + droppedCount := stats["dropped_count"] + processRate := float64(processedCount) / float64(b.N) * 100 + dropRate := float64(droppedCount) / float64(b.N) * 100 + + // 报告指标 + b.ReportMetric(throughput, "ops/sec") + b.ReportMetric(processRate, "process_rate_%") + b.ReportMetric(dropRate, "drop_rate_%") + b.ReportMetric(float64(results), "results") + + // 输出可读的性能报告 + b.Logf("性能报告 - %s:", tt.name) + b.Logf(" 吞吐量: %.0f ops/sec (%.1f万 ops/sec)", throughput, throughput/10000) + b.Logf(" 处理率: %.1f%%, 丢弃率: %.2f%%", processRate, dropRate) + b.Logf(" 结果数: %d", results) + }) + } +} + +// BenchmarkConfigurationOptimized 优化后的配置对比基准测试 +func BenchmarkConfigurationOptimized(b *testing.B) { + configs := []struct { + name string + setupFunc func() *Streamsql + }{ + { + name: "Lightweight", + setupFunc: func() *Streamsql { + return New(WithBufferSizes(5000, 5000, 250)) + }, + }, + { + name: "Default", + setupFunc: func() *Streamsql { + return New() + }, + }, + { + name: "HighPerformance", + setupFunc: func() *Streamsql { + return New(WithHighPerformance()) + }, + }, + } + + sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" + + for _, config := range configs { + b.Run(config.name, func(b *testing.B) { + ssql := config.setupFunc() + defer ssql.Stop() + + err := ssql.Execute(sql) + if err != nil { + b.Fatalf("SQL执行失败: %v", err) + } + + var resultCount int64 + ssql.AddSink(func(result []map[string]interface{}) { + atomic.AddInt64(&resultCount, 1) + }) + + // 异步消费结果 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + select { + case <-ssql.Stream().GetResultsChan(): + case <-ctx.Done(): + return + } + } + }() + + testData := generateOptimizedTestData(3) + ssql.Stream().ResetStats() + + b.ResetTimer() + + start := time.Now() + for i := 0; i < b.N; i++ { + ssql.Emit(testData[i%len(testData)]) + } + inputDuration := time.Since(start) + + b.StopTimer() + + time.Sleep(100 * time.Millisecond) + cancel() + + // 获取详细统计 + detailedStats := ssql.Stream().GetDetailedStats() + results := atomic.LoadInt64(&resultCount) + + throughput := float64(b.N) / inputDuration.Seconds() + processRate := detailedStats["process_rate"].(float64) + dropRate := detailedStats["drop_rate"].(float64) + + b.ReportMetric(throughput, "ops/sec") + b.ReportMetric(processRate, "process_rate_%") + b.ReportMetric(dropRate, "drop_rate_%") + + b.Logf("%s配置性能:", config.name) + b.Logf(" 吞吐量: %.0f ops/sec (%.1f万 ops/sec)", throughput, throughput/10000) + b.Logf(" 处理率: %.1f%%, 丢弃率: %.2f%%", processRate, dropRate) + b.Logf(" 结果数: %d", results) + }) + } +} + +// BenchmarkPureInputOptimized 优化后的纯输入性能测试 +func BenchmarkPureInputOptimized(b *testing.B) { + ssql := New(WithHighPerformance()) + defer ssql.Stop() + + sql := "SELECT deviceId FROM stream" + err := ssql.Execute(sql) + if err != nil { + b.Fatal(err) + } + + // 启动结果消费者 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + select { + case <-ssql.Stream().GetResultsChan(): + case <-ctx.Done(): + return + } + } + }() + + // 预生成数据 + data := map[string]interface{}{ + "deviceId": "device1", + "temperature": 25.0, + } + + b.ResetTimer() + start := time.Now() + + for i := 0; i < b.N; i++ { + ssql.Emit(data) + } + + b.StopTimer() + duration := time.Since(start) + throughput := float64(b.N) / duration.Seconds() + + b.ReportMetric(throughput, "pure_ops/sec") + b.Logf("纯输入性能: %.0f ops/sec (%.1f万 ops/sec)", throughput, throughput/10000) +} + +// BenchmarkPostAggregationPerformance 后聚合性能基准测试 +func BenchmarkPostAggregationPerformance(b *testing.B) { + tests := []struct { + name string + sql string + }{ + { + name: "SimpleAggregation", + sql: "SELECT deviceId, AVG(temperature) FROM stream GROUP BY deviceId, TumblingWindow('100ms')", + }, + { + name: "PostAggregationSimple", + sql: "SELECT deviceId, AVG(temperature) + 10 as adjusted_temp FROM stream GROUP BY deviceId, TumblingWindow('100ms')", + }, + { + name: "PostAggregationComplex", + sql: "SELECT deviceId, CEIL(AVG(temperature) * 1.8 + 32) as fahrenheit FROM stream GROUP BY deviceId, TumblingWindow('100ms')", + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + ssql := New() + defer ssql.Stop() + + err := ssql.Execute(tt.sql) + if err != nil { + b.Fatalf("SQL执行失败: %v", err) + } + + var resultCount int64 + ssql.AddSink(func(result []map[string]interface{}) { + atomic.AddInt64(&resultCount, 1) + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + select { + case <-ssql.Stream().GetResultsChan(): + case <-ctx.Done(): + return + } + } + }() + + testData := generateOptimizedTestData(5) + ssql.Stream().ResetStats() + + b.ResetTimer() + + start := time.Now() + for i := 0; i < b.N; i++ { + ssql.Emit(testData[i%len(testData)]) + } + inputDuration := time.Since(start) + + b.StopTimer() + + time.Sleep(200 * time.Millisecond) + cancel() + + stats := ssql.Stream().GetStats() + results := atomic.LoadInt64(&resultCount) + + throughput := float64(b.N) / inputDuration.Seconds() + processedCount := stats["output_count"] + droppedCount := stats["dropped_count"] + processRate := float64(processedCount) / float64(b.N) * 100 + dropRate := float64(droppedCount) / float64(b.N) * 100 + + b.ReportMetric(throughput, "ops/sec") + b.ReportMetric(processRate, "process_rate_%") + b.ReportMetric(dropRate, "drop_rate_%") + b.ReportMetric(float64(results), "results") + + b.Logf("%s性能:", tt.name) + b.Logf(" 吞吐量: %.0f ops/sec (%.1f万 ops/sec)", throughput, throughput/10000) + b.Logf(" 处理率: %.1f%%, 丢弃率: %.2f%%", processRate, dropRate) + b.Logf(" 结果数: %d", results) + }) + } +} + +// generateOptimizedTestData 生成优化的测试数据 +func generateOptimizedTestData(count int) []map[string]interface{} { + data := make([]map[string]interface{}, count) + devices := []string{"device1", "device2", "device3", "device4", "device5"} + + for i := 0; i < count; i++ { + data[i] = map[string]interface{}{ + "deviceId": devices[rand.Intn(len(devices))], + "temperature": 15.0 + rand.Float64()*20, // 15-35度 + "humidity": 30.0 + rand.Float64()*40, // 30-70% + "timestamp": time.Now().UnixNano(), + } + } + return data +} + +// BenchmarkMemoryEfficiency 内存效率基准测试 +func BenchmarkMemoryEfficiency(b *testing.B) { + configs := []struct { + name string + setupFunc func() *Streamsql + }{ + { + name: "Lightweight5K", + setupFunc: func() *Streamsql { + return New(WithBufferSizes(5000, 5000, 250)) + }, + }, + { + name: "Default20K", + setupFunc: func() *Streamsql { + return New() + }, + }, + { + name: "HighPerf50K", + setupFunc: func() *Streamsql { + return New(WithHighPerformance()) + }, + }, + } + + sql := "SELECT deviceId, temperature FROM stream WHERE temperature > 20" + + for _, config := range configs { + b.Run(config.name, func(b *testing.B) { + ssql := config.setupFunc() + defer ssql.Stop() + + err := ssql.Execute(sql) + if err != nil { + b.Fatalf("SQL执行失败: %v", err) + } + + var resultCount int64 + ssql.AddSink(func(result []map[string]interface{}) { + atomic.AddInt64(&resultCount, 1) + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + select { + case <-ssql.Stream().GetResultsChan(): + case <-ctx.Done(): + return + } + } + }() + + testData := generateOptimizedTestData(3) + + b.ResetTimer() + + start := time.Now() + for i := 0; i < b.N; i++ { + ssql.Emit(testData[i%len(testData)]) + } + inputDuration := time.Since(start) + + b.StopTimer() + + time.Sleep(50 * time.Millisecond) + cancel() + + detailedStats := ssql.Stream().GetDetailedStats() + results := atomic.LoadInt64(&resultCount) + + throughput := float64(b.N) / inputDuration.Seconds() + processRate := detailedStats["process_rate"].(float64) + dataChanUsage := detailedStats["data_chan_usage"].(float64) + resultChanUsage := detailedStats["result_chan_usage"].(float64) + + b.ReportMetric(throughput, "ops/sec") + b.ReportMetric(processRate, "process_rate_%") + b.ReportMetric(dataChanUsage, "data_chan_usage_%") + b.ReportMetric(resultChanUsage, "result_chan_usage_%") + b.ReportMetric(float64(results), "results") + + basicStats := detailedStats["basic_stats"].(map[string]int64) + b.Logf("%s配置效率:", config.name) + b.Logf(" 缓冲区: 数据%d/结果%d/Sink%d", + basicStats["data_chan_cap"], + basicStats["result_chan_cap"], + basicStats["sink_pool_cap"]) + b.Logf(" 吞吐量: %.0f ops/sec, 处理率: %.1f%%, 结果数: %d", throughput, processRate, results) + b.Logf(" 通道使用率: 数据%.1f%%, 结果%.1f%%", dataChanUsage, resultChanUsage) + }) + } +} diff --git a/streamsql_function_integration_test.go b/streamsql_function_integration_test.go index f6d0379..78d0be7 100644 --- a/streamsql_function_integration_test.go +++ b/streamsql_function_integration_test.go @@ -640,7 +640,7 @@ func TestNestedFunctionSupport(t *testing.T) { // 执行包含 avg(round(temperature, 2)) 的查询 query := "SELECT device, avg(round(temperature, 2)) as avg_rounded FROM stream GROUP BY device, TumblingWindow('1s')" - t.Logf("Executing query: %s", query) + err := streamsql.Execute(query) assert.Nil(t, err) @@ -686,7 +686,6 @@ func TestNestedFunctionSupport(t *testing.T) { } else if val, ok := avgRounded.(float64); ok { // 期望值:avg(20.57, 25.23, 30.12) = (20.57 + 25.23 + 30.12) / 3 = 25.31 assert.InEpsilon(t, 25.31, val, 0.01) - t.Logf("avg(round()) test passed: %v", val) } else { t.Errorf("avg_rounded is not a float64: %v (type: %T)", avgRounded, avgRounded) } @@ -740,9 +739,6 @@ func TestNestedFunctionSupport(t *testing.T) { assert.Len(t, resultSlice, 1) item := resultSlice[0] - for key, value := range item { - t.Logf(" %s: %v (type: %T)", key, value, value) - } assert.Equal(t, "sensor1", item["device"]) @@ -798,7 +794,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { assert.Len(t, resultSlice, 1) item := resultSlice[0] - t.Logf("Result: %+v", item) // 验证执行顺序:round(25.67, 1) -> 25.7, concat('temp_', '25.7') -> 'temp_25.7', upper('temp_25.7') -> 'TEMP_25.7' assert.Equal(t, "TEMP_25.7", item["formatted_temp"]) @@ -814,7 +809,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { defer streamsql.Stop() query := "SELECT device, round(len(upper(device)), 0) as device_length FROM stream" - t.Logf("Executing query: %s", query) + err := streamsql.Execute(query) assert.Nil(t, err) @@ -838,7 +833,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { assert.Len(t, resultSlice, 1) item := resultSlice[0] - t.Logf("Result: %+v", item) // 验证执行顺序:upper('sensor1') -> 'SENSOR1', len('SENSOR1') -> 7, round(7, 0) -> 7 assert.Equal(t, float64(7), item["device_length"]) @@ -854,7 +848,7 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { defer streamsql.Stop() query := "SELECT device, abs(round(sqrt(temperature), 2)) as processed_temp FROM stream" - t.Logf("Executing query: %s", query) + err := streamsql.Execute(query) assert.Nil(t, err) @@ -878,8 +872,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { assert.Len(t, resultSlice, 1) item := resultSlice[0] - //t.Logf("Result: %+v", item) - // 验证执行顺序:sqrt(16) -> 4, round(4, 2) -> 4.00, abs(4.00) -> 4.00 assert.Equal(t, float64(4), item["processed_temp"]) case <-ctx.Done(): @@ -887,56 +879,40 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { } }) - // 测试6: 复杂的聚合函数嵌套 + // 测试6: 复杂的聚合函数嵌套 - 应该报错 t.Run("ComplexAggregationNesting", func(t *testing.T) { - // 测试 max(round(avg(temperature), 1)) + // 测试 max(round(avg(temperature), 1)) - 这是嵌套聚合函数,应该报错 streamsql := New() defer streamsql.Stop() query := "SELECT device, max(round(avg(temperature), 1)) as max_rounded_avg FROM stream GROUP BY device, TumblingWindow('1s')" - t.Logf("Executing query: %s", query) err := streamsql.Execute(query) - assert.Nil(t, err) + // 应该返回嵌套聚合函数错误 + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "aggregate function calls cannot be nested") + }) - strm := streamsql.stream - resultChan := make(chan interface{}, 10) - strm.AddSink(func(result []map[string]interface{}) { - resultChan <- result - }) + // 测试7: 其他类型的嵌套聚合函数检测 + t.Run("NestedAggregationDetection", func(t *testing.T) { + streamsql := New() + defer streamsql.Stop() - // 添加测试数据 - testData := []map[string]interface{}{ - {"device": "sensor1", "temperature": 20.567}, - {"device": "sensor1", "temperature": 25.234}, - {"device": "sensor1", "temperature": 30.123}, - } + // 测试 sum(count(*)) - 聚合函数嵌套聚合函数 + query1 := "SELECT sum(count(*)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')" + err1 := streamsql.Execute(query1) + assert.NotNil(t, err1) + assert.Contains(t, err1.Error(), "aggregate function calls cannot be nested") - for _, data := range testData { - strm.Emit(data) - } + // 测试 avg(min(temperature)) - 聚合函数嵌套聚合函数 + query2 := "SELECT avg(min(temperature)) as nested_agg FROM stream GROUP BY device, TumblingWindow('1s')" + err2 := streamsql.Execute(query2) + assert.NotNil(t, err2) + assert.Contains(t, err2.Error(), "aggregate function calls cannot be nested") - // 等待窗口初始化 - time.Sleep(1 * time.Second) - strm.Window.Trigger() - - // 等待结果 - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - select { - case result := <-resultChan: - resultSlice, ok := result.([]map[string]interface{}) - require.True(t, ok) - assert.Len(t, resultSlice, 1) - - item := resultSlice[0] - //t.Logf("Result: %+v", item) - - // 验证执行顺序:avg(20.567, 25.234, 30.123) -> 25.308, round(25.308, 1) -> 25.3, max(25.3) -> 25.3 - assert.InEpsilon(t, 25.3, item["max_rounded_avg"], 0.01) - case <-ctx.Done(): - t.Fatal("测试超时") - } + // 测试 round(avg(temperature), 1) - 正常函数嵌套聚合函数,应该正常 + query3 := "SELECT round(avg(temperature), 1) as normal_nesting FROM stream GROUP BY device, TumblingWindow('1s')" + err3 := streamsql.Execute(query3) + assert.Nil(t, err3) // 这种嵌套应该是允许的 }) // 测试7: 日期时间函数嵌套 @@ -946,7 +922,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { defer streamsql.Stop() query := "SELECT device, year(date_add(created_at, 1, 'years')) as next_year FROM stream" - t.Logf("Executing query: %s", query) err := streamsql.Execute(query) assert.Nil(t, err) @@ -970,7 +945,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { assert.Len(t, resultSlice, 1) item := resultSlice[0] - //t.Logf("Result: %+v", item) // 验证执行顺序:date_add('2023-12-25 15:30:45', 1, 'years') -> '2024-12-25 15:30:45', year('2024-12-25 15:30:45') -> 2024 assert.Equal(t, float64(2024), item["next_year"]) @@ -986,7 +960,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { defer streamsql.Stop() query := "SELECT device, sqrt(len(invalid_field)) as error_result FROM stream" - t.Logf("Executing query: %s", query) err := streamsql.Execute(query) assert.Nil(t, err) @@ -1010,7 +983,6 @@ func TestNestedFunctionExecutionOrder(t *testing.T) { assert.Len(t, resultSlice, 1) item := resultSlice[0] - t.Logf("Error handling result: %+v", item) // 验证错误处理:invalid_field不存在,应该返回nil或默认值 _, exists := item["error_result"] diff --git a/streamsql_post_aggregation_test.go b/streamsql_post_aggregation_test.go new file mode 100644 index 0000000..d78afa0 --- /dev/null +++ b/streamsql_post_aggregation_test.go @@ -0,0 +1,562 @@ +package streamsql + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// 辅助函数:创建测试环境 +func createTestEnvironment(t *testing.T, rsql string) (*Streamsql, chan interface{}) { + ssql := New() + t.Cleanup(func() { ssql.Stop() }) + + err := ssql.Execute(rsql) + require.NoError(t, err) + + resultChan := make(chan interface{}, 10) + ssql.AddSink(func(result []map[string]interface{}) { + resultChan <- result + }) + + return ssql, resultChan +} + +// 辅助函数:发送测试数据并收集结果 +func sendDataAndCollectResults(t *testing.T, ssql *Streamsql, resultChan chan interface{}, testData []map[string]interface{}, windowSizeSeconds int) []map[string]interface{} { + for _, data := range testData { + ssql.Emit(data) + } + + // 等待窗口触发 + time.Sleep(time.Duration(windowSizeSeconds+1) * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + var results []map[string]interface{} +collecting: + for { + select { + case result := <-resultChan: + if resultSlice, ok := result.([]map[string]interface{}); ok { + results = append(results, resultSlice...) + } + case <-time.After(1 * time.Second): + break collecting + case <-ctx.Done(): + break collecting + } + } + + return results +} + +// TestPostAggregationExpressions 测试分阶段聚合功能 +func TestPostAggregationExpressions(t *testing.T) { + t.Run("基础聚合函数复杂运算", func(t *testing.T) { + rsql := `SELECT deviceId, + FIRST_VALUE(value) as firstVal, + LAST_VALUE(value) as lastVal, + (LAST_VALUE(value) - FIRST_VALUE(value)) as diffVal, + SUM(value) as sumVal, + AVG(value) as avgVal, + (SUM(value) / COUNT(*)) as calcAvg, + (SUM(value) + AVG(value)) as sumPlusAvg + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "dev1", "value": 10.0, "ts": baseTime}, + {"deviceId": "dev1", "value": 20.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "dev1", "value": 30.0, "ts": baseTime.Add(2 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + // 验证基础聚合函数运算 + assert.Equal(t, "dev1", result["deviceId"]) + assert.Equal(t, 10.0, result["firstVal"]) + assert.Equal(t, 30.0, result["lastVal"]) + assert.Equal(t, 20.0, result["diffVal"]) // LAST_VALUE - FIRST_VALUE + assert.Equal(t, 60.0, result["sumVal"]) + assert.Equal(t, 20.0, result["avgVal"]) + assert.Equal(t, 20.0, result["calcAvg"]) // SUM / COUNT + assert.Equal(t, 80.0, result["sumPlusAvg"]) // SUM + AVG + }) + + // IF_NULL 基础功能:在 IF_NULL 中包裹聚合/分析函数 + t.Run("验证:IF_NULL 基础功能", func(t *testing.T) { + rsql := `SELECT deviceId, + IF_NULL(FIRST_VALUE(value), 0) as firstOrZero, + IF_NULL(LAST_VALUE(value), 0) as lastOrZero, + IF_NULL(AVG(value), 0) as avgOrZero + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": nil, "ts": baseTime}, + {"deviceId": "sensor1", "value": 10.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": nil, "ts": baseTime.Add(2 * time.Second)}, + {"deviceId": "sensor1", "value": 30.0, "ts": baseTime.Add(3 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + assert.Equal(t, "sensor1", result["deviceId"]) + // FIRST_VALUE(value) 为 nil => IF_NULL(...,0) = 0 + assert.Equal(t, 0.0, result["firstOrZero"]) + // LAST_VALUE(value) 为 30 => IF_NULL(...,0) = 30 + assert.Equal(t, 30.0, result["lastOrZero"]) + // AVG(value) 仅计算非空 => (10+30)/2 = 20 => IF_NULL(...,0) = 20 + assert.Equal(t, 20.0, result["avgOrZero"]) + }) + + // 聚合函数参数中嵌套 IF_NULL:如 SUM(IF_NULL(value,0)) + t.Run("验证:聚合函数嵌套 IF_NULL", func(t *testing.T) { + rsql := `SELECT deviceId, + SUM(IF_NULL(value, 0)) as sumVal, + AVG(IF_NULL(value, 0)) as avgVal, + MAX(IF_NULL(value, 0)) as maxVal, + MIN(IF_NULL(value, 0)) as minVal + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": nil, "ts": baseTime}, + {"deviceId": "sensor1", "value": 10.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": nil, "ts": baseTime.Add(2 * time.Second)}, + {"deviceId": "sensor1", "value": 30.0, "ts": baseTime.Add(3 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + assert.Equal(t, "sensor1", result["deviceId"]) + // SUM(IF_NULL(value,0)) = 0 + 10 + 0 + 30 = 40 + assert.Equal(t, 40.0, result["sumVal"]) + // AVG(IF_NULL(value,0)) = (0 + 10 + 0 + 30)/4 = 10 + assert.Equal(t, 10.0, result["avgVal"]) + // MAX(IF_NULL(value,0)) = max(0,10,0,30) = 30 + assert.Equal(t, 30.0, result["maxVal"]) + // MIN(IF_NULL(value,0)) = min(0,10,0,30) = 0 + assert.Equal(t, 0.0, result["minVal"]) + }) + + t.Run("分析函数与聚合函数复杂运算", func(t *testing.T) { + rsql := `SELECT deviceId, + SUM(value) as total, + AVG(value) as average, + LATEST(value) as latest, + (SUM(value) + LATEST(value)) as totalPlusLatest, + (AVG(value) * LATEST(value)) as avgTimesLatest + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 10.0, "ts": baseTime}, + {"deviceId": "sensor1", "value": 20.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": 30.0, "ts": baseTime.Add(2 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + // 验证分析函数与聚合函数的复杂运算 + assert.Equal(t, "sensor1", result["deviceId"]) + assert.Equal(t, 60.0, result["total"]) // 10+20+30 + assert.Equal(t, 20.0, result["average"]) // 60/3 + assert.Equal(t, 30.0, result["latest"]) // 最新值 + assert.Equal(t, 90.0, result["totalPlusLatest"]) // 60 + 30 + assert.Equal(t, 600.0, result["avgTimesLatest"]) // 20 * 30 + }) + + t.Run("最外层嵌套普通函数验证", func(t *testing.T) { + rsql := `SELECT deviceId, + SUM(value) as total, + COUNT(*) as count, + AVG(value) as average, + MAX(value) as maxVal, + (COUNT(*) * AVG(value)) as countTimesAvg, + (SUM(value) / MAX(value)) as sumDivideMax, + ((COUNT(*) + SUM(value)) * AVG(value)) as complexNested, + FLOOR((SUM(value) / MAX(value))) as floorResult, + CEIL((AVG(value) / COUNT(*))) as ceilResult, + ROUND((SUM(value) * AVG(value) / 1000), 2) as roundResult + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 10.0, "ts": baseTime}, + {"deviceId": "sensor1", "value": 20.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": 30.0, "ts": baseTime.Add(2 * time.Second)}, + {"deviceId": "sensor1", "value": 40.0, "ts": baseTime.Add(3 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + // 验证基础函数 + assert.Equal(t, "sensor1", result["deviceId"]) + assert.Equal(t, 100.0, result["total"]) // 10+20+30+40 + assert.Equal(t, 4.0, result["count"]) // 4 records + assert.Equal(t, 25.0, result["average"]) // 100/4 + assert.Equal(t, 40.0, result["maxVal"]) // max value + + // 验证最外层嵌套普通函数 + // (COUNT(*) * AVG(value)) = 4 * 25 = 100 + assert.Equal(t, 100.0, result["countTimesAvg"], "最外层嵌套函数计算错误") + + // (SUM(value) / MAX(value)) = 100 / 40 = 2.5 + assert.Equal(t, 2.5, result["sumDivideMax"], "最外层嵌套函数计算错误") + + // ((COUNT(*) + SUM(value)) * AVG(value)) = (4 + 100) * 25 = 2600 + assert.Equal(t, 2600.0, result["complexNested"], "最外层复杂嵌套函数计算错误") + + // 验证最外层嵌套普通函数 + // FLOOR((SUM(value) / MAX(value))) = FLOOR(100/40) = FLOOR(2.5) = 2 + if floorResult, ok := result["floorResult"].(float64); ok { + assert.Equal(t, 2.0, floorResult, "FLOOR函数嵌套计算错误") + } + + // CEIL((AVG(value) / COUNT(*))) = CEIL(25/4) = CEIL(6.25) = 7 + if ceilResult, ok := result["ceilResult"].(float64); ok { + assert.Equal(t, 7.0, ceilResult, "CEIL函数嵌套计算错误") + } + + // ROUND((SUM(value) * AVG(value) / 1000), 2) = ROUND(100*25/1000, 2) = ROUND(2.5, 2) = 2.5 + if roundResult, ok := result["roundResult"].(float64); ok { + assert.Equal(t, 2.5, roundResult, "ROUND函数嵌套计算错误") + } + + // 验证最外层嵌套普通函数的正确性 + assert.Equal(t, 100.0, result["countTimesAvg"], "COUNT(*) * AVG(value) 计算错误") + assert.Equal(t, 2.5, result["sumDivideMax"], "SUM(value) / MAX(value) 计算错误") + assert.Equal(t, 2600.0, result["complexNested"], "复杂嵌套表达式计算错误") + assert.Equal(t, 2.0, result["floorResult"], "FLOOR函数嵌套计算错误") + assert.Equal(t, 7.0, result["ceilResult"], "CEIL函数嵌套计算错误") + assert.Equal(t, 2.5, result["roundResult"], "ROUND函数嵌套计算错误") + }) + + t.Run("电表读数差值计算", func(t *testing.T) { + rsql := `SELECT deviceId, + (LAST_VALUE(displayNum) - FIRST_VALUE(displayNum)) as diffVal, + window_start() as start, + window_end() as end + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + // 设备1的数据 + {"deviceId": "meter001", "displayNum": 100.0, "ts": baseTime}, + {"deviceId": "meter001", "displayNum": 115.0, "ts": baseTime.Add(3 * time.Second)}, + + // 设备2的数据 + {"deviceId": "meter002", "displayNum": 200.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "meter002", "displayNum": 206.0, "ts": baseTime.Add(4 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.GreaterOrEqual(t, len(results), 1, "应该至少有一个窗口的结果") + + // 预期结果 + expectedDiffs := map[string]float64{ + "meter001": 15.0, // 115.0 - 100.0 = 15.0 kWh + "meter002": 6.0, // 206.0 - 200.0 = 6.0 kWh + } + + // 验证每个设备的计算结果 + deviceResults := make(map[string]map[string]interface{}) + for _, result := range results { + deviceId, ok := result["deviceId"].(string) + require.True(t, ok, "deviceId应该是字符串类型") + deviceResults[deviceId] = result + } + + for deviceId, expectedDiff := range expectedDiffs { + result, exists := deviceResults[deviceId] + assert.True(t, exists, "应该有设备 %s 的结果", deviceId) + + if exists { + diffVal, ok := result["diffVal"].(float64) + assert.True(t, ok, "diffVal应该是float64类型") + assert.InEpsilon(t, expectedDiff, diffVal, 0.001, + "设备 %s 的用电量计算应该正确: 期望 %.1f, 实际 %.1f", + deviceId, expectedDiff, diffVal) + + // 验证窗口时间字段存在 + assert.Contains(t, result, "start", "结果应包含窗口开始时间") + assert.Contains(t, result, "end", "结果应包含窗口结束时间") + } + } + + // 原始问题验证成功:电表读数差值计算正确 + }) + + t.Run("综合功能验证", func(t *testing.T) { + rsql := `SELECT deviceId, + SUM(value) as total, + AVG(value) as average, + FIRST_VALUE(value) as first, + LAST_VALUE(value) as last, + LATEST(value) as latest, + COUNT(*) as count, + MAX(value) as maxVal, + MIN(value) as minVal, + ((SUM(value) + FIRST_VALUE(value)) / COUNT(*)) as complexCalc1, + (LAST_VALUE(value) * AVG(value) - FIRST_VALUE(value)) as complexCalc2, + ((LATEST(value) + SUM(value)) / COUNT(*)) as complexCalc3, + (MAX(value) + MIN(value) - AVG(value)) as complexCalc4, + ROUND(SQRT(ABS(AVG(value) - MIN(value))), 2) as nestedMathFunc, + UPPER(CONCAT('RESULT_', CAST(ROUND(SUM(value), 0) as STRING))) as nestedStrMathFunc + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 10.0, "ts": baseTime}, + {"deviceId": "sensor1", "value": 20.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": 30.0, "ts": baseTime.Add(2 * time.Second)}, + {"deviceId": "sensor1", "value": 40.0, "ts": baseTime.Add(3 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + // 验证基础函数 + assert.Equal(t, "sensor1", result["deviceId"]) + assert.Equal(t, 100.0, result["total"]) // 10+20+30+40 + assert.Equal(t, 25.0, result["average"]) // 100/4 + assert.Equal(t, 10.0, result["first"]) // 第一个值 + assert.Equal(t, 40.0, result["last"]) // 最后一个值 + assert.Equal(t, 40.0, result["latest"]) // 最新值 + assert.Equal(t, 4.0, result["count"]) // 4条记录 + assert.Equal(t, 40.0, result["maxVal"]) // 最大值 + assert.Equal(t, 10.0, result["minVal"]) // 最小值 + + // 验证复杂表达式计算 + assert.Equal(t, 27.5, result["complexCalc1"]) // (100 + 10) / 4 = 27.5 + assert.Equal(t, 990.0, result["complexCalc2"]) // 40 * 25 - 10 = 990 + assert.Equal(t, 35.0, result["complexCalc3"]) // (40 + 100) / 4 = 35 + assert.Equal(t, 25.0, result["complexCalc4"]) // 40 + 10 - 25 = 25 + + // 验证多层嵌套数学函数 + // ROUND(SQRT(ABS(AVG(value) - MIN(value))), 2) = ROUND(SQRT(ABS(25-10)), 2) = ROUND(SQRT(15), 2) ≈ 3.87 + if nestedMathFunc, ok := result["nestedMathFunc"].(float64); ok { + assert.InEpsilon(t, 3.87, nestedMathFunc, 0.01, "多层嵌套数学函数计算错误") + } + + // 验证多层嵌套字符串和数学函数 + // UPPER(CONCAT('RESULT_', CAST(ROUND(SUM(value), 0) as STRING))) = UPPER(CONCAT('RESULT_', '100')) = 'RESULT_100' + if nestedStrMathFunc, ok := result["nestedStrMathFunc"].(string); ok { + assert.Equal(t, "RESULT_100", nestedStrMathFunc, "多层嵌套字符串和数学函数计算错误") + } + }) + + t.Run("嵌套聚合函数运算测试", func(t *testing.T) { + rsql := `SELECT deviceId, + SUM(value) as total, + AVG(value) as average, + COUNT(*) as count, + MAX(value) as maxVal, + MIN(value) as minVal, + ROUND(AVG(ABS(value)), 2) as avgAbs, + MAX(ROUND(value, 1)) as maxRounded, + MIN(CEIL(value / 10)) as minCeiled, + AVG(SQRT(value)) as avgSqrt, + SUM(POWER(value, 2)) as sumSquares, + CEIL(AVG(FLOOR(SQRT(value)))) as tripleNested2, + ABS(MIN(ROUND(value / 5, 2))) as tripleNested3 + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 16.0, "ts": baseTime}, + {"deviceId": "sensor1", "value": 25.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": 36.0, "ts": baseTime.Add(2 * time.Second)}, + {"deviceId": "sensor1", "value": 49.0, "ts": baseTime.Add(3 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + // 验证基础聚合函数 + assert.Equal(t, "sensor1", result["deviceId"]) + assert.Equal(t, 126.0, result["total"]) // 16+25+36+49 + assert.Equal(t, 31.5, result["average"]) // 126/4 + assert.Equal(t, 4.0, result["count"]) // 4 records + assert.Equal(t, 49.0, result["maxVal"]) // max value + assert.Equal(t, 16.0, result["minVal"]) // min value + + // 验证嵌套聚合函数运算 + // ROUND(AVG(ABS(value)), 2) = ROUND(AVG(16,25,36,49), 2) = ROUND(31.5, 2) = 31.5 + if avgAbs, ok := result["avgAbs"].(float64); ok { + assert.Equal(t, 31.5, avgAbs, "AVG(ABS(value))计算错误") + } + + // MAX(ROUND(value, 1)) = MAX(16.0, 25.0, 36.0, 49.0) = 49.0 + if maxRounded, ok := result["maxRounded"].(float64); ok { + assert.Equal(t, 49.0, maxRounded, "MAX(ROUND(value, 1))计算错误") + } + + // MIN(CEIL(value / 10)) = MIN(CEIL(1.6), CEIL(2.5), CEIL(3.6), CEIL(4.9)) = MIN(2, 3, 4, 5) = 2 + if minCeiled, ok := result["minCeiled"].(float64); ok { + assert.Equal(t, 2.0, minCeiled, "MIN(CEIL(value / 10))计算错误") + } + + // AVG(SQRT(value)) = AVG(SQRT(16), SQRT(25), SQRT(36), SQRT(49)) = AVG(4, 5, 6, 7) = 5.5 + if avgSqrt, ok := result["avgSqrt"].(float64); ok { + assert.Equal(t, 5.5, avgSqrt, "AVG(SQRT(value))计算错误") + } + + // SUM(POWER(value, 2)) = SUM(16^2, 25^2, 36^2, 49^2) = SUM(256, 625, 1296, 2401) = 4578 + if sumSquares, ok := result["sumSquares"].(float64); ok { + assert.Equal(t, 4578.0, sumSquares, "SUM(POWER(value, 2))计算错误") + } + + // CEIL(AVG(FLOOR(SQRT(value)))) + // = CEIL(AVG(FLOOR(4), FLOOR(5), FLOOR(6), FLOOR(7))) = CEIL(AVG(4, 5, 6, 7)) = CEIL(5.5) = 6 + if tripleNested2, ok := result["tripleNested2"].(float64); ok { + assert.Equal(t, 6.0, tripleNested2, "三层嵌套聚合2计算错误") + } + + // ABS(MIN(ROUND(value / 5, 2))) + // = ABS(MIN(ROUND(3.2, 2), ROUND(5, 2), ROUND(7.2, 2), ROUND(9.8, 2))) + // = ABS(MIN(3.2, 5.0, 7.2, 9.8)) = ABS(3.2) = 3.2 + if tripleNested3, ok := result["tripleNested3"].(float64); ok { + assert.Equal(t, 3.2, tripleNested3, "三层嵌套聚合3计算错误") + } + }) + + t.Run("验证:NTH_VALUE和LEAD函数", func(t *testing.T) { + rsql := `SELECT deviceId, + SUM(value) as total, + COUNT(*) as count, + NTH_VALUE(value, 2) as secondValue, + LEAD(value, 1) as leadValue, + (COUNT(*) * NTH_VALUE(value, 2)) as countTimesSecond, + (SUM(value) + LEAD(value, 1)) as sumPlusLead + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 10.0, "ts": baseTime}, + {"deviceId": "sensor1", "value": 20.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": 30.0, "ts": baseTime.Add(2 * time.Second)}, + {"deviceId": "sensor1", "value": 40.0, "ts": baseTime.Add(3 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + // 验证基础函数 + assert.Equal(t, "sensor1", result["deviceId"]) + assert.Equal(t, 100.0, result["total"]) // 10+20+30+40 + assert.Equal(t, 4.0, result["count"]) // 4 records + + // 验证窗口函数基础功能 + assert.NotNil(t, result["countTimesSecond"], "COUNT(*) * NTH_VALUE(value, 2) 应该有计算结果") + + }) + + t.Run("验证:NTH_VALUE基础功能", func(t *testing.T) { + rsql := `SELECT deviceId, + NTH_VALUE(value, 1) as firstValue, + NTH_VALUE(value, 2) as secondValue, + NTH_VALUE(value, 3) as thirdValue, + NTH_VALUE(value, 4) as fourthValue + FROM stream + GROUP BY deviceId, TumblingWindow('5s') + WITH (TIMESTAMP='ts', TIMEUNIT='ss')` + + ssql, resultChan := createTestEnvironment(t, rsql) + + baseTime := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + testData := []map[string]interface{}{ + {"deviceId": "sensor1", "value": 100.0, "ts": baseTime}, + {"deviceId": "sensor1", "value": 200.0, "ts": baseTime.Add(1 * time.Second)}, + {"deviceId": "sensor1", "value": 300.0, "ts": baseTime.Add(2 * time.Second)}, + {"deviceId": "sensor1", "value": 400.0, "ts": baseTime.Add(3 * time.Second)}, + } + + results := sendDataAndCollectResults(t, ssql, resultChan, testData, 5) + require.Len(t, results, 1) + result := results[0] + + // 验证 NTH_VALUE 函数的返回值 + // 期望结果:按添加顺序 + // 第1个值: 100, 第2个值: 200, 第3个值: 300, 第4个值: 400 + if firstValue, ok := result["firstValue"].(float64); ok { + assert.Equal(t, 100.0, firstValue, "第1个值应该是100") + } else { + assert.Error(t, errors.New("firstValue 为空")) + } + + if secondValue, ok := result["secondValue"].(float64); ok { + assert.Equal(t, 200.0, secondValue, "第2个值应该是200") + } else { + assert.Error(t, errors.New("secondValue 为空")) + } + + if thirdValue, ok := result["thirdValue"].(float64); ok { + assert.Equal(t, 300.0, thirdValue, "第3个值应该是300") + } else { + assert.Error(t, errors.New("thirdValue 为空")) + } + + if fourthValue, ok := result["fourthValue"].(float64); ok { + assert.Equal(t, 400.0, fourthValue, "第4个值应该是400") + } else { + assert.Error(t, errors.New("fourthValue 为空")) + } + }) +} diff --git a/types/config.go b/types/config.go index e15c493..7a15d90 100644 --- a/types/config.go +++ b/types/config.go @@ -9,15 +9,16 @@ import ( // Config stream processing configuration type Config struct { // SQL processing related configuration - WindowConfig WindowConfig `json:"windowConfig"` - GroupFields []string `json:"groupFields"` - SelectFields map[string]aggregator.AggregateType `json:"selectFields"` - FieldAlias map[string]string `json:"fieldAlias"` - SimpleFields []string `json:"simpleFields"` - FieldExpressions map[string]FieldExpression `json:"fieldExpressions"` - FieldOrder []string `json:"fieldOrder"` // Original order of fields in SELECT statement - Where string `json:"where"` - Having string `json:"having"` + WindowConfig WindowConfig `json:"windowConfig"` + GroupFields []string `json:"groupFields"` + SelectFields map[string]aggregator.AggregateType `json:"selectFields"` + FieldAlias map[string]string `json:"fieldAlias"` + SimpleFields []string `json:"simpleFields"` + FieldExpressions map[string]FieldExpression `json:"fieldExpressions"` + PostAggExpressions []PostAggregationExpression `json:"postAggExpressions"` // Post-aggregation expressions + FieldOrder []string `json:"fieldOrder"` // Original order of fields in SELECT statement + Where string `json:"where"` + Having string `json:"having"` // Feature switches NeedWindow bool `json:"needWindow"` @@ -47,6 +48,23 @@ type FieldExpression struct { Fields []string `json:"fields"` // all fields referenced in expression } +// PostAggregationExpression represents an expression that needs to be evaluated after aggregation +type PostAggregationExpression struct { + OutputField string `json:"outputField"` // 输出字段名 + OriginalExpr string `json:"originalExpr"` // 原始表达式 + ExpressionTemplate string `json:"expressionTemplate"` // 表达式模板 + RequiredFields []AggregationFieldInfo `json:"requiredFields"` // 依赖的聚合字段 +} + +// AggregationFieldInfo holds information about an aggregation function in an expression +type AggregationFieldInfo struct { + FuncName string `json:"funcName"` // 函数名,如 "first_value" + InputField string `json:"inputField"` // 输入字段,如 "displayNum" + Placeholder string `json:"placeholder"` // 占位符,如 "__first_value_0__" + AggType aggregator.AggregateType `json:"aggType"` // 聚合类型 + FullCall string `json:"fullCall"` // 完整函数调用,如 "NTH_VALUE(value, 2)" +} + // ProjectionSourceType projection source type type ProjectionSourceType int