forked from GiteaTest2015/streamsql
test:增加测试用例
This commit is contained in:
+185
-2
@@ -125,6 +125,11 @@ func validateBasicSyntax(exprStr string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// 检查表达式开头和结尾的运算符
|
||||
if err := checkExpressionStartEnd(trimmed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查连续运算符
|
||||
if err := checkConsecutiveOperators(trimmed); err != nil {
|
||||
return err
|
||||
@@ -133,6 +138,27 @@ func validateBasicSyntax(exprStr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkExpressionStartEnd checks if expression starts or ends with an operator
|
||||
func checkExpressionStartEnd(expr string) error {
|
||||
operators := []string{"+", "*", "/", "%", "^", "==", "!=", ">=", "<=", ">", "<"}
|
||||
|
||||
// 检查表达式开头(允许负号,因为它是合法的负数表示)
|
||||
for _, op := range operators {
|
||||
if strings.HasPrefix(expr, op) {
|
||||
return fmt.Errorf("expression cannot start with operator")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查表达式结尾
|
||||
for _, op := range operators {
|
||||
if strings.HasSuffix(expr, op) {
|
||||
return fmt.Errorf("expression cannot end with operator")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkConsecutiveOperators checks for consecutive operators
|
||||
func checkConsecutiveOperators(expr string) error {
|
||||
// Simplified consecutive operator check: look for obvious double operator patterns
|
||||
@@ -191,6 +217,20 @@ func checkConsecutiveOperators(expr string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// 特殊处理:如果当前是幂运算符(^),下一个是负号,且负号后跟数字,则允许
|
||||
if currentOp == "^" && nextPos < len(expr) && expr[nextPos] == '-' {
|
||||
// 检查负号后是否跟数字
|
||||
digitPos := nextPos + 1
|
||||
for digitPos < len(expr) && (expr[digitPos] == ' ' || expr[digitPos] == '\t') {
|
||||
digitPos++
|
||||
}
|
||||
if digitPos < len(expr) && expr[digitPos] >= '0' && expr[digitPos] <= '9' {
|
||||
// 这是幂运算符后跟负数,允许通过
|
||||
i = nextPos // 跳过到负号位置
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 检查其他连续运算符
|
||||
for _, op := range operators {
|
||||
if nextPos+len(op) <= len(expr) && expr[nextPos:nextPos+len(op)] == op {
|
||||
@@ -453,7 +493,31 @@ func evaluateNode(node *ExprNode, data map[string]interface{}) (float64, error)
|
||||
return 0, fmt.Errorf("field '%s' not found", fieldName)
|
||||
|
||||
case TypeOperator:
|
||||
// Calculate values of left and right sub-expressions
|
||||
// Check if this is a comparison operator
|
||||
if isComparisonOperator(node.Value) {
|
||||
// For comparison operators, use evaluateNodeValue to get original types
|
||||
leftValue, err := evaluateNodeValue(node.Left, data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rightValue, err := evaluateNodeValue(node.Right, data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Perform comparison and convert boolean to number
|
||||
result, err := compareValues(leftValue, rightValue, node.Value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if result {
|
||||
return 1.0, nil
|
||||
}
|
||||
return 0.0, nil
|
||||
}
|
||||
|
||||
// For arithmetic operators, calculate numeric values
|
||||
left, err := evaluateNode(node.Left, data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -640,6 +704,107 @@ func evaluateBuiltinFunction(node *ExprNode, data map[string]interface{}) (float
|
||||
}
|
||||
return math.Round(arg), nil
|
||||
|
||||
case "pow":
|
||||
if len(node.Args) != 2 {
|
||||
return 0, fmt.Errorf("pow function requires exactly 2 arguments")
|
||||
}
|
||||
base, err := evaluateNode(node.Args[0], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
exponent, err := evaluateNode(node.Args[1], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return math.Pow(base, exponent), nil
|
||||
|
||||
case "max":
|
||||
if len(node.Args) < 1 {
|
||||
return 0, fmt.Errorf("max function requires at least 1 argument")
|
||||
}
|
||||
maxVal, err := evaluateNode(node.Args[0], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := 1; i < len(node.Args); i++ {
|
||||
arg, err := evaluateNode(node.Args[i], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if arg > maxVal {
|
||||
maxVal = arg
|
||||
}
|
||||
}
|
||||
return maxVal, nil
|
||||
|
||||
case "min":
|
||||
if len(node.Args) < 1 {
|
||||
return 0, fmt.Errorf("min function requires at least 1 argument")
|
||||
}
|
||||
minVal, err := evaluateNode(node.Args[0], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := 1; i < len(node.Args); i++ {
|
||||
arg, err := evaluateNode(node.Args[i], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if arg < minVal {
|
||||
minVal = arg
|
||||
}
|
||||
}
|
||||
return minVal, nil
|
||||
|
||||
case "log":
|
||||
if len(node.Args) != 1 {
|
||||
return 0, fmt.Errorf("log function requires exactly 1 argument")
|
||||
}
|
||||
arg, err := evaluateNode(node.Args[0], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if arg <= 0 {
|
||||
return 0, fmt.Errorf("log of non-positive number")
|
||||
}
|
||||
return math.Log(arg), nil
|
||||
|
||||
case "log10":
|
||||
if len(node.Args) != 1 {
|
||||
return 0, fmt.Errorf("log10 function requires exactly 1 argument")
|
||||
}
|
||||
arg, err := evaluateNode(node.Args[0], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if arg <= 0 {
|
||||
return 0, fmt.Errorf("log10 of non-positive number")
|
||||
}
|
||||
return math.Log10(arg), nil
|
||||
|
||||
case "exp":
|
||||
if len(node.Args) != 1 {
|
||||
return 0, fmt.Errorf("exp function requires exactly 1 argument")
|
||||
}
|
||||
arg, err := evaluateNode(node.Args[0], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return math.Exp(arg), nil
|
||||
|
||||
case "len":
|
||||
if len(node.Args) != 1 {
|
||||
return 0, fmt.Errorf("len function requires exactly 1 argument")
|
||||
}
|
||||
// Use evaluateNodeValue to get the original value
|
||||
arg, err := evaluateNodeValue(node.Args[0], data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Convert to string and get length
|
||||
strVal := fmt.Sprintf("%v", arg)
|
||||
return float64(len(strVal)), nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown function: %s", node.Value)
|
||||
}
|
||||
@@ -957,8 +1122,14 @@ func likeMatch(text, pattern string, textIndex, patternIndex int) bool {
|
||||
func convertToFloat(val interface{}) (float64, error) {
|
||||
switch v := val.(type) {
|
||||
case float64:
|
||||
if math.IsNaN(v) {
|
||||
return 0, fmt.Errorf("NaN value detected")
|
||||
}
|
||||
return v, nil
|
||||
case float32:
|
||||
if math.IsNaN(float64(v)) {
|
||||
return 0, fmt.Errorf("NaN value detected")
|
||||
}
|
||||
return float64(v), nil
|
||||
case int:
|
||||
return float64(v), nil
|
||||
@@ -966,8 +1137,20 @@ func convertToFloat(val interface{}) (float64, error) {
|
||||
return float64(v), nil
|
||||
case int64:
|
||||
return float64(v), nil
|
||||
case bool:
|
||||
if v {
|
||||
return 1.0, nil
|
||||
}
|
||||
return 0.0, nil
|
||||
case string:
|
||||
return strconv.ParseFloat(v, 64)
|
||||
f, err := strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if math.IsNaN(f) {
|
||||
return 0, fmt.Errorf("NaN value detected")
|
||||
}
|
||||
return f, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("cannot convert %T to float64", val)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user