This commit is contained in:
emrgnt-cmplxty
2024-12-13 16:28:25 -08:00
parent 481d187740
commit d145346aa4
5 changed files with 38 additions and 55 deletions
@@ -1,29 +1,13 @@
name: R2R Full Python Integration Test (ubuntu)
on:
push:
branches:
- main
paths:
- 'py/**'
- '.github/workflows/**'
- 'tests/**'
pull_request:
branches:
- dev
- dev-minor
- main
paths:
- 'py/**'
- '.github/workflows/**'
- 'tests/**'
workflow_dispatch:
jobs:
integration-test:
runs-on: ubuntu-latest
timeout-minutes: 30
env:
TELEMETRY_ENABLED: 'false'
R2R_PROJECT_NAME: r2r_default
@@ -39,7 +23,7 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python and install dependencies
uses: ./.github/actions/setup-python-full
with:
@@ -47,20 +31,20 @@ jobs:
python-version: '3.12'
poetry-version: '1.7.1'
r2r-version: 'latest'
- name: Setup and start Docker
uses: ./.github/actions/setup-docker
id: docker-setup
- name: Login Docker
uses: ./.github/actions/login-docker
with:
docker_username: ${{ secrets.RAGTORICHES_DOCKER_UNAME }}
docker_password: ${{ secrets.RAGTORICHES_DOCKER_TOKEN }}
- name: Start R2R Full server
uses: ./.github/actions/start-r2r-full
- name: Wait for server to be ready
run: |
timeout=300 # 5 minutes timeout
@@ -73,7 +57,7 @@ jobs:
sleep 5
timeout=$((timeout - 5))
done
- name: Run R2R Full Python Integration Test
run: |
cd py && poetry run pytest tests/ \
@@ -82,7 +66,7 @@ jobs:
--log-cli-level=INFO \
--junit-xml=test-results/junit.xml \
--html=test-results/report.html
- name: Upload test results
if: always()
uses: actions/upload-artifact@v3
@@ -91,7 +75,7 @@ jobs:
path: |
test-results/
pytest-logs/
- name: Check for test failures
if: failure()
run: |
@@ -107,4 +91,4 @@ jobs:
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
--health-retries 5
@@ -69,7 +69,7 @@ jobs:
- name: Wait for server to be ready
run: |
timeout=180 # 3 minutes timeout
timeout=300 # 5 minutes timeout
while ! curl -s http://localhost:8000/health > /dev/null; do
if [ $timeout -le 0 ]; then
echo "Server failed to start within timeout"
+10 -11
View File
@@ -2,30 +2,29 @@
import logging
import os
import warnings
from typing import Any, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
from ..base.abstractions import VectorQuantizationType
from ..base.providers import (
DatabaseConfig,
DatabaseProvider,
PostgresConfigurationSettings,
)
from ..base.abstractions import VectorQuantizationType
from .base import PostgresConnectionManager
from .base import PostgresConnectionManager, SemaphoreConnectionPool
from .chunks import PostgresChunksHandler
from .collections import PostgresCollectionsHandler
from .conversations import PostgresConversationsHandler
from .documents import PostgresDocumentsHandler
from .files import PostgresFilesHandler
from .graphs import (
PostgresCommunitiesHandler,
PostgresEntitiesHandler,
PostgresGraphsHandler,
PostgresRelationshipsHandler,
PostgresEntitiesHandler,
PostgresCommunitiesHandler,
)
from .prompts_handler import PostgresPromptsHandler
from .tokens import PostgresTokensHandler
from .users import PostgresUserHandler
from .chunks import PostgresChunksHandler
from .conversations import PostgresConversationsHandler
from .base import SemaphoreConnectionPool
if TYPE_CHECKING:
from ..providers.crypto import BCryptProvider
@@ -74,7 +73,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
graphs_handler: PostgresGraphsHandler
prompts_handler: PostgresPromptsHandler
files_handler: PostgresFilesHandler
conversation_handler: PostgresConversationsHandler
conversations_handler: PostgresConversationsHandler
def __init__(
self,
@@ -161,7 +160,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
self.dimension,
self.quantization_type,
)
self.conversation_handler = PostgresConversationsHandler(
self.conversations_handler = PostgresConversationsHandler(
self.project_name, self.connection_manager
)
self.entities_handler = PostgresEntitiesHandler(
@@ -229,7 +228,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
await self.communities_handler.create_tables()
await self.entities_handler.create_tables()
await self.relationships_handler.create_tables()
await self.conversation_handler.create_tables()
await self.conversations_handler.create_tables()
def _get_postgres_configuration_settings(
self, config: DatabaseConfig
+12 -10
View File
@@ -637,7 +637,7 @@ class ManagementService(Service):
) -> dict:
try:
return await self.providers.database.prompts_handler.get_prompt( # type: ignore
prompt_name=prompt_name,
name=prompt_name,
inputs=inputs,
prompt_override=prompt_override,
)
@@ -677,14 +677,14 @@ class ManagementService(Service):
conversation_id: str,
auth_user=None,
) -> Tuple[str, list[Message], list[dict]]:
return await self.providers.database.conversation_handler.get_conversation( # type: ignore
return await self.providers.database.conversations_handler.get_conversation( # type: ignore
conversation_id
)
async def verify_conversation_access(
self, conversation_id: str, user_id: UUID
) -> bool:
return await self.providers.database.conversation_handler.verify_conversation_access(
return await self.providers.database.conversations_handler.verify_conversation_access(
conversation_id, user_id
)
@@ -692,7 +692,7 @@ class ManagementService(Service):
async def create_conversation(
self, user_id: Optional[UUID] = None, auth_user=None
) -> dict:
return await self.providers.database.conversation_handler.create_conversation( # type: ignore
return await self.providers.database.conversations_handler.create_conversation( # type: ignore
user_id=user_id
)
@@ -705,7 +705,7 @@ class ManagementService(Service):
user_ids: Optional[UUID | list[UUID]] = None,
auth_user=None,
) -> dict[str, list[dict] | int]:
return await self.providers.database.conversation_handler.get_conversations_overview(
return await self.providers.database.conversations_handler.get_conversations_overview(
offset=offset,
limit=limit,
user_ids=user_ids,
@@ -721,7 +721,7 @@ class ManagementService(Service):
metadata: Optional[dict] = None,
auth_user=None,
) -> str:
return await self.providers.database.conversation_handler.add_message(
return await self.providers.database.conversations_handler.add_message(
conversation_id, content, parent_id, metadata
)
@@ -733,20 +733,22 @@ class ManagementService(Service):
additional_metadata: dict,
auth_user=None,
) -> Tuple[str, str]:
return await self.providers.database.conversation_handler.edit_message(
message_id, new_content, additional_metadata
return (
await self.providers.database.conversations_handler.edit_message(
message_id, new_content, additional_metadata
)
)
@telemetry_event("updateMessageMetadata")
async def update_message_metadata(
self, message_id: str, metadata: dict, auth_user=None
):
await self.providers.database.conversation_handler.update_message_metadata(
await self.providers.database.conversations_handler.update_message_metadata(
message_id, metadata
)
@telemetry_event("DeleteConversation")
async def delete_conversation(self, conversation_id: str, auth_user=None):
await self.providers.database.conversation_handler.delete_conversation(
await self.providers.database.conversations_handler.delete_conversation(
conversation_id
)
+5 -7
View File
@@ -307,10 +307,8 @@ class RetrievalService(Service):
if conversation_id: # Fetch the existing conversation
try:
conversation = (
await self.logging_connection.get_conversation(
conversation_id=conversation_id
)
conversation = await self.providers.database.conversations_handler.get_conversations_overview(
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}")
@@ -330,7 +328,7 @@ class RetrievalService(Service):
messages = messages_from_conversation + messages
else: # Create new conversation
conversation_response = (
await self.logging_connection.create_conversation()
await self.providers.database.conversations_handler.create_conversation()
)
conversation_id = conversation_response.id
@@ -347,7 +345,7 @@ class RetrievalService(Service):
# Save the new message to the conversation
parent_id = ids[-1] if ids else None
message_response = await self.logging_connection.add_message(
message_response = await self.providers.database.conversations_handler.add_message(
conversation_id=conversation_id,
content=current_message,
parent_id=parent_id,
@@ -398,7 +396,7 @@ class RetrievalService(Service):
role="assistant", content=str(results[-1])
)
await self.logging_connection.add_message(
await self.providers.database.conversations_handler.add_message(
conversation_id=conversation_id,
content=assistant_message,
parent_id=message_id,