Files
streamsql/expr/expression.go
T
2025-05-25 18:02:37 +08:00

732 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package expr
import (
"fmt"
"math"
"strconv"
"strings"
"github.com/rulego/streamsql/functions"
)
// 表达式类型
const (
TypeNumber = "number" // 数字常量
TypeField = "field" // 字段引用
TypeOperator = "operator" // 运算符
TypeFunction = "function" // 函数调用
TypeParenthesis = "parenthesis" // 括号
)
// 操作符优先级
var operatorPrecedence = map[string]int{
"+": 1,
"-": 1,
"*": 2,
"/": 2,
"%": 2,
"^": 3, // 幂运算
}
// 表达式节点
type ExprNode struct {
Type string
Value string
Left *ExprNode
Right *ExprNode
Args []*ExprNode // 用于函数调用的参数
}
// Expression 表示一个可计算的表达式
type Expression struct {
Root *ExprNode
useExprLang bool // 是否使用expr-lang/expr
exprLangExpression string // expr-lang表达式字符串
}
// NewExpression 创建一个新的表达式
func NewExpression(exprStr string) (*Expression, error) {
// 首先尝试使用自定义解析器
tokens, err := tokenize(exprStr)
if err != nil {
// 如果自定义解析失败标记为使用expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
root, err := parseExpression(tokens)
if err != nil {
// 如果自定义解析失败标记为使用expr-lang
return &Expression{
Root: nil,
useExprLang: true,
exprLangExpression: exprStr,
}, nil
}
return &Expression{
Root: root,
useExprLang: false,
}, nil
}
// Evaluate 计算表达式的值
func (e *Expression) Evaluate(data map[string]interface{}) (float64, error) {
if e.useExprLang {
return e.evaluateWithExprLang(data)
}
return evaluateNode(e.Root, data)
}
// evaluateWithExprLang 使用expr-lang/expr评估表达式
func (e *Expression) evaluateWithExprLang(data map[string]interface{}) (float64, error) {
// 使用桥接器评估表达式
bridge := functions.GetExprBridge()
result, err := bridge.EvaluateExpression(e.exprLangExpression, data)
if err != nil {
return 0, err
}
// 尝试转换结果为float64
switch v := result.(type) {
case float64:
return v, nil
case float32:
return float64(v), nil
case int:
return float64(v), nil
case int32:
return float64(v), nil
case int64:
return float64(v), nil
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f, nil
}
return 0, fmt.Errorf("cannot convert string result '%s' to float64", v)
default:
return 0, fmt.Errorf("expression result type %T is not convertible to float64", result)
}
}
// GetFields 获取表达式中引用的所有字段
func (e *Expression) GetFields() []string {
if e.useExprLang {
// 对于expr-lang表达式需要解析字段引用
// 这里简化处理实际应该使用AST分析
return extractFieldsFromExprLang(e.exprLangExpression)
}
fields := make(map[string]bool)
collectFields(e.Root, fields)
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
return result
}
// extractFieldsFromExprLang 从expr-lang表达式中提取字段引用简化版本
func extractFieldsFromExprLang(expression string) []string {
// 这是一个简化的实现实际应该使用AST解析
// 暂时使用正则表达式或简单的字符串解析
fields := make(map[string]bool)
// 简单的字段提取:查找标识符模式
tokens := strings.FieldsFunc(expression, func(c rune) bool {
return !(c >= 'a' && c <= 'z') && !(c >= 'A' && c <= 'Z') && !(c >= '0' && c <= '9') && c != '_'
})
for _, token := range tokens {
if isIdentifier(token) && !isNumber(token) && !isFunctionOrKeyword(token) {
fields[token] = true
}
}
result := make([]string, 0, len(fields))
for field := range fields {
result = append(result, field)
}
return result
}
// isFunctionOrKeyword 检查是否是函数名或关键字
func isFunctionOrKeyword(token string) bool {
// 检查是否是已知函数或关键字
keywords := []string{
"and", "or", "not", "true", "false", "nil", "null",
"if", "else", "then", "in", "contains", "matches",
}
for _, keyword := range keywords {
if strings.ToLower(token) == keyword {
return true
}
}
// 检查是否是注册的函数
bridge := functions.GetExprBridge()
_, exists, _ := bridge.ResolveFunction(token)
return exists
}
// collectFields 收集表达式中所有字段
func collectFields(node *ExprNode, fields map[string]bool) {
if node == nil {
return
}
if node.Type == TypeField {
fields[node.Value] = true
}
collectFields(node.Left, fields)
collectFields(node.Right, fields)
for _, arg := range node.Args {
collectFields(arg, fields)
}
}
// evaluateNode 计算节点的值
func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error) {
if node == nil {
return 0, fmt.Errorf("null expression node")
}
switch node.Type {
case TypeNumber:
return strconv.ParseFloat(node.Value, 64)
case TypeField:
// 从数据中获取字段值
val, ok := data[node.Value]
if !ok {
return 0, fmt.Errorf("field %s not found in data", node.Value)
}
// 尝试转换为 float64
switch v := val.(type) {
case float64:
return v, nil
case float32:
return float64(v), nil
case int:
return float64(v), nil
case int32:
return float64(v), nil
case int64:
return float64(v), nil
default:
// 尝试字符串转换
if strVal, ok := val.(string); ok {
if f, err := strconv.ParseFloat(strVal, 64); err == nil {
return f, nil
}
}
return 0, fmt.Errorf("cannot convert field %s value to number", node.Value)
}
case TypeOperator:
// 计算左右子表达式的值
left, err := evaluateNode(node.Left, data)
if err != nil {
return 0, err
}
right, err := evaluateNode(node.Right, data)
if err != nil {
return 0, err
}
// 执行运算
switch node.Value {
case "+":
return left + right, nil
case "-":
return left - right, nil
case "*":
return left * right, nil
case "/":
if right == 0 {
return 0, fmt.Errorf("division by zero")
}
return left / right, nil
case "%":
if right == 0 {
return 0, fmt.Errorf("modulo by zero")
}
return math.Mod(left, right), nil
case "^":
return math.Pow(left, right), nil
default:
return 0, fmt.Errorf("unknown operator: %s", node.Value)
}
case TypeFunction:
// 首先检查是否是新的函数注册系统中的函数
fn, exists := functions.Get(node.Value)
if exists {
// 计算所有参数
args := make([]interface{}, len(node.Args))
for i, arg := range node.Args {
val, err := evaluateNode(arg, data)
if err != nil {
return 0, err
}
args[i] = val
}
// 创建函数执行上下文
ctx := &functions.FunctionContext{
Data: data,
}
// 执行函数
result, err := fn.Execute(ctx, args)
if err != nil {
return 0, err
}
// 转换结果为 float64
switch r := result.(type) {
case float64:
return r, nil
case float32:
return float64(r), nil
case int:
return float64(r), nil
case int32:
return float64(r), nil
case int64:
return float64(r), nil
default:
return 0, fmt.Errorf("function %s returned non-numeric value", node.Value)
}
}
// 回退到内置函数处理(保持向后兼容)
return evaluateBuiltinFunction(node, data)
}
return 0, fmt.Errorf("unknown node type: %s", node.Type)
}
// evaluateBuiltinFunction 处理内置函数(向后兼容)
func evaluateBuiltinFunction(node *ExprNode, data map[string]interface{}) (float64, error) {
switch strings.ToLower(node.Value) {
case "abs":
if len(node.Args) != 1 {
return 0, fmt.Errorf("abs function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Abs(arg), nil
case "sqrt":
if len(node.Args) != 1 {
return 0, fmt.Errorf("sqrt function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
if arg < 0 {
return 0, fmt.Errorf("sqrt of negative number")
}
return math.Sqrt(arg), nil
case "sin":
if len(node.Args) != 1 {
return 0, fmt.Errorf("sin function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Sin(arg), nil
case "cos":
if len(node.Args) != 1 {
return 0, fmt.Errorf("cos function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Cos(arg), nil
case "tan":
if len(node.Args) != 1 {
return 0, fmt.Errorf("tan function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Tan(arg), nil
case "floor":
if len(node.Args) != 1 {
return 0, fmt.Errorf("floor function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Floor(arg), nil
case "ceil":
if len(node.Args) != 1 {
return 0, fmt.Errorf("ceil function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Ceil(arg), nil
case "round":
if len(node.Args) != 1 {
return 0, fmt.Errorf("round function requires exactly 1 argument")
}
arg, err := evaluateNode(node.Args[0], data)
if err != nil {
return 0, err
}
return math.Round(arg), nil
default:
return 0, fmt.Errorf("unknown function: %s", node.Value)
}
}
// tokenize 将表达式字符串转换为token列表
func tokenize(expr string) ([]string, error) {
expr = strings.TrimSpace(expr)
if expr == "" {
return nil, fmt.Errorf("empty expression")
}
tokens := make([]string, 0)
i := 0
for i < len(expr) {
ch := expr[i]
// 跳过空白字符
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
i++
continue
}
// 处理数字
if isDigit(ch) || (ch == '.' && i+1 < len(expr) && isDigit(expr[i+1])) {
start := i
hasDot := ch == '.'
i++
for i < len(expr) && (isDigit(expr[i]) || (expr[i] == '.' && !hasDot)) {
if expr[i] == '.' {
hasDot = true
}
i++
}
tokens = append(tokens, expr[start:i])
continue
}
// 处理标识符(字段名或函数名)
if isLetter(ch) {
start := i
i++
for i < len(expr) && (isLetter(expr[i]) || isDigit(expr[i]) || expr[i] == '_') {
i++
}
tokens = append(tokens, expr[start:i])
continue
}
// 处理运算符和括号
if ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '%' || ch == '^' ||
ch == '(' || ch == ')' || ch == ',' {
tokens = append(tokens, string(ch))
i++
continue
}
// 未知字符
return nil, fmt.Errorf("unexpected character: %c at position %d", ch, i)
}
return tokens, nil
}
// parseExpression 解析表达式
func parseExpression(tokens []string) (*ExprNode, error) {
if len(tokens) == 0 {
return nil, fmt.Errorf("empty token list")
}
// 使用Shunting-yard算法处理运算符优先级
output := make([]*ExprNode, 0)
operators := make([]string, 0)
i := 0
for i < len(tokens) {
token := tokens[i]
// 处理数字
if isNumber(token) {
output = append(output, &ExprNode{
Type: TypeNumber,
Value: token,
})
i++
continue
}
// 处理字段名或函数调用
if isIdentifier(token) {
// 检查下一个token是否是左括号如果是则为函数调用
if i+1 < len(tokens) && tokens[i+1] == "(" {
funcName := token
i += 2 // 跳过函数名和左括号
// 解析函数参数
args, newIndex, err := parseFunctionArgs(tokens, i)
if err != nil {
return nil, err
}
output = append(output, &ExprNode{
Type: TypeFunction,
Value: funcName,
Args: args,
})
i = newIndex
continue
}
// 普通字段
output = append(output, &ExprNode{
Type: TypeField,
Value: token,
})
i++
continue
}
// 处理左括号
if token == "(" {
operators = append(operators, token)
i++
continue
}
// 处理右括号
if token == ")" {
for len(operators) > 0 && operators[len(operators)-1] != "(" {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
if len(operators) == 0 || operators[len(operators)-1] != "(" {
return nil, fmt.Errorf("mismatched parentheses")
}
operators = operators[:len(operators)-1] // 弹出左括号
i++
continue
}
// 处理运算符
if isOperator(token) {
for len(operators) > 0 && operators[len(operators)-1] != "(" &&
operatorPrecedence[operators[len(operators)-1]] >= operatorPrecedence[token] {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
operators = append(operators, token)
i++
continue
}
// 处理逗号(在函数参数列表中处理)
if token == "," {
i++
continue
}
return nil, fmt.Errorf("unexpected token: %s", token)
}
// 处理剩余的运算符
for len(operators) > 0 {
op := operators[len(operators)-1]
operators = operators[:len(operators)-1]
if op == "(" {
return nil, fmt.Errorf("mismatched parentheses")
}
if len(output) < 2 {
return nil, fmt.Errorf("not enough operands for operator: %s", op)
}
right := output[len(output)-1]
left := output[len(output)-2]
output = output[:len(output)-2]
output = append(output, &ExprNode{
Type: TypeOperator,
Value: op,
Left: left,
Right: right,
})
}
if len(output) != 1 {
return nil, fmt.Errorf("invalid expression")
}
return output[0], nil
}
// parseFunctionArgs 解析函数参数
func parseFunctionArgs(tokens []string, startIndex int) ([]*ExprNode, int, error) {
args := make([]*ExprNode, 0)
i := startIndex
// 处理空参数列表
if i < len(tokens) && tokens[i] == ")" {
return args, i + 1, nil
}
for i < len(tokens) {
// 解析参数表达式
argTokens := make([]string, 0)
parenthesesCount := 0
for i < len(tokens) {
token := tokens[i]
if token == "(" {
parenthesesCount++
} else if token == ")" {
parenthesesCount--
if parenthesesCount < 0 {
break
}
} else if token == "," && parenthesesCount == 0 {
break
}
argTokens = append(argTokens, token)
i++
}
if len(argTokens) > 0 {
arg, err := parseExpression(argTokens)
if err != nil {
return nil, 0, err
}
args = append(args, arg)
}
if i >= len(tokens) {
return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments")
}
if tokens[i] == ")" {
return args, i + 1, nil
}
if tokens[i] == "," {
i++
continue
}
return nil, 0, fmt.Errorf("unexpected token in function arguments: %s", tokens[i])
}
return nil, 0, fmt.Errorf("unexpected end of tokens in function arguments")
}
// 辅助函数
func isDigit(ch byte) bool {
return ch >= '0' && ch <= '9'
}
func isLetter(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
func isNumber(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}
func isIdentifier(s string) bool {
if len(s) == 0 {
return false
}
if !isLetter(s[0]) && s[0] != '_' {
return false
}
for i := 1; i < len(s); i++ {
if !isLetter(s[i]) && !isDigit(s[i]) && s[i] != '_' {
return false
}
}
return true
}
func isOperator(s string) bool {
_, ok := operatorPrecedence[s]
return ok
}