Files
R2R/py/tests/unit/conftest.py
T
emrgnt-cmplxty 05f04e7109 add agent tests
2025-03-21 22:10:35 -07:00

496 lines
16 KiB
Python

# tests/conftest.py
import os
import pytest
from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
from core.providers import NaClCryptoConfig, NaClCryptoProvider
from core.providers.database.postgres import (
PostgresChunksHandler,
PostgresCollectionsHandler,
PostgresConversationsHandler,
PostgresDatabaseProvider,
PostgresDocumentsHandler,
PostgresGraphsHandler,
PostgresLimitsHandler,
PostgresPromptsHandler,
)
from core.providers.database.users import ( # Make sure this import is correct
PostgresUserHandler, )
TEST_DB_CONNECTION_STRING = os.environ.get(
"TEST_DB_CONNECTION_STRING",
"postgresql://postgres:postgres@localhost:5432/test_db",
)
@pytest.fixture
async def db_provider():
crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
db_config = DatabaseConfig(
app=AppConfig(project_name="test_project"),
provider="postgres",
connection_string=TEST_DB_CONNECTION_STRING,
postgres_configuration_settings={
"max_connections": 10,
"statement_cache_size": 100,
},
project_name="test_project",
)
dimension = 4
quantization_type = VectorQuantizationType.FP32
db_provider = PostgresDatabaseProvider(db_config, dimension,
crypto_provider, quantization_type)
await db_provider.initialize()
yield db_provider
# Teardown logic if needed
await db_provider.close()
@pytest.fixture
def crypto_provider():
# Provide a crypto provider fixture if needed separately
return NaClCryptoProvider(NaClCryptoConfig(app={}))
@pytest.fixture
async def chunks_handler(db_provider):
dimension = db_provider.dimension
quantization_type = db_provider.quantization_type
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
handler = PostgresChunksHandler(
project_name=project_name,
connection_manager=connection_manager,
dimension=dimension,
quantization_type=quantization_type,
)
await handler.create_tables()
return handler
@pytest.fixture
async def collections_handler(db_provider):
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
config = db_provider.config
handler = PostgresCollectionsHandler(
project_name=project_name,
connection_manager=connection_manager,
config=config,
)
await handler.create_tables()
return handler
@pytest.fixture
async def conversations_handler(db_provider):
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
handler = PostgresConversationsHandler(project_name, connection_manager)
await handler.create_tables()
return handler
@pytest.fixture
async def documents_handler(db_provider):
dimension = db_provider.dimension
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
handler = PostgresDocumentsHandler(
project_name=project_name,
connection_manager=connection_manager,
dimension=dimension,
)
await handler.create_tables()
return handler
@pytest.fixture
async def graphs_handler(db_provider):
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
dimension = db_provider.dimension
quantization_type = db_provider.quantization_type
# If collections_handler is needed, you can depend on the collections_handler fixture
# or pass None if it's optional.
handler = PostgresGraphsHandler(
project_name=project_name,
connection_manager=connection_manager,
dimension=dimension,
quantization_type=quantization_type,
collections_handler=
None, # if needed, or await collections_handler fixture
)
await handler.create_tables()
return handler
@pytest.fixture
async def limits_handler(db_provider):
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
config = db_provider.config
handler = PostgresLimitsHandler(
project_name=project_name,
connection_manager=connection_manager,
config=config,
)
await handler.create_tables()
# Optionally truncate
await connection_manager.execute_query(
f"TRUNCATE {handler._get_table_name('request_log')};")
return handler
@pytest.fixture
async def users_handler(db_provider, crypto_provider):
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
handler = PostgresUserHandler(
project_name=project_name,
connection_manager=connection_manager,
crypto_provider=crypto_provider,
)
await handler.create_tables()
# Optionally clean up users table before each test
await connection_manager.execute_query(
f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
await connection_manager.execute_query(
f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
return handler
@pytest.fixture
async def prompt_handler(db_provider):
"""Returns an instance of PostgresPromptsHandler, creating the necessary
tables first."""
# from core.providers.database.postgres_prompts import PostgresPromptsHandler
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
handler = PostgresPromptsHandler(
project_name=project_name,
connection_manager=connection_manager,
# You can specify a local prompt directory if desired
prompt_directory=None,
)
# Create necessary tables and do initial prompt load
await handler.create_tables()
return handler
@pytest.fixture
async def graphs_handler(db_provider):
project_name = db_provider.project_name
connection_manager = db_provider.connection_manager
dimension = db_provider.dimension
quantization_type = db_provider.quantization_type
# Optionally ensure 'collection_ids' column exists on your table(s), e.g.:
create_col_sql = f"""
ALTER TABLE "{project_name}"."graphs_entities"
ADD COLUMN IF NOT EXISTS collection_ids UUID[] DEFAULT '{{}}';
"""
await connection_manager.execute_query(create_col_sql)
handler = PostgresGraphsHandler(
project_name=project_name,
connection_manager=connection_manager,
dimension=dimension,
quantization_type=quantization_type,
collections_handler=None,
)
await handler.create_tables()
return handler
# Citation testing fixtures and utilities
import json
import re
from unittest.mock import MagicMock, AsyncMock
from typing import Tuple, Any, AsyncGenerator
from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
from core.utils import CitationTracker, SearchResultsCollector
from core.agent.base import R2RStreamingAgent
class MockLLMProvider:
"""Mock LLM provider for testing."""
def __init__(self, response_content=None, citations=None):
self.response_content = response_content or "This is a response"
self.citations = citations or []
async def aget_completion(self, messages, generation_config):
"""Mock synchronous completion."""
content = self.response_content
for citation in self.citations:
content += f" [{citation}]"
mock_response = MagicMock(spec=LLMChatCompletion)
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = content
mock_response.choices[0].finish_reason = "stop"
return mock_response
async def aget_completion_stream(self, messages, generation_config):
"""Mock streaming completion."""
content = self.response_content
for citation in self.citations:
content += f" [{citation}]"
# Simulate streaming by yielding one character at a time
for i in range(len(content)):
chunk = MagicMock(spec=LLMChatCompletionChunk)
chunk.choices = [MagicMock()]
chunk.choices[0].delta = MagicMock()
chunk.choices[0].delta.content = content[i]
chunk.choices[0].finish_reason = None
yield chunk
# Final chunk with finish_reason="stop"
final_chunk = MagicMock(spec=LLMChatCompletionChunk)
final_chunk.choices = [MagicMock()]
final_chunk.choices[0].delta = MagicMock()
final_chunk.choices[0].delta.content = ""
final_chunk.choices[0].finish_reason = "stop"
yield final_chunk
class MockPromptsHandler:
"""Mock prompts handler for testing."""
async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
"""Return a mock system prompt."""
return "You are a helpful assistant that provides well-sourced information."
class MockDatabaseProvider:
"""Mock database provider for testing."""
def __init__(self):
# Add a prompts_handler attribute to prevent AttributeError
self.prompts_handler = MockPromptsHandler()
async def acreate_conversation(self, *args, **kwargs):
return {"id": "conv_12345"}
async def aupdate_conversation(self, *args, **kwargs):
return True
async def acreate_message(self, *args, **kwargs):
return {"id": "msg_12345"}
class MockSearchResultsCollector:
"""Mock search results collector for testing."""
def __init__(self, results=None):
self.results = results or {}
def find_by_short_id(self, short_id):
return self.results.get(short_id, {
"document_id": f"doc_{short_id}",
"text": f"This is document text for {short_id}",
"metadata": {"source": f"source_{short_id}"}
})
# Create a concrete implementation of R2RStreamingAgent for testing
class MockR2RStreamingAgent(R2RStreamingAgent):
"""Mock streaming agent for testing that implements the abstract method."""
# Regex pattern for citations, copied from the actual agent
BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
def _register_tools(self):
"""Implement the abstract method with a no-op version."""
pass
async def _setup(self, system_instruction=None, *args, **kwargs):
"""Override _setup to simplify initialization and avoid external dependencies."""
# Use a simple system message instead of fetching from database
system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
# Add system message to conversation
await self.conversation.add_message(
Message(role="system", content=system_content)
)
def _format_sse_event(self, event_type, data):
"""Format an SSE event manually."""
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
async def arun(
self,
system_instruction: str = None,
messages: list[Message] = None,
*args,
**kwargs,
) -> AsyncGenerator[str, None]:
"""
Simplified version of arun that focuses on citation handling for testing.
"""
await self._setup(system_instruction)
if messages:
for m in messages:
await self.conversation.add_message(m)
# Initialize citation tracker
citation_tracker = CitationTracker()
citation_payloads = {}
# Track streaming citations for final persistence
self.streaming_citations = []
# Get the LLM response with citations
response_content = "This is a test response with citations"
response_content += " [abc1234] [def5678]"
# Yield an initial message event with the start of the text
yield self._format_sse_event("message", {"content": response_content})
# Manually extract and emit citation events
# This is a simpler approach than the character-by-character approach
citation_spans = extract_citation_spans(response_content)
# Process the citations
for cid, spans in citation_spans.items():
for span in spans:
# Check if the span is new and record it
if citation_tracker.is_new_span(cid, span):
# Look up the source document for this citation
source_doc = self.search_results_collector.find_by_short_id(cid)
# Create citation payload
citation_payload = {
"document_id": source_doc.get("document_id", f"doc_{cid}"),
"text": source_doc.get("text", f"This is document text for {cid}"),
"metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
}
# Store the payload by citation ID
citation_payloads[cid] = citation_payload
# Track for persistence
self.streaming_citations.append({
"id": cid,
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
})
# Emit citation event in the expected format
citation_event = {
"id": cid,
"object": "citation",
"span": {"start": span[0], "end": span[1]},
"payload": citation_payload
}
yield self._format_sse_event("citation", citation_event)
# Add assistant message with citation metadata to conversation
await self.conversation.add_message(
Message(
role="assistant",
content=response_content,
metadata={"citations": self.streaming_citations}
)
)
# Prepare consolidated citations for final answer
consolidated_citations = []
# Group citations by ID with all their spans
for cid, spans in citation_tracker.get_all_spans().items():
if cid in citation_payloads:
consolidated_citations.append({
"id": cid,
"object": "citation",
"spans": [{"start": s[0], "end": s[1]} for s in spans],
"payload": citation_payloads[cid]
})
# Create and emit final answer event
final_evt_payload = {
"id": "msg_final",
"object": "agent.final_answer",
"generated_answer": response_content,
"citations": consolidated_citations
}
# Manually format the final answer event
yield self._format_sse_event("agent.final_answer", final_evt_payload)
# Signal the end of the SSE stream
yield "event: done\ndata: {}\n\n"
@pytest.fixture
def mock_streaming_agent():
"""Create a streaming agent with mocked dependencies."""
# Create mock config
config = MagicMock()
config.stream = True
config.max_iterations = 3
# Create mock providers
llm_provider = MockLLMProvider(
response_content="This is a test response with citations",
citations=["abc1234", "def5678"]
)
db_provider = MockDatabaseProvider()
# Create agent with mocked dependencies using our concrete implementation
agent = MockR2RStreamingAgent(
database_provider=db_provider,
llm_provider=llm_provider,
config=config,
rag_generation_config=GenerationConfig(model="test/model")
)
# Replace the search results collector with our mock
agent.search_results_collector = MockSearchResultsCollector({
"abc1234": {
"document_id": "doc_abc1234",
"text": "This is document text for abc1234",
"metadata": {"source": "source_abc1234"}
},
"def5678": {
"document_id": "doc_def5678",
"text": "This is document text for def5678",
"metadata": {"source": "source_def5678"}
}
})
return agent
async def collect_stream_output(stream):
"""Collect all output from a stream into a list."""
output = []
async for event in stream:
output.append(event)
return output
from core.utils import extract_citation_spans, find_new_citation_spans