add agent tests

This commit is contained in:
emrgnt-cmplxty
2025-03-21 22:10:35 -07:00
parent ae279a670b
commit 05f04e7109
4 changed files with 875 additions and 129 deletions
+277
View File
@@ -216,3 +216,280 @@ async def graphs_handler(db_provider):
)
await handler.create_tables()
return handler
# Citation testing fixtures and utilities
import json
import re
from unittest.mock import MagicMock, AsyncMock
from typing import Tuple, Any, AsyncGenerator
from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
from core.utils import CitationTracker, SearchResultsCollector
from core.agent.base import R2RStreamingAgent
class MockLLMProvider:
"""Mock LLM provider for testing."""
def __init__(self, response_content=None, citations=None):
self.response_content = response_content or "This is a response"
self.citations = citations or []
async def aget_completion(self, messages, generation_config):
"""Mock synchronous completion."""
content = self.response_content
for citation in self.citations:
content += f" [{citation}]"
mock_response = MagicMock(spec=LLMChatCompletion)
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = content
mock_response.choices[0].finish_reason = "stop"
return mock_response
async def aget_completion_stream(self, messages, generation_config):
"""Mock streaming completion."""
content = self.response_content
for citation in self.citations:
content += f" [{citation}]"
# Simulate streaming by yielding one character at a time
for i in range(len(content)):
chunk = MagicMock(spec=LLMChatCompletionChunk)
chunk.choices = [MagicMock()]
chunk.choices[0].delta = MagicMock()
chunk.choices[0].delta.content = content[i]
chunk.choices[0].finish_reason = None
yield chunk
# Final chunk with finish_reason="stop"
final_chunk = MagicMock(spec=LLMChatCompletionChunk)
final_chunk.choices = [MagicMock()]
final_chunk.choices[0].delta = MagicMock()
final_chunk.choices[0].delta.content = ""
final_chunk.choices[0].finish_reason = "stop"
yield final_chunk
class MockPromptsHandler:
"""Mock prompts handler for testing."""
async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
"""Return a mock system prompt."""
return "You are a helpful assistant that provides well-sourced information."
class MockDatabaseProvider:
"""Mock database provider for testing."""
def __init__(self):
# Add a prompts_handler attribute to prevent AttributeError
self.prompts_handler = MockPromptsHandler()
async def acreate_conversation(self, *args, **kwargs):
return {"id": "conv_12345"}
async def aupdate_conversation(self, *args, **kwargs):
return True
async def acreate_message(self, *args, **kwargs):
return {"id": "msg_12345"}
class MockSearchResultsCollector:
"""Mock search results collector for testing."""
def __init__(self, results=None):
self.results = results or {}
def find_by_short_id(self, short_id):
return self.results.get(short_id, {
"document_id": f"doc_{short_id}",
"text": f"This is document text for {short_id}",
"metadata": {"source": f"source_{short_id}"}
})
# Create a concrete implementation of R2RStreamingAgent for testing
class MockR2RStreamingAgent(R2RStreamingAgent):
"""Mock streaming agent for testing that implements the abstract method."""
# Regex pattern for citations, copied from the actual agent
BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
def _register_tools(self):
"""Implement the abstract method with a no-op version."""
pass
async def _setup(self, system_instruction=None, *args, **kwargs):
"""Override _setup to simplify initialization and avoid external dependencies."""
# Use a simple system message instead of fetching from database
system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
# Add system message to conversation
await self.conversation.add_message(
Message(role="system", content=system_content)
)
def _format_sse_event(self, event_type, data):
"""Format an SSE event manually."""
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
async def arun(
self,
system_instruction: str = None,
messages: list[Message] = None,
*args,
**kwargs,
) -> AsyncGenerator[str, None]:
"""
Simplified version of arun that focuses on citation handling for testing.
"""
await self._setup(system_instruction)
if messages:
for m in messages:
await self.conversation.add_message(m)
# Initialize citation tracker
citation_tracker = CitationTracker()
citation_payloads = {}
# Track streaming citations for final persistence
self.streaming_citations = []
# Get the LLM response with citations
response_content = "This is a test response with citations"
response_content += " [abc1234] [def5678]"
# Yield an initial message event with the start of the text
yield self._format_sse_event("message", {"content": response_content})
# Manually extract and emit citation events
# This is a simpler approach than the character-by-character approach
citation_spans = extract_citation_spans(response_content)
# Process the citations
for cid, spans in citation_spans.items():
for span in spans:
# Check if the span is new and record it
if citation_tracker.is_new_span(cid, span):
# Look up the source document for this citation
source_doc = self.search_results_collector.find_by_short_id(cid)
# Create citation payload
citation_payload = {
"document_id": source_doc.get("document_id", f"doc_{cid}"),
"text": source_doc.get("text", f"This is document text for {cid}"),
"metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
}
# Store the payload by citation ID
citation_payloads[cid] = citation_payload
# Track for persistence
self.streaming_citations.append({
"id": cid,
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
})
# Emit citation event in the expected format
citation_event = {
"id": cid,
"object": "citation",
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
}
yield self._format_sse_event("citation", citation_event)
# Add assistant message with citation metadata to conversation
await self.conversation.add_message(
Message(
role="assistant",
content=response_content,
metadata={"citations": self.streaming_citations}
)
)
# Prepare consolidated citations for final answer
consolidated_citations = []
# Group citations by ID with all their spans
for cid, spans in citation_tracker.get_all_spans().items():
if cid in citation_payloads:
consolidated_citations.append({
"id": cid,
"object": "citation",
"spans": [{"start": s[0], "end": s[1]} for s in spans],
"payload": citation_payloads[cid]
})
# Create and emit final answer event
final_evt_payload = {
"id": "msg_final",
"object": "agent.final_answer",
"generated_answer": response_content,
"citations": consolidated_citations
}
# Manually format the final answer event
yield self._format_sse_event("agent.final_answer", final_evt_payload)
# Signal the end of the SSE stream
yield "event: done\ndata: {}\n\n"
@pytest.fixture
def mock_streaming_agent():
"""Create a streaming agent with mocked dependencies."""
# Create mock config
config = MagicMock()
config.stream = True
config.max_iterations = 3
# Create mock providers
llm_provider = MockLLMProvider(
response_content="This is a test response with citations",
citations=["abc1234", "def5678"]
)
db_provider = MockDatabaseProvider()
# Create agent with mocked dependencies using our concrete implementation
agent = MockR2RStreamingAgent(
database_provider=db_provider,
llm_provider=llm_provider,
config=config,
rag_generation_config=GenerationConfig(model="test/model")
)
# Replace the search results collector with our mock
agent.search_results_collector = MockSearchResultsCollector({
"abc1234": {
"document_id": "doc_abc1234",
"text": "This is document text for abc1234",
"metadata": {"source": "source_abc1234"}
},
"def5678": {
"document_id": "doc_def5678",
"text": "This is document text for def5678",
"metadata": {"source": "source_def5678"}
}
})
return agent
async def collect_stream_output(stream):
"""Collect all output from a stream into a list."""
output = []
async for event in stream:
output.append(event)
return output
from core.utils import extract_citation_spans, find_new_citation_spans
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+37 -37
View File
@@ -153,7 +153,7 @@ class TestFilterTypeConversions:
filters = {"id": None}
sql, params = simplified_apply_filters(filters, [])
# Different implementations might handle NULL differently
assert ("IS NULL" in sql or "= NULL" in sql or
assert ("IS NULL" in sql or "= NULL" in sql or
(sql.strip().endswith("= $1") and (not params or params == [None]))), "Should handle null in top-level column"
# Metadata field
@@ -211,9 +211,9 @@ class TestRealWorldQueries:
{"metadata.score": {SimplifiedFilterOperator.GT: 80}}
]
}
sql, params = simplified_apply_filters(filters, [])
# Basic checks
assert " AND " in sql, "Should use AND to combine all conditions"
assert "collection_ids &&" in sql, "Should use overlap operator for collections"
@@ -222,13 +222,13 @@ class TestRealWorldQueries:
assert "metadata->>'status'" in sql or "metadata->'status'" in sql, "Should compare status as text"
assert "@>" in sql, "Should use containment for tags"
assert "score" in sql and "numeric" in sql, "Should compare score as numeric"
# Check parameters
assert ["collection1", "collection2"] in params, "Should include collection IDs"
assert "2021-01-01" in params, "Should include date string"
assert "active" in params, "Should include status"
assert "80" in params, "Should include score (as string)"
# At least one parameter should be a JSON string for tags
assert any(isinstance(p, str) and "important" in p for p in params), "Should include JSON-encoded tags"
@@ -246,15 +246,15 @@ class TestRealWorldQueries:
}
]
}
sql, params = simplified_apply_filters(filters, [])
# Basic checks
assert " OR " in sql, "Should use OR for pagination options"
assert " AND " in sql, "Should use AND for tie-breaker"
assert "metadata->>'created_at'" in sql, "Should reference created_at"
assert "id <" in sql, "Should have ID comparison"
# Check parameters
assert "2023-01-01" in params, "Should include date twice"
assert params.count("2023-01-01") == 2, "Date should appear twice (LT and EQ)"
@@ -266,9 +266,9 @@ class TestRealWorldQueries:
filters = {
"metadata.level1.level2.level3.value": {SimplifiedFilterOperator.GT: 100}
}
sql, params = simplified_apply_filters(filters, [])
# Check JSON path navigation
expected_path = "metadata->'level1'->'level2'->'level3'->>'value'"
assert expected_path in sql, "Should properly navigate nested JSON path"
@@ -285,16 +285,16 @@ class TestRealWorldQueries:
{"metadata.content": {SimplifiedFilterOperator.ILIKE: "%search term%"}}
]
}
sql, params = simplified_apply_filters(filters, [])
# Check basic structure
assert " OR " in sql, "Should use OR to combine text search conditions"
# Different implementations can use different JSON extraction operators
assert "metadata" in sql and "title" in sql and "ILIKE" in sql, "Should search in title with ILIKE"
assert "metadata" in sql and "description" in sql, "Should search in description"
assert "metadata" in sql and "content" in sql, "Should search in content"
# Check parameters
assert all("%search term%" in p for p in params), "All parameters should contain search term"
@@ -308,11 +308,11 @@ class TestCornerCases:
many_conditions = []
for i in range(50): # 50 conditions
many_conditions.append({"metadata.field" + str(i): "value" + str(i)})
filters = {SimplifiedFilterOperator.AND: many_conditions}
sql, params = simplified_apply_filters(filters, [])
# Basic checks
assert " AND " in sql, "Should use AND to combine conditions"
assert len(params) == 50, "Should have 50 parameters"
@@ -328,15 +328,15 @@ class TestCornerCases:
{"metadata.score": {SimplifiedFilterOperator.LT: 100}}
]
}
sql, params = simplified_apply_filters(filters, [])
# Check structure
assert " AND " in sql, "Should use AND to combine range bounds"
assert "metadata->>'score'" in sql, "Should reference score field"
assert "::numeric >=" in sql, "Should have GTE comparison"
assert "::numeric <" in sql, "Should have LT comparison"
# Check parameters
assert "50" in params, "Should include lower bound"
assert "100" in params, "Should include upper bound"
@@ -348,7 +348,7 @@ class TestCornerCases:
sql, params = simplified_apply_filters(filters, [])
# The behavior depends on implementation - could be "FALSE" or empty list handling
assert "FALSE" in sql.upper() or ("ANY" in sql and "[]" in str(params)), "Should handle empty IN list"
# Empty CONTAINS list
filters = {"metadata.tags": {SimplifiedFilterOperator.CONTAINS: []}}
sql, params = simplified_apply_filters(filters, [])
@@ -363,7 +363,7 @@ class TestCornerCases:
# Should be safe because of parameterization
assert "id =" in sql, "Should handle value normally"
assert params == ["value'; DROP TABLE users; --"], "Should include value as parameter"
# Very long string
long_string = "x" * 1000 # 1000 character string
filters = {"metadata.field": long_string}
@@ -378,7 +378,7 @@ class TestCornerCases:
sql, params = simplified_apply_filters(filters, [])
assert "metadata->>'title'" in sql, "Should handle Unicode normally"
assert params == ["😀 Unicode 测试"], "Should include Unicode string as parameter"
# Unicode in field name (might not be supported in current implementation)
try:
filters = {"metadata.标题": "value"}
@@ -590,11 +590,11 @@ class TestMetadataFilters:
# With prefix
filters1 = {"metadata.key": "value"}
sql1, params1 = main_apply_filters(filters1, [])
# Without prefix (should be treated as metadata)
filters2 = {"key": "value"}
sql2, params2 = main_apply_filters(filters2, [], top_level_columns=["id", "owner_id"])
# Both should produce the same SQL
assert "metadata" in sql2, "Field not in top_level_columns should be treated as metadata"
assert params1 == params2, "Parameters should be the same"
@@ -678,15 +678,15 @@ class TestEdgeCases:
def test_filter_modes(self):
"""Test different filter modes (where_clause, condition_only, append_only)."""
filters = {"id": "test-id"}
# Default where_clause mode
sql1, _ = main_apply_filters(filters, [])
assert sql1.startswith("WHERE"), "Default mode should prepend WHERE"
# condition_only mode
sql2, _ = main_apply_filters(filters, [], mode="condition_only")
assert not sql2.startswith("WHERE"), "condition_only mode should not prepend WHERE"
# append_only mode
sql3, _ = main_apply_filters(filters, [], mode="append_only")
assert sql3.startswith("AND"), "append_only mode should prepend AND"
@@ -718,13 +718,13 @@ class TestEdgeCases:
"""Test with custom top_level_columns parameter."""
# Define a custom set of top-level columns
custom_columns = ["id", "custom_field"]
# Test a field that's in custom_columns
filters = {"custom_field": "value"}
sql, _ = main_apply_filters(filters, [], top_level_columns=custom_columns)
assert "custom_field =" in sql, "Should treat custom_field as a normal column"
assert "metadata" not in sql, "Should not treat custom_field as metadata"
# Test a field that's not in custom_columns
filters = {"other_field": "value"}
sql, _ = main_apply_filters(filters, [], top_level_columns=custom_columns)
@@ -734,7 +734,7 @@ class TestEdgeCases:
"""Test with custom json_column parameter."""
# Use a custom json column name
custom_json = "properties"
filters = {"field": "value"}
sql, _ = main_apply_filters(filters, [], top_level_columns=["id"], json_column=custom_json)
assert custom_json in sql, f"Should use {custom_json} instead of metadata"
@@ -759,23 +759,23 @@ class TestComplexFilterCombinations:
]
}
sql, params = main_apply_filters(filters, [])
# Check for AND operator
assert " AND " in sql, "SQL should contain AND operator"
# Check for metadata handling
assert "metadata->>'score'" in sql, "SQL should handle metadata field"
assert "::numeric >=" in sql, "SQL should handle numeric comparison"
# Check for OR operator
assert " OR " in sql, "SQL should contain OR operator"
# Check for collection_id handling
assert "collection_ids" in sql, "SQL should handle collection_id"
# Check for parent_id handling
assert "parent_id = ANY" in sql, "SQL should handle parent_id IN condition"
# Check parameters
assert len(params) == 4, "Should have 4 parameters"
assert "test-id" in params, "Parameters should include test-id"
@@ -787,7 +787,7 @@ class TestComplexFilterCombinations:
"""Test filtering with deeply nested JSON fields."""
filters = {"metadata.level1.level2.level3.deep": {MainFilterOperator.GT: 100}}
sql, params = main_apply_filters(filters, [])
expected_path = (
"metadata->'level1'->'level2'->'level3'->>'deep'"
)