mirror of
https://gitee.com/rulego/streamsql.git
synced 2025-07-07 00:10:55 +00:00
196 lines
5.1 KiB
Go
196 lines
5.1 KiB
Go
package rsql
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestFunctionValidator_ValidateExpression(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
expression string
|
|
expectedErrors int
|
|
errorType ErrorType
|
|
errorMessage string
|
|
}{
|
|
{
|
|
name: "Valid builtin function",
|
|
expression: "abs(temperature)",
|
|
expectedErrors: 0,
|
|
},
|
|
{
|
|
name: "Valid nested builtin functions",
|
|
expression: "sqrt(abs(temperature))",
|
|
expectedErrors: 0,
|
|
},
|
|
{
|
|
name: "Unknown function",
|
|
expression: "unknown_func(temperature)",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownFunction,
|
|
errorMessage: "unknown_func",
|
|
},
|
|
{
|
|
name: "Multiple unknown functions",
|
|
expression: "unknown1(temperature) + unknown2(humidity)",
|
|
expectedErrors: 2,
|
|
errorType: ErrorTypeUnknownFunction,
|
|
},
|
|
{
|
|
name: "Mixed valid and invalid functions",
|
|
expression: "abs(temperature) + unknown_func(humidity)",
|
|
expectedErrors: 1,
|
|
errorType: ErrorTypeUnknownFunction,
|
|
errorMessage: "unknown_func",
|
|
},
|
|
{
|
|
name: "No functions in expression",
|
|
expression: "temperature + humidity",
|
|
expectedErrors: 0,
|
|
},
|
|
{
|
|
name: "Function with complex arguments",
|
|
expression: "abs(temperature * 2 + humidity)",
|
|
expectedErrors: 0,
|
|
},
|
|
{
|
|
name: "Keyword should not be treated as function",
|
|
expression: "CASE WHEN temperature > 0 THEN 1 ELSE 0 END",
|
|
expectedErrors: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
errorRecovery := &ErrorRecovery{}
|
|
validator := NewFunctionValidator(errorRecovery)
|
|
|
|
validator.ValidateExpression(tt.expression, 0)
|
|
|
|
errors := errorRecovery.GetErrors()
|
|
assert.Equal(t, tt.expectedErrors, len(errors), "Expected %d errors, got %d", tt.expectedErrors, len(errors))
|
|
|
|
if tt.expectedErrors > 0 {
|
|
assert.Equal(t, tt.errorType, errors[0].Type, "Expected error type %v, got %v", tt.errorType, errors[0].Type)
|
|
if tt.errorMessage != "" {
|
|
assert.Contains(t, errors[0].Message, tt.errorMessage, "Error message should contain %s", tt.errorMessage)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFunctionValidator_ExtractFunctionCalls(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
expression string
|
|
expected []FunctionCall
|
|
}{
|
|
{
|
|
name: "Single function",
|
|
expression: "abs(x)",
|
|
expected: []FunctionCall{
|
|
{Name: "abs", Position: 0},
|
|
},
|
|
},
|
|
{
|
|
name: "Multiple functions",
|
|
expression: "abs(x) + sqrt(y)",
|
|
expected: []FunctionCall{
|
|
{Name: "abs", Position: 0},
|
|
{Name: "sqrt", Position: 9},
|
|
},
|
|
},
|
|
{
|
|
name: "Nested functions",
|
|
expression: "sqrt(abs(x))",
|
|
expected: []FunctionCall{
|
|
{Name: "sqrt", Position: 0},
|
|
{Name: "abs", Position: 5},
|
|
},
|
|
},
|
|
{
|
|
name: "Function with spaces",
|
|
expression: "abs ( x )",
|
|
expected: []FunctionCall{
|
|
{Name: "abs", Position: 0},
|
|
},
|
|
},
|
|
{
|
|
name: "No functions",
|
|
expression: "x + y * 2",
|
|
expected: []FunctionCall{},
|
|
},
|
|
{
|
|
name: "Keywords should be filtered",
|
|
expression: "CASE(x) WHEN(y)",
|
|
expected: []FunctionCall{},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
errorRecovery := &ErrorRecovery{}
|
|
validator := NewFunctionValidator(errorRecovery)
|
|
|
|
result := validator.extractFunctionCalls(tt.expression)
|
|
assert.Equal(t, len(tt.expected), len(result), "Expected %d function calls, got %d", len(tt.expected), len(result))
|
|
|
|
for i, expected := range tt.expected {
|
|
if i < len(result) {
|
|
assert.Equal(t, expected.Name, result[i].Name, "Expected function name %s, got %s", expected.Name, result[i].Name)
|
|
assert.Equal(t, expected.Position, result[i].Position, "Expected position %d, got %d", expected.Position, result[i].Position)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFunctionValidator_IsBuiltinFunction(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
funcName string
|
|
expected bool
|
|
}{
|
|
{"abs function", "abs", true},
|
|
{"ABS function (case insensitive)", "ABS", true},
|
|
{"sqrt function", "sqrt", true},
|
|
{"unknown function", "unknown_func", false},
|
|
{"empty string", "", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
errorRecovery := &ErrorRecovery{}
|
|
validator := NewFunctionValidator(errorRecovery)
|
|
|
|
result := validator.isBuiltinFunction(tt.funcName)
|
|
assert.Equal(t, tt.expected, result, "Expected %v for function %s", tt.expected, tt.funcName)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFunctionValidator_IsKeyword(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
word string
|
|
expected bool
|
|
}{
|
|
{"SELECT keyword", "SELECT", true},
|
|
{"select keyword (case insensitive)", "select", true},
|
|
{"CASE keyword", "CASE", true},
|
|
{"regular identifier", "temperature", false},
|
|
{"function name", "abs", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
errorRecovery := &ErrorRecovery{}
|
|
validator := NewFunctionValidator(errorRecovery)
|
|
|
|
result := validator.isKeyword(tt.word)
|
|
assert.Equal(t, tt.expected, result, "Expected %v for word %s", tt.expected, tt.word)
|
|
})
|
|
}
|
|
} |