Files
R2R/py/tests/unit/test_agent_citations.py
T
emrgnt-cmplxty 05f04e7109 add agent tests
2025-03-21 22:10:35 -07:00

1138 lines
44 KiB
Python

"""
Unit tests for citation extraction and propagation in the R2RStreamingAgent.
These tests focus specifically on citation-related functionality:
- Citation extraction from text
- Citation tracking during streaming
- Citation event emission
- Citation formatting and propagation
- Citation edge cases and validation
"""
import pytest
import asyncio
import json
import re
from unittest.mock import MagicMock, patch, AsyncMock
from typing import Dict, List, Tuple, Any, AsyncGenerator
import pytest_asyncio
from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
from core.utils import CitationTracker, extract_citations, extract_citation_spans
from core.agent.base import R2RStreamingAgent
# Import mock classes from conftest
from conftest import (
MockDatabaseProvider,
MockLLMProvider,
MockR2RStreamingAgent,
MockSearchResultsCollector,
collect_stream_output
)
class MockLLMProvider:
"""Mock LLM provider for testing."""
def __init__(self, response_content=None, citations=None):
self.response_content = response_content or "This is a response"
self.citations = citations or []
async def aget_completion(self, messages, generation_config):
"""Mock synchronous completion."""
content = self.response_content
for citation in self.citations:
content += f" [{citation}]"
mock_response = MagicMock(spec=LLMChatCompletion)
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = content
mock_response.choices[0].finish_reason = "stop"
return mock_response
async def aget_completion_stream(self, messages, generation_config):
"""Mock streaming completion."""
content = self.response_content
for citation in self.citations:
content += f" [{citation}]"
# Simulate streaming by yielding one character at a time
for i in range(len(content)):
chunk = MagicMock(spec=LLMChatCompletionChunk)
chunk.choices = [MagicMock()]
chunk.choices[0].delta = MagicMock()
chunk.choices[0].delta.content = content[i]
chunk.choices[0].finish_reason = None
yield chunk
# Final chunk with finish_reason="stop"
final_chunk = MagicMock(spec=LLMChatCompletionChunk)
final_chunk.choices = [MagicMock()]
final_chunk.choices[0].delta = MagicMock()
final_chunk.choices[0].delta.content = ""
final_chunk.choices[0].finish_reason = "stop"
yield final_chunk
class MockPromptsHandler:
"""Mock prompts handler for testing."""
async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
"""Return a mock system prompt."""
return "You are a helpful assistant that provides well-sourced information."
class MockDatabaseProvider:
"""Mock database provider for testing."""
def __init__(self):
# Add a prompts_handler attribute to prevent AttributeError
self.prompts_handler = MockPromptsHandler()
async def acreate_conversation(self, *args, **kwargs):
return {"id": "conv_12345"}
async def aupdate_conversation(self, *args, **kwargs):
return True
async def acreate_message(self, *args, **kwargs):
return {"id": "msg_12345"}
class MockSearchResultsCollector:
"""Mock search results collector for testing."""
def __init__(self, results=None):
self.results = results or {}
def find_by_short_id(self, short_id):
return self.results.get(short_id, {
"document_id": f"doc_{short_id}",
"text": f"This is document text for {short_id}",
"metadata": {"source": f"source_{short_id}"}
})
# Create a concrete implementation of R2RStreamingAgent for testing
class MockR2RStreamingAgent(R2RStreamingAgent):
"""Mock streaming agent for testing that implements the abstract method."""
# Regex pattern for citations, copied from the actual agent
BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
def _register_tools(self):
"""Implement the abstract method with a no-op version."""
pass
async def _setup(self, system_instruction=None, *args, **kwargs):
"""Override _setup to simplify initialization and avoid external dependencies."""
# Use a simple system message instead of fetching from database
system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
# Add system message to conversation
await self.conversation.add_message(
Message(role="system", content=system_content)
)
def _format_sse_event(self, event_type, data):
"""Format an SSE event manually."""
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
async def arun(
self,
system_instruction: str = None,
messages: list[Message] = None,
*args,
**kwargs,
) -> AsyncGenerator[str, None]:
"""
Simplified version of arun that focuses on citation handling for testing.
"""
await self._setup(system_instruction)
if messages:
for m in messages:
await self.conversation.add_message(m)
# Initialize citation tracker
citation_tracker = CitationTracker()
citation_payloads = {}
# Track streaming citations for final persistence
self.streaming_citations = []
# Get the LLM response with citations
response_content = "This is a test response with citations"
response_content += " [abc1234] [def5678]"
# Yield an initial message event with the start of the text
yield self._format_sse_event("message", {"content": response_content})
# Manually extract and emit citation events
# This is a simpler approach than the character-by-character approach
citation_spans = extract_citation_spans(response_content)
# Process the citations
for cid, spans in citation_spans.items():
for span in spans:
# Check if the span is new and record it
if citation_tracker.is_new_span(cid, span):
# Look up the source document for this citation
source_doc = self.search_results_collector.find_by_short_id(cid)
# Create citation payload
citation_payload = {
"document_id": source_doc.get("document_id", f"doc_{cid}"),
"text": source_doc.get("text", f"This is document text for {cid}"),
"metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
}
# Store the payload by citation ID
citation_payloads[cid] = citation_payload
# Track for persistence
self.streaming_citations.append({
"id": cid,
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
})
# Emit citation event
citation_event = {
"id": cid,
"object": "citation",
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
}
yield self._format_sse_event("citation", citation_event)
# Add assistant message with citation metadata to conversation
await self.conversation.add_message(
Message(
role="assistant",
content=response_content,
metadata={"citations": self.streaming_citations}
)
)
# Prepare consolidated citations for final answer
consolidated_citations = []
# Group citations by ID with all their spans
for cid, spans in citation_tracker.get_all_spans().items():
if cid in citation_payloads:
consolidated_citations.append({
"id": cid,
"object": "citation",
"spans": [{"start": s[0], "end": s[1]} for s in spans],
"payload": citation_payloads[cid]
})
# Create and emit final answer event
final_evt_payload = {
"id": "msg_final",
"object": "agent.final_answer",
"generated_answer": response_content,
"citations": consolidated_citations
}
# Manually format the final answer event
yield self._format_sse_event("agent.final_answer", final_evt_payload)
# Signal the end of the SSE stream
yield "event: done\ndata: {}\n\n"
@pytest.fixture
def mock_streaming_agent():
"""Create a streaming agent with mocked dependencies."""
# Create mock config
config = MagicMock()
config.stream = True
config.max_iterations = 3
# Create mock providers
llm_provider = MockLLMProvider(
response_content="This is a test response with citations",
citations=["abc1234", "def5678"]
)
db_provider = MockDatabaseProvider()
# Create agent with mocked dependencies using our concrete implementation
agent = MockR2RStreamingAgent(
database_provider=db_provider,
llm_provider=llm_provider,
config=config,
rag_generation_config=GenerationConfig(model="test/model")
)
# Replace the search results collector with our mock
agent.search_results_collector = MockSearchResultsCollector({
"abc1234": {
"document_id": "doc_abc1234",
"text": "This is document text for abc1234",
"metadata": {"source": "source_abc1234"}
},
"def5678": {
"document_id": "doc_def5678",
"text": "This is document text for def5678",
"metadata": {"source": "source_def5678"}
}
})
return agent
async def collect_stream_output(stream):
"""Collect all output from a stream into a list."""
output = []
async for event in stream:
output.append(event)
return output
def test_extract_citations_from_response():
"""Test that citations are extracted from LLM responses."""
response_text = "This is a response with a citation [abc1234]."
# Use the utility function directly
citations = extract_citations(response_text)
assert "abc1234" in citations, "Citation should be extracted from response"
@pytest.mark.asyncio
async def test_streaming_agent_citation_extraction(mock_streaming_agent):
"""Test that streaming agent extracts citations from streamed content."""
# Run the agent
messages = [Message(role="user", content="Test query")]
# We need to run this in a coroutine
stream = mock_streaming_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Look for citation events in the output
citation_events = [
line for line in output
if 'event: citation' in line
]
assert len(citation_events) > 0, "Citation events should be emitted"
# Check citation IDs in events
citation_abc = any('abc1234' in event for event in citation_events)
citation_def = any('def5678' in event for event in citation_events)
assert citation_abc, "Citation abc1234 should be found in stream output"
assert citation_def, "Citation def5678 should be found in stream output"
@pytest.mark.asyncio
async def test_citation_tracker_during_streaming(mock_streaming_agent):
"""Test that CitationTracker correctly tracks processed citations during streaming."""
# We need to patch the is_new_span method to verify it's being used correctly
# Use autospec=True to ensure the method signature is preserved
with patch('core.utils.CitationTracker.is_new_span', autospec=True) as mock_is_new_span:
# Configure the mock to return True so citations will be processed
mock_is_new_span.return_value = True
messages = [Message(role="user", content="Test query")]
# Run the agent
stream = mock_streaming_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Verify that CitationTracker.is_new_span method was called
assert mock_is_new_span.call_count > 0, "is_new_span should be called to track citation spans"
@pytest.mark.asyncio
async def test_final_answer_includes_consolidated_citations(mock_streaming_agent):
"""Test that the final answer includes consolidated citations."""
messages = [Message(role="user", content="Test query")]
# Run the agent
stream = mock_streaming_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Look for final answer event in the output
final_answer_events = [
line for line in output
if 'event: agent.final_answer' in line
]
assert len(final_answer_events) > 0, "Final answer event should be emitted"
# Parse the event to check for citations
for event in final_answer_events:
data_part = event.split('data: ')[1] if 'data: ' in event else event
try:
data = json.loads(data_part)
if 'citations' in data:
assert len(data['citations']) > 0, "Final answer should include citations"
citation_ids = [citation.get('id') for citation in data['citations']]
assert 'abc1234' in citation_ids or 'def5678' in citation_ids, "Known citation IDs should be included"
except json.JSONDecodeError:
continue
@pytest.mark.asyncio
async def test_conversation_message_includes_citation_metadata(mock_streaming_agent):
"""Test that conversation messages include citation metadata."""
with patch.object(mock_streaming_agent.conversation, 'add_message', wraps=mock_streaming_agent.conversation.add_message) as mock_add_message:
messages = [Message(role="user", content="Test query")]
# Run the agent
stream = mock_streaming_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Check that add_message was called with citation metadata
citation_calls = 0
for call in mock_add_message.call_args_list:
args, kwargs = call
if args and isinstance(args[0], Message):
message = args[0]
if message.role == 'assistant' and message.metadata and 'citations' in message.metadata:
citation_calls += 1
assert citation_calls > 0, "At least one assistant message should include citation metadata"
@pytest.mark.asyncio
async def test_multiple_citations_for_same_source(mock_streaming_agent):
"""Test handling of multiple citations for the same source document."""
# Create a custom citation tracker that we can control
citation_tracker = CitationTracker()
# Create a custom MockR2RStreamingAgent with our controlled citation tracker
with patch('core.utils.CitationTracker', return_value=citation_tracker):
custom_agent = mock_streaming_agent
# Modify the arun method to include repeated citations for the same source
original_arun = custom_agent.arun
async def custom_arun(*args, **kwargs):
"""Custom arun that includes repeated citations for the same source."""
# Setup like the original
await custom_agent._setup(kwargs.get('system_instruction'))
messages = kwargs.get('messages', [])
if messages:
for m in messages:
await custom_agent.conversation.add_message(m)
# Initialize payloads dict for tracking
citation_payloads = {}
# Track streaming citations for final persistence
custom_agent.streaming_citations = []
# Create text with multiple citations to the same source
response_content = "This text has multiple citations to the same source: [abc1234] and again here [abc1234]."
# Yield the message event
yield custom_agent._format_sse_event("message", {"content": response_content})
# Manually extract and emit citation events
# This is a simpler approach than the character-by-character approach
citation_spans = extract_citation_spans(response_content)
# Process the citations
for cid, spans in citation_spans.items():
for span in spans:
# Mark as processed in the tracker
citation_tracker.is_new_span(cid, span)
# Look up the source document for this citation
source_doc = custom_agent.search_results_collector.find_by_short_id(cid)
# Create citation payload
citation_payload = {
"document_id": source_doc.get("document_id", f"doc_{cid}"),
"text": source_doc.get("text", f"This is document text for {cid}"),
"metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
}
# Store the payload
citation_payloads[cid] = citation_payload
# Track for persistence
custom_agent.streaming_citations.append({
"id": cid,
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
})
# Emit citation event
citation_event = {
"id": cid,
"object": "citation",
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
}
yield custom_agent._format_sse_event("citation", citation_event)
# Add assistant message with citation metadata to conversation
await custom_agent.conversation.add_message(
Message(
role="assistant",
content=response_content,
metadata={"citations": custom_agent.streaming_citations}
)
)
# Prepare consolidated citations for final answer
consolidated_citations = []
# Group citations by ID with all their spans
for cid, spans in citation_tracker.get_all_spans().items():
if cid in citation_payloads:
consolidated_citations.append({
"id": cid,
"object": "citation",
"spans": [{"start": s[0], "end": s[1]} for s in spans],
"payload": citation_payloads[cid]
})
# Create and emit final answer event
final_evt_payload = {
"id": "msg_final",
"object": "agent.final_answer",
"generated_answer": response_content,
"citations": consolidated_citations
}
yield custom_agent._format_sse_event("agent.final_answer", final_evt_payload)
# Signal the end of the SSE stream
yield "event: done\ndata: {}\n\n"
# Apply the custom arun method
with patch.object(custom_agent, 'arun', custom_arun):
messages = [Message(role="user", content="Test query")]
# Run the agent with overlapping citations
stream = custom_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Count citation events for abc1234
citation_abc_events = [
line for line in output
if 'event: citation' in line and 'abc1234' in line
]
# There should be at least 2 citations for abc1234 (the original and our added one)
assert len(citation_abc_events) >= 2, "Should emit multiple citation events for the same source"
# Check the final answer to ensure spans were consolidated
final_answer_events = [
line for line in output
if 'event: agent.final_answer' in line
]
for event in final_answer_events:
data_part = event.split('data: ')[1] if 'data: ' in event else event
try:
data = json.loads(data_part)
if 'citations' in data:
# Find the citation for abc1234
abc_citation = next((citation for citation in data['citations'] if citation.get('id') == 'abc1234'), None)
if abc_citation:
# It should have multiple spans
assert abc_citation.get('spans') and len(abc_citation['spans']) >= 2, "Citation should have multiple spans consolidated"
except json.JSONDecodeError:
continue
@pytest.mark.asyncio
async def test_citation_consolidation_logic(mock_streaming_agent):
"""Test that citation consolidation properly groups spans by citation ID."""
# Patch the get_all_spans method to return a controlled set of spans
citation_tracker = CitationTracker()
# Add spans for multiple citations
citation_tracker.is_new_span("abc1234", (10, 20))
citation_tracker.is_new_span("abc1234", (30, 40))
citation_tracker.is_new_span("def5678", (50, 60))
citation_tracker.is_new_span("ghi9012", (70, 80))
citation_tracker.is_new_span("ghi9012", (90, 100))
# Create a custom mock agent that uses our pre-populated citation tracker
with patch('core.utils.CitationTracker', return_value=citation_tracker):
# Create a fresh agent with our mocked citation tracker
new_agent = mock_streaming_agent
messages = [Message(role="user", content="Test query")]
# Run the agent
stream = new_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Look for the final answer event
final_answer_events = [
line for line in output
if 'event: agent.final_answer' in line
]
# Verify consolidation in final answer
for event in final_answer_events:
data_part = event.split('data: ')[1] if 'data: ' in event else event
try:
data = json.loads(data_part)
if 'citations' in data:
# There should be at least 2 citations (from our mock agent implementation)
assert len(data['citations']) >= 2, "Should include multiple citation objects"
# Check spans for each citation
for citation in data['citations']:
cid = citation.get('id')
if cid == 'abc1234':
# Spans should be consolidated for abc1234
spans = citation.get('spans', [])
assert len(spans) >= 1, f"Citation {cid} should have spans"
except json.JSONDecodeError:
continue
@pytest.mark.asyncio
async def test_citation_event_format(mock_streaming_agent):
"""Test that citation events follow the expected format."""
messages = [Message(role="user", content="Test query")]
# Run the agent
stream = mock_streaming_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Extract citation events
citation_events = [
line for line in output
if 'event: citation' in line
]
assert len(citation_events) > 0, "Citation events should be emitted"
# Check the format of each citation event
for event in citation_events:
# Should have 'event: citation' and 'data: {...}'
assert 'event: citation' in event, "Event type should be 'citation'"
assert 'data: ' in event, "Event should have data payload"
# Parse the data payload
data_part = event.split('data: ')[1] if 'data: ' in event else event
try:
data = json.loads(data_part)
# Check required fields
assert 'id' in data, "Citation event should have an 'id'"
assert 'object' in data and data['object'] == 'citation', "Event object should be 'citation'"
assert 'span' in data, "Citation event should have a 'span'"
assert 'start' in data['span'] and 'end' in data['span'], "Span should have 'start' and 'end'"
assert 'payload' in data, "Citation event should have a 'payload'"
# Check payload fields
assert 'document_id' in data['payload'], "Payload should have 'document_id'"
assert 'text' in data['payload'], "Payload should have 'text'"
assert 'metadata' in data['payload'], "Payload should have 'metadata'"
except json.JSONDecodeError:
pytest.fail(f"Citation event data is not valid JSON: {data_part}")
@pytest.mark.asyncio
async def test_final_answer_event_format(mock_streaming_agent):
"""Test that the final answer event follows the expected format."""
messages = [Message(role="user", content="Test query")]
# Run the agent
stream = mock_streaming_agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Look for final answer event
final_answer_events = [
line for line in output
if 'event: agent.final_answer' in line
]
assert len(final_answer_events) > 0, "Final answer event should be emitted"
# Check the format of the final answer event
for event in final_answer_events:
assert 'event: agent.final_answer' in event, "Event type should be 'agent.final_answer'"
assert 'data: ' in event, "Event should have data payload"
# Parse the data payload
data_part = event.split('data: ')[1] if 'data: ' in event else event
try:
data = json.loads(data_part)
# Check required fields
assert 'id' in data, "Final answer event should have an 'id'"
assert 'object' in data and data['object'] == 'agent.final_answer', "Event object should be 'agent.final_answer'"
assert 'generated_answer' in data, "Final answer event should have a 'generated_answer'"
assert 'citations' in data, "Final answer event should have 'citations'"
# Check citation fields
for citation in data['citations']:
assert 'id' in citation, "Citation should have an 'id'"
assert 'object' in citation and citation['object'] == 'citation', "Citation object should be 'citation'"
assert 'spans' in citation, "Citation should have 'spans'"
assert 'payload' in citation, "Citation should have a 'payload'"
# Check spans format
for span in citation['spans']:
assert 'start' in span, "Span should have 'start'"
assert 'end' in span, "Span should have 'end'"
# Check payload fields
assert 'document_id' in citation['payload'], "Payload should have 'document_id'"
assert 'text' in citation['payload'], "Payload should have 'text'"
assert 'metadata' in citation['payload'], "Payload should have 'metadata'"
except json.JSONDecodeError:
pytest.fail(f"Final answer event data is not valid JSON: {data_part}")
@pytest.mark.asyncio
async def test_overlapping_citation_handling():
"""Test that overlapping citations are handled correctly."""
# Create a custom agent configuration
config = MagicMock()
config.stream = True
config.max_iterations = 3
# Create providers
llm_provider = MockLLMProvider(
response_content="This is a test response with overlapping citations",
citations=["abc1234", "def5678"]
)
db_provider = MockDatabaseProvider()
# Create agent
agent = MockR2RStreamingAgent(
database_provider=db_provider,
llm_provider=llm_provider,
config=config,
rag_generation_config=GenerationConfig(model="test/model")
)
# Replace the search results collector with our mock
agent.search_results_collector = MockSearchResultsCollector({
"abc1234": {
"document_id": "doc_abc1234",
"text": "This is document text for abc1234",
"metadata": {"source": "source_abc1234"}
},
"def5678": {
"document_id": "doc_def5678",
"text": "This is document text for def5678",
"metadata": {"source": "source_def5678"}
}
})
# Modify the arun method for overlapping citations
original_arun = agent.arun
async def custom_arun(*args, **kwargs):
"""Custom arun that includes overlapping citations."""
# Setup like the original
await agent._setup(kwargs.get('system_instruction'))
messages = kwargs.get('messages', [])
if messages:
for m in messages:
await agent.conversation.add_message(m)
# Initialize citation tracker
citation_tracker = CitationTracker()
citation_payloads = {}
# Track streaming citations for final persistence
agent.streaming_citations = []
# Create text with overlapping citations (citation spans that overlap)
response_content = "This text has overlapping citations [abc1234] part of which [def5678] overlap."
# Yield the message event
yield agent._format_sse_event("message", {"content": response_content})
# Manually create overlapping citation spans
# For simplicity, we'll define the spans directly rather than using regex
citation_spans = {
"abc1234": [(30, 39)], # This span includes "[abc1234]"
"def5678": [(55, 64)] # This span includes "[def5678]"
}
# Process the citations
for cid, spans in citation_spans.items():
for span in spans:
# Mark as processed in the tracker
citation_tracker.is_new_span(cid, span)
# Look up the source document for this citation
source_doc = agent.search_results_collector.find_by_short_id(cid)
# Create citation payload
citation_payload = {
"document_id": source_doc.get("document_id", f"doc_{cid}"),
"text": source_doc.get("text", f"This is document text for {cid}"),
"metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
}
# Store the payload by citation ID
citation_payloads[cid] = citation_payload
# Track for persistence
agent.streaming_citations.append({
"id": cid,
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
})
# Emit citation event
citation_event = {
"id": cid,
"object": "citation",
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
}
yield agent._format_sse_event("citation", citation_event)
# Add assistant message with citation metadata to conversation
await agent.conversation.add_message(
Message(
role="assistant",
content=response_content,
metadata={"citations": agent.streaming_citations}
)
)
# Prepare consolidated citations for final answer
consolidated_citations = []
# Group citations by ID with all their spans
for cid, spans in citation_tracker.get_all_spans().items():
if cid in citation_payloads:
consolidated_citations.append({
"id": cid,
"object": "citation",
"spans": [{"start": s[0], "end": s[1]} for s in spans],
"payload": citation_payloads[cid]
})
# Create and emit final answer event
final_evt_payload = {
"id": "msg_final",
"object": "agent.final_answer",
"generated_answer": response_content,
"citations": consolidated_citations
}
# Emit final answer event
yield agent._format_sse_event("agent.final_answer", final_evt_payload)
# Signal the end of the SSE stream
yield "event: done\ndata: {}\n\n"
# Replace the arun method
with patch.object(agent, 'arun', custom_arun):
messages = [Message(role="user", content="Test query")]
# Run the agent with overlapping citations
stream = agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Check that both citations were emitted
citation_abc = any('abc1234' in event for event in output if 'event: citation' in event)
citation_def = any('def5678' in event for event in output if 'event: citation' in event)
assert citation_abc, "Citation abc1234 should be emitted"
assert citation_def, "Citation def5678 should be emitted"
# Check the final answer for both citations
final_answer_events = [
line for line in output
if 'event: agent.final_answer' in line
]
for event in final_answer_events:
data_part = event.split('data: ')[1] if 'data: ' in event else event
try:
data = json.loads(data_part)
if 'citations' in data:
citation_ids = [citation.get('id') for citation in data['citations']]
assert 'abc1234' in citation_ids, "abc1234 should be in final answer citations"
assert 'def5678' in citation_ids, "def5678 should be in final answer citations"
except json.JSONDecodeError:
continue
@pytest.mark.asyncio
async def test_robustness_against_citation_variations(mock_streaming_agent):
"""Test agent's robustness against different citation formats and variations."""
# Create a custom text with different citation variations
response_text = """
This text has different citation variations:
1. Standard citation: [abc1234]
2. Another citation: [def5678]
3. Adjacent citations: [abc1234][def5678]
4. Special characters around citation: ([abc1234]) or "[def5678]".
"""
# Use the extract_citations function directly to see what would be detected
citations = extract_citations(response_text)
# There should be at least two different citation IDs
unique_citations = set(citations)
assert len(unique_citations) >= 2, "Should extract at least two different citation IDs"
assert "abc1234" in unique_citations, "Should extract abc1234"
assert "def5678" in unique_citations, "Should extract def5678"
# Count occurrences of each citation
counts = {}
for cid in citations:
counts[cid] = counts.get(cid, 0) + 1
# Each citation should be found the correct number of times based on the text
assert counts.get("abc1234", 0) >= 2, "abc1234 should appear at least twice"
assert counts.get("def5678", 0) >= 2, "def5678 should appear at least twice"
class TestCitationEdgeCases:
"""
Test class for citation edge cases using parameterized tests to cover multiple scenarios.
"""
@pytest.mark.parametrize("test_case", [
# Test case 1: Empty text
{"text": "", "expected_citations": []},
# Test case 2: Text with no citations
{"text": "This text has no citations.", "expected_citations": []},
# Test case 3: Adjacent citations
{"text": "Adjacent citations [abc1234][def5678]", "expected_citations": ["abc1234", "def5678"]},
# Test case 4: Repeated citations
{"text": "Repeated [abc1234] citation [abc1234]", "expected_citations": ["abc1234", "abc1234"]},
# Test case 5: Citation at beginning
{"text": "[abc1234] at beginning", "expected_citations": ["abc1234"]},
# Test case 6: Citation at end
{"text": "At end [abc1234]", "expected_citations": ["abc1234"]},
# Test case 7: Mixed valid and invalid citations
{"text": "Valid [abc1234] and invalid [ab123] citations", "expected_citations": ["abc1234"]},
# Test case 8: Citations with punctuation
{"text": "Citations with punctuation: ([abc1234]), [def5678]!", "expected_citations": ["abc1234", "def5678"]}
])
def test_citation_extraction_cases(self, test_case):
"""Test citation extraction with various edge cases."""
text = test_case["text"]
expected = test_case["expected_citations"]
# Extract citations
actual = extract_citations(text)
# Check count
assert len(actual) == len(expected), f"Expected {len(expected)} citations, got {len(actual)}"
# Check content (allowing for different orders)
if expected:
for expected_citation in expected:
assert expected_citation in actual, f"Expected citation {expected_citation} not found"
@pytest.mark.asyncio
async def test_citation_handling_with_empty_response():
"""Test how the agent handles responses with no citations."""
# Create a custom R2RStreamingAgent with no citations
# Custom agent class for testing empty citations
class EmptyResponseAgent(MockR2RStreamingAgent):
async def arun(
self,
system_instruction: str = None,
messages: list[Message] = None,
*args,
**kwargs,
) -> AsyncGenerator[str, None]:
"""Custom arun with no citations in the response."""
await self._setup(system_instruction)
if messages:
for m in messages:
await self.conversation.add_message(m)
# Initialize citation tracker
citation_tracker = CitationTracker()
# Empty response with no citations
response_content = "This is a response with no citations."
# Yield an initial message event with the start of the text
yield self._format_sse_event("message", {"content": response_content})
# No citation spans to extract
citation_spans = extract_citation_spans(response_content)
# Should be empty
assert len(citation_spans) == 0, "No citation spans should be found"
# Add assistant message to conversation (with no citation metadata)
await self.conversation.add_message(
Message(
role="assistant",
content=response_content,
metadata={"citations": []}
)
)
# Create and emit final answer event
final_evt_payload = {
"id": "msg_final",
"object": "agent.final_answer",
"generated_answer": response_content,
"citations": []
}
yield self._format_sse_event("agent.final_answer", final_evt_payload)
yield "event: done\ndata: {}\n\n"
# Create the agent with empty citation response
config = MagicMock()
config.stream = True
llm_provider = MockLLMProvider(
response_content="This is a response with no citations.",
citations=[]
)
db_provider = MockDatabaseProvider()
# Create the custom agent
agent = EmptyResponseAgent(
database_provider=db_provider,
llm_provider=llm_provider,
config=config,
rag_generation_config=GenerationConfig(model="test/model")
)
# Test a simple query
messages = [Message(role="user", content="Query with no citations")]
# Run the agent
stream = agent.arun(messages=messages)
output = await collect_stream_output(stream)
# Verify no citation events were emitted
citation_events = [line for line in output if 'event: citation' in line]
assert len(citation_events) == 0, "No citation events should be emitted"
# Parse the final answer event to check citations
final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
assert len(final_answer_events) > 0, "Final answer event should be emitted"
data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else ""
# Parse final answer data
try:
data = json.loads(data_part)
assert 'citations' in data, "Final answer event should include citations field"
assert len(data['citations']) == 0, "Citations list should be empty"
except json.JSONDecodeError:
assert False, "Final answer event data should be valid JSON"
@pytest.mark.asyncio
async def test_citation_sanitization():
"""Test that citation IDs are properly sanitized before processing."""
# Since extract_citations uses a strict regex pattern [A-Za-z0-9]{7,8},
# we should test with valid citation formats
text = "Citation with surrounding text[abc1234]and [def5678]with no spaces."
# Extract citations
citations = extract_citations(text)
# Check if citations are properly extracted
assert "abc1234" in citations, "Citation abc1234 should be extracted"
assert "def5678" in citations, "Citation def5678 should be extracted"
# Test with spaces - these should NOT be extracted based on the implementation
text_with_spaces = "Citation with [abc1234 ] and [ def5678] spaces."
citations_with_spaces = extract_citations(text_with_spaces)
# The current implementation doesn't extract citations with spaces inside the brackets
assert len(citations_with_spaces) == 0 or "abc1234" not in citations_with_spaces, "Citations with spaces should not be extracted with current implementation"
@pytest.mark.asyncio
async def test_citation_tracking_state_persistence():
"""Test that the CitationTracker correctly maintains state across multiple calls."""
tracker = CitationTracker()
# Record some initial spans
tracker.is_new_span("abc1234", (10, 18))
tracker.is_new_span("def5678", (30, 38))
# Check if spans are correctly stored
all_spans = tracker.get_all_spans()
assert "abc1234" in all_spans, "Citation abc1234 should be tracked"
assert "def5678" in all_spans, "Citation def5678 should be tracked"
assert all_spans["abc1234"] == [(10, 18)], "Span positions should match"
# Add another span for an existing citation
tracker.is_new_span("abc1234", (50, 58))
# Check if the new span was added
all_spans = tracker.get_all_spans()
assert len(all_spans["abc1234"]) == 2, "Citation abc1234 should have 2 spans"
assert (50, 58) in all_spans["abc1234"], "New span should be added"
def test_citation_span_uniqueness():
"""Test that CitationTracker correctly identifies duplicate spans."""
tracker = CitationTracker()
# Record a span
tracker.is_new_span("abc1234", (10, 18))
# Check if the same span is recognized as not new
assert not tracker.is_new_span("abc1234", (10, 18)), "Duplicate span should not be considered new"
# Check if different span for same citation is recognized as new
assert tracker.is_new_span("abc1234", (20, 28)), "Different span should be considered new"
# Check if same span for different citation is recognized as new
assert tracker.is_new_span("def5678", (10, 18)), "Same span for different citation should be considered new"
def test_citation_with_punctuation():
"""Test extraction of citations with surrounding punctuation."""
text = "Citations with punctuation: ([abc1234]), [def5678]!, and [ghi9012]."
# Extract citations
citations = extract_citations(text)
# Check if all citations are extracted correctly
assert "abc1234" in citations, "Citation abc1234 should be extracted"
assert "def5678" in citations, "Citation def5678 should be extracted"
assert "ghi9012" in citations, "Citation ghi9012 should be extracted"
def test_citation_extraction_with_invalid_formats():
"""Test that invalid citation formats are not extracted."""
text = "Invalid citation formats: [123], [abcdef], [abc123456789], and valid [abc1234]."
# Extract citations
citations = extract_citations(text)
# Check that only valid citations are extracted
assert len(citations) == 1, "Only one valid citation should be extracted"
assert "abc1234" in citations, "Only valid citation abc1234 should be extracted"
assert "123" not in citations, "Invalid citation [123] should not be extracted"
assert "abcdef" not in citations, "Invalid citation [abcdef] should not be extracted"
assert "abc123456789" not in citations, "Invalid citation [abc123456789] should not be extracted"