add agent tests
This commit is contained in:
@@ -216,3 +216,280 @@ async def graphs_handler(db_provider):
|
||||
)
|
||||
await handler.create_tables()
|
||||
return handler
|
||||
|
||||
# Citation testing fixtures and utilities
|
||||
import json
|
||||
import re
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from typing import Tuple, Any, AsyncGenerator
|
||||
|
||||
from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
|
||||
from core.utils import CitationTracker, SearchResultsCollector
|
||||
from core.agent.base import R2RStreamingAgent
|
||||
|
||||
|
||||
class MockLLMProvider:
|
||||
"""Mock LLM provider for testing."""
|
||||
|
||||
def __init__(self, response_content=None, citations=None):
|
||||
self.response_content = response_content or "This is a response"
|
||||
self.citations = citations or []
|
||||
|
||||
async def aget_completion(self, messages, generation_config):
|
||||
"""Mock synchronous completion."""
|
||||
content = self.response_content
|
||||
for citation in self.citations:
|
||||
content += f" [{citation}]"
|
||||
|
||||
mock_response = MagicMock(spec=LLMChatCompletion)
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message = MagicMock()
|
||||
mock_response.choices[0].message.content = content
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
return mock_response
|
||||
|
||||
async def aget_completion_stream(self, messages, generation_config):
|
||||
"""Mock streaming completion."""
|
||||
content = self.response_content
|
||||
for citation in self.citations:
|
||||
content += f" [{citation}]"
|
||||
|
||||
# Simulate streaming by yielding one character at a time
|
||||
for i in range(len(content)):
|
||||
chunk = MagicMock(spec=LLMChatCompletionChunk)
|
||||
chunk.choices = [MagicMock()]
|
||||
chunk.choices[0].delta = MagicMock()
|
||||
chunk.choices[0].delta.content = content[i]
|
||||
chunk.choices[0].finish_reason = None
|
||||
yield chunk
|
||||
|
||||
# Final chunk with finish_reason="stop"
|
||||
final_chunk = MagicMock(spec=LLMChatCompletionChunk)
|
||||
final_chunk.choices = [MagicMock()]
|
||||
final_chunk.choices[0].delta = MagicMock()
|
||||
final_chunk.choices[0].delta.content = ""
|
||||
final_chunk.choices[0].finish_reason = "stop"
|
||||
yield final_chunk
|
||||
|
||||
|
||||
class MockPromptsHandler:
|
||||
"""Mock prompts handler for testing."""
|
||||
|
||||
async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
|
||||
"""Return a mock system prompt."""
|
||||
return "You are a helpful assistant that provides well-sourced information."
|
||||
|
||||
|
||||
class MockDatabaseProvider:
|
||||
"""Mock database provider for testing."""
|
||||
|
||||
def __init__(self):
|
||||
# Add a prompts_handler attribute to prevent AttributeError
|
||||
self.prompts_handler = MockPromptsHandler()
|
||||
|
||||
async def acreate_conversation(self, *args, **kwargs):
|
||||
return {"id": "conv_12345"}
|
||||
|
||||
async def aupdate_conversation(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
async def acreate_message(self, *args, **kwargs):
|
||||
return {"id": "msg_12345"}
|
||||
|
||||
|
||||
class MockSearchResultsCollector:
|
||||
"""Mock search results collector for testing."""
|
||||
|
||||
def __init__(self, results=None):
|
||||
self.results = results or {}
|
||||
|
||||
def find_by_short_id(self, short_id):
|
||||
return self.results.get(short_id, {
|
||||
"document_id": f"doc_{short_id}",
|
||||
"text": f"This is document text for {short_id}",
|
||||
"metadata": {"source": f"source_{short_id}"}
|
||||
})
|
||||
|
||||
|
||||
# Create a concrete implementation of R2RStreamingAgent for testing
|
||||
class MockR2RStreamingAgent(R2RStreamingAgent):
|
||||
"""Mock streaming agent for testing that implements the abstract method."""
|
||||
|
||||
# Regex pattern for citations, copied from the actual agent
|
||||
BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
|
||||
SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
|
||||
|
||||
def _register_tools(self):
|
||||
"""Implement the abstract method with a no-op version."""
|
||||
pass
|
||||
|
||||
async def _setup(self, system_instruction=None, *args, **kwargs):
|
||||
"""Override _setup to simplify initialization and avoid external dependencies."""
|
||||
# Use a simple system message instead of fetching from database
|
||||
system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
|
||||
|
||||
# Add system message to conversation
|
||||
await self.conversation.add_message(
|
||||
Message(role="system", content=system_content)
|
||||
)
|
||||
|
||||
def _format_sse_event(self, event_type, data):
|
||||
"""Format an SSE event manually."""
|
||||
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
system_instruction: str = None,
|
||||
messages: list[Message] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Simplified version of arun that focuses on citation handling for testing.
|
||||
"""
|
||||
await self._setup(system_instruction)
|
||||
|
||||
if messages:
|
||||
for m in messages:
|
||||
await self.conversation.add_message(m)
|
||||
|
||||
# Initialize citation tracker
|
||||
citation_tracker = CitationTracker()
|
||||
citation_payloads = {}
|
||||
|
||||
# Track streaming citations for final persistence
|
||||
self.streaming_citations = []
|
||||
|
||||
# Get the LLM response with citations
|
||||
response_content = "This is a test response with citations"
|
||||
response_content += " [abc1234] [def5678]"
|
||||
|
||||
# Yield an initial message event with the start of the text
|
||||
yield self._format_sse_event("message", {"content": response_content})
|
||||
|
||||
# Manually extract and emit citation events
|
||||
# This is a simpler approach than the character-by-character approach
|
||||
citation_spans = extract_citation_spans(response_content)
|
||||
|
||||
# Process the citations
|
||||
for cid, spans in citation_spans.items():
|
||||
for span in spans:
|
||||
# Check if the span is new and record it
|
||||
if citation_tracker.is_new_span(cid, span):
|
||||
|
||||
# Look up the source document for this citation
|
||||
source_doc = self.search_results_collector.find_by_short_id(cid)
|
||||
|
||||
# Create citation payload
|
||||
citation_payload = {
|
||||
"document_id": source_doc.get("document_id", f"doc_{cid}"),
|
||||
"text": source_doc.get("text", f"This is document text for {cid}"),
|
||||
"metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
|
||||
}
|
||||
|
||||
# Store the payload by citation ID
|
||||
citation_payloads[cid] = citation_payload
|
||||
|
||||
# Track for persistence
|
||||
self.streaming_citations.append({
|
||||
"id": cid,
|
||||
"span": {"start": span[0], "end": span[1]},
|
||||
"payload": citation_payload
|
||||
})
|
||||
|
||||
# Emit citation event in the expected format
|
||||
citation_event = {
|
||||
"id": cid,
|
||||
"object": "citation",
|
||||
"span": {"start": span[0], "end": span[1]},
|
||||
"payload": citation_payload
|
||||
}
|
||||
|
||||
yield self._format_sse_event("citation", citation_event)
|
||||
|
||||
# Add assistant message with citation metadata to conversation
|
||||
await self.conversation.add_message(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
metadata={"citations": self.streaming_citations}
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare consolidated citations for final answer
|
||||
consolidated_citations = []
|
||||
|
||||
# Group citations by ID with all their spans
|
||||
for cid, spans in citation_tracker.get_all_spans().items():
|
||||
if cid in citation_payloads:
|
||||
consolidated_citations.append({
|
||||
"id": cid,
|
||||
"object": "citation",
|
||||
"spans": [{"start": s[0], "end": s[1]} for s in spans],
|
||||
"payload": citation_payloads[cid]
|
||||
})
|
||||
|
||||
# Create and emit final answer event
|
||||
final_evt_payload = {
|
||||
"id": "msg_final",
|
||||
"object": "agent.final_answer",
|
||||
"generated_answer": response_content,
|
||||
"citations": consolidated_citations
|
||||
}
|
||||
|
||||
# Manually format the final answer event
|
||||
yield self._format_sse_event("agent.final_answer", final_evt_payload)
|
||||
|
||||
# Signal the end of the SSE stream
|
||||
yield "event: done\ndata: {}\n\n"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_streaming_agent():
|
||||
"""Create a streaming agent with mocked dependencies."""
|
||||
# Create mock config
|
||||
config = MagicMock()
|
||||
config.stream = True
|
||||
config.max_iterations = 3
|
||||
|
||||
# Create mock providers
|
||||
llm_provider = MockLLMProvider(
|
||||
response_content="This is a test response with citations",
|
||||
citations=["abc1234", "def5678"]
|
||||
)
|
||||
db_provider = MockDatabaseProvider()
|
||||
|
||||
# Create agent with mocked dependencies using our concrete implementation
|
||||
agent = MockR2RStreamingAgent(
|
||||
database_provider=db_provider,
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
rag_generation_config=GenerationConfig(model="test/model")
|
||||
)
|
||||
|
||||
# Replace the search results collector with our mock
|
||||
agent.search_results_collector = MockSearchResultsCollector({
|
||||
"abc1234": {
|
||||
"document_id": "doc_abc1234",
|
||||
"text": "This is document text for abc1234",
|
||||
"metadata": {"source": "source_abc1234"}
|
||||
},
|
||||
"def5678": {
|
||||
"document_id": "doc_def5678",
|
||||
"text": "This is document text for def5678",
|
||||
"metadata": {"source": "source_def5678"}
|
||||
}
|
||||
})
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
async def collect_stream_output(stream):
|
||||
"""Collect all output from a stream into a list."""
|
||||
output = []
|
||||
async for event in stream:
|
||||
output.append(event)
|
||||
return output
|
||||
|
||||
|
||||
from core.utils import extract_citation_spans, find_new_citation_spans
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -153,7 +153,7 @@ class TestFilterTypeConversions:
|
||||
filters = {"id": None}
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
# Different implementations might handle NULL differently
|
||||
assert ("IS NULL" in sql or "= NULL" in sql or
|
||||
assert ("IS NULL" in sql or "= NULL" in sql or
|
||||
(sql.strip().endswith("= $1") and (not params or params == [None]))), "Should handle null in top-level column"
|
||||
|
||||
# Metadata field
|
||||
@@ -211,9 +211,9 @@ class TestRealWorldQueries:
|
||||
{"metadata.score": {SimplifiedFilterOperator.GT: 80}}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
|
||||
|
||||
# Basic checks
|
||||
assert " AND " in sql, "Should use AND to combine all conditions"
|
||||
assert "collection_ids &&" in sql, "Should use overlap operator for collections"
|
||||
@@ -222,13 +222,13 @@ class TestRealWorldQueries:
|
||||
assert "metadata->>'status'" in sql or "metadata->'status'" in sql, "Should compare status as text"
|
||||
assert "@>" in sql, "Should use containment for tags"
|
||||
assert "score" in sql and "numeric" in sql, "Should compare score as numeric"
|
||||
|
||||
|
||||
# Check parameters
|
||||
assert ["collection1", "collection2"] in params, "Should include collection IDs"
|
||||
assert "2021-01-01" in params, "Should include date string"
|
||||
assert "active" in params, "Should include status"
|
||||
assert "80" in params, "Should include score (as string)"
|
||||
|
||||
|
||||
# At least one parameter should be a JSON string for tags
|
||||
assert any(isinstance(p, str) and "important" in p for p in params), "Should include JSON-encoded tags"
|
||||
|
||||
@@ -246,15 +246,15 @@ class TestRealWorldQueries:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
|
||||
|
||||
# Basic checks
|
||||
assert " OR " in sql, "Should use OR for pagination options"
|
||||
assert " AND " in sql, "Should use AND for tie-breaker"
|
||||
assert "metadata->>'created_at'" in sql, "Should reference created_at"
|
||||
assert "id <" in sql, "Should have ID comparison"
|
||||
|
||||
|
||||
# Check parameters
|
||||
assert "2023-01-01" in params, "Should include date twice"
|
||||
assert params.count("2023-01-01") == 2, "Date should appear twice (LT and EQ)"
|
||||
@@ -266,9 +266,9 @@ class TestRealWorldQueries:
|
||||
filters = {
|
||||
"metadata.level1.level2.level3.value": {SimplifiedFilterOperator.GT: 100}
|
||||
}
|
||||
|
||||
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
|
||||
|
||||
# Check JSON path navigation
|
||||
expected_path = "metadata->'level1'->'level2'->'level3'->>'value'"
|
||||
assert expected_path in sql, "Should properly navigate nested JSON path"
|
||||
@@ -285,16 +285,16 @@ class TestRealWorldQueries:
|
||||
{"metadata.content": {SimplifiedFilterOperator.ILIKE: "%search term%"}}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
|
||||
|
||||
# Check basic structure
|
||||
assert " OR " in sql, "Should use OR to combine text search conditions"
|
||||
# Different implementations can use different JSON extraction operators
|
||||
assert "metadata" in sql and "title" in sql and "ILIKE" in sql, "Should search in title with ILIKE"
|
||||
assert "metadata" in sql and "description" in sql, "Should search in description"
|
||||
assert "metadata" in sql and "content" in sql, "Should search in content"
|
||||
|
||||
|
||||
# Check parameters
|
||||
assert all("%search term%" in p for p in params), "All parameters should contain search term"
|
||||
|
||||
@@ -308,11 +308,11 @@ class TestCornerCases:
|
||||
many_conditions = []
|
||||
for i in range(50): # 50 conditions
|
||||
many_conditions.append({"metadata.field" + str(i): "value" + str(i)})
|
||||
|
||||
|
||||
filters = {SimplifiedFilterOperator.AND: many_conditions}
|
||||
|
||||
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
|
||||
|
||||
# Basic checks
|
||||
assert " AND " in sql, "Should use AND to combine conditions"
|
||||
assert len(params) == 50, "Should have 50 parameters"
|
||||
@@ -328,15 +328,15 @@ class TestCornerCases:
|
||||
{"metadata.score": {SimplifiedFilterOperator.LT: 100}}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
|
||||
|
||||
# Check structure
|
||||
assert " AND " in sql, "Should use AND to combine range bounds"
|
||||
assert "metadata->>'score'" in sql, "Should reference score field"
|
||||
assert "::numeric >=" in sql, "Should have GTE comparison"
|
||||
assert "::numeric <" in sql, "Should have LT comparison"
|
||||
|
||||
|
||||
# Check parameters
|
||||
assert "50" in params, "Should include lower bound"
|
||||
assert "100" in params, "Should include upper bound"
|
||||
@@ -348,7 +348,7 @@ class TestCornerCases:
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
# The behavior depends on implementation - could be "FALSE" or empty list handling
|
||||
assert "FALSE" in sql.upper() or ("ANY" in sql and "[]" in str(params)), "Should handle empty IN list"
|
||||
|
||||
|
||||
# Empty CONTAINS list
|
||||
filters = {"metadata.tags": {SimplifiedFilterOperator.CONTAINS: []}}
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
@@ -363,7 +363,7 @@ class TestCornerCases:
|
||||
# Should be safe because of parameterization
|
||||
assert "id =" in sql, "Should handle value normally"
|
||||
assert params == ["value'; DROP TABLE users; --"], "Should include value as parameter"
|
||||
|
||||
|
||||
# Very long string
|
||||
long_string = "x" * 1000 # 1000 character string
|
||||
filters = {"metadata.field": long_string}
|
||||
@@ -378,7 +378,7 @@ class TestCornerCases:
|
||||
sql, params = simplified_apply_filters(filters, [])
|
||||
assert "metadata->>'title'" in sql, "Should handle Unicode normally"
|
||||
assert params == ["😀 Unicode 测试"], "Should include Unicode string as parameter"
|
||||
|
||||
|
||||
# Unicode in field name (might not be supported in current implementation)
|
||||
try:
|
||||
filters = {"metadata.标题": "value"}
|
||||
@@ -590,11 +590,11 @@ class TestMetadataFilters:
|
||||
# With prefix
|
||||
filters1 = {"metadata.key": "value"}
|
||||
sql1, params1 = main_apply_filters(filters1, [])
|
||||
|
||||
|
||||
# Without prefix (should be treated as metadata)
|
||||
filters2 = {"key": "value"}
|
||||
sql2, params2 = main_apply_filters(filters2, [], top_level_columns=["id", "owner_id"])
|
||||
|
||||
|
||||
# Both should produce the same SQL
|
||||
assert "metadata" in sql2, "Field not in top_level_columns should be treated as metadata"
|
||||
assert params1 == params2, "Parameters should be the same"
|
||||
@@ -678,15 +678,15 @@ class TestEdgeCases:
|
||||
def test_filter_modes(self):
|
||||
"""Test different filter modes (where_clause, condition_only, append_only)."""
|
||||
filters = {"id": "test-id"}
|
||||
|
||||
|
||||
# Default where_clause mode
|
||||
sql1, _ = main_apply_filters(filters, [])
|
||||
assert sql1.startswith("WHERE"), "Default mode should prepend WHERE"
|
||||
|
||||
|
||||
# condition_only mode
|
||||
sql2, _ = main_apply_filters(filters, [], mode="condition_only")
|
||||
assert not sql2.startswith("WHERE"), "condition_only mode should not prepend WHERE"
|
||||
|
||||
|
||||
# append_only mode
|
||||
sql3, _ = main_apply_filters(filters, [], mode="append_only")
|
||||
assert sql3.startswith("AND"), "append_only mode should prepend AND"
|
||||
@@ -718,13 +718,13 @@ class TestEdgeCases:
|
||||
"""Test with custom top_level_columns parameter."""
|
||||
# Define a custom set of top-level columns
|
||||
custom_columns = ["id", "custom_field"]
|
||||
|
||||
|
||||
# Test a field that's in custom_columns
|
||||
filters = {"custom_field": "value"}
|
||||
sql, _ = main_apply_filters(filters, [], top_level_columns=custom_columns)
|
||||
assert "custom_field =" in sql, "Should treat custom_field as a normal column"
|
||||
assert "metadata" not in sql, "Should not treat custom_field as metadata"
|
||||
|
||||
|
||||
# Test a field that's not in custom_columns
|
||||
filters = {"other_field": "value"}
|
||||
sql, _ = main_apply_filters(filters, [], top_level_columns=custom_columns)
|
||||
@@ -734,7 +734,7 @@ class TestEdgeCases:
|
||||
"""Test with custom json_column parameter."""
|
||||
# Use a custom json column name
|
||||
custom_json = "properties"
|
||||
|
||||
|
||||
filters = {"field": "value"}
|
||||
sql, _ = main_apply_filters(filters, [], top_level_columns=["id"], json_column=custom_json)
|
||||
assert custom_json in sql, f"Should use {custom_json} instead of metadata"
|
||||
@@ -759,23 +759,23 @@ class TestComplexFilterCombinations:
|
||||
]
|
||||
}
|
||||
sql, params = main_apply_filters(filters, [])
|
||||
|
||||
|
||||
# Check for AND operator
|
||||
assert " AND " in sql, "SQL should contain AND operator"
|
||||
|
||||
|
||||
# Check for metadata handling
|
||||
assert "metadata->>'score'" in sql, "SQL should handle metadata field"
|
||||
assert "::numeric >=" in sql, "SQL should handle numeric comparison"
|
||||
|
||||
|
||||
# Check for OR operator
|
||||
assert " OR " in sql, "SQL should contain OR operator"
|
||||
|
||||
|
||||
# Check for collection_id handling
|
||||
assert "collection_ids" in sql, "SQL should handle collection_id"
|
||||
|
||||
|
||||
# Check for parent_id handling
|
||||
assert "parent_id = ANY" in sql, "SQL should handle parent_id IN condition"
|
||||
|
||||
|
||||
# Check parameters
|
||||
assert len(params) == 4, "Should have 4 parameters"
|
||||
assert "test-id" in params, "Parameters should include test-id"
|
||||
@@ -787,7 +787,7 @@ class TestComplexFilterCombinations:
|
||||
"""Test filtering with deeply nested JSON fields."""
|
||||
filters = {"metadata.level1.level2.level3.deep": {MainFilterOperator.GT: 100}}
|
||||
sql, params = main_apply_filters(filters, [])
|
||||
|
||||
|
||||
expected_path = (
|
||||
"metadata->'level1'->'level2'->'level3'->>'deep'"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user