diff --git a/py/tests/unit/conftest.py b/py/tests/unit/conftest.py index 027739fd4..ac74cd771 100644 --- a/py/tests/unit/conftest.py +++ b/py/tests/unit/conftest.py @@ -216,3 +216,280 @@ async def graphs_handler(db_provider): ) await handler.create_tables() return handler + +# Citation testing fixtures and utilities +import json +import re +from unittest.mock import MagicMock, AsyncMock +from typing import Tuple, Any, AsyncGenerator + +from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig +from core.utils import CitationTracker, SearchResultsCollector +from core.agent.base import R2RStreamingAgent + + +class MockLLMProvider: + """Mock LLM provider for testing.""" + + def __init__(self, response_content=None, citations=None): + self.response_content = response_content or "This is a response" + self.citations = citations or [] + + async def aget_completion(self, messages, generation_config): + """Mock synchronous completion.""" + content = self.response_content + for citation in self.citations: + content += f" [{citation}]" + + mock_response = MagicMock(spec=LLMChatCompletion) + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = content + mock_response.choices[0].finish_reason = "stop" + return mock_response + + async def aget_completion_stream(self, messages, generation_config): + """Mock streaming completion.""" + content = self.response_content + for citation in self.citations: + content += f" [{citation}]" + + # Simulate streaming by yielding one character at a time + for i in range(len(content)): + chunk = MagicMock(spec=LLMChatCompletionChunk) + chunk.choices = [MagicMock()] + chunk.choices[0].delta = MagicMock() + chunk.choices[0].delta.content = content[i] + chunk.choices[0].finish_reason = None + yield chunk + + # Final chunk with finish_reason="stop" + final_chunk = MagicMock(spec=LLMChatCompletionChunk) + final_chunk.choices = [MagicMock()] + final_chunk.choices[0].delta = MagicMock() + final_chunk.choices[0].delta.content = "" + final_chunk.choices[0].finish_reason = "stop" + yield final_chunk + + +class MockPromptsHandler: + """Mock prompts handler for testing.""" + + async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs): + """Return a mock system prompt.""" + return "You are a helpful assistant that provides well-sourced information." + + +class MockDatabaseProvider: + """Mock database provider for testing.""" + + def __init__(self): + # Add a prompts_handler attribute to prevent AttributeError + self.prompts_handler = MockPromptsHandler() + + async def acreate_conversation(self, *args, **kwargs): + return {"id": "conv_12345"} + + async def aupdate_conversation(self, *args, **kwargs): + return True + + async def acreate_message(self, *args, **kwargs): + return {"id": "msg_12345"} + + +class MockSearchResultsCollector: + """Mock search results collector for testing.""" + + def __init__(self, results=None): + self.results = results or {} + + def find_by_short_id(self, short_id): + return self.results.get(short_id, { + "document_id": f"doc_{short_id}", + "text": f"This is document text for {short_id}", + "metadata": {"source": f"source_{short_id}"} + }) + + +# Create a concrete implementation of R2RStreamingAgent for testing +class MockR2RStreamingAgent(R2RStreamingAgent): + """Mock streaming agent for testing that implements the abstract method.""" + + # Regex pattern for citations, copied from the actual agent + BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]") + SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}") + + def _register_tools(self): + """Implement the abstract method with a no-op version.""" + pass + + async def _setup(self, system_instruction=None, *args, **kwargs): + """Override _setup to simplify initialization and avoid external dependencies.""" + # Use a simple system message instead of fetching from database + system_content = system_instruction or "You are a helpful assistant that provides well-sourced information." + + # Add system message to conversation + await self.conversation.add_message( + Message(role="system", content=system_content) + ) + + def _format_sse_event(self, event_type, data): + """Format an SSE event manually.""" + return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" + + async def arun( + self, + system_instruction: str = None, + messages: list[Message] = None, + *args, + **kwargs, + ) -> AsyncGenerator[str, None]: + """ + Simplified version of arun that focuses on citation handling for testing. + """ + await self._setup(system_instruction) + + if messages: + for m in messages: + await self.conversation.add_message(m) + + # Initialize citation tracker + citation_tracker = CitationTracker() + citation_payloads = {} + + # Track streaming citations for final persistence + self.streaming_citations = [] + + # Get the LLM response with citations + response_content = "This is a test response with citations" + response_content += " [abc1234] [def5678]" + + # Yield an initial message event with the start of the text + yield self._format_sse_event("message", {"content": response_content}) + + # Manually extract and emit citation events + # This is a simpler approach than the character-by-character approach + citation_spans = extract_citation_spans(response_content) + + # Process the citations + for cid, spans in citation_spans.items(): + for span in spans: + # Check if the span is new and record it + if citation_tracker.is_new_span(cid, span): + + # Look up the source document for this citation + source_doc = self.search_results_collector.find_by_short_id(cid) + + # Create citation payload + citation_payload = { + "document_id": source_doc.get("document_id", f"doc_{cid}"), + "text": source_doc.get("text", f"This is document text for {cid}"), + "metadata": source_doc.get("metadata", {"source": f"source_{cid}"}), + } + + # Store the payload by citation ID + citation_payloads[cid] = citation_payload + + # Track for persistence + self.streaming_citations.append({ + "id": cid, + "span": {"start": span[0], "end": span[1]}, + "payload": citation_payload + }) + + # Emit citation event in the expected format + citation_event = { + "id": cid, + "object": "citation", + "span": {"start": span[0], "end": span[1]}, + "payload": citation_payload + } + + yield self._format_sse_event("citation", citation_event) + + # Add assistant message with citation metadata to conversation + await self.conversation.add_message( + Message( + role="assistant", + content=response_content, + metadata={"citations": self.streaming_citations} + ) + ) + + # Prepare consolidated citations for final answer + consolidated_citations = [] + + # Group citations by ID with all their spans + for cid, spans in citation_tracker.get_all_spans().items(): + if cid in citation_payloads: + consolidated_citations.append({ + "id": cid, + "object": "citation", + "spans": [{"start": s[0], "end": s[1]} for s in spans], + "payload": citation_payloads[cid] + }) + + # Create and emit final answer event + final_evt_payload = { + "id": "msg_final", + "object": "agent.final_answer", + "generated_answer": response_content, + "citations": consolidated_citations + } + + # Manually format the final answer event + yield self._format_sse_event("agent.final_answer", final_evt_payload) + + # Signal the end of the SSE stream + yield "event: done\ndata: {}\n\n" + + +@pytest.fixture +def mock_streaming_agent(): + """Create a streaming agent with mocked dependencies.""" + # Create mock config + config = MagicMock() + config.stream = True + config.max_iterations = 3 + + # Create mock providers + llm_provider = MockLLMProvider( + response_content="This is a test response with citations", + citations=["abc1234", "def5678"] + ) + db_provider = MockDatabaseProvider() + + # Create agent with mocked dependencies using our concrete implementation + agent = MockR2RStreamingAgent( + database_provider=db_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=GenerationConfig(model="test/model") + ) + + # Replace the search results collector with our mock + agent.search_results_collector = MockSearchResultsCollector({ + "abc1234": { + "document_id": "doc_abc1234", + "text": "This is document text for abc1234", + "metadata": {"source": "source_abc1234"} + }, + "def5678": { + "document_id": "doc_def5678", + "text": "This is document text for def5678", + "metadata": {"source": "source_def5678"} + } + }) + + return agent + + +async def collect_stream_output(stream): + """Collect all output from a stream into a list.""" + output = [] + async for event in stream: + output.append(event) + return output + + +from core.utils import extract_citation_spans, find_new_citation_spans diff --git a/py/tests/unit/test_agent.py b/py/tests/unit/test_agent.py new file mode 100644 index 000000000..0aeb8b546 --- /dev/null +++ b/py/tests/unit/test_agent.py @@ -0,0 +1,312 @@ +""" +Unit tests for the core R2RStreamingAgent functionality. + +These tests focus on the core functionality of the agent, separate from +citation-specific behavior which is tested in test_agent_citations.py. +""" + +import pytest +import asyncio +import json +import re +from unittest.mock import MagicMock, patch, AsyncMock +from typing import Dict, List, Tuple, Any, AsyncGenerator + +import pytest_asyncio + +from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig +from core.utils import CitationTracker, SearchResultsCollector, SSEFormatter +from core.agent.base import R2RStreamingAgent + +# Import mock classes from conftest +from conftest import ( + MockDatabaseProvider, + MockLLMProvider, + MockR2RStreamingAgent, + MockSearchResultsCollector, + collect_stream_output +) + + +@pytest.mark.asyncio +async def test_streaming_agent_functionality(): + """Test basic functionality of the streaming agent.""" + # Create mock config + config = MagicMock() + config.stream = True + + # Create mock providers + llm_provider = MockLLMProvider( + response_content="This is a test response", + citations=[] + ) + db_provider = MockDatabaseProvider() + + # Create mock search results collector + search_results_collector = MockSearchResultsCollector({}) + + # Create agent + agent = MockR2RStreamingAgent( + database_provider=db_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=GenerationConfig(model="test/model") + ) + + # Set the search results collector + agent.search_results_collector = search_results_collector + + # Test a simple query + messages = [Message(role="user", content="Test query")] + + # Run the agent + stream = agent.arun(messages=messages) + output = await collect_stream_output(stream) + + # Verify response + message_events = [line for line in output if 'event: message' in line] + assert len(message_events) > 0, "Message event should be emitted" + + # Verify final answer + final_answer_events = [line for line in output if 'event: agent.final_answer' in line] + assert len(final_answer_events) > 0, "Final answer event should be emitted" + + # Verify done event + done_events = [line for line in output if 'event: done' in line] + assert len(done_events) > 0, "Done event should be emitted" + + +@pytest.mark.asyncio +async def test_agent_handles_multiple_messages(): + """Test agent handles conversation with multiple messages.""" + # Create mock config + config = MagicMock() + config.stream = True + + # Create mock providers + llm_provider = MockLLMProvider( + response_content="This is a response to multiple messages", + citations=[] + ) + db_provider = MockDatabaseProvider() + + # Create mock search results collector + search_results = { + "abc1234": { + "document_id": "doc_abc1234", + "text": "This is document text for abc1234", + "metadata": {"source": "source_abc1234"} + }, + "def5678": { + "document_id": "doc_def5678", + "text": "This is document text for def5678", + "metadata": {"source": "source_def5678"} + } + } + search_results_collector = MockSearchResultsCollector(search_results) + + # Create agent + agent = MockR2RStreamingAgent( + database_provider=db_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=GenerationConfig(model="test/model") + ) + + # Set the search results collector + agent.search_results_collector = search_results_collector + + # Test with multiple messages + messages = [ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="First question"), + Message(role="assistant", content="First answer"), + Message(role="user", content="Follow-up question") + ] + + # Run the agent + stream = agent.arun(messages=messages) + output = await collect_stream_output(stream) + + # Verify response + message_events = [line for line in output if 'event: message' in line] + assert len(message_events) > 0, "Message event should be emitted" + + # After running, check that conversation has the new assistant response + # Note: MockR2RStreamingAgent._setup adds a default system message + # and then our messages are added, plus the agent's response + assert len(agent.conversation.messages) == 6, "Conversation should have correct number of messages" + + # The last message should be the assistant's response + assert agent.conversation.messages[-1].role == "assistant", "Last message should be from assistant" + + # We should have two system messages (default + our custom one) + system_messages = [m for m in agent.conversation.messages if m.role == "system"] + assert len(system_messages) == 2, "Should have two system messages" + + +@pytest.mark.asyncio +async def test_agent_event_format(): + """Test the format of events emitted by the agent.""" + # Create mock config + config = MagicMock() + config.stream = True + + # Create mock providers + llm_provider = MockLLMProvider( + response_content="This is a test of event formatting", + citations=[] + ) + db_provider = MockDatabaseProvider() + + # Create mock search results collector + search_results_collector = MockSearchResultsCollector({}) + + # Create agent + agent = MockR2RStreamingAgent( + database_provider=db_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=GenerationConfig(model="test/model") + ) + + # Set the search results collector + agent.search_results_collector = search_results_collector + + # Test a simple query + messages = [Message(role="user", content="Test query")] + + # Run the agent + stream = agent.arun(messages=messages) + output = await collect_stream_output(stream) + + # Check message event format + message_events = [line for line in output if 'event: message' in line] + assert len(message_events) > 0, "Message event should be emitted" + + data_part = message_events[0].split('data: ')[1] if 'data: ' in message_events[0] else "" + try: + data = json.loads(data_part) + assert "content" in data, "Message event should include content" + except json.JSONDecodeError: + assert False, "Message event data should be valid JSON" + + # Check final answer event format + final_answer_events = [line for line in output if 'event: agent.final_answer' in line] + assert len(final_answer_events) > 0, "Final answer event should be emitted" + + data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else "" + try: + data = json.loads(data_part) + assert "id" in data, "Final answer event should include ID" + assert "object" in data, "Final answer event should include object type" + assert "generated_answer" in data, "Final answer event should include generated answer" + except json.JSONDecodeError: + assert False, "Final answer event data should be valid JSON" + + +@pytest.mark.asyncio +async def test_final_answer_event_format(): + """Test that the final answer event has the expected format and content.""" + # Create mock config + config = MagicMock() + config.stream = True + + # Create mock providers + llm_provider = MockLLMProvider( + response_content="This is a test final answer", + citations=[] + ) + db_provider = MockDatabaseProvider() + + # Create mock search results collector + search_results_collector = MockSearchResultsCollector({}) + + # Create agent + agent = MockR2RStreamingAgent( + database_provider=db_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=GenerationConfig(model="test/model") + ) + + # Set the search results collector + agent.search_results_collector = search_results_collector + + # Test a simple query + messages = [Message(role="user", content="Test query")] + + # Run the agent + stream = agent.arun(messages=messages) + output = await collect_stream_output(stream) + + # Extract and verify final answer event + final_answer_events = [line for line in output if 'event: agent.final_answer' in line] + assert len(final_answer_events) > 0, "Final answer event should be emitted" + + data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else "" + try: + data = json.loads(data_part) + assert data["id"] == "msg_final", "Final answer ID should be msg_final" + assert data["object"] == "agent.final_answer", "Final answer object should be agent.final_answer" + assert "generated_answer" in data, "Final answer should include generated_answer" + assert "citations" in data, "Final answer should include citations field" + except json.JSONDecodeError: + assert False, "Final answer event data should be valid JSON" + + +@pytest.mark.asyncio +async def test_conversation_message_format(): + """Test that the conversation includes properly formatted assistant messages.""" + # Create mock config + config = MagicMock() + config.stream = True + + # Create mock providers + llm_provider = MockLLMProvider( + response_content="This is a test message", + citations=[] + ) + db_provider = MockDatabaseProvider() + + # Create mock search results collector + search_results = { + "abc1234": { + "document_id": "doc_abc1234", + "text": "This is document text for abc1234", + "metadata": {"source": "source_abc1234"} + }, + "def5678": { + "document_id": "doc_def5678", + "text": "This is document text for def5678", + "metadata": {"source": "source_def5678"} + } + } + search_results_collector = MockSearchResultsCollector(search_results) + + # Create agent + agent = MockR2RStreamingAgent( + database_provider=db_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=GenerationConfig(model="test/model") + ) + + # Set the search results collector + agent.search_results_collector = search_results_collector + + # Test a simple query + messages = [Message(role="user", content="Test query")] + + # Run the agent + stream = agent.arun(messages=messages) + await collect_stream_output(stream) + + # Get the last message from the conversation + last_message = agent.conversation.messages[-1] + + # Verify message format - note that MockR2RStreamingAgent uses a hardcoded response + assert last_message.role == "assistant", "Last message should be from assistant" + assert "This is a test response with citations" in last_message.content, "Message content should include response" + assert "metadata" in last_message.dict(), "Message should include metadata" + assert "citations" in last_message.metadata, "Message metadata should include citations" diff --git a/py/tests/unit/test_agent_citations.py b/py/tests/unit/test_agent_citations.py index e7c6f58ef..6e905af35 100644 --- a/py/tests/unit/test_agent_citations.py +++ b/py/tests/unit/test_agent_citations.py @@ -1,3 +1,14 @@ +""" +Unit tests for citation extraction and propagation in the R2RStreamingAgent. + +These tests focus specifically on citation-related functionality: +- Citation extraction from text +- Citation tracking during streaming +- Citation event emission +- Citation formatting and propagation +- Citation edge cases and validation +""" + import pytest import asyncio import json @@ -5,10 +16,21 @@ import re from unittest.mock import MagicMock, patch, AsyncMock from typing import Dict, List, Tuple, Any, AsyncGenerator +import pytest_asyncio + from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig -from core.utils import CitationTracker, SearchResultsCollector, SSEFormatter, find_new_citation_spans, extract_citation_spans +from core.utils import CitationTracker, extract_citations, extract_citation_spans from core.agent.base import R2RStreamingAgent +# Import mock classes from conftest +from conftest import ( + MockDatabaseProvider, + MockLLMProvider, + MockR2RStreamingAgent, + MockSearchResultsCollector, + collect_stream_output +) + class MockLLMProvider: """Mock LLM provider for testing.""" @@ -179,7 +201,7 @@ class MockR2RStreamingAgent(R2RStreamingAgent): "payload": citation_payload }) - # Emit citation event in the expected format + # Emit citation event citation_event = { "id": cid, "object": "citation", @@ -279,7 +301,6 @@ def test_extract_citations_from_response(): response_text = "This is a response with a citation [abc1234]." # Use the utility function directly - from core.utils import extract_citations citations = extract_citations(response_text) assert "abc1234" in citations, "Citation should be extracted from response" @@ -417,91 +438,45 @@ async def test_multiple_citations_for_same_source(mock_streaming_agent): # Yield the message event yield custom_agent._format_sse_event("message", {"content": response_content}) - # Define citation spans explicitly - first_span = (45, 54) # Span for first [abc1234] - second_span = (70, 79) # Span for second [abc1234] + # Manually extract and emit citation events + # This is a simpler approach than the character-by-character approach + citation_spans = extract_citation_spans(response_content) - # Process the first citation instance - citation_tracker.is_new_span("abc1234", first_span) + # Process the citations + for cid, spans in citation_spans.items(): + for span in spans: + # Mark as processed in the tracker + citation_tracker.is_new_span(cid, span) - # Look up the source document - source_doc = custom_agent.search_results_collector.find_by_short_id("abc1234") + # Look up the source document for this citation + source_doc = custom_agent.search_results_collector.find_by_short_id(cid) - # Create citation payload - citation_payload = { - "document_id": source_doc.get("document_id", "doc_abc1234"), - "text": source_doc.get("text", "This is document text for abc1234"), - "metadata": source_doc.get("metadata", {"source": "source_abc1234"}), - } + # Create citation payload + citation_payload = { + "document_id": source_doc.get("document_id", f"doc_{cid}"), + "text": source_doc.get("text", f"This is document text for {cid}"), + "metadata": source_doc.get("metadata", {"source": f"source_{cid}"}), + } - # Store the payload - citation_payloads["abc1234"] = citation_payload + # Store the payload + citation_payloads[cid] = citation_payload - # Track for first citation - first_citation = { - "id": "abc1234", - "span": {"start": first_span[0], "end": first_span[1]}, - "payload": citation_payload - } - custom_agent.streaming_citations.append(first_citation) + # Track for persistence + custom_agent.streaming_citations.append({ + "id": cid, + "span": {"start": span[0], "end": span[1]}, + "payload": citation_payload + }) - # Emit first citation event - first_citation_event = { - "id": "abc1234", - "object": "citation", - "span": {"start": first_span[0], "end": first_span[1]}, - "payload": citation_payload - } + # Emit citation event + citation_event = { + "id": cid, + "object": "citation", + "span": {"start": span[0], "end": span[1]}, + "payload": citation_payload + } - yield custom_agent._format_sse_event("citation", first_citation_event) - - # Process the second citation instance with the same ID but different span - citation_tracker.is_new_span("abc1234", second_span) - - # Track for second citation - second_citation = { - "id": "abc1234", - "span": {"start": second_span[0], "end": second_span[1]}, - "payload": citation_payload - } - custom_agent.streaming_citations.append(second_citation) - - # Emit second citation event - second_citation_event = { - "id": "abc1234", - "object": "citation", - "span": {"start": second_span[0], "end": second_span[1]}, - "payload": citation_payload - } - - yield custom_agent._format_sse_event("citation", second_citation_event) - - # Also add a different citation ID for completeness - citation_tracker.is_new_span("def5678", (90, 99)) - - source_doc_def = custom_agent.search_results_collector.find_by_short_id("def5678") - - citation_payload_def = { - "document_id": source_doc_def.get("document_id", "doc_def5678"), - "text": source_doc_def.get("text", "This is document text for def5678"), - "metadata": source_doc_def.get("metadata", {"source": "source_def5678"}), - } - - citation_payloads["def5678"] = citation_payload_def - - third_citation = { - "id": "def5678", - "span": {"start": 90, "end": 99}, - "payload": citation_payload_def - } - custom_agent.streaming_citations.append(third_citation) - - yield custom_agent._format_sse_event("citation", { - "id": "def5678", - "object": "citation", - "span": {"start": 90, "end": 99}, - "payload": citation_payload_def - }) + yield custom_agent._format_sse_event("citation", citation_event) # Add assistant message with citation metadata to conversation await custom_agent.conversation.add_message( @@ -533,7 +508,6 @@ async def test_multiple_citations_for_same_source(mock_streaming_agent): "citations": consolidated_citations } - # Emit final answer event yield custom_agent._format_sse_event("agent.final_answer", final_evt_payload) # Signal the end of the SSE stream @@ -543,7 +517,7 @@ async def test_multiple_citations_for_same_source(mock_streaming_agent): with patch.object(custom_agent, 'arun', custom_arun): messages = [Message(role="user", content="Test query")] - # Run the agent with the modified arun + # Run the agent with overlapping citations stream = custom_agent.arun(messages=messages) output = await collect_stream_output(stream) @@ -650,7 +624,7 @@ async def test_citation_event_format(mock_streaming_agent): assert 'data: ' in event, "Event should have data payload" # Parse the data payload - data_part = event.split('data: ')[1] if 'data: ' in event else "{}" + data_part = event.split('data: ')[1] if 'data: ' in event else event try: data = json.loads(data_part) @@ -693,7 +667,7 @@ async def test_final_answer_event_format(mock_streaming_agent): assert 'data: ' in event, "Event should have data payload" # Parse the data payload - data_part = event.split('data: ')[1] if 'data: ' in event else "{}" + data_part = event.split('data: ')[1] if 'data: ' in event else event try: data = json.loads(data_part) @@ -747,7 +721,7 @@ async def test_overlapping_citation_handling(): rag_generation_config=GenerationConfig(model="test/model") ) - # Replace search results collector + # Replace the search results collector with our mock agent.search_results_collector = MockSearchResultsCollector({ "abc1234": { "document_id": "doc_abc1234", @@ -912,18 +886,17 @@ async def test_robustness_against_citation_variations(mock_streaming_agent): """ # Use the extract_citations function directly to see what would be detected - from core.utils import extract_citations - extracted = extract_citations(response_text) + citations = extract_citations(response_text) # There should be at least two different citation IDs - unique_citations = set(extracted) + unique_citations = set(citations) assert len(unique_citations) >= 2, "Should extract at least two different citation IDs" assert "abc1234" in unique_citations, "Should extract abc1234" assert "def5678" in unique_citations, "Should extract def5678" # Count occurrences of each citation counts = {} - for cid in extracted: + for cid in citations: counts[cid] = counts.get(cid, 0) + 1 # Each citation should be found the correct number of times based on the text @@ -963,8 +936,6 @@ class TestCitationEdgeCases: ]) def test_citation_extraction_cases(self, test_case): """Test citation extraction with various edge cases.""" - from core.utils import extract_citations - text = test_case["text"] expected = test_case["expected_citations"] @@ -978,3 +949,189 @@ class TestCitationEdgeCases: if expected: for expected_citation in expected: assert expected_citation in actual, f"Expected citation {expected_citation} not found" + +@pytest.mark.asyncio +async def test_citation_handling_with_empty_response(): + """Test how the agent handles responses with no citations.""" + # Create a custom R2RStreamingAgent with no citations + + # Custom agent class for testing empty citations + class EmptyResponseAgent(MockR2RStreamingAgent): + async def arun( + self, + system_instruction: str = None, + messages: list[Message] = None, + *args, + **kwargs, + ) -> AsyncGenerator[str, None]: + """Custom arun with no citations in the response.""" + await self._setup(system_instruction) + + if messages: + for m in messages: + await self.conversation.add_message(m) + + # Initialize citation tracker + citation_tracker = CitationTracker() + + # Empty response with no citations + response_content = "This is a response with no citations." + + # Yield an initial message event with the start of the text + yield self._format_sse_event("message", {"content": response_content}) + + # No citation spans to extract + citation_spans = extract_citation_spans(response_content) + + # Should be empty + assert len(citation_spans) == 0, "No citation spans should be found" + + # Add assistant message to conversation (with no citation metadata) + await self.conversation.add_message( + Message( + role="assistant", + content=response_content, + metadata={"citations": []} + ) + ) + + # Create and emit final answer event + final_evt_payload = { + "id": "msg_final", + "object": "agent.final_answer", + "generated_answer": response_content, + "citations": [] + } + + yield self._format_sse_event("agent.final_answer", final_evt_payload) + yield "event: done\ndata: {}\n\n" + + # Create the agent with empty citation response + config = MagicMock() + config.stream = True + + llm_provider = MockLLMProvider( + response_content="This is a response with no citations.", + citations=[] + ) + + db_provider = MockDatabaseProvider() + + # Create the custom agent + agent = EmptyResponseAgent( + database_provider=db_provider, + llm_provider=llm_provider, + config=config, + rag_generation_config=GenerationConfig(model="test/model") + ) + + # Test a simple query + messages = [Message(role="user", content="Query with no citations")] + + # Run the agent + stream = agent.arun(messages=messages) + output = await collect_stream_output(stream) + + # Verify no citation events were emitted + citation_events = [line for line in output if 'event: citation' in line] + assert len(citation_events) == 0, "No citation events should be emitted" + + # Parse the final answer event to check citations + final_answer_events = [line for line in output if 'event: agent.final_answer' in line] + assert len(final_answer_events) > 0, "Final answer event should be emitted" + + data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else "" + + # Parse final answer data + try: + data = json.loads(data_part) + assert 'citations' in data, "Final answer event should include citations field" + assert len(data['citations']) == 0, "Citations list should be empty" + except json.JSONDecodeError: + assert False, "Final answer event data should be valid JSON" + +@pytest.mark.asyncio +async def test_citation_sanitization(): + """Test that citation IDs are properly sanitized before processing.""" + # Since extract_citations uses a strict regex pattern [A-Za-z0-9]{7,8}, + # we should test with valid citation formats + text = "Citation with surrounding text[abc1234]and [def5678]with no spaces." + + # Extract citations + citations = extract_citations(text) + + # Check if citations are properly extracted + assert "abc1234" in citations, "Citation abc1234 should be extracted" + assert "def5678" in citations, "Citation def5678 should be extracted" + + # Test with spaces - these should NOT be extracted based on the implementation + text_with_spaces = "Citation with [abc1234 ] and [ def5678] spaces." + citations_with_spaces = extract_citations(text_with_spaces) + + # The current implementation doesn't extract citations with spaces inside the brackets + assert len(citations_with_spaces) == 0 or "abc1234" not in citations_with_spaces, "Citations with spaces should not be extracted with current implementation" + +@pytest.mark.asyncio +async def test_citation_tracking_state_persistence(): + """Test that the CitationTracker correctly maintains state across multiple calls.""" + tracker = CitationTracker() + + # Record some initial spans + tracker.is_new_span("abc1234", (10, 18)) + tracker.is_new_span("def5678", (30, 38)) + + # Check if spans are correctly stored + all_spans = tracker.get_all_spans() + assert "abc1234" in all_spans, "Citation abc1234 should be tracked" + assert "def5678" in all_spans, "Citation def5678 should be tracked" + assert all_spans["abc1234"] == [(10, 18)], "Span positions should match" + + # Add another span for an existing citation + tracker.is_new_span("abc1234", (50, 58)) + + # Check if the new span was added + all_spans = tracker.get_all_spans() + assert len(all_spans["abc1234"]) == 2, "Citation abc1234 should have 2 spans" + assert (50, 58) in all_spans["abc1234"], "New span should be added" + +def test_citation_span_uniqueness(): + """Test that CitationTracker correctly identifies duplicate spans.""" + tracker = CitationTracker() + + # Record a span + tracker.is_new_span("abc1234", (10, 18)) + + # Check if the same span is recognized as not new + assert not tracker.is_new_span("abc1234", (10, 18)), "Duplicate span should not be considered new" + + # Check if different span for same citation is recognized as new + assert tracker.is_new_span("abc1234", (20, 28)), "Different span should be considered new" + + # Check if same span for different citation is recognized as new + assert tracker.is_new_span("def5678", (10, 18)), "Same span for different citation should be considered new" + +def test_citation_with_punctuation(): + """Test extraction of citations with surrounding punctuation.""" + text = "Citations with punctuation: ([abc1234]), [def5678]!, and [ghi9012]." + + # Extract citations + citations = extract_citations(text) + + # Check if all citations are extracted correctly + assert "abc1234" in citations, "Citation abc1234 should be extracted" + assert "def5678" in citations, "Citation def5678 should be extracted" + assert "ghi9012" in citations, "Citation ghi9012 should be extracted" + +def test_citation_extraction_with_invalid_formats(): + """Test that invalid citation formats are not extracted.""" + text = "Invalid citation formats: [123], [abcdef], [abc123456789], and valid [abc1234]." + + # Extract citations + citations = extract_citations(text) + + # Check that only valid citations are extracted + assert len(citations) == 1, "Only one valid citation should be extracted" + assert "abc1234" in citations, "Only valid citation abc1234 should be extracted" + assert "123" not in citations, "Invalid citation [123] should not be extracted" + assert "abcdef" not in citations, "Invalid citation [abcdef] should not be extracted" + assert "abc123456789" not in citations, "Invalid citation [abc123456789] should not be extracted" diff --git a/py/tests/unit/test_filters.py b/py/tests/unit/test_filters.py index 8adf6f666..f7f588df6 100644 --- a/py/tests/unit/test_filters.py +++ b/py/tests/unit/test_filters.py @@ -153,7 +153,7 @@ class TestFilterTypeConversions: filters = {"id": None} sql, params = simplified_apply_filters(filters, []) # Different implementations might handle NULL differently - assert ("IS NULL" in sql or "= NULL" in sql or + assert ("IS NULL" in sql or "= NULL" in sql or (sql.strip().endswith("= $1") and (not params or params == [None]))), "Should handle null in top-level column" # Metadata field @@ -211,9 +211,9 @@ class TestRealWorldQueries: {"metadata.score": {SimplifiedFilterOperator.GT: 80}} ] } - + sql, params = simplified_apply_filters(filters, []) - + # Basic checks assert " AND " in sql, "Should use AND to combine all conditions" assert "collection_ids &&" in sql, "Should use overlap operator for collections" @@ -222,13 +222,13 @@ class TestRealWorldQueries: assert "metadata->>'status'" in sql or "metadata->'status'" in sql, "Should compare status as text" assert "@>" in sql, "Should use containment for tags" assert "score" in sql and "numeric" in sql, "Should compare score as numeric" - + # Check parameters assert ["collection1", "collection2"] in params, "Should include collection IDs" assert "2021-01-01" in params, "Should include date string" assert "active" in params, "Should include status" assert "80" in params, "Should include score (as string)" - + # At least one parameter should be a JSON string for tags assert any(isinstance(p, str) and "important" in p for p in params), "Should include JSON-encoded tags" @@ -246,15 +246,15 @@ class TestRealWorldQueries: } ] } - + sql, params = simplified_apply_filters(filters, []) - + # Basic checks assert " OR " in sql, "Should use OR for pagination options" assert " AND " in sql, "Should use AND for tie-breaker" assert "metadata->>'created_at'" in sql, "Should reference created_at" assert "id <" in sql, "Should have ID comparison" - + # Check parameters assert "2023-01-01" in params, "Should include date twice" assert params.count("2023-01-01") == 2, "Date should appear twice (LT and EQ)" @@ -266,9 +266,9 @@ class TestRealWorldQueries: filters = { "metadata.level1.level2.level3.value": {SimplifiedFilterOperator.GT: 100} } - + sql, params = simplified_apply_filters(filters, []) - + # Check JSON path navigation expected_path = "metadata->'level1'->'level2'->'level3'->>'value'" assert expected_path in sql, "Should properly navigate nested JSON path" @@ -285,16 +285,16 @@ class TestRealWorldQueries: {"metadata.content": {SimplifiedFilterOperator.ILIKE: "%search term%"}} ] } - + sql, params = simplified_apply_filters(filters, []) - + # Check basic structure assert " OR " in sql, "Should use OR to combine text search conditions" # Different implementations can use different JSON extraction operators assert "metadata" in sql and "title" in sql and "ILIKE" in sql, "Should search in title with ILIKE" assert "metadata" in sql and "description" in sql, "Should search in description" assert "metadata" in sql and "content" in sql, "Should search in content" - + # Check parameters assert all("%search term%" in p for p in params), "All parameters should contain search term" @@ -308,11 +308,11 @@ class TestCornerCases: many_conditions = [] for i in range(50): # 50 conditions many_conditions.append({"metadata.field" + str(i): "value" + str(i)}) - + filters = {SimplifiedFilterOperator.AND: many_conditions} - + sql, params = simplified_apply_filters(filters, []) - + # Basic checks assert " AND " in sql, "Should use AND to combine conditions" assert len(params) == 50, "Should have 50 parameters" @@ -328,15 +328,15 @@ class TestCornerCases: {"metadata.score": {SimplifiedFilterOperator.LT: 100}} ] } - + sql, params = simplified_apply_filters(filters, []) - + # Check structure assert " AND " in sql, "Should use AND to combine range bounds" assert "metadata->>'score'" in sql, "Should reference score field" assert "::numeric >=" in sql, "Should have GTE comparison" assert "::numeric <" in sql, "Should have LT comparison" - + # Check parameters assert "50" in params, "Should include lower bound" assert "100" in params, "Should include upper bound" @@ -348,7 +348,7 @@ class TestCornerCases: sql, params = simplified_apply_filters(filters, []) # The behavior depends on implementation - could be "FALSE" or empty list handling assert "FALSE" in sql.upper() or ("ANY" in sql and "[]" in str(params)), "Should handle empty IN list" - + # Empty CONTAINS list filters = {"metadata.tags": {SimplifiedFilterOperator.CONTAINS: []}} sql, params = simplified_apply_filters(filters, []) @@ -363,7 +363,7 @@ class TestCornerCases: # Should be safe because of parameterization assert "id =" in sql, "Should handle value normally" assert params == ["value'; DROP TABLE users; --"], "Should include value as parameter" - + # Very long string long_string = "x" * 1000 # 1000 character string filters = {"metadata.field": long_string} @@ -378,7 +378,7 @@ class TestCornerCases: sql, params = simplified_apply_filters(filters, []) assert "metadata->>'title'" in sql, "Should handle Unicode normally" assert params == ["😀 Unicode 测试"], "Should include Unicode string as parameter" - + # Unicode in field name (might not be supported in current implementation) try: filters = {"metadata.标题": "value"} @@ -590,11 +590,11 @@ class TestMetadataFilters: # With prefix filters1 = {"metadata.key": "value"} sql1, params1 = main_apply_filters(filters1, []) - + # Without prefix (should be treated as metadata) filters2 = {"key": "value"} sql2, params2 = main_apply_filters(filters2, [], top_level_columns=["id", "owner_id"]) - + # Both should produce the same SQL assert "metadata" in sql2, "Field not in top_level_columns should be treated as metadata" assert params1 == params2, "Parameters should be the same" @@ -678,15 +678,15 @@ class TestEdgeCases: def test_filter_modes(self): """Test different filter modes (where_clause, condition_only, append_only).""" filters = {"id": "test-id"} - + # Default where_clause mode sql1, _ = main_apply_filters(filters, []) assert sql1.startswith("WHERE"), "Default mode should prepend WHERE" - + # condition_only mode sql2, _ = main_apply_filters(filters, [], mode="condition_only") assert not sql2.startswith("WHERE"), "condition_only mode should not prepend WHERE" - + # append_only mode sql3, _ = main_apply_filters(filters, [], mode="append_only") assert sql3.startswith("AND"), "append_only mode should prepend AND" @@ -718,13 +718,13 @@ class TestEdgeCases: """Test with custom top_level_columns parameter.""" # Define a custom set of top-level columns custom_columns = ["id", "custom_field"] - + # Test a field that's in custom_columns filters = {"custom_field": "value"} sql, _ = main_apply_filters(filters, [], top_level_columns=custom_columns) assert "custom_field =" in sql, "Should treat custom_field as a normal column" assert "metadata" not in sql, "Should not treat custom_field as metadata" - + # Test a field that's not in custom_columns filters = {"other_field": "value"} sql, _ = main_apply_filters(filters, [], top_level_columns=custom_columns) @@ -734,7 +734,7 @@ class TestEdgeCases: """Test with custom json_column parameter.""" # Use a custom json column name custom_json = "properties" - + filters = {"field": "value"} sql, _ = main_apply_filters(filters, [], top_level_columns=["id"], json_column=custom_json) assert custom_json in sql, f"Should use {custom_json} instead of metadata" @@ -759,23 +759,23 @@ class TestComplexFilterCombinations: ] } sql, params = main_apply_filters(filters, []) - + # Check for AND operator assert " AND " in sql, "SQL should contain AND operator" - + # Check for metadata handling assert "metadata->>'score'" in sql, "SQL should handle metadata field" assert "::numeric >=" in sql, "SQL should handle numeric comparison" - + # Check for OR operator assert " OR " in sql, "SQL should contain OR operator" - + # Check for collection_id handling assert "collection_ids" in sql, "SQL should handle collection_id" - + # Check for parent_id handling assert "parent_id = ANY" in sql, "SQL should handle parent_id IN condition" - + # Check parameters assert len(params) == 4, "Should have 4 parameters" assert "test-id" in params, "Parameters should include test-id" @@ -787,7 +787,7 @@ class TestComplexFilterCombinations: """Test filtering with deeply nested JSON fields.""" filters = {"metadata.level1.level2.level3.deep": {MainFilterOperator.GT: 100}} sql, params = main_apply_filters(filters, []) - + expected_path = ( "metadata->'level1'->'level2'->'level3'->>'deep'" )