Merge pull request #10 from rulego/dev

fix:修复嵌套函数错误
This commit is contained in:
Whki
2025-05-28 19:21:50 +08:00
committed by GitHub
7 changed files with 378 additions and 50 deletions
+87 -1
View File
@@ -2,12 +2,12 @@ package functions
import (
"fmt"
"github.com/rulego/streamsql/utils/cast"
"strconv"
"strings"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
"github.com/rulego/streamsql/utils/cast"
)
// ExprBridge 桥接 StreamSQL 函数系统与 expr-lang/expr
@@ -109,6 +109,11 @@ func (bridge *ExprBridge) CompileExpressionWithStreamSQLFunctions(expression str
// EvaluateExpression 评估表达式,自动选择最合适的引擎
func (bridge *ExprBridge) EvaluateExpression(expression string, data map[string]interface{}) (interface{}, error) {
// 首先检查是否是CONCAT函数调用
if strings.HasPrefix(strings.ToUpper(expression), "CONCAT(") {
return bridge.evaluateConcatFunction(expression, data)
}
// 首先检查是否包含字符串拼接模式
if bridge.isStringConcatenationExpression(expression, data) {
result, err := bridge.evaluateStringConcatenation(expression, data)
@@ -177,6 +182,87 @@ func (bridge *ExprBridge) fallbackToCustomExpr(expression string, data map[strin
return nil, fmt.Errorf("unable to evaluate expression: %s, string concat error: %v, numeric error: %v", expression, err, err)
}
// evaluateConcatFunction 处理CONCAT函数调用
func (bridge *ExprBridge) evaluateConcatFunction(expression string, data map[string]interface{}) (interface{}, error) {
// 提取CONCAT函数的参数
start := strings.Index(expression, "(")
end := strings.LastIndex(expression, ")")
if start == -1 || end == -1 || end <= start {
return nil, fmt.Errorf("invalid CONCAT function syntax: %s", expression)
}
// 获取参数字符串
paramsStr := strings.TrimSpace(expression[start+1 : end])
if paramsStr == "" {
return "", nil // 空参数返回空字符串
}
// 解析参数
params := bridge.parseParameters(paramsStr)
var result strings.Builder
for _, param := range params {
param = strings.TrimSpace(param)
// 处理字符串字面量
if (strings.HasPrefix(param, "'") && strings.HasSuffix(param, "'")) ||
(strings.HasPrefix(param, "\"") && strings.HasSuffix(param, "\"")) {
// 去掉引号
literal := param[1 : len(param)-1]
result.WriteString(literal)
} else {
// 处理字段引用
if value, exists := data[param]; exists {
strValue := cast.ToString(value)
result.WriteString(strValue)
} else {
return nil, fmt.Errorf("field %s not found in data", param)
}
}
}
return result.String(), nil
}
// parseParameters 解析函数参数,正确处理引号内的逗号
func (bridge *ExprBridge) parseParameters(paramsStr string) []string {
var params []string
var current strings.Builder
inQuotes := false
quoteChar := byte(0)
for i := 0; i < len(paramsStr); i++ {
ch := paramsStr[i]
if !inQuotes {
if ch == '\'' || ch == '"' {
inQuotes = true
quoteChar = ch
current.WriteByte(ch)
} else if ch == ',' {
// 参数分隔符
params = append(params, current.String())
current.Reset()
} else {
current.WriteByte(ch)
}
} else {
if ch == quoteChar {
inQuotes = false
quoteChar = 0
}
current.WriteByte(ch)
}
}
// 添加最后一个参数
if current.Len() > 0 {
params = append(params, current.String())
}
return params
}
// evaluateStringConcatenation 处理字符串拼接表达式
func (bridge *ExprBridge) evaluateStringConcatenation(expression string, data map[string]interface{}) (interface{}, error) {
// 检查是否是字符串拼接表达式 (包含 + 和字符串字面量)
+105
View File
@@ -425,11 +425,13 @@ func (f *PercentileFunction) Execute(ctx *FunctionContext, args []interface{}) (
// CollectFunction 收集函数 - 获取当前窗口所有消息的列值组成的数组
type CollectFunction struct {
*BaseFunction
values []interface{}
}
func NewCollectFunction() *CollectFunction {
return &CollectFunction{
BaseFunction: NewBaseFunction("collect", TypeAggregation, "聚合函数", "收集所有值组成数组", 1, -1),
values: make([]interface{}, 0),
}
}
@@ -444,14 +446,47 @@ func (f *CollectFunction) Execute(ctx *FunctionContext, args []interface{}) (int
return result, nil
}
// 实现AggregatorFunction接口
func (f *CollectFunction) New() AggregatorFunction {
return &CollectFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, 0),
}
}
func (f *CollectFunction) Add(value interface{}) {
f.values = append(f.values, value)
}
func (f *CollectFunction) Result() interface{} {
result := make([]interface{}, len(f.values))
copy(result, f.values)
return result
}
func (f *CollectFunction) Reset() {
f.values = make([]interface{}, 0)
}
func (f *CollectFunction) Clone() AggregatorFunction {
newFunc := &CollectFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, len(f.values)),
}
copy(newFunc.values, f.values)
return newFunc
}
// LastValueFunction 最后值函数 - 返回组中最后一行的值
type LastValueFunction struct {
*BaseFunction
lastValue interface{}
}
func NewLastValueFunction() *LastValueFunction {
return &LastValueFunction{
BaseFunction: NewBaseFunction("last_value", TypeAggregation, "聚合函数", "返回最后一个值", 1, -1),
lastValue: nil,
}
}
@@ -467,14 +502,43 @@ func (f *LastValueFunction) Execute(ctx *FunctionContext, args []interface{}) (i
return args[len(args)-1], nil
}
// 实现AggregatorFunction接口
func (f *LastValueFunction) New() AggregatorFunction {
return &LastValueFunction{
BaseFunction: f.BaseFunction,
lastValue: nil,
}
}
func (f *LastValueFunction) Add(value interface{}) {
f.lastValue = value
}
func (f *LastValueFunction) Result() interface{} {
return f.lastValue
}
func (f *LastValueFunction) Reset() {
f.lastValue = nil
}
func (f *LastValueFunction) Clone() AggregatorFunction {
return &LastValueFunction{
BaseFunction: f.BaseFunction,
lastValue: f.lastValue,
}
}
// MergeAggFunction 合并聚合函数 - 将组中的值合并为单个值
type MergeAggFunction struct {
*BaseFunction
values []interface{}
}
func NewMergeAggFunction() *MergeAggFunction {
return &MergeAggFunction{
BaseFunction: NewBaseFunction("merge_agg", TypeAggregation, "聚合函数", "合并所有值", 1, -1),
values: make([]interface{}, 0),
}
}
@@ -498,6 +562,47 @@ func (f *MergeAggFunction) Execute(ctx *FunctionContext, args []interface{}) (in
return result.String(), nil
}
// 实现AggregatorFunction接口
func (f *MergeAggFunction) New() AggregatorFunction {
return &MergeAggFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, 0),
}
}
func (f *MergeAggFunction) Add(value interface{}) {
f.values = append(f.values, value)
}
func (f *MergeAggFunction) Result() interface{} {
if len(f.values) == 0 {
return nil
}
// 尝试合并为字符串
var result strings.Builder
for i, arg := range f.values {
if i > 0 {
result.WriteString(",")
}
result.WriteString(cast.ToString(arg))
}
return result.String()
}
func (f *MergeAggFunction) Reset() {
f.values = make([]interface{}, 0)
}
func (f *MergeAggFunction) Clone() AggregatorFunction {
newFunc := &MergeAggFunction{
BaseFunction: f.BaseFunction,
values: make([]interface{}, len(f.values)),
}
copy(newFunc.values, f.values)
return newFunc
}
// StdDevSFunction 样本标准差函数
type StdDevSFunction struct {
*BaseFunction
+4
View File
@@ -220,10 +220,14 @@ func (f *ExpressionAggregatorFunction) New() AggregatorFunction {
func (f *ExpressionAggregatorFunction) Add(value interface{}) {
// 对于表达式聚合器,保存最后一个计算结果
// 表达式的计算结果应该是每个数据项的计算结果
f.lastResult = value
}
func (f *ExpressionAggregatorFunction) Result() interface{} {
// 对于表达式聚合器,返回最后一个计算结果
// 注意:对于字符串函数如CONCAT,每个数据项都会产生一个结果
// 在窗口聚合中,我们返回最后一个计算的结果
return f.lastResult
}
+67 -5
View File
@@ -325,6 +325,29 @@ func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName s
// 对于复杂表达式,包括多参数函数调用
expression = fieldExpr
// 对于CONCAT等字符串函数,直接保存完整表达式
if strings.ToLower(funcName) == "concat" {
// 智能解析CONCAT函数的参数来提取字段名
var fields []string
params := parseSmartParameters(fieldExpr)
for _, param := range params {
param = strings.TrimSpace(param)
// 如果参数不是字符串常量(不被引号包围),则认为是字段名
if !((strings.HasPrefix(param, "'") && strings.HasSuffix(param, "'")) ||
(strings.HasPrefix(param, "\"") && strings.HasSuffix(param, "\""))) {
if isIdentifier(param) {
fields = append(fields, param)
}
}
}
if len(fields) > 0 {
// 对于CONCAT函数,保存完整的函数调用作为表达式
return fields[0], funcName + "(" + fieldExpr + ")", fields
}
// 如果没有找到字段,返回空字段名但保留表达式
return "", funcName + "(" + fieldExpr + ")", nil
}
// 使用表达式引擎解析
parsedExpr, err := expr.NewExpression(fieldExpr)
if err != nil {
@@ -370,23 +393,62 @@ func extractAggFieldWithExpression(exprStr string, funcName string) (fieldName s
return fieldExpr, expression, nil
}
// parseSmartParameters 智能解析函数参数,正确处理引号内的逗号
func parseSmartParameters(paramsStr string) []string {
var params []string
var current strings.Builder
inQuotes := false
quoteChar := byte(0)
for i := 0; i < len(paramsStr); i++ {
ch := paramsStr[i]
if !inQuotes {
if ch == '\'' || ch == '"' {
inQuotes = true
quoteChar = ch
current.WriteByte(ch)
} else if ch == ',' {
// 参数分隔符
params = append(params, current.String())
current.Reset()
} else {
current.WriteByte(ch)
}
} else {
if ch == quoteChar {
inQuotes = false
quoteChar = 0
}
current.WriteByte(ch)
}
}
// 添加最后一个参数
if current.Len() > 0 {
params = append(params, current.String())
}
return params
}
// isIdentifier 检查字符串是否是有效的标识符
func isIdentifier(s string) bool {
if len(s) == 0 {
return false
}
// 第一个字符必须是字母或下划线
if !((s[0] >= 'a' && s[0] <= 'z') || (s[0] >= 'A' && s[0] <= 'Z') || s[0] == '_') {
return false
}
// 其余字符必须是字母、数字或下划线
for i := 1; i < len(s); i++ {
if !((s[i] >= 'a' && s[i] <= 'z') || (s[i] >= 'A' && s[i] <= 'Z') ||
(s[i] >= '0' && s[i] <= '9') || s[i] == '_') {
char := s[i]
if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') || char == '_') {
return false
}
}
return true
}
+19 -10
View File
@@ -90,6 +90,11 @@ func (l *Lexer) NextToken() Token {
l.readChar()
return Token{Type: TokenSlash, Value: "/"}
case '=':
if l.peekChar() == '=' {
l.readChar()
l.readChar()
return Token{Type: TokenEQ, Value: "=="}
}
l.readChar()
return Token{Type: TokenEQ, Value: "="}
case '>':
@@ -114,6 +119,10 @@ func (l *Lexer) NextToken() Token {
l.readChar()
return Token{Type: TokenNE, Value: "!="}
}
case '\'':
return Token{Type: TokenString, Value: l.readString()}
case '"':
return Token{Type: TokenString, Value: l.readString()}
}
if isLetter(l.ch) {
@@ -125,10 +134,6 @@ func (l *Lexer) NextToken() Token {
return Token{Type: TokenNumber, Value: l.readNumber()}
}
if l.ch == '\'' {
return Token{Type: TokenString, Value: l.readString()}
}
l.readChar()
return Token{Type: TokenEOF}
}
@@ -188,16 +193,20 @@ func (l *Lexer) readNumber() string {
}
func (l *Lexer) readString() string {
l.readChar() // 跳过开头单引号
pos := l.pos
quoteChar := l.ch // 记录引号类型(单引号或双引号)
startPos := l.pos // 记录开始位置(包含引号)
l.readChar() // 跳过开头引号
for l.ch != '\'' && l.ch != 0 {
for l.ch != quoteChar && l.ch != 0 {
l.readChar()
}
str := l.input[pos:l.pos]
l.readChar() // 跳过结尾引号
return str
if l.ch == quoteChar {
l.readChar() // 跳过结尾引号
}
// 返回包含引号的完整字符串
return l.input[startPos:l.pos]
}
func (l *Lexer) skipWhitespace() {
+12 -4
View File
@@ -171,9 +171,13 @@ func (p *Parser) parseWhere(stmt *SelectStatement) error {
case TokenIdent, TokenNumber:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, "'"+tok.Value+"'")
conditions = append(conditions, tok.Value)
case TokenEQ:
conditions = append(conditions, "==")
if tok.Value == "=" {
conditions = append(conditions, "==")
} else {
conditions = append(conditions, tok.Value)
}
case TokenAND:
conditions = append(conditions, "&&")
case TokenOR:
@@ -431,9 +435,13 @@ func (p *Parser) parseHaving(stmt *SelectStatement) error {
case TokenIdent, TokenNumber:
conditions = append(conditions, tok.Value)
case TokenString:
conditions = append(conditions, "'"+tok.Value+"'")
conditions = append(conditions, tok.Value)
case TokenEQ:
conditions = append(conditions, "==")
if tok.Value == "=" {
conditions = append(conditions, "==")
} else {
conditions = append(conditions, tok.Value)
}
case TokenAND:
conditions = append(conditions, "&&")
case TokenOR:
+84 -30
View File
@@ -23,6 +23,7 @@ func TestStreamProcess(t *testing.T) {
"temperature": aggregator.Avg,
"humidity": aggregator.Sum,
},
NeedWindow: true,
}
strm, err := NewStream(config)
@@ -64,16 +65,41 @@ func TestStreamProcess(t *testing.T) {
// 预期结果:只有 device='aa' 且 temperature>10 的数据会被聚合
expected := map[string]interface{}{
"device": "aa",
"temperature_avg": 27.5, // (25+30)/2
"humidity_sum": 115.0, // 60+55
"device": "aa",
"temperature": 27.5, // (25+30)/2
"humidity": 115.0, // 60+55
}
// 验证结果
t.Logf("Received result: %+v (type: %T)", actual, actual)
if actual == nil {
t.Fatal("Received nil result")
}
assert.IsType(t, []map[string]interface{}{}, actual)
t.Logf("Type assertion successful")
resultMap := actual.([]map[string]interface{})
assert.InEpsilon(t, expected["temperature_avg"].(float64), resultMap[0]["temperature_avg"].(float64), 0.0001)
assert.InDelta(t, expected["humidity_sum"].(float64), resultMap[0]["humidity_sum"].(float64), 0.0001)
t.Logf("Result map length: %d", len(resultMap))
if len(resultMap) > 0 {
t.Logf("First result: %+v", resultMap[0])
// 检查temperature字段
if tempAvg, ok := resultMap[0]["temperature"]; ok {
t.Logf("temperature: %+v (type: %T)", tempAvg, tempAvg)
assert.InEpsilon(t, expected["temperature"].(float64), tempAvg.(float64), 0.0001)
} else {
t.Fatal("temperature field not found in result")
}
// 检查humidity字段
if humSum, ok := resultMap[0]["humidity"]; ok {
t.Logf("humidity: %+v (type: %T)", humSum, humSum)
assert.InDelta(t, expected["humidity"].(float64), humSum.(float64), 0.0001)
} else {
t.Fatal("humidity field not found in result")
}
} else {
t.Fatal("No results in result map")
}
}
// 不设置过滤器
@@ -88,6 +114,7 @@ func TestStreamWithoutFilter(t *testing.T) {
"temperature": aggregator.Max,
"humidity": aggregator.Min,
},
NeedWindow: true,
}
strm, err := NewStream(config)
@@ -126,14 +153,14 @@ func TestStreamWithoutFilter(t *testing.T) {
expected := []map[string]interface{}{
{
"device": "aa",
"temperature_max": 30.0,
"humidity_min": 55.0,
"device": "aa",
"temperature": 30.0,
"humidity": 55.0,
},
{
"device": "bb",
"temperature_max": 22.0,
"humidity_min": 70.0,
"device": "bb",
"temperature": 22.0,
"humidity": 70.0,
},
}
@@ -146,8 +173,8 @@ func TestStreamWithoutFilter(t *testing.T) {
found := false
for _, resultMap := range resultSlice {
if resultMap["device"] == expectedResult["device"] {
assert.InEpsilon(t, expectedResult["temperature_max"].(float64), resultMap["temperature_max"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["humidity_min"].(float64), resultMap["humidity_min"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["temperature"].(float64), resultMap["temperature"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["humidity"].(float64), resultMap["humidity"].(float64), 0.0001)
found = true
break
}
@@ -167,6 +194,7 @@ func TestIncompleteStreamProcess(t *testing.T) {
"temperature": aggregator.Avg,
"humidity": aggregator.Sum,
},
NeedWindow: true,
}
strm, err := NewStream(config)
@@ -210,16 +238,41 @@ func TestIncompleteStreamProcess(t *testing.T) {
// 预期结果:只有 device='aa' 且 temperature>10 的数据会被聚合
expected := map[string]interface{}{
"device": "aa",
"temperature_avg": 27.5, // (25+30)/2
"humidity_sum": 115.0, // 60+55
"device": "aa",
"temperature": 27.5, // (25+30)/2
"humidity": 115.0, // 60+55
}
// 验证结果
t.Logf("Received result: %+v (type: %T)", actual, actual)
if actual == nil {
t.Fatal("Received nil result")
}
assert.IsType(t, []map[string]interface{}{}, actual)
t.Logf("Type assertion successful")
resultMap := actual.([]map[string]interface{})
assert.InEpsilon(t, expected["temperature_avg"].(float64), resultMap[0]["temperature_avg"].(float64), 0.0001)
assert.InDelta(t, expected["humidity_sum"].(float64), resultMap[0]["humidity_sum"].(float64), 0.0001)
t.Logf("Result map length: %d", len(resultMap))
if len(resultMap) > 0 {
t.Logf("First result: %+v", resultMap[0])
// 检查temperature字段
if tempAvg, ok := resultMap[0]["temperature"]; ok {
t.Logf("temperature: %+v (type: %T)", tempAvg, tempAvg)
assert.InEpsilon(t, expected["temperature"].(float64), tempAvg.(float64), 0.0001)
} else {
t.Fatal("temperature field not found in result")
}
// 检查humidity字段
if humSum, ok := resultMap[0]["humidity"]; ok {
t.Logf("humidity: %+v (type: %T)", humSum, humSum)
assert.InDelta(t, expected["humidity"].(float64), humSum.(float64), 0.0001)
} else {
t.Fatal("humidity field not found in result")
}
} else {
t.Fatal("No results in result map")
}
}
func TestWindowSlotAgg(t *testing.T) {
@@ -236,6 +289,7 @@ func TestWindowSlotAgg(t *testing.T) {
"start": aggregator.WindowStart,
"end": aggregator.WindowEnd,
},
NeedWindow: true,
}
strm, err := NewStream(config)
@@ -276,18 +330,18 @@ func TestWindowSlotAgg(t *testing.T) {
expected := []map[string]interface{}{
{
"device": "aa",
"temperature_max": 30.0,
"humidity_min": 55.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
"device": "aa",
"temperature": 30.0,
"humidity": 55.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
{
"device": "bb",
"temperature_max": 22.0,
"humidity_min": 70.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
"device": "bb",
"temperature": 22.0,
"humidity": 70.0,
"start": baseTime.UnixNano(),
"end": baseTime.Add(2 * time.Second).UnixNano(),
},
}
@@ -300,8 +354,8 @@ func TestWindowSlotAgg(t *testing.T) {
found := false
for _, resultMap := range resultSlice {
if resultMap["device"] == expectedResult["device"] {
assert.InEpsilon(t, expectedResult["temperature_max"].(float64), resultMap["temperature_max"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["humidity_min"].(float64), resultMap["humidity_min"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["temperature"].(float64), resultMap["temperature"].(float64), 0.0001)
assert.InEpsilon(t, expectedResult["humidity"].(float64), resultMap["humidity"].(float64), 0.0001)
assert.Equal(t, expectedResult["start"].(int64), resultMap["start"].(int64))
assert.Equal(t, expectedResult["end"].(int64), resultMap["end"].(int64))
found = true