496 lines
16 KiB
Python
496 lines
16 KiB
Python
# tests/conftest.py
|
|
import os
|
|
|
|
import pytest
|
|
|
|
from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
|
|
from core.providers import NaClCryptoConfig, NaClCryptoProvider
|
|
from core.providers.database.postgres import (
|
|
PostgresChunksHandler,
|
|
PostgresCollectionsHandler,
|
|
PostgresConversationsHandler,
|
|
PostgresDatabaseProvider,
|
|
PostgresDocumentsHandler,
|
|
PostgresGraphsHandler,
|
|
PostgresLimitsHandler,
|
|
PostgresPromptsHandler,
|
|
)
|
|
from core.providers.database.users import ( # Make sure this import is correct
|
|
PostgresUserHandler, )
|
|
|
|
TEST_DB_CONNECTION_STRING = os.environ.get(
|
|
"TEST_DB_CONNECTION_STRING",
|
|
"postgresql://postgres:postgres@localhost:5432/test_db",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_provider():
|
|
crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
|
|
db_config = DatabaseConfig(
|
|
app=AppConfig(project_name="test_project"),
|
|
provider="postgres",
|
|
connection_string=TEST_DB_CONNECTION_STRING,
|
|
postgres_configuration_settings={
|
|
"max_connections": 10,
|
|
"statement_cache_size": 100,
|
|
},
|
|
project_name="test_project",
|
|
)
|
|
|
|
dimension = 4
|
|
quantization_type = VectorQuantizationType.FP32
|
|
|
|
db_provider = PostgresDatabaseProvider(db_config, dimension,
|
|
crypto_provider, quantization_type)
|
|
|
|
await db_provider.initialize()
|
|
yield db_provider
|
|
# Teardown logic if needed
|
|
await db_provider.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def crypto_provider():
|
|
# Provide a crypto provider fixture if needed separately
|
|
return NaClCryptoProvider(NaClCryptoConfig(app={}))
|
|
|
|
|
|
@pytest.fixture
|
|
async def chunks_handler(db_provider):
|
|
dimension = db_provider.dimension
|
|
quantization_type = db_provider.quantization_type
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
|
|
handler = PostgresChunksHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
dimension=dimension,
|
|
quantization_type=quantization_type,
|
|
)
|
|
await handler.create_tables()
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def collections_handler(db_provider):
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
config = db_provider.config
|
|
|
|
handler = PostgresCollectionsHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
config=config,
|
|
)
|
|
await handler.create_tables()
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def conversations_handler(db_provider):
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
|
|
handler = PostgresConversationsHandler(project_name, connection_manager)
|
|
await handler.create_tables()
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def documents_handler(db_provider):
|
|
dimension = db_provider.dimension
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
|
|
handler = PostgresDocumentsHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
dimension=dimension,
|
|
)
|
|
await handler.create_tables()
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def graphs_handler(db_provider):
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
dimension = db_provider.dimension
|
|
quantization_type = db_provider.quantization_type
|
|
|
|
# If collections_handler is needed, you can depend on the collections_handler fixture
|
|
# or pass None if it's optional.
|
|
handler = PostgresGraphsHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
dimension=dimension,
|
|
quantization_type=quantization_type,
|
|
collections_handler=
|
|
None, # if needed, or await collections_handler fixture
|
|
)
|
|
await handler.create_tables()
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def limits_handler(db_provider):
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
config = db_provider.config
|
|
|
|
handler = PostgresLimitsHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
config=config,
|
|
)
|
|
await handler.create_tables()
|
|
# Optionally truncate
|
|
await connection_manager.execute_query(
|
|
f"TRUNCATE {handler._get_table_name('request_log')};")
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def users_handler(db_provider, crypto_provider):
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
|
|
handler = PostgresUserHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
crypto_provider=crypto_provider,
|
|
)
|
|
await handler.create_tables()
|
|
|
|
# Optionally clean up users table before each test
|
|
await connection_manager.execute_query(
|
|
f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
|
|
await connection_manager.execute_query(
|
|
f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
|
|
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def prompt_handler(db_provider):
|
|
"""Returns an instance of PostgresPromptsHandler, creating the necessary
|
|
tables first."""
|
|
# from core.providers.database.postgres_prompts import PostgresPromptsHandler
|
|
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
|
|
handler = PostgresPromptsHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
# You can specify a local prompt directory if desired
|
|
prompt_directory=None,
|
|
)
|
|
# Create necessary tables and do initial prompt load
|
|
await handler.create_tables()
|
|
return handler
|
|
|
|
|
|
@pytest.fixture
|
|
async def graphs_handler(db_provider):
|
|
project_name = db_provider.project_name
|
|
connection_manager = db_provider.connection_manager
|
|
dimension = db_provider.dimension
|
|
quantization_type = db_provider.quantization_type
|
|
|
|
# Optionally ensure 'collection_ids' column exists on your table(s), e.g.:
|
|
create_col_sql = f"""
|
|
ALTER TABLE "{project_name}"."graphs_entities"
|
|
ADD COLUMN IF NOT EXISTS collection_ids UUID[] DEFAULT '{{}}';
|
|
"""
|
|
await connection_manager.execute_query(create_col_sql)
|
|
|
|
handler = PostgresGraphsHandler(
|
|
project_name=project_name,
|
|
connection_manager=connection_manager,
|
|
dimension=dimension,
|
|
quantization_type=quantization_type,
|
|
collections_handler=None,
|
|
)
|
|
await handler.create_tables()
|
|
return handler
|
|
|
|
# Citation testing fixtures and utilities
|
|
import json
|
|
import re
|
|
from unittest.mock import MagicMock, AsyncMock
|
|
from typing import Tuple, Any, AsyncGenerator
|
|
|
|
from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
|
|
from core.utils import CitationTracker, SearchResultsCollector
|
|
from core.agent.base import R2RStreamingAgent
|
|
|
|
|
|
class MockLLMProvider:
|
|
"""Mock LLM provider for testing."""
|
|
|
|
def __init__(self, response_content=None, citations=None):
|
|
self.response_content = response_content or "This is a response"
|
|
self.citations = citations or []
|
|
|
|
async def aget_completion(self, messages, generation_config):
|
|
"""Mock synchronous completion."""
|
|
content = self.response_content
|
|
for citation in self.citations:
|
|
content += f" [{citation}]"
|
|
|
|
mock_response = MagicMock(spec=LLMChatCompletion)
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message = MagicMock()
|
|
mock_response.choices[0].message.content = content
|
|
mock_response.choices[0].finish_reason = "stop"
|
|
return mock_response
|
|
|
|
async def aget_completion_stream(self, messages, generation_config):
|
|
"""Mock streaming completion."""
|
|
content = self.response_content
|
|
for citation in self.citations:
|
|
content += f" [{citation}]"
|
|
|
|
# Simulate streaming by yielding one character at a time
|
|
for i in range(len(content)):
|
|
chunk = MagicMock(spec=LLMChatCompletionChunk)
|
|
chunk.choices = [MagicMock()]
|
|
chunk.choices[0].delta = MagicMock()
|
|
chunk.choices[0].delta.content = content[i]
|
|
chunk.choices[0].finish_reason = None
|
|
yield chunk
|
|
|
|
# Final chunk with finish_reason="stop"
|
|
final_chunk = MagicMock(spec=LLMChatCompletionChunk)
|
|
final_chunk.choices = [MagicMock()]
|
|
final_chunk.choices[0].delta = MagicMock()
|
|
final_chunk.choices[0].delta.content = ""
|
|
final_chunk.choices[0].finish_reason = "stop"
|
|
yield final_chunk
|
|
|
|
|
|
class MockPromptsHandler:
|
|
"""Mock prompts handler for testing."""
|
|
|
|
async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
|
|
"""Return a mock system prompt."""
|
|
return "You are a helpful assistant that provides well-sourced information."
|
|
|
|
|
|
class MockDatabaseProvider:
|
|
"""Mock database provider for testing."""
|
|
|
|
def __init__(self):
|
|
# Add a prompts_handler attribute to prevent AttributeError
|
|
self.prompts_handler = MockPromptsHandler()
|
|
|
|
async def acreate_conversation(self, *args, **kwargs):
|
|
return {"id": "conv_12345"}
|
|
|
|
async def aupdate_conversation(self, *args, **kwargs):
|
|
return True
|
|
|
|
async def acreate_message(self, *args, **kwargs):
|
|
return {"id": "msg_12345"}
|
|
|
|
|
|
class MockSearchResultsCollector:
|
|
"""Mock search results collector for testing."""
|
|
|
|
def __init__(self, results=None):
|
|
self.results = results or {}
|
|
|
|
def find_by_short_id(self, short_id):
|
|
return self.results.get(short_id, {
|
|
"document_id": f"doc_{short_id}",
|
|
"text": f"This is document text for {short_id}",
|
|
"metadata": {"source": f"source_{short_id}"}
|
|
})
|
|
|
|
|
|
# Create a concrete implementation of R2RStreamingAgent for testing
|
|
class MockR2RStreamingAgent(R2RStreamingAgent):
|
|
"""Mock streaming agent for testing that implements the abstract method."""
|
|
|
|
# Regex pattern for citations, copied from the actual agent
|
|
BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
|
|
SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
|
|
|
|
def _register_tools(self):
|
|
"""Implement the abstract method with a no-op version."""
|
|
pass
|
|
|
|
async def _setup(self, system_instruction=None, *args, **kwargs):
|
|
"""Override _setup to simplify initialization and avoid external dependencies."""
|
|
# Use a simple system message instead of fetching from database
|
|
system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
|
|
|
|
# Add system message to conversation
|
|
await self.conversation.add_message(
|
|
Message(role="system", content=system_content)
|
|
)
|
|
|
|
def _format_sse_event(self, event_type, data):
|
|
"""Format an SSE event manually."""
|
|
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|
|
|
async def arun(
|
|
self,
|
|
system_instruction: str = None,
|
|
messages: list[Message] = None,
|
|
*args,
|
|
**kwargs,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Simplified version of arun that focuses on citation handling for testing.
|
|
"""
|
|
await self._setup(system_instruction)
|
|
|
|
if messages:
|
|
for m in messages:
|
|
await self.conversation.add_message(m)
|
|
|
|
# Initialize citation tracker
|
|
citation_tracker = CitationTracker()
|
|
citation_payloads = {}
|
|
|
|
# Track streaming citations for final persistence
|
|
self.streaming_citations = []
|
|
|
|
# Get the LLM response with citations
|
|
response_content = "This is a test response with citations"
|
|
response_content += " [abc1234] [def5678]"
|
|
|
|
# Yield an initial message event with the start of the text
|
|
yield self._format_sse_event("message", {"content": response_content})
|
|
|
|
# Manually extract and emit citation events
|
|
# This is a simpler approach than the character-by-character approach
|
|
citation_spans = extract_citation_spans(response_content)
|
|
|
|
# Process the citations
|
|
for cid, spans in citation_spans.items():
|
|
for span in spans:
|
|
# Check if the span is new and record it
|
|
if citation_tracker.is_new_span(cid, span):
|
|
|
|
# Look up the source document for this citation
|
|
source_doc = self.search_results_collector.find_by_short_id(cid)
|
|
|
|
# Create citation payload
|
|
citation_payload = {
|
|
"document_id": source_doc.get("document_id", f"doc_{cid}"),
|
|
"text": source_doc.get("text", f"This is document text for {cid}"),
|
|
"metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
|
|
}
|
|
|
|
# Store the payload by citation ID
|
|
citation_payloads[cid] = citation_payload
|
|
|
|
# Track for persistence
|
|
self.streaming_citations.append({
|
|
"id": cid,
|
|
"span": {"start": span[0], "end": span[1]},
|
|
"payload": citation_payload
|
|
})
|
|
|
|
# Emit citation event in the expected format
|
|
citation_event = {
|
|
"id": cid,
|
|
"object": "citation",
|
|
"span": {"start": span[0], "end": span[1]},
|
|
"payload": citation_payload
|
|
}
|
|
|
|
yield self._format_sse_event("citation", citation_event)
|
|
|
|
# Add assistant message with citation metadata to conversation
|
|
await self.conversation.add_message(
|
|
Message(
|
|
role="assistant",
|
|
content=response_content,
|
|
metadata={"citations": self.streaming_citations}
|
|
)
|
|
)
|
|
|
|
# Prepare consolidated citations for final answer
|
|
consolidated_citations = []
|
|
|
|
# Group citations by ID with all their spans
|
|
for cid, spans in citation_tracker.get_all_spans().items():
|
|
if cid in citation_payloads:
|
|
consolidated_citations.append({
|
|
"id": cid,
|
|
"object": "citation",
|
|
"spans": [{"start": s[0], "end": s[1]} for s in spans],
|
|
"payload": citation_payloads[cid]
|
|
})
|
|
|
|
# Create and emit final answer event
|
|
final_evt_payload = {
|
|
"id": "msg_final",
|
|
"object": "agent.final_answer",
|
|
"generated_answer": response_content,
|
|
"citations": consolidated_citations
|
|
}
|
|
|
|
# Manually format the final answer event
|
|
yield self._format_sse_event("agent.final_answer", final_evt_payload)
|
|
|
|
# Signal the end of the SSE stream
|
|
yield "event: done\ndata: {}\n\n"
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_streaming_agent():
|
|
"""Create a streaming agent with mocked dependencies."""
|
|
# Create mock config
|
|
config = MagicMock()
|
|
config.stream = True
|
|
config.max_iterations = 3
|
|
|
|
# Create mock providers
|
|
llm_provider = MockLLMProvider(
|
|
response_content="This is a test response with citations",
|
|
citations=["abc1234", "def5678"]
|
|
)
|
|
db_provider = MockDatabaseProvider()
|
|
|
|
# Create agent with mocked dependencies using our concrete implementation
|
|
agent = MockR2RStreamingAgent(
|
|
database_provider=db_provider,
|
|
llm_provider=llm_provider,
|
|
config=config,
|
|
rag_generation_config=GenerationConfig(model="test/model")
|
|
)
|
|
|
|
# Replace the search results collector with our mock
|
|
agent.search_results_collector = MockSearchResultsCollector({
|
|
"abc1234": {
|
|
"document_id": "doc_abc1234",
|
|
"text": "This is document text for abc1234",
|
|
"metadata": {"source": "source_abc1234"}
|
|
},
|
|
"def5678": {
|
|
"document_id": "doc_def5678",
|
|
"text": "This is document text for def5678",
|
|
"metadata": {"source": "source_def5678"}
|
|
}
|
|
})
|
|
|
|
return agent
|
|
|
|
|
|
async def collect_stream_output(stream):
|
|
"""Collect all output from a stream into a list."""
|
|
output = []
|
|
async for event in stream:
|
|
output.append(event)
|
|
return output
|
|
|
|
|
|
from core.utils import extract_citation_spans, find_new_citation_spans
|