Files
streamsql/rsql/function_validator_test.go
2025-06-11 18:46:32 +08:00

196 lines
5.1 KiB
Go

package rsql
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFunctionValidator_ValidateExpression(t *testing.T) {
tests := []struct {
name string
expression string
expectedErrors int
errorType ErrorType
errorMessage string
}{
{
name: "Valid builtin function",
expression: "abs(temperature)",
expectedErrors: 0,
},
{
name: "Valid nested builtin functions",
expression: "sqrt(abs(temperature))",
expectedErrors: 0,
},
{
name: "Unknown function",
expression: "unknown_func(temperature)",
expectedErrors: 1,
errorType: ErrorTypeUnknownFunction,
errorMessage: "unknown_func",
},
{
name: "Multiple unknown functions",
expression: "unknown1(temperature) + unknown2(humidity)",
expectedErrors: 2,
errorType: ErrorTypeUnknownFunction,
},
{
name: "Mixed valid and invalid functions",
expression: "abs(temperature) + unknown_func(humidity)",
expectedErrors: 1,
errorType: ErrorTypeUnknownFunction,
errorMessage: "unknown_func",
},
{
name: "No functions in expression",
expression: "temperature + humidity",
expectedErrors: 0,
},
{
name: "Function with complex arguments",
expression: "abs(temperature * 2 + humidity)",
expectedErrors: 0,
},
{
name: "Keyword should not be treated as function",
expression: "CASE WHEN temperature > 0 THEN 1 ELSE 0 END",
expectedErrors: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errorRecovery := &ErrorRecovery{}
validator := NewFunctionValidator(errorRecovery)
validator.ValidateExpression(tt.expression, 0)
errors := errorRecovery.GetErrors()
assert.Equal(t, tt.expectedErrors, len(errors), "Expected %d errors, got %d", tt.expectedErrors, len(errors))
if tt.expectedErrors > 0 {
assert.Equal(t, tt.errorType, errors[0].Type, "Expected error type %v, got %v", tt.errorType, errors[0].Type)
if tt.errorMessage != "" {
assert.Contains(t, errors[0].Message, tt.errorMessage, "Error message should contain %s", tt.errorMessage)
}
}
})
}
}
func TestFunctionValidator_ExtractFunctionCalls(t *testing.T) {
tests := []struct {
name string
expression string
expected []FunctionCall
}{
{
name: "Single function",
expression: "abs(x)",
expected: []FunctionCall{
{Name: "abs", Position: 0},
},
},
{
name: "Multiple functions",
expression: "abs(x) + sqrt(y)",
expected: []FunctionCall{
{Name: "abs", Position: 0},
{Name: "sqrt", Position: 9},
},
},
{
name: "Nested functions",
expression: "sqrt(abs(x))",
expected: []FunctionCall{
{Name: "sqrt", Position: 0},
{Name: "abs", Position: 5},
},
},
{
name: "Function with spaces",
expression: "abs ( x )",
expected: []FunctionCall{
{Name: "abs", Position: 0},
},
},
{
name: "No functions",
expression: "x + y * 2",
expected: []FunctionCall{},
},
{
name: "Keywords should be filtered",
expression: "CASE(x) WHEN(y)",
expected: []FunctionCall{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errorRecovery := &ErrorRecovery{}
validator := NewFunctionValidator(errorRecovery)
result := validator.extractFunctionCalls(tt.expression)
assert.Equal(t, len(tt.expected), len(result), "Expected %d function calls, got %d", len(tt.expected), len(result))
for i, expected := range tt.expected {
if i < len(result) {
assert.Equal(t, expected.Name, result[i].Name, "Expected function name %s, got %s", expected.Name, result[i].Name)
assert.Equal(t, expected.Position, result[i].Position, "Expected position %d, got %d", expected.Position, result[i].Position)
}
}
})
}
}
func TestFunctionValidator_IsBuiltinFunction(t *testing.T) {
tests := []struct {
name string
funcName string
expected bool
}{
{"abs function", "abs", true},
{"ABS function (case insensitive)", "ABS", true},
{"sqrt function", "sqrt", true},
{"unknown function", "unknown_func", false},
{"empty string", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errorRecovery := &ErrorRecovery{}
validator := NewFunctionValidator(errorRecovery)
result := validator.isBuiltinFunction(tt.funcName)
assert.Equal(t, tt.expected, result, "Expected %v for function %s", tt.expected, tt.funcName)
})
}
}
func TestFunctionValidator_IsKeyword(t *testing.T) {
tests := []struct {
name string
word string
expected bool
}{
{"SELECT keyword", "SELECT", true},
{"select keyword (case insensitive)", "select", true},
{"CASE keyword", "CASE", true},
{"regular identifier", "temperature", false},
{"function name", "abs", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errorRecovery := &ErrorRecovery{}
validator := NewFunctionValidator(errorRecovery)
result := validator.isKeyword(tt.word)
assert.Equal(t, tt.expected, result, "Expected %v for word %s", tt.expected, tt.word)
})
}
}