From fb6353eebc18292420afbd0154c8b1a6bd33b22b Mon Sep 17 00:00:00 2001 From: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 19 Feb 2025 15:18:12 -0800 Subject: [PATCH] Ruff Formatting and Tests Fixes (#1988) * Check in * Add code quality checks to CI/CD * Fix workflow --- .github/workflows/quality.yml | 28 ++ .gitignore | 1 + py/core/__init__.py | 1 - py/core/agent/base.py | 6 +- py/core/agent/rag.py | 106 ++--- py/core/base/abstractions/__init__.py | 1 + py/core/base/agent/agent.py | 4 +- py/core/base/providers/auth.py | 4 +- py/core/base/providers/base.py | 7 +- py/core/base/providers/crypto.py | 28 +- py/core/base/providers/database.py | 11 +- py/core/main/api/v3/base_router.py | 23 +- py/core/main/api/v3/chunks_router.py | 89 ++-- py/core/main/api/v3/collections_router.py | 345 ++++++-------- py/core/main/api/v3/conversations_router.py | 206 +++------ py/core/main/api/v3/documents_router.py | 378 ++++++--------- py/core/main/api/v3/graph_router.py | 326 +++++-------- py/core/main/api/v3/indices_router.py | 91 ++-- py/core/main/api/v3/prompts_router.py | 120 ++--- py/core/main/api/v3/retrieval_router.py | 156 +++---- py/core/main/api/v3/system_router.py | 54 +-- py/core/main/api/v3/users_router.py | 437 +++++++----------- py/core/main/app_entry.py | 1 - py/core/main/assembly/factory.py | 10 +- py/core/main/config.py | 2 +- .../orchestration/hatchet/graph_workflow.py | 12 +- .../hatchet/ingestion_workflow.py | 8 +- .../simple/ingestion_workflow.py | 24 +- py/core/main/services/auth_service.py | 27 +- py/core/main/services/graph_service.py | 72 ++- py/core/main/services/ingestion_service.py | 68 ++- py/core/main/services/management_service.py | 29 +- py/core/main/services/retrieval_service.py | 61 ++- py/core/parsers/media/audio_parser.py | 4 +- py/core/parsers/media/doc_parser.py | 4 +- py/core/parsers/media/img_parser.py | 3 +- py/core/parsers/media/odt_parser.py | 2 +- py/core/parsers/media/pdf_parser.py | 18 +- py/core/parsers/media/ppt_parser.py | 2 +- py/core/parsers/media/rtf_parser.py | 2 +- py/core/parsers/structured/__init__.py | 1 + py/core/parsers/structured/epub_parser.py | 2 +- py/core/parsers/structured/json_parser.py | 5 +- py/core/parsers/structured/msg_parser.py | 1 - py/core/parsers/structured/org_parser.py | 2 +- py/core/parsers/structured/p7s_parser.py | 7 +- py/core/parsers/structured/rst_parser.py | 2 +- py/core/parsers/structured/tiff_parser.py | 4 +- py/core/providers/auth/jwt.py | 6 +- py/core/providers/auth/r2r_auth.py | 33 +- py/core/providers/auth/supabase.py | 15 +- py/core/providers/crypto/bcrypt.py | 8 +- py/core/providers/crypto/nacl.py | 15 +- py/core/providers/database/base.py | 9 +- py/core/providers/database/chunks.py | 51 +- py/core/providers/database/collections.py | 37 +- py/core/providers/database/conversations.py | 31 +- py/core/providers/database/documents.py | 45 +- py/core/providers/database/files.py | 6 +- py/core/providers/database/filters.py | 19 +- py/core/providers/database/graphs.py | 122 +++-- py/core/providers/database/limits.py | 21 +- py/core/providers/database/prompts_handler.py | 54 ++- py/core/providers/database/users.py | 34 +- py/core/providers/email/console_mock.py | 27 +- py/core/providers/email/sendgrid.py | 4 +- py/core/providers/email/smtp.py | 8 +- py/core/providers/embeddings/litellm.py | 8 +- py/core/providers/embeddings/ollama.py | 4 +- py/core/providers/ingestion/r2r/base.py | 1 - py/core/providers/llm/anthropic.py | 35 +- py/core/providers/llm/openai.py | 3 +- py/core/providers/llm/r2r_llm.py | 34 +- py/core/providers/orchestration/hatchet.py | 2 +- py/core/telemetry/posthog.py | 7 +- py/core/utils/logging_config.py | 7 +- py/core/utils/serper.py | 5 +- ...2fac23e4d91b_migrate_to_document_search.py | 37 +- .../3efc7b3b1b3d_add_total_tokens_count.py | 32 +- ...70560f406_add_limits_overrides_to_users.py | 5 +- .../8077140e1e99_v3_api_database_revision.py | 5 +- ...cf6a8a4_add_user_and_document_count_to_.py | 11 +- .../d342e632358a_migrate_to_asyncpg.py | 24 +- py/pyproject.toml | 1 + py/sdk/asnyc_methods/__init__.py | 40 +- py/sdk/asnyc_methods/chunks.py | 22 +- py/sdk/asnyc_methods/collections.py | 39 +- py/sdk/asnyc_methods/conversations.py | 29 +- py/sdk/asnyc_methods/documents.py | 56 +-- py/sdk/asnyc_methods/graphs.py | 65 +-- py/sdk/asnyc_methods/indices.py | 13 +- py/sdk/asnyc_methods/prompts.py | 20 +- py/sdk/asnyc_methods/retrieval.py | 17 +- py/sdk/asnyc_methods/system.py | 14 +- py/sdk/asnyc_methods/users.py | 75 ++- py/sdk/async_client.py | 4 +- py/sdk/sync_methods/__init__.py | 40 +- py/sdk/sync_methods/chunks.py | 22 +- py/sdk/sync_methods/collections.py | 39 +- py/sdk/sync_methods/conversations.py | 29 +- py/sdk/sync_methods/documents.py | 59 +-- py/sdk/sync_methods/graphs.py | 65 +-- py/sdk/sync_methods/indices.py | 13 +- py/sdk/sync_methods/prompts.py | 20 +- py/sdk/sync_methods/retrieval.py | 17 +- py/sdk/sync_methods/system.py | 14 +- py/sdk/sync_methods/users.py | 75 ++- py/shared/abstractions/__init__.py | 1 + py/shared/abstractions/document.py | 11 +- py/shared/abstractions/exception.py | 8 +- py/shared/abstractions/graph.py | 24 +- py/shared/abstractions/llm.py | 1 - py/shared/abstractions/search.py | 30 +- py/shared/abstractions/vector.py | 27 +- py/shared/api/models/management/responses.py | 1 - py/shared/api/models/retrieval/responses.py | 57 ++- py/shared/utils/base_utils.py | 96 ++-- py/shared/utils/splitter/text.py | 140 +++--- py/tests/integration/conftest.py | 24 +- py/tests/integration/test_base.py | 5 +- py/tests/integration/test_chunks.py | 128 +++-- .../integration/test_collection_id_filter.py | 92 ++-- py/tests/integration/test_collections.py | 88 ++-- .../test_collections_users_interaction.py | 245 ++++------ py/tests/integration/test_conversations.py | 68 ++- py/tests/integration/test_documents.py | 197 ++++---- py/tests/integration/test_filters.py | 99 ++-- py/tests/integration/test_graphs.py | 114 +++-- py/tests/integration/test_indices.py | 9 +- py/tests/integration/test_ingestion.py | 29 +- py/tests/integration/test_retrieval.py | 377 ++++++++++----- .../integration/test_retrieval_advanced.py | 56 +-- py/tests/integration/test_system.py | 12 - py/tests/integration/test_users.py | 116 ++--- py/tests/scaling/loadTester.py | 25 +- py/tests/unit/conftest.py | 25 +- py/tests/unit/test_chunks.py | 121 +++-- py/tests/unit/test_citations.py | 257 +++++----- py/tests/unit/test_collections.py | 88 ++-- py/tests/unit/test_config.py | 135 +++--- py/tests/unit/test_conversations.py | 22 +- py/tests/unit/test_documents.py | 61 +-- py/tests/unit/test_graphs.py | 187 ++++---- py/tests/unit/test_limits.py | 162 +++---- py/tests/unit/test_prompts.py | 57 +-- py/tests/unit/test_routes.py | 72 +-- 146 files changed, 3310 insertions(+), 4305 deletions(-) create mode 100644 .github/workflows/quality.yml diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml new file mode 100644 index 000000000..be5530c68 --- /dev/null +++ b/.github/workflows/quality.yml @@ -0,0 +1,28 @@ +name: Code Quality Checks + +on: + push: + branches: [ '**' ] + pull_request: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit + + - name: Run pre-commit hooks (excluding mypy) + env: + SKIP: mypy + run: | + pre-commit run --all-files diff --git a/.gitignore b/.gitignore index b33fdcc7e..801f614c1 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ go.work.sum .vscode/ .python-version +.ruff_cache/ diff --git a/py/core/__init__.py b/py/core/__init__.py index d1cbe310c..a719136e3 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -29,7 +29,6 @@ logger.propagate = False logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("LiteLLM").setLevel(logging.WARNING) - __all__ = [ "R2RAgent", "R2RStreamingAgent", diff --git a/py/core/agent/base.py b/py/core/agent/base.py index ccdaff3df..0856a92e9 100644 --- a/py/core/agent/base.py +++ b/py/core/agent/base.py @@ -326,8 +326,8 @@ class R2RStreamingReasoningAgent(R2RStreamingAgent): *args, **kwargs, ) -> AsyncGenerator[str, None]: - """ - Revised processing for the reasoning agent. + """Revised processing for the reasoning agent. + This version: 1. Accumulates tool calls in a list (each with a unique internal_id). 2. When finish_reason == "tool_calls", it records the tool calls in the conversation, @@ -420,7 +420,7 @@ class R2RStreamingReasoningAgent(R2RStreamingAgent): await self.conversation.add_message(assistant_msg) # Execute tool calls in parallel - for idx, tool_call in pending_tool_calls.items(): + for _idx, tool_call in pending_tool_calls.items(): if inside_thoughts: yield "" yield "" diff --git a/py/core/agent/rag.py b/py/core/agent/rag.py index bf5bf9e5f..6a9cfb881 100644 --- a/py/core/agent/rag.py +++ b/py/core/agent/rag.py @@ -36,10 +36,11 @@ COMPUTE_FAILURE = "I failed to reach a conclusion with my allowed comp class SearchResultsCollector: - """ - Collects search results in the form (source_type, result_obj, aggregator_index). - aggregator_index increments globally so that the nth item appended - is always aggregator_index == n, across the entire conversation. + """Collects search results in the form (source_type, result_obj, + aggregator_index). + + aggregator_index increments globally so that the nth item appended is + always aggregator_index == n, across the entire conversation. """ def __init__(self): @@ -48,10 +49,9 @@ class SearchResultsCollector: self._next_index = 1 # 1-based indexing def add_aggregate_result(self, agg: "AggregateSearchResult"): - """ - Flatten the chunk_search_results, graph_search_results, web_search_results, - and context_document_results, each assigned a unique aggregator index. - """ + """Flatten the chunk_search_results, graph_search_results, + web_search_results, and context_document_results, each assigned a + unique aggregator index.""" if agg.chunk_search_results: for c in agg.chunk_search_results: self._results_in_order.append(("chunk", c, self._next_index)) @@ -75,10 +75,8 @@ class SearchResultsCollector: self._next_index += 1 def get_all_results(self) -> list[Tuple[str, Any, int]]: - """ - Return list of (source_type, result_obj, aggregator_index), - in the order appended. - """ + """Return list of (source_type, result_obj, aggregator_index), in the + order appended.""" return self._results_in_order @@ -87,14 +85,13 @@ def num_tokens(text, model="gpt-4o"): encoding = tiktoken.encoding_for_model(model) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") - """Return the number of tokens used by a list of messages for both user and assistant.""" return len(encoding.encode(text, disallowed_special=())) class RAGAgentMixin: - """ - A Mixin for adding local_search, web_search, and content tools + """A Mixin for adding local_search, web_search, and content tools. + to your R2R Agents. This allows your agent to: - call local_search_method (semantic/hybrid search) - call content_method (fetch entire doc/chunk structures) @@ -121,10 +118,8 @@ class RAGAgentMixin: super().__init__(*args, **kwargs) def _register_tools(self): - """ - Called by the base agent to register all requested tools - from self.config.tools. - """ + """Called by the base agent to register all requested tools from + self.config.tools.""" if not self.config.tools: return for tool_name in set(self.config.tools): @@ -141,10 +136,8 @@ class RAGAgentMixin: # Local Search Tool def local_search(self) -> Tool: - """ - Tool to do a semantic/hybrid search on the local knowledge base - using self.local_search_method. - """ + """Tool to do a semantic/hybrid search on the local knowledge base + using self.local_search_method.""" return Tool( name="local_search", description=( @@ -172,9 +165,10 @@ class RAGAgentMixin: *args, **kwargs, ) -> AggregateSearchResult: - """ - Calls the passed-in `local_search_method(query, search_settings)`. - Expects either an AggregateSearchResult or a dict with chunk_search_results, etc. + """Calls the passed-in `local_search_method(query, search_settings)`. + + Expects either an AggregateSearchResult or a dict with + chunk_search_results, etc. """ if not self.local_search_method: raise ValueError( @@ -203,9 +197,10 @@ class RAGAgentMixin: # 2) Local Context def content(self) -> Tool: - """ - Tool to fetch entire documents from the local database. Typically used if the agent needs - deeper or more structured context from documents, not just chunk-level hits. + """Tool to fetch entire documents from the local database. + + Typically used if the agent needs deeper or more structured context + from documents, not just chunk-level hits. """ if "gemini" in self.rag_generation_config.model: tool = Tool( @@ -269,9 +264,10 @@ class RAGAgentMixin: *args, **kwargs, ) -> AggregateSearchResult: - """ - Calls the passed-in `content_method(filters, options)` to fetch - doc+chunk structures. Typically returns a list of dicts: + """Calls the passed-in `content_method(filters, options)` to fetch + doc+chunk structures. + + Typically returns a list of dicts: [ { 'document': {...}, 'chunks': [ {...}, {...}, ... ] }, ... @@ -395,9 +391,8 @@ class RAGAgentMixin: class R2RRAGAgent(RAGAgentMixin, R2RAgent): - """ - Non-streaming RAG Agent that supports local_search, content, web_search. - """ + """Non-streaming RAG Agent that supports local_search, content, + web_search.""" def __init__( self, @@ -438,9 +433,8 @@ class R2RRAGAgent(RAGAgentMixin, R2RAgent): class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent): - """ - Streaming-capable RAG Agent that supports local_search, content, web_search. - """ + """Streaming-capable RAG Agent that supports local_search, content, + web_search.""" def __init__( self, @@ -484,9 +478,8 @@ class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent): class R2RStreamingReasoningRAGAgent(RAGAgentMixin, R2RStreamingReasoningAgent): - """ - Streaming-capable RAG Agent that supports local_search, content, web_search. - """ + """Streaming-capable RAG Agent that supports local_search, content, + web_search.""" def __init__( self, @@ -574,8 +567,9 @@ class R2RXMLToolsStreamingReasoningRAGAgent(R2RStreamingReasoningRAGAgent): *args, **kwargs, ) -> AsyncGenerator[str, None]: - """ - Iterative approach with chain-of-thought wrapped in ... each iteration. + """Iterative approach with chain-of-thought wrapped in + ... each iteration. + 1) In each iteration (up to max_steps): a) Call _generate_thinking_response(conversation_context). b) Stream chain-of-thought tokens *inline* but enclosed by .... @@ -717,8 +711,8 @@ class R2RXMLToolsStreamingReasoningRAGAgent(R2RStreamingReasoningRAGAgent): return def _build_single_user_prompt(self, conversation_msgs: list[dict]) -> str: - """ - Converts system+user+assistant messages into a single text prompt. + """Converts system+user+assistant messages into a single text prompt. + Overridable if you want a different style. """ system_msgs = [] @@ -737,8 +731,7 @@ class R2RXMLToolsStreamingReasoningRAGAgent(R2RStreamingReasoningRAGAgent): @staticmethod def _parse_tool_calls(text: str) -> list[dict]: - """ - Parse tool calls from XML-like text. + """Parse tool calls from XML-like text. This function locates blocks (or, if not present, the entire text) and then extracts all blocks within. It patches incomplete tags and @@ -885,9 +878,8 @@ class R2RXMLToolsStreamingReasoningRAGAgent(R2RStreamingReasoningRAGAgent): class GeminiXMLToolsStreamingReasoningRAGAgent( R2RXMLToolsStreamingReasoningRAGAgent ): - """ - A Gemini-based implementation that uses the `XMLToolsStreamingRAGAgentBase`. - """ + """A Gemini-based implementation that uses the + `XMLToolsStreamingRAGAgentBase`.""" def __init__( self, @@ -920,8 +912,9 @@ class GeminiXMLToolsStreamingReasoningRAGAgent( *args, **kwargs, ) -> AsyncGenerator[str, None]: - """ - Iterative approach with chain-of-thought wrapped in ... each iteration. + """Iterative approach with chain-of-thought wrapped in + ... each iteration. + 1) In each iteration (up to max_steps): a) Call _generate_thinking_response(conversation_context). b) Stream chain-of-thought tokens *inline* but enclosed by .... @@ -1096,8 +1089,8 @@ class GeminiXMLToolsStreamingReasoningRAGAgent( initial_delay: float = 1.0, **kwargs, ) -> AsyncGenerator[tuple[bool, str], None]: - """ - Generate thinking response with retry logic for handling transient failures. + """Generate thinking response with retry logic for handling transient + failures. Args: user_prompt: The prompt to send to Gemini @@ -1159,9 +1152,8 @@ class GeminiXMLToolsStreamingReasoningRAGAgent( return def _parse_action_blocks(self, text: str) -> list[dict]: - """ - Find ... blocks in 'text' using simple regex, - then parse out blocks within each . + """Find ... blocks in 'text' using simple regex, then + parse out blocks within each . Returns a list of dicts, each with: { diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 93b213ef2..157ba9757 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -100,6 +100,7 @@ __all__ = [ "R2RException", # Graph abstractions "Entity", + "Graph", "Community", "StoreType", "GraphExtraction", diff --git a/py/core/base/agent/agent.py b/py/core/base/agent/agent.py index 234113188..1fad4ce2e 100644 --- a/py/core/base/agent/agent.py +++ b/py/core/base/agent/agent.py @@ -138,7 +138,7 @@ class Agent(ABC): last_message["role"] in ["tool", "function"] and last_message["content"] != "" and "ollama" in self.rag_generation_config.model - or self.config.include_tools == False + or not self.config.include_tools ): return GenerationConfig( **self.rag_generation_config.model_dump( @@ -207,7 +207,7 @@ class Agent(ABC): raise R2RException( message=f"Error parsing function arguments: {e}, agent likely produced invalid tool inputs.", status_code=400, - ) + ) from e merged_kwargs = {**kwargs, **function_args} raw_result = await tool.results_function(*args, **merged_kwargs) diff --git a/py/core/base/providers/auth.py b/py/core/base/providers/auth.py index 1ece483e5..8673b03dc 100644 --- a/py/core/base/providers/auth.py +++ b/py/core/base/providers/auth.py @@ -262,7 +262,7 @@ class AuthProvider(Provider, ABC): ) raise R2RException( status_code=401, message="Authentication failed" - ) + ) from None except Exception as e: logger.error(f"WebSocket error during auth: {e}") @@ -271,7 +271,7 @@ class AuthProvider(Provider, ABC): ) raise R2RException( status_code=401, message="Authentication failed" - ) + ) from None return _websocket_auth_wrapper diff --git a/py/core/base/providers/base.py b/py/core/base/providers/base.py index 962fca1d5..f90869cdb 100644 --- a/py/core/base/providers/base.py +++ b/py/core/base/providers/base.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class InnerConfig(BaseModel, ABC): - """A base provider configuration class""" + """A base provider configuration class.""" class Config: populate_by_name = True @@ -78,7 +78,7 @@ class AppConfig(InnerConfig): class ProviderConfig(BaseModel, ABC): - """A base provider configuration class""" + """A base provider configuration class.""" app: AppConfig # Add an app_config field extra_fields: dict[str, Any] = {} @@ -122,7 +122,8 @@ class ProviderConfig(BaseModel, ABC): class Provider(ABC): - """A base provider class to provide a common interface for all providers.""" + """A base provider class to provide a common interface for all + providers.""" def __init__(self, config: ProviderConfig, *args, **kwargs): if config: diff --git a/py/core/base/providers/crypto.py b/py/core/base/providers/crypto.py index b9ce1fecd..bdf794b01 100644 --- a/py/core/base/providers/crypto.py +++ b/py/core/base/providers/crypto.py @@ -27,14 +27,16 @@ class CryptoProvider(Provider, ABC): @abstractmethod def get_password_hash(self, password: str) -> str: - """Hash a plaintext password using a secure password hashing algorithm (e.g., Argon2i).""" + """Hash a plaintext password using a secure password hashing algorithm + (e.g., Argon2i).""" pass @abstractmethod def verify_password( self, plain_password: str, hashed_password: str ) -> bool: - """Verify that a plaintext password matches the given hashed password.""" + """Verify that a plaintext password matches the given hashed + password.""" pass @abstractmethod @@ -44,8 +46,7 @@ class CryptoProvider(Provider, ABC): @abstractmethod def generate_signing_keypair(self) -> Tuple[str, str, str]: - """ - Generate a new Ed25519 signing keypair for request signing. + """Generate a new Ed25519 signing keypair for request signing. Returns: A tuple of (key_id, private_key, public_key). @@ -57,20 +58,21 @@ class CryptoProvider(Provider, ABC): @abstractmethod def sign_request(self, private_key: str, data: str) -> str: - """Sign request data with an Ed25519 private key, returning the signature.""" + """Sign request data with an Ed25519 private key, returning the + signature.""" pass @abstractmethod def verify_request_signature( self, public_key: str, signature: str, data: str ) -> bool: - """Verify a request signature using the corresponding Ed25519 public key.""" + """Verify a request signature using the corresponding Ed25519 public + key.""" pass @abstractmethod def generate_api_key(self) -> Tuple[str, str]: - """ - Generate a new API key for a user. + """Generate a new API key for a user. Returns: A tuple (key_id, raw_api_key): @@ -81,8 +83,8 @@ class CryptoProvider(Provider, ABC): @abstractmethod def hash_api_key(self, raw_api_key: str) -> str: - """ - Hash a raw API key for secure storage in the database. + """Hash a raw API key for secure storage in the database. + Use strong parameters suitable for long-term secrets. """ pass @@ -94,8 +96,7 @@ class CryptoProvider(Provider, ABC): @abstractmethod def generate_secure_token(self, data: dict, expiry: datetime) -> str: - """ - Generate a secure, signed token (e.g., JWT) embedding claims. + """Generate a secure, signed token (e.g., JWT) embedding claims. Args: data: The claims to include in the token. @@ -108,8 +109,7 @@ class CryptoProvider(Provider, ABC): @abstractmethod def verify_secure_token(self, token: str) -> Optional[dict]: - """ - Verify a secure token (e.g., JWT). + """Verify a secure token (e.g., JWT). Args: token: The token string to verify. diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index ebc995870..1af1cc0db 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -71,11 +71,12 @@ class Handler(ABC): class PostgresConfigurationSettings(BaseModel): - """ - Configuration settings with defaults defined by the PGVector docker image. + """Configuration settings with defaults defined by the PGVector docker + image. - These settings are helpful in managing the connections to the database. - To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/ + These settings are helpful in managing the connections to the database. To + tune these settings for a specific deployment, see + https://pgtune.leopard.in.ua/ """ checkpoint_completion_target: Optional[float] = 0.9 @@ -114,7 +115,7 @@ class LimitSettings(BaseModel): class DatabaseConfig(ProviderConfig): - """A base database configuration class""" + """A base database configuration class.""" provider: str = "postgres" user: Optional[str] = None diff --git a/py/core/main/api/v3/base_router.py b/py/core/main/api/v3/base_router.py index 58098c8c5..e217a98e1 100644 --- a/py/core/main/api/v3/base_router.py +++ b/py/core/main/api/v3/base_router.py @@ -79,9 +79,8 @@ class BaseRouterV3: @classmethod def build_router(cls, engine): - """ - Class method for building a router instance (if you have a standard pattern). - """ + """Class method for building a router instance (if you have a standard + pattern).""" return cls(engine).router def _register_workflows(self): @@ -92,14 +91,12 @@ class BaseRouterV3: @abstractmethod def _setup_routes(self): - """ - Subclasses override this to define actual endpoints. - """ + """Subclasses override this to define actual endpoints.""" pass def set_rate_limiting(self): - """ - Adds a yield-based dependency for rate limiting each request. + """Adds a yield-based dependency for rate limiting each request. + Checks the limits, then logs the request if the check passes. """ @@ -107,10 +104,10 @@ class BaseRouterV3: request: Request, auth_user=Depends(self.providers.auth.auth_wrapper()), ): - """ - 1) Fetch the user from the DB (including .limits_overrides). - 2) Pass it to limits_handler.check_limits. - 3) After the endpoint completes, call limits_handler.log_request. + """1) Fetch the user from the DB (including .limits_overrides). + + 2) Pass it to limits_handler.check_limits. 3) After the endpoint + completes, call limits_handler.log_request. """ # If the user is superuser, skip checks if auth_user.is_superuser: @@ -135,7 +132,7 @@ class BaseRouterV3: ) except ValueError as e: # If check_limits raises ValueError -> 429 Too Many Requests - raise HTTPException(status_code=429, detail=str(e)) + raise HTTPException(status_code=429, detail=str(e)) from e request.state.user_id = user_id request.state.route = route diff --git a/py/core/main/api/v3/chunks_router.py b/py/core/main/api/v3/chunks_router.py index 3bf398622..cfd8e14b8 100644 --- a/py/core/main/api/v3/chunks_router.py +++ b/py/core/main/api/v3/chunks_router.py @@ -47,8 +47,7 @@ class ChunksRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -58,8 +57,7 @@ class ChunksRouter(BaseRouterV3): "limit": 10 } ) - """ - ), + """), } ] }, @@ -73,8 +71,7 @@ class ChunksRouter(BaseRouterV3): auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedVectorSearchResponse: # type: ignore # TODO - Deduplicate this code by sharing the code on the retrieval router - """ - Perform a semantic search query over all stored chunks. + """Perform a semantic search query over all stored chunks. This endpoint allows for complex filtering of search results using PostgreSQL-based queries. Filters can be applied to various fields such as document_id, and internal metadata values. @@ -102,21 +99,18 @@ class ChunksRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() response = client.chunks.retrieve( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -128,8 +122,7 @@ class ChunksRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -139,11 +132,11 @@ class ChunksRouter(BaseRouterV3): id: UUID = Path(...), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedChunkResponse: - """ - Get a specific chunk by its ID. + """Get a specific chunk by its ID. - Returns the chunk's content, metadata, and associated document/collection information. - Users can only retrieve chunks they own or have access to through collections. + Returns the chunk's content, metadata, and associated + document/collection information. Users can only retrieve chunks + they own or have access to through collections. """ chunk = await self.services.ingestion.get_chunk(id) if not chunk: @@ -173,8 +166,7 @@ class ChunksRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -185,13 +177,11 @@ class ChunksRouter(BaseRouterV3): "metadata": {"key": "new value"} } ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -205,8 +195,7 @@ class ChunksRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -218,11 +207,11 @@ class ChunksRouter(BaseRouterV3): # TODO: Run with orchestration? auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedChunkResponse: - """ - Update an existing chunk's content and/or metadata. + """Update an existing chunk's content and/or metadata. - The chunk's vectors will be automatically recomputed based on the new content. - Users can only update chunks they own unless they are superusers. + The chunk's vectors will be automatically recomputed based on the + new content. Users can only update chunks they own unless they are + superusers. """ # Get the existing chunk to get its chunk_id existing_chunk = await self.services.ingestion.get_chunk( @@ -266,21 +255,18 @@ class ChunksRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() response = client.chunks.delete( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -292,8 +278,7 @@ class ChunksRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -303,12 +288,11 @@ class ChunksRouter(BaseRouterV3): id: UUID = Path(...), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete a specific chunk by ID. + """Delete a specific chunk by ID. - This permanently removes the chunk and its associated vector embeddings. - The parent document remains unchanged. Users can only delete chunks they - own unless they are superusers. + This permanently removes the chunk and its associated vector + embeddings. The parent document remains unchanged. Users can only + delete chunks they own unless they are superusers. """ # Get the existing chunk to get its chunk_id existing_chunk = await self.services.ingestion.get_chunk(id) @@ -339,8 +323,7 @@ class ChunksRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -350,13 +333,11 @@ class ChunksRouter(BaseRouterV3): offset=0, limit=10, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -371,8 +352,7 @@ class ChunksRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -398,15 +378,14 @@ class ChunksRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedChunksResponse: - """ - List chunks with pagination support. + """List chunks with pagination support. Returns a paginated list of chunks that the user has access to. Results can be filtered and sorted based on various parameters. Vector embeddings are only included if specifically requested. - Regular users can only list chunks they own or have access to through - collections. Superusers can list all chunks in the system. + Regular users can only list chunks they own or have access to + through collections. Superusers can list all chunks in the system. """ # Build filters filters = {} diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 2bdc6466c..c664ff92e 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -1,5 +1,6 @@ import logging import textwrap +from enum import Enum from typing import Optional from uuid import UUID @@ -30,9 +31,6 @@ from .base_router import BaseRouterV3 logger = logging.getLogger() -from enum import Enum - - class CollectionAction(str, Enum): VIEW = "view" EDIT = "edit" @@ -45,8 +43,8 @@ class CollectionAction(str, Enum): async def authorize_collection_action( auth_user, collection_id: UUID, action: CollectionAction, services ) -> bool: - """ - Authorize a user's action on a given collection based on: + """Authorize a user's action on a given collection based on: + - If user is superuser (admin): Full access. - If user is owner of the collection: Full access. - If user is a member of the collection (in `collection_ids`): VIEW only. @@ -103,8 +101,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -114,13 +111,11 @@ class CollectionsRouter(BaseRouterV3): name="My New Collection", description="This is a sample collection" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -133,19 +128,16 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/collections" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{"name": "My New Collection", "description": "This is a sample collection"}' - """ - ), + """), }, ] }, @@ -158,11 +150,12 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCollectionResponse: - """ - Create a new collection and automatically add the creating user to it. + """Create a new collection and automatically add the creating user + to it. - This endpoint allows authenticated users to create a new collection with a specified name - and optional description. The user creating the collection is automatically added as a member. + This endpoint allows authenticated users to create a new collection + with a specified name and optional description. The user creating + the collection is automatically added as a member. """ user_collections_count = ( await self.services.management.collections_overview( @@ -198,8 +191,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -210,13 +202,11 @@ class CollectionsRouter(BaseRouterV3): columns=["id", "name", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -230,21 +220,18 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/collections/export" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -263,9 +250,7 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export collections as a CSV file. - """ + """Export collections as a CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -298,8 +283,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -309,13 +293,11 @@ class CollectionsRouter(BaseRouterV3): offset=0, limit=10, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -325,17 +307,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/collections?offset=0&limit=10&name=Sample" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -359,13 +338,15 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCollectionsResponse: - """ - Returns a paginated list of collections the authenticated user has access to. + """Returns a paginated list of collections the authenticated user + has access to. - Results can be filtered by providing specific collection IDs. Regular users will only see - collections they own or have access to. Superusers can see all collections. + Results can be filtered by providing specific collection IDs. + Regular users will only see collections they own or have access to. + Superusers can see all collections. - The collections are returned in order of last modification, with most recent first. + The collections are returned in order of last modification, with + most recent first. """ requesting_user_id = ( None if auth_user.is_superuser else [auth_user.id] @@ -399,21 +380,18 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.collections.retrieve("123e4567-e89b-12d3-a456-426614174000") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -423,17 +401,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -445,11 +420,11 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCollectionResponse: - """ - Get details of a specific collection. + """Get details of a specific collection. - This endpoint retrieves detailed information about a single collection identified by its UUID. - The user must have access to the collection to view its details. + This endpoint retrieves detailed information about a single + collection identified by its UUID. The user must have access to the + collection to view its details. """ await authorize_collection_action( auth_user, id, CollectionAction.VIEW, self.services @@ -480,8 +455,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -492,13 +466,11 @@ class CollectionsRouter(BaseRouterV3): name="Updated Collection Name", description="Updated description" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -512,19 +484,16 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{"name": "Updated Collection Name", "description": "Updated description"}' - """ - ), + """), }, ] }, @@ -547,11 +516,11 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCollectionResponse: - """ - Update an existing collection's configuration. + """Update an existing collection's configuration. - This endpoint allows updating the name and description of an existing collection. - The user must have appropriate permissions to modify the collection. + This endpoint allows updating the name and description of an + existing collection. The user must have appropriate permissions to + modify the collection. """ await authorize_collection_action( auth_user, id, CollectionAction.EDIT, self.services @@ -578,21 +547,18 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.collections.delete("123e4567-e89b-12d3-a456-426614174000") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -602,17 +568,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -625,12 +588,12 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete an existing collection. + """Delete an existing collection. - This endpoint allows deletion of a collection identified by its UUID. - The user must have appropriate permissions to delete the collection. - Deleting a collection removes all associations but does not delete the documents within it. + This endpoint allows deletion of a collection identified by its + UUID. The user must have appropriate permissions to delete the + collection. Deleting a collection removes all associations but does + not delete the documents within it. """ if id == generate_default_user_collection_id(auth_user.id): raise R2RException( @@ -652,8 +615,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -663,13 +625,11 @@ class CollectionsRouter(BaseRouterV3): "123e4567-e89b-12d3-a456-426614174000", "456e789a-b12c-34d5-e678-901234567890" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -682,17 +642,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -703,9 +660,7 @@ class CollectionsRouter(BaseRouterV3): document_id: UUID = Path(...), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Add a document to a collection. - """ + """Add a document to a collection.""" await authorize_collection_action( auth_user, id, CollectionAction.ADD_DOCUMENT, self.services ) @@ -724,8 +679,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -736,13 +690,11 @@ class CollectionsRouter(BaseRouterV3): offset=0, limit=10, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -752,17 +704,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents?offset=0&limit=10" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -785,11 +734,12 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedDocumentsResponse: - """ - Get all documents in a collection with pagination and sorting options. + """Get all documents in a collection with pagination and sorting + options. - This endpoint retrieves a paginated list of documents associated with a specific collection. - It supports sorting options to customize the order of returned documents. + This endpoint retrieves a paginated list of documents associated + with a specific collection. It supports sorting options to + customize the order of returned documents. """ await authorize_collection_action( auth_user, id, CollectionAction.VIEW, self.services @@ -815,8 +765,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -826,13 +775,11 @@ class CollectionsRouter(BaseRouterV3): "123e4567-e89b-12d3-a456-426614174000", "456e789a-b12c-34d5-e678-901234567890" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -845,17 +792,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/documents/456e789a-b12c-34d5-e678-901234567890" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -871,11 +815,11 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Remove a document from a collection. + """Remove a document from a collection. - This endpoint removes the association between a document and a collection. - It does not delete the document itself. The user must have permissions to modify the collection. + This endpoint removes the association between a document and a + collection. It does not delete the document itself. The user must + have permissions to modify the collection. """ await authorize_collection_action( auth_user, id, CollectionAction.REMOVE_DOCUMENT, self.services @@ -893,8 +837,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -905,13 +848,11 @@ class CollectionsRouter(BaseRouterV3): offset=0, limit=10, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -923,17 +864,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users?offset=0&limit=10" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -956,11 +894,12 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedUsersResponse: - """ - Get all users in a collection with pagination and sorting options. + """Get all users in a collection with pagination and sorting + options. - This endpoint retrieves a paginated list of users who have access to a specific collection. - It supports sorting options to customize the order of returned users. + This endpoint retrieves a paginated list of users who have access + to a specific collection. It supports sorting options to customize + the order of returned users. """ await authorize_collection_action( auth_user, id, CollectionAction.VIEW, self.services @@ -986,8 +925,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -997,13 +935,11 @@ class CollectionsRouter(BaseRouterV3): "123e4567-e89b-12d3-a456-426614174000", "789a012b-c34d-5e6f-g789-012345678901" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1016,17 +952,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1041,11 +974,11 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Add a user to a collection. + """Add a user to a collection. - This endpoint grants a user access to a specific collection. - The authenticated user must have admin permissions for the collection to add new users. + This endpoint grants a user access to a specific collection. The + authenticated user must have admin permissions for the collection + to add new users. """ await authorize_collection_action( auth_user, id, CollectionAction.MANAGE_USERS, self.services @@ -1064,8 +997,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1075,13 +1007,11 @@ class CollectionsRouter(BaseRouterV3): "123e4567-e89b-12d3-a456-426614174000", "789a012b-c34d-5e6f-g789-012345678901" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1094,17 +1024,14 @@ class CollectionsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/v3/collections/123e4567-e89b-12d3-a456-426614174000/users/789a012b-c34d-5e6f-g789-012345678901" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1119,11 +1046,11 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Remove a user from a collection. + """Remove a user from a collection. - This endpoint revokes a user's access to a specific collection. - The authenticated user must have admin permissions for the collection to remove users. + This endpoint revokes a user's access to a specific collection. The + authenticated user must have admin permissions for the collection + to remove users. """ await authorize_collection_action( auth_user, id, CollectionAction.MANAGE_USERS, self.services @@ -1144,8 +1071,7 @@ class CollectionsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1154,8 +1080,7 @@ class CollectionsRouter(BaseRouterV3): result = client.documents.extract( id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1" ) - """ - ), + """), }, ], }, @@ -1176,11 +1101,11 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Extracts entities and relationships from a document. - The entities and relationships extraction process involves: - 1. Parsing documents into semantic chunks - 2. Extracting entities and relationships using LLMs + """Extracts entities and relationships from a document. + + The entities and relationships extraction process involves: + 1. Parsing documents into semantic chunks + 2. Extracting entities and relationships using LLMs """ await authorize_collection_action( auth_user, id, CollectionAction.EDIT, self.services @@ -1248,10 +1173,10 @@ class CollectionsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCollectionResponse: - """ - Retrieve a collection by its (owner_id, name) combination. - The authenticated user can only fetch collections they own, - or, if superuser, from anyone. + """Retrieve a collection by its (owner_id, name) combination. + + The authenticated user can only fetch collections they own, or, if + superuser, from anyone. """ if auth_user.is_superuser: if not owner_id: diff --git a/py/core/main/api/v3/conversations_router.py b/py/core/main/api/v3/conversations_router.py index de58baf5d..7d01435e4 100644 --- a/py/core/main/api/v3/conversations_router.py +++ b/py/core/main/api/v3/conversations_router.py @@ -40,21 +40,18 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.conversations.create() - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -64,17 +61,14 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/conversations" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -86,10 +80,10 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedConversationResponse: - """ - Create a new conversation. + """Create a new conversation. - This endpoint initializes a new conversation for the authenticated user. + This endpoint initializes a new conversation for the authenticated + user. """ user_id = auth_user.id @@ -106,8 +100,7 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -117,13 +110,11 @@ class ConversationsRouter(BaseRouterV3): offset=0, limit=10, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -133,17 +124,14 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/conversations?offset=0&limit=10" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -167,10 +155,10 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedConversationsResponse: - """ - List conversations with pagination and sorting options. + """List conversations with pagination and sorting options. - This endpoint returns a paginated list of conversations for the authenticated user. + This endpoint returns a paginated list of conversations for the + authenticated user. """ requesting_user_id = ( None if auth_user.is_superuser else [auth_user.id] @@ -200,8 +188,7 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -212,13 +199,11 @@ class ConversationsRouter(BaseRouterV3): columns=["id", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -232,21 +217,18 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/conversations/export" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -265,9 +247,7 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export conversations as a downloadable CSV file. - """ + """Export conversations as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -300,8 +280,7 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -312,13 +291,11 @@ class ConversationsRouter(BaseRouterV3): columns=["id", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -332,21 +309,18 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/conversations/export_messages" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -365,9 +339,7 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export conversations as a downloadable CSV file. - """ + """Export conversations as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -400,8 +372,7 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -410,13 +381,11 @@ class ConversationsRouter(BaseRouterV3): result = client.conversations.get( "123e4567-e89b-12d3-a456-426614174000" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -428,17 +397,14 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -450,10 +416,10 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedConversationMessagesResponse: - """ - Get details of a specific conversation. + """Get details of a specific conversation. - This endpoint retrieves detailed information about a single conversation identified by its UUID. + This endpoint retrieves detailed information about a single + conversation identified by its UUID. """ requesting_user_id = ( None if auth_user.is_superuser else [auth_user.id] @@ -473,21 +439,18 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -500,19 +463,16 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -d '{"name": "new_name"}' - """ - ), + """), }, ] }, @@ -530,10 +490,10 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedConversationResponse: - """ - Update an existing conversation. + """Update an existing conversation. - This endpoint updates the name of an existing conversation identified by its UUID. + This endpoint updates the name of an existing conversation + identified by its UUID. """ return await self.services.management.update_conversation( conversation_id=id, @@ -548,21 +508,18 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -574,17 +531,14 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -597,8 +551,7 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete an existing conversation. + """Delete an existing conversation. This endpoint deletes a conversation identified by its UUID. """ @@ -620,8 +573,7 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -634,13 +586,11 @@ class ConversationsRouter(BaseRouterV3): parent_id="parent_message_id", metadata={"key": "value"} ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -655,19 +605,16 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" \\ -d '{"content": "Hello, world!", "parent_id": "parent_message_id", "metadata": {"key": "value"}}' - """ - ), + """), }, ] }, @@ -691,8 +638,7 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedMessageResponse: - """ - Add a new message to a conversation. + """Add a new message to a conversation. This endpoint adds a new message to an existing conversation. """ @@ -716,8 +662,7 @@ class ConversationsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -728,13 +673,11 @@ class ConversationsRouter(BaseRouterV3): "message_id_to_update", content="Updated content" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -748,19 +691,16 @@ class ConversationsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages/message_id_to_update" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" \\ -d '{"content": "Updated content"}' - """ - ), + """), }, ] }, @@ -781,10 +721,10 @@ class ConversationsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedMessageResponse: - """ - Update an existing message in a conversation. + """Update an existing message in a conversation. - This endpoint updates the content of an existing message in a conversation. + This endpoint updates the content of an existing message in a + conversation. """ return await self.services.management.edit_message( message_id=message_id, diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index d46aa4cbf..96b01d3af 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -94,10 +94,9 @@ class DocumentsRouter(BaseRouterV3): search_mode: SearchMode, search_settings: Optional[SearchSettings], ) -> SearchSettings: - """ - Prepare the effective search settings based on the provided search_mode, - optional user-overrides in search_settings, and applied filters. - """ + """Prepare the effective search settings based on the provided + search_mode, optional user-overrides in search_settings, and applied + filters.""" if search_mode != SearchMode.custom: # Start from mode defaults @@ -195,8 +194,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -207,13 +205,11 @@ class DocumentsRouter(BaseRouterV3): metadata={"metadata_1":"some random metadata"}, id=None ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -226,21 +222,18 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/documents" \\ -H "Content-Type: multipart/form-data" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -F "file=@pg_essay_1.html;type=text/html" \\ -F 'metadata={}' \\ -F 'id=null' - """ - ), + """), }, ] }, @@ -578,8 +571,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -590,13 +582,11 @@ class DocumentsRouter(BaseRouterV3): columns=["id", "title", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -610,21 +600,18 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/documents/export" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -643,9 +630,7 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export documents as a downloadable CSV file. - """ + """Export documents as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -679,24 +664,20 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" client.documents.download_zip( document_ids=["uuid1", "uuid2"], start_date="2024-01-01", end_date="2024-12-31" ) - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/documents/download_zip?document_ids=uuid1,uuid2&start_date=2024-01-01&end_date=2024-12-31" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -717,8 +698,8 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> StreamingResponse: - """ - Export multiple documents as a zip file. Documents can be filtered by IDs and/or date range. + """Export multiple documents as a zip file. Documents can be + filtered by IDs and/or date range. The endpoint allows downloading: - Specific documents by providing their IDs @@ -780,8 +761,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -791,13 +771,11 @@ class DocumentsRouter(BaseRouterV3): limit=10, offset=0 ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -810,17 +788,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/documents" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -848,13 +823,15 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedDocumentsResponse: - """ - Returns a paginated list of documents the authenticated user has access to. + """Returns a paginated list of documents the authenticated user has + access to. - Results can be filtered by providing specific document IDs. Regular users will only see - documents they own or have access to through collections. Superusers can see all documents. + Results can be filtered by providing specific document IDs. Regular + users will only see documents they own or have access to through + collections. Superusers can see all documents. - The documents are returned in order of last modification, with most recent first. + The documents are returned in order of last modification, with most + recent first. """ requesting_user_id = ( None if auth_user.is_superuser else [auth_user.id] @@ -894,8 +871,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -904,13 +880,11 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.retrieve( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -922,17 +896,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -945,8 +916,8 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedDocumentResponse: - """ - Retrieves detailed information about a specific document by its ID. + """Retrieves detailed information about a specific document by its + ID. This endpoint returns the document's metadata, status, and system information. It does not return the document's content - use the `/documents/{id}/download` endpoint for that. @@ -982,8 +953,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -992,13 +962,11 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.list_chunks( id="32b6a70f-a995-5c51-85d2-834f06283a1e" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1010,17 +978,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/chunks" \\ -H "Authorization: Bearer YOUR_API_KEY"\ - """ - ), + """), }, ] }, @@ -1048,16 +1013,16 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedChunksResponse: - """ - Retrieves the text chunks that were generated from a document during ingestion. - Chunks represent semantic sections of the document and are used for retrieval - and analysis. + """Retrieves the text chunks that were generated from a document + during ingestion. Chunks represent semantic sections of the + document and are used for retrieval and analysis. - Users can only access chunks from documents they own or have access to through - collections. Vector embeddings are only included if specifically requested. + Users can only access chunks from documents they own or have access + to through collections. Vector embeddings are only included if + specifically requested. - Results are returned in chunk sequence order, representing their position in - the original document. + Results are returned in chunk sequence order, representing their + position in the original document. """ list_document_chunks = ( await self.services.management.list_document_chunks( @@ -1108,8 +1073,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1118,13 +1082,11 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.download( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1136,17 +1098,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/download" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1156,20 +1115,20 @@ class DocumentsRouter(BaseRouterV3): id: str = Path(..., description="Document ID"), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> StreamingResponse: - """ - Downloads the original file content of a document. + """Downloads the original file content of a document. - For uploaded files, returns the original file with its proper MIME type. - For text-only documents, returns the content as plain text. + For uploaded files, returns the original file with its proper MIME + type. For text-only documents, returns the content as plain text. - Users can only download documents they own or have access to through collections. + Users can only download documents they own or have access to + through collections. """ try: document_uuid = UUID(id) except ValueError: raise R2RException( status_code=422, message="Invalid document ID format." - ) + ) from None # Retrieve the document's information documents_overview_response = ( @@ -1253,25 +1212,21 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) response = client.documents.delete_by_filter( filters={"document_type": {"$eq": "txt"}} ) - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/v3/documents/by-filter?filters=%7B%22document_type%22%3A%7B%22%24eq%22%3A%22text%22%7D%2C%22created_at%22%3A%7B%22%24lt%22%3A%222023-01-01T00%3A00%3A00Z%22%7D%7D" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1283,8 +1238,9 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete documents based on provided filters. Allowed operators + """Delete documents based on provided filters. + + Allowed operators include: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. Deletion requests are limited to a user's own documents. @@ -1309,8 +1265,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1319,13 +1274,11 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.delete( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1337,17 +1290,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1357,8 +1307,9 @@ class DocumentsRouter(BaseRouterV3): id: UUID = Path(..., description="Document ID"), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete a specific document. All chunks corresponding to the document are deleted, and all other references to the document are removed. + """Delete a specific document. All chunks corresponding to the + document are deleted, and all other references to the document are + removed. NOTE - Deletions do not yet impact the knowledge graph or other derived data. This feature is planned for a future release. """ @@ -1384,8 +1335,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1394,13 +1344,11 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.list_collections( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", offset=0, limit=10 ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1412,17 +1360,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/collections" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1443,9 +1388,9 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCollectionsResponse: - """ - Retrieves all collections that contain the specified document. This endpoint is restricted - to superusers only and provides a system-wide view of document organization. + """Retrieves all collections that contain the specified document. + This endpoint is restricted to superusers only and provides a + system-wide view of document organization. Collections are used to organize documents and manage access control. A document can belong to multiple collections, and users can access documents through collection membership. @@ -1481,8 +1426,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1491,8 +1435,7 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.extract( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, ], }, @@ -1513,8 +1456,7 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Extracts entities and relationships from a document. + """Extracts entities and relationships from a document. The entities and relationships extraction process involves: @@ -1606,8 +1548,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1615,11 +1556,11 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.deduplicate( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), + }, + { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1631,15 +1572,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), + }, + { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/deduplicate" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ], }, @@ -1660,9 +1600,7 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Deduplicates entities from a document. - """ + """Deduplicates entities from a document.""" settings = settings.model_dump() if settings else None # type: ignore documents_overview_response = ( @@ -1745,8 +1683,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1755,8 +1692,7 @@ class DocumentsRouter(BaseRouterV3): response = client.documents.extract( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, ], }, @@ -1784,14 +1720,16 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedEntitiesResponse: - """ - Retrieves the entities that were extracted from a document. These represent - important semantic elements like people, places, organizations, concepts, etc. + """Retrieves the entities that were extracted from a document. + These represent important semantic elements like people, places, + organizations, concepts, etc. - Users can only access entities from documents they own or have access to through - collections. Entity embeddings are only included if specifically requested. + Users can only access entities from documents they own or have + access to through collections. Entity embeddings are only included + if specifically requested. - Results are returned in the order they were extracted from the document. + Results are returned in the order they were extracted from the + document. """ # if ( # not auth_user.is_superuser @@ -1844,8 +1782,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -1857,13 +1794,11 @@ class DocumentsRouter(BaseRouterV3): columns=["id", "title", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -1878,21 +1813,18 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -1915,9 +1847,7 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export documents as a downloadable CSV file. - """ + """Export documents as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -1951,8 +1881,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1963,13 +1892,11 @@ class DocumentsRouter(BaseRouterV3): offset=0, limit=100 ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1983,17 +1910,14 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/relationships" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -2025,14 +1949,16 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedRelationshipsResponse: - """ - Retrieves the relationships between entities that were extracted from a document. These represent - connections and interactions between entities found in the text. + """Retrieves the relationships between entities that were extracted + from a document. These represent connections and interactions + between entities found in the text. - Users can only access relationships from documents they own or have access to through - collections. Results can be filtered by entity names and relationship types. + Users can only access relationships from documents they own or have + access to through collections. Results can be filtered by entity + names and relationship types. - Results are returned in the order they were extracted from the document. + Results are returned in the order they were extracted from the + document. """ # if ( # not auth_user.is_superuser @@ -2086,8 +2012,7 @@ class DocumentsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -2099,13 +2024,11 @@ class DocumentsRouter(BaseRouterV3): columns=["id", "title", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -2120,21 +2043,18 @@ class DocumentsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -2157,9 +2077,7 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export documents as a downloadable CSV file. - """ + """Export documents as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -2214,8 +2132,8 @@ class DocumentsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedDocumentSearchResponse: - """ - Perform a search query on the automatically generated document summaries in the system. + """Perform a search query on the automatically generated document + summaries in the system. This endpoint allows for complex filtering of search results using PostgreSQL-based queries. Filters can be applied to various fields such as document_id, and internal metadata values. diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 8cfc57d87..7481c9327 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -78,7 +78,8 @@ class GraphRouter(BaseRouterV3): async def _get_collection_id( self, collection_id: Optional[UUID], auth_user ) -> UUID: - """Helper method to get collection ID, using default if none provided""" + """Helper method to get collection ID, using default if none + provided.""" if collection_id is None: return generate_default_user_collection_id(auth_user.id) return collection_id @@ -141,13 +142,15 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGraphsResponse: - """ - Returns a paginated list of graphs the authenticated user has access to. + """Returns a paginated list of graphs the authenticated user has + access to. - Results can be filtered by providing specific graph IDs. Regular users will only see - graphs they own or have access to. Superusers can see all graphs. + Results can be filtered by providing specific graph IDs. Regular + users will only see graphs they own or have access to. Superusers + can see all graphs. - The graphs are returned in order of last modification, with most recent first. + The graphs are returned in order of last modification, with most + recent first. """ requesting_user_id = ( None if auth_user.is_superuser else [auth_user.id] @@ -175,8 +178,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -184,13 +186,11 @@ class GraphRouter(BaseRouterV3): response = client.graphs.get( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -202,16 +202,13 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), + -H "Authorization: Bearer YOUR_API_KEY" """), }, ] }, @@ -221,9 +218,7 @@ class GraphRouter(BaseRouterV3): collection_id: UUID = Path(...), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGraphResponse: - """ - Retrieves detailed information about a specific graph by ID. - """ + """Retrieves detailed information about a specific graph by ID.""" if ( # not auth_user.is_superuser collection_id not in auth_user.collection_ids @@ -257,8 +252,8 @@ class GraphRouter(BaseRouterV3): run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Creates communities in the graph by analyzing entity relationships and similarities. + """Creates communities in the graph by analyzing entity + relationships and similarities. Communities are created through the following process: 1. Analyzes entity relationships and metadata to build a similarity graph @@ -352,8 +347,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -361,13 +355,11 @@ class GraphRouter(BaseRouterV3): response = client.graphs.reset( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -379,16 +371,13 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/reset" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), + -H "Authorization: Bearer YOUR_API_KEY" """), }, ] }, @@ -398,13 +387,13 @@ class GraphRouter(BaseRouterV3): collection_id: UUID = Path(...), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Deletes a graph and all its associated data. + """Deletes a graph and all its associated data. - This endpoint permanently removes the specified graph along with all - entities and relationships that belong to only this graph. - The original source entities and relationships extracted from underlying documents are not deleted - and are managed through the document lifecycle. + This endpoint permanently removes the specified graph along with + all entities and relationships that belong to only this graph. The + original source entities and relationships extracted from + underlying documents are not deleted and are managed through the + document lifecycle. """ if not auth_user.is_superuser: raise R2RException("Only superusers can reset a graph", 403) @@ -431,8 +420,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -444,13 +432,11 @@ class GraphRouter(BaseRouterV3): "name": "New Name", "description": "New Description" } - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -464,8 +450,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -484,11 +469,11 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGraphResponse: - """ - Update an existing graphs's configuration. + """Update an existing graphs's configuration. - This endpoint allows updating the name and description of an existing collection. - The user must have appropriate permissions to modify the collection. + This endpoint allows updating the name and description of an + existing collection. The user must have appropriate permissions to + modify the collection. """ if not auth_user.is_superuser: raise R2RException( @@ -517,21 +502,18 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) response = client.graphs.list_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -543,8 +525,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ], }, @@ -596,8 +577,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -609,13 +589,11 @@ class GraphRouter(BaseRouterV3): columns=["id", "title", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -630,21 +608,18 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/graphs/export_entities" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -667,9 +642,7 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export documents as a downloadable CSV file. - """ + """Export documents as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -810,8 +783,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -823,13 +795,11 @@ class GraphRouter(BaseRouterV3): columns=["id", "title", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -844,21 +814,18 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/graphs/export_relationships" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -881,9 +848,7 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export documents as a downloadable CSV file. - """ + """Export documents as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -916,8 +881,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -927,13 +891,11 @@ class GraphRouter(BaseRouterV3): collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -946,8 +908,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1041,8 +1002,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1052,13 +1012,11 @@ class GraphRouter(BaseRouterV3): collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1071,8 +1029,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1119,21 +1076,18 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) response = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1145,8 +1099,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ], }, @@ -1170,9 +1123,7 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedRelationshipsResponse: - """ - Lists all relationships in the graph with pagination support. - """ + """Lists all relationships in the graph with pagination support.""" if ( # not auth_user.is_superuser collection_id not in auth_user.collection_ids @@ -1200,8 +1151,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1211,13 +1161,11 @@ class GraphRouter(BaseRouterV3): collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1230,8 +1178,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ], }, @@ -1345,8 +1292,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1356,13 +1302,11 @@ class GraphRouter(BaseRouterV3): collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1375,8 +1319,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ], }, @@ -1423,8 +1366,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1438,13 +1380,11 @@ class GraphRouter(BaseRouterV3): rating=5, rating_explanation="This is a rating explanation", ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1461,8 +1401,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1486,8 +1425,7 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCommunityResponse: - """ - Creates a new community in the graph. + """Creates a new community in the graph. While communities are typically built automatically via the /graphs/{id}/communities/build endpoint, this endpoint allows you to manually create your own communities. @@ -1532,21 +1470,18 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) response = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1558,8 +1493,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1583,9 +1517,7 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCommunitiesResponse: - """ - Lists all communities in the graph with pagination support. - """ + """Lists all communities in the graph with pagination support.""" if ( # not auth_user.is_superuser collection_id not in auth_user.collection_ids @@ -1613,21 +1545,18 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) response = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1639,8 +1568,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1657,9 +1585,7 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCommunityResponse: - """ - Retrieves a specific community by its ID. - """ + """Retrieves a specific community by its ID.""" if ( # not auth_user.is_superuser collection_id not in auth_user.collection_ids @@ -1690,8 +1616,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1701,13 +1626,11 @@ class GraphRouter(BaseRouterV3): collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", community_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1720,8 +1643,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1769,8 +1691,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -1782,13 +1703,11 @@ class GraphRouter(BaseRouterV3): columns=["id", "title", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -1803,27 +1722,24 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/graphs/export_communities" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, ) @self.base_endpoint - async def export_relationships( + async def export_communities( background_tasks: BackgroundTasks, collection_id: UUID = Path( ..., @@ -1840,9 +1756,7 @@ class GraphRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export documents as a downloadable CSV file. - """ + """Export documents as a downloadable CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -1876,8 +1790,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1891,13 +1804,11 @@ class GraphRouter(BaseRouterV3): "description": "Tech companies and products" } } - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1916,8 +1827,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1933,9 +1843,7 @@ class GraphRouter(BaseRouterV3): rating_explanation: Optional[str] = Body(None), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCommunityResponse: - """ - Updates an existing community in the graph. - """ + """Updates an existing community in the graph.""" if ( not auth_user.is_superuser and collection_id not in auth_user.graph_ids @@ -1970,8 +1878,7 @@ class GraphRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1979,13 +1886,11 @@ class GraphRouter(BaseRouterV3): response = client.graphs.pull( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1997,8 +1902,7 @@ class GraphRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -2017,8 +1921,8 @@ class GraphRouter(BaseRouterV3): # ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Adds documents to a graph by copying their entities and relationships. + """Adds documents to a graph by copying their entities and + relationships. This endpoint: 1. Copies document entities to the graphs_entities table diff --git a/py/core/main/api/v3/indices_router.py b/py/core/main/api/v3/indices_router.py index 0661d403b..84f9a0215 100644 --- a/py/core/main/api/v3/indices_router.py +++ b/py/core/main/api/v3/indices_router.py @@ -38,8 +38,7 @@ class IndicesRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -78,13 +77,11 @@ class IndicesRouter(BaseRouterV3): "concurrently": True } ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -109,13 +106,11 @@ class IndicesRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" # Create HNSW Index curl -X POST "https://api.example.com/indices" \\ -H "Content-Type: application/json" \\ @@ -155,8 +150,7 @@ class IndicesRouter(BaseRouterV3): "concurrently": true } }' - """ - ), + """), }, ] }, @@ -170,9 +164,12 @@ class IndicesRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Create a new vector similarity search index in over the target table. Allowed tables include 'vectors', 'entity', 'document_collections'. - Vectors correspond to the chunks of text that are indexed for similarity search, whereas entity and document_collections are created during knowledge graph construction. + """Create a new vector similarity search index in over the target + table. Allowed tables include 'vectors', 'entity', + 'document_collections'. Vectors correspond to the chunks of text + that are indexed for similarity search, whereas entity and + document_collections are created during knowledge graph + construction. This endpoint creates a database index optimized for efficient similarity search over vector embeddings. It supports two main indexing methods: @@ -238,8 +235,7 @@ class IndicesRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -249,13 +245,11 @@ class IndicesRouter(BaseRouterV3): offset=0, limit=10 ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -268,13 +262,11 @@ class IndicesRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/indices?offset=0&limit=10" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" @@ -283,8 +275,7 @@ class IndicesRouter(BaseRouterV3): curl -X GET "https://api.example.com/indices?offset=0&limit=10&filters={\"table_name\":\"vectors\"}" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" - """ - ), + """), }, ] }, @@ -305,8 +296,8 @@ class IndicesRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedVectorIndicesResponse: - """ - List existing vector similarity search indices with pagination support. + """List existing vector similarity search indices with pagination + support. Returns details about each index including: - Name and table name @@ -345,21 +336,18 @@ class IndicesRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # Get detailed information about a specific index index = client.indices.retrieve("index_1") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -374,17 +362,14 @@ class IndicesRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/indices/vectors/index_1" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -400,8 +385,7 @@ class IndicesRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedVectorIndexResponse: - """ - Get detailed information about a specific vector index. + """Get detailed information about a specific vector index. Returns comprehensive information about the index including: - Configuration details (method, measure, parameters) @@ -501,8 +485,7 @@ class IndicesRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -513,13 +496,11 @@ class IndicesRouter(BaseRouterV3): table_name="vectors", run_with_orchestration=True ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -534,18 +515,15 @@ class IndicesRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/indices/index_1" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -566,8 +544,7 @@ class IndicesRouter(BaseRouterV3): # run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Delete an existing vector similarity search index. + """Delete an existing vector similarity search index. This endpoint removes the specified index from the database. Important considerations: diff --git a/py/core/main/api/v3/prompts_router.py b/py/core/main/api/v3/prompts_router.py index 363c8dc90..55512143c 100644 --- a/py/core/main/api/v3/prompts_router.py +++ b/py/core/main/api/v3/prompts_router.py @@ -35,8 +35,7 @@ class PromptsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -47,13 +46,11 @@ class PromptsRouter(BaseRouterV3): template="Hello, {name}!", input_types={"name": "string"} ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -67,19 +64,16 @@ class PromptsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/prompts" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" \\ -d '{"name": "greeting_prompt", "template": "Hello, {name}!", "input_types": {"name": "string"}}' - """ - ), + """), }, ] }, @@ -96,10 +90,10 @@ class PromptsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Create a new prompt with the given configuration. + """Create a new prompt with the given configuration. - This endpoint allows superusers to create a new prompt with a specified name, template, and input types. + This endpoint allows superusers to create a new prompt with a + specified name, template, and input types. """ if not auth_user.is_superuser: raise R2RException( @@ -119,21 +113,18 @@ class PromptsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.prompts.list() - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -143,17 +134,14 @@ class PromptsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/v3/prompts" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -162,10 +150,10 @@ class PromptsRouter(BaseRouterV3): async def get_prompts( auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedPromptsResponse: - """ - List all available prompts. + """List all available prompts. - This endpoint retrieves a list of all prompts in the system. Only superusers can access this endpoint. + This endpoint retrieves a list of all prompts in the system. Only + superusers can access this endpoint. """ if not auth_user.is_superuser: raise R2RException( @@ -191,8 +179,7 @@ class PromptsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -203,13 +190,11 @@ class PromptsRouter(BaseRouterV3): inputs={"name": "John"}, prompt_override="Hi, {name}!" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -223,17 +208,14 @@ class PromptsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/prompts/greeting_prompt?inputs=%7B%22name%22%3A%22John%22%7D&prompt_override=Hi%2C%20%7Bname%7D!" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -249,11 +231,12 @@ class PromptsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedPromptResponse: - """ - Get a specific prompt by name, optionally with inputs and override. + """Get a specific prompt by name, optionally with inputs and + override. - This endpoint retrieves a specific prompt and allows for optional inputs and template override. - Only superusers can access this endpoint. + This endpoint retrieves a specific prompt and allows for optional + inputs and template override. Only superusers can access this + endpoint. """ if not auth_user.is_superuser: raise R2RException( @@ -273,8 +256,7 @@ class PromptsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -285,13 +267,11 @@ class PromptsRouter(BaseRouterV3): template="Greetings, {name}!", input_types={"name": "string", "age": "integer"} ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -305,19 +285,16 @@ class PromptsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X PUT "https://api.example.com/v3/prompts/greeting_prompt" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" \\ -d '{"template": "Greetings, {name}!", "input_types": {"name": "string", "age": "integer"}}' - """ - ), + """), }, ] }, @@ -334,10 +311,10 @@ class PromptsRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedGenericMessageResponse: - """ - Update an existing prompt's template and/or input types. + """Update an existing prompt's template and/or input types. - This endpoint allows superusers to update the template and input types of an existing prompt. + This endpoint allows superusers to update the template and input + types of an existing prompt. """ if not auth_user.is_superuser: raise R2RException( @@ -357,21 +334,18 @@ class PromptsRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.prompts.delete("greeting_prompt") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -383,17 +357,14 @@ class PromptsRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/v3/prompts/greeting_prompt" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -403,8 +374,7 @@ class PromptsRouter(BaseRouterV3): name: str = Path(..., description="Prompt name"), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete a prompt by name. + """Delete a prompt by name. This endpoint allows superusers to delete an existing prompt. """ diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 79326f29a..d21257993 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -60,10 +60,9 @@ class RetrievalRouter(BaseRouterV3): search_mode: SearchMode, search_settings: Optional[SearchSettings], ) -> SearchSettings: - """ - Prepare the effective search settings based on the provided search_mode, - optional user-overrides in search_settings, and applied filters. - """ + """Prepare the effective search settings based on the provided + search_mode, optional user-overrides in search_settings, and applied + filters.""" if search_mode != SearchMode.custom: # Start from mode defaults effective_settings = SearchSettings.get_default(search_mode.value) @@ -91,8 +90,7 @@ class RetrievalRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -125,13 +123,11 @@ class RetrievalRouter(BaseRouterV3): "chunk_settings": {"index_measure": "l2_distance"} } ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -147,13 +143,11 @@ class RetrievalRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/retrieval/search" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ @@ -164,8 +158,7 @@ class RetrievalRouter(BaseRouterV3): use_semantic_search: true } }' - """ - ), + """), }, ] }, @@ -198,8 +191,8 @@ class RetrievalRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedSearchResponse: - """ - Perform a search query against vector and/or graph-based databases. + """Perform a search query against vector and/or graph-based + databases. **Search Modes:** - `basic`: Defaults to semantic search. Simple and easy to use. @@ -250,8 +243,7 @@ class RetrievalRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -273,13 +265,11 @@ class RetrievalRouter(BaseRouterV3): "max_tokens": 150 } ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -303,13 +293,11 @@ class RetrievalRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/retrieval/rag" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ @@ -329,8 +317,7 @@ class RetrievalRouter(BaseRouterV3): max_tokens: 150 } }' - """ - ), + """), }, ] }, @@ -372,8 +359,7 @@ class RetrievalRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedRAGResponse: - """ - Execute a RAG (Retrieval-Augmented Generation) query. + """Execute a RAG (Retrieval-Augmented Generation) query. This endpoint combines search results with language model generation. It supports the same filtering capabilities as the search endpoint, @@ -425,8 +411,7 @@ class RetrievalRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -456,13 +441,11 @@ class RetrievalRouter(BaseRouterV3): include_title_if_available=True, conversation_id="550e8400-e29b-41d4-a716-446655440000" # Optional for conversation continuity ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -495,13 +478,11 @@ class RetrievalRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/retrieval/agent" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ @@ -524,8 +505,7 @@ class RetrievalRouter(BaseRouterV3): "include_title_if_available": true, "conversation_id": "550e8400-e29b-41d4-a716-446655440000" }' - """ - ), + """), }, ] }, @@ -591,8 +571,8 @@ class RetrievalRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedAgentResponse: - """ - Engage with an intelligent RAG-powered conversational agent for complex information retrieval and analysis. + """Engage with an intelligent RAG-powered conversational agent for + complex information retrieval and analysis. This advanced endpoint combines retrieval-augmented generation (RAG) with a conversational AI agent to provide detailed, context-aware responses based on your document collection. The agent can: @@ -665,7 +645,7 @@ class RetrievalRouter(BaseRouterV3): else: return response except Exception as e: - raise R2RException(str(e), 500) + raise R2RException(str(e), 500) from e @self.router.post( "/retrieval/reasoning_agent", @@ -675,8 +655,7 @@ class RetrievalRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -694,13 +673,11 @@ class RetrievalRouter(BaseRouterV3): } conversation_id="550e8400-e29b-41d4-a716-446655440000" # Optional for conversation continuity ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -721,13 +698,11 @@ class RetrievalRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/retrieval/agent" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ @@ -738,8 +713,7 @@ class RetrievalRouter(BaseRouterV3): }, "conversation_id": "550e8400-e29b-41d4-a716-446655440000" }' - """ - ), + """), }, ] }, @@ -768,8 +742,8 @@ class RetrievalRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedAgentResponse: - """ - Engage with an intelligent RAG-powered conversational agent for complex information retrieval and analysis. + """Engage with an intelligent RAG-powered conversational agent for + complex information retrieval and analysis. This advanced endpoint combines retrieval-augmented generation (RAG) with a conversational AI agent to provide detailed, context-aware responses based on your document collection. The agent can: @@ -843,7 +817,7 @@ class RetrievalRouter(BaseRouterV3): else: return response except Exception as e: - raise R2RException(str(e), 500) + raise R2RException(str(e), 500) from e @self.router.post( "/retrieval/completion", @@ -853,8 +827,7 @@ class RetrievalRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -874,13 +847,11 @@ class RetrievalRouter(BaseRouterV3): "stream": False } ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -903,13 +874,11 @@ class RetrievalRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/retrieval/completion" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ @@ -927,8 +896,7 @@ class RetrievalRouter(BaseRouterV3): "stream": false } }' - """ - ), + """), }, ] }, @@ -967,14 +935,15 @@ class RetrievalRouter(BaseRouterV3): auth_user=Depends(self.providers.auth.auth_wrapper()), response_model=WrappedCompletionResponse, ) -> WrappedLLMChatCompletion: - """ - Generate completions for a list of messages. + """Generate completions for a list of messages. - This endpoint uses the language model to generate completions for the provided messages. - The generation process can be customized using the generation_config parameter. + This endpoint uses the language model to generate completions for + the provided messages. The generation process can be customized + using the generation_config parameter. - The messages list should contain alternating user and assistant messages, with an optional - system message at the start. Each message should have a 'role' and 'content'. + The messages list should contain alternating user and assistant + messages, with an optional system message at the start. Each + message should have a 'role' and 'content'. """ return await self.services.retrieval.completion( @@ -990,8 +959,7 @@ class RetrievalRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1000,13 +968,11 @@ class RetrievalRouter(BaseRouterV3): result = client.retrieval.embedding( text="Who is Aristotle?", ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1018,21 +984,18 @@ class RetrievalRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/retrieval/embedding" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{ "text": "Who is Aristotle?", }' - """ - ), + """), }, ] }, @@ -1045,11 +1008,12 @@ class RetrievalRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedEmbeddingResponse: - """ - Generate embeddings for the provided text using the specified model. + """Generate embeddings for the provided text using the specified + model. - This endpoint uses the language model to generate embeddings for the provided text. - The model parameter specifies the model to use for generating embeddings. + This endpoint uses the language model to generate embeddings for + the provided text. The model parameter specifies the model to use + for generating embeddings. """ return await self.services.retrieval.embedding( diff --git a/py/core/main/api/v3/system_router.py b/py/core/main/api/v3/system_router.py index 27cff1d06..682be7509 100644 --- a/py/core/main/api/v3/system_router.py +++ b/py/core/main/api/v3/system_router.py @@ -36,21 +36,18 @@ class SystemRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.system.health() - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -60,18 +57,15 @@ class SystemRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/health"\\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ - """ - ), + """), }, ] }, @@ -87,21 +81,18 @@ class SystemRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.system.settings() - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -111,18 +102,15 @@ class SystemRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/system/settings" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ - """ - ), + """), }, ] }, @@ -145,21 +133,18 @@ class SystemRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # when using auth, do client.login(...) result = client.system.status() - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -169,18 +154,15 @@ class SystemRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/system/status" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ - """ - ), + """), }, ] }, diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index b9565d439..d1181813d 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -60,21 +60,18 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() new_user = client.users.create( email="jane.doe@example.com", password="secure_password123" - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -87,20 +84,17 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users" \\ -H "Content-Type: application/json" \\ -d '{ "email": "jane.doe@example.com", "password": "secure_password123" - }'""" - ), + }'"""), }, ] }, @@ -160,8 +154,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient("http://localhost:7272") @@ -172,13 +165,11 @@ class UsersRouter(BaseRouterV3): columns=["id", "name", "created_at"], include_header=True, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient("http://localhost:7272"); @@ -192,21 +183,18 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "http://127.0.0.1:7272/v3/users/export" \ -H "Authorization: Bearer YOUR_API_KEY" \ -H "Content-Type: application/json" \ -H "Accept: text/csv" \ -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \ --output export.csv - """ - ), + """), }, ] }, @@ -225,9 +213,7 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> FileResponse: - """ - Export users as a CSV file. - """ + """Export users as a CSV file.""" if not auth_user.is_superuser: raise R2RException( @@ -260,21 +246,18 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() tokens = client.users.verify_email( email="jane.doe@example.com", verification_code="1lklwal!awdclm" - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -287,18 +270,15 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/login" \\ -H "Content-Type: application/x-www-form-urlencoded" \\ -d "email=jane.doe@example.com&verification_code=1lklwal!awdclm" - """ - ), + """), }, ] }, @@ -337,20 +317,17 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() tokens = client.users.send_verification_email( email="jane.doe@example.com", - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -362,18 +339,15 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/send-verification-email" \\ -H "Content-Type: application/x-www-form-urlencoded" \\ -d "email=jane.doe@example.com" - """ - ), + """), }, ] }, @@ -407,8 +381,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -416,13 +389,11 @@ class UsersRouter(BaseRouterV3): email="jane.doe@example.com", password="secure_password123" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -435,18 +406,15 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/login" \\ -H "Content-Type: application/x-www-form-urlencoded" \\ -d "username=jane.doe@example.com&password=secure_password123" - """ - ), + """), }, ] }, @@ -467,20 +435,17 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # client.login(...) result = client.users.logout() - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -490,17 +455,14 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/logout" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -521,21 +483,18 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() # client.login(...) new_tokens = client.users.refresh_token() - # New tokens are automatically stored in the client""" - ), + # New tokens are automatically stored in the client"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -545,19 +504,16 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/refresh-token" \\ -H "Content-Type: application/json" \\ -d '{ "refresh_token": "YOUR_REFRESH_TOKEN" - }'""" - ), + }'"""), }, ] }, @@ -580,8 +536,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -590,13 +545,11 @@ class UsersRouter(BaseRouterV3): result = client.users.change_password( current_password="old_password123", new_password="new_secure_password456" - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -609,21 +562,18 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/change-password" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" \\ -d '{ "current_password": "old_password123", "new_password": "new_secure_password456" - }'""" - ), + }'"""), }, ] }, @@ -650,20 +600,17 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() result = client.users.request_password_reset( email="jane.doe@example.com" - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -675,19 +622,16 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/request-password-reset" \\ -H "Content-Type: application/json" \\ -d '{ "email": "jane.doe@example.com" - }'""" - ), + }'"""), }, ] }, @@ -710,21 +654,18 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() result = client.users.reset_password( reset_token="reset_token_received_via_email", new_password="new_secure_password789" - )""" - ), + )"""), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -737,20 +678,17 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/v3/users/reset-password" \\ -H "Content-Type: application/json" \\ -d '{ "reset_token": "reset_token_received_via_email", "new_password": "new_secure_password789" - }'""" - ), + }'"""), }, ] }, @@ -774,8 +712,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -786,13 +723,11 @@ class UsersRouter(BaseRouterV3): offset=0, limit=100, ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -802,17 +737,14 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/users?offset=0&limit=100&username=john&email=john@example.com&is_active=true&is_superuser=false" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -843,8 +775,8 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedUsersResponse: - """ - List all users with pagination and filtering options. + """List all users with pagination and filtering options. + Only accessible by superusers. """ @@ -873,8 +805,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -882,13 +813,11 @@ class UsersRouter(BaseRouterV3): # Get user details users = client.users.me() - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -898,17 +827,14 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/users/me" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -917,9 +843,8 @@ class UsersRouter(BaseRouterV3): async def get_current_user( auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedUserResponse: - """ - Get detailed information about the currently authenticated user. - """ + """Get detailed information about the currently authenticated + user.""" return auth_user @self.router.get( @@ -930,8 +855,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -941,13 +865,11 @@ class UsersRouter(BaseRouterV3): users = client.users.retrieve( id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -959,17 +881,14 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -981,9 +900,10 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedUserResponse: - """ - Get detailed information about a specific user. - Users can only access their own information unless they are superusers. + """Get detailed information about a specific user. + + Users can only access their own information unless they are + superusers. """ if not auth_user.is_superuser and auth_user.id != id: raise R2RException( @@ -1009,8 +929,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1018,13 +937,11 @@ class UsersRouter(BaseRouterV3): # Delete user client.users.delete(id="550e8400-e29b-41d4-a716-446655440000", password="secure_password123") - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1037,8 +954,7 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, ] }, @@ -1057,8 +973,8 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete a specific user. + """Delete a specific user. + Users can only delete their own account unless they are superusers. """ if not auth_user.is_superuser and auth_user.id != id: @@ -1083,8 +999,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1096,13 +1011,11 @@ class UsersRouter(BaseRouterV3): offset=0, limit=100 ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1116,17 +1029,14 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections?offset=0&limit=100" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1149,9 +1059,10 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedCollectionsResponse: - """ - Get all collections associated with a specific user. - Users can only access their own collections unless they are superusers. + """Get all collections associated with a specific user. + + Users can only access their own collections unless they are + superusers. """ if auth_user.id != id and not auth_user.is_superuser: raise R2RException( @@ -1178,8 +1089,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1190,13 +1100,11 @@ class UsersRouter(BaseRouterV3): id="550e8400-e29b-41d4-a716-446655440000", collection_id="750e8400-e29b-41d4-a716-446655440000" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1209,17 +1117,14 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1254,8 +1159,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1266,13 +1170,11 @@ class UsersRouter(BaseRouterV3): id="550e8400-e29b-41d4-a716-446655440000", collection_id="750e8400-e29b-41d4-a716-446655440000" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1285,17 +1187,14 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/collections/750e8400-e29b-41d4-a716-446655440000" \\ -H "Authorization: Bearer YOUR_API_KEY" - """ - ), + """), }, ] }, @@ -1310,8 +1209,8 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Remove a user from a collection. + """Remove a user from a collection. + Requires either superuser status or access to the collection. """ if auth_user.id != id and not auth_user.is_superuser: @@ -1334,8 +1233,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1346,13 +1244,11 @@ class UsersRouter(BaseRouterV3): "550e8400-e29b-41d4-a716-446655440000", name="John Doe" ) - """ - ), + """), }, { "lang": "JavaScript", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" const { r2rClient } = require("r2r-js"); const client = new r2rClient(); @@ -1365,13 +1261,11 @@ class UsersRouter(BaseRouterV3): } main(); - """ - ), + """), }, { "lang": "Shell", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" \\ @@ -1379,8 +1273,7 @@ class UsersRouter(BaseRouterV3): "id": "550e8400-e29b-41d4-a716-446655440000", "name": "John Doe", }' - """ - ), + """), }, ] }, @@ -1407,10 +1300,11 @@ class UsersRouter(BaseRouterV3): metadata: dict[str, str | None] | None = None, auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedUserResponse: - """ - Update user information. - Users can only update their own information unless they are superusers. - Superuser status can only be modified by existing superusers. + """Update user information. + + Users can only update their own information unless they are + superusers. Superuser status can only be modified by existing + superusers. """ if is_superuser is not None and not auth_user.is_superuser: @@ -1453,8 +1347,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1466,18 +1359,15 @@ class UsersRouter(BaseRouterV3): description="API key for accessing the app", ) # result["api_key"] contains the newly created API key - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\ -H "Authorization: Bearer YOUR_API_TOKEN" \\ -d '{"name": "My API Key", "description": "API key for accessing the app"}' - """ - ), + """), }, ] }, @@ -1495,8 +1385,8 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedAPIKeyResponse: - """ - Create a new API key for the specified user. + """Create a new API key for the specified user. + Only superusers or the user themselves may create an API key. """ if auth_user.id != id and not auth_user.is_superuser: @@ -1518,8 +1408,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient client = R2RClient() @@ -1528,17 +1417,14 @@ class UsersRouter(BaseRouterV3): keys = client.users.list_api_keys( id="550e8400-e29b-41d4-a716-446655440000" ) - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\ -H "Authorization: Bearer YOUR_API_TOKEN" - """ - ), + """), }, ] }, @@ -1550,8 +1436,8 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedAPIKeysResponse: - """ - List all API keys for the specified user. + """List all API keys for the specified user. + Only superusers or the user themselves may list the API keys. """ if auth_user.id != id and not auth_user.is_superuser: @@ -1575,8 +1461,7 @@ class UsersRouter(BaseRouterV3): "x-codeSamples": [ { "lang": "Python", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" from r2r import R2RClient from uuid import UUID @@ -1587,17 +1472,14 @@ class UsersRouter(BaseRouterV3): id="550e8400-e29b-41d4-a716-446655440000", key_id="d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25" ) - """ - ), + """), }, { "lang": "cURL", - "source": textwrap.dedent( - """ + "source": textwrap.dedent(""" curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys/d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25" \\ -H "Authorization: Bearer YOUR_API_TOKEN" - """ - ), + """), }, ] }, @@ -1610,8 +1492,8 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedBooleanResponse: - """ - Delete a specific API key for the specified user. + """Delete a specific API key for the specified user. + Only superusers or the user themselves may delete the API key. """ if auth_user.id != id and not auth_user.is_superuser: @@ -1691,9 +1573,8 @@ class UsersRouter(BaseRouterV3): ), auth_user=Depends(self.providers.auth.auth_wrapper()), ) -> WrappedLimitsResponse: - """ - Return the system default limits, user-level overrides, and final "effective" limit settings - for the specified user. + """Return the system default limits, user-level overrides, and + final "effective" limit settings for the specified user. Only superusers or the user themself may fetch these values. """ @@ -1712,9 +1593,7 @@ class UsersRouter(BaseRouterV3): @self.router.get("/users/oauth/google/authorize") @self.base_endpoint async def google_authorize() -> WrappedGenericMessageResponse: - """ - Redirect user to Google's OAuth 2.0 consent screen. - """ + """Redirect user to Google's OAuth 2.0 consent screen.""" state = "some_random_string_or_csrf_token" # Usually you store a random state in session/Redis scope = "openid email profile" @@ -1736,8 +1615,8 @@ class UsersRouter(BaseRouterV3): async def google_callback( code: str = Query(...), state: str = Query(...) ) -> WrappedLoginResponse: - """ - Google's callback that will receive the `code` and `state`. + """Google's callback that will receive the `code` and `state`. + We then exchange code for tokens, verify, and log the user in. """ # 1. Exchange `code` for tokens @@ -1770,7 +1649,7 @@ class UsersRouter(BaseRouterV3): raise HTTPException( status_code=400, detail=f"Token verification failed: {str(e)}", - ) + ) from e # id_info will contain "sub", "email", etc. google_id = id_info["sub"] @@ -1787,9 +1666,7 @@ class UsersRouter(BaseRouterV3): @self.router.get("/users/oauth/github/authorize") @self.base_endpoint async def github_authorize() -> WrappedGenericMessageResponse: - """ - Redirect user to GitHub's OAuth consent screen. - """ + """Redirect user to GitHub's OAuth consent screen.""" state = "some_random_string_or_csrf_token" scope = "read:user user:email" @@ -1807,11 +1684,9 @@ class UsersRouter(BaseRouterV3): async def github_callback( code: str = Query(...), state: str = Query(...) ) -> WrappedLoginResponse: - """ - GitHub callback route to exchange code for an access_token, - then fetch user info from GitHub's API, - then do the same 'oauth-based' login or registration. - """ + """GitHub callback route to exchange code for an access_token, then + fetch user info from GitHub's API, then do the same 'oauth-based' + login or registration.""" # 1. Exchange code for access_token token_resp = requests.post( "https://github.com/login/oauth/access_token", diff --git a/py/core/main/app_entry.py b/py/core/main/app_entry.py index 930548436..83cefc8bd 100644 --- a/py/core/main/app_entry.py +++ b/py/core/main/app_entry.py @@ -14,7 +14,6 @@ from .assembly import R2RBuilder, R2RConfig logger, log_file = configure_logging() - # Global scheduler scheduler = AsyncIOScheduler() diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index a6d1ed9fc..5c285b0f8 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -15,11 +15,6 @@ from core.base import ( IngestionConfig, OrchestrationConfig, ) - -from ..abstractions import R2RProviders -from ..config import R2RConfig - -logger = logging.getLogger() from core.providers import ( AnthropicCompletionProvider, AsyncSMTPEmailProvider, @@ -47,6 +42,11 @@ from core.providers import ( UnstructuredIngestionProvider, ) +from ..abstractions import R2RProviders +from ..config import R2RConfig + +logger = logging.getLogger() + class R2RProviderFactory: def __init__(self, config: R2RConfig): diff --git a/py/core/main/config.py b/py/core/main/config.py index 101c24d2b..2a0306205 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -167,7 +167,7 @@ class R2RConfig: @staticmethod def _serialize_config(config_section: Any) -> dict: - """Serialize config section while excluding internal state""" + """Serialize config section while excluding internal state.""" if isinstance(config_section, dict): return { R2RConfig._serialize_key(k): R2RConfig._serialize_config(v) diff --git a/py/core/main/orchestration/hatchet/graph_workflow.py b/py/core/main/orchestration/hatchet/graph_workflow.py index 48a367640..26de8bcc1 100644 --- a/py/core/main/orchestration/hatchet/graph_workflow.py +++ b/py/core/main/orchestration/hatchet/graph_workflow.py @@ -5,6 +5,7 @@ import logging import math import time import uuid +from typing import TYPE_CHECKING from hatchet_sdk import ConcurrencyLimitStrategy, Context @@ -17,20 +18,19 @@ from core.base.abstractions import ( from ...services import GraphService -logger = logging.getLogger() -from typing import TYPE_CHECKING - if TYPE_CHECKING: from hatchet_sdk import Hatchet +logger = logging.getLogger() + def hatchet_graph_search_results_factory( orchestration_provider: OrchestrationProvider, service: GraphService ) -> dict[str, "Hatchet.Workflow"]: def convert_to_dict(input_data): - """ - Converts input data back to a plain dictionary format, handling special cases like UUID and GenerationConfig. - This is the inverse of get_input_data_dict. + """Converts input data back to a plain dictionary format, handling + special cases like UUID and GenerationConfig. This is the inverse of + get_input_data_dict. Args: input_data: Dictionary containing the input data with potentially special types diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index f5760e831..2e839c2a6 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -296,12 +296,12 @@ def hatchet_ingestion_factory( raise R2RException( status_code=401, message="Authentication error: Invalid API key or credentials.", - ) + ) from None except Exception as e: raise HTTPException( status_code=500, detail=f"Error during ingestion: {str(e)}", - ) + ) from e @orchestration_provider.failure() async def on_failure(self, context: Context) -> None: @@ -605,7 +605,7 @@ def hatchet_ingestion_factory( raise HTTPException( status_code=500, detail=f"Error during chunk update: {str(e)}", - ) + ) from e @orchestration_provider.failure() async def on_failure(self, context: Context) -> None: @@ -692,7 +692,7 @@ def hatchet_ingestion_factory( raise HTTPException( status_code=500, detail=f"Error during document metadata update: {str(e)}", - ) + ) from e @orchestration_provider.failure() async def on_failure(self, context: Context) -> None: diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 5e8806a75..a055ea34d 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -236,7 +236,7 @@ def simple_ingestion_factory(service: IngestionService): raise R2RException( status_code=401, message="Authentication error: Invalid API key or credentials.", - ) + ) from e except Exception as e: if document_info is not None: await service.update_document_status( @@ -248,7 +248,7 @@ def simple_ingestion_factory(service: IngestionService): raise raise HTTPException( status_code=500, detail=f"Error during ingestion: {str(e)}" - ) + ) from e async def update_files(input_data): from core.main import IngestionServiceAdapter @@ -267,12 +267,12 @@ def simple_ingestion_factory(service: IngestionService): if not file_datas: raise R2RException( status_code=400, message="No files provided for update." - ) + ) from None if len(document_ids) != len(file_datas): raise R2RException( status_code=400, message="Number of ids does not match number of files.", - ) + ) from None documents_overview = ( await service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. @@ -287,7 +287,7 @@ def simple_ingestion_factory(service: IngestionService): raise R2RException( status_code=404, message="One or more documents not found.", - ) + ) from None results = [] @@ -333,7 +333,7 @@ def simple_ingestion_factory(service: IngestionService): raise R2RException( status_code=501, message="Automatic extraction not yet implemented for `simple` ingestion workflows.", - ) + ) from None async def ingest_chunks(input_data): document_info = None @@ -461,7 +461,7 @@ def simple_ingestion_factory(service: IngestionService): raise R2RException( status_code=501, message="Automatic extraction not yet implemented for `simple` ingestion workflows.", - ) + ) from None except Exception as e: logger.error( @@ -478,7 +478,7 @@ def simple_ingestion_factory(service: IngestionService): raise HTTPException( status_code=500, detail=f"Error during chunk ingestion: {str(e)}", - ) + ) from e async def update_chunk(input_data): from core.main import IngestionServiceAdapter @@ -511,7 +511,7 @@ def simple_ingestion_factory(service: IngestionService): raise HTTPException( status_code=500, detail=f"Error during chunk update: {str(e)}", - ) + ) from e async def create_vector_index(input_data): try: @@ -531,7 +531,7 @@ def simple_ingestion_factory(service: IngestionService): raise HTTPException( status_code=500, detail=f"Error during vector index creation: {str(e)}", - ) + ) from e async def delete_vector_index(input_data): try: @@ -553,7 +553,7 @@ def simple_ingestion_factory(service: IngestionService): raise HTTPException( status_code=500, detail=f"Error during vector index deletion: {str(e)}", - ) + ) from e async def update_document_metadata(input_data): try: @@ -585,7 +585,7 @@ def simple_ingestion_factory(service: IngestionService): raise HTTPException( status_code=500, detail=f"Error during document metadata update: {str(e)}", - ) + ) from e return { "ingest-files": ingest_files, diff --git a/py/core/main/services/auth_service.py b/py/core/main/services/auth_service.py index 7a74033e3..980b574b6 100644 --- a/py/core/main/services/auth_service.py +++ b/py/core/main/services/auth_service.py @@ -236,9 +236,10 @@ class AuthService(Service): self, user_id: UUID, ) -> dict: - """ - Get only the verification code data for a specific user. - This method should be called after superuser authorization has been verified. + """Get only the verification code data for a specific user. + + This method should be called after superuser authorization has been + verified. """ verification_data = await self.providers.database.users_handler.get_user_validation_data( user_id=user_id @@ -257,9 +258,10 @@ class AuthService(Service): self, user_id: UUID, ) -> dict: - """ - Get only the verification code data for a specific user. - This method should be called after superuser authorization has been verified. + """Get only the verification code data for a specific user. + + This method should be called after superuser authorization has been + verified. """ verification_data = await self.providers.database.users_handler.get_user_validation_data( user_id=user_id @@ -275,8 +277,7 @@ class AuthService(Service): @telemetry_event("SendResetEmail") async def send_reset_email(self, email: str) -> dict: - """ - Generate a new verification code and send a reset email to the user. + """Generate a new verification code and send a reset email to the user. Returns the verification code for testing/sandbox environments. Args: @@ -290,8 +291,8 @@ class AuthService(Service): async def create_user_api_key( self, user_id: UUID, name: Optional[str], description: Optional[str] ) -> dict: - """ - Generate a new API key for the user with optional name and description. + """Generate a new API key for the user with optional name and + description. Args: user_id (UUID): The ID of the user @@ -306,8 +307,7 @@ class AuthService(Service): ) async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: - """ - Delete the API key for the user. + """Delete the API key for the user. Args: user_id (UUID): The ID of the user @@ -321,8 +321,7 @@ class AuthService(Service): ) async def list_user_api_keys(self, user_id: UUID) -> list[dict]: - """ - List all API keys for the user. + """List all API keys for the user. Args: user_id (UUID): The ID of the user diff --git a/py/core/main/services/graph_service.py b/py/core/main/services/graph_service.py index 181c6dfd5..fdee8e76c 100644 --- a/py/core/main/services/graph_service.py +++ b/py/core/main/services/graph_service.py @@ -38,9 +38,7 @@ MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH = 128 async def _collect_async_results(result_gen: AsyncGenerator) -> list[Any]: - """ - Collects all results from an async generator into a list. - """ + """Collects all results from an async generator into a list.""" results = [] async for res in result_gen: results.append(res) @@ -384,8 +382,8 @@ class GraphService(Service): batch_size: int = 256, **kwargs, ): - """ - A new implementation of the old GraphDescriptionPipe logic inline. No references to pipe objects. + """A new implementation of the old GraphDescriptionPipe logic inline. + No references to pipe objects. We: 1) Count how many entities are in the document @@ -449,8 +447,9 @@ class GraphService(Service): limit: int, max_description_input_length: int, ) -> AsyncGenerator[str, None]: - """ - Core logic that replaces GraphDescriptionPipe._run_logic for a particular document/batch. + """Core logic that replaces GraphDescriptionPipe._run_logic for a + particular document/batch. + Yields entity-names or some textual result as each entity is updated. """ start_time = time.time() @@ -503,17 +502,16 @@ class GraphService(Service): document_id: UUID, max_description_input_length: int, ) -> str: - """ - Adapted from the old process_entity function in GraphDescriptionPipe. + """Adapted from the old process_entity function in + GraphDescriptionPipe. + If entity has no description, call an LLM to create one, then store it. Returns the name of the top entity (or could store more details). """ def truncate_info(info_list: list[str], max_length: int) -> str: - """ - Shuffles lines of info to try to keep them distinct, then accumulates - until hitting max_length. - """ + """Shuffles lines of info to try to keep them distinct, then + accumulates until hitting max_length.""" random.shuffle(info_list) truncated_info = "" current_length = 0 @@ -626,9 +624,8 @@ class GraphService(Service): generation_config: GenerationConfig, leiden_params: dict, ) -> dict: - """ - The actual clustering logic (previously in GraphClusteringPipe.cluster_graph_search_results). - """ + """The actual clustering logic (previously in + GraphClusteringPipe.cluster_graph_search_results).""" clustering_mode = ( self.config.database.graph_creation_settings.clustering_mode ) @@ -650,9 +647,10 @@ class GraphService(Service): leiden_params: Optional[dict] = None, **kwargs, ): - """ - Replacement for the old GraphCommunitySummaryPipe logic. Summarizes communities after clustering. - Returns an async generator or you can collect into a list. + """Replacement for the old GraphCommunitySummaryPipe logic. + + Summarizes communities after clustering. Returns an async generator or + you can collect into a list. """ logger.info( f"Running inline community summaries for coll={collection_id}, offset={offset}, limit={limit}" @@ -677,8 +675,9 @@ class GraphService(Service): collection_id: UUID, leiden_params: dict, ) -> AsyncGenerator[dict, None]: - """ - Does the community summary logic from GraphCommunitySummaryPipe._run_logic. + """Does the community summary logic from + GraphCommunitySummaryPipe._run_logic. + Yields each summary dictionary as it completes. """ start_time = time.time() @@ -920,9 +919,8 @@ class GraphService(Service): relationships: list[Relationship], max_summary_input_length: int, ) -> str: - """ - Gathers the entity/relationship text, tries not to exceed `max_summary_input_length`. - """ + """Gathers the entity/relationship text, tries not to exceed + `max_summary_input_length`.""" # Group them by entity.name entity_map: dict[str, dict] = {} for e in entities: @@ -992,9 +990,8 @@ class GraphService(Service): *args: Any, **kwargs: Any, ) -> AsyncGenerator[GraphExtraction | R2RDocumentProcessingError, None]: - """ - The original “extract Graph from doc” logic, but inlined instead of referencing a pipe. - """ + """The original “extract Graph from doc” logic, but inlined instead of + referencing a pipe.""" start_time = time.time() logger.info( @@ -1107,10 +1104,8 @@ class GraphService(Service): task_id: Optional[int] = None, total_tasks: Optional[int] = None, ) -> GraphExtraction: - """ - (Equivalent to _extract_graph_search_results in old code.) - Merges chunk data, calls LLM, parses XML, returns GraphExtraction object. - """ + """(Equivalent to _extract_graph_search_results in old code.) Merges + chunk data, calls LLM, parses XML, returns GraphExtraction object.""" combined_extraction: str = " ".join([c.data for c in chunks if c.data]) # Possibly get doc-level summary @@ -1177,9 +1172,8 @@ class GraphService(Service): async def _parse_graph_search_results_extraction_xml( self, response_str: str, chunks: list[DocumentChunk] ) -> tuple[list[Entity], list[Relationship]]: - """ - Helper to parse the LLM's XML format, handle edge cases/cleanup, produce Entities/Relationships. - """ + """Helper to parse the LLM's XML format, handle edge cases/cleanup, + produce Entities/Relationships.""" def sanitize_xml(r: str) -> str: # Remove markdown fences @@ -1197,10 +1191,10 @@ class GraphService(Service): wrapped = f"{cleaned_xml}" try: root = ET.fromstring(wrapped) - except ET.ParseError as e: + except ET.ParseError: raise R2RException( - f"Failed to parse XML: {e}\nData: {wrapped[:1000]}...", 400 - ) + f"Failed to parse XML:\nData: {wrapped[:1000]}...", 400 + ) from None entities_elems = root.findall(".//entity") if ( @@ -1279,9 +1273,7 @@ class GraphService(Service): self, graph_search_results_extractions: list[GraphExtraction], ): - """ - Stores a batch of knowledge graph extractions in the DB. - """ + """Stores a batch of knowledge graph extractions in the DB.""" for extraction in graph_search_results_extractions: # Map name->id after creation entities_id_map = {} diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 3c65995a9..9bc615051 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -41,10 +41,8 @@ STARTING_VERSION = "v0" class IngestionService: - """ - A refactored IngestionService that inlines all pipe logic for parsing, - embedding, and vector storage directly in its methods. - """ + """A refactored IngestionService that inlines all pipe logic for parsing, + embedding, and vector storage directly in its methods.""" def __init__( self, @@ -66,9 +64,11 @@ class IngestionService: *args: Any, **kwargs: Any, ) -> dict: - """ - Pre-ingests a file by creating or validating the DocumentResponse entry. - Does not actually parse/ingest the content. (See parse_file() for that step.) + """Pre-ingests a file by creating or validating the DocumentResponse + entry. + + Does not actually parse/ingest the content. (See parse_file() for that + step.) """ try: if not file_data: @@ -137,7 +137,7 @@ class IngestionService: except Exception as e: raise HTTPException( status_code=500, detail=f"Error during ingestion: {str(e)}" - ) + ) from e def create_document_info_from_file( self, @@ -210,11 +210,9 @@ class IngestionService: document_info: DocumentResponse, ingestion_config: dict | None, ) -> AsyncGenerator[DocumentChunk, None]: - """ - Inline replacement for the old parsing_pipe.run(...) - Reads the file content from the DB, calls the ingestion provider to parse, - and yields DocumentChunk objects. - """ + """Inline replacement for the old parsing_pipe.run(...) Reads the file + content from the DB, calls the ingestion provider to parse, and yields + DocumentChunk objects.""" version = document_info.version or "v0" ingestion_config_override = ingestion_config or {} @@ -278,14 +276,14 @@ class IngestionService: error_message=e.message, document_id=document_info.id, status_code=e.status_code, - ) + ) from None except Exception as e: if isinstance(e, R2RException): raise raise R2RDocumentProcessingError( document_id=document_info.id, error_message=f"Error parsing document: {str(e)}", - ) + ) from e async def augment_document_info( self, @@ -336,8 +334,8 @@ class IngestionService: chunked_documents: list[dict], embedding_batch_size: int = 8, ) -> AsyncGenerator[VectorEntry, None]: - """ - Inline replacement for the old embedding_pipe.run(...). + """Inline replacement for the old embedding_pipe.run(...). + Batches the embedding calls and yields VectorEntry objects. """ if not chunked_documents: @@ -423,10 +421,10 @@ class IngestionService: embeddings: Sequence[dict | VectorEntry], storage_batch_size: int = 128, ) -> AsyncGenerator[str, None]: - """ - Inline replacement for the old vector_storage_pipe.run(...). - Batches up the vector entries, enforces usage limits, stores them, - and yields a success/error string (or you could yield a StorageResult). + """Inline replacement for the old vector_storage_pipe.run(...). + + Batches up the vector entries, enforces usage limits, stores them, and + yields a success/error string (or you could yield a StorageResult). """ if not embeddings: return @@ -520,10 +518,8 @@ class IngestionService: async def finalize_ingestion( self, document_info: DocumentResponse ) -> None: - """ - Called at the end of a successful ingestion pipeline to - set the document status to SUCCESS or similar final steps. - """ + """Called at the end of a successful ingestion pipeline to set the + document status to SUCCESS or similar final steps.""" async def empty_generator(): yield document_info @@ -566,9 +562,8 @@ class IngestionService: *args: Any, **kwargs: Any, ) -> DocumentResponse: - """ - Directly ingest user-provided text chunks (rather than from a file). - """ + """Directly ingest user-provided text chunks (rather than from a + file).""" if not chunks: raise R2RException( status_code=400, message="No chunks provided for ingestion." @@ -619,9 +614,8 @@ class IngestionService: *args: Any, **kwargs: Any, ) -> dict: - """ - Update an individual chunk's text and metadata, re-embed, and re-store it. - """ + """Update an individual chunk's text and metadata, re-embed, and re- + store it.""" # Verify chunk exists and user has access existing_chunks = ( await self.providers.database.chunks_handler.list_document_chunks( @@ -696,9 +690,9 @@ class IngestionService: chunk_enrichment_settings: ChunkEnrichmentSettings, list_document_chunks: list[dict], ) -> VectorEntry: - """ - Helper for chunk_enrichment. Leverages an LLM to rewrite or expand chunk text, - then re-embeds it. + """Helper for chunk_enrichment. + + Leverages an LLM to rewrite or expand chunk text, then re-embeds it. """ preceding_chunks = [ list_document_chunks[idx]["text"] @@ -782,10 +776,8 @@ class IngestionService: document_summary: str | None, chunk_enrichment_settings: ChunkEnrichmentSettings, ) -> int: - """ - Example function that modifies chunk text via an LLM then re-embeds - and re-stores all chunks for the given document. - """ + """Example function that modifies chunk text via an LLM then re-embeds + and re-stores all chunks for the given document.""" list_document_chunks = ( await self.providers.database.chunks_handler.list_document_chunks( document_id=document_id, diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index d58bdaa8c..5b6267eea 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -73,9 +73,8 @@ class ManagementService(Service): self, filters: dict[str, Any], ): - """ - Delete chunks matching the given filters. If any documents are now empty - (i.e., have no remaining chunks), delete those documents as well. + """Delete chunks matching the given filters. If any documents are now + empty (i.e., have no remaining chunks), delete those documents as well. Args: filters (dict[str, Any]): Filters specifying which chunks to delete. @@ -90,8 +89,9 @@ class ManagementService(Service): def transform_chunk_id_to_id( filters: dict[str, Any], ) -> dict[str, Any]: - """ - Example transformation function if your filters use `chunk_id` instead of `id`. + """Example transformation function if your filters use `chunk_id` + instead of `id`. + Recursively transform `chunk_id` to `id`. """ if isinstance(filters, dict): @@ -562,7 +562,7 @@ class ManagementService(Service): name=name, description=description, ) - graph_result = await self.providers.database.graphs_handler.create( + await self.providers.database.graphs_handler.create( collection_id=result.id, name=name, description=description, @@ -706,7 +706,7 @@ class ManagementService(Service): ) return f"Prompt '{name}' added successfully." # type: ignore except ValueError as e: - raise R2RException(status_code=400, message=str(e)) + raise R2RException(status_code=400, message=str(e)) from e @telemetry_event("GetPrompt") async def get_cached_prompt( @@ -726,7 +726,7 @@ class ManagementService(Service): ) } except ValueError as e: - raise R2RException(status_code=404, message=str(e)) + raise R2RException(status_code=404, message=str(e)) from e @telemetry_event("GetPrompt") async def get_prompt( @@ -742,7 +742,7 @@ class ManagementService(Service): prompt_override=prompt_override, ) except ValueError as e: - raise R2RException(status_code=404, message=str(e)) + raise R2RException(status_code=404, message=str(e)) from e @telemetry_event("GetAllPrompts") async def get_all_prompts(self) -> dict[str, Prompt]: @@ -761,7 +761,7 @@ class ManagementService(Service): ) return f"Prompt '{name}' updated successfully." # type: ignore except ValueError as e: - raise R2RException(status_code=404, message=str(e)) + raise R2RException(status_code=404, message=str(e)) from e @telemetry_event("DeletePrompt") async def delete_prompt(self, name: str) -> dict: @@ -769,7 +769,7 @@ class ManagementService(Service): await self.providers.database.prompts_handler.delete_prompt(name) return {"message": f"Prompt '{name}' deleted successfully."} except ValueError as e: - raise R2RException(status_code=404, message=str(e)) + raise R2RException(status_code=404, message=str(e)) from e @telemetry_event("GetConversation") async def get_conversation( @@ -890,9 +890,9 @@ class ManagementService(Service): async def get_max_upload_size_by_type( self, user_id: UUID, file_type_or_ext: str ) -> int: - """ - Return the maximum allowed upload size (in bytes) for the given user's file type/extension. - Respects user-level overrides if present, falling back to the system config. + """Return the maximum allowed upload size (in bytes) for the given + user's file type/extension. Respects user-level overrides if present, + falling back to the system config. ```json { @@ -907,7 +907,6 @@ class ManagementService(Service): } } ``` - """ # 1. Normalize extension ext = file_type_or_ext.lower().lstrip(".") diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 1d249e3c0..7079cedd4 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -8,6 +8,7 @@ from datetime import datetime from typing import Any, Optional from uuid import UUID +import tiktoken from fastapi import HTTPException from core import ( @@ -52,9 +53,6 @@ from .base import Service logger = logging.getLogger() -import tiktoken - - def convert_nonserializable_objects(obj): if isinstance(obj, dict): new_obj = {} @@ -124,7 +122,8 @@ def tokens_count_for_message(message, encoding): def num_tokens_from_messages(messages, model="gpt-4o"): - """Return the number of tokens used by a list of messages for both user and assistant.""" + """Return the number of tokens used by a list of messages for both user and + assistant.""" try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -132,7 +131,7 @@ def num_tokens_from_messages(messages, model="gpt-4o"): encoding = tiktoken.get_encoding("cl100k_base") tokens = 0 - for i, message in enumerate(messages): + for i, _message in enumerate(messages): tokens += tokens_count_for_message(messages[i], encoding) tokens += 3 # every reply is primed with assistant @@ -158,8 +157,9 @@ class RetrievalService(Service): *args, **kwargs, ) -> AggregateSearchResult: - """ - Replaces your pipeline-based `SearchPipeline.run(...)` with a single method. + """Replaces your pipeline-based `SearchPipeline.run(...)` with a single + method. + Does parallel vector + graph search, returning an aggregated result. """ @@ -230,12 +230,10 @@ class RetrievalService(Service): query: str, search_settings: SearchSettings, ) -> list[ChunkSearchResult]: - """ - Equivalent to your old VectorSearchPipe.search, but simplified: - • embed query - • do fulltext, semantic, or hybrid search - • optional re-rank - • return list of ChunkSearchResult + """Equivalent to your old VectorSearchPipe.search, but simplified: + + • embed query • do fulltext, semantic, or hybrid search • optional re- + rank • return list of ChunkSearchResult """ # If chunk search is disabled, just return empty if not search_settings.chunk_settings.enabled: @@ -343,7 +341,7 @@ class RetrievalService(Service): if isinstance(metadata, str): try: metadata = json.loads(metadata) - except: + except Exception: pass # store @@ -392,7 +390,7 @@ class RetrievalService(Service): if isinstance(metadata, str): try: metadata = json.loads(metadata) - except: + except Exception: pass results.append( @@ -439,7 +437,7 @@ class RetrievalService(Service): if isinstance(metadata, str): try: metadata = json.loads(metadata) - except: + except Exception: pass results.append( @@ -612,11 +610,11 @@ class RetrievalService(Service): raise HTTPException( status_code=502, detail="Server not reachable or returned an invalid response", - ) + ) from e raise HTTPException( status_code=500, detail=f"Internal RAG Error - {str(e)}", - ) + ) from e async def stream_rag_response( self, @@ -641,10 +639,10 @@ class RetrievalService(Service): raise HTTPException( status_code=502, detail="Server not reachable or returned an invalid response", - ) + ) from e raise HTTPException( status_code=500, detail=f"Internal RAG Error - {str(e)}" - ) + ) from e return stream_response() @@ -1060,19 +1058,18 @@ class RetrievalService(Service): raise HTTPException( status_code=502, detail="Server not reachable or returned an invalid response", - ) + ) from e raise HTTPException( status_code=500, detail=f"Internal Server Error - {str(e)}", - ) + ) from e async def get_context( self, filters: dict[str, Any], options: dict[str, Any], ) -> list[dict[str, Any]]: - """ - Return an ordered list of documents (with minimal overview fields), + """Return an ordered list of documents (with minimal overview fields), plus all associated chunks in ascending chunk order. Only the filters: owner_id, collection_ids, and document_id @@ -1174,10 +1171,8 @@ class RetrievalService(Service): max_summary_length: int = 128, limit: int = 1000, ) -> str: - """ - Fetches documents matching the given filters and returns a formatted string - enumerating them. - """ + """Fetches documents matching the given filters and returns a formatted + string enumerating them.""" # We only want up to `limit` documents for brevity docs_data = await self.providers.database.documents_handler.get_documents_overview( offset=0, @@ -1213,10 +1208,8 @@ class RetrievalService(Service): filter_collection_ids: Optional[list[UUID]] = None, limit: int = 5, ) -> str: - """ - Fetches collections matching the given filters and returns a formatted string - enumerating them. - """ + """Fetches collections matching the given filters and returns a + formatted string enumerating them.""" coll_data = await self.providers.database.collections_handler.get_collections_overview( offset=0, limit=limit, @@ -1309,7 +1302,9 @@ class RetrievalServiceAdapter: try: user_data = json.loads(user_data) except json.JSONDecodeError: - raise ValueError(f"Invalid user data format: {user_data}") + raise ValueError( + f"Invalid user data format: {user_data}" + ) from None return User.from_dict(user_data) @staticmethod diff --git a/py/core/parsers/media/audio_parser.py b/py/core/parsers/media/audio_parser.py index a66e03c5a..4dfc01537 100644 --- a/py/core/parsers/media/audio_parser.py +++ b/py/core/parsers/media/audio_parser.py @@ -32,8 +32,8 @@ class AudioParser(AsyncParser[bytes]): async def ingest( # type: ignore self, data: bytes, **kwargs ) -> AsyncGenerator[str, None]: - """ - Ingest audio data and yield a transcription using Whisper via LiteLLM. + """Ingest audio data and yield a transcription using Whisper via + LiteLLM. Args: data: Raw audio bytes diff --git a/py/core/parsers/media/doc_parser.py b/py/core/parsers/media/doc_parser.py index 429c247e8..5b49e2cc2 100644 --- a/py/core/parsers/media/doc_parser.py +++ b/py/core/parsers/media/doc_parser.py @@ -68,7 +68,7 @@ class DOCParser(AsyncParser[str | bytes]): yield paragraph.strip() except Exception as e: - raise ValueError(f"Error processing DOC file: {str(e)}") + raise ValueError(f"Error processing DOC file: {str(e)}") from e finally: ole.close() file_obj.close() @@ -89,7 +89,7 @@ class DOCParser(AsyncParser[str | bytes]): return text except Exception as e: - raise ValueError(f"Error extracting text: {str(e)}") + raise ValueError(f"Error extracting text: {str(e)}") from e def _clean_text(self, text: str) -> list[str]: """Clean and split the extracted text into paragraphs.""" diff --git a/py/core/parsers/media/img_parser.py b/py/core/parsers/media/img_parser.py index 63c043db6..b98da006d 100644 --- a/py/core/parsers/media/img_parser.py +++ b/py/core/parsers/media/img_parser.py @@ -51,7 +51,8 @@ class ImageParser(AsyncParser[str | bytes]): try: header = data[:32] # Get first 32 bytes return any(pattern in header for pattern in heic_patterns) - except: + except Exception as e: + logger.error(f"Error checking for HEIC format: {str(e)}") return False async def _convert_heic_to_jpeg(self, data: bytes) -> bytes: diff --git a/py/core/parsers/media/odt_parser.py b/py/core/parsers/media/odt_parser.py index 07d349e55..cb1464649 100644 --- a/py/core/parsers/media/odt_parser.py +++ b/py/core/parsers/media/odt_parser.py @@ -55,6 +55,6 @@ class ODTParser(AsyncParser[str | bytes]): yield text.strip() except Exception as e: - raise ValueError(f"Error processing ODT file: {str(e)}") + raise ValueError(f"Error processing ODT file: {str(e)}") from e finally: file_obj.close() diff --git a/py/core/parsers/media/pdf_parser.py b/py/core/parsers/media/pdf_parser.py index 9bcf3b0bb..8e02233d5 100644 --- a/py/core/parsers/media/pdf_parser.py +++ b/py/core/parsers/media/pdf_parser.py @@ -46,9 +46,8 @@ class VLMPDFParser(AsyncParser[str | bytes]): async def convert_pdf_to_images( self, data: str | bytes ) -> list[Image.Image]: - """ - Convert PDF pages to images asynchronously using in-memory conversion. - """ + """Convert PDF pages to images asynchronously using in-memory + conversion.""" logger.info("Starting PDF conversion to images.") start_time = time.perf_counter() options = { @@ -71,16 +70,18 @@ class VLMPDFParser(AsyncParser[str | bytes]): f"PDF conversion completed in {elapsed:.2f} seconds, total pages: {len(images)}" ) return images - except PDFInfoNotInstalledError: + except PDFInfoNotInstalledError as e: logger.error( "PDFInfoNotInstalledError encountered during PDF conversion." ) - raise PopplerNotFoundError() + raise PopplerNotFoundError() from e except Exception as err: logger.error( f"Error converting PDF to images: {err} type: {type(err)}" ) - raise PDFParsingError(f"Failed to process PDF: {str(err)}", err) + raise PDFParsingError( + f"Failed to process PDF: {str(err)}", err + ) from err async def process_page( self, image: Image.Image, page_num: int @@ -147,8 +148,9 @@ class VLMPDFParser(AsyncParser[str | bytes]): async def ingest( self, data: str | bytes, maintain_order: bool = True, **kwargs ) -> AsyncGenerator[dict[str, str | int], None]: - """ - Ingest PDF data and yield the text description for each page using the vision model. + """Ingest PDF data and yield the text description for each page using + the vision model. + (This version yields a string per page rather than a dictionary.) """ ingest_start = time.perf_counter() diff --git a/py/core/parsers/media/ppt_parser.py b/py/core/parsers/media/ppt_parser.py index 922b6525d..c8bbaa554 100644 --- a/py/core/parsers/media/ppt_parser.py +++ b/py/core/parsers/media/ppt_parser.py @@ -83,6 +83,6 @@ class PPTParser(AsyncParser[str | bytes]): current_position += 1 except Exception as e: - raise ValueError(f"Error processing PPT file: {str(e)}") + raise ValueError(f"Error processing PPT file: {str(e)}") from e finally: ole.close() diff --git a/py/core/parsers/media/rtf_parser.py b/py/core/parsers/media/rtf_parser.py index 6aa02fc96..6be120762 100644 --- a/py/core/parsers/media/rtf_parser.py +++ b/py/core/parsers/media/rtf_parser.py @@ -42,4 +42,4 @@ class RTFParser(AsyncParser[str | bytes]): yield paragraph.strip() except Exception as e: - raise ValueError(f"Error processing RTF file: {str(e)}") + raise ValueError(f"Error processing RTF file: {str(e)}") from e diff --git a/py/core/parsers/structured/__init__.py b/py/core/parsers/structured/__init__.py index 097354501..5c2ba89ec 100644 --- a/py/core/parsers/structured/__init__.py +++ b/py/core/parsers/structured/__init__.py @@ -3,6 +3,7 @@ from .csv_parser import CSVParser, CSVParserAdvanced from .eml_parser import EMLParser from .epub_parser import EPUBParser from .json_parser import JSONParser + # from .msg_parser import MSGParser from .org_parser import ORGParser from .p7s_parser import P7SParser diff --git a/py/core/parsers/structured/epub_parser.py b/py/core/parsers/structured/epub_parser.py index c99c719ce..ff51fb86b 100644 --- a/py/core/parsers/structured/epub_parser.py +++ b/py/core/parsers/structured/epub_parser.py @@ -113,7 +113,7 @@ class EPUBParser(AsyncParser[str | bytes]): except Exception as e: logger.error(f"Error processing EPUB file: {str(e)}") - raise ValueError(f"Error processing EPUB file: {str(e)}") + raise ValueError(f"Error processing EPUB file: {str(e)}") from e finally: try: file_obj.close() diff --git a/py/core/parsers/structured/json_parser.py b/py/core/parsers/structured/json_parser.py index 08db317ba..3948e4de5 100644 --- a/py/core/parsers/structured/json_parser.py +++ b/py/core/parsers/structured/json_parser.py @@ -28,8 +28,7 @@ class JSONParser(AsyncParser[str | bytes]): async def ingest( self, data: str | bytes, *args, **kwargs ) -> AsyncGenerator[str, None]: - """ - Ingest JSON data and yield a formatted text representation. + """Ingest JSON data and yield a formatted text representation. :param data: The JSON data to parse. :param kwargs: Additional keyword arguments. @@ -48,7 +47,7 @@ class JSONParser(AsyncParser[str | bytes]): raise R2RException( message=f"Failed to parse JSON data, likely due to invalid JSON: {str(e)}", status_code=400, - ) + ) from e chunk_size = kwargs.get("chunk_size") if chunk_size and isinstance(chunk_size, int): diff --git a/py/core/parsers/structured/msg_parser.py b/py/core/parsers/structured/msg_parser.py index 4db67585a..93753c8ff 100644 --- a/py/core/parsers/structured/msg_parser.py +++ b/py/core/parsers/structured/msg_parser.py @@ -10,7 +10,6 @@ # IngestionConfig, # ) - # class MSGParser(AsyncParser[str | bytes]): # """Parser for MSG (Outlook Message) files.""" diff --git a/py/core/parsers/structured/org_parser.py b/py/core/parsers/structured/org_parser.py index b8b37ef8f..2ea3f8574 100644 --- a/py/core/parsers/structured/org_parser.py +++ b/py/core/parsers/structured/org_parser.py @@ -67,6 +67,6 @@ class ORGParser(AsyncParser[str | bytes]): yield content.strip() except Exception as e: - raise ValueError(f"Error processing ORG file: {str(e)}") + raise ValueError(f"Error processing ORG file: {str(e)}") from e finally: file_obj.close() diff --git a/py/core/parsers/structured/p7s_parser.py b/py/core/parsers/structured/p7s_parser.py index 6f1693c71..84983494c 100644 --- a/py/core/parsers/structured/p7s_parser.py +++ b/py/core/parsers/structured/p7s_parser.py @@ -125,7 +125,7 @@ class P7SParser(AsyncParser[str | bytes]): except Exception as e: raise ValueError( f"Failed to decode base64 PKCS#7 signature: {str(e)}" - ) + ) from e # If we reach here, no PKCS#7 part was found raise ValueError( "No application/x-pkcs7-signature part found in the MIME message." @@ -144,7 +144,8 @@ class P7SParser(AsyncParser[str | bytes]): async def ingest( self, data: str | bytes, **kwargs ) -> AsyncGenerator[str, None]: - """Ingest an S/MIME message and extract the PKCS#7 signature information.""" + """Ingest an S/MIME message and extract the PKCS#7 signature + information.""" # If data is a string, it might be base64 encoded, or it might be the raw MIME text. # We should assume it's raw MIME text here because the input includes MIME headers. if isinstance(data, str): @@ -174,4 +175,4 @@ class P7SParser(AsyncParser[str | bytes]): yield f"Certificate {i}: No detailed information extracted." except Exception as e: - raise ValueError(f"Error processing P7S file: {str(e)}") + raise ValueError(f"Error processing P7S file: {str(e)}") from e diff --git a/py/core/parsers/structured/rst_parser.py b/py/core/parsers/structured/rst_parser.py index b073cd686..763906552 100644 --- a/py/core/parsers/structured/rst_parser.py +++ b/py/core/parsers/structured/rst_parser.py @@ -55,4 +55,4 @@ class RSTParser(AsyncParser[str | bytes]): yield paragraph.strip() except Exception as e: - raise ValueError(f"Error processing RST file: {str(e)}") + raise ValueError(f"Error processing RST file: {str(e)}") from e diff --git a/py/core/parsers/structured/tiff_parser.py b/py/core/parsers/structured/tiff_parser.py index 7bec6d1d6..046e6736f 100644 --- a/py/core/parsers/structured/tiff_parser.py +++ b/py/core/parsers/structured/tiff_parser.py @@ -48,7 +48,7 @@ class TIFFParser(AsyncParser[str | bytes]): tiff_image.save(output_buffer, format="JPEG", quality=95) return output_buffer.getvalue() except Exception as e: - raise ValueError(f"Error converting TIFF to JPEG: {str(e)}") + raise ValueError(f"Error converting TIFF to JPEG: {str(e)}") from e async def ingest( self, data: str | bytes, **kwargs @@ -102,4 +102,4 @@ class TIFFParser(AsyncParser[str | bytes]): raise ValueError("No response content") except Exception as e: - raise ValueError(f"Error processing TIFF file: {str(e)}") + raise ValueError(f"Error processing TIFF file: {str(e)}") from e diff --git a/py/core/providers/auth/jwt.py b/py/core/providers/auth/jwt.py index 2a7e96826..08f85e6df 100644 --- a/py/core/providers/auth/jwt.py +++ b/py/core/providers/auth/jwt.py @@ -74,11 +74,11 @@ class JwtAuthProvider(AuthProvider): logger.info(f"JWT verification failed: {e}") raise R2RException( status_code=401, message="Invalid JWT token", detail=e - ) + ) from e if user: # Create user in database if not exists try: - existingUser = await self.database_provider.users_handler.get_user_by_email( + await self.database_provider.users_handler.get_user_by_email( user.get("email") ) # TODO do we want to update user info here based on what's in the token? @@ -95,7 +95,7 @@ class JwtAuthProvider(AuthProvider): logger.error(f"Error creating user: {e}") raise R2RException( status_code=500, message="Failed to create user" - ) + ) from e return TokenData( email=user.get("email"), token_type="bearer", diff --git a/py/core/providers/auth/r2r_auth.py b/py/core/providers/auth/r2r_auth.py index f4d935e1c..762884ce3 100644 --- a/py/core/providers/auth/r2r_auth.py +++ b/py/core/providers/auth/r2r_auth.py @@ -29,9 +29,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def normalize_email(email: str) -> str: - """ - Normalizes an email address by converting it to lowercase. - This ensures consistent email handling throughout the application. + """Normalizes an email address by converting it to lowercase. This ensures + consistent email handling throughout the application. Args: email: The email address to normalize @@ -146,16 +145,16 @@ class R2RAuthProvider(AuthProvider): ) async def authenticate_api_key(self, api_key: str) -> User: - """ - Authenticate using an API key of the form "public_key.raw_key". + """Authenticate using an API key of the form "public_key.raw_key". + Returns a User if successful, or raises R2RException if not. """ try: key_id, raw_key = api_key.split(".", 1) - except ValueError: + except ValueError as e: raise R2RException( status_code=401, message="Invalid API key format" - ) + ) from e key_record = ( await self.database_provider.users_handler.get_api_key_record( @@ -181,9 +180,7 @@ class R2RAuthProvider(AuthProvider): return user async def user(self, token: str = Depends(oauth2_scheme)) -> User: - """ - Attempt to authenticate via JWT first, then fallback to API key. - """ + """Attempt to authenticate via JWT first, then fallback to API key.""" # Try JWT auth try: token_data = await self.decode_token(token=token) @@ -260,7 +257,7 @@ class R2RAuthProvider(AuthProvider): owner_id=new_user.id, ) ) - graph_result = await self.database_provider.graphs_handler.create( + await self.database_provider.graphs_handler.create( collection_id=default_collection.id, name=default_collection.name, description=default_collection.description, @@ -620,10 +617,12 @@ class R2RAuthProvider(AuthProvider): async def oauth_callback_handler( self, provider: str, oauth_id: str, email: str ) -> dict[str, Token]: - """ - Handles a login/registration flow for OAuth providers (e.g., Google or GitHub). + """Handles a login/registration flow for OAuth providers (e.g., Google + or GitHub). + :param provider: "google" or "github" - :param oauth_id: The unique ID from the OAuth provider (e.g. Google's 'sub') + :param oauth_id: The unique ID from the OAuth provider (e.g. Google's + 'sub') :param email: The user's email from the provider, if available. :return: dict with access_token and refresh_token """ @@ -641,7 +640,7 @@ class R2RAuthProvider(AuthProvider): status_code=401, message="User already exists and is not linked to Google account", ) - except: + except Exception: # Create new user user = await self.register( email=normalize_email(email) @@ -661,7 +660,7 @@ class R2RAuthProvider(AuthProvider): status_code=401, message="User already exists and is not linked to Github account", ) - except: + except Exception: # Create new user user = await self.register( email=normalize_email(email) @@ -676,7 +675,7 @@ class R2RAuthProvider(AuthProvider): # If no user found or creation fails raise R2RException( status_code=401, message="Could not create or fetch user" - ) + ) from None # If user is inactive, etc. if not user.is_active: diff --git a/py/core/providers/auth/supabase.py b/py/core/providers/auth/supabase.py index 7c4e88ad8..5fc0e0bfb 100644 --- a/py/core/providers/auth/supabase.py +++ b/py/core/providers/auth/supabase.py @@ -23,7 +23,6 @@ from ..database import PostgresDatabaseProvider logger = logging.getLogger() - logger = logging.getLogger() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -83,7 +82,7 @@ class SupabaseAuthProvider(AuthProvider): ) -> User: # type: ignore # Use Supabase client to create a new user - if user := self.supabase.auth.sign_up(email=email, password=password): + if self.supabase.auth.sign_up(email=email, password=password): raise R2RException( status_code=400, message="Supabase provider implementation is still under construction", @@ -104,9 +103,7 @@ class SupabaseAuthProvider(AuthProvider): self, email: str, verification_code: str ) -> dict[str, str]: # Use Supabase client to verify email - if response := self.supabase.auth.verify_email( - email, verification_code - ): + if self.supabase.auth.verify_email(email, verification_code): return {"message": "Email verified successfully"} else: raise R2RException( @@ -181,9 +178,7 @@ class SupabaseAuthProvider(AuthProvider): self, user: User, current_password: str, new_password: str ) -> dict[str, str]: # Use Supabase client to update user password - if response := self.supabase.auth.update( - user.id, {"password": new_password} - ): + if self.supabase.auth.update(user.id, {"password": new_password}): return {"message": "Password changed successfully"} else: raise R2RException( @@ -192,7 +187,7 @@ class SupabaseAuthProvider(AuthProvider): async def request_password_reset(self, email: str) -> dict[str, str]: # Use Supabase client to send password reset email - if response := self.supabase.auth.send_password_reset_email(email): + if self.supabase.auth.send_password_reset_email(email): return { "message": "If the email exists, a reset link has been sent" } @@ -205,7 +200,7 @@ class SupabaseAuthProvider(AuthProvider): self, reset_token: str, new_password: str ) -> dict[str, str]: # Use Supabase client to reset password with token - if response := self.supabase.auth.reset_password_for_email( + if self.supabase.auth.reset_password_for_email( reset_token, new_password ): return {"message": "Password reset successfully"} diff --git a/py/core/providers/crypto/bcrypt.py b/py/core/providers/crypto/bcrypt.py index 9d5d8e09e..9c39977c1 100644 --- a/py/core/providers/crypto/bcrypt.py +++ b/py/core/providers/crypto/bcrypt.py @@ -41,7 +41,7 @@ class BcryptCryptoConfig(CryptoConfig): try: # First try to decode as base64 (new format) stored_hash = base64.b64decode(hashed_password.encode("utf-8")) - except: + except Exception: # If that fails, treat as raw bcrypt hash (old format) stored_hash = hashed_password.encode("utf-8") @@ -86,7 +86,7 @@ class BCryptCryptoProvider(CryptoProvider, ABC): stored_hash = base64.b64decode(hashed_password.encode("utf-8")) if not stored_hash.startswith(b"$2b$"): # Valid bcrypt hash prefix stored_hash = hashed_password.encode("utf-8") - except: + except Exception: # Otherwise raw bcrypt hash (old format) stored_hash = hashed_password.encode("utf-8") @@ -131,7 +131,9 @@ class BCryptCryptoProvider(CryptoProvider, ABC): signature = signing_key.sign(data.encode()) return base64.b64encode(signature.signature).decode() except Exception as e: - raise ValueError(f"Invalid private key or signing error: {str(e)}") + raise ValueError( + f"Invalid private key or signing error: {str(e)}" + ) from e def verify_request_signature( self, public_key: str, signature: str, data: str diff --git a/py/core/providers/crypto/nacl.py b/py/core/providers/crypto/nacl.py index e11e38f8c..63232565d 100644 --- a/py/core/providers/crypto/nacl.py +++ b/py/core/providers/crypto/nacl.py @@ -19,7 +19,8 @@ DEFAULT_NACL_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Repla def encode_bytes_readable(random_bytes: bytes, chars: str) -> str: - """Convert random bytes to a readable string using the given character set.""" + """Convert random bytes to a readable string using the given character + set.""" # Each byte gives us 8 bits of randomness # We use modulo to map each byte to our character set result = [] @@ -122,7 +123,9 @@ class NaClCryptoProvider(CryptoProvider): signature = signing_key.sign(data.encode()) return base64.b64encode(signature.signature).decode() except Exception as e: - raise ValueError(f"Invalid private key or signing error: {str(e)}") + raise ValueError( + f"Invalid private key or signing error: {str(e)}" + ) from e def verify_request_signature( self, public_key: str, signature: str, data: str @@ -137,8 +140,8 @@ class NaClCryptoProvider(CryptoProvider): return False def generate_secure_token(self, data: dict, expiry: datetime) -> str: - """ - Generate a secure token using JWT with HS256. + """Generate a secure token using JWT with HS256. + The secret_key is used for symmetrical signing. """ now = datetime.now(timezone.utc) @@ -154,9 +157,7 @@ class NaClCryptoProvider(CryptoProvider): return jwt.encode(to_encode, self.secret_key, algorithm="HS256") def verify_secure_token(self, token: str) -> Optional[dict]: - """ - Verify a secure token using the shared secret_key and JWT. - """ + """Verify a secure token using the shared secret_key and JWT.""" try: payload = jwt.decode(token, self.secret_key, algorithms=["HS256"]) exp = payload.get("exp") diff --git a/py/core/providers/database/base.py b/py/core/providers/database/base.py index 4a7d15503..159af7fb0 100644 --- a/py/core/providers/database/base.py +++ b/py/core/providers/database/base.py @@ -198,8 +198,7 @@ class PostgresConnectionManager(DatabaseConnectionManager): else await conn.fetch(query) ) except asyncpg.exceptions.DuplicatePreparedStatementError: - error_msg = textwrap.dedent( - """ + error_msg = textwrap.dedent(""" Database Configuration Error Your database provider does not support statement caching. @@ -213,8 +212,7 @@ class PostgresConnectionManager(DatabaseConnectionManager): This is required when using connection poolers like PgBouncer or managed database services like Supabase. - """ - ).strip() + """).strip() raise ValueError(error_msg) from None async def fetchrow_query(self, query, params=None): @@ -229,8 +227,7 @@ class PostgresConnectionManager(DatabaseConnectionManager): @asynccontextmanager async def transaction(self, isolation_level=None): - """ - Async context manager for database transactions. + """Async context manager for database transactions. Args: isolation_level: Optional isolation level for the transaction diff --git a/py/core/providers/database/chunks.py b/py/core/providers/database/chunks.py index b07ab437e..8bb22d0ac 100644 --- a/py/core/providers/database/chunks.py +++ b/py/core/providers/database/chunks.py @@ -22,17 +22,17 @@ from core.base import ( VectorQuantizationType, VectorTableName, ) +from core.base.utils import _decorate_vector_type from .base import PostgresConnectionManager from .filters import apply_filters logger = logging.getLogger() -from core.base.utils import _decorate_vector_type def psql_quote_literal(value: str) -> str: - """ - Safely quote a string literal for PostgreSQL to prevent SQL injection. + """Safely quote a string literal for PostgreSQL to prevent SQL injection. + This is a simple implementation - in production, you should use proper parameterization or your database driver's quoting functions. """ @@ -50,9 +50,8 @@ def quantize_vector_to_binary( vector: list[float] | np.ndarray, threshold: float = 0.0, ) -> bytes: - """ - Quantizes a float vector to a binary vector string for PostgreSQL bit type. - Used when quantization_type is INT1. + """Quantizes a float vector to a binary vector string for PostgreSQL bit + type. Used when quantization_type is INT1. Args: vector (List[float] | np.ndarray): Input vector of floats @@ -148,9 +147,11 @@ class PostgresChunksHandler(Handler): await self.connection_manager.execute_query(query) async def upsert(self, entry: VectorEntry) -> None: - """ - Upsert function that handles vector quantization only when quantization_type is INT1. - Matches the table schema where vec_binary column only exists for INT1 quantization. + """Upsert function that handles vector quantization only when + quantization_type is INT1. + + Matches the table schema where vec_binary column only exists for INT1 + quantization. """ # Check the quantization type to determine which columns to use if self.quantization_type == VectorQuantizationType.INT1: @@ -216,9 +217,11 @@ class PostgresChunksHandler(Handler): ) async def upsert_entries(self, entries: list[VectorEntry]) -> None: - """ - Batch upsert function that handles vector quantization only when quantization_type is INT1. - Matches the table schema where vec_binary column only exists for INT1 quantization. + """Batch upsert function that handles vector quantization only when + quantization_type is INT1. + + Matches the table schema where vec_binary column only exists for INT1 + quantization. """ if self.quantization_type == VectorQuantizationType.INT1: bit_dim = ( @@ -293,7 +296,7 @@ class PostgresChunksHandler(Handler): search_settings.chunk_settings.index_measure ) except ValueError: - raise ValueError("Invalid index measure") + raise ValueError("Invalid index measure") from None table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME) cols = [ @@ -664,9 +667,7 @@ class PostgresChunksHandler(Handler): WHERE $1 = ANY(collection_ids) RETURNING collection_ids """ - results = await self.connection_manager.fetchrow_query( - query, (collection_id,) - ) + await self.connection_manager.fetchrow_query(query, (collection_id,)) return None async def list_document_chunks( @@ -745,8 +746,7 @@ class PostgresChunksHandler(Handler): index_column: Optional[str] = None, concurrently: bool = True, ) -> None: - """ - Creates an index for the collection. + """Creates an index for the collection. Note: When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step @@ -873,7 +873,7 @@ class PostgresChunksHandler(Handler): # Non-concurrent index creation can use normal query execution await self.connection_manager.execute_query(create_index_sql) except Exception as e: - raise Exception(f"Failed to create index: {e}") + raise Exception(f"Failed to create index: {e}") from e return None async def list_indices( @@ -967,8 +967,7 @@ class PostgresChunksHandler(Handler): table_name: Optional[VectorTableName] = None, concurrently: bool = True, ) -> None: - """ - Deletes a vector index. + """Deletes a vector index. Args: index_name (str): Name of the index to delete @@ -1042,7 +1041,7 @@ class PostgresChunksHandler(Handler): else: await self.connection_manager.execute_query(drop_query) except Exception as e: - raise Exception(f"Failed to delete index: {e}") + raise Exception(f"Failed to delete index: {e}") from e async def list_chunks( self, @@ -1051,8 +1050,7 @@ class PostgresChunksHandler(Handler): filters: Optional[dict[str, Any]] = None, include_vectors: bool = False, ) -> dict[str, Any]: - """ - List chunks with pagination support. + """List chunks with pagination support. Args: offset (int, optional): Number of records to skip. Defaults to 0. @@ -1118,9 +1116,8 @@ class PostgresChunksHandler(Handler): query_text: str, settings: SearchSettings, ) -> list[dict[str, Any]]: - """ - Search for documents based on their metadata fields and/or body text. - Joins with documents table to get complete document metadata. + """Search for documents based on their metadata fields and/or body + text. Joins with documents table to get complete document metadata. Args: query_text (str): The search query text diff --git a/py/core/providers/database/collections.py b/py/core/providers/database/collections.py index e3f42c4bc..879cbe96f 100644 --- a/py/core/providers/database/collections.py +++ b/py/core/providers/database/collections.py @@ -21,7 +21,6 @@ from core.base.abstractions import ( IngestionStatus, ) from core.base.api.models import CollectionResponse -from core.utils import generate_default_user_collection_id from .base import PostgresConnectionManager @@ -169,7 +168,7 @@ class PostgresCollectionsHandler(Handler): raise R2RException( message="Collection with this ID already exists", status_code=409, - ) + ) from None except Exception as e: raise HTTPException( status_code=500, @@ -290,8 +289,8 @@ class PostgresCollectionsHandler(Handler): async def documents_in_collection( self, collection_id: UUID, offset: int, limit: int ) -> dict[str, list[DocumentResponse] | int]: - """ - Get all documents in a specific collection with pagination. + """Get all documents in a specific collection with pagination. + Args: collection_id (UUID): The ID of the collection to get documents from. offset (int): The number of documents to skip. @@ -356,28 +355,24 @@ class PostgresCollectionsHandler(Handler): param_index = 1 if filter_user_ids: - conditions.append( - f""" + conditions.append(f""" c.id IN ( SELECT unnest(collection_ids) FROM {self.project_name}.users WHERE id = ANY(${param_index}) ) - """ - ) + """) params.append(filter_user_ids) param_index += 1 if filter_document_ids: - conditions.append( - f""" + conditions.append(f""" c.id IN ( SELECT unnest(collection_ids) FROM {self.project_name}.documents WHERE id = ANY(${param_index}) ) - """ - ) + """) params.append(filter_document_ids) param_index += 1 @@ -428,8 +423,7 @@ class PostgresCollectionsHandler(Handler): document_id: UUID, collection_id: UUID, ) -> UUID: - """ - Assign a document to a collection. + """Assign a document to a collection. Args: document_id (UUID): The ID of the document to assign. @@ -500,8 +494,7 @@ class PostgresCollectionsHandler(Handler): async def remove_document_from_collection_relational( self, document_id: UUID, collection_id: UUID ) -> None: - """ - Remove a document from a collection. + """Remove a document from a collection. Args: document_id (UUID): The ID of the document to remove. @@ -536,8 +529,7 @@ class PostgresCollectionsHandler(Handler): async def decrement_collection_document_count( self, collection_id: UUID, decrement_by: int = 1 ) -> None: - """ - Decrement the document count for a collection. + """Decrement the document count for a collection. Args: collection_id (UUID): The ID of the collection to update @@ -558,9 +550,8 @@ class PostgresCollectionsHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "owner_id", @@ -676,8 +667,8 @@ class PostgresCollectionsHandler(Handler): async def get_collection_by_name( self, owner_id: UUID, name: str ) -> Optional[CollectionResponse]: - """ - Fetch a collection by owner_id + name combination. + """Fetch a collection by owner_id + name combination. + Return None if not found. """ query = f""" diff --git a/py/core/providers/database/conversations.py b/py/core/providers/database/conversations.py index 24c292b6f..e93241f6d 100644 --- a/py/core/providers/database/conversations.py +++ b/py/core/providers/database/conversations.py @@ -17,7 +17,8 @@ from .base import PostgresConnectionManager def _json_default(obj: Any) -> str: - """Default handler for objects not serializable by the standard json encoder.""" + """Default handler for objects not serializable by the standard json + encoder.""" if isinstance(obj, datetime): # Return ISO8601 string return obj.isoformat() @@ -107,15 +108,13 @@ class PostgresConversationsHandler(Handler): param_index = 1 if filter_user_ids: - conditions.append( - f""" + conditions.append(f""" c.user_id IN ( SELECT id FROM {self.project_name}.users WHERE id = ANY(${param_index}) ) - """ - ) + """) params.append(filter_user_ids) param_index += 1 @@ -350,15 +349,13 @@ class PostgresConversationsHandler(Handler): if filter_user_ids: param_index = 2 - conditions.append( - f""" + conditions.append(f""" c.user_id IN ( SELECT id FROM {self.project_name}.users WHERE id = ANY(${param_index}) ) - """ - ) + """) params.append(filter_user_ids) query = f""" @@ -439,15 +436,13 @@ class PostgresConversationsHandler(Handler): if filter_user_ids: param_index = 2 - conditions.append( - f""" + conditions.append(f""" c.user_id IN ( SELECT id FROM {self.project_name}.users WHERE id = ANY(${param_index}) ) - """ - ) + """) params.append(filter_user_ids) conv_query = f""" @@ -482,9 +477,8 @@ class PostgresConversationsHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "user_id", @@ -585,9 +579,8 @@ class PostgresConversationsHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "conversation_id", diff --git a/py/core/providers/database/documents.py b/py/core/providers/database/documents.py index 39c37ae53..0f10207de 100644 --- a/py/core/providers/database/documents.py +++ b/py/core/providers/database/documents.py @@ -29,9 +29,8 @@ logger = logging.getLogger() def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]: - """ - Recursively transform filter field names by replacing 'document_id' with 'id'. - Handles nested logical operators like $and, $or, etc. + """Recursively transform filter field names by replacing 'document_id' with + 'id'. Handles nested logical operators like $and, $or, etc. Args: filters (dict[str, Any]): The original filters dictionary @@ -351,8 +350,7 @@ class PostgresDocumentsHandler(Handler): status_type: str, column_name: str, ): - """ - Get the workflow status for a given document or list of documents. + """Get the workflow status for a given document or list of documents. Args: ids (list[UUID]): The document IDs. @@ -378,8 +376,7 @@ class PostgresDocumentsHandler(Handler): status_type: str, collection_id: Optional[UUID] = None, ): - """ - Get the IDs from a given table. + """Get the IDs from a given table. Args: status (str | list[str]): The status or list of statuses to retrieve. @@ -403,8 +400,7 @@ class PostgresDocumentsHandler(Handler): status_type: str, column_name: str, ): - """ - Set the workflow status for a given document or list of documents. + """Set the workflow status for a given document or list of documents. Args: ids (list[UUID]): The document IDs. @@ -421,8 +417,7 @@ class PostgresDocumentsHandler(Handler): await self.connection_manager.execute_query(query, [status, ids]) def _get_status_model(self, status_type: str): - """ - Get the status model for a given status type. + """Get the status model for a given status type. Args: status_type (str): The type of status to retrieve. @@ -444,8 +439,7 @@ class PostgresDocumentsHandler(Handler): async def get_workflow_status( self, id: UUID | list[UUID], status_type: str ): - """ - Get the workflow status for a given document or list of documents. + """Get the workflow status for a given document or list of documents. Args: id (UUID | list[UUID]): The document ID or list of document IDs. @@ -470,8 +464,7 @@ class PostgresDocumentsHandler(Handler): async def set_workflow_status( self, id: UUID | list[UUID], status_type: str, status: str ): - """ - Set the workflow status for a given document or list of documents. + """Set the workflow status for a given document or list of documents. Args: id (UUID | list[UUID]): The document ID or list of document IDs. @@ -495,8 +488,7 @@ class PostgresDocumentsHandler(Handler): status: str | list[str], collection_id: Optional[UUID] = None, ): - """ - Get the IDs for a given status. + """Get the IDs for a given status. Args: ids_key (str): The key to retrieve the IDs. @@ -522,8 +514,7 @@ class PostgresDocumentsHandler(Handler): include_summary_embedding: Optional[bool] = True, filters: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: - """ - Fetch overviews of documents with optional offset/limit pagination. + """Fetch overviews of documents with optional offset/limit pagination. You can use either: - Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids` @@ -706,7 +697,8 @@ class PostgresDocumentsHandler(Handler): async def semantic_document_search( self, query_embedding: list[float], search_settings: SearchSettings ) -> list[DocumentResponse]: - """Search documents using semantic similarity with their summary embeddings.""" + """Search documents using semantic similarity with their summary + embeddings.""" where_clauses = ["summary_embedding IS NOT NULL"] params: list[str | int | bytes] = [str(query_embedding)] @@ -888,7 +880,8 @@ class PostgresDocumentsHandler(Handler): query_embedding: list[float], search_settings: SearchSettings, ) -> list[DocumentResponse]: - """Search documents using both semantic and full-text search with RRF fusion.""" + """Search documents using both semantic and full-text search with RRF + fusion.""" # Get more results than needed for better fusion extended_settings = copy.deepcopy(search_settings) @@ -979,9 +972,8 @@ class PostgresDocumentsHandler(Handler): query_embedding: Optional[list[float]] = None, settings: Optional[SearchSettings] = None, ) -> list[DocumentResponse]: - """ - Main search method that delegates to the appropriate search method based on settings. - """ + """Main search method that delegates to the appropriate search method + based on settings.""" if settings is None: settings = SearchSettings() @@ -1012,9 +1004,8 @@ class PostgresDocumentsHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "collection_ids", diff --git a/py/core/providers/database/files.py b/py/core/providers/database/files.py index f8bcff3fa..e6069a89e 100644 --- a/py/core/providers/database/files.py +++ b/py/core/providers/database/files.py @@ -243,11 +243,11 @@ class PostgresFilesHandler(Handler): if not chunk: break file_data.write(chunk) - except asyncpg.exceptions.UndefinedObjectError as e: + except asyncpg.exceptions.UndefinedObjectError: raise R2RException( status_code=404, - message=f"Failed to read large object {oid}: {e}", - ) + message=f"Failed to read large object {oid}", + ) from None finally: await conn.execute("SELECT lo_close($1)", lobject) diff --git a/py/core/providers/database/filters.py b/py/core/providers/database/filters.py index fb6a36625..5367b9204 100644 --- a/py/core/providers/database/filters.py +++ b/py/core/providers/database/filters.py @@ -168,9 +168,10 @@ class SQLFilterBuilder: @staticmethod def _psql_quote_literal(value: str) -> str: - """ - Simple quoting for demonstration. In production, use parameterized queries or - your DB driver's quoting function instead. + """Simple quoting for demonstration. + + In production, use parameterized queries or your DB driver's quoting + function instead. """ return "'" + value.replace("'", "''") + "'" @@ -195,8 +196,8 @@ class SQLFilterBuilder: return self._build_column_condition(key, op, val) def _build_parent_id_condition(self, op: str, val: Any) -> str: - """ - For 'graphs' tables, parent_id is a single UUID (not an array). + """For 'graphs' tables, parent_id is a single UUID (not an array). + We handle the same ops but in a simpler, single-UUID manner. """ param_idx = len(self.params) + 1 @@ -240,8 +241,8 @@ class SQLFilterBuilder: raise FilterError(f"Unsupported operator {op} for parent_id") def _build_collection_id_condition(self, op: str, val: Any) -> str: - """ - For the 'chunks' table, collection_ids is an array of UUIDs. + """For the 'chunks' table, collection_ids is an array of UUIDs. + This logic stays exactly as you had it. """ param_idx = len(self.params) + 1 @@ -441,9 +442,7 @@ class SQLFilterBuilder: def apply_filters( filters: dict, params: list[Any], mode: str = "where_clause" ) -> tuple[str, list[Any]]: - """ - Apply filters with consistent WHERE clause handling - """ + """Apply filters with consistent WHERE clause handling.""" if not filters: return "", params diff --git a/py/core/providers/database/graphs.py b/py/core/providers/database/graphs.py index ba71e9571..d7857fd3e 100644 --- a/py/core/providers/database/graphs.py +++ b/py/core/providers/database/graphs.py @@ -326,9 +326,8 @@ class PostgresEntitiesHandler(Handler): entity_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPHS, ) -> None: - """ - Delete entities from the specified store. - If entity_ids is not provided, deletes all entities for the given parent_id. + """Delete entities from the specified store. If entity_ids is not + provided, deletes all entities for the given parent_id. Args: parent_id (UUID): Parent ID (collection_id or document_id) @@ -378,10 +377,12 @@ class PostgresEntitiesHandler(Handler): parent_id: UUID, store_type: StoreType, ) -> list[list[Entity]]: - """ - Find all groups of entities that share identical names within the same parent. - Returns a list of entity groups, where each group contains entities with the same name. - For each group, includes the n most dissimilar descriptions based on cosine similarity. + """Find all groups of entities that share identical names within the + same parent. + + Returns a list of entity groups, where each group contains entities + with the same name. For each group, includes the n most dissimilar + descriptions based on cosine similarity. """ table_name = self._get_entity_table_for_store(store_type) @@ -425,8 +426,8 @@ class PostgresEntitiesHandler(Handler): parent_id: UUID, store_type: StoreType, ) -> list[tuple[list[Entity], Entity]]: - """ - Merge entities that share identical names. + """Merge entities that share identical names. + Returns list of tuples: (original_entities, merged_entity) """ duplicate_blocks = await self.get_duplicate_name_blocks( @@ -515,8 +516,8 @@ class PostgresEntitiesHandler(Handler): return result[0]["id"] async def _create_merged_entity(self, entities: list[Entity]) -> Entity: - """ - Create a merged entity from a list of duplicate entities. + """Create a merged entity from a list of duplicate entities. + Uses various strategies to combine fields. """ if not entities: @@ -567,9 +568,8 @@ class PostgresEntitiesHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "name", @@ -836,8 +836,7 @@ class PostgresRelationshipsHandler(Handler): relationship_types: Optional[list[str]] = None, include_metadata: bool = False, ): - """ - Get relationships from the specified store. + """Get relationships from the specified store. Args: parent_id: UUID of the parent (collection_id or document_id) @@ -1037,9 +1036,8 @@ class PostgresRelationshipsHandler(Handler): relationship_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPHS, ) -> None: - """ - Delete relationships from the specified store. - If relationship_ids is not provided, deletes all relationships for the given parent_id. + """Delete relationships from the specified store. If relationship_ids + is not provided, deletes all relationships for the given parent_id. Args: parent_id: UUID of the parent (collection_id or document_id) @@ -1088,9 +1086,8 @@ class PostgresRelationshipsHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "subject", @@ -1494,9 +1491,8 @@ class PostgresCommunitiesHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "collection_id", @@ -1710,12 +1706,10 @@ class PostgresGraphsHandler(Handler): raise R2RException( message="Graph with this ID already exists", status_code=409, - ) + ) from None async def reset(self, parent_id: UUID) -> None: - """ - Completely reset a graph and all associated data. - """ + """Completely reset a graph and all associated data.""" await self.entities.delete( parent_id=parent_id, store_type=StoreType.GRAPHS @@ -1853,9 +1847,8 @@ class PostgresGraphsHandler(Handler): } async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool: - """ - Add documents to the graph by copying their entities and relationships. - """ + """Add documents to the graph by copying their entities and + relationships.""" # Copy entities from document_entity to graphs_entities ENTITY_COPY_QUERY = f""" INSERT INTO {self._get_table_name("graphs_entities")} ( @@ -1973,8 +1966,7 @@ class PostgresGraphsHandler(Handler): entity_names: Optional[list[str]] = None, include_embeddings: bool = False, ) -> tuple[list[Entity], int]: - """ - Get entities for a graph. + """Get entities for a graph. Args: offset: Number of records to skip @@ -2058,8 +2050,7 @@ class PostgresGraphsHandler(Handler): relationship_types: Optional[list[str]] = None, include_embeddings: bool = False, ) -> tuple[list[Relationship], int]: - """ - Get relationships for a graph. + """Get relationships for a graph. Args: parent_id: UUID of the graph @@ -2137,10 +2128,10 @@ class PostgresGraphsHandler(Handler): self, entities: list[Entity], table_name: str, - conflict_columns: list[str] = [], + conflict_columns: list[str] | None = None, ) -> asyncpg.Record: - """ - Upsert entities into the entities_raw table. These are raw entities extracted from the document. + """Upsert entities into the entities_raw table. These are raw entities + extracted from the document. Args: entities: list[Entity]: list of entities to upsert @@ -2149,6 +2140,8 @@ class PostgresGraphsHandler(Handler): Returns: result: asyncpg.Record: result of the upsert operation """ + if not conflict_columns: + conflict_columns = [] cleaned_entities = [] for entity in entities: entity_dict = entity.to_dict() @@ -2187,8 +2180,7 @@ class PostgresGraphsHandler(Handler): return [Relationship(**relationship) for relationship in relationships] async def has_document(self, graph_id: UUID, document_id: UUID) -> bool: - """ - Check if a document exists in the graph's document_ids array. + """Check if a document exists in the graph's document_ids array. Args: graph_id (UUID): ID of the graph to check @@ -2227,8 +2219,7 @@ class PostgresGraphsHandler(Handler): community_ids: Optional[list[UUID]] = None, include_embeddings: bool = False, ) -> tuple[list[Community], int]: - """ - Get communities for a graph. + """Get communities for a graph. Args: collection_id: UUID of the collection @@ -2349,9 +2340,7 @@ class PostgresGraphsHandler(Handler): leiden_params: dict[str, Any], clustering_mode: str, ) -> Tuple[int, Any]: - """ - Calls the external clustering service to cluster the graph. - """ + """Calls the external clustering service to cluster the graph.""" offset = 0 page_size = 1000 @@ -2392,8 +2381,9 @@ class PostgresGraphsHandler(Handler): async def _call_clustering_service( self, relationships: list[Relationship], leiden_params: dict[str, Any] ) -> list[dict]: - """ - Calls the external Graspologic clustering service, sending relationships and parameters. + """Calls the external Graspologic clustering service, sending + relationships and parameters. + Expects a response with 'communities' field. """ # Convert relationships to a JSON-friendly format @@ -2429,9 +2419,10 @@ class PostgresGraphsHandler(Handler): leiden_params: dict[str, Any], clustering_mode: str = "remote", ) -> Any: - """ - Create a graph and cluster it. If clustering_mode='local', use hierarchical_leiden locally. - If clustering_mode='remote', call the external service. + """Create a graph and cluster it. + + If clustering_mode='local', use hierarchical_leiden locally. If + clustering_mode='remote', call the external service. """ if clustering_mode == "remote": @@ -2569,9 +2560,8 @@ class PostgresGraphsHandler(Handler): async def graph_search( self, query: str, **kwargs: Any ) -> AsyncGenerator[Any, None]: - """ - Perform semantic search with similarity scores while maintaining exact same structure. - """ + """Perform semantic search with similarity scores while maintaining + exact same structure.""" query_embedding = kwargs.get("query_embedding", None) if query_embedding is None: @@ -2643,8 +2633,8 @@ class PostgresGraphsHandler(Handler): def _build_filters( self, filter_dict: dict, parameters: list[Any], search_type: str ) -> str: - """ - Build a WHERE clause from a nested filter dictionary for the graph search. + """Build a WHERE clause from a nested filter dictionary for the graph + search. - If search_type == "communities", we normally filter by `collection_id`. - Otherwise (entities/relationships), we normally filter by `parent_id`. @@ -2797,8 +2787,10 @@ class PostgresGraphsHandler(Handler): ) return community_mapping - except ImportError as e: - raise ImportError("Please install the graspologic package.") from e + except ImportError: + raise ImportError( + "Please install the graspologic package." + ) from None async def get_existing_document_entity_chunk_ids( self, document_id: UUID @@ -2871,12 +2863,16 @@ async def _add_objects( objects: list[dict], full_table_name: str, connection_manager: PostgresConnectionManager, - conflict_columns: list[str] = [], - exclude_metadata: list[str] = [], + conflict_columns: list[str] | None = None, + exclude_metadata: list[str] | None = None, ) -> list[UUID]: - """ - Bulk insert objects into the specified table using jsonb_to_recordset. - """ + """Bulk insert objects into the specified table using + jsonb_to_recordset.""" + + if conflict_columns is None: + conflict_columns = [] + if exclude_metadata is None: + exclude_metadata = [] # Exclude specified metadata and prepare data cleaned_objects = [] diff --git a/py/core/providers/database/limits.py b/py/core/providers/database/limits.py index 5481e28ca..1029ec50e 100644 --- a/py/core/providers/database/limits.py +++ b/py/core/providers/database/limits.py @@ -48,10 +48,8 @@ class PostgresLimitsHandler(Handler): route: Optional[str], since: datetime, ) -> int: - """ - Count how many requests a user (optionally for a specific route) - has made since the given datetime. - """ + """Count how many requests a user (optionally for a specific route) has + made since the given datetime.""" if route: query = f""" SELECT COUNT(*)::int @@ -82,9 +80,10 @@ class PostgresLimitsHandler(Handler): user_id: UUID, route: Optional[str] = None, # <--- ADDED THIS ) -> int: - """ - Count the number of requests so far this month for a given user. - If route is provided, count only for that route. Otherwise, count globally. + """Count the number of requests so far this month for a given user. + + If route is provided, count only for that route. Otherwise, count + globally. """ now = datetime.now(timezone.utc) start_of_month = now.replace( @@ -154,8 +153,7 @@ class PostgresLimitsHandler(Handler): return effective async def check_limits(self, user: User, route: str): - """ - Perform rate limit checks for a user on a specific route. + """Perform rate limit checks for a user on a specific route. :param user: The fully-fetched User object with .limits_overrides, etc. :param route: The route/path being accessed. @@ -212,9 +210,7 @@ class PostgresLimitsHandler(Handler): raise ValueError("Monthly rate limit exceeded") async def log_request(self, user_id: UUID, route: str): - """ - Log a successful request to the request_log table. - """ + """Log a successful request to the request_log table.""" query = f""" INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route) @@ -236,7 +232,6 @@ class PostgresLimitsHandler(Handler): # logger = logging.getLogger(__name__) - # class PostgresLimitsHandler(Handler): # TABLE_NAME = "request_log" diff --git a/py/core/providers/database/prompts_handler.py b/py/core/providers/database/prompts_handler.py index 7c032231a..dd17f0658 100644 --- a/py/core/providers/database/prompts_handler.py +++ b/py/core/providers/database/prompts_handler.py @@ -20,7 +20,7 @@ T = TypeVar("T") @dataclass class CacheEntry(Generic[T]): - """Represents a cached item with metadata""" + """Represents a cached item with metadata.""" value: T created_at: datetime @@ -29,7 +29,7 @@ class CacheEntry(Generic[T]): class Cache(Generic[T]): - """A generic cache implementation with TTL and LRU-like features""" + """A generic cache implementation with TTL and LRU-like features.""" def __init__( self, @@ -44,7 +44,7 @@ class Cache(Generic[T]): self._last_cleanup = datetime.now() def get(self, key: str) -> Optional[T]: - """Retrieve an item from cache""" + """Retrieve an item from cache.""" self._maybe_cleanup() if key not in self._cache: @@ -61,7 +61,7 @@ class Cache(Generic[T]): return entry.value def set(self, key: str, value: T) -> None: - """Store an item in cache""" + """Store an item in cache.""" self._maybe_cleanup() now = datetime.now() @@ -73,22 +73,22 @@ class Cache(Generic[T]): self._evict_lru() def invalidate(self, key: str) -> None: - """Remove an item from cache""" + """Remove an item from cache.""" self._cache.pop(key, None) def clear(self) -> None: - """Clear all cached items""" + """Clear all cached items.""" self._cache.clear() def _maybe_cleanup(self) -> None: - """Periodically clean up expired entries""" + """Periodically clean up expired entries.""" now = datetime.now() if now - self._last_cleanup > self._cleanup_interval: self._cleanup() self._last_cleanup = now def _cleanup(self) -> None: - """Remove expired entries""" + """Remove expired entries.""" if not self._ttl: return @@ -100,7 +100,7 @@ class Cache(Generic[T]): del self._cache[k] def _evict_lru(self) -> None: - """Remove least recently used item""" + """Remove least recently used item.""" if not self._cache: return @@ -111,7 +111,8 @@ class Cache(Generic[T]): class CacheablePromptHandler(Handler): - """Abstract base class that adds caching capabilities to prompt handlers""" + """Abstract base class that adds caching capabilities to prompt + handlers.""" def __init__( self, @@ -126,7 +127,7 @@ class CacheablePromptHandler(Handler): def _cache_key( self, prompt_name: str, inputs: Optional[dict] = None ) -> str: - """Generate a cache key for a prompt request""" + """Generate a cache key for a prompt request.""" if inputs: # Sort dict items for consistent keys sorted_inputs = sorted(inputs.items()) @@ -205,7 +206,7 @@ class CacheablePromptHandler(Handler): ) -> str: if inputs: # optional input validation if needed - for k, v in inputs.items(): + for k, _v in inputs.items(): if k not in input_types: raise ValueError( f"Unexpected input '{k}' for prompt with input types {input_types}" @@ -219,7 +220,7 @@ class CacheablePromptHandler(Handler): template: Optional[str] = None, input_types: Optional[dict[str, str]] = None, ) -> None: - """Public method to update a prompt with proper cache invalidation""" + """Public method to update a prompt with proper cache invalidation.""" # First invalidate all caches for this prompt self._template_cache.invalidate(name) cache_keys_to_invalidate = [ @@ -245,12 +246,12 @@ class CacheablePromptHandler(Handler): template: Optional[str] = None, input_types: Optional[dict[str, str]] = None, ) -> None: - """Implementation of prompt update logic""" + """Implementation of prompt update logic.""" pass @abstractmethod async def _get_template_info(self, prompt_name: str) -> Optional[dict]: - """Get template info with caching""" + """Get template info with caching.""" pass @@ -320,11 +321,10 @@ class PostgresPromptsHandler(CacheablePromptHandler): async def _load_prompts_from_yaml_directory( self, default_overwrite_on_diff: bool = False ) -> None: - """ - Load prompts from YAML files in the specified directory. + """Load prompts from YAML files in the specified directory. :param default_overwrite_on_diff: If a YAML prompt does not specify - 'overwrite_on_diff', we use this default. + 'overwrite_on_diff', we use this default. """ if not self.prompt_directory.is_dir(): logger.warning( @@ -393,7 +393,7 @@ class PostgresPromptsHandler(CacheablePromptHandler): inputs: Optional[dict[str, Any]] = None, bypass_template_cache: bool = False, ) -> str: - """Implementation of database prompt retrieval""" + """Implementation of database prompt retrieval.""" # If we're bypassing the template cache, skip the cache lookup if not bypass_template_cache: template_info = self._template_cache.get(prompt_name) @@ -433,7 +433,7 @@ class PostgresPromptsHandler(CacheablePromptHandler): return self._format_prompt(template, inputs, input_types) async def _get_template_info(self, prompt_name: str) -> Optional[dict]: # type: ignore - """Get template info with caching""" + """Get template info with caching.""" cached = self._template_cache.get(prompt_name) if cached is not None: return cached @@ -469,7 +469,8 @@ class PostgresPromptsHandler(CacheablePromptHandler): template: Optional[str] = None, input_types: Optional[dict[str, str]] = None, ) -> None: - """Implementation of database prompt update with proper connection handling""" + """Implementation of database prompt update with proper connection + handling.""" if not template and not input_types: return @@ -563,8 +564,7 @@ class PostgresPromptsHandler(CacheablePromptHandler): preserve_existing: bool = False, overwrite_on_diff: bool = False, # <-- new param ) -> None: - """ - Add or update a prompt. + """Add or update a prompt. If `preserve_existing` is True and prompt already exists, we skip updating. @@ -699,14 +699,18 @@ class PostgresPromptsHandler(CacheablePromptHandler): self, system_prompt_name: Optional[str] = None, system_role: str = "system", - system_inputs: dict = {}, + system_inputs: dict | None = None, system_prompt_override: Optional[str] = None, task_prompt_name: Optional[str] = None, task_role: str = "user", - task_inputs: dict = {}, + task_inputs: dict | None = None, task_prompt_override: Optional[str] = None, ) -> list[dict]: """Create a message payload from system and task prompts.""" + if system_inputs is None: + system_inputs = {} + if task_inputs is None: + task_inputs = {} if system_prompt_override: system_prompt = system_prompt_override else: diff --git a/py/core/providers/database/users.py b/py/core/providers/database/users.py index 9e8c15848..208eeaa4d 100644 --- a/py/core/providers/database/users.py +++ b/py/core/providers/database/users.py @@ -369,8 +369,7 @@ class PostgresUserHandler(Handler): merge_limits: bool = False, new_metadata: dict[str, Optional[str]] | None = None, ) -> User: - """ - Update user information including limits_overrides. + """Update user information including limits_overrides. Args: user: User object containing updated information @@ -386,7 +385,9 @@ class PostgresUserHandler(Handler): try: current_user = await self.get_user_by_id(user.id) except R2RException: - raise R2RException(status_code=404, message="User not found") + raise R2RException( + status_code=404, message="User not found" + ) from None # If the new user.google_id != current_user.google_id, check for duplicates if user.email and (user.email != current_user.email): @@ -865,9 +866,7 @@ class PostgresUserHandler(Handler): limit: int, user_ids: Optional[list[UUID]] = None, ) -> dict[str, list[User] | int]: - """ - Return users with document usage and total entries. - """ + """Return users with document usage and total entries.""" query = f""" WITH user_document_ids AS ( SELECT @@ -964,9 +963,10 @@ class PostgresUserHandler(Handler): self, user_id: UUID, ) -> dict: - """ - Get verification data for a specific user. - This method should be called after superuser authorization has been verified. + """Get verification data for a specific user. + + This method should be called after superuser authorization has been + verified. """ query = f""" SELECT @@ -1008,9 +1008,8 @@ class PostgresUserHandler(Handler): name: Optional[str] = None, description: Optional[str] = None, ) -> UUID: - """ - Store a new API key for a user with optional name and description. - """ + """Store a new API key for a user with optional name and + description.""" query = f""" INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (user_id, public_key, hashed_key, name, description) @@ -1027,8 +1026,8 @@ class PostgresUserHandler(Handler): return result["id"] async def get_api_key_record(self, key_id: str) -> Optional[dict]: - """ - Get API key record by 'public_key' and update 'updated_at' to now. + """Get API key record by 'public_key' and update 'updated_at' to now. + Returns { "user_id", "hashed_key" } or None if not found. """ query = f""" @@ -1103,9 +1102,8 @@ class PostgresUserHandler(Handler): filters: Optional[dict] = None, include_header: bool = True, ) -> tuple[str, IO]: - """ - Creates a CSV file from the PostgreSQL data and returns the path to the temp file. - """ + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" valid_columns = { "id", "email", @@ -1216,7 +1214,7 @@ class PostgresUserHandler(Handler): raise HTTPException( status_code=500, detail=f"Failed to export data: {str(e)}", - ) + ) from e async def get_user_by_google_id(self, google_id: str) -> Optional[User]: """Return a User if the google_id is found; otherwise None.""" diff --git a/py/core/providers/email/console_mock.py b/py/core/providers/email/console_mock.py index 404d3b643..459a978d8 100644 --- a/py/core/providers/email/console_mock.py +++ b/py/core/providers/email/console_mock.py @@ -7,7 +7,8 @@ logger = logging.getLogger() class ConsoleMockEmailProvider(EmailProvider): - """A simple email provider that logs emails to console, useful for testing""" + """A simple email provider that logs emails to console, useful for + testing.""" async def send_email( self, @@ -18,50 +19,43 @@ class ConsoleMockEmailProvider(EmailProvider): *args, **kwargs, ) -> None: - logger.info( - f""" + logger.info(f""" -------- Email Message -------- To: {to_email} Subject: {subject} Body: {body} ----------------------------- - """ - ) + """) async def send_verification_email( self, to_email: str, verification_code: str, *args, **kwargs ) -> None: - logger.info( - f""" + logger.info(f""" -------- Email Message -------- To: {to_email} Subject: Please verify your email address Body: Verification code: {verification_code} ----------------------------- - """ - ) + """) async def send_password_reset_email( self, to_email: str, reset_token: str, *args, **kwargs ) -> None: - logger.info( - f""" + logger.info(f""" -------- Email Message -------- To: {to_email} Subject: Password Reset Request Body: Reset token: {reset_token} ----------------------------- - """ - ) + """) async def send_password_changed_email( self, to_email: str, *args, **kwargs ) -> None: - logger.info( - f""" + logger.info(f""" -------- Email Message -------- To: {to_email} Subject: Your Password Has Been Changed @@ -70,5 +64,4 @@ class ConsoleMockEmailProvider(EmailProvider): For security reasons, you will need to log in again on all your devices. ----------------------------- - """ - ) + """) diff --git a/py/core/providers/email/sendgrid.py b/py/core/providers/email/sendgrid.py index ff370ad38..8b2553f14 100644 --- a/py/core/providers/email/sendgrid.py +++ b/py/core/providers/email/sendgrid.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) class SendGridEmailProvider(EmailProvider): - """Email provider implementation using SendGrid API""" + """Email provider implementation using SendGrid API.""" def __init__(self, config: EmailConfig): super().__init__(config) @@ -48,7 +48,7 @@ class SendGridEmailProvider(EmailProvider): self.docs_base_url = f"{self.frontend_url}/documentation" def _get_base_template_data(self, to_email: str) -> dict: - """Get base template data used across all email templates""" + """Get base template data used across all email templates.""" return { "user_email": to_email, "docs_url": self.docs_base_url, diff --git a/py/core/providers/email/smtp.py b/py/core/providers/email/smtp.py index 79ac66e5c..bd68ff36e 100644 --- a/py/core/providers/email/smtp.py +++ b/py/core/providers/email/smtp.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) class AsyncSMTPEmailProvider(EmailProvider): - """Email provider implementation using Brevo SMTP relay""" + """Email provider implementation using Brevo SMTP relay.""" def __init__(self, config: EmailConfig): super().__init__(config) @@ -45,7 +45,7 @@ class AsyncSMTPEmailProvider(EmailProvider): self.ssl_context = ssl.create_default_context() async def _send_email_sync(self, msg: MIMEMultipart) -> None: - """Synchronous email sending wrapped in asyncio executor""" + """Synchronous email sending wrapped in asyncio executor.""" loop = asyncio.get_running_loop() def _send(): @@ -90,10 +90,10 @@ class AsyncSMTPEmailProvider(EmailProvider): logger.info("Initializing SMTP connection...") async with asyncio.timeout(30): # Overall timeout await self._send_email_sync(msg) - except asyncio.TimeoutError: + except asyncio.TimeoutError as e: error_msg = "Operation timed out while trying to send email" logger.error(error_msg) - raise RuntimeError(error_msg) + raise RuntimeError(error_msg) from e except Exception as e: error_msg = f"Failed to send email: {str(e)}" logger.error(error_msg) diff --git a/py/core/providers/embeddings/litellm.py b/py/core/providers/embeddings/litellm.py index 1d32e1847..5f705c912 100644 --- a/py/core/providers/embeddings/litellm.py +++ b/py/core/providers/embeddings/litellm.py @@ -93,7 +93,7 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider): error_msg = f"Error getting embeddings: {str(e)}" logger.error(error_msg) - raise R2RException(error_msg, 400) + raise R2RException(error_msg, 400) from e def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: texts = task["texts"] @@ -112,7 +112,7 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider): except Exception as e: error_msg = f"Error getting embeddings: {str(e)}" logger.error(error_msg) - raise R2RException(error_msg, 400) + raise R2RException(error_msg, 400) from e async def async_get_embedding( self, @@ -250,8 +250,8 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider): stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK, limit: int = 10, ) -> list[ChunkSearchResult]: - """ - Asynchronously rerank search results using the configured rerank model. + """Asynchronously rerank search results using the configured rerank + model. Args: query: The search query string diff --git a/py/core/providers/embeddings/ollama.py b/py/core/providers/embeddings/ollama.py index e83effcfe..297d9167c 100644 --- a/py/core/providers/embeddings/ollama.py +++ b/py/core/providers/embeddings/ollama.py @@ -71,7 +71,7 @@ class OllamaEmbeddingProvider(EmbeddingProvider): except Exception as e: error_msg = f"Error getting embeddings: {str(e)}" logger.error(error_msg) - raise R2RException(error_msg, 400) + raise R2RException(error_msg, 400) from e def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: texts = task["texts"] @@ -91,7 +91,7 @@ class OllamaEmbeddingProvider(EmbeddingProvider): except Exception as e: error_msg = f"Error getting embeddings: {str(e)}" logger.error(error_msg) - raise R2RException(error_msg, 400) + raise R2RException(error_msg, 400) from e async def async_get_embedding( self, diff --git a/py/core/providers/ingestion/r2r/base.py b/py/core/providers/ingestion/r2r/base.py index 908cd3cd1..a0d77b3dd 100644 --- a/py/core/providers/ingestion/r2r/base.py +++ b/py/core/providers/ingestion/r2r/base.py @@ -16,7 +16,6 @@ from core.base import ( RecursiveCharacterTextSplitter, TextSplitter, ) -from core.base.abstractions import DocumentChunk from core.utils import generate_extraction_id from ...database import PostgresDatabaseProvider diff --git a/py/core/providers/llm/anthropic.py b/py/core/providers/llm/anthropic.py index 5d2de4ff9..fc2b88bce 100644 --- a/py/core/providers/llm/anthropic.py +++ b/py/core/providers/llm/anthropic.py @@ -30,8 +30,7 @@ def generate_tool_id() -> str: def openai_message_to_anthropic_block(msg: dict) -> dict: - """ - Converts a single OpenAI-style message (including function/tool calls) + """Converts a single OpenAI-style message (including function/tool calls) into one Anthropic-style message. Expected keys in `msg` can include: @@ -112,8 +111,7 @@ class AnthropicCompletionProvider(CompletionProvider): logger.debug("AnthropicCompletionProvider initialized successfully") def _get_base_args(self, generation_config: GenerationConfig) -> dict: - """ - Build the arguments dictionary for Anthropic's messages.create(). + """Build the arguments dictionary for Anthropic's messages.create(). Handles tool configuration according to Anthropic's schema: { @@ -175,10 +173,8 @@ class AnthropicCompletionProvider(CompletionProvider): return args def _convert_to_chat_completion(self, anthropic_msg: Message) -> dict: - """ - Convert a **non-streaming** Anthropic `Message` response into - an OpenAI-style dict. - """ + """Convert a **non-streaming** Anthropic `Message` response into an + OpenAI-style dict.""" # anthropic_msg.content is a list of blocks; gather text from "text" blocks content_text = "" if anthropic_msg.content: @@ -239,9 +235,8 @@ class AnthropicCompletionProvider(CompletionProvider): def _split_system_messages( self, messages: list[dict] ) -> (list[dict], Optional[str]): - """ - Extract the system message and properly group tool results with their calls. - """ + """Extract the system message and properly group tool results with + their calls.""" system_msg = None filtered = [] pending_tool_results = [] @@ -299,8 +294,9 @@ class AnthropicCompletionProvider(CompletionProvider): return filtered, system_msg async def _execute_task(self, task: dict[str, Any]): - """ - Async entry point. Decide if streaming or not, then call the appropriate helper. + """Async entry point. + + Decide if streaming or not, then call the appropriate helper. """ api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: @@ -357,9 +353,8 @@ class AnthropicCompletionProvider(CompletionProvider): async def _execute_task_async_streaming( self, args: dict ) -> AsyncGenerator[dict, None]: - """ - Streaming call (async): yields partial tokens in OpenAI-like SSE format. - """ + """Streaming call (async): yields partial tokens in OpenAI-like SSE + format.""" # The `stream=True` is typically handled by Anthropics from the original args, # but we remove it to avoid conflicts and rely on `messages.stream()`. args.pop("stream", None) @@ -390,9 +385,7 @@ class AnthropicCompletionProvider(CompletionProvider): raise def _execute_task_sync(self, task: dict[str, Any]): - """ - Synchronous entry point. - """ + """Synchronous entry point.""" messages = task["messages"] generation_config = task["generation_config"] extra_kwargs = task["kwargs"] @@ -412,9 +405,7 @@ class AnthropicCompletionProvider(CompletionProvider): return self._execute_task_sync_nonstreaming(args) def _execute_task_sync_nonstreaming(self, args: dict) -> LLMChatCompletion: - """ - Non-streaming synchronous call. - """ + """Non-streaming synchronous call.""" try: response = self.client.messages.create(**args) logger.debug("Anthropic sync non-stream call succeeded.") diff --git a/py/core/providers/llm/openai.py b/py/core/providers/llm/openai.py index d6965be99..18e8b9394 100644 --- a/py/core/providers/llm/openai.py +++ b/py/core/providers/llm/openai.py @@ -150,7 +150,8 @@ class OpenAICompletionProvider(CompletionProvider): ) def _get_client_and_model(self, model: str): - """Determine which client to use based on model prefix and return the appropriate client and model name.""" + """Determine which client to use based on model prefix and return the + appropriate client and model name.""" if model.startswith("azure/"): if not self.azure_client: raise ValueError( diff --git a/py/core/providers/llm/r2r_llm.py b/py/core/providers/llm/r2r_llm.py index 30a4240ab..02e745cbb 100644 --- a/py/core/providers/llm/r2r_llm.py +++ b/py/core/providers/llm/r2r_llm.py @@ -14,19 +14,18 @@ logger = logging.getLogger(__name__) class R2RCompletionProvider(CompletionProvider): - """ - A provider that routes to the right LLM provider (R2R): - - If `generation_config.model` starts with "anthropic/", call AnthropicCompletionProvider. - - If it starts with "azure-foundry/", call AzureFoundryCompletionProvider. - - If it starts with one of the other OpenAI-like prefixes ("openai/", "azure/", "deepseek/", "ollama/", "lmstudio/") - or has no prefix (e.g. "gpt-4", "gpt-3.5"), call OpenAICompletionProvider. - - Otherwise, fallback to LiteLLMCompletionProvider. + """A provider that routes to the right LLM provider (R2R): + + - If `generation_config.model` starts with "anthropic/", call AnthropicCompletionProvider. + - If it starts with "azure-foundry/", call AzureFoundryCompletionProvider. + - If it starts with one of the other OpenAI-like prefixes ("openai/", "azure/", "deepseek/", "ollama/", "lmstudio/") + or has no prefix (e.g. "gpt-4", "gpt-3.5"), call OpenAICompletionProvider. + - Otherwise, fallback to LiteLLMCompletionProvider. """ def __init__(self, config: CompletionConfig, *args, **kwargs) -> None: - """ - Initialize sub-providers for OpenAI, Anthropic, LiteLLM, and Azure Foundry. - """ + """Initialize sub-providers for OpenAI, Anthropic, LiteLLM, and Azure + Foundry.""" super().__init__(config) self.config = config @@ -51,9 +50,8 @@ class R2RCompletionProvider(CompletionProvider): def _choose_subprovider_by_model( self, model_name: str, is_streaming: bool = False ) -> CompletionProvider: - """ - Decide which underlying sub-provider to call based on the model name (prefix). - """ + """Decide which underlying sub-provider to call based on the model name + (prefix).""" # Route to Anthropic if appropriate. if model_name.startswith("anthropic/"): if not is_streaming: @@ -88,18 +86,16 @@ class R2RCompletionProvider(CompletionProvider): return self._litellm_provider async def _execute_task(self, task: dict[str, Any]): - """ - Pick the sub-provider based on model name and forward the async call. - """ + """Pick the sub-provider based on model name and forward the async + call.""" generation_config: GenerationConfig = task["generation_config"] model_name = generation_config.model sub_provider = self._choose_subprovider_by_model(model_name) return await sub_provider._execute_task(task) def _execute_task_sync(self, task: dict[str, Any]): - """ - Pick the sub-provider based on model name and forward the sync call. - """ + """Pick the sub-provider based on model name and forward the sync + call.""" generation_config: GenerationConfig = task["generation_config"] model_name = generation_config.model sub_provider = self._choose_subprovider_by_model(model_name) diff --git a/py/core/providers/orchestration/hatchet.py b/py/core/providers/orchestration/hatchet.py index ba12bcd03..210367777 100644 --- a/py/core/providers/orchestration/hatchet.py +++ b/py/core/providers/orchestration/hatchet.py @@ -15,7 +15,7 @@ class HatchetOrchestrationProvider(OrchestrationProvider): except ImportError: raise ImportError( "Hatchet SDK not installed. Please install it using `pip install hatchet-sdk`." - ) + ) from None root_logger = logging.getLogger() self.orchestrator = Hatchet( diff --git a/py/core/telemetry/posthog.py b/py/core/telemetry/posthog.py index 358cc1332..1b0fb537f 100644 --- a/py/core/telemetry/posthog.py +++ b/py/core/telemetry/posthog.py @@ -9,9 +9,10 @@ logger = logging.getLogger() class PosthogClient: - """ - This is a write-only project API key, so it can only create new events. It can't - read events or any of your other data stored with PostHog, so it's safe to use in public apps. + """This is a write-only project API key, so it can only create new events. + + It can't read events or any of your other data stored with PostHog, so it's + safe to use in public apps. """ def __init__( diff --git a/py/core/utils/logging_config.py b/py/core/utils/logging_config.py index 231aaedb1..2337d2885 100644 --- a/py/core/utils/logging_config.py +++ b/py/core/utils/logging_config.py @@ -7,9 +7,10 @@ from pathlib import Path class HTTPStatusFilter(logging.Filter): - """ - This filter inspects uvicorn.access log records. It uses record.getMessage() to retrieve - the fully formatted log message. Then it searches for HTTP status codes and adjusts the + """This filter inspects uvicorn.access log records. It uses + record.getMessage() to retrieve the fully formatted log message. Then it + searches for HTTP status codes and adjusts the. + record's log level based on that status: - 4xx: WARNING - 5xx: ERROR diff --git a/py/core/utils/serper.py b/py/core/utils/serper.py index bfed0cad3..dcb842fa6 100644 --- a/py/core/utils/serper.py +++ b/py/core/utils/serper.py @@ -6,9 +6,8 @@ import os # TODO - Move process json to dedicated data processing module def process_json(json_object, indent=0): - """ - Recursively traverses the JSON object (dicts and lists) to create an unstructured text blob. - """ + """Recursively traverses the JSON object (dicts and lists) to create an + unstructured text blob.""" text_blob = "" if isinstance(json_object, dict): for key, value in json_object.items(): diff --git a/py/migrations/versions/2fac23e4d91b_migrate_to_document_search.py b/py/migrations/versions/2fac23e4d91b_migrate_to_document_search.py index 7e45237fb..d60dc00ed 100644 --- a/py/migrations/versions/2fac23e4d91b_migrate_to_document_search.py +++ b/py/migrations/versions/2fac23e4d91b_migrate_to_document_search.py @@ -1,9 +1,8 @@ -"""migrate_to_document_search +"""migrate_to_document_search. Revision ID: 2fac23e4d91b Revises: Create Date: 2024-11-11 11:55:49.461015 - """ import asyncio @@ -44,13 +43,13 @@ class Vector(UserDefinedType): def run_async(coroutine): - """Helper function to run async code synchronously""" + """Helper function to run async code synchronously.""" with ThreadPoolExecutor() as pool: return pool.submit(asyncio.run, coroutine).result() async def async_generate_all_summaries(): - """Asynchronous function to generate summaries""" + """Asynchronous function to generate summaries.""" base_url = os.getenv("R2R_BASE_URL") if not base_url: @@ -182,12 +181,12 @@ async def async_generate_all_summaries(): def generate_all_summaries(): - """Synchronous wrapper for async_generate_all_summaries""" + """Synchronous wrapper for async_generate_all_summaries.""" return run_async(async_generate_all_summaries()) def check_if_upgrade_needed(): - """Check if the upgrade has already been applied or is needed""" + """Check if the upgrade has already been applied or is needed.""" # Get database connection connection = op.get_bind() inspector = inspect(connection) @@ -234,7 +233,7 @@ def upgrade() -> None: ) pass except json.JSONDecodeError: - raise ValueError("Invalid document_summaries.json file") + raise ValueError("Invalid document_summaries.json file") from None # Create the vector extension if it doesn't exist op.execute("CREATE EXTENSION IF NOT EXISTS vector") @@ -253,8 +252,7 @@ def upgrade() -> None: ) # Add generated column for full text search - op.execute( - f""" + op.execute(f""" ALTER TABLE {project_name}.document_info ADD COLUMN doc_search_vector tsvector GENERATED ALWAYS AS ( @@ -262,17 +260,14 @@ def upgrade() -> None: setweight(to_tsvector('english', COALESCE(summary, '')), 'B') || setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C') ) STORED; - """ - ) + """) # Create index for full text search - op.execute( - f""" + op.execute(f""" CREATE INDEX idx_doc_search_{project_name} ON {project_name}.document_info USING GIN (doc_search_vector); - """ - ) + """) if document_summaries: # Update existing documents with summaries and embeddings @@ -283,15 +278,13 @@ def upgrade() -> None: ) # Use plain SQL with proper escaping for PostgreSQL - op.execute( - f""" + op.execute(f""" UPDATE {project_name}.document_info SET summary = '{doc_data["summary"].replace("'", "''")}', summary_embedding = '{embedding_str}'::vector({dimension}) WHERE document_id = '{doc_id}'::uuid; - """ - ) + """) else: print( "No document summaries found, skipping update of existing documents" @@ -300,16 +293,14 @@ def upgrade() -> None: def downgrade() -> None: # First drop any dependencies on the columns we want to remove - op.execute( - f""" + op.execute(f""" -- Drop the full text search index first DROP INDEX IF EXISTS {project_name}.idx_doc_search_{project_name}; -- Drop the generated column that depends on the summary column ALTER TABLE {project_name}.document_info DROP COLUMN IF EXISTS doc_search_vector; - """ - ) + """) # Now we can safely drop the summary and embedding columns op.drop_column("document_info", "summary_embedding", schema=project_name) diff --git a/py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py b/py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py index b538a77a0..9e5675dc8 100644 --- a/py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py +++ b/py/migrations/versions/3efc7b3b1b3d_add_total_tokens_count.py @@ -1,9 +1,8 @@ -"""add_total_tokens_to_documents +"""add_total_tokens_to_documents. Revision ID: 3efc7b3b1b3d Revises: 7eb70560f406 Create Date: 2025-01-21 14:59:00.000000 - """ import logging @@ -28,9 +27,10 @@ project_name = os.getenv("R2R_PROJECT_NAME", "r2r_default") def count_tokens_for_text(text: str, model: str = "gpt-3.5-turbo") -> int: - """ - Count the number of tokens in the given text using tiktoken. - Default model is set to "gpt-3.5-turbo". Adjust if you prefer a different model. + """Count the number of tokens in the given text using tiktoken. + + Default model is set to "gpt-3.5-turbo". Adjust if you prefer a different + model. """ try: encoding = tiktoken.encoding_for_model(model) @@ -41,7 +41,7 @@ def count_tokens_for_text(text: str, model: str = "gpt-3.5-turbo") -> int: def check_if_upgrade_needed() -> bool: - """Check if the upgrade has already been applied""" + """Check if the upgrade has already been applied.""" connection = op.get_bind() inspector = inspect(connection) @@ -114,15 +114,13 @@ def upgrade() -> None: ) # Fetch next batch of document IDs - batch_docs_query = text( - f""" + batch_docs_query = text(f""" SELECT id FROM {project_name}.documents ORDER BY id LIMIT :limit_val OFFSET :offset_val - """ - ) + """) batch_docs = connection.execute( batch_docs_query, {"limit_val": BATCH_SIZE, "offset_val": offset} ).fetchall() @@ -135,13 +133,11 @@ def upgrade() -> None: # Process each document in the batch for doc_id in doc_ids: - chunks_query = text( - f""" + chunks_query = text(f""" SELECT data FROM {project_name}.chunks WHERE document_id = :doc_id - """ - ) + """) chunk_rows = connection.execute( chunks_query, {"doc_id": doc_id} ).fetchall() @@ -154,13 +150,11 @@ def upgrade() -> None: ) # Update total_tokens for this document - update_query = text( - f""" + update_query = text(f""" UPDATE {project_name}.documents SET total_tokens = :tokcount WHERE id = :doc_id - """ - ) + """) connection.execute( update_query, {"tokcount": total_tokens, "doc_id": doc_id} ) @@ -171,7 +165,7 @@ def upgrade() -> None: def downgrade() -> None: - """Remove the total_tokens column on downgrade""" + """Remove the total_tokens column on downgrade.""" logger.info( "Dropping column 'total_tokens' from 'documents' table (downgrade)." ) diff --git a/py/migrations/versions/7eb70560f406_add_limits_overrides_to_users.py b/py/migrations/versions/7eb70560f406_add_limits_overrides_to_users.py index 24eaecb95..a48ffc3c2 100644 --- a/py/migrations/versions/7eb70560f406_add_limits_overrides_to_users.py +++ b/py/migrations/versions/7eb70560f406_add_limits_overrides_to_users.py @@ -1,9 +1,8 @@ -"""add_limits_overrides_to_users +"""add_limits_overrides_to_users. Revision ID: 7eb70560f406 Revises: c45a9cf6a8a4 Create Date: 2025-01-03 20:27:16.139511 - """ import os @@ -23,7 +22,7 @@ project_name = os.getenv("R2R_PROJECT_NAME", "r2r_default") def check_if_upgrade_needed(): - """Check if the upgrade has already been applied""" + """Check if the upgrade has already been applied.""" connection = op.get_bind() inspector = inspect(connection) diff --git a/py/migrations/versions/8077140e1e99_v3_api_database_revision.py b/py/migrations/versions/8077140e1e99_v3_api_database_revision.py index 6e1cba286..3757843eb 100644 --- a/py/migrations/versions/8077140e1e99_v3_api_database_revision.py +++ b/py/migrations/versions/8077140e1e99_v3_api_database_revision.py @@ -1,9 +1,8 @@ -"""v3_api_database_revision +"""v3_api_database_revision. Revision ID: 8077140e1e99 Revises: Create Date: 2024-12-03 12:10:10.878485 - """ import os @@ -27,7 +26,7 @@ if not project_name: def check_if_upgrade_needed(): - """Check if the upgrade has already been applied or is needed""" + """Check if the upgrade has already been applied or is needed.""" connection = op.get_bind() inspector = inspect(connection) diff --git a/py/migrations/versions/c45a9cf6a8a4_add_user_and_document_count_to_.py b/py/migrations/versions/c45a9cf6a8a4_add_user_and_document_count_to_.py index 33baf0737..aceca0bf9 100644 --- a/py/migrations/versions/c45a9cf6a8a4_add_user_and_document_count_to_.py +++ b/py/migrations/versions/c45a9cf6a8a4_add_user_and_document_count_to_.py @@ -1,9 +1,8 @@ -"""Add user and document count to collection +"""Add user and document count to collection. Revision ID: c45a9cf6a8a4 Revises: 8077140e1e99 Create Date: 2024-12-10 13:28:07.798167 - """ import os @@ -27,7 +26,7 @@ if not project_name: def check_if_upgrade_needed(): - """Check if the upgrade has already been applied""" + """Check if the upgrade has already been applied.""" connection = op.get_bind() inspector = inspect(connection) @@ -67,8 +66,7 @@ def upgrade(): ) # Initialize the counts based on existing relationships - op.execute( - f""" + op.execute(f""" WITH collection_counts AS ( SELECT c.id, COUNT(DISTINCT u.id) as user_count, @@ -83,8 +81,7 @@ def upgrade(): document_count = COALESCE(cc.document_count, 0) FROM collection_counts cc WHERE c.id = cc.id - """ - ) + """) def downgrade(): diff --git a/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py b/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py index a8d0bccbb..a4ebf89ca 100644 --- a/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py +++ b/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py @@ -1,9 +1,8 @@ -"""migrate_to_asyncpg +"""migrate_to_asyncpg. Revision ID: d342e632358a Revises: Create Date: 2024-10-22 11:55:49.461015 - """ import os @@ -21,7 +20,6 @@ down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None - project_name = os.getenv("R2R_PROJECT_NAME") or "r2r_default" new_vector_table_name = "vectors" @@ -34,7 +32,7 @@ class Vector(UserDefinedType): def check_if_upgrade_needed(): - """Check if the upgrade has already been applied or is needed""" + """Check if the upgrade has already been applied or is needed.""" connection = op.get_bind() inspector = inspect(connection) @@ -150,8 +148,7 @@ def upgrade() -> None: # Migrate data from old table (assuming old table name is 'old_vectors') # Note: You'll need to replace 'old_schema' and 'old_vectors' with your actual names - op.execute( - f""" + op.execute(f""" INSERT INTO {project_name}.{new_vector_table_name} (extraction_id, document_id, user_id, collection_ids, vec, text, metadata) SELECT @@ -163,23 +160,18 @@ def upgrade() -> None: text, metadata FROM {project_name}.{old_vector_table_name} - """ - ) + """) # Verify data migration - op.execute( - f""" + op.execute(f""" SELECT COUNT(*) old_count FROM {project_name}.{old_vector_table_name}; SELECT COUNT(*) new_count FROM {project_name}.{new_vector_table_name}; - """ - ) + """) # If we get here, migration was successful, so drop the old table - op.execute( - f""" + op.execute(f""" DROP TABLE IF EXISTS {project_name}.{old_vector_table_name}; - """ - ) + """) def downgrade() -> None: diff --git a/py/pyproject.toml b/py/pyproject.toml index bf64615ab..62a88248f 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -175,6 +175,7 @@ exclude = ["py/tests/*"] line-length = 79 target-version = "py310" select = ["E", "F", "I", "B"] +ignore = ["B008", "B024", "B026", "E501", "F402", "F403", "F405", "F841"] [tool.ruff.format] quote-style = "double" diff --git a/py/sdk/asnyc_methods/__init__.py b/py/sdk/asnyc_methods/__init__.py index bda8063cc..efa520d67 100644 --- a/py/sdk/asnyc_methods/__init__.py +++ b/py/sdk/asnyc_methods/__init__.py @@ -1,23 +1,23 @@ -from .chunks import * -from .collections import * -from .conversations import * -from .documents import * -from .graphs import * -from .indices import * -from .prompts import * -from .retrieval import * -from .system import * -from .users import * +from .chunks import ChunksSDK +from .collections import CollectionsSDK +from .conversations import ConversationsSDK +from .documents import DocumentsSDK +from .graphs import GraphsSDK +from .indices import IndicesSDK +from .prompts import PromptsSDK +from .retrieval import RetrievalSDK +from .system import SystemSDK +from .users import UsersSDK __all__ = [ - "Chunks", - "Collections", - "Conversations", - "Documents", - "Graphs", - "Indices", - "Prompts", - "Retrieval", - "System", - "Users", + "ChunksSDK", + "CollectionsSDK", + "ConversationsSDK", + "DocumentsSDK", + "GraphsSDK", + "IndicesSDK", + "PromptsSDK", + "RetrievalSDK", + "SystemSDK", + "UsersSDK", ] diff --git a/py/sdk/asnyc_methods/chunks.py b/py/sdk/asnyc_methods/chunks.py index 08c1cfa7d..a64142d76 100644 --- a/py/sdk/asnyc_methods/chunks.py +++ b/py/sdk/asnyc_methods/chunks.py @@ -13,9 +13,7 @@ from ..models import SearchSettings class ChunksSDK: - """ - SDK for interacting with chunks in the v3 API. - """ + """SDK for interacting with chunks in the v3 API.""" def __init__(self, client): self.client = client @@ -24,8 +22,7 @@ class ChunksSDK: self, chunk: dict[str, str], ) -> WrappedChunkResponse: - """ - Update an existing chunk. + """Update an existing chunk. Args: chunk (dict[str, str]): Chunk to update. Should contain: @@ -47,8 +44,7 @@ class ChunksSDK: self, id: str | UUID, ) -> WrappedChunkResponse: - """ - Get a specific chunk. + """Get a specific chunk. Args: id (str | UUID): Chunk ID to retrieve @@ -73,8 +69,7 @@ class ChunksSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedChunksResponse: - """ - List chunks for a specific document. + """List chunks for a specific document. Args: document_id (str | UUID): Document ID to get chunks for @@ -105,8 +100,7 @@ class ChunksSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a specific chunk. + """Delete a specific chunk. Args: id (str | UUID): ID of chunk to delete @@ -130,8 +124,7 @@ class ChunksSDK: limit: Optional[int] = 100, filters: Optional[dict] = None, ) -> WrappedChunksResponse: - """ - List chunks with pagination support. + """List chunks with pagination support. Args: include_vectors (bool, optional): Include vector data in response. Defaults to False. @@ -167,8 +160,7 @@ class ChunksSDK: query: str, search_settings: Optional[dict | SearchSettings] = None, ) -> WrappedVectorSearchResponse: - """ - Conduct a vector and/or graph search. + """Conduct a vector and/or graph search. Args: query (str): The query to search for. diff --git a/py/sdk/asnyc_methods/collections.py b/py/sdk/asnyc_methods/collections.py index 4d64840f2..a768b72ee 100644 --- a/py/sdk/asnyc_methods/collections.py +++ b/py/sdk/asnyc_methods/collections.py @@ -20,8 +20,7 @@ class CollectionsSDK: name: str, description: Optional[str] = None, ) -> WrappedCollectionResponse: - """ - Create a new collection. + """Create a new collection. Args: name (str): Name of the collection @@ -46,8 +45,7 @@ class CollectionsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCollectionsResponse: - """ - List collections with pagination and filtering options. + """List collections with pagination and filtering options. Args: ids (Optional[list[str | UUID]]): Filter collections by ids @@ -74,8 +72,7 @@ class CollectionsSDK: self, id: str | UUID, ) -> WrappedCollectionResponse: - """ - Get detailed information about a specific collection. + """Get detailed information about a specific collection. Args: id (str | UUID): Collection ID to retrieve @@ -96,8 +93,7 @@ class CollectionsSDK: description: Optional[str] = None, generate_description: Optional[bool] = False, ) -> WrappedCollectionResponse: - """ - Update collection information. + """Update collection information. Args: id (str | UUID): Collection ID to update @@ -129,8 +125,7 @@ class CollectionsSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a collection. + """Delete a collection. Args: id (str | UUID): Collection ID to delete @@ -150,8 +145,7 @@ class CollectionsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedDocumentsResponse: - """ - List all documents in a collection. + """List all documents in a collection. Args: id (str | UUID): Collection ID @@ -180,8 +174,7 @@ class CollectionsSDK: id: str | UUID, document_id: str | UUID, ) -> WrappedGenericMessageResponse: - """ - Add a document to a collection. + """Add a document to a collection. Args: id (str | UUID): Collection ID @@ -203,8 +196,7 @@ class CollectionsSDK: id: str | UUID, document_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a document from a collection. + """Remove a document from a collection. Args: id (str | UUID): Collection ID @@ -227,8 +219,7 @@ class CollectionsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedUsersResponse: - """ - List all users in a collection. + """List all users in a collection. Args: id (str, UUID): Collection ID @@ -254,8 +245,7 @@ class CollectionsSDK: id: str | UUID, user_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Add a user to a collection. + """Add a user to a collection. Args: id (str | UUID): Collection ID @@ -275,8 +265,7 @@ class CollectionsSDK: id: str | UUID, user_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a user from a collection. + """Remove a user from a collection. Args: id (str | UUID): Collection ID @@ -299,8 +288,7 @@ class CollectionsSDK: settings: Optional[dict] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Extract entities and relationships from documents in a collection. + """Extract entities and relationships from documents in a collection. Args: id (str | UUID): Collection ID to extract from @@ -330,8 +318,7 @@ class CollectionsSDK: async def retrieve_by_name( self, name: str, owner_id: Optional[str] = None ) -> WrappedCollectionResponse: - """ - Retrieve a collection by its name. + """Retrieve a collection by its name. For non-superusers, the backend will use the authenticated user's ID. For superusers, the caller must supply an owner_id to restrict the search. diff --git a/py/sdk/asnyc_methods/conversations.py b/py/sdk/asnyc_methods/conversations.py index ace06eb36..885f9fc51 100644 --- a/py/sdk/asnyc_methods/conversations.py +++ b/py/sdk/asnyc_methods/conversations.py @@ -22,8 +22,7 @@ class ConversationsSDK: self, name: Optional[str] = None, ) -> WrappedConversationResponse: - """ - Create a new conversation. + """Create a new conversation. Returns: WrappedConversationResponse @@ -47,8 +46,7 @@ class ConversationsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedConversationsResponse: - """ - List conversations with pagination and sorting options. + """List conversations with pagination and sorting options. Args: ids (Optional[list[str | UUID]]): List of conversation IDs to retrieve @@ -78,8 +76,7 @@ class ConversationsSDK: self, id: str | UUID, ) -> WrappedConversationMessagesResponse: - """ - Get detailed information about a specific conversation. + """Get detailed information about a specific conversation. Args: id (str | UUID): The ID of the conversation to retrieve @@ -100,8 +97,7 @@ class ConversationsSDK: id: str | UUID, name: str, ) -> WrappedConversationResponse: - """ - Update an existing conversation. + """Update an existing conversation. Args: id (str | UUID): The ID of the conversation to update @@ -127,8 +123,7 @@ class ConversationsSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a conversation. + """Delete a conversation. Args: id (str | UUID): The ID of the conversation to delete @@ -152,8 +147,7 @@ class ConversationsSDK: metadata: Optional[dict] = None, parent_id: Optional[str] = None, ) -> WrappedMessageResponse: - """ - Add a new message to a conversation. + """Add a new message to a conversation. Args: id (str | UUID): The ID of the conversation to add the message to @@ -190,8 +184,7 @@ class ConversationsSDK: content: Optional[str] = None, metadata: Optional[dict] = None, ) -> WrappedMessageResponse: - """ - Update an existing message in a conversation. + """Update an existing message in a conversation. Args: id (str | UUID): The ID of the conversation containing the message @@ -221,8 +214,8 @@ class ConversationsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export conversations to a CSV file, streaming the results directly to disk. + """Export conversations to a CSV file, streaming the results directly + to disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -272,8 +265,8 @@ class ConversationsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export messages to a CSV file, streaming the results directly to disk. + """Export messages to a CSV file, streaming the results directly to + disk. Args: output_path (str | Path): Local path where the CSV file should be saved diff --git a/py/sdk/asnyc_methods/documents.py b/py/sdk/asnyc_methods/documents.py index 0a7dee9a1..f2f45a585 100644 --- a/py/sdk/asnyc_methods/documents.py +++ b/py/sdk/asnyc_methods/documents.py @@ -24,9 +24,7 @@ from ..models import IngestionMode, SearchMode, SearchSettings class DocumentsSDK: - """ - SDK for interacting with documents in the v3 API. - """ + """SDK for interacting with documents in the v3 API.""" def __init__(self, client): self.client = client @@ -43,8 +41,7 @@ class DocumentsSDK: ingestion_config: Optional[dict | IngestionMode] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedIngestionResponse: - """ - Create a new document from either a file or content. + """Create a new document from either a file or content. Args: file_path (Optional[str]): The file to upload, if any @@ -141,8 +138,7 @@ class DocumentsSDK: self, id: str | UUID, ) -> WrappedDocumentResponse: - """ - Get a specific document by ID. + """Get a specific document by ID. Args: id (str | UUID): ID of document to retrieve @@ -178,9 +174,7 @@ class DocumentsSDK: end_date: Optional[datetime] = None, output_path: Optional[str | Path] = None, ) -> BytesIO | None: - """ - Download multiple documents as a zip file. - """ + """Download multiple documents as a zip file.""" params: dict[str, Any] = {} if document_ids: params["document_ids"] = [str(doc_id) for doc_id in document_ids] @@ -218,8 +212,8 @@ class DocumentsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export documents to a CSV file, streaming the results directly to disk. + """Export documents to a CSV file, streaming the results directly to + disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -269,8 +263,8 @@ class DocumentsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export documents to a CSV file, streaming the results directly to disk. + """Export documents to a CSV file, streaming the results directly to + disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -321,8 +315,8 @@ class DocumentsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export document relationships to a CSV file, streaming the results directly to disk. + """Export document relationships to a CSV file, streaming the results + directly to disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -369,8 +363,7 @@ class DocumentsSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a specific document. + """Delete a specific document. Args: id (str | UUID): ID of document to delete @@ -393,8 +386,7 @@ class DocumentsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedChunksResponse: - """ - Get chunks for a specific document. + """Get chunks for a specific document. Args: id (str | UUID): ID of document to retrieve chunks for @@ -426,8 +418,7 @@ class DocumentsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCollectionsResponse: - """ - List collections for a specific document. + """List collections for a specific document. Args: id (str | UUID): ID of document to retrieve collections for @@ -455,8 +446,7 @@ class DocumentsSDK: self, filters: dict, ) -> WrappedBooleanResponse: - """ - Delete documents based on filters. + """Delete documents based on filters. Args: filters (dict): Filters to apply when selecting documents to delete @@ -480,8 +470,7 @@ class DocumentsSDK: settings: Optional[dict] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Extract entities and relationships from a document. + """Extract entities and relationships from a document. Args: id (str, UUID): ID of document to extract from @@ -512,8 +501,7 @@ class DocumentsSDK: limit: Optional[int] = 100, include_embeddings: Optional[bool] = False, ) -> WrappedEntitiesResponse: - """ - List entities extracted from a document. + """List entities extracted from a document. Args: id (str | UUID): ID of document to get entities from @@ -546,8 +534,7 @@ class DocumentsSDK: entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, ) -> WrappedRelationshipsResponse: - """ - List relationships extracted from a document. + """List relationships extracted from a document. Args: id (str | UUID): ID of document to get relationships from @@ -583,8 +570,7 @@ class DocumentsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedDocumentsResponse: - """ - List documents with pagination. + """List documents with pagination. Args: ids (Optional[list[str | UUID]]): Optional list of document IDs to filter by @@ -616,8 +602,7 @@ class DocumentsSDK: search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, ) -> WrappedDocumentSearchResponse: - """ - Conduct a vector and/or graph search. + """Conduct a vector and/or graph search. Args: query (str): The query to search for. @@ -650,8 +635,7 @@ class DocumentsSDK: settings: Optional[dict] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Deduplicate entities and relationships from a document. + """Deduplicate entities and relationships from a document. Args: id (str, UUID): ID of document to extract from diff --git a/py/sdk/asnyc_methods/graphs.py b/py/sdk/asnyc_methods/graphs.py index 222122372..676aceaa1 100644 --- a/py/sdk/asnyc_methods/graphs.py +++ b/py/sdk/asnyc_methods/graphs.py @@ -17,9 +17,7 @@ from shared.api.models import ( class GraphsSDK: - """ - SDK for interacting with knowledge graphs in the v3 API. - """ + """SDK for interacting with knowledge graphs in the v3 API.""" def __init__(self, client): self.client = client @@ -30,8 +28,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedGraphsResponse: - """ - List graphs with pagination and filtering options. + """List graphs with pagination and filtering options. Args: ids (Optional[list[str | UUID]]): Filter graphs by ids @@ -58,8 +55,7 @@ class GraphsSDK: self, collection_id: str | UUID, ) -> WrappedGraphResponse: - """ - Get detailed information about a specific graph. + """Get detailed information about a specific graph. Args: collection_id (str | UUID): Graph ID to retrieve @@ -77,8 +73,7 @@ class GraphsSDK: self, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Deletes a graph and all its associated data. + """Deletes a graph and all its associated data. This endpoint permanently removes the specified graph along with all entities and relationships that belong to only this graph. @@ -103,8 +98,7 @@ class GraphsSDK: name: Optional[str] = None, description: Optional[str] = None, ) -> WrappedGraphResponse: - """ - Update graph information. + """Update graph information. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -135,8 +129,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedEntitiesResponse: - """ - List entities in a graph. + """List entities in a graph. Args: collection_id (str | UUID): Graph ID to list entities from @@ -165,8 +158,7 @@ class GraphsSDK: collection_id: str | UUID, entity_id: str | UUID, ) -> WrappedEntityResponse: - """ - Get entity information in a graph. + """Get entity information in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -188,8 +180,7 @@ class GraphsSDK: collection_id: str | UUID, entity_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove an entity from a graph. + """Remove an entity from a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -210,8 +201,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedRelationshipsResponse: - """ - List relationships in a graph. + """List relationships in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -240,8 +230,7 @@ class GraphsSDK: collection_id: str | UUID, relationship_id: str | UUID, ) -> WrappedRelationshipResponse: - """ - Get relationship information in a graph. + """Get relationship information in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -263,8 +252,7 @@ class GraphsSDK: collection_id: str | UUID, relationship_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a relationship from a graph. + """Remove a relationship from a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -287,8 +275,7 @@ class GraphsSDK: settings: Optional[dict] = None, run_with_orchestration: bool = True, ) -> WrappedGenericMessageResponse: - """ - Build a graph. + """Build a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -318,8 +305,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCommunitiesResponse: - """ - List communities in a graph. + """List communities in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -348,8 +334,7 @@ class GraphsSDK: collection_id: str | UUID, community_id: str | UUID, ) -> WrappedCommunityResponse: - """ - Get community information in a graph. + """Get community information in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -378,8 +363,7 @@ class GraphsSDK: level: Optional[int] = None, attributes: Optional[dict] = None, ) -> WrappedCommunityResponse: - """ - Update community information. + """Update community information. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -425,8 +409,7 @@ class GraphsSDK: collection_id: str | UUID, community_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a community from a graph. + """Remove a community from a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -447,8 +430,8 @@ class GraphsSDK: self, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Adds documents to a graph by copying their entities and relationships. + """Adds documents to a graph by copying their entities and + relationships. This endpoint: 1. Copies document entities to the graphs_entities table @@ -481,8 +464,7 @@ class GraphsSDK: collection_id: str | UUID, document_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Removes a document from a graph and removes any associated entities + """Removes a document from a graph and removes any associated entities. This endpoint: 1. Removes the document ID from the graph's document_ids array @@ -509,8 +491,7 @@ class GraphsSDK: category: Optional[str] = None, metadata: Optional[dict] = None, ) -> WrappedEntityResponse: - """ - Creates a new entity in the graph. + """Creates a new entity in the graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -552,8 +533,7 @@ class GraphsSDK: weight: Optional[float] = None, metadata: Optional[dict] = None, ) -> WrappedRelationshipResponse: - """ - Creates a new relationship in the graph. + """Creates a new relationship in the graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -600,8 +580,7 @@ class GraphsSDK: rating: Optional[float] = None, rating_explanation: Optional[str] = None, ) -> WrappedCommunityResponse: - """ - Creates a new community in the graph. + """Creates a new community in the graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph diff --git a/py/sdk/asnyc_methods/indices.py b/py/sdk/asnyc_methods/indices.py index 66c7a2fcb..966023ed5 100644 --- a/py/sdk/asnyc_methods/indices.py +++ b/py/sdk/asnyc_methods/indices.py @@ -17,8 +17,7 @@ class IndicesSDK: config: dict, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Create a new vector similarity search index in the database. + """Create a new vector similarity search index in the database. Args: config (dict | IndexConfig): Configuration for the vector index. @@ -46,8 +45,8 @@ class IndicesSDK: offset: Optional[int] = 0, limit: Optional[int] = 10, ) -> WrappedVectorIndicesResponse: - """ - List existing vector similarity search indices with pagination support. + """List existing vector similarity search indices with pagination + support. Args: filters (Optional[dict]): Filter criteria for indices. @@ -77,8 +76,7 @@ class IndicesSDK: index_name: str, table_name: str = "vectors", ) -> WrappedVectorIndexResponse: - """ - Get detailed information about a specific vector index. + """Get detailed information about a specific vector index. Args: index_name (str): The name of the index to retrieve. @@ -100,8 +98,7 @@ class IndicesSDK: index_name: str, table_name: str = "vectors", ) -> WrappedGenericMessageResponse: - """ - Delete an existing vector index. + """Delete an existing vector index. Args: index_name (str): The name of the index to retrieve. diff --git a/py/sdk/asnyc_methods/prompts.py b/py/sdk/asnyc_methods/prompts.py index b5881562d..c7cdb50af 100644 --- a/py/sdk/asnyc_methods/prompts.py +++ b/py/sdk/asnyc_methods/prompts.py @@ -16,8 +16,8 @@ class PromptsSDK: async def create( self, name: str, template: str, input_types: dict ) -> WrappedGenericMessageResponse: - """ - Create a new prompt. + """Create a new prompt. + Args: name (str): The name of the prompt template (str): The template string for the prompt @@ -40,8 +40,8 @@ class PromptsSDK: return WrappedGenericMessageResponse(**response_dict) async def list(self) -> WrappedPromptsResponse: - """ - List all available prompts. + """List all available prompts. + Returns: dict: List of all available prompts """ @@ -59,8 +59,8 @@ class PromptsSDK: inputs: Optional[dict] = None, prompt_override: Optional[str] = None, ) -> WrappedPromptResponse: - """ - Get a specific prompt by name, optionally with inputs and override. + """Get a specific prompt by name, optionally with inputs and override. + Args: name (str): The name of the prompt to retrieve inputs (Optional[dict]): JSON-encoded inputs for the prompt @@ -88,8 +88,8 @@ class PromptsSDK: template: Optional[str] = None, input_types: Optional[dict] = None, ) -> WrappedGenericMessageResponse: - """ - Update an existing prompt's template and/or input types. + """Update an existing prompt's template and/or input types. + Args: name (str): The name of the prompt to update template (Optional[str]): The updated template string for the prompt @@ -112,8 +112,8 @@ class PromptsSDK: return WrappedGenericMessageResponse(**response_dict) async def delete(self, name: str) -> WrappedBooleanResponse: - """ - Delete a prompt by name. + """Delete a prompt by name. + Args: name (str): The name of the prompt to delete Returns: diff --git a/py/sdk/asnyc_methods/retrieval.py b/py/sdk/asnyc_methods/retrieval.py index 90e547332..09ea60d5d 100644 --- a/py/sdk/asnyc_methods/retrieval.py +++ b/py/sdk/asnyc_methods/retrieval.py @@ -18,9 +18,7 @@ from ..models import ( class RetrievalSDK: - """ - SDK for interacting with documents in the v3 API. - """ + """SDK for interacting with documents in the v3 API.""" def __init__(self, client): self.client = client @@ -31,8 +29,7 @@ class RetrievalSDK: search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, ) -> WrappedSearchResponse: - """ - Conduct a vector and/or graph search. + """Conduct a vector and/or graph search. Args: query (str): The query to search for. @@ -115,8 +112,8 @@ class RetrievalSDK: task_prompt_override: Optional[str] = None, include_title_if_available: Optional[bool] = False, ) -> WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: - """ - Conducts a Retrieval Augmented Generation (RAG) search with the given query. + """Conducts a Retrieval Augmented Generation (RAG) search with the + given query. Args: query (str): The query to search for. @@ -177,8 +174,7 @@ class RetrievalSDK: max_tool_context_length: Optional[int] = None, use_system_context: Optional[bool] = True, ) -> WrappedAgentResponse | AsyncGenerator[Message, None]: - """ - Performs a single turn in a conversation with a RAG agent. + """Performs a single turn in a conversation with a RAG agent. Args: message (Optional[dict | Message]): The message to send to the agent. @@ -243,8 +239,7 @@ class RetrievalSDK: tools: Optional[list[dict]] = None, max_tool_context_length: Optional[int] = None, ) -> WrappedAgentResponse | AsyncGenerator[Message, None]: - """ - Performs a single turn in a conversation with a RAG agent. + """Performs a single turn in a conversation with a RAG agent. Args: message (Optional[dict | Message]): The message to send to the agent. diff --git a/py/sdk/asnyc_methods/system.py b/py/sdk/asnyc_methods/system.py index e844fcea1..0fac5ee29 100644 --- a/py/sdk/asnyc_methods/system.py +++ b/py/sdk/asnyc_methods/system.py @@ -13,9 +13,7 @@ class SystemSDK: self.client = client async def health(self) -> WrappedGenericMessageResponse: - """ - Check the health of the R2R server. - """ + """Check the health of the R2R server.""" response_dict = await self.client._make_request( "GET", "health", version="v3" ) @@ -28,8 +26,7 @@ class SystemSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedLogsResponse: - """ - Get logs from the server. + """Get logs from the server. Args: run_type_filter (Optional[str]): The run type to filter by. @@ -55,8 +52,7 @@ class SystemSDK: return WrappedLogsResponse(**response_dict) async def settings(self) -> WrappedSettingsResponse: - """ - Get the configuration settings for the R2R server. + """Get the configuration settings for the R2R server. Returns: dict: The server settings. @@ -68,8 +64,8 @@ class SystemSDK: return WrappedSettingsResponse(**response_dict) async def status(self) -> WrappedServerStatsResponse: - """ - Get statistics about the server, including the start time, uptime, CPU usage, and memory usage. + """Get statistics about the server, including the start time, uptime, + CPU usage, and memory usage. Returns: dict: The server statistics. diff --git a/py/sdk/asnyc_methods/users.py b/py/sdk/asnyc_methods/users.py index 2e556bcf4..207c3cf44 100644 --- a/py/sdk/asnyc_methods/users.py +++ b/py/sdk/asnyc_methods/users.py @@ -27,8 +27,7 @@ class UsersSDK: bio: Optional[str] = None, profile_picture: Optional[str] = None, ) -> WrappedUserResponse: - """ - Register a new user. + """Register a new user. Args: email (str): User's email address @@ -62,9 +61,7 @@ class UsersSDK: async def send_verification_email( self, email: str ) -> WrappedGenericMessageResponse: - """ - Request that a verification email to a user. - """ + """Request that a verification email to a user.""" response_dict = await self.client._make_request( "POST", "users/send-verification-email", @@ -77,9 +74,8 @@ class UsersSDK: async def delete( self, id: str | UUID, password: str ) -> WrappedBooleanResponse: - """ - Delete a specific user. - Users can only delete their own account unless they are superusers. + """Delete a specific user. Users can only delete their own account + unless they are superusers. Args: id (str | UUID): User ID to delete @@ -103,8 +99,7 @@ class UsersSDK: async def verify_email( self, email: str, verification_code: str ) -> WrappedGenericMessageResponse: - """ - Verify a user's email address. + """Verify a user's email address. Args: email (str): User's email address @@ -127,8 +122,7 @@ class UsersSDK: return WrappedGenericMessageResponse(**response_dict) async def login(self, email: str, password: str) -> WrappedLoginResponse: - """ - Log in a user. + """Log in a user. Args: email (str): User's email address @@ -202,8 +196,7 @@ class UsersSDK: async def change_password( self, current_password: str, new_password: str ) -> WrappedGenericMessageResponse: - """ - Change the user's password. + """Change the user's password. Args: current_password (str): User's current password @@ -228,8 +221,7 @@ class UsersSDK: async def request_password_reset( self, email: str ) -> WrappedGenericMessageResponse: - """ - Request a password reset. + """Request a password reset. Args: email (str): User's email address @@ -249,8 +241,7 @@ class UsersSDK: async def reset_password( self, reset_token: str, new_password: str ) -> WrappedGenericMessageResponse: - """ - Reset password using a reset token. + """Reset password using a reset token. Args: reset_token (str): Password reset token @@ -278,8 +269,7 @@ class UsersSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedUsersResponse: - """ - List users with pagination and filtering options. + """List users with pagination and filtering options. Args: offset (int, optional): Specifies the number of objects to skip. Defaults to 0. @@ -308,8 +298,7 @@ class UsersSDK: self, id: str | UUID, ) -> WrappedUserResponse: - """ - Get a specific user. + """Get a specific user. Args: id (str | UUID): User ID to retrieve @@ -328,8 +317,7 @@ class UsersSDK: async def me( self, ) -> WrappedUserResponse: - """ - Get detailed information about the currently authenticated user. + """Get detailed information about the currently authenticated user. Returns: dict: Detailed user information @@ -353,8 +341,7 @@ class UsersSDK: limits_overrides: dict | None = None, metadata: dict[str, str | None] | None = None, ) -> WrappedUserResponse: - """ - Update user information. + """Update user information. Args: id (str | UUID): User ID to update @@ -398,8 +385,7 @@ class UsersSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCollectionsResponse: - """ - Get all collections associated with a specific user. + """Get all collections associated with a specific user. Args: id (str | UUID): User ID to get collections for @@ -428,8 +414,7 @@ class UsersSDK: id: str | UUID, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Add a user to a collection. + """Add a user to a collection. Args: id (str | UUID): User ID to add @@ -448,8 +433,7 @@ class UsersSDK: id: str | UUID, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a user from a collection. + """Remove a user from a collection. Args: id (str | UUID): User ID to remove @@ -472,8 +456,7 @@ class UsersSDK: name: Optional[str] = None, description: Optional[str] = None, ) -> WrappedAPIKeyResponse: - """ - Create a new API key for the specified user. + """Create a new API key for the specified user. Args: id (str | UUID): User ID to create API key for @@ -502,8 +485,7 @@ class UsersSDK: self, id: str | UUID, ) -> WrappedAPIKeysResponse: - """ - List all API keys for the specified user. + """List all API keys for the specified user. Args: id (str | UUID): User ID to list API keys for @@ -524,8 +506,7 @@ class UsersSDK: id: str | UUID, key_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a specific API key for the specified user. + """Delete a specific API key for the specified user. Args: id (str | UUID): User ID @@ -552,8 +533,8 @@ class UsersSDK: return WrappedLimitsResponse(**response_dict) async def oauth_google_authorize(self) -> WrappedGenericMessageResponse: - """ - Get Google OAuth 2.0 authorization URL from the server. + """Get Google OAuth 2.0 authorization URL from the server. + Returns: WrappedGenericMessageResponse """ @@ -566,8 +547,8 @@ class UsersSDK: return WrappedGenericMessageResponse(**response_dict) async def oauth_github_authorize(self) -> WrappedGenericMessageResponse: - """ - Get GitHub OAuth 2.0 authorization URL from the server. + """Get GitHub OAuth 2.0 authorization URL from the server. + Returns: WrappedGenericMessageResponse """ @@ -582,9 +563,8 @@ class UsersSDK: async def oauth_google_callback( self, code: str, state: str ) -> WrappedLoginResponse: - """ - Exchange `code` and `state` with the Google OAuth 2.0 callback route. - """ + """Exchange `code` and `state` with the Google OAuth 2.0 callback + route.""" response_dict = await self.client._make_request( "GET", "users/oauth/google/callback", @@ -597,9 +577,8 @@ class UsersSDK: async def oauth_github_callback( self, code: str, state: str ) -> WrappedLoginResponse: - """ - Exchange `code` and `state` with the GitHub OAuth 2.0 callback route. - """ + """Exchange `code` and `state` with the GitHub OAuth 2.0 callback + route.""" response_dict = await self.client._make_request( "GET", "users/oauth/github/callback", diff --git a/py/sdk/async_client.py b/py/sdk/async_client.py index 49c22037e..4593116f3 100644 --- a/py/sdk/async_client.py +++ b/py/sdk/async_client.py @@ -22,9 +22,7 @@ from .base.base_client import BaseClient class R2RAsyncClient(BaseClient): - """ - Asynchronous client for interacting with the R2R API. - """ + """Asynchronous client for interacting with the R2R API.""" def __init__( self, diff --git a/py/sdk/sync_methods/__init__.py b/py/sdk/sync_methods/__init__.py index bda8063cc..efa520d67 100644 --- a/py/sdk/sync_methods/__init__.py +++ b/py/sdk/sync_methods/__init__.py @@ -1,23 +1,23 @@ -from .chunks import * -from .collections import * -from .conversations import * -from .documents import * -from .graphs import * -from .indices import * -from .prompts import * -from .retrieval import * -from .system import * -from .users import * +from .chunks import ChunksSDK +from .collections import CollectionsSDK +from .conversations import ConversationsSDK +from .documents import DocumentsSDK +from .graphs import GraphsSDK +from .indices import IndicesSDK +from .prompts import PromptsSDK +from .retrieval import RetrievalSDK +from .system import SystemSDK +from .users import UsersSDK __all__ = [ - "Chunks", - "Collections", - "Conversations", - "Documents", - "Graphs", - "Indices", - "Prompts", - "Retrieval", - "System", - "Users", + "ChunksSDK", + "CollectionsSDK", + "ConversationsSDK", + "DocumentsSDK", + "GraphsSDK", + "IndicesSDK", + "PromptsSDK", + "RetrievalSDK", + "SystemSDK", + "UsersSDK", ] diff --git a/py/sdk/sync_methods/chunks.py b/py/sdk/sync_methods/chunks.py index 8f42ba2ba..b7e2124fc 100644 --- a/py/sdk/sync_methods/chunks.py +++ b/py/sdk/sync_methods/chunks.py @@ -13,9 +13,7 @@ from ..models import SearchSettings class ChunksSDK: - """ - SDK for interacting with chunks in the v3 API. - """ + """SDK for interacting with chunks in the v3 API.""" def __init__(self, client): self.client = client @@ -24,8 +22,7 @@ class ChunksSDK: self, chunk: dict[str, str], ) -> WrappedChunkResponse: - """ - Update an existing chunk. + """Update an existing chunk. Args: chunk (dict[str, str]): Chunk to update. Should contain: @@ -47,8 +44,7 @@ class ChunksSDK: self, id: str | UUID, ) -> WrappedChunkResponse: - """ - Get a specific chunk. + """Get a specific chunk. Args: id (str | UUID): Chunk ID to retrieve @@ -73,8 +69,7 @@ class ChunksSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedChunksResponse: - """ - List chunks for a specific document. + """List chunks for a specific document. Args: document_id (str | UUID): Document ID to get chunks for @@ -105,8 +100,7 @@ class ChunksSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a specific chunk. + """Delete a specific chunk. Args: id (str | UUID): ID of chunk to delete @@ -130,8 +124,7 @@ class ChunksSDK: limit: Optional[int] = 100, filters: Optional[dict] = None, ) -> WrappedChunksResponse: - """ - List chunks with pagination support. + """List chunks with pagination support. Args: include_vectors (bool, optional): Include vector data in response. Defaults to False. @@ -167,8 +160,7 @@ class ChunksSDK: query: str, search_settings: Optional[dict | SearchSettings] = None, ) -> WrappedVectorSearchResponse: - """ - Conduct a vector and/or graph search. + """Conduct a vector and/or graph search. Args: query (str): The query to search for. diff --git a/py/sdk/sync_methods/collections.py b/py/sdk/sync_methods/collections.py index 4de9de4d3..69e9d7b40 100644 --- a/py/sdk/sync_methods/collections.py +++ b/py/sdk/sync_methods/collections.py @@ -20,8 +20,7 @@ class CollectionsSDK: name: str, description: Optional[str] = None, ) -> WrappedCollectionResponse: - """ - Create a new collection. + """Create a new collection. Args: name (str): Name of the collection @@ -46,8 +45,7 @@ class CollectionsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCollectionsResponse: - """ - List collections with pagination and filtering options. + """List collections with pagination and filtering options. Args: ids (Optional[list[str | UUID]]): Filter collections by ids @@ -74,8 +72,7 @@ class CollectionsSDK: self, id: str | UUID, ) -> WrappedCollectionResponse: - """ - Get detailed information about a specific collection. + """Get detailed information about a specific collection. Args: id (str | UUID): Collection ID to retrieve @@ -96,8 +93,7 @@ class CollectionsSDK: description: Optional[str] = None, generate_description: Optional[bool] = False, ) -> WrappedCollectionResponse: - """ - Update collection information. + """Update collection information. Args: id (str | UUID): Collection ID to update @@ -129,8 +125,7 @@ class CollectionsSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a collection. + """Delete a collection. Args: id (str | UUID): Collection ID to delete @@ -150,8 +145,7 @@ class CollectionsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedDocumentsResponse: - """ - List all documents in a collection. + """List all documents in a collection. Args: id (str | UUID): Collection ID @@ -180,8 +174,7 @@ class CollectionsSDK: id: str | UUID, document_id: str | UUID, ) -> WrappedGenericMessageResponse: - """ - Add a document to a collection. + """Add a document to a collection. Args: id (str | UUID): Collection ID @@ -203,8 +196,7 @@ class CollectionsSDK: id: str | UUID, document_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a document from a collection. + """Remove a document from a collection. Args: id (str | UUID): Collection ID @@ -227,8 +219,7 @@ class CollectionsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedUsersResponse: - """ - List all users in a collection. + """List all users in a collection. Args: id (str, UUID): Collection ID @@ -254,8 +245,7 @@ class CollectionsSDK: id: str | UUID, user_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Add a user to a collection. + """Add a user to a collection. Args: id (str | UUID): Collection ID @@ -275,8 +265,7 @@ class CollectionsSDK: id: str | UUID, user_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a user from a collection. + """Remove a user from a collection. Args: id (str | UUID): Collection ID @@ -299,8 +288,7 @@ class CollectionsSDK: settings: Optional[dict] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Extract entities and relationships from documents in a collection. + """Extract entities and relationships from documents in a collection. Args: id (str | UUID): Collection ID to extract from @@ -330,8 +318,7 @@ class CollectionsSDK: def retrieve_by_name( self, name: str, owner_id: Optional[str] = None ) -> WrappedCollectionResponse: - """ - Retrieve a collection by its name. + """Retrieve a collection by its name. For non-superusers, the backend will use the authenticated user's ID. For superusers, the caller must supply an owner_id to restrict the search. diff --git a/py/sdk/sync_methods/conversations.py b/py/sdk/sync_methods/conversations.py index c8f22f287..d3da6bb3c 100644 --- a/py/sdk/sync_methods/conversations.py +++ b/py/sdk/sync_methods/conversations.py @@ -20,8 +20,7 @@ class ConversationsSDK: self, name: Optional[str] = None, ) -> WrappedConversationResponse: - """ - Create a new conversation. + """Create a new conversation. Returns: WrappedConversationResponse @@ -45,8 +44,7 @@ class ConversationsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedConversationsResponse: - """ - List conversations with pagination and sorting options. + """List conversations with pagination and sorting options. Args: ids (Optional[list[str | UUID]]): List of conversation IDs to retrieve @@ -76,8 +74,7 @@ class ConversationsSDK: self, id: str | UUID, ) -> WrappedConversationMessagesResponse: - """ - Get detailed information about a specific conversation. + """Get detailed information about a specific conversation. Args: id (str | UUID): The ID of the conversation to retrieve @@ -98,8 +95,7 @@ class ConversationsSDK: id: str | UUID, name: str, ) -> WrappedConversationResponse: - """ - Update an existing conversation. + """Update an existing conversation. Args: id (str | UUID): The ID of the conversation to update @@ -125,8 +121,7 @@ class ConversationsSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a conversation. + """Delete a conversation. Args: id (str | UUID): The ID of the conversation to delete @@ -150,8 +145,7 @@ class ConversationsSDK: metadata: Optional[dict] = None, parent_id: Optional[str] = None, ) -> WrappedMessageResponse: - """ - Add a new message to a conversation. + """Add a new message to a conversation. Args: id (str | UUID): The ID of the conversation to add the message to @@ -188,8 +182,7 @@ class ConversationsSDK: content: Optional[str] = None, metadata: Optional[dict] = None, ) -> WrappedMessageResponse: - """ - Update an existing message in a conversation. + """Update an existing message in a conversation. Args: id (str | UUID): The ID of the conversation containing the message @@ -219,8 +212,8 @@ class ConversationsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export conversations to a CSV file, streaming the results directly to disk. + """Export conversations to a CSV file, streaming the results directly + to disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -270,8 +263,8 @@ class ConversationsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export messages to a CSV file, streaming the results directly to disk. + """Export messages to a CSV file, streaming the results directly to + disk. Args: output_path (str | Path): Local path where the CSV file should be saved diff --git a/py/sdk/sync_methods/documents.py b/py/sdk/sync_methods/documents.py index e66a240e6..3948d6f4c 100644 --- a/py/sdk/sync_methods/documents.py +++ b/py/sdk/sync_methods/documents.py @@ -28,9 +28,7 @@ from ..models import IngestionMode, SearchMode, SearchSettings class DocumentsSDK: - """ - SDK for interacting with documents in the v3 API. - """ + """SDK for interacting with documents in the v3 API.""" def __init__(self, client): self.client = client @@ -47,8 +45,7 @@ class DocumentsSDK: ingestion_config: Optional[dict | IngestionMode] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedIngestionResponse: - """ - Create a new document from either a file or content. + """Create a new document from either a file or content. Args: file_path (Optional[str]): The file to upload, if any @@ -145,8 +142,7 @@ class DocumentsSDK: self, id: str | UUID, ) -> WrappedDocumentResponse: - """ - Get a specific document by ID. + """Get a specific document by ID. Args: id (str | UUID): ID of document to retrieve @@ -182,9 +178,7 @@ class DocumentsSDK: end_date: Optional[datetime] = None, output_path: Optional[str | Path] = None, ) -> BytesIO | None: - """ - Download multiple documents as a zip file. - """ + """Download multiple documents as a zip file.""" params: dict[str, Any] = {} if document_ids: params["document_ids"] = [str(doc_id) for doc_id in document_ids] @@ -222,8 +216,8 @@ class DocumentsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export documents to a CSV file, streaming the results directly to disk. + """Export documents to a CSV file, streaming the results directly to + disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -273,8 +267,8 @@ class DocumentsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export documents to a CSV file, streaming the results directly to disk. + """Export documents to a CSV file, streaming the results directly to + disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -325,8 +319,8 @@ class DocumentsSDK: filters: Optional[dict] = None, include_header: bool = True, ) -> None: - """ - Export document relationships to a CSV file, streaming the results directly to disk. + """Export document relationships to a CSV file, streaming the results + directly to disk. Args: output_path (str | Path): Local path where the CSV file should be saved @@ -373,8 +367,7 @@ class DocumentsSDK: self, id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a specific document. + """Delete a specific document. Args: id (str | UUID): ID of document to delete @@ -397,8 +390,7 @@ class DocumentsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedChunksResponse: - """ - Get chunks for a specific document. + """Get chunks for a specific document. Args: id (str | UUID): ID of document to retrieve chunks for @@ -430,8 +422,7 @@ class DocumentsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCollectionsResponse: - """ - List collections for a specific document. + """List collections for a specific document. Args: id (str | UUID): ID of document to retrieve collections for @@ -459,8 +450,7 @@ class DocumentsSDK: self, filters: dict, ) -> WrappedBooleanResponse: - """ - Delete documents based on filters. + """Delete documents based on filters. Args: filters (dict): Filters to apply when selecting documents to delete @@ -484,8 +474,7 @@ class DocumentsSDK: settings: Optional[dict] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Extract entities and relationships from a document. + """Extract entities and relationships from a document. Args: id (str, UUID): ID of document to extract from @@ -516,8 +505,7 @@ class DocumentsSDK: limit: Optional[int] = 100, include_embeddings: Optional[bool] = False, ) -> WrappedEntitiesResponse: - """ - List entities extracted from a document. + """List entities extracted from a document. Args: id (str | UUID): ID of document to get entities from @@ -550,8 +538,7 @@ class DocumentsSDK: entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, ) -> WrappedRelationshipsResponse: - """ - List relationships extracted from a document. + """List relationships extracted from a document. Args: id (str | UUID): ID of document to get relationships from @@ -587,8 +574,7 @@ class DocumentsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedDocumentsResponse: - """ - List documents with pagination. + """List documents with pagination. Args: ids (Optional[list[str | UUID]]): Optional list of document IDs to filter by @@ -620,8 +606,7 @@ class DocumentsSDK: search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, ) -> WrappedDocumentSearchResponse: - """ - Conduct a vector and/or graph search. + """Conduct a vector and/or graph search. Args: query (str): The query to search for. @@ -654,8 +639,7 @@ class DocumentsSDK: settings: Optional[dict] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Deduplicate entities and relationships from a document. + """Deduplicate entities and relationships from a document. Args: id (str, UUID): ID of document to extract from @@ -681,8 +665,7 @@ class DocumentsSDK: return WrappedGenericMessageResponse(**response_dict) def create_sample(self, hi_res: bool = False) -> WrappedIngestionResponse: - """ - Ingest a sample document into R2R. + """Ingest a sample document into R2R. This method downloads a sample file from a predefined URL, saves it as a temporary file, and ingests it using the `create` method. The diff --git a/py/sdk/sync_methods/graphs.py b/py/sdk/sync_methods/graphs.py index 09fa76257..d6903209b 100644 --- a/py/sdk/sync_methods/graphs.py +++ b/py/sdk/sync_methods/graphs.py @@ -17,9 +17,7 @@ from shared.api.models import ( class GraphsSDK: - """ - SDK for interacting with knowledge graphs in the v3 API. - """ + """SDK for interacting with knowledge graphs in the v3 API.""" def __init__(self, client): self.client = client @@ -30,8 +28,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedGraphsResponse: - """ - List graphs with pagination and filtering options. + """List graphs with pagination and filtering options. Args: ids (Optional[list[str | UUID]]): Filter graphs by ids @@ -58,8 +55,7 @@ class GraphsSDK: self, collection_id: str | UUID, ) -> WrappedGraphResponse: - """ - Get detailed information about a specific graph. + """Get detailed information about a specific graph. Args: collection_id (str | UUID): Graph ID to retrieve @@ -77,8 +73,7 @@ class GraphsSDK: self, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Deletes a graph and all its associated data. + """Deletes a graph and all its associated data. This endpoint permanently removes the specified graph along with all entities and relationships that belong to only this graph. @@ -103,8 +98,7 @@ class GraphsSDK: name: Optional[str] = None, description: Optional[str] = None, ) -> WrappedGraphResponse: - """ - Update graph information. + """Update graph information. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -135,8 +129,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedEntitiesResponse: - """ - List entities in a graph. + """List entities in a graph. Args: collection_id (str | UUID): Graph ID to list entities from @@ -165,8 +158,7 @@ class GraphsSDK: collection_id: str | UUID, entity_id: str | UUID, ) -> WrappedEntityResponse: - """ - Get entity information in a graph. + """Get entity information in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -188,8 +180,7 @@ class GraphsSDK: collection_id: str | UUID, entity_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove an entity from a graph. + """Remove an entity from a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -212,8 +203,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedRelationshipsResponse: - """ - List relationships in a graph. + """List relationships in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -242,8 +232,7 @@ class GraphsSDK: collection_id: str | UUID, relationship_id: str | UUID, ) -> WrappedRelationshipResponse: - """ - Get relationship information in a graph. + """Get relationship information in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -265,8 +254,7 @@ class GraphsSDK: collection_id: str | UUID, relationship_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a relationship from a graph. + """Remove a relationship from a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -289,8 +277,7 @@ class GraphsSDK: settings: Optional[dict] = None, run_with_orchestration: bool = True, ) -> WrappedGenericMessageResponse: - """ - Build a graph. + """Build a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -320,8 +307,7 @@ class GraphsSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCommunitiesResponse: - """ - List communities in a graph. + """List communities in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -350,8 +336,7 @@ class GraphsSDK: collection_id: str | UUID, community_id: str | UUID, ) -> WrappedCommunityResponse: - """ - Get community information in a graph. + """Get community information in a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -380,8 +365,7 @@ class GraphsSDK: level: Optional[int] = None, attributes: Optional[dict] = None, ) -> WrappedCommunityResponse: - """ - Update community information. + """Update community information. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -427,8 +411,7 @@ class GraphsSDK: collection_id: str | UUID, community_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a community from a graph. + """Remove a community from a graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -449,8 +432,8 @@ class GraphsSDK: self, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Adds documents to a graph by copying their entities and relationships. + """Adds documents to a graph by copying their entities and + relationships. This endpoint: 1. Copies document entities to the graphs_entities table @@ -483,8 +466,7 @@ class GraphsSDK: collection_id: str | UUID, document_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Removes a document from a graph and removes any associated entities + """Removes a document from a graph and removes any associated entities. This endpoint: 1. Removes the document ID from the graph's document_ids array @@ -511,8 +493,7 @@ class GraphsSDK: category: Optional[str] = None, metadata: Optional[dict] = None, ) -> WrappedEntityResponse: - """ - Creates a new entity in the graph. + """Creates a new entity in the graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -554,8 +535,7 @@ class GraphsSDK: weight: Optional[float] = None, metadata: Optional[dict] = None, ) -> WrappedRelationshipResponse: - """ - Creates a new relationship in the graph. + """Creates a new relationship in the graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph @@ -602,8 +582,7 @@ class GraphsSDK: rating: Optional[float] = None, rating_explanation: Optional[str] = None, ) -> WrappedCommunityResponse: - """ - Creates a new community in the graph. + """Creates a new community in the graph. Args: collection_id (str | UUID): The collection ID corresponding to the graph diff --git a/py/sdk/sync_methods/indices.py b/py/sdk/sync_methods/indices.py index 88f01fe91..1db9afc43 100644 --- a/py/sdk/sync_methods/indices.py +++ b/py/sdk/sync_methods/indices.py @@ -17,8 +17,7 @@ class IndicesSDK: config: dict, run_with_orchestration: Optional[bool] = True, ) -> WrappedGenericMessageResponse: - """ - Create a new vector similarity search index in the database. + """Create a new vector similarity search index in the database. Args: config (dict | IndexConfig): Configuration for the vector index. @@ -49,8 +48,8 @@ class IndicesSDK: offset: Optional[int] = 0, limit: Optional[int] = 10, ) -> WrappedVectorIndicesResponse: - """ - List existing vector similarity search indices with pagination support. + """List existing vector similarity search indices with pagination + support. Args: filters (Optional[dict]): Filter criteria for indices. @@ -80,8 +79,7 @@ class IndicesSDK: index_name: str, table_name: str = "vectors", ) -> WrappedVectorIndexResponse: - """ - Get detailed information about a specific vector index. + """Get detailed information about a specific vector index. Args: index_name (str): The name of the index to retrieve. @@ -103,8 +101,7 @@ class IndicesSDK: index_name: str, table_name: str = "vectors", ) -> WrappedGenericMessageResponse: - """ - Delete an existing vector index. + """Delete an existing vector index. Args: index_name (str): The name of the index to retrieve. diff --git a/py/sdk/sync_methods/prompts.py b/py/sdk/sync_methods/prompts.py index f0e30f3d8..fc123c07e 100644 --- a/py/sdk/sync_methods/prompts.py +++ b/py/sdk/sync_methods/prompts.py @@ -16,8 +16,8 @@ class PromptsSDK: def create( self, name: str, template: str, input_types: dict ) -> WrappedGenericMessageResponse: - """ - Create a new prompt. + """Create a new prompt. + Args: name (str): The name of the prompt template (str): The template string for the prompt @@ -40,8 +40,8 @@ class PromptsSDK: return WrappedGenericMessageResponse(**response_dict) def list(self) -> WrappedPromptsResponse: - """ - List all available prompts. + """List all available prompts. + Returns: dict: List of all available prompts """ @@ -59,8 +59,8 @@ class PromptsSDK: inputs: Optional[dict] = None, prompt_override: Optional[str] = None, ) -> WrappedPromptResponse: - """ - Get a specific prompt by name, optionally with inputs and override. + """Get a specific prompt by name, optionally with inputs and override. + Args: name (str): The name of the prompt to retrieve inputs (Optional[dict]): JSON-encoded inputs for the prompt @@ -88,8 +88,8 @@ class PromptsSDK: template: Optional[str] = None, input_types: Optional[dict] = None, ) -> WrappedGenericMessageResponse: - """ - Update an existing prompt's template and/or input types. + """Update an existing prompt's template and/or input types. + Args: name (str): The name of the prompt to update template (Optional[str]): The updated template string for the prompt @@ -112,8 +112,8 @@ class PromptsSDK: return WrappedGenericMessageResponse(**response_dict) def delete(self, name: str) -> WrappedBooleanResponse: - """ - Delete a prompt by name. + """Delete a prompt by name. + Args: name (str): The name of the prompt to delete Returns: diff --git a/py/sdk/sync_methods/retrieval.py b/py/sdk/sync_methods/retrieval.py index bb2974a41..7cf8f892c 100644 --- a/py/sdk/sync_methods/retrieval.py +++ b/py/sdk/sync_methods/retrieval.py @@ -19,9 +19,7 @@ from ..models import ( class RetrievalSDK: - """ - SDK for interacting with documents in the v3 API. - """ + """SDK for interacting with documents in the v3 API.""" def __init__(self, client): self.client = client @@ -32,8 +30,7 @@ class RetrievalSDK: search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, ) -> WrappedSearchResponse: - """ - Conduct a vector and/or graph search. + """Conduct a vector and/or graph search. Args: query (str): The query to search for. @@ -116,8 +113,8 @@ class RetrievalSDK: task_prompt_override: Optional[str] = None, include_title_if_available: Optional[bool] = False, ) -> WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: - """ - Conducts a Retrieval Augmented Generation (RAG) search with the given query. + """Conducts a Retrieval Augmented Generation (RAG) search with the + given query. Args: query (str): The query to search for. @@ -178,8 +175,7 @@ class RetrievalSDK: max_tool_context_length: Optional[int] = None, use_extended_prompt: Optional[bool] = True, ) -> WrappedAgentResponse | AsyncGenerator[Message, None]: - """ - Performs a single turn in a conversation with a RAG agent. + """Performs a single turn in a conversation with a RAG agent. Args: message (Optional[dict | Message]): The message to send to the agent. @@ -244,8 +240,7 @@ class RetrievalSDK: tools: Optional[list[dict]] = None, max_tool_context_length: Optional[int] = None, ) -> WrappedAgentResponse | AsyncGenerator[Message, None]: - """ - Performs a single turn in a conversation with a RAG agent. + """Performs a single turn in a conversation with a RAG agent. Args: message (Optional[dict | Message]): The message to send to the agent. diff --git a/py/sdk/sync_methods/system.py b/py/sdk/sync_methods/system.py index 61d4e554c..9a26e5413 100644 --- a/py/sdk/sync_methods/system.py +++ b/py/sdk/sync_methods/system.py @@ -13,9 +13,7 @@ class SystemSDK: self.client = client def health(self) -> WrappedGenericMessageResponse: - """ - Check the health of the R2R server. - """ + """Check the health of the R2R server.""" response_dict = self.client._make_request( "GET", "health", version="v3" ) @@ -28,8 +26,7 @@ class SystemSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedLogsResponse: - """ - Get logs from the server. + """Get logs from the server. Args: run_type_filter (Optional[str]): The run type to filter by. @@ -55,8 +52,7 @@ class SystemSDK: return WrappedLogsResponse(**response_dict) def settings(self) -> WrappedSettingsResponse: - """ - Get the configuration settings for the R2R server. + """Get the configuration settings for the R2R server. Returns: dict: The server settings. @@ -68,8 +64,8 @@ class SystemSDK: return WrappedSettingsResponse(**response_dict) def status(self) -> WrappedServerStatsResponse: - """ - Get statistics about the server, including the start time, uptime, CPU usage, and memory usage. + """Get statistics about the server, including the start time, uptime, + CPU usage, and memory usage. Returns: dict: The server statistics. diff --git a/py/sdk/sync_methods/users.py b/py/sdk/sync_methods/users.py index 0dac687ee..3021642fe 100644 --- a/py/sdk/sync_methods/users.py +++ b/py/sdk/sync_methods/users.py @@ -27,8 +27,7 @@ class UsersSDK: bio: Optional[str] = None, profile_picture: Optional[str] = None, ) -> WrappedUserResponse: - """ - Register a new user. + """Register a new user. Args: email (str): User's email address @@ -62,9 +61,7 @@ class UsersSDK: def send_verification_email( self, email: str ) -> WrappedGenericMessageResponse: - """ - Request that a verification email to a user. - """ + """Request that a verification email to a user.""" response_dict = self.client._make_request( "POST", "users/send-verification-email", @@ -75,9 +72,8 @@ class UsersSDK: return WrappedGenericMessageResponse(**response_dict) def delete(self, id: str | UUID, password: str) -> WrappedBooleanResponse: - """ - Delete a specific user. - Users can only delete their own account unless they are superusers. + """Delete a specific user. Users can only delete their own account + unless they are superusers. Args: id (str | UUID): User ID to delete @@ -101,8 +97,7 @@ class UsersSDK: def verify_email( self, email: str, verification_code: str ) -> WrappedGenericMessageResponse: - """ - Verify a user's email address. + """Verify a user's email address. Args: email (str): User's email address @@ -125,8 +120,7 @@ class UsersSDK: return WrappedGenericMessageResponse(**response_dict) def login(self, email: str, password: str) -> WrappedLoginResponse: - """ - Log in a user. + """Log in a user. Args: email (str): User's email address @@ -201,8 +195,7 @@ class UsersSDK: def change_password( self, current_password: str, new_password: str ) -> WrappedGenericMessageResponse: - """ - Change the user's password. + """Change the user's password. Args: current_password (str): User's current password @@ -227,8 +220,7 @@ class UsersSDK: def request_password_reset( self, email: str ) -> WrappedGenericMessageResponse: - """ - Request a password reset. + """Request a password reset. Args: email (str): User's email address @@ -248,8 +240,7 @@ class UsersSDK: def reset_password( self, reset_token: str, new_password: str ) -> WrappedGenericMessageResponse: - """ - Reset password using a reset token. + """Reset password using a reset token. Args: reset_token (str): Password reset token @@ -277,8 +268,7 @@ class UsersSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedUsersResponse: - """ - List users with pagination and filtering options. + """List users with pagination and filtering options. Args: offset (int, optional): Specifies the number of objects to skip. Defaults to 0. @@ -307,8 +297,7 @@ class UsersSDK: self, id: str | UUID, ) -> WrappedUserResponse: - """ - Get a specific user. + """Get a specific user. Args: id (str | UUID): User ID to retrieve @@ -327,8 +316,7 @@ class UsersSDK: def me( self, ) -> WrappedUserResponse: - """ - Get detailed information about the currently authenticated user. + """Get detailed information about the currently authenticated user. Returns: dict: Detailed user information @@ -352,8 +340,7 @@ class UsersSDK: limits_overrides: dict | None = None, metadata: dict[str, str | None] | None = None, ) -> WrappedUserResponse: - """ - Update user information. + """Update user information. Args: id (str | UUID): User ID to update @@ -397,8 +384,7 @@ class UsersSDK: offset: Optional[int] = 0, limit: Optional[int] = 100, ) -> WrappedCollectionsResponse: - """ - Get all collections associated with a specific user. + """Get all collections associated with a specific user. Args: id (str | UUID): User ID to get collections for @@ -427,8 +413,7 @@ class UsersSDK: id: str | UUID, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Add a user to a collection. + """Add a user to a collection. Args: id (str | UUID): User ID to add @@ -447,8 +432,7 @@ class UsersSDK: id: str | UUID, collection_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Remove a user from a collection. + """Remove a user from a collection. Args: id (str | UUID): User ID to remove @@ -471,8 +455,7 @@ class UsersSDK: name: Optional[str] = None, description: Optional[str] = None, ) -> WrappedAPIKeyResponse: - """ - Create a new API key for the specified user. + """Create a new API key for the specified user. Args: id (str | UUID): User ID to create API key for @@ -501,8 +484,7 @@ class UsersSDK: self, id: str | UUID, ) -> WrappedAPIKeysResponse: - """ - List all API keys for the specified user. + """List all API keys for the specified user. Args: id (str | UUID): User ID to list API keys for @@ -523,8 +505,7 @@ class UsersSDK: id: str | UUID, key_id: str | UUID, ) -> WrappedBooleanResponse: - """ - Delete a specific API key for the specified user. + """Delete a specific API key for the specified user. Args: id (str | UUID): User ID @@ -551,8 +532,8 @@ class UsersSDK: return WrappedLimitsResponse(**response_dict) def oauth_google_authorize(self) -> WrappedGenericMessageResponse: - """ - Get Google OAuth 2.0 authorization URL from the server. + """Get Google OAuth 2.0 authorization URL from the server. + Returns: WrappedGenericMessageResponse """ @@ -565,8 +546,8 @@ class UsersSDK: return WrappedGenericMessageResponse(**response_dict) def oauth_github_authorize(self) -> WrappedGenericMessageResponse: - """ - Get GitHub OAuth 2.0 authorization URL from the server. + """Get GitHub OAuth 2.0 authorization URL from the server. + Returns: {"redirect_url": "..."} """ response_dict = self.client._make_request( @@ -580,9 +561,8 @@ class UsersSDK: def oauth_google_callback( self, code: str, state: str ) -> WrappedLoginResponse: - """ - Exchange `code` and `state` with the Google OAuth 2.0 callback route. - """ + """Exchange `code` and `state` with the Google OAuth 2.0 callback + route.""" response_dict = self.client._make_request( "GET", "users/oauth/google/callback", @@ -595,9 +575,8 @@ class UsersSDK: def oauth_github_callback( self, code: str, state: str ) -> WrappedLoginResponse: - """ - Exchange `code` and `state` with the GitHub OAuth 2.0 callback route. - """ + """Exchange `code` and `state` with the GitHub OAuth 2.0 callback + route.""" response_dict = self.client._make_request( "GET", "users/oauth/github/callback", diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index adf16e595..639d2d3dd 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -119,6 +119,7 @@ __all__ = [ "GraphCommunityResult", "GraphSearchSettings", "ChunkSearchSettings", + "ContextDocumentResult", "ChunkSearchResult", "SearchSettings", "select_search_filters", diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 464f65e08..07d7f6c0e 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -10,6 +10,7 @@ from uuid import UUID, uuid4 from pydantic import Field from .base import R2RSerializable +from .llm import GenerationConfig logger = logging.getLogger() @@ -194,7 +195,8 @@ class DocumentResponse(R2RSerializable): total_tokens: Optional[int] = None def convert_to_db_entry(self): - """Prepare the document info for database entry, extracting certain fields from metadata.""" + """Prepare the document info for database entry, extracting certain + fields from metadata.""" now = datetime.now() # Format the embedding properly for Postgres vector type @@ -283,13 +285,8 @@ class IngestionMode(str, Enum): custom = "custom" -from .llm import GenerationConfig - - class ChunkEnrichmentSettings(R2RSerializable): - """ - Settings for chunk enrichment. - """ + """Settings for chunk enrichment.""" enable_chunk_enrichment: bool = Field( default=False, diff --git a/py/shared/abstractions/exception.py b/py/shared/abstractions/exception.py index 67eaffcab..3dedfae8f 100644 --- a/py/shared/abstractions/exception.py +++ b/py/shared/abstractions/exception.py @@ -37,7 +37,7 @@ class R2RDocumentProcessingError(R2RException): class PDFParsingError(R2RException): - """Custom exception for PDF parsing errors""" + """Custom exception for PDF parsing errors.""" def __init__( self, @@ -55,8 +55,7 @@ class PopplerNotFoundError(PDFParsingError): """Specific error for when Poppler is not installed.""" def __init__(self): - installation_instructions = textwrap.dedent( - """ + installation_instructions = textwrap.dedent(""" PDF processing requires Poppler to be installed. Please install Poppler and ensure it's in your system PATH. Installing poppler: @@ -68,8 +67,7 @@ class PopplerNotFoundError(PDFParsingError): 2. Move extracted directory to desired location 3. Add bin/ directory to PATH 4. Test by running 'pdftoppm -h' in terminal - """ - ) + """) super().__init__( message=installation_instructions, status_code=422, diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index e9bc361ef..0b8fd2364 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -37,7 +37,11 @@ class Entity(R2RSerializable): class Relationship(R2RSerializable): - """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" + """A relationship between two entities. + + This is a generic relationship, and can be used to represent any type of + relationship between any two entities. + """ id: Optional[UUID] = None subject: str @@ -152,13 +156,11 @@ class GraphCreationSettings(R2RSerializable): graph_extraction_prompt: str = Field( default="graph_extraction", description="The prompt to use for knowledge graph extraction.", - alias="graph_extraction", # TODO - mark deprecated & remove ) graph_entity_description_prompt: str = Field( default="graph_entity_description", description="The prompt to use for entity description generation.", - alias="graph_entity_description_prompt", # TODO - mark deprecated & remove ) entity_types: list[str] = Field( @@ -173,17 +175,20 @@ class GraphCreationSettings(R2RSerializable): chunk_merge_count: int = Field( default=2, - description="The number of extractions to merge into a single graph extraction.", + description="""The number of extractions to merge into a single graph + extraction.""", ) max_knowledge_relationships: int = Field( default=100, - description="The maximum number of knowledge relationships to extract from each chunk.", + description="""The maximum number of knowledge relationships to extract + from each chunk.""", ) max_description_input_length: int = Field( default=65536, - description="The maximum length of the description for a node in the graph.", + description="""The maximum length of the description for a node in the + graph.""", ) generation_config: Optional[GenerationConfig] = Field( @@ -202,13 +207,13 @@ class GraphEnrichmentSettings(R2RSerializable): force_graph_search_results_enrichment: bool = Field( default=False, - description="Force run the enrichment step even if graph creation is still in progress for some documents.", + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", ) graph_communities_prompt: str = Field( default="graph_communities", description="The prompt to use for knowledge graph enrichment.", - alias="graph_communities", # TODO - mark deprecated & remove ) max_summary_input_length: int = Field( @@ -232,7 +237,8 @@ class GraphCommunitySettings(R2RSerializable): force_graph_search_results_enrichment: bool = Field( default=False, - description="Force run the enrichment step even if graph creation is still in progress for some documents.", + description="""Force run the enrichment step even if graph creation is + still in progress for some documents.""", ) graph_communities: str = Field( diff --git a/py/shared/abstractions/llm.py b/py/shared/abstractions/llm.py index 72897b292..fb84c98c6 100644 --- a/py/shared/abstractions/llm.py +++ b/py/shared/abstractions/llm.py @@ -12,7 +12,6 @@ from .base import R2RSerializable if TYPE_CHECKING: from .search import AggregateSearchResult - LLMChatCompletion = ChatCompletion LLMChatCompletionChunk = ChatCompletionChunk diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 2582b3291..977b82711 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -134,18 +134,6 @@ class GraphSearchResult(R2RSerializable): metadata: dict[str, Any] = {} score: Optional[float] = None - class Config: - json_schema_extra = { - "example": { - "content": GraphEntityResult.Config.json_schema_extra, - "result_type": "entity", - "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], - "metadata": { - "associated_query": "What is the capital of France?" - }, - } - } - def __str__(self) -> str: return f"GraphSearchResult(content={self.content}, result_type={self.result_type})" @@ -238,10 +226,8 @@ class WebSearchResponse(R2RSerializable): class ContextDocumentResult(R2RSerializable): - """ - Holds a single 'document' plus its 'chunks', exactly as your - content_method returns them, or tidied up a bit. - """ + """Holds a single 'document' plus its 'chunks', exactly as your + content_method returns them, or tidied up a bit.""" document: dict[str, Any] # or create a formal Document model chunks: list[str] = Field(default_factory=list) @@ -425,7 +411,8 @@ class GraphSearchSettings(R2RSerializable): class SearchSettings(R2RSerializable): - """Main search settings class that combines shared settings with specialized settings for chunks and graph.""" + """Main search settings class that combines shared settings with + specialized settings for chunks and graph.""" # Search type flags use_hybrid_search: bool = Field( @@ -473,17 +460,20 @@ class SearchSettings(R2RSerializable): ) include_scores: bool = Field( default=True, - description="Whether to include search score values in the search results", + description="""Whether to include search score values in the + search results""", ) # Search strategy and settings search_strategy: str = Field( default="vanilla", - description="Search strategy to use (e.g., 'vanilla', 'query_fusion', 'hyde')", + description="""Search strategy to use + (e.g., 'vanilla', 'query_fusion', 'hyde')""", ) hybrid_settings: HybridSearchSettings = Field( default_factory=HybridSearchSettings, - description="Settings for hybrid search (only used if `use_semantic_search` and `use_fulltext_search` are both true)", + description="""Settings for hybrid search (only used if + `use_semantic_search` and `use_fulltext_search` are both true)""", ) # Specialized settings diff --git a/py/shared/abstractions/vector.py b/py/shared/abstractions/vector.py index 7d86e4e1d..0b88a765a 100644 --- a/py/shared/abstractions/vector.py +++ b/py/shared/abstractions/vector.py @@ -14,8 +14,7 @@ class VectorType(str, Enum): class IndexMethod(str, Enum): - """ - An enum representing the index methods available. + """An enum representing the index methods available. This class currently only supports the 'ivfflat' method but may expand in the future. @@ -35,8 +34,8 @@ class IndexMethod(str, Enum): class IndexMeasure(str, Enum): - """ - An enum representing the types of distance measures available for indexing. + """An enum representing the types of distance measures available for + indexing. Attributes: cosine_distance (str): The cosine distance measure for indexing. @@ -78,9 +77,8 @@ class IndexMeasure(str, Enum): class IndexArgsIVFFlat(R2RSerializable): - """ - A class for arguments that can optionally be supplied to the index creation - method when building an IVFFlat type index. + """A class for arguments that can optionally be supplied to the index + creation method when building an IVFFlat type index. Attributes: nlist (int): The number of IVF centroids that the index should use @@ -90,9 +88,8 @@ class IndexArgsIVFFlat(R2RSerializable): class IndexArgsHNSW(R2RSerializable): - """ - A class for arguments that can optionally be supplied to the index creation - method when building an HNSW type index. + """A class for arguments that can optionally be supplied to the index + creation method when building an HNSW type index. Ref: https://github.com/pgvector/pgvector#index-options @@ -110,9 +107,7 @@ class IndexArgsHNSW(R2RSerializable): class VectorTableName(str, Enum): - """ - This enum represents the different tables where we store vectors. - """ + """This enum represents the different tables where we store vectors.""" CHUNKS = "chunks" ENTITIES_DOCUMENT = "documents_entities" @@ -126,8 +121,7 @@ class VectorTableName(str, Enum): class VectorQuantizationType(str, Enum): - """ - An enum representing the types of quantization available for vectors. + """An enum representing the types of quantization available for vectors. Attributes: FP32 (str): 32-bit floating point quantization. @@ -186,7 +180,8 @@ class Vector(R2RSerializable): class VectorEntry(R2RSerializable): - """A vector entry that can be stored directly in supported vector databases.""" + """A vector entry that can be stored directly in supported vector + databases.""" id: UUID document_id: UUID diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index 25066f612..be258a704 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -170,7 +170,6 @@ WrappedChunksResponse = PaginatedR2RResult[list[ChunkResponse]] WrappedCollectionResponse = R2RResults[CollectionResponse] WrappedCollectionsResponse = PaginatedR2RResult[list[CollectionResponse]] - # Conversation Responses WrappedConversationMessagesResponse = R2RResults[list[MessageResponse]] WrappedConversationResponse = R2RResults[ConversationResponse] diff --git a/py/shared/api/models/retrieval/responses.py b/py/shared/api/models/retrieval/responses.py index 778056f5e..b74d18789 100644 --- a/py/shared/api/models/retrieval/responses.py +++ b/py/shared/api/models/retrieval/responses.py @@ -8,7 +8,6 @@ from shared.abstractions import ( LLMChatCompletion, Message, ) -from shared.abstractions.llm import LLMChatCompletion from shared.api.models.base import R2RResults from shared.api.models.management.responses import DocumentResponse @@ -16,10 +15,10 @@ from ....abstractions import R2RSerializable class Citation(R2RSerializable): - """ - Represents a single citation reference in the RAG response. - Combines both bracket metadata (start/end offsets, snippet range) - and the mapped source fields (id, doc ID, chunk text, etc.). + """Represents a single citation reference in the RAG response. + + Combines both bracket metadata (start/end offsets, snippet range) and the + mapped source fields (id, doc ID, chunk text, etc.). """ # Bracket references @@ -242,7 +241,40 @@ class AgentResponse(R2RSerializable): "messages": [ { "role": "assistant", - "content": "Aristotle (384–322 BC) was an Ancient Greek philosopher and polymath whose contributions have had a profound impact on various fields of knowledge. Here are some key points about his life and work:\n\n1. **Early Life**: Aristotle was born in 384 BC in Stagira, Chalcidice, which is near modern-day Thessaloniki, Greece. His father, Nicomachus, was the personal physician to King Amyntas of Macedon, which exposed Aristotle to medical and biological knowledge from a young age [C].\n\n2. **Education and Career**: After the death of his parents, Aristotle was sent to Athens to study at Plato's Academy, where he remained for about 20 years. After Plato's death, Aristotle left Athens and eventually became the tutor of Alexander the Great [C].\n\n3. **Philosophical Contributions**: Aristotle founded the Lyceum in Athens, where he established the Peripatetic school of philosophy. His works cover a wide range of subjects, including metaphysics, ethics, politics, logic, biology, and aesthetics. His writings laid the groundwork for many modern scientific and philosophical inquiries [A].\n\n4. **Legacy**: Aristotle's influence extends beyond philosophy to the natural sciences, linguistics, economics, and psychology. His method of systematic observation and analysis has been foundational to the development of modern science [A].\n\nAristotle's comprehensive approach to knowledge and his systematic methodology have earned him a lasting legacy as one of the greatest philosophers of all time.\n\nSources:\n- [A] Aristotle's broad range of writings and influence on modern science.\n- [C] Details about Aristotle's early life and education.", + "content": """Aristotle (384–322 BC) was an Ancient + Greek philosopher and polymath whose contributions + have had a profound impact on various fields of + knowledge. + Here are some key points about his life and work: + \n\n1. **Early Life**: Aristotle was born in 384 BC in + Stagira, Chalcidice, which is near modern-day + Thessaloniki, Greece. His father, Nicomachus, was the + personal physician to King Amyntas of Macedon, which + exposed Aristotle to medical and biological knowledge + from a young age [C].\n\n2. **Education and Career**: + After the death of his parents, Aristotle was sent to + Athens to study at Plato's Academy, where he remained + for about 20 years. After Plato's death, Aristotle + left Athens and eventually became the tutor of + Alexander the Great [C]. + \n\n3. **Philosophical Contributions**: Aristotle + founded the Lyceum in Athens, where he established the + Peripatetic school of philosophy. His works cover a + wide range of subjects, including metaphysics, ethics, + politics, logic, biology, and aesthetics. His writings + laid the groundwork for many modern scientific and + philosophical inquiries [A].\n\n4. **Legacy**: + Aristotle's influence extends beyond philosophy to the + natural sciences, linguistics, economics, and + psychology. His method of systematic observation and + analysis has been foundational to the development of + modern science [A].\n\nAristotle's comprehensive + approach to knowledge and his systematic methodology + have earned him a lasting legacy as one of the + greatest philosophers of all time.\n\nSources: + \n- [A] Aristotle's broad range of writings and + influence on modern science.\n- [C] Details about + Aristotle's early life and education.""", "name": None, "function_call": None, "tool_calls": None, @@ -257,13 +289,20 @@ class AgentResponse(R2RSerializable): "snippetEndIndex": 418, "sourceType": "chunk", "id": "e760bb76-1c6e-52eb-910d-0ce5b567011b", - "document_id": "e43864f5-a36f-548e-aacd-6f8d48b30c7f", - "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "document_id": """ + e43864f5-a36f-548e-aacd-6f8d48b30c7f + """, + "owner_id": """ + 2acb499e-8428-543b-bd85-0d9098718220 + """, "collection_ids": [ "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09" ], "score": 0.64, - "text": "Document Title: DeepSeek_R1.pdf\n\nText: could achieve an accuracy of ...", + "text": """ + Document Title: DeepSeek_R1.pdf + \n\nText: could achieve an accuracy of ... + """, "metadata": { "title": "DeepSeek_R1.pdf", "license": "CC-BY-4.0", diff --git a/py/shared/utils/base_utils.py b/py/shared/utils/base_utils.py index 80ccc6ad7..4acc23521 100644 --- a/py/shared/utils/base_utils.py +++ b/py/shared/utils/base_utils.py @@ -31,10 +31,12 @@ def reorder_collector_to_match_final_brackets( collector: Any, # "SearchResultsCollector", final_citations: list["Citation"], ): - """ - Rebuilds collector._results_in_order so that bracket i => aggregator[i-1]. - Each citation's rawIndex indicates which aggregator item the LLM used originally. - We place that aggregator item in the new position for bracket 'index'. + """Rebuilds collector._results_in_order so that bracket i => + aggregator[i-1]. + + Each citation's rawIndex indicates which aggregator item the LLM used + originally. We place that aggregator item in the new position for bracket + 'index'. """ old_list = collector.get_all_results() # [(source_type, result_obj), ...] max_index = max((c.index for c in final_citations), default=0) @@ -61,9 +63,10 @@ def map_citations_to_collector( citations: list["Citation"], collector: Any, # "SearchResultsCollector" ) -> list["Citation"]: - """ - For each citation, use its 'rawIndex' to look up the aggregator item from the - collector. We then fill out the Citation’s sourceType, doc_id, text, metadata, etc. + """For each citation, use its 'rawIndex' to look up the aggregator item + from the collector. + + We then fill out the Citation’s sourceType, doc_id, text, metadata, etc. """ from ..api.models.retrieval.responses import Citation @@ -134,10 +137,12 @@ def map_citations_to_collector( def _expand_citation_span_to_sentence( full_text: str, start: int, end: int ) -> Tuple[int, int]: - """ - Return (sentence_start, sentence_end) for the sentence containing the bracket [n]. + """Return (sentence_start, sentence_end) for the sentence containing the + bracket [n]. + We define a sentence boundary as '.', '?', or '!', optionally followed by - spaces or a newline. This is a simple heuristic; you can refine it as needed. + spaces or a newline. This is a simple heuristic; you can refine it as + needed. """ sentence_enders = {".", "?", "!"} @@ -167,10 +172,11 @@ def _expand_citation_span_to_sentence( def extract_citations(text: str) -> list["Citation"]: - """ - Find bracket references like [3], [10], etc. Return a list of Citation objects - whose 'index' field is the number found in brackets, but we will later rename - that to 'rawIndex' to avoid confusion. + """Find bracket references like [3], [10], etc. + + Return a list of Citation objects whose 'index' field is the number found + in brackets, but we will later rename that to 'rawIndex' to avoid + confusion. """ from ..api.models.retrieval.responses import Citation @@ -203,8 +209,9 @@ def extract_citations(text: str) -> list["Citation"]: def reassign_citations_in_order( text: str, citations: list["Citation"] ) -> Tuple[str, list["Citation"]]: - """ - Sort citations by their start index, unify repeated bracket numbers, and relabel them + """Sort citations by their start index, unify repeated bracket numbers, and + relabel them. + in ascending order of first appearance. Return (new_text, new_citations). - new_citations[i].index = the new bracket number - new_citations[i].rawIndex = the original bracket number @@ -280,11 +287,11 @@ def format_search_results_for_llm( results: AggregateSearchResult, collector: Any, # SearchResultsCollector ) -> str: - """ - Instead of resetting 'source_counter' to 1, we: - - For each chunk / graph / web / contextDoc in `results`, - - Find the aggregator index from the collector, - - Print 'Source [X]:' with that aggregator index. + """Instead of resetting 'source_counter' to 1, we: + + - For each chunk / graph / web / contextDoc in `results`, + - Find the aggregator index from the collector, + - Print 'Source [X]:' with that aggregator index. """ lines = [] @@ -293,7 +300,7 @@ def format_search_results_for_llm( # in the same order. But let's do a "lookup aggregator index" approach: def get_aggregator_index_for_item(item): - for stype, obj, agg_index in collector.get_all_results(): + for _stype, obj, agg_index in collector.get_all_results(): if obj is item: return agg_index return None # not found, fallback @@ -415,16 +422,14 @@ def _generate_id_from_label(label) -> UUID: def generate_id(label: Optional[str] = None) -> UUID: - """ - Generates a unique run id - """ - return _generate_id_from_label(label if label != None else str(uuid4())) + """Generates a unique run id.""" + return _generate_id_from_label( + label if label is not None else str(uuid4()) + ) def generate_document_id(filename: str, user_id: UUID) -> UUID: - """ - Generates a unique document id from a given filename and user id - """ + """Generates a unique document id from a given filename and user id.""" safe_filename = filename.replace("/", "_") return _generate_id_from_label(f"{safe_filename}-{str(user_id)}") @@ -432,37 +437,28 @@ def generate_document_id(filename: str, user_id: UUID) -> UUID: def generate_extraction_id( document_id: UUID, iteration: int = 0, version: str = "0" ) -> UUID: - """ - Generates a unique extraction id from a given document id and iteration - """ + """Generates a unique extraction id from a given document id and + iteration.""" return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}") def generate_default_user_collection_id(user_id: UUID) -> UUID: - """ - Generates a unique collection id from a given user id - """ + """Generates a unique collection id from a given user id.""" return _generate_id_from_label(str(user_id)) def generate_user_id(email: str) -> UUID: - """ - Generates a unique user id from a given email - """ + """Generates a unique user id from a given email.""" return _generate_id_from_label(email) def generate_default_prompt_id(prompt_name: str) -> UUID: - """ - Generates a unique prompt id - """ + """Generates a unique prompt id.""" return _generate_id_from_label(prompt_name) def generate_entity_document_id() -> UUID: - """ - Generates a unique document id inserting entities into a graph - """ + """Generates a unique document id inserting entities into a graph.""" generation_time = datetime.now().isoformat() return _generate_id_from_label(f"entity-{generation_time}") @@ -484,9 +480,7 @@ def validate_uuid(uuid_str: str) -> UUID: def update_settings_from_dict(server_settings, settings_dict: dict): - """ - Updates a settings object with values from a dictionary. - """ + """Updates a settings object with values from a dictionary.""" settings = deepcopy(server_settings) for key, value in settings_dict.items(): if value is not None: @@ -512,12 +506,10 @@ def _decorate_vector_type( def _get_vector_column_str( dimension: int | float, quantization_type: VectorQuantizationType ) -> str: - """ - Returns a string representation of a vector column type. + """Returns a string representation of a vector column type. - Explicitly handles the case where the dimension is not a valid number - meant to support embedding models that do not allow for specifying - the dimension. + Explicitly handles the case where the dimension is not a valid number meant + to support embedding models that do not allow for specifying the dimension. """ if math.isnan(dimension) or dimension <= 0: vector_dim = "" # Allows for Postgres to handle any dimension diff --git a/py/shared/utils/splitter/text.py b/py/shared/utils/splitter/text.py index d8234d07e..92a7c81b4 100644 --- a/py/shared/utils/splitter/text.py +++ b/py/shared/utils/splitter/text.py @@ -2,7 +2,6 @@ # URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851 """**Text Splitters** are classes for splitting text. - **Class hierarchy:** .. code-block:: @@ -18,7 +17,6 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive .. code-block:: Document, Tokenizer, Language, LineType, HeaderType - """ # noqa: E501 from __future__ import annotations @@ -124,15 +122,13 @@ class Serializable(BaseModel, ABC): def lc_secrets(self) -> dict[str, str]: """A map of constructor argument names to secret ids. - For example, - {"openai_api_key": "OPENAI_API_KEY"} + For example, {"openai_api_key": "OPENAI_API_KEY"} """ return {} @property def lc_attributes(self) -> dict: - """ - List of attribute names that should be included in the serialized + """List of attribute names that should be included in the serialized kwargs. These attributes must be accepted by the constructor. @@ -143,8 +139,8 @@ class Serializable(BaseModel, ABC): def lc_id(cls) -> list[str]: """A unique identifier for this class for serialization purposes. - The unique identifier is a list of strings that describes the path - to the object. + The unique identifier is a list of strings that describes the path to + the object. """ return [*cls.get_lc_namespace(), cls.__name__] @@ -297,9 +293,8 @@ class SplitterDocument(Serializable): page_content: str """String text.""" metadata: dict = Field(default_factory=dict) - """Arbitrary metadata about the page content (e.g., source, relationships to other - documents, etc.). - """ + """Arbitrary metadata about the page content (e.g., source, relationships + to other documents, etc.).""" type: Literal["Document"] = "Document" def __init__(self, page_content: str, **kwargs: Any) -> None: @@ -350,7 +345,6 @@ class BaseDocumentTransformer(ABC): self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: raise NotImplementedError - """ # noqa: E501 @abstractmethod @@ -390,8 +384,8 @@ def _make_spacy_pipe_for_splitting( import spacy except ImportError: raise ImportError( - "Spacy is not installed, please install it with `pip install spacy`." - ) + "Spacy is not installed, run `pip install spacy`." + ) from None if pipe == "sentencizer": from spacy.lang.en import English @@ -443,9 +437,10 @@ class TextSplitter(BaseDocumentTransformer, ABC): chunk_overlap: Overlap in characters between chunks length_function: Function that measures the length of given chunks keep_separator: Whether to keep the separator in the chunks - add_start_index: If `True`, includes chunk's start index in metadata - strip_whitespace: If `True`, strips whitespace from the start and end of - every document + add_start_index: If `True`, includes chunk's start index in + metadata + strip_whitespace: If `True`, strips whitespace from the start and + end of every document """ if chunk_overlap > chunk_size: raise ValueError( @@ -570,7 +565,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): raise ValueError( "Could not import transformers python package. " "Please install it with `pip install transformers`." - ) + ) from None return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod @@ -586,11 +581,9 @@ class TextSplitter(BaseDocumentTransformer, ABC): try: import tiktoken except ImportError: - raise ImportError( - "Could not import tiktoken python package. " - "This is needed in order to calculate max_tokens_for_prompt. " - "Please install it with `pip install tiktoken`." - ) + raise ImportError("""Could not import tiktoken python package. + This is needed in order to calculate max_tokens_for_prompt. + Please install it with `pip install tiktoken`.""") from None if model is not None: enc = tiktoken.encoding_for_model(model) @@ -882,8 +875,8 @@ class ElementType(TypedDict): class HTMLHeaderTextSplitter: - """ - Splitting HTML files based on specified headers. + """Splitting HTML files based on specified headers. + Requires lxml package. """ @@ -895,9 +888,10 @@ class HTMLHeaderTextSplitter: """Create a new HTMLHeaderTextSplitter. Args: - headers_to_split_on: list of tuples of headers we want to track mapped to - (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, - h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)]. + headers_to_split_on: list of tuples of headers we want to track + mapped to (arbitrary) keys for metadata. Allowed header values: + h1, h2, h3, h4, h5, h6 + e.g. [("h1", "Header 1"), ("h2", "Header 2)]. return_each_element: Return each element w/ associated headers. """ # Output element-by-element or aggregated into chunks w/ common headers @@ -907,10 +901,11 @@ class HTMLHeaderTextSplitter: def aggregate_elements_to_chunks( self, elements: list[ElementType] ) -> list[SplitterDocument]: - """Combine elements with common metadata into chunks + """Combine elements with common metadata into chunks. Args: - elements: HTML element content with associated identifying info and metadata + elements: HTML element content with associated identifying + info and metadata """ aggregated_chunks: list[ElementType] = [] @@ -935,7 +930,7 @@ class HTMLHeaderTextSplitter: ] def split_text_from_url(self, url: str) -> list[SplitterDocument]: - """Split HTML from web URL + """Split HTML from web URL. Args: url: web URL @@ -944,7 +939,7 @@ class HTMLHeaderTextSplitter: return self.split_text_from_file(BytesIO(r.content)) def split_text(self, text: str) -> list[SplitterDocument]: - """Split HTML text string + """Split HTML text string. Args: text: HTML text @@ -952,25 +947,26 @@ class HTMLHeaderTextSplitter: return self.split_text_from_file(StringIO(text)) def split_text_from_file(self, file: Any) -> list[SplitterDocument]: - """Split HTML file + """Split HTML file. Args: file: HTML file """ try: from lxml import etree - except ImportError as e: + except ImportError: raise ImportError( - "Unable to import lxml, please install with `pip install lxml`." - ) from e + "Unable to import lxml, run `pip install lxml`." + ) from None # use lxml library to parse html document and return xml ElementTree # Explicitly encoding in utf-8 allows non-English # html files to be processed without garbled characters parser = etree.HTMLParser(encoding="utf-8") tree = etree.parse(file, parser) - # document transformation for "structure-aware" chunking is handled with xsl. - # see comments in html_chunks_with_headers.xslt for more detailed information. + # document transformation for "structure-aware" chunking is handled + # with xsl. See comments in html_chunks_with_headers.xslt for more + # detailed information. xslt_path = ( pathlib.Path(__file__).parent / "document_transformers/xsl/html_chunks_with_headers.xslt" @@ -1013,8 +1009,8 @@ class HTMLHeaderTextSplitter: ] ), metadata={ - # Add text of specified headers to metadata using header - # mapping. + # Add text of specified headers to + # metadata using header mapping. header_mapping[node.tag]: node.text for node in filter( lambda x: x.tag in header_filter, @@ -1044,13 +1040,13 @@ class Tokenizer: """Tokenizer data class.""" chunk_overlap: int - """Overlap in tokens between chunks""" + """Overlap in tokens between chunks.""" tokens_per_chunk: int - """Maximum number of tokens per chunk""" + """Maximum number of tokens per chunk.""" decode: Callable[[list[int]], str] - """ Function to decode a list of token ids to a string""" + """Function to decode a list of token ids to a string.""" encode: Callable[[str], list[int]] - """ Function to encode a string to a list of token ids""" + """Function to encode a string to a list of token ids.""" def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: @@ -1090,7 +1086,7 @@ class TokenTextSplitter(TextSplitter): "Could not import tiktoken python package. " "This is needed in order to for TokenTextSplitter. " "Please install it with `pip install tiktoken`." - ) + ) from None if model is not None: enc = tiktoken.encoding_for_model(model) @@ -1135,10 +1131,12 @@ class SentenceTransformersTokenTextSplitter(TextSplitter): from sentence_transformers import SentenceTransformer except ImportError: raise ImportError( - "Could not import sentence_transformer python package. " - "This is needed in order to for SentenceTransformersTokenTextSplitter. " - "Please install it with `pip install sentence-transformers`." - ) + """Could not import sentence_transformer python package. + This is needed in order to for + SentenceTransformersTokenTextSplitter. + Please install it with `pip install sentence-transformers`. + """ + ) from None self.model = model self._model = SentenceTransformer(self.model, trust_remote_code=True) @@ -1221,8 +1219,7 @@ class Language(str, Enum): class RecursiveCharacterTextSplitter(TextSplitter): """Splitting text by recursively look at characters. - Recursively tries to split by different characters to find one - that works. + Recursively tries to split by different characters to find one that works. """ def __init__( @@ -1580,9 +1577,11 @@ class RecursiveCharacterTextSplitter(TextSplitter): ] elif language == Language.MARKDOWN: return [ - # First, try to split along Markdown headings (starting with level 2) + # First, try to split along Markdown headings + # (starting with level 2) "\n#{1,6} ", - # Note the alternative syntax for headings (below) is not handled here + # Note the alternative syntax for headings (below) + # is not handled here # Heading level 2 # --------------- # End of code block @@ -1593,7 +1592,8 @@ class RecursiveCharacterTextSplitter(TextSplitter): "\n___+\n", # Note that this splitter doesn't handle # horizontal lines defined - # by *three or more* of ***, ---, or ___, but this is not handled + # by *three or more* of ***, ---, or ___, + # but this is not handled "\n\n", "\n", " ", @@ -1775,9 +1775,8 @@ class NLTKTextSplitter(TextSplitter): self._tokenizer = sent_tokenize except ImportError: - raise ImportError( - "NLTK is not installed, please install it with `pip install nltk`." - ) + raise ImportError("""NLTK is not installed, please install it with + `pip install nltk`.""") from None self._separator = separator self._language = language @@ -1791,7 +1790,6 @@ class NLTKTextSplitter(TextSplitter): class SpacyTextSplitter(TextSplitter): """Splitting text using Spacy package. - Per default, Spacy's `en_core_web_sm` model is used and its default max_length is 1000000 (it is the length of maximum character this model takes which can be increased for large files). For a faster, @@ -1835,12 +1833,10 @@ class KonlpyTextSplitter(TextSplitter): try: from konlpy.tag import Kkma except ImportError: - raise ImportError( - """ + raise ImportError(""" Konlpy is not installed, please install it with `pip install konlpy` - """ - ) + """) from None self.kkma = Kkma() def split_text(self, text: str) -> list[str]: @@ -1914,18 +1910,22 @@ class RecursiveJsonSplitter: for i, item in enumerate(data) } else: - # Base case: the item is neither a dict nor a list, so return it unchanged + # The item is neither a dict nor a list, return unchanged return data def _json_split( self, data: dict[str, Any], - current_path: list[str] = [], - chunks: list[dict] = [{}], + current_path: list[str] | None = None, + chunks: list[dict] | None = None, ) -> list[dict]: - """ - Split json into maximum size dictionaries while preserving structure. - """ + """Split json into maximum size dictionaries while preserving + structure.""" + if current_path is None: + current_path = [] + if chunks is None: + chunks = [{}] + if isinstance(data, dict): for key, value in data.items(): new_path = current_path + [key] @@ -1953,7 +1953,7 @@ class RecursiveJsonSplitter: json_data: dict[str, Any], convert_lists: bool = False, ) -> list[dict]: - """Splits JSON into a list of JSON chunks""" + """Splits JSON into a list of JSON chunks.""" if convert_lists: chunks = self._json_split( @@ -1970,7 +1970,7 @@ class RecursiveJsonSplitter: def split_text( self, json_data: dict[str, Any], convert_lists: bool = False ) -> list[str]: - """Splits JSON into a list of JSON formatted strings""" + """Splits JSON into a list of JSON formatted strings.""" chunks = self.split_json( json_data=json_data, convert_lists=convert_lists diff --git a/py/tests/integration/conftest.py b/py/tests/integration/conftest.py index dedc1299b..bd02a5395 100644 --- a/py/tests/integration/conftest.py +++ b/py/tests/integration/conftest.py @@ -7,6 +7,7 @@ from r2r import R2RAsyncClient, R2RClient class TestConfig: + def __init__(self): self.base_url = "http://localhost:7272" self.index_wait_time = 1.0 @@ -47,8 +48,8 @@ async def aclient(config) -> AsyncGenerator[R2RClient, None]: @pytest.fixture async def superuser_client( - client: R2RClient, config: TestConfig -) -> AsyncGenerator[R2RClient, None]: + client: R2RClient, + config: TestConfig) -> AsyncGenerator[R2RClient, None]: """Creates a superuser client for tests requiring elevated privileges.""" await client.users.login(config.superuser_email, config.superuser_password) yield client @@ -62,6 +63,7 @@ from r2r import R2RClient, R2RException @pytest.fixture(scope="session") def config(): + class TestConfig: base_url = "http://localhost:7272" superuser_email = "admin@example.com" @@ -98,13 +100,15 @@ def test_document(client: R2RClient): @pytest.fixture(scope="session") def test_collection(client: R2RClient, test_document): - """Create a test collection with sample documents and clean up after tests.""" + """Create a test collection with sample documents and clean up after + tests.""" collection_name = f"Test Collection {uuid.uuid4()}" collection_id = client.collections.create(name=collection_name).results.id docs = [ { - "text": f"Aristotle was a Greek philosopher who studied under Plato {str(uuid.uuid4())}.", + "text": + f"Aristotle was a Greek philosopher who studied under Plato {str(uuid.uuid4())}.", "metadata": { "rating": 5, "tags": ["philosophy", "greek"], @@ -112,7 +116,8 @@ def test_collection(client: R2RClient, test_document): }, }, { - "text": f"Socrates is considered a founder of Western philosophy {str(uuid.uuid4())}.", + "text": + f"Socrates is considered a founder of Western philosophy {str(uuid.uuid4())}.", "metadata": { "rating": 3, "tags": ["philosophy", "classical"], @@ -120,7 +125,8 @@ def test_collection(client: R2RClient, test_document): }, }, { - "text": f"Rene Descartes was a French philosopher. unique_philosopher {str(uuid.uuid4())}", + "text": + f"Rene Descartes was a French philosopher. unique_philosopher {str(uuid.uuid4())}", "metadata": { "rating": 8, "tags": ["rationalism", "french"], @@ -128,7 +134,8 @@ def test_collection(client: R2RClient, test_document): }, }, { - "text": f"Immanuel Kant, a German philosopher, influenced Enlightenment thought {str(uuid.uuid4())}.", + "text": + f"Immanuel Kant, a German philosopher, influenced Enlightenment thought {str(uuid.uuid4())}.", "metadata": { "rating": 7, "tags": ["enlightenment", "german"], @@ -140,8 +147,7 @@ def test_collection(client: R2RClient, test_document): doc_ids = [] for doc in docs: doc_id = client.documents.create( - raw_text=doc["text"], metadata=doc["metadata"] - ).results.document_id + raw_text=doc["text"], metadata=doc["metadata"]).results.document_id doc_ids.append(doc_id) client.collections.add_document(collection_id, doc_id) client.collections.add_document(collection_id, test_document) diff --git a/py/tests/integration/test_base.py b/py/tests/integration/test_base.py index 9ad9495ea..cc60ad916 100644 --- a/py/tests/integration/test_base.py +++ b/py/tests/integration/test_base.py @@ -7,9 +7,8 @@ class BaseTest: """Base class for all test classes with common utilities.""" @staticmethod - async def cleanup_resource( - cleanup_func, resource_id: Optional[str] = None - ) -> None: + async def cleanup_resource(cleanup_func, + resource_id: Optional[str] = None) -> None: """Generic cleanup helper that won't fail the test if cleanup fails.""" if resource_id: try: diff --git a/py/tests/integration/test_chunks.py b/py/tests/integration/test_chunks.py index c5624e041..e2f5f56ab 100644 --- a/py/tests/integration/test_chunks.py +++ b/py/tests/integration/test_chunks.py @@ -9,17 +9,16 @@ from r2r import R2RAsyncClient, R2RException class AsyncR2RTestClient: - """Wrapper to ensure async operations use the correct event loop""" + """Wrapper to ensure async operations use the correct event loop.""" def __init__(self, base_url: str = "http://localhost:7272"): self.client = R2RAsyncClient(base_url) - async def create_document( - self, chunks: list[str], run_with_orchestration: bool = False - ): + async def create_document(self, + chunks: list[str], + run_with_orchestration: bool = False): response = await self.client.documents.create( - chunks=chunks, run_with_orchestration=run_with_orchestration - ) + chunks=chunks, run_with_orchestration=run_with_orchestration) return response.results.document_id, [] async def delete_document(self, doc_id: str): @@ -33,12 +32,15 @@ class AsyncR2RTestClient: response = await self.client.chunks.retrieve(id=chunk_id) return response.results - async def update_chunk( - self, chunk_id: str, text: str, metadata: Optional[dict] = None - ): - response = await self.client.chunks.update( - {"id": chunk_id, "text": text, "metadata": metadata or {}} - ) + async def update_chunk(self, + chunk_id: str, + text: str, + metadata: Optional[dict] = None): + response = await self.client.chunks.update({ + "id": chunk_id, + "text": text, + "metadata": metadata or {} + }) return response.results async def delete_chunk(self, chunk_id: str): @@ -47,8 +49,7 @@ class AsyncR2RTestClient: async def search_chunks(self, query: str, limit: int = 5): response = await self.client.chunks.search( - query=query, search_settings={"limit": limit} - ) + query=query, search_settings={"limit": limit}) return response.results async def register_user(self, email: str, password: str): @@ -75,8 +76,7 @@ async def test_document( uuid_1 = uuid.uuid4() uuid_2 = uuid.uuid4() doc_id, _ = await test_client.create_document( - [f"Test chunk 1_{uuid_1}", f"Test chunk 2_{uuid_2}"] - ) + [f"Test chunk 1_{uuid_1}", f"Test chunk 2_{uuid_2}"]) await asyncio.sleep(1) # Wait for ingestion chunks = await test_client.list_chunks(str(doc_id)) yield doc_id, chunks @@ -85,14 +85,14 @@ async def test_document( class TestChunks: + @pytest.mark.asyncio - async def test_create_and_list_chunks( - self, test_client: AsyncR2RTestClient, cleanup_documents - ): + async def test_create_and_list_chunks(self, + test_client: AsyncR2RTestClient, + cleanup_documents): # Create document with chunks doc_id, _ = await test_client.create_document( - ["Hello chunk", "World chunk"] - ) + ["Hello chunk", "World chunk"]) cleanup_documents(str(doc_id)) await asyncio.sleep(1) # Wait for ingestion @@ -101,36 +101,31 @@ class TestChunks: assert len(chunks) == 2, "Expected 2 chunks in the document" @pytest.mark.asyncio - async def test_retrieve_chunk( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_retrieve_chunk(self, test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id retrieved = await test_client.retrieve_chunk(chunk_id) assert str(retrieved.id) == str(chunk_id), "Retrieved wrong chunk ID" assert retrieved.text.split("_")[0] == "Test chunk 1", ( - "Chunk text mismatch" - ) + "Chunk text mismatch") @pytest.mark.asyncio - async def test_update_chunk( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_update_chunk(self, test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id # Update chunk - updated = await test_client.update_chunk( - str(chunk_id), "Updated text", {"version": 2} - ) + updated = await test_client.update_chunk(str(chunk_id), "Updated text", + {"version": 2}) assert updated.text == "Updated text", "Chunk text not updated" assert updated.metadata["version"] == 2, "Metadata not updated" @pytest.mark.asyncio - async def test_delete_chunk( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_delete_chunk(self, test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id @@ -144,18 +139,15 @@ class TestChunks: assert exc_info.value.status_code == 404 @pytest.mark.asyncio - async def test_search_chunks( - self, test_client: AsyncR2RTestClient, cleanup_documents - ): + async def test_search_chunks(self, test_client: AsyncR2RTestClient, + cleanup_documents): # Create searchable document random_1 = uuid.uuid4() random_2 = uuid.uuid4() - doc_id, _ = await test_client.create_document( - [ - f"Aristotle reference {random_1}", - f"Another piece of text {random_2}", - ] - ) + doc_id, _ = await test_client.create_document([ + f"Aristotle reference {random_1}", + f"Another piece of text {random_2}", + ]) cleanup_documents(doc_id) await asyncio.sleep(1) # Wait for indexing @@ -164,9 +156,9 @@ class TestChunks: assert len(results) > 0, "No search results found" @pytest.mark.asyncio - async def test_unauthorized_chunk_access( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_unauthorized_chunk_access(self, + test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id @@ -182,9 +174,9 @@ class TestChunks: assert exc_info.value.status_code == 403 @pytest.mark.asyncio - async def test_list_chunks_with_filters( - self, test_client: AsyncR2RTestClient, cleanup_documents - ): + async def test_list_chunks_with_filters(self, + test_client: AsyncR2RTestClient, + cleanup_documents): """Test listing chunks with owner_id filter.""" # Create and login as temporary user temp_email = f"{uuid.uuid4()}@example.com" @@ -193,15 +185,13 @@ class TestChunks: # Create a document with chunks doc_id, _ = await test_client.create_document( - ["Test chunk 1", "Test chunk 2"] - ) + ["Test chunk 1", "Test chunk 2"]) cleanup_documents(doc_id) await asyncio.sleep(1) # Wait for ingestion @pytest.mark.asyncio - async def test_list_chunks_pagination( - self, test_client: AsyncR2RTestClient - ): + async def test_list_chunks_pagination(self, + test_client: AsyncR2RTestClient): """Test chunk listing with pagination.""" # Create and login as temporary user temp_email = f"{uuid.uuid4()}@example.com" @@ -218,23 +208,20 @@ class TestChunks: # Test first page response1 = await test_client.client.chunks.list(offset=0, limit=2) - assert len(response1.results) == 2, ( - "Expected 2 results on first page" - ) + assert len( + response1.results) == 2, ("Expected 2 results on first page") # Test second page response2 = await test_client.client.chunks.list(offset=2, limit=2) - assert len(response2.results) == 2, ( - "Expected 2 results on second page" - ) + assert len( + response2.results) == 2, ("Expected 2 results on second page") # Verify no duplicate results ids_page1 = {str(chunk.id) for chunk in response1.results} ids_page2 = {str(chunk.id) for chunk in response2.results} assert not ids_page1.intersection(ids_page2), ( - "Found duplicate chunks across pages" - ) + "Found duplicate chunks across pages") finally: # Cleanup @@ -247,8 +234,7 @@ class TestChunks: @pytest.mark.asyncio async def test_list_chunks_with_multiple_documents( - self, test_client: AsyncR2RTestClient - ): + self, test_client: AsyncR2RTestClient): """Test listing chunks across multiple documents.""" # Create and login as temporary user temp_email = f"{uuid.uuid4()}@example.com" @@ -260,8 +246,7 @@ class TestChunks: # Create multiple documents for i in range(2): doc_id, _ = await test_client.create_document( - [f"Doc {i} chunk 1", f"Doc {i} chunk 2"] - ) + [f"Doc {i} chunk 1", f"Doc {i} chunk 2"]) doc_ids.append(doc_id) await asyncio.sleep(1) # Wait for ingestion @@ -272,11 +257,12 @@ class TestChunks: assert len(response.results) == 4, "Expected 4 total chunks" chunk_doc_ids = { - str(chunk.document_id) for chunk in response.results + str(chunk.document_id) + for chunk in response.results } - assert all(str(doc_id) in chunk_doc_ids for doc_id in doc_ids), ( - "Got chunks from wrong documents" - ) + assert all( + str(doc_id) in chunk_doc_ids + for doc_id in doc_ids), ("Got chunks from wrong documents") finally: # Cleanup diff --git a/py/tests/integration/test_collection_id_filter.py b/py/tests/integration/test_collection_id_filter.py index 69c0fb55e..d1648c54b 100644 --- a/py/tests/integration/test_collection_id_filter.py +++ b/py/tests/integration/test_collection_id_filter.py @@ -13,15 +13,14 @@ def setup_docs_with_collections(client: R2RClient): coll_ids = [] for i in range(3): collection_id = client.collections.create( - name=f"TestColl{i}" - ).results.id + name=f"TestColl{i}").results.id coll_ids.append(collection_id) # Create documents with different collection arrangements: # doc1: [coll1] doc1 = client.documents.create( - raw_text=f"Doc in coll1{random_suffix}", run_with_orchestration=False - ).results.document_id + raw_text=f"Doc in coll1{random_suffix}", + run_with_orchestration=False).results.document_id client.collections.add_document(coll_ids[0], doc1) # doc2: [coll1, coll2] @@ -40,8 +39,8 @@ def setup_docs_with_collections(client: R2RClient): # doc4: [coll3] doc4 = client.documents.create( - raw_text="Doc in coll3" + random_suffix, run_with_orchestration=False - ).results.document_id + raw_text="Doc in coll3" + random_suffix, + run_with_orchestration=False).results.document_id client.collections.add_document(coll_ids[2], doc4) yield {"coll_ids": coll_ids, "doc_ids": [doc1, doc2, doc3, doc4]} @@ -59,18 +58,18 @@ def setup_docs_with_collections(client: R2RClient): pass -def test_collection_id_eq_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_eq_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids # collection_id = coll_ids[0] should match doc1 and doc2 only filters = {"collection_id": {"$eq": str(coll_ids[0])}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert { str(doc1), @@ -78,9 +77,8 @@ def test_collection_id_eq_filter( } == found_ids, f"Expected doc1 and doc2, got {found_ids}" -def test_collection_id_ne_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_ne_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -89,9 +87,10 @@ def test_collection_id_ne_filter( # Those are doc3 (no collections) and doc4 (in coll3 only) filters = {"collection_id": {"$ne": str(coll_ids[0])}} # listed = client.documents.list(limit=10, offset=0, filters=filters)["results"] - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert coll_ids[0] not in found_ids, f"Expected no coll0, got {found_ids}" # assert { @@ -100,9 +99,8 @@ def test_collection_id_ne_filter( # } == found_ids, f"Expected doc3 and doc4, got {found_ids}" -def test_collection_id_in_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_in_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -111,9 +109,10 @@ def test_collection_id_in_filter( # doc1 in coll0, doc2 in coll0, doc4 in coll2 # doc3 is in none filters = {"collection_id": {"$in": [str(coll_ids[0]), str(coll_ids[2])]}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert { str(doc1), @@ -122,9 +121,8 @@ def test_collection_id_in_filter( } == found_ids, f"Expected doc1, doc2, doc4, got {found_ids}" -def test_collection_id_nin_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_nin_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -133,16 +131,16 @@ def test_collection_id_nin_filter( # doc2 belongs to coll1, so exclude doc2 # doc1, doc3, doc4 remain filters = {"collection_id": {"$nin": [str(coll_ids[1])]}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert coll_ids[1] not in found_ids, f"Expected no coll1, got {found_ids}" -def test_collection_id_contains_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_contains_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -151,9 +149,10 @@ def test_collection_id_contains_filter( # If collection_id {"$contains": "coll_ids[0]"}, docs must have coll0 in their array # That would be doc1 and doc2 only filters = {"collection_id": {"$contains": str(coll_ids[0])}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert { str(doc1), @@ -161,9 +160,8 @@ def test_collection_id_contains_filter( } == found_ids, f"Expected doc1 and doc2, got {found_ids}" -def test_collection_id_contains_multiple( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_contains_multiple(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -172,18 +170,20 @@ def test_collection_id_contains_multiple( # this should mean the doc's collection_ids contain ALL of these. # Only doc2 has coll0 AND coll1. doc1 only has coll0, doc3 no collections, doc4 only coll3. filters = { - "collection_id": {"$contains": [str(coll_ids[0]), str(coll_ids[1])]} + "collection_id": { + "$contains": [str(coll_ids[0]), str(coll_ids[1])] + } } - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert {str(doc2)} == found_ids, f"Expected doc2 only, got {found_ids}" -def test_delete_by_collection_id_eq( - client: R2RClient, setup_docs_with_collections -): +def test_delete_by_collection_id_eq(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc1, doc2, doc3, doc4 = setup_docs_with_collections["doc_ids"] diff --git a/py/tests/integration/test_collections.py b/py/tests/integration/test_collections.py index 78bc8503a..743922982 100644 --- a/py/tests/integration/test_collections.py +++ b/py/tests/integration/test_collections.py @@ -22,9 +22,8 @@ def test_document_2(client: R2RClient): def test_create_collection(client: R2RClient): - collection_id = client.collections.create( - name="Test Collection Creation", description="Desc" - ).results.id + collection_id = client.collections.create(name="Test Collection Creation", + description="Desc").results.id assert collection_id is not None, "No collection_id returned" # Cleanup @@ -39,11 +38,9 @@ def test_list_collections(client: R2RClient, test_collection): def test_retrieve_collection(client: R2RClient, test_collection): # Retrieve the collection just created retrieved = client.collections.retrieve( - test_collection["collection_id"] - ).results + test_collection["collection_id"]).results assert retrieved.id == test_collection["collection_id"], ( - "Retrieved wrong collection ID" - ) + "Retrieved wrong collection ID") def test_update_collection(client: R2RClient, test_collection): @@ -56,48 +53,37 @@ def test_update_collection(client: R2RClient, test_collection): ).results assert updated.name == updated_name, "Collection name not updated" assert updated.description == updated_desc, ( - "Collection description not updated" - ) + "Collection description not updated") -def test_add_document_to_collection( - client: R2RClient, test_collection, test_document_2 -): - client.collections.add_document( - test_collection["collection_id"], str(test_document_2) - ) +def test_add_document_to_collection(client: R2RClient, test_collection, + test_document_2): + client.collections.add_document(test_collection["collection_id"], + str(test_document_2)) docs_in_collection = client.collections.list_documents( - test_collection["collection_id"] - ).results + test_collection["collection_id"]).results found = any( - str(doc.id) == str(test_document_2) for doc in docs_in_collection - ) + str(doc.id) == str(test_document_2) for doc in docs_in_collection) assert found, "Added document not found in collection" -def test_list_documents_in_collection( - client: R2RClient, test_collection, test_document -): +def test_list_documents_in_collection(client: R2RClient, test_collection, + test_document): # Document should be in the collection already from previous test docs_in_collection = client.collections.list_documents( - test_collection["collection_id"] - ).results + test_collection["collection_id"]).results found = any( - str(doc.id) == str(test_document) for doc in docs_in_collection - ) + str(doc.id) == str(test_document) for doc in docs_in_collection) assert found, "Expected document not found in collection" -def test_remove_document_from_collection( - client: R2RClient, test_collection, test_document -): +def test_remove_document_from_collection(client: R2RClient, test_collection, + test_document): # Remove the document from the collection - client.collections.remove_document( - test_collection["collection_id"], test_document - ) + client.collections.remove_document(test_collection["collection_id"], + test_document) docs_in_collection = client.collections.list_documents( - test_collection["collection_id"] - ).results + test_collection["collection_id"]).results found = any(str(doc.id) == test_document for doc in docs_in_collection) assert not found, "Document still present in collection after removal" @@ -111,8 +97,7 @@ def test_remove_non_member_user_from_collection(mutable_client: R2RClient): # Create a collection by the same user collection_id = mutable_client.collections.create( - name="User Owned Collection" - ).results.id + name="User Owned Collection").results.id mutable_client.users.logout() # Create another user who will not be added to the collection @@ -147,8 +132,7 @@ def test_delete_collection(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.collections.retrieve(coll_id) assert exc_info.value.status_code == 404, ( - "Wrong error code retrieving deleted collection" - ) + "Wrong error code retrieving deleted collection") def test_add_user_to_non_existent_collection(mutable_client: R2RClient): @@ -164,12 +148,10 @@ def test_add_user_to_non_existent_collection(mutable_client: R2RClient): # (Assumes superuser credentials are already in the client fixture) fake_collection_id = str(uuid.uuid4()) # Non-existent collection ID with pytest.raises(R2RException) as exc_info: - result = mutable_client.collections.add_user( - fake_collection_id, user_id - ) + result = mutable_client.collections.add_user(fake_collection_id, + user_id) assert exc_info.value.status_code == 404, ( - "Wrong error code for non-existent collection" - ) + "Wrong error code for non-existent collection") def test_create_collection_without_name(client: R2RClient): @@ -188,16 +170,14 @@ def test_filter_collections_by_non_existent_id(client: R2RClient): # Filter collections by an ID that does not exist random_id = str(uuid.uuid4()) resp = client.collections.list(ids=[random_id]) - assert len(resp.results) == 0, ( - "Expected no collections for a non-existent ID" - ) + assert len( + resp.results) == 0, ("Expected no collections for a non-existent ID") def test_list_documents_in_empty_collection(client: R2RClient): # Create a new collection with no documents empty_coll_id = client.collections.create( - name="Empty Collection" - ).results.id + name="Empty Collection").results.id docs = client.collections.list_documents(empty_coll_id).results assert len(docs) == 0, "Expected no documents in a new empty collection" @@ -240,8 +220,7 @@ def test_delete_non_existent_collection(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.collections.delete(fake_collection_id) assert exc_info.value.status_code == 404, ( - "Expected 404 when deleting non-existent collection" - ) + "Expected 404 when deleting non-existent collection") def test_retrieve_collection_by_name(client: R2RClient): @@ -250,19 +229,16 @@ def test_retrieve_collection_by_name(client: R2RClient): # Create a collection with the unique name created_resp = client.collections.create( - name=unique_name, description="Collection for retrieval by name test" - ) + name=unique_name, description="Collection for retrieval by name test") created = created_resp.results assert created.id is not None, ( - "Creation did not return a valid collection ID" - ) + "Creation did not return a valid collection ID") # Retrieve the collection by its name retrieved_resp = client.collections.retrieve_by_name(unique_name) retrieved = retrieved_resp.results assert retrieved.id == created.id, ( - "Retrieved collection does not match the created collection" - ) + "Retrieved collection does not match the created collection") # Cleanup: Delete the created collection client.collections.delete(created.id) diff --git a/py/tests/integration/test_collections_users_interaction.py b/py/tests/integration/test_collections_users_interaction.py index 06fde4caa..5e1e8c854 100644 --- a/py/tests/integration/test_collections_users_interaction.py +++ b/py/tests/integration/test_collections_users_interaction.py @@ -27,9 +27,8 @@ def normal_user_client(mutable_client: R2RClient): # Cleanup: Try deleting the normal user if exists try: mutable_client.users.login(email, password) - mutable_client.users.delete( - id=mutable_client.users.me().results.id, password=password - ) + mutable_client.users.delete(id=mutable_client.users.me().results.id, + password=password) except R2RException: pass @@ -84,33 +83,29 @@ def superuser_owned_collection(client: R2RClient): pass -def test_non_member_cannot_view_collection( - normal_user_client, superuser_owned_collection -): - """A normal user (not a member of a superuser-owned collection) tries to view it.""" +def test_non_member_cannot_view_collection(normal_user_client, + superuser_owned_collection): + """A normal user (not a member of a superuser-owned collection) tries to + view it.""" # The normal user is not added to the superuser collection, should fail with pytest.raises(R2RException) as exc_info: normal_user_client.collections.retrieve(superuser_owned_collection) assert exc_info.value.status_code == 403, ( - "Non-member should not be able to view collection." - ) + "Non-member should not be able to view collection.") -def test_collection_owner_can_view_collection( - normal_user_client: R2RClient, user_owned_collection -): +def test_collection_owner_can_view_collection(normal_user_client: R2RClient, + user_owned_collection): """The owner should be able to view their own collection.""" coll = normal_user_client.collections.retrieve( - user_owned_collection - ).results + user_owned_collection).results assert coll.id == user_owned_collection, ( - "Owner cannot view their own collection." - ) + "Owner cannot view their own collection.") -def test_collection_member_can_view_collection( - client, normal_user_client: R2RClient, user_owned_collection -): +def test_collection_member_can_view_collection(client, + normal_user_client: R2RClient, + user_owned_collection): """A user added to a collection should be able to view it.""" # Create another user and add them to the user's collection new_user_email = f"temp_member_{uuid.uuid4()}@test.com" @@ -141,21 +136,19 @@ def test_non_owner_member_cannot_edit_collection( another_normal_user_client: R2RClient, normal_user_client: R2RClient, ): - """A member who is not the owner should not be able to edit the collection.""" + """A member who is not the owner should not be able to edit the + collection.""" # Add another normal user to the owner's collection another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # Another normal user tries to update collection with pytest.raises(R2RException) as exc_info: - another_normal_user_client.collections.update( - user_owned_collection, name="Malicious Update" - ) + another_normal_user_client.collections.update(user_owned_collection, + name="Malicious Update") assert exc_info.value.status_code == 403, ( - "Non-owner member should not be able to edit." - ) + "Non-owner member should not be able to edit.") def test_non_owner_member_cannot_delete_collection( @@ -163,19 +156,18 @@ def test_non_owner_member_cannot_delete_collection( another_normal_user_client: R2RClient, normal_user_client: R2RClient, ): - """A member who is not the owner should not be able to delete the collection.""" + """A member who is not the owner should not be able to delete the + collection.""" # Add the other user another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # Another user tries to delete with pytest.raises(R2RException) as exc_info: another_normal_user_client.collections.delete(user_owned_collection) assert exc_info.value.status_code == 403, ( - "Non-owner member should not be able to delete." - ) + "Non-owner member should not be able to delete.") def test_non_owner_member_cannot_add_other_users( @@ -196,24 +188,20 @@ def test_non_owner_member_cannot_add_other_users( # This code snippet assumes we have these credentials available. # If not, manage credentials store in fixture creation. normal_user_client.users.login(normal_user_email, "normal_password") - third_user_id = normal_user_client.users.create( - third_email, third_password - ).results.id + third_user_id = normal_user_client.users.create(third_email, + third_password).results.id # Add another user as a member another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # Now, another_normal_user_client tries to add the third user with pytest.raises(R2RException) as exc_info: another_normal_user_client.collections.add_user( - user_owned_collection, third_user_id - ) + user_owned_collection, third_user_id) assert exc_info.value.status_code == 403, ( - "Non-owner member should not be able to add users." - ) + "Non-owner member should not be able to add users.") def test_owner_can_remove_member_from_collection( @@ -224,47 +212,40 @@ def test_owner_can_remove_member_from_collection( """The owner should be able to remove a member from their collection.""" # Add another user to the collection another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # Remove them remove_resp = normal_user_client.collections.remove_user( - user_owned_collection, another_user_id - ).results + user_owned_collection, another_user_id).results assert remove_resp.success, "Owner could not remove member." # The removed user should no longer have access with pytest.raises(R2RException) as exc_info: another_normal_user_client.collections.retrieve(user_owned_collection) assert exc_info.value.status_code == 403, ( - "Removed user still has access after removal." - ) + "Removed user still has access after removal.") -def test_superuser_can_access_any_collection( - client: R2RClient, user_owned_collection -): +def test_superuser_can_access_any_collection(client: R2RClient, + user_owned_collection): """A superuser should be able to view and edit any collection.""" # Superuser can view coll = client.collections.retrieve(user_owned_collection).results assert coll.id == user_owned_collection, ( - "Superuser cannot view a user collection." - ) + "Superuser cannot view a user collection.") # Superuser can also update - updated = client.collections.update( - user_owned_collection, name="Superuser Edit" - ).results + updated = client.collections.update(user_owned_collection, + name="Superuser Edit").results assert updated.name == "Superuser Edit", ( - "Superuser cannot edit collection." - ) + "Superuser cannot edit collection.") -def test_unauthenticated_cannot_access_collections( - config, user_owned_collection -): - """An unauthenticated (no login) client should not access protected endpoints.""" +def test_unauthenticated_cannot_access_collections(config, + user_owned_collection): + """An unauthenticated (no login) client should not access protected + endpoints.""" unauth_client = R2RClient(config.base_url) # we must CREATE + LOGIN as superuser is default user for unauth in basic config user_name = f"unauth_user_{uuid.uuid4()}@email.com" @@ -273,18 +254,16 @@ def test_unauthenticated_cannot_access_collections( with pytest.raises(R2RException) as exc_info: unauth_client.collections.retrieve(user_owned_collection) assert exc_info.value.status_code == 403, ( - "Unaurthorized user should get 403" - ) + "Unaurthorized user should get 403") def test_user_cannot_add_document_to_collection_they_cannot_edit( - client: R2RClient, normal_user_client: R2RClient -): - """A normal user who is just a member (not owner) of a collection should not be able to add documents.""" + client: R2RClient, normal_user_client: R2RClient): + """A normal user who is just a member (not owner) of a collection should + not be able to add documents.""" # Create a collection as normal user (owner) coll_id = normal_user_client.collections.create( - name="Owned by user", description="desc" - ).results.id + name="Owned by user", description="desc").results.id # Create a second user and add them as member second_email = f"second_{uuid.uuid4()}@test.com" @@ -307,23 +286,20 @@ def test_user_cannot_add_document_to_collection_they_cannot_edit( # Create a document as owner doc_id = normal_user_client.documents.create( - raw_text="Test Document" - ).results.document_id + raw_text="Test Document").results.document_id # Now second user tries to add another document (which they do not have edit rights for) second_client.users.logout() second_client.users.login(second_email, second_password) # Another doc created by second user (just for attempt) doc2_id = second_client.documents.create( - raw_text="Doc by second user" - ).results.document_id + raw_text="Doc by second user").results.document_id # Second user tries to add their doc2_id to the owner’s collection with pytest.raises(R2RException) as exc_info: second_client.collections.add_document(coll_id, doc2_id) assert exc_info.value.status_code == 403, ( - "Non-owner member should not add documents." - ) + "Non-owner member should not add documents.") # Cleanup normal_user_client.collections.delete(coll_id) @@ -332,18 +308,15 @@ def test_user_cannot_add_document_to_collection_they_cannot_edit( def test_user_cannot_remove_document_from_collection_they_cannot_edit( - normal_user_client: R2RClient, -): + normal_user_client: R2RClient, ): """A user who is just a member should not remove documents.""" # Create a collection coll_id = normal_user_client.collections.create( - name="Removable", description="desc" - ).results.id + name="Removable", description="desc").results.id # Create a document in it doc_id = normal_user_client.documents.create( - raw_text="Doc in coll" - ).results.document_id + raw_text="Doc in coll").results.document_id normal_user_client.collections.add_document(coll_id, doc_id) # Create another user and add as member @@ -364,16 +337,14 @@ def test_user_cannot_remove_document_from_collection_they_cannot_edit( with pytest.raises(R2RException) as exc_info: member_client.collections.remove_document(coll_id, doc_id) assert exc_info.value.status_code == 403, ( - "Member should not remove documents." - ) + "Member should not remove documents.") # Cleanup normal_user_client.collections.delete(coll_id) def test_normal_user_cannot_make_another_user_superuser( - normal_user_client: R2RClient, -): + normal_user_client: R2RClient, ): """A normal user tries to update another user to superuser, should fail.""" # Create another user email = f"regular_{uuid.uuid4()}@test.com" @@ -384,24 +355,20 @@ def test_normal_user_cannot_make_another_user_superuser( with pytest.raises(R2RException) as exc_info: normal_user_client.users.update(new_user_id, is_superuser=True) assert exc_info.value.status_code == 403, ( - "Non-superuser should not grant superuser status." - ) + "Non-superuser should not grant superuser status.") def test_normal_user_cannot_view_other_users_if_not_superuser( - normal_user_client: R2RClient, -): + normal_user_client: R2RClient, ): """A normal user tries to list all users, should fail.""" with pytest.raises(R2RException) as exc_info: normal_user_client.users.list() assert exc_info.value.status_code == 403, ( - "Non-superuser should not list all users." - ) + "Non-superuser should not list all users.") def test_normal_user_cannot_update_other_users_details( - normal_user_client: R2RClient, client: R2RClient -): + normal_user_client: R2RClient, client: R2RClient): """A normal user tries to update another normal user's details.""" # Create another normal user email = f"other_normal_{uuid.uuid4()}@test.com" @@ -417,8 +384,7 @@ def test_normal_user_cannot_update_other_users_details( with pytest.raises(R2RException) as exc_info: normal_user_client.users.update(another_user_id, name="Hacked Name") assert exc_info.value.status_code == 403, ( - "Non-superuser should not update another user's info." - ) + "Non-superuser should not update another user's info.") # Additional Tests for Strengthened Coverage @@ -429,22 +395,18 @@ def test_owner_cannot_promote_member_to_superuser_via_collection( normal_user_client: R2RClient, another_normal_user_client: R2RClient, ): - """ - Ensures that being a collection owner doesn't confer the right - to promote a user to superuser. - """ + """Ensures that being a collection owner doesn't confer the right to + promote a user to superuser.""" # Add another user to the collection another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # Try to update the member's superuser status with pytest.raises(R2RException) as exc_info: normal_user_client.users.update(another_user_id, is_superuser=True) assert exc_info.value.status_code == 403, ( - "Collection owners should not grant superuser status." - ) + "Collection owners should not grant superuser status.") def test_member_cannot_view_other_users_info( @@ -452,31 +414,25 @@ def test_member_cannot_view_other_users_info( normal_user_client: R2RClient, another_normal_user_client: R2RClient, ): - """ - A member (non-owner) of a collection should not be able to retrieve other users' details - outside of their allowed scope. - """ + """A member (non-owner) of a collection should not be able to retrieve + other users' details outside of their allowed scope.""" # Add the other normal user as a member another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # As another_normal_user_client (a member), try to retrieve owner user details owner_id = normal_user_client.users.me().results.id with pytest.raises(R2RException) as exc_info: another_normal_user_client.users.retrieve(owner_id) assert exc_info.value.status_code == 403, ( - "Members should not be able to view other users' details." - ) + "Members should not be able to view other users' details.") -def test_unauthenticated_user_cannot_join_collection( - config, user_owned_collection -): - """ - An unauthenticated user should not be able to join or view collections. - """ +def test_unauthenticated_user_cannot_join_collection(config, + user_owned_collection): + """An unauthenticated user should not be able to join or view + collections.""" unauth_client = R2RClient(config.base_url) # we must CREATE + LOGIN as superuser is default user for unauth in basic config user_name = f"unauth_user_{uuid.uuid4()}@email.com" @@ -497,23 +453,19 @@ def test_non_owner_cannot_remove_users_they_did_not_add( normal_user_client: R2RClient, another_normal_user_client: R2RClient, ): - """ - A member who is not the owner cannot remove other members from the collection. - """ + """A member who is not the owner cannot remove other members from the + collection.""" # Add another user as a member another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # Now try removing that user as another_normal_user_client with pytest.raises(R2RException) as exc_info: another_normal_user_client.collections.remove_user( - user_owned_collection, another_user_id - ) + user_owned_collection, another_user_id) assert exc_info.value.status_code == 403, ( - "Non-owner member should not remove other users." - ) + "Non-owner member should not remove other users.") def test_owner_cannot_access_deleted_member_info_after_removal( @@ -521,20 +473,16 @@ def test_owner_cannot_access_deleted_member_info_after_removal( normal_user_client: R2RClient, another_normal_user_client: R2RClient, ): - """ - After the owner removes a user from the collection, ensure that attempts to - perform collection-specific actions with that user fail. - """ + """After the owner removes a user from the collection, ensure that attempts + to perform collection-specific actions with that user fail.""" # Add another user to the collection another_user_id = another_normal_user_client.users.me().results.id - normal_user_client.collections.add_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.add_user(user_owned_collection, + another_user_id) # Remove them - normal_user_client.collections.remove_user( - user_owned_collection, another_user_id - ) + normal_user_client.collections.remove_user(user_owned_collection, + another_user_id) # Now, try listing collections for that removed user (as owner), # if there's an endpoint that filters by user, to ensure no special access remains. @@ -544,20 +492,15 @@ def test_owner_cannot_access_deleted_member_info_after_removal( normal_user_client.users.retrieve(another_user_id) # We expect a 403 because normal_user_client is not superuser and not that user. assert exc_info.value.status_code == 403, ( - "Owner should not access removed member's user info." - ) + "Owner should not access removed member's user info.") def test_member_cannot_add_document_to_non_existent_collection( - normal_user_client: R2RClient, -): - """ - A member tries to add a document to a collection that doesn't exist. - """ + normal_user_client: R2RClient, ): + """A member tries to add a document to a collection that doesn't exist.""" fake_coll_id = str(uuid.uuid4()) doc_id = normal_user_client.documents.create( - raw_text="Test Doc" - ).results.document_id + raw_text="Test Doc").results.document_id with pytest.raises(R2RException) as exc_info: normal_user_client.collections.add_document(fake_coll_id, doc_id) assert exc_info.value.status_code in [ diff --git a/py/tests/integration/test_conversations.py b/py/tests/integration/test_conversations.py index fcc598a92..056122d3d 100644 --- a/py/tests/integration/test_conversations.py +++ b/py/tests/integration/test_conversations.py @@ -35,8 +35,7 @@ def test_retrieve_conversation(client: R2RClient, test_conversation): # A new conversation might have no messages, so results should be an empty list assert isinstance(retrieved, list), "Expected list of messages" assert len(retrieved) == 0, ( - "Expected empty message list for a new conversation" - ) + "Expected empty message list for a new conversation") def test_delete_conversation(client: R2RClient): @@ -48,8 +47,7 @@ def test_delete_conversation(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.conversations.retrieve(id=conv_id) assert exc_info.value.status_code == 404, ( - "Wrong error code retrieving deleted conversation" - ) + "Wrong error code retrieving deleted conversation") def test_add_message(client: R2RClient, test_conversation): @@ -72,8 +70,7 @@ def test_retrieve_non_existent_conversation(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.conversations.retrieve(id=bad_id) assert exc_info.value.status_code == 404, ( - "Wrong error code for non-existent conversation" - ) + "Wrong error code for non-existent conversation") def test_delete_non_existent_conversation(client: R2RClient): @@ -81,8 +78,7 @@ def test_delete_non_existent_conversation(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.conversations.delete(id=bad_id) assert exc_info.value.status_code == 404, ( - "Wrong error code for delete non-existent" - ) + "Wrong error code for delete non-existent") def test_add_message_to_non_existent_conversation(client: R2RClient): @@ -95,8 +91,7 @@ def test_add_message_to_non_existent_conversation(client: R2RClient): ) # Expected a 404 since conversation doesn't exist assert exc_info.value.status_code == 404, ( - "Wrong error code for adding message to non-existent conversation" - ) + "Wrong error code for adding message to non-existent conversation") def test_update_message(client: R2RClient, test_conversation): @@ -112,21 +107,21 @@ def test_update_message(client: R2RClient, test_conversation): id=test_conversation, message_id=original_msg_id, content="Updated content", - metadata={"new_key": "new_value"}, + metadata={ + "new_key": "new_value" + }, ).results assert update_resp.message is not None, "No message returned after update" assert update_resp.metadata is not None, ( - "No metadata returned after update" - ) + "No metadata returned after update") assert update_resp.id is not None, "No metadata returned after update" # Retrieve the conversation with the new branch updated_conv = client.conversations.retrieve(id=test_conversation).results assert updated_conv, "No conversation returned after update" assert updated_conv[0].message.content == "Updated content", ( - "Message content not updated" - ) + "Message content not updated") # found_updated = any(msg["id"] == new_message_id and msg["message"]["content"] == "Updated content" for msg in updated_conv) # assert found_updated, "Updated message not found in the new branch" @@ -134,12 +129,11 @@ def test_update_message(client: R2RClient, test_conversation): def test_update_non_existent_message(client: R2RClient, test_conversation): fake_msg_id = str(uuid.uuid4()) with pytest.raises(R2RException) as exc_info: - client.conversations.update_message( - id=test_conversation, message_id=fake_msg_id, content="Should fail" - ) + client.conversations.update_message(id=test_conversation, + message_id=fake_msg_id, + content="Should fail") assert exc_info.value.status_code == 404, ( - "Wrong error code for updating non-existent message" - ) + "Wrong error code for updating non-existent message") def test_add_message_with_empty_content(client: R2RClient, test_conversation): @@ -151,8 +145,7 @@ def test_add_message_with_empty_content(client: R2RClient, test_conversation): ) # Check for 400 or a relevant error code depending on server validation assert exc_info.value.status_code == 400, ( - "Wrong error code or no error for empty content message" - ) + "Wrong error code or no error for empty content message") def test_add_message_invalid_role(client: R2RClient, test_conversation): @@ -163,8 +156,7 @@ def test_add_message_invalid_role(client: R2RClient, test_conversation): role="invalid_role", ) assert exc_info.value.status_code == 400, ( - "Wrong error code or no error for invalid role" - ) + "Wrong error code or no error for invalid role") def test_add_message_to_deleted_conversation(client: R2RClient): @@ -180,19 +172,19 @@ def test_add_message_to_deleted_conversation(client: R2RClient): role="user", ) assert exc_info.value.status_code == 404, ( - "Wrong error code for adding message to deleted conversation" - ) + "Wrong error code for adding message to deleted conversation") -def test_update_message_with_additional_metadata( - client: R2RClient, test_conversation -): +def test_update_message_with_additional_metadata(client: R2RClient, + test_conversation): # Add a message with initial metadata original_msg_id = client.conversations.add_message( id=test_conversation, content="Initial content", role="user", - metadata={"initial_key": "initial_value"}, + metadata={ + "initial_key": "initial_value" + }, ).results.id # Update the message with new content and additional metadata @@ -200,7 +192,9 @@ def test_update_message_with_additional_metadata( id=test_conversation, message_id=original_msg_id, content="Updated content", - metadata={"new_key": "new_value"}, + metadata={ + "new_key": "new_value" + }, ).results # Retrieve the conversation from the new branch @@ -212,18 +206,14 @@ def test_update_message_with_additional_metadata( None, ) assert updated_message is not None, ( - "Updated message not found in conversation" - ) + "Updated message not found in conversation") # Check that metadata includes old keys, new keys, and 'edited': True msg_metadata = updated_message.metadata assert msg_metadata.get("initial_key") == "initial_value", ( - "Old metadata not preserved" - ) + "Old metadata not preserved") assert msg_metadata.get("new_key") == "new_value", "New metadata not added" assert msg_metadata.get("edited") is True, ( - "'edited' flag not set in metadata" - ) + "'edited' flag not set in metadata") assert updated_message.message.content == "Updated content", ( - "Message content not updated" - ) + "Message content not updated") diff --git a/py/tests/integration/test_documents.py b/py/tests/integration/test_documents.py index f9f0b8add..1519940b5 100644 --- a/py/tests/integration/test_documents.py +++ b/py/tests/integration/test_documents.py @@ -35,9 +35,8 @@ def test_create_document_with_file(client: R2RClient, cleanup_documents): def test_create_document_with_raw_text(client: R2RClient, cleanup_documents): - resp = client.documents.create( - raw_text="This is raw text content.", run_with_orchestration=False - ) + resp = client.documents.create(raw_text="This is raw text content.", + run_with_orchestration=False) results = resp.results doc_id = cleanup_documents(results.document_id) @@ -47,8 +46,7 @@ def test_create_document_with_raw_text(client: R2RClient, cleanup_documents): retrieved = client.documents.retrieve(id=doc_id) retrieved_results = retrieved.results assert retrieved_results.id == doc_id, ( - "Failed to retrieve the ingested raw text document" - ) + "Failed to retrieve the ingested raw text document") def test_create_document_with_chunks(client: R2RClient, cleanup_documents): @@ -65,8 +63,7 @@ def test_create_document_with_chunks(client: R2RClient, cleanup_documents): retrieved = client.documents.retrieve(id=doc_id) retrieved_results = retrieved.results assert retrieved_results.id == doc_id, ( - "Failed to retrieve the chunk-based document" - ) + "Failed to retrieve the chunk-based document") def test_create_document_different_modes(client: R2RClient, cleanup_documents): @@ -111,9 +108,8 @@ def test_download_document(client: R2RClient, test_document): def test_delete_document(client: R2RClient): # Create a doc to delete - resp = client.documents.create( - raw_text="This is a temporary doc", run_with_orchestration=False - ).results + resp = client.documents.create(raw_text="This is a temporary doc", + run_with_orchestration=False).results doc_id = resp.document_id del_resp = client.documents.delete(id=doc_id).results assert del_resp.success, "Failed to delete document" @@ -127,7 +123,9 @@ def test_delete_document_by_filter(client: R2RClient): # Create a doc with unique metadata resp = client.documents.create( raw_text="Document to be filtered out", - metadata={"to_delete": "yes"}, + metadata={ + "to_delete": "yes" + }, run_with_orchestration=False, ).results doc_id = resp.document_id @@ -139,17 +137,15 @@ def test_delete_document_by_filter(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.documents.retrieve(id=doc_id) assert exc_info.value.status_code == 404, ( - "Document still exists after filter-based deletion" - ) + "Document still exists after filter-based deletion") # @pytest.mark.skip(reason="Only if superuser-specific logic is implemented") def test_list_document_collections(client: R2RClient, test_document): # This test assumes the currently logged in user is a superuser collections = client.documents.list_collections(id=test_document).results - assert isinstance(collections, list), ( - "Document collections list is not a list" - ) + assert isinstance(collections, + list), ("Document collections list is not a list") # @pytest.mark.skip( @@ -157,9 +153,8 @@ def test_list_document_collections(client: R2RClient, test_document): # ) def test_extract_document(client: R2RClient, test_document): time.sleep(10) - run_resp = client.documents.extract( - id=test_document, run_with_orchestration=False - ).results + run_resp = client.documents.extract(id=test_document, + run_with_orchestration=False).results assert run_resp.message is not None, "No message after extraction run" @@ -178,11 +173,9 @@ def test_list_entities(client: R2RClient, test_document): def test_list_relationships(client: R2RClient, test_document): try: relationships = client.documents.list_relationships( - id=test_document - ).results - assert isinstance(relationships, list), ( - "Relationships response not a list" - ) + id=test_document).results + assert isinstance(relationships, + list), ("Relationships response not a list") except R2RException as e: pytest.skip(f"No relationships extracted yet: {str(e)}") @@ -191,14 +184,13 @@ def test_search_documents(client: R2RClient, test_document): # Add some delay if indexing takes time time.sleep(1) query = "Temporary" - search_results = client.documents.search( - query=query, search_mode="custom", search_settings={"limit": 5} - ) + search_results = client.documents.search(query=query, + search_mode="custom", + search_settings={"limit": 5}) assert search_results.results is not None, "Search results key not found" # We cannot guarantee a match, but at least we got a well-formed response - assert isinstance(search_results.results, list), ( - "Search results not a list" - ) + assert isinstance(search_results.results, + list), ("Search results not a list") def test_list_document_chunks(mutable_client: R2RClient, cleanup_documents): @@ -207,8 +199,7 @@ def test_list_document_chunks(mutable_client: R2RClient, cleanup_documents): mutable_client.users.login(temp_user, "password") resp = mutable_client.documents.create( - chunks=["C1", "C2", "C3"], run_with_orchestration=False - ).results + chunks=["C1", "C2", "C3"], run_with_orchestration=False).results doc_id = cleanup_documents(resp.document_id) chunks_resp = mutable_client.documents.list_chunks(id=doc_id) results = chunks_resp.results @@ -221,8 +212,7 @@ def test_search_documents_extended(client: R2RClient, cleanup_documents): client.documents.create( raw_text="Aristotle was a Greek philosopher.", run_with_orchestration=False, - ).results.document_id - ) + ).results.document_id) time.sleep(1) # If indexing is asynchronous search_results = client.documents.search( @@ -231,8 +221,7 @@ def test_search_documents_extended(client: R2RClient, cleanup_documents): search_settings={"limit": 1}, ) assert search_results.results is not None, ( - "No results key in search response" - ) + "No results key in search response") assert len(search_results.results) > 0, "No documents found" @@ -248,8 +237,7 @@ def test_delete_document_non_existent(client): with pytest.raises(R2RException) as exc_info: client.documents.delete(id=bad_id) assert exc_info.value.status_code == 404, ( - "Wrong error code for delete non-existent" - ) + "Wrong error code for delete non-existent") # @pytest.mark.skip(reason="If your API restricts this endpoint to superusers") @@ -264,17 +252,15 @@ def test_get_document_collections_non_superuser(client): with pytest.raises(R2RException) as exc_info: non_super_client.documents.list_collections(id=document_id) assert exc_info.value.status_code == 403, ( - "Expected 403 for non-superuser collections access" - ) + "Expected 403 for non-superuser collections access") def test_access_document_not_owned(client: R2RClient, cleanup_documents): # Create a doc as superuser doc_id = cleanup_documents( client.documents.create( - raw_text="Owner doc test", run_with_orchestration=False - ).results.document_id - ) + raw_text="Owner doc test", + run_with_orchestration=False).results.document_id) # Now try to access with a non-superuser non_super_client = R2RClient(client.base_url) @@ -285,13 +271,11 @@ def test_access_document_not_owned(client: R2RClient, cleanup_documents): with pytest.raises(R2RException) as exc_info: non_super_client.documents.download(id=doc_id) assert exc_info.value.status_code == 403, ( - "Wrong error code for unauthorized access" - ) + "Wrong error code for unauthorized access") -def test_list_documents_with_pagination( - mutable_client: R2RClient, cleanup_documents -): +def test_list_documents_with_pagination(mutable_client: R2RClient, + cleanup_documents): temp_user = f"{uuid.uuid4()}@me.com" mutable_client.users.create(temp_user, "password") mutable_client.users.login(temp_user, "password") @@ -299,9 +283,8 @@ def test_list_documents_with_pagination( for i in range(3): cleanup_documents( mutable_client.documents.create( - raw_text=f"Doc {i}", run_with_orchestration=False - ).results.document_id - ) + raw_text=f"Doc {i}", + run_with_orchestration=False).results.document_id) listed = mutable_client.documents.list(limit=2, offset=0) results = listed.results @@ -311,9 +294,8 @@ def test_list_documents_with_pagination( def test_ingest_invalid_chunks(client): invalid_chunks = ["Valid chunk", 12345, {"not": "a string"}] with pytest.raises(R2RException) as exc_info: - client.documents.create( - chunks=invalid_chunks, run_with_orchestration=False - ) + client.documents.create(chunks=invalid_chunks, + run_with_orchestration=False) assert exc_info.value.status_code in [ 400, 422, @@ -323,29 +305,29 @@ def test_ingest_invalid_chunks(client): def test_ingest_too_many_chunks(client): excessive_chunks = ["Chunk"] * (1024 * 100 + 1) # Just over the limit with pytest.raises(R2RException) as exc_info: - client.documents.create( - chunks=excessive_chunks, run_with_orchestration=False - ) + client.documents.create(chunks=excessive_chunks, + run_with_orchestration=False) assert exc_info.value.status_code == 400, ( - "Wrong error code for exceeding max chunks" - ) + "Wrong error code for exceeding max chunks") def test_delete_by_complex_filter(client: R2RClient, cleanup_documents): doc1 = cleanup_documents( client.documents.create( raw_text="Doc with tag A", - metadata={"tag": "A"}, + metadata={ + "tag": "A" + }, run_with_orchestration=False, - ).results.document_id - ) + ).results.document_id) doc2 = cleanup_documents( client.documents.create( raw_text="Doc with tag B", - metadata={"tag": "B"}, + metadata={ + "tag": "B" + }, run_with_orchestration=False, - ).results.document_id - ) + ).results.document_id) filters = {"$or": [{"tag": {"$eq": "A"}}, {"tag": {"$eq": "B"}}]} del_resp = client.documents.delete_by_filter(filters).results @@ -356,25 +338,29 @@ def test_delete_by_complex_filter(client: R2RClient, cleanup_documents): with pytest.raises(R2RException) as exc_info: client.documents.retrieve(d_id) assert exc_info.value.status_code == 404, ( - f"Document {d_id} still exists after deletion" - ) + f"Document {d_id} still exists after deletion") def test_search_documents_no_match(client: R2RClient, cleanup_documents): doc_id = cleanup_documents( client.documents.create( raw_text="Just a random document", - metadata={"category": "unrelated"}, + metadata={ + "category": "unrelated" + }, run_with_orchestration=False, - ).results.document_id - ) + ).results.document_id) # Search for non-existent category search_results = client.documents.search( query="nonexistent category", search_mode="basic", search_settings={ - "filters": {"category": {"$eq": "doesnotexist"}}, + "filters": { + "category": { + "$eq": "doesnotexist" + } + }, "limit": 10, }, ) @@ -404,9 +390,7 @@ def test_delete_by_workflow_metadata(client: R2RClient, cleanup_documents): } }, run_with_orchestration=False, - ).results.document_id - ) - ) + ).results.document_id)) docs.append( cleanup_documents( @@ -420,9 +404,7 @@ def test_delete_by_workflow_metadata(client: R2RClient, cleanup_documents): } }, run_with_orchestration=False, - ).results.document_id - ) - ) + ).results.document_id)) docs.append( cleanup_documents( @@ -436,15 +418,21 @@ def test_delete_by_workflow_metadata(client: R2RClient, cleanup_documents): } }, run_with_orchestration=False, - ).results.document_id - ) - ) + ).results.document_id)) # Delete drafts with no reviews filters = { "$and": [ - {"metadata.workflow.state": {"$eq": "draft"}}, - {"metadata.workflow.review_count": {"$eq": 0}}, + { + "metadata.workflow.state": { + "$eq": "draft" + } + }, + { + "metadata.workflow.review_count": { + "$eq": 0 + } + }, ] } @@ -464,9 +452,8 @@ def test_delete_by_workflow_metadata(client: R2RClient, cleanup_documents): raise -def test_delete_by_classification_metadata( - client: R2RClient, cleanup_documents -): +def test_delete_by_classification_metadata(client: R2RClient, + cleanup_documents): """Test deletion by document classification metadata.""" docs = [] try: @@ -482,9 +469,7 @@ def test_delete_by_classification_metadata( } }, run_with_orchestration=False, - ).results.document_id - ) - ) + ).results.document_id)) docs.append( cleanup_documents( @@ -498,15 +483,21 @@ def test_delete_by_classification_metadata( } }, run_with_orchestration=False, - ).results.document_id - ) - ) + ).results.document_id)) # Delete HR documents with high retention filters = { "$and": [ - {"classification.department": {"$eq": "HR"}}, - {"classification.retention_years": {"$gt": 5}}, + { + "classification.department": { + "$eq": "HR" + } + }, + { + "classification.retention_years": { + "$gt": 5 + } + }, ] } @@ -542,9 +533,7 @@ def test_delete_by_version_metadata(client: R2RClient, cleanup_documents): }, }, run_with_orchestration=False, - ).results.document_id - ) - ) + ).results.document_id)) docs.append( cleanup_documents( @@ -558,15 +547,21 @@ def test_delete_by_version_metadata(client: R2RClient, cleanup_documents): }, }, run_with_orchestration=False, - ).results.document_id - ) - ) + ).results.document_id)) # Delete deprecated documents with legacy tag filters = { "$and": [ - {"metadata.version_info.status": {"$eq": "deprecated"}}, - {"metadata.version_info.tags": {"$in": ["legacy"]}}, + { + "metadata.version_info.status": { + "$eq": "deprecated" + } + }, + { + "metadata.version_info.tags": { + "$in": ["legacy"] + } + }, ] } diff --git a/py/tests/integration/test_filters.py b/py/tests/integration/test_filters.py index d6ccc6c2c..1f5d9c386 100644 --- a/py/tests/integration/test_filters.py +++ b/py/tests/integration/test_filters.py @@ -18,8 +18,8 @@ def setup_docs_with_collections(client: R2RClient): # Create documents with different collection arrangements: # doc1: [coll1] doc1 = client.documents.create( - raw_text="Doc in coll1" + random_suffix, run_with_orchestration=False - ).results.document_id + raw_text="Doc in coll1" + random_suffix, + run_with_orchestration=False).results.document_id client.collections.add_document(coll_ids[0], doc1) # doc2: [coll1, coll2] @@ -38,8 +38,8 @@ def setup_docs_with_collections(client: R2RClient): # doc4: [coll3] doc4 = client.documents.create( - raw_text="Doc in coll3" + random_suffix, run_with_orchestration=False - ).results.document_id + raw_text="Doc in coll3" + random_suffix, + run_with_orchestration=False).results.document_id client.collections.add_document(coll_ids[2], doc4) yield {"coll_ids": coll_ids, "doc_ids": [doc1, doc2, doc3, doc4]} @@ -57,18 +57,18 @@ def setup_docs_with_collections(client: R2RClient): pass -def test_collection_id_eq_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_eq_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids # collection_id = coll_ids[0] should match doc1 and doc2 only filters = {"collection_id": {"$eq": str(coll_ids[0])}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert { str(doc1), @@ -76,21 +76,20 @@ def test_collection_id_eq_filter( } == found_ids, f"Expected doc1 and doc2, got {found_ids}" -def test_collection_id_ne_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_ne_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids filters = {"collection_id": {"$ne": str(coll_ids[0])}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} - assert str(coll_ids[0]) not in found_ids, ( - f"Expected no coll0, got {found_ids}" - ) + assert str( + coll_ids[0]) not in found_ids, (f"Expected no coll0, got {found_ids}") # expected_ids = {doc3, doc4} @@ -99,9 +98,8 @@ def test_collection_id_ne_filter( # ), f"Expected {expected_ids} to be included in results, but got {found_ids}" -def test_collection_id_in_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_in_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -110,9 +108,10 @@ def test_collection_id_in_filter( # doc1 in coll0, doc2 in coll0, doc4 in coll2 # doc3 is in none filters = {"collection_id": {"$in": [str(coll_ids[0]), str(coll_ids[2])]}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert { str(doc1), @@ -121,32 +120,30 @@ def test_collection_id_in_filter( } == found_ids, f"Expected doc1, doc2, doc4, got {found_ids}" -def test_collection_id_nin_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_nin_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids filters = {"collection_id": {"$nin": [str(coll_ids[1])]}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} # expected_ids = {doc1, doc3, doc4} found_ids = {str(d.document_id) for d in listed} - assert str(coll_ids[1]) not in found_ids, ( - f"Expected no coll1, got {found_ids}" - ) + assert str( + coll_ids[1]) not in found_ids, (f"Expected no coll1, got {found_ids}") # assert expected_ids.issubset( # found_ids # ), f"Expected {expected_ids} to be included in results, but got {found_ids}" -def test_collection_id_contains_filter( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_contains_filter(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -155,9 +152,10 @@ def test_collection_id_contains_filter( # If collection_id {"$contains": "coll_ids[0]"}, docs must have coll0 in their array # That would be doc1 and doc2 only filters = {"collection_id": {"$contains": str(coll_ids[0])}} - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert { str(doc1), @@ -165,9 +163,8 @@ def test_collection_id_contains_filter( } == found_ids, f"Expected doc1 and doc2, got {found_ids}" -def test_collection_id_contains_multiple( - client: R2RClient, setup_docs_with_collections -): +def test_collection_id_contains_multiple(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc_ids = setup_docs_with_collections["doc_ids"] doc1, doc2, doc3, doc4 = doc_ids @@ -176,18 +173,20 @@ def test_collection_id_contains_multiple( # this should mean the doc's collection_ids contain ALL of these. # Only doc2 has coll0 AND coll1. doc1 only has coll0, doc3 no collections, doc4 only coll3. filters = { - "collection_id": {"$contains": [str(coll_ids[0]), str(coll_ids[1])]} + "collection_id": { + "$contains": [str(coll_ids[0]), str(coll_ids[1])] + } } - listed = client.retrieval.search( - query="whoami", search_settings={"filters": filters} - ).results.chunk_search_results + listed = client.retrieval.search(query="whoami", + search_settings={ + "filters": filters + }).results.chunk_search_results found_ids = {str(d.document_id) for d in listed} assert {str(doc2)} == found_ids, f"Expected doc2 only, got {found_ids}" -def test_delete_by_collection_id_eq( - client: R2RClient, setup_docs_with_collections -): +def test_delete_by_collection_id_eq(client: R2RClient, + setup_docs_with_collections): coll_ids = setup_docs_with_collections["coll_ids"] doc1, doc2, doc3, doc4 = setup_docs_with_collections["doc_ids"] diff --git a/py/tests/integration/test_graphs.py b/py/tests/integration/test_graphs.py index 744dd3b86..41b93d0be 100644 --- a/py/tests/integration/test_graphs.py +++ b/py/tests/integration/test_graphs.py @@ -7,6 +7,7 @@ from r2r import R2RClient, R2RException @pytest.fixture(scope="session") def config(): + class TestConfig: base_url = "http://localhost:7272" superuser_email = "admin@example.com" @@ -25,7 +26,8 @@ def client(config): @pytest.fixture def test_collection(client): - """Create a test collection (and thus a graph) for testing, then delete it afterwards.""" + """Create a test collection (and thus a graph) for testing, then delete it + afterwards.""" collection_id = client.collections.create( name=f"Test Collection {uuid.uuid4()}", description="A sample collection for graph tests", @@ -54,21 +56,19 @@ def test_update_graph(client: R2RClient, test_collection): new_name = "Updated Test Graph Name" new_description = "Updated test description" - resp = client.graphs.update( - collection_id=collection_id, name=new_name, description=new_description - ).results + resp = client.graphs.update(collection_id=collection_id, + name=new_name, + description=new_description).results assert resp.name == new_name, "Name not updated correctly" assert resp.description == new_description, ( - "Description not updated correctly" - ) + "Description not updated correctly") def test_list_entities(client: R2RClient, test_collection): collection_id = test_collection - resp = client.graphs.list_entities( - collection_id=collection_id, limit=5 - ).results + resp = client.graphs.list_entities(collection_id=collection_id, + limit=5).results assert isinstance(resp, list), "No results array in entities response" @@ -84,17 +84,15 @@ def test_create_and_get_entity(client: R2RClient, test_collection): ).results entity_id = str(create_resp.id) - resp = client.graphs.get_entity( - collection_id=collection_id, entity_id=entity_id - ).results + resp = client.graphs.get_entity(collection_id=collection_id, + entity_id=entity_id).results assert resp.name == entity_name, "Entity name mismatch" def test_list_relationships(client: R2RClient, test_collection): collection_id = test_collection - resp = client.graphs.list_relationships( - collection_id=collection_id, limit=5 - ).results + resp = client.graphs.list_relationships(collection_id=collection_id, + limit=5).results assert isinstance(resp, list), "No results array in relationships response" @@ -127,59 +125,56 @@ def test_create_and_get_relationship(client: R2RClient, test_collection): # Get relationship resp = client.graphs.get_relationship( - collection_id=collection_id, relationship_id=relationship_id - ).results + collection_id=collection_id, relationship_id=relationship_id).results assert resp.predicate == "related_to", "Relationship predicate mismatch" -def test_build_communities(client: R2RClient, test_collection): - collection_id = test_collection +# def test_build_communities(client: R2RClient, test_collection): +# collection_id = test_collection - # Create two entities - entity1 = client.graphs.create_entity( - collection_id=collection_id, - name="Entity 1", - description="Entity 1 description", - ).results - entity2 = client.graphs.create_entity( - collection_id=collection_id, - name="Entity 2", - description="Entity 2 description", - ).results +# # Create two entities +# entity1 = client.graphs.create_entity( +# collection_id=collection_id, +# name="Entity 1", +# description="Entity 1 description", +# ).results +# entity2 = client.graphs.create_entity( +# collection_id=collection_id, +# name="Entity 2", +# description="Entity 2 description", +# ).results - # Create relationship - rel_resp = client.graphs.create_relationship( - collection_id=str(collection_id), - subject="Entity 1", - subject_id=entity1.id, - predicate="related_to", - object="Entity 2", - object_id=entity2.id, - description="Test relationship", - ).results - relationship_id = str(rel_resp.id) +# # Create relationship +# rel_resp = client.graphs.create_relationship( +# collection_id=str(collection_id), +# subject="Entity 1", +# subject_id=entity1.id, +# predicate="related_to", +# object="Entity 2", +# object_id=entity2.id, +# description="Test relationship", +# ).results +# relationship_id = str(rel_resp.id) - # Build communities - resp = client.graphs.build( - collection_id=str(collection_id), - # graph_enrichment_settings={"use_semantic_clustering": True}, - run_with_orchestration=False, - ).results +# # Build communities +# resp = client.graphs.build( +# collection_id=str(collection_id), +# # graph_enrichment_settings={"use_semantic_clustering": True}, +# run_with_orchestration=False, +# ).results - # After building, list communities - resp = client.graphs.list_communities( - collection_id=str(collection_id), limit=5 - ).results - # We cannot guarantee communities are created if no entities or special conditions apply. - # If no communities, we may skip this assert or ensure at least no error occurred. - assert isinstance(resp, list), "No communities array returned." +# # After building, list communities +# resp = client.graphs.list_communities(collection_id=str(collection_id), +# limit=5).results +# # We cannot guarantee communities are created if no entities or special conditions apply. +# # If no communities, we may skip this assert or ensure at least no error occurred. +# assert isinstance(resp, list), "No communities array returned." def test_list_communities(client: R2RClient, test_collection): collection_id = test_collection - resp = client.graphs.list_communities( - collection_id=collection_id, limit=5 - ).results + resp = client.graphs.list_communities(collection_id=collection_id, + limit=5).results assert isinstance(resp, list), "No results array in communities response" @@ -197,9 +192,8 @@ def test_create_and_get_community(client: R2RClient, test_collection): ).results community_id = str(create_resp.id) - resp = client.graphs.get_community( - collection_id=collection_id, community_id=community_id - ).results + resp = client.graphs.get_community(collection_id=collection_id, + community_id=community_id).results assert resp.name == community_name, "Community name mismatch" diff --git a/py/tests/integration/test_indices.py b/py/tests/integration/test_indices.py index 98354caab..0af028054 100644 --- a/py/tests/integration/test_indices.py +++ b/py/tests/integration/test_indices.py @@ -5,6 +5,7 @@ from r2r import R2RClient, R2RException @pytest.fixture(scope="session") def config(): + class TestConfig: base_url = "http://localhost:7272" superuser_email = "admin@example.com" @@ -90,9 +91,7 @@ def test_list_indices(client: R2RClient): def test_error_handling(client: R2RClient): # Try to get a non-existent index with pytest.raises(R2RException) as exc_info: - client.indices.retrieve( - index_name="nonexistent_index", table_name="chunks" - ) + client.indices.retrieve(index_name="nonexistent_index", + table_name="chunks") assert "not found" in str(exc_info.value).lower(), ( - "Unexpected error message for non-existent index" - ) + "Unexpected error message for non-existent index") diff --git a/py/tests/integration/test_ingestion.py b/py/tests/integration/test_ingestion.py index ddebb6e77..509a6a123 100644 --- a/py/tests/integration/test_ingestion.py +++ b/py/tests/integration/test_ingestion.py @@ -1,5 +1,5 @@ -""" -Tests document ingestion functionality in R2R across all supported file types and modes. +"""Tests document ingestion functionality in R2R across all supported file +types and modes. Supported file types include: - Documents: .doc, .docx, .odt, .pdf, .rtf, .txt @@ -42,8 +42,7 @@ def file_ingestion( raw_text: Optional[str] = None, timeout: int = 600, ) -> UUID: - """ - Test ingestion of a file with the given parameters. + """Test ingestion of a file with the given parameters. Args: client: R2RClient instance @@ -100,8 +99,7 @@ def file_ingestion( break elif ingestion_status == "failed": raise AssertionError( - f"Document ingestion failed: {retrieval_response}" - ) + f"Document ingestion failed: {retrieval_response}") except R2RException as e: if e.status_code == 404: @@ -131,6 +129,7 @@ def file_ingestion( @pytest.fixture(scope="session") def config(): + class TestConfig: base_url = "http://localhost:7272" superuser_email = "admin@example.com" @@ -162,7 +161,7 @@ def client(config): ("jpeg", "core/examples/supported_file_types/jpeg.jpeg"), ("jpg", "core/examples/supported_file_types/jpg.jpg"), ("md", "core/examples/supported_file_types/md.md"), - ("msg", "core/examples/supported_file_types/msg.msg"), + # ("msg", "core/examples/supported_file_types/msg.msg"), ("odt", "core/examples/supported_file_types/odt.odt"), ("org", "core/examples/supported_file_types/org.org"), ("p7s", "core/examples/supported_file_types/p7s.p7s"), @@ -179,9 +178,8 @@ def client(config): ("xlsx", "core/examples/supported_file_types/xlsx.xlsx"), ], ) -def test_file_type_ingestion( - client: R2RClient, file_type: str, file_path: str -): +def test_file_type_ingestion(client: R2RClient, file_type: str, + file_path: str): """Test ingestion of specific file type.""" try: @@ -207,7 +205,8 @@ def test_file_type_ingestion( ], ) def test_hires_ingestion(client: R2RClient, file_type: str, file_path: str): - """Test hi-res ingestion with complex documents containing mixed content.""" + """Test hi-res ingestion with complex documents containing mixed + content.""" if file_type == "pdf": try: result = file_ingestion( @@ -221,8 +220,7 @@ def test_hires_ingestion(client: R2RClient, file_type: str, file_path: str): except Exception as e: # Changed from R2RException to Exception if "PDF processing requires Poppler to be installed" in str(e): pytest.skip( - "Skipping PDF test due to missing Poppler dependency" - ) + "Skipping PDF test due to missing Poppler dependency") raise else: result = file_ingestion( @@ -266,9 +264,8 @@ def test_raw_text_ingestion(client: R2RClient): """Test ingestion of raw text content.""" text_content = "This is a test document.\nIt has multiple lines.\nTesting raw text ingestion." - response = client.documents.create( - raw_text=text_content, ingestion_mode="fast" - ) + response = client.documents.create(raw_text=text_content, + ingestion_mode="fast") assert response is not None assert response.results is not None diff --git a/py/tests/integration/test_retrieval.py b/py/tests/integration/test_retrieval.py index 96694d361..87c069f7f 100644 --- a/py/tests/integration/test_retrieval.py +++ b/py/tests/integration/test_retrieval.py @@ -8,6 +8,7 @@ from r2r import R2RClient, R2RException @pytest.fixture(scope="session") def config(): + class TestConfig: base_url = "http://localhost:7272" superuser_email = "admin@example.com" @@ -25,9 +26,8 @@ def client(config): def test_search_basic_mode(client: R2RClient): - results = client.retrieval.search( - query="Aristotle", search_mode="basic" - ).results + results = client.retrieval.search(query="Aristotle", + search_mode="basic").results assert results is not None, "No results field in search response" @@ -36,7 +36,10 @@ def test_search_advanced_mode_with_filters(client: R2RClient): results = client.retrieval.search( query="Philosophy", search_mode="advanced", - search_settings={"filters": filters, "limit": 5}, + search_settings={ + "filters": filters, + "limit": 5 + }, ).results assert results is not None, "No results in advanced mode search" @@ -45,7 +48,10 @@ def test_search_custom_mode(client: R2RClient): results = client.retrieval.search( query="Greek philosophers", search_mode="custom", - search_settings={"use_semantic_search": True, "limit": 3}, + search_settings={ + "use_semantic_search": True, + "limit": 3 + }, ).results assert results is not None, "No results in custom mode search" @@ -53,8 +59,14 @@ def test_search_custom_mode(client: R2RClient): def test_rag_query(client: R2RClient): results = client.retrieval.rag( query="Summarize Aristotle's contributions to logic", - rag_generation_config={"stream": False, "max_tokens": 100}, - search_settings={"use_semantic_search": True, "limit": 3}, + rag_generation_config={ + "stream": False, + "max_tokens": 100 + }, + search_settings={ + "use_semantic_search": True, + "limit": 3 + }, ).results assert results.completion is not None, "RAG response missing 'completion'" @@ -64,14 +76,22 @@ def test_rag_with_filter(client: R2RClient): # generate a random string suffix = str(uuid.uuid4()) client.documents.create( - raw_text=f"Aristotle was a Greek philosopher, contributions to philosophy were in logic, {suffix}.", + raw_text= + f"Aristotle was a Greek philosopher, contributions to philosophy were in logic, {suffix}.", metadata={"tier": "test"}, ) results = client.retrieval.rag( query="What were aristotle's contributions to philosophy?", - rag_generation_config={"stream": False, "max_tokens": 100}, + rag_generation_config={ + "stream": False, + "max_tokens": 100 + }, search_settings={ - "filters": {"metadata.tier": {"$eq": "test"}}, + "filters": { + "metadata.tier": { + "$eq": "test" + } + }, "use_semantic_search": True, "limit": 3, }, @@ -82,8 +102,14 @@ def test_rag_with_filter(client: R2RClient): def test_rag_stream_query(client: R2RClient): resp = client.retrieval.rag( query="Detail the philosophical schools Aristotle influenced", - rag_generation_config={"stream": True, "max_tokens": 50}, - search_settings={"use_semantic_search": True, "limit": 2}, + rag_generation_config={ + "stream": True, + "max_tokens": 50 + }, + search_settings={ + "use_semantic_search": True, + "limit": 2 + }, ) # Consume a few chunks from the async generator @@ -105,8 +131,14 @@ def test_agent_query(client: R2RClient): msg = Message(role="user", content="What is Aristotle known for?") results = client.retrieval.agent( message=msg, - rag_generation_config={"stream": False, "max_tokens": 100}, - search_settings={"use_semantic_search": True, "limit": 3}, + rag_generation_config={ + "stream": False, + "max_tokens": 100 + }, + search_settings={ + "use_semantic_search": True, + "limit": 3 + }, ).results assert results is not None, "Agent response missing 'results'" assert len(results.messages) > 0, "No messages returned by agent" @@ -116,8 +148,14 @@ def test_agent_query_stream(client: R2RClient): msg = Message(role="user", content="Explain Aristotle's logic in steps.") resp = client.retrieval.agent( message=msg, - rag_generation_config={"stream": True, "max_tokens": 50}, - search_settings={"use_semantic_search": True, "limit": 3}, + rag_generation_config={ + "stream": True, + "max_tokens": 50 + }, + search_settings={ + "use_semantic_search": True, + "limit": 3 + }, ) def consume_stream(): @@ -134,22 +172,37 @@ def test_agent_query_stream(client: R2RClient): def test_completion(client: R2RClient): messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - {"role": "assistant", "content": "The capital of France is Paris."}, - {"role": "user", "content": "What about Italy?"}, + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is the capital of France?" + }, + { + "role": "assistant", + "content": "The capital of France is Paris." + }, + { + "role": "user", + "content": "What about Italy?" + }, ] resp = client.retrieval.completion( messages, - generation_config={"max_tokens": 50, "model": "openai/gpt-4o"}, + generation_config={ + "max_tokens": 50, + "model": "openai/gpt-4o" + }, ) - assert "results" in resp, "Completion response missing 'results'" - assert "choices" in resp["results"], "No choices in completion result" + assert resp.results is not None, "Completion response missing 'results'" + assert resp.results.choices is not None, "No choices in completion result" def test_embedding(client: R2RClient): text = "Who is Aristotle?" - resp = client.retrieval.embedding(text=text)["results"] + resp = client.retrieval.embedding(text=text).results assert len(resp) > 0, "No embedding vector returned" @@ -178,37 +231,41 @@ def test_no_results_scenario(client: R2RClient): def test_pagination_limit_one(client: R2RClient): - client.documents.create( - chunks=[ - "a" + " " + str(uuid.uuid4()), - "b" + " " + str(uuid.uuid4()), - "c" + " " + str(uuid.uuid4()), - ] - ) - results = client.retrieval.search( - query="Aristotle", search_mode="basic", search_settings={"limit": 1} - ).results + client.documents.create(chunks=[ + "a" + " " + str(uuid.uuid4()), + "b" + " " + str(uuid.uuid4()), + "c" + " " + str(uuid.uuid4()), + ]) + results = client.retrieval.search(query="Aristotle", + search_mode="basic", + search_settings={ + "limit": 1 + }).results assert len(results.chunk_search_results) == 1, ( - "Expected one result with limit=1" - ) + "Expected one result with limit=1") def test_pagination_offset(client: R2RClient): resp0 = client.retrieval.search( query="Aristotle", search_mode="basic", - search_settings={"limit": 1, "offset": 0}, + search_settings={ + "limit": 1, + "offset": 0 + }, ).results resp1 = client.retrieval.search( query="Aristotle", search_mode="basic", - search_settings={"limit": 1, "offset": 1}, + search_settings={ + "limit": 1, + "offset": 1 + }, ).results - assert ( - resp0.chunk_search_results[0].text - != resp1.chunk_search_results[0].text - ), "Offset should return different results" + assert (resp0.chunk_search_results[0].text + != resp1.chunk_search_results[0].text + ), "Offset should return different results" def test_rag_task_prompt_override(client: R2RClient): @@ -223,14 +280,19 @@ def test_rag_task_prompt_override(client: R2RClient): """ results = client.retrieval.rag( query="Tell me about Aristotle", - rag_generation_config={"stream": False, "max_tokens": 50}, - search_settings={"use_semantic_search": True, "limit": 3}, + rag_generation_config={ + "stream": False, + "max_tokens": 50 + }, + search_settings={ + "use_semantic_search": True, + "limit": 3 + }, task_prompt_override=custom_prompt, ).results answer = results.completion assert "[END-TEST-PROMPT]" in answer, ( - "Custom prompt override not reflected in RAG answer" - ) + "Custom prompt override not reflected in RAG answer") def test_agent_conversation_id(client: R2RClient): @@ -238,24 +300,34 @@ def test_agent_conversation_id(client: R2RClient): msg = Message(role="user", content="What is Aristotle known for?") results = client.retrieval.agent( message=msg, - rag_generation_config={"stream": False, "max_tokens": 50}, - search_settings={"use_semantic_search": True, "limit": 3}, + rag_generation_config={ + "stream": False, + "max_tokens": 50 + }, + search_settings={ + "use_semantic_search": True, + "limit": 3 + }, conversation_id=str(conversation_id), ).results - assert len(results.messages) > 0, ( - "No results from agent with conversation_id" - ) + assert len( + results.messages) > 0, ("No results from agent with conversation_id") msg2 = Message(role="user", content="Can you elaborate more?") results2 = client.retrieval.agent( message=msg2, - rag_generation_config={"stream": False, "max_tokens": 50}, - search_settings={"use_semantic_search": True, "limit": 3}, + rag_generation_config={ + "stream": False, + "max_tokens": 50 + }, + search_settings={ + "use_semantic_search": True, + "limit": 3 + }, conversation_id=str(conversation_id), ).results assert len(results2.messages) > 0, ( - "No results from agent in second turn of conversation" - ) + "No results from agent in second turn of conversation") def test_complex_filters_and_fulltext(client: R2RClient, test_collection): @@ -265,8 +337,12 @@ def test_complex_filters_and_fulltext(client: R2RClient, test_collection): # rating > 5 # include owner id and collection ids to make robust against other database interactions from other users filters = { - "rating": {"$gt": 5}, - "owner_id": {"$eq": str(user_id)}, + "rating": { + "$gt": 5 + }, + "owner_id": { + "$eq": str(user_id) + }, "collection_ids": { "$overlap": [str(test_collection["collection_id"])] }, @@ -274,17 +350,23 @@ def test_complex_filters_and_fulltext(client: R2RClient, test_collection): results = client.retrieval.search( query="a", search_mode=SearchMode.custom, - search_settings={"use_semantic_search": True, "filters": filters}, + search_settings={ + "use_semantic_search": True, + "filters": filters + }, ).results results = results.chunk_search_results assert len(results) == 2, ( - f"Expected 2 docs with rating > 5, got {len(results)}" - ) + f"Expected 2 docs with rating > 5, got {len(results)}") # category in [ancient, modern] filters = { - "metadata.category": {"$in": ["ancient", "modern"]}, - "owner_id": {"$eq": str(user_id)}, + "metadata.category": { + "$in": ["ancient", "modern"] + }, + "owner_id": { + "$eq": str(user_id) + }, "collection_ids": { "$overlap": [str(test_collection["collection_id"])] }, @@ -293,19 +375,33 @@ def test_complex_filters_and_fulltext(client: R2RClient, test_collection): results = client.retrieval.search( query="b", search_mode=SearchMode.custom, - search_settings={"use_semantic_search": True, "filters": filters}, + search_settings={ + "use_semantic_search": True, + "filters": filters + }, ).results chunk_search_results = results.chunk_search_results assert len(chunk_search_results) == 4, ( - f"Expected all 4 docs, got {len(chunk_search_results)}" - ) + f"Expected all 4 docs, got {len(chunk_search_results)}") # rating > 5 AND category=modern filters = { "$and": [ - {"metadata.rating": {"$gt": 5}}, - {"metadata.category": {"$eq": "modern"}}, - {"owner_id": {"$eq": str(user_id)}}, + { + "metadata.rating": { + "$gt": 5 + } + }, + { + "metadata.category": { + "$eq": "modern" + } + }, + { + "owner_id": { + "$eq": str(user_id) + } + }, { "collection_ids": { "$overlap": [str(test_collection["collection_id"])] @@ -316,7 +412,9 @@ def test_complex_filters_and_fulltext(client: R2RClient, test_collection): results = client.retrieval.search( query="d", search_mode=SearchMode.custom, - search_settings={"filters": filters}, + search_settings={ + "filters": filters + }, ).results chunk_search_results = results.chunk_search_results assert len(chunk_search_results) == 2, ( @@ -330,7 +428,9 @@ def test_complex_filters_and_fulltext(client: R2RClient, test_collection): "use_fulltext_search": True, "use_semantic_search": False, "filters": { - "owner_id": {"$eq": str(user_id)}, + "owner_id": { + "$eq": str(user_id) + }, "collection_ids": { "$overlap": [str(test_collection["collection_id"])] }, @@ -348,19 +448,34 @@ def test_complex_nested_filters(client: R2RClient, test_collection): # _setup_collection_with_documents(client) # ((category=ancient OR rating<5) AND tags contains 'philosophy') - print( - 'test_collection["collection_id"] = ', test_collection["collection_id"] - ) + print('test_collection["collection_id"] = ', + test_collection["collection_id"]) filters = { "$and": [ { "$or": [ - {"metadata.category": {"$eq": "ancient"}}, - {"metadata.rating": {"$lt": 5}}, + { + "metadata.category": { + "$eq": "ancient" + } + }, + { + "metadata.rating": { + "$lt": 5 + } + }, ] }, - {"metadata.tags": {"$contains": ["philosophy"]}}, - {"owner_id": {"$eq": str(client.users.me().results.id)}}, + { + "metadata.tags": { + "$contains": ["philosophy"] + } + }, + { + "owner_id": { + "$eq": str(client.users.me().results.id) + } + }, { "collection_ids": { "$overlap": [str(test_collection["collection_id"])] @@ -371,14 +486,15 @@ def test_complex_nested_filters(client: R2RClient, test_collection): results = client.retrieval.search( query="complex", - search_settings={"filters": filters}, + search_settings={ + "filters": filters + }, ).results chunk_search_results = results.chunk_search_results print("results -> ", chunk_search_results) assert len(chunk_search_results) == 2, ( - f"Expected 2 docs, got {len(chunk_search_results)}" - ) + f"Expected 2 docs, got {len(chunk_search_results)}") def test_filters_no_match(client: R2RClient): @@ -386,12 +502,13 @@ def test_filters_no_match(client: R2RClient): results = client.retrieval.search( query="noresults", search_mode="custom", - search_settings={"filters": filters}, + search_settings={ + "filters": filters + }, ).results chunk_search_results = results.chunk_search_results assert len(chunk_search_results) == 0, ( - f"Expected 0 docs, got {len(chunk_search_results)}" - ) + f"Expected 0 docs, got {len(chunk_search_results)}") def test_pagination_extremes(client: R2RClient): @@ -401,7 +518,10 @@ def test_pagination_extremes(client: R2RClient): results = client.retrieval.search( query="Aristotle", search_mode="basic", - search_settings={"limit": 10, "offset": offset}, + search_settings={ + "limit": 10, + "offset": offset + }, ).results chunk_search_results = results.chunk_search_results assert len(chunk_search_results) == 0, ( @@ -420,8 +540,7 @@ def test_full_text_stopwords(client: R2RClient): }, ) assert resp.results is not None, ( - "No results field in stopword query response" - ) + "No results field in stopword query response") def test_full_text_non_ascii(client: R2RClient): @@ -435,8 +554,7 @@ def test_full_text_non_ascii(client: R2RClient): }, ) assert resp.results is not None, ( - "No results field in non-ASCII query response" - ) + "No results field in non-ASCII query response") def test_missing_fields(client: R2RClient): @@ -444,7 +562,9 @@ def test_missing_fields(client: R2RClient): results = client.retrieval.search( query="missingfield", search_mode="custom", - search_settings={"filters": filters}, + search_settings={ + "filters": filters + }, ).results chunk_search_results = results.chunk_search_results assert len(chunk_search_results) == 0, ( @@ -455,12 +575,17 @@ def test_missing_fields(client: R2RClient): def test_rag_with_large_context(client: R2RClient): results = client.retrieval.rag( query="Explain the contributions of Kant in detail", - rag_generation_config={"stream": False, "max_tokens": 200}, - search_settings={"use_semantic_search": True, "limit": 10}, + rag_generation_config={ + "stream": False, + "max_tokens": 200 + }, + search_settings={ + "use_semantic_search": True, + "limit": 10 + }, ).results assert results.completion is not None, ( - "RAG large context missing 'completion'" - ) + "RAG large context missing 'completion'") completion = results.completion assert len(completion) > 0, "RAG large context returned empty answer" @@ -471,53 +596,65 @@ def test_agent_long_conversation(client: R2RClient): msg1 = Message(role="user", content="What were Aristotle's main ideas?") resp1 = client.retrieval.agent( message=msg1, - rag_generation_config={"stream": False, "max_tokens": 100}, - search_settings={"use_semantic_search": True, "limit": 5}, + rag_generation_config={ + "stream": False, + "max_tokens": 100 + }, + search_settings={ + "use_semantic_search": True, + "limit": 5 + }, conversation_id=str(conversation_id), ) assert resp1.results is not None, ( - "No results in first turn of conversation" - ) + "No results in first turn of conversation") - msg2 = Message( - role="user", content="How did these ideas influence modern philosophy?" - ) + msg2 = Message(role="user", + content="How did these ideas influence modern philosophy?") resp2 = client.retrieval.agent( message=msg2, - rag_generation_config={"stream": False, "max_tokens": 100}, - search_settings={"use_semantic_search": True, "limit": 5}, + rag_generation_config={ + "stream": False, + "max_tokens": 100 + }, + search_settings={ + "use_semantic_search": True, + "limit": 5 + }, conversation_id=str(conversation_id), ) assert resp2.results is not None, ( - "No results in second turn of conversation" - ) + "No results in second turn of conversation") msg3 = Message(role="user", content="Now tell me about Descartes.") resp3 = client.retrieval.agent( message=msg3, - rag_generation_config={"stream": False, "max_tokens": 100}, - search_settings={"use_semantic_search": True, "limit": 5}, + rag_generation_config={ + "stream": False, + "max_tokens": 100 + }, + search_settings={ + "use_semantic_search": True, + "limit": 5 + }, conversation_id=str(conversation_id), ) assert resp3.results is not None, ( - "No results in third turn of conversation" - ) + "No results in third turn of conversation") def test_filter_by_document_type(client: R2RClient): random_suffix = str(uuid.uuid4()) - client.documents.create( - chunks=[ - f"a {random_suffix}", - f"b {random_suffix}", - f"c {random_suffix}", - ] - ) + client.documents.create(chunks=[ + f"a {random_suffix}", + f"b {random_suffix}", + f"c {random_suffix}", + ]) filters = {"document_type": {"$eq": "txt"}} - results = client.retrieval.search( - query="a", search_settings={"filters": filters} - ).results + results = client.retrieval.search(query="a", + search_settings={ + "filters": filters + }).results chunk_search_results = results.chunk_search_results assert len(chunk_search_results) > 0, ( - "No results found for filter by document type" - ) + "No results found for filter by document type") diff --git a/py/tests/integration/test_retrieval_advanced.py b/py/tests/integration/test_retrieval_advanced.py index e01c53cb3..88719eccc 100644 --- a/py/tests/integration/test_retrieval_advanced.py +++ b/py/tests/integration/test_retrieval_advanced.py @@ -5,38 +5,42 @@ from r2r import R2RClient # Semantic Search Tests def test_semantic_search_with_near_duplicates(client: R2RClient): - """Test semantic search can handle and differentiate near-duplicate content""" + """Test semantic search can handle and differentiate near-duplicate + content.""" random_1 = str(uuid.uuid4()) random_2 = str(uuid.uuid4()) # Create two similar but distinct documents doc1 = client.documents.create( - raw_text=f"Aristotle was a Greek philosopher who studied logic {random_1}." + raw_text= + f"Aristotle was a Greek philosopher who studied logic {random_1}." ).results.document_id doc2 = client.documents.create( - raw_text=f"Aristotle, the Greek philosopher, studied formal logic {random_2}." + raw_text= + f"Aristotle, the Greek philosopher, studied formal logic {random_2}." ).results.document_id resp = client.retrieval.search( query="Tell me about Aristotle's work in logic", search_mode="custom", - search_settings={"use_semantic_search": True, "limit": 25}, + search_settings={ + "use_semantic_search": True, + "limit": 25 + }, ) results = resp.results.chunk_search_results # Both documents should be returned but with different scores scores = [ - r.score - for r in results + r.score for r in results if str(r.document_id) in [str(doc1), str(doc2)] ] assert len(scores) == 2, "Expected both similar documents" - assert len(set(scores)) == 2, ( - "Expected different scores for similar documents" - ) + assert len( + set(scores)) == 2, ("Expected different scores for similar documents") def test_semantic_search_multilingual(client: R2RClient): - """Test semantic search handles multilingual content""" + """Test semantic search handles multilingual content.""" # Create documents in different languages random_1 = str(uuid.uuid4()) random_2 = str(uuid.uuid4()) @@ -49,9 +53,10 @@ def test_semantic_search_multilingual(client: R2RClient): ] doc_ids = [] for text, lang in docs: - doc_id = client.documents.create( - raw_text=text, metadata={"language": lang} - ).results.document_id + doc_id = client.documents.create(raw_text=text, + metadata={ + "language": lang + }).results.document_id doc_ids.append(doc_id) # Query in different languages @@ -115,24 +120,25 @@ def test_semantic_search_multilingual(client: R2RClient): # RAG Tests def test_rag_context_window_limits(client: R2RClient): - """Test RAG handles documents at or near context window limits""" + """Test RAG handles documents at or near context window limits.""" # Create a document that approaches the context window limit random_1 = str(uuid.uuid4()) - large_text = ( - "Aristotle " * 1000 - ) # Adjust multiplier based on your context window + large_text = ("Aristotle " * 1000 + ) # Adjust multiplier based on your context window doc_id = client.documents.create( - raw_text=f"{large_text} {random_1}" - ).results.document_id + raw_text=f"{large_text} {random_1}").results.document_id resp = client.retrieval.rag( query="Summarize this text about Aristotle", - search_settings={"filters": {"document_id": {"$eq": str(doc_id)}}}, + search_settings={"filters": { + "document_id": { + "$eq": str(doc_id) + } + }}, rag_generation_config={"max_tokens": 100}, ) assert resp.results is not None, ( - "RAG should handle large context gracefully" - ) + "RAG should handle large context gracefully") # UNCOMMENT LATER @@ -148,7 +154,6 @@ def test_rag_context_window_limits(client: R2RClient): # ) # assert "results" in resp, "RAG should handle empty chunks gracefully" - # # Agent Tests # def test_agent_clarification_requests(client: R2RClient): # """Test agent's ability to request clarification for ambiguous queries""" @@ -168,7 +173,6 @@ def test_rag_context_window_limits(client: R2RClient): # ] # ), "Agent should request clarification for ambiguous queries" - ## TODO - uncomment later # def test_agent_source_citation_consistency(client: R2RClient): # """Test agent consistently cites sources across conversation turns""" @@ -200,7 +204,6 @@ def test_rag_context_window_limits(client: R2RClient): # s in sources2 for s in sources1 # ), "Follow-up should reference some original sources" - ## TODO - uncomment later # # Error Handling Tests # def test_malformed_filter_handling(client: R2RClient): @@ -223,7 +226,6 @@ def test_rag_context_window_limits(client: R2RClient): # 422, # ], f"Expected validation error for filter: {invalid_filter}" - ## TODO - Uncomment later # def test_concurrent_search_stability(client: R2RClient): # """Test system handles concurrent search requests properly""" @@ -250,7 +252,7 @@ def test_rag_context_window_limits(client: R2RClient): # Helper function for source extraction def _extract_sources(content: str) -> list[str]: - """Extract source citations from response content""" + """Extract source citations from response content.""" # This is a simplified version - implement based on your citation format import re diff --git a/py/tests/integration/test_system.py b/py/tests/integration/test_system.py index 200885d50..a1461b673 100644 --- a/py/tests/integration/test_system.py +++ b/py/tests/integration/test_system.py @@ -5,7 +5,6 @@ # from datetime import datetime # from r2r import R2RClient, R2RException, LimitSettings - # async def test_health_endpoint(aclient): # """Test health endpoint is accessible and not rate limited""" # # Health endpoint doesn't require authentication @@ -13,7 +12,6 @@ # response = await aclient.system.health() # assert response["results"]["message"] == "ok" - # async def test_system_status(aclient, config): # """Test system status endpoint returns correct data""" # # Login as superuser for system status @@ -28,7 +26,6 @@ # datetime.fromisoformat(stats["start_time"]) - # async def test_per_minute_route_limit(aclient, test_collection): # """Test route-specific per-minute limit for search endpoint""" # # Create and login as new user @@ -54,7 +51,6 @@ # assert "rate limit" in str(exc_info.value).lower() # await aclient.users.logout() - # async def test_global_per_minute_limit(aclient, test_collection): # """Test global per-minute limit""" # # Create and login as new user @@ -78,7 +74,6 @@ # assert "rate limit" in str(exc_info.value).lower() # await aclient.users.logout() - # async def test_global_per_minute_limit_split(aclient, test_collection): # """Test global per-minute limit""" # # Create and login as new user @@ -137,7 +132,6 @@ # # assert "monthly" in str(exc_info.value).lower() # # client.users.logout() - # async def test_non_superuser_system_access(aclient): # """Test system endpoint access control""" # # Create and login as regular user @@ -160,7 +154,6 @@ # await endpoint() # # assert exc_info.value.status_code == 403 - # async def test_limit_reset(aclient, test_collection): # """Test that per-minute limits reset after one minute""" # # Create and login as new user @@ -188,7 +181,6 @@ # ) # assert "results" in response - # ## THIS FAILS, BUT WE ARE OK WITH THIS EDGE CASE # # async def test_concurrent_requests(aclient, test_collection): # # """Test concurrent requests properly handle rate limits""" @@ -208,7 +200,6 @@ # # success_count = sum(1 for r in results if isinstance(r, dict)) # # assert success_count <= 5 # route_per_min limit - # async def test_user_specific_limits(aclient, config): # """Test user-specific limit overrides""" # # Create and login as new user @@ -237,7 +228,6 @@ # assert i >= 1 # Should fail after first request # break - # async def test_global_monthly_limit(aclient, test_collection): # """Test global monthly limit across all routes""" # test_user = f"test_user_{uuid.uuid4()}@example.com" @@ -265,7 +255,6 @@ # await aclient.users.me() # assert "monthly" in str(exc_info.value).lower() - # async def test_mixed_limits(aclient, test_collection): # """Test interaction between different types of limits""" # test_user = f"test_user_{uuid.uuid4()}@example.com" @@ -283,7 +272,6 @@ # await aclient.users.me() # assert "rate limit" in str(exc_info.value).lower() - # async def test_route_limit_inheritance(aclient, test_collection): # """Test that routes without specific limits inherit global limits""" # test_user = f"test_user_{uuid.uuid4()}@example.com" diff --git a/py/tests/integration/test_users.py b/py/tests/integration/test_users.py index 179f8b007..3e9ec7113 100644 --- a/py/tests/integration/test_users.py +++ b/py/tests/integration/test_users.py @@ -7,6 +7,7 @@ from r2r import R2RClient, R2RException @pytest.fixture(scope="session") def config(): + class TestConfig: base_url = "http://localhost:7272" superuser_email = "admin@example.com" @@ -30,9 +31,8 @@ def superuser_login(client: R2RClient, config): # client.users.logout() -def register_and_return_user_id( - client: R2RClient, email: str, password: str -) -> str: +def register_and_return_user_id(client: R2RClient, email: str, + password: str) -> str: return client.users.create(email, password).results.id @@ -53,8 +53,7 @@ def test_user_refresh_token(client: R2RClient): new_access_token = client.users.refresh_token().results.access_token.token assert new_access_token != old_access_token, ( - "Refresh token did not provide a new access token." - ) + "Refresh token did not provide a new access token.") def test_change_password(client: R2RClient): @@ -63,9 +62,8 @@ def test_change_password(client: R2RClient): new_password = "new_password456" register_and_return_user_id(client, random_email, old_password) client.users.login(random_email, old_password) - change_resp = client.users.change_password( - old_password, new_password - ).results + change_resp = client.users.change_password(old_password, + new_password).results assert change_resp.message is not None, "Change password failed." # Check old password no longer works @@ -73,8 +71,7 @@ def test_change_password(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.users.login(random_email, old_password) assert exc_info.value.status_code == 401, ( - "Old password should not work anymore." - ) + "Old password should not work anymore.") # New password should work client.users.login(random_email, new_password) @@ -82,7 +79,8 @@ def test_change_password(client: R2RClient): @pytest.mark.skip( - reason="Requires a real or mocked reset token retrieval if verification is implemented." + reason= + "Requires a real or mocked reset token retrieval if verification is implemented." ) def test_request_and_reset_password(client: R2RClient): # This test scenario assumes you can obtain a valid reset token somehow. @@ -155,36 +153,32 @@ def test_user_collections(client: R2RClient, superuser_login, config): client.users.logout() -def test_add_remove_user_from_collection( - client: R2RClient, superuser_login, config -): +def test_add_remove_user_from_collection(client: R2RClient, superuser_login, + config): random_email = f"{uuid.uuid4()}@example.com" password = "somepassword" user_id = register_and_return_user_id(client, random_email, password) # Add user to known collection add_resp = client.users.add_to_collection( - user_id, config.known_collection_id - ).results + user_id, config.known_collection_id).results assert add_resp.success, "Failed to add user to collection." # Verify collections = client.users.list_collections(user_id).results assert any( - str(col.id) == str(config.known_collection_id) for col in collections - ), "User not in collection after add." + str(col.id) == str(config.known_collection_id) + for col in collections), "User not in collection after add." # Remove user from collection remove_resp = client.users.remove_from_collection( - user_id, config.known_collection_id - ).results + user_id, config.known_collection_id).results assert remove_resp.success, "Failed to remove user from collection." collections_after = client.users.list_collections(user_id).results assert not any( - str(col.id) == str(config.known_collection_id) - for col in collections_after - ), "User still in collection after removal." + str(col.id) == str(config.known_collection_id) for col in + collections_after), "User still in collection after removal." client.users.logout() @@ -205,24 +199,20 @@ def test_delete_user(client: R2RClient): client.users.login(random_email, password) assert exc_info.value.status_code == 404, ( - "User still exists after deletion." - ) + "User still exists after deletion.") -def test_superuser_downgrade_permissions( - client: R2RClient, superuser_login, config -): +def test_superuser_downgrade_permissions(client: R2RClient, superuser_login, + config): user_email = f"test_super_{uuid.uuid4()}@test.com" user_password = "securepass" - new_user_id = register_and_return_user_id( - client, user_email, user_password - ) + new_user_id = register_and_return_user_id(client, user_email, + user_password) # Upgrade user to superuser upgraded_user = client.users.update(new_user_id, is_superuser=True).results assert upgraded_user.is_superuser == True, ( - "User not upgraded to superuser." - ) + "User not upgraded to superuser.") # Logout admin, login as new superuser client.users.logout() @@ -233,9 +223,8 @@ def test_superuser_downgrade_permissions( # Downgrade back to normal (re-login as original admin) client.users.logout() client.users.login(config.superuser_email, config.superuser_password) - downgraded_user = client.users.update( - new_user_id, is_superuser=False - ).results + downgraded_user = client.users.update(new_user_id, + is_superuser=False).results assert downgraded_user.is_superuser == False, "User not downgraded." # Now login as downgraded user and verify no superuser access @@ -244,8 +233,7 @@ def test_superuser_downgrade_permissions( with pytest.raises(R2RException) as exc_info: client.users.list() assert exc_info.value.status_code == 403, ( - "Downgraded user still has superuser privileges." - ) + "Downgraded user still has superuser privileges.") client.users.logout() @@ -276,8 +264,7 @@ def test_non_owner_delete_collection(client: R2RClient): with pytest.raises(R2RException) as exc_info: result = client.collections.delete(coll_id) assert exc_info.value.status_code == 403, ( - "Wrong error code for non-owner deletion attempt" - ) + "Wrong error code for non-owner deletion attempt") # Cleanup client.users.logout() @@ -354,8 +341,7 @@ def test_login_with_incorrect_password(client: R2RClient): with pytest.raises(R2RException) as exc_info: client.users.login(email, "wrongpass") assert exc_info.value.status_code == 401, ( - "Expected 401 when logging in with incorrect password." - ) + "Expected 401 when logging in with incorrect password.") client.users.logout() @@ -389,8 +375,7 @@ def test_verification_with_invalid_code(client: R2RClient): @pytest.mark.skip( - reason="Verification and token logic depends on implementation." -) + reason="Verification and token logic depends on implementation.") def test_password_reset_with_invalid_token(client: R2RClient): email = f"{uuid.uuid4()}@example.com" password = "initialpass" @@ -410,7 +395,7 @@ def test_password_reset_with_invalid_token(client: R2RClient): @pytest.fixture def user_with_api_key(client: R2RClient): - """Fixture that creates a user and returns their ID and API key details""" + """Fixture that creates a user and returns their ID and API key details.""" random_email = f"{uuid.uuid4()}@example.com" password = "api_key_test_password" user_id = client.users.create(random_email, password).results.id @@ -432,7 +417,8 @@ def user_with_api_key(client: R2RClient): def test_api_key_lifecycle(client: R2RClient): - """Test the complete lifecycle of API keys including creation, listing, and deletion""" + """Test the complete lifecycle of API keys including creation, listing, and + deletion.""" # Create user and login email = f"{uuid.uuid4()}@example.com" password = "api_key_test_password" @@ -451,8 +437,7 @@ def test_api_key_lifecycle(client: R2RClient): list_resp = client.users.list_api_keys(user_id).results assert len(list_resp) > 0, "No API keys found after creation" assert list_resp[0].key_id == key_id, ( - "Listed key ID doesn't match created key" - ) + "Listed key ID doesn't match created key") assert list_resp[0].updated_at is not None, "Updated timestamp missing" assert list_resp[0].public_key is not None, "Public key missing in list" @@ -462,15 +447,15 @@ def test_api_key_lifecycle(client: R2RClient): # Verify deletion list_resp_after = client.users.list_api_keys(user_id).results - assert not any(k.key_id == key_id for k in list_resp_after), ( - "API key still exists after deletion" - ) + assert not any( + k.key_id == key_id + for k in list_resp_after), ("API key still exists after deletion") client.users.logout() def test_api_key_authentication(client: R2RClient, user_with_api_key): - """Test using an API key for authentication""" + """Test using an API key for authentication.""" user_id, api_key, _ = user_with_api_key # Create new client with API key @@ -483,7 +468,7 @@ def test_api_key_authentication(client: R2RClient, user_with_api_key): def test_api_key_permissions(client: R2RClient, user_with_api_key): - """Test API key permission restrictions""" + """Test API key permission restrictions.""" user_id, api_key, _ = user_with_api_key # Create new client with API key @@ -494,24 +479,22 @@ def test_api_key_permissions(client: R2RClient, user_with_api_key): with pytest.raises(R2RException) as exc_info: api_client.users.list() assert exc_info.value.status_code == 403, ( - "Non-superuser API key shouldn't list users" - ) + "Non-superuser API key shouldn't list users") def test_invalid_api_key(client: R2RClient): - """Test behavior with invalid API key""" + """Test behavior with invalid API key.""" api_client = R2RClient(client.base_url) api_client.set_api_key("invalid.api.key") with pytest.raises(R2RException) as exc_info: api_client.users.me() assert exc_info.value.status_code == 401, ( - "Expected 401 for invalid API key" - ) + "Expected 401 for invalid API key") def test_multiple_api_keys(client: R2RClient): - """Test creating and managing multiple API keys for a single user""" + """Test creating and managing multiple API keys for a single user.""" email = f"{uuid.uuid4()}@example.com" password = "multi_key_test_password" user_id = client.users.create(email, password).results.id @@ -532,8 +515,7 @@ def test_multiple_api_keys(client: R2RClient): client.users.delete_api_key(user_id, key_id) current_keys = client.users.list_api_keys(user_id).results assert not any(k.key_id == key_id for k in current_keys), ( - f"Key {key_id} still exists after deletion" - ) + f"Key {key_id} still exists after deletion") client.users.logout() @@ -555,7 +537,9 @@ def test_update_user_limits_overrides(client: R2RClient): "global_per_min": 10, "monthly_limit": 3000, "route_overrides": { - "/some-route": {"route_per_min": 5}, + "/some-route": { + "route_per_min": 5 + }, }, } client.users.update(id=fetched_user.id, limits_overrides=overrides) @@ -565,9 +549,5 @@ def test_update_user_limits_overrides(client: R2RClient): updated_user = client.users.me().results assert len(updated_user.limits_overrides) != 0 assert updated_user.limits_overrides["global_per_min"] == 10 - assert ( - updated_user.limits_overrides["route_overrides"]["/some-route"][ - "route_per_min" - ] - == 5 - ) + assert (updated_user.limits_overrides["route_overrides"]["/some-route"] + ["route_per_min"] == 5) diff --git a/py/tests/scaling/loadTester.py b/py/tests/scaling/loadTester.py index ed56fcd2c..c75f57a7a 100644 --- a/py/tests/scaling/loadTester.py +++ b/py/tests/scaling/loadTester.py @@ -45,6 +45,7 @@ class Metrics: class LoadTester: + def __init__(self, base_url: str): self.base_url = base_url self.metrics: list[Metrics] = [] @@ -54,7 +55,8 @@ class LoadTester: self.client = R2RAsyncClient(base_url) async def safe_call(self, coro, timeout, operation_desc="operation"): - """Safely call an async function with a timeout and handle exceptions.""" + """Safely call an async function with a timeout and handle + exceptions.""" try: return await asyncio.wait_for(coro, timeout=timeout) except asyncio.TimeoutError: @@ -84,11 +86,10 @@ class LoadTester: timeout=LOGIN_TIMEOUT, operation_desc=f"login user {user_email}", ) - user = ( - {"email": user_email, "password": password} - if login_result - else None - ) + user = ({ + "email": user_email, + "password": password + } if login_result else None) # Ingest documents for user files = glob("core/examples/data/*") @@ -108,7 +109,7 @@ class LoadTester: return user async def setup_users(self): - """Initialize users and their documents""" + """Initialize users and their documents.""" print("Setting up users...") setup_tasks = [] @@ -167,18 +168,16 @@ class LoadTester: end_time=end_time, status=status, duration_ms=duration_ms, - ) - ) + )) # Wait according to queries per second rate await asyncio.sleep(max(0, 1 / QUERIES_PER_SECOND)) def calculate_statistics(self): - """Calculate and print test statistics""" + """Calculate and print test statistics.""" durations = [m.duration_ms for m in self.metrics] successful_requests = len( - [m for m in self.metrics if m.status == "success"] - ) + [m for m in self.metrics if m.status == "success"]) failed_requests = len([m for m in self.metrics if m.status == "error"]) print("\nTest Results:") @@ -204,7 +203,7 @@ class LoadTester: ) async def run_load_test(self): - """Main load test execution""" + """Main load test execution.""" await self.setup_users() if not self.users: diff --git a/py/tests/unit/conftest.py b/py/tests/unit/conftest.py index 9db1ef7ee..027739fd4 100644 --- a/py/tests/unit/conftest.py +++ b/py/tests/unit/conftest.py @@ -16,8 +16,7 @@ from core.providers.database.postgres import ( PostgresPromptsHandler, ) from core.providers.database.users import ( # Make sure this import is correct - PostgresUserHandler, -) + PostgresUserHandler, ) TEST_DB_CONNECTION_STRING = os.environ.get( "TEST_DB_CONNECTION_STRING", @@ -42,9 +41,8 @@ async def db_provider(): dimension = 4 quantization_type = VectorQuantizationType.FP32 - db_provider = PostgresDatabaseProvider( - db_config, dimension, crypto_provider, quantization_type - ) + db_provider = PostgresDatabaseProvider(db_config, dimension, + crypto_provider, quantization_type) await db_provider.initialize() yield db_provider @@ -129,7 +127,8 @@ async def graphs_handler(db_provider): connection_manager=connection_manager, dimension=dimension, quantization_type=quantization_type, - collections_handler=None, # if needed, or await collections_handler fixture + collections_handler= + None, # if needed, or await collections_handler fixture ) await handler.create_tables() return handler @@ -149,8 +148,7 @@ async def limits_handler(db_provider): await handler.create_tables() # Optionally truncate await connection_manager.execute_query( - f"TRUNCATE {handler._get_table_name('request_log')};" - ) + f"TRUNCATE {handler._get_table_name('request_log')};") return handler @@ -168,20 +166,17 @@ async def users_handler(db_provider, crypto_provider): # Optionally clean up users table before each test await connection_manager.execute_query( - f"TRUNCATE {handler._get_table_name('users')} CASCADE;" - ) + f"TRUNCATE {handler._get_table_name('users')} CASCADE;") await connection_manager.execute_query( - f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;" - ) + f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;") return handler @pytest.fixture async def prompt_handler(db_provider): - """ - Returns an instance of PostgresPromptsHandler, creating the necessary tables first. - """ + """Returns an instance of PostgresPromptsHandler, creating the necessary + tables first.""" # from core.providers.database.postgres_prompts import PostgresPromptsHandler project_name = db_provider.project_name diff --git a/py/tests/unit/test_chunks.py b/py/tests/unit/test_chunks.py index 5c4b26045..a64ef63ff 100644 --- a/py/tests/unit/test_chunks.py +++ b/py/tests/unit/test_chunks.py @@ -9,17 +9,16 @@ from r2r import R2RAsyncClient, R2RException class AsyncR2RTestClient: - """Wrapper to ensure async operations use the correct event loop""" + """Wrapper to ensure async operations use the correct event loop.""" def __init__(self, base_url: str = "http://localhost:7272"): self.client = R2RAsyncClient(base_url) - async def create_document( - self, chunks: list[str], run_with_orchestration: bool = False - ): + async def create_document(self, + chunks: list[str], + run_with_orchestration: bool = False): response = await self.client.documents.create( - chunks=chunks, run_with_orchestration=run_with_orchestration - ) + chunks=chunks, run_with_orchestration=run_with_orchestration) return response.results.document_id, [] async def delete_document(self, doc_id: str) -> None: @@ -33,12 +32,15 @@ class AsyncR2RTestClient: response = await self.client.chunks.retrieve(id=chunk_id) return response.results - async def update_chunk( - self, chunk_id: str, text: str, metadata: Optional[dict] = None - ): - response = await self.client.chunks.update( - {"id": chunk_id, "text": text, "metadata": metadata or {}} - ) + async def update_chunk(self, + chunk_id: str, + text: str, + metadata: Optional[dict] = None): + response = await self.client.chunks.update({ + "id": chunk_id, + "text": text, + "metadata": metadata or {} + }) return response.results async def delete_chunk(self, chunk_id: str): @@ -47,8 +49,7 @@ class AsyncR2RTestClient: async def search_chunks(self, query: str, limit: int = 5): response = await self.client.chunks.search( - query=query, search_settings={"limit": limit} - ) + query=query, search_settings={"limit": limit}) return response.results @@ -76,8 +77,7 @@ async def test_document( uuid_1 = uuid.uuid4() uuid_2 = uuid.uuid4() doc_id, _ = await test_client.create_document( - [f"Test chunk 1_{uuid_1}", f"Test chunk 2_{uuid_2}"] - ) + [f"Test chunk 1_{uuid_1}", f"Test chunk 2_{uuid_2}"]) await asyncio.sleep(5) # Wait for ingestion chunks = await test_client.list_chunks(str(doc_id)) yield doc_id, chunks @@ -86,14 +86,13 @@ async def test_document( class TestChunks: + @pytest.mark.asyncio - async def test_create_and_list_chunks( - self, test_client: AsyncR2RTestClient - ): + async def test_create_and_list_chunks(self, + test_client: AsyncR2RTestClient): # Create document with chunks doc_id, _ = await test_client.create_document( - ["Hello chunk", "World chunk"] - ) + ["Hello chunk", "World chunk"]) await asyncio.sleep(1) # Wait for ingestion # List and verify chunks @@ -104,36 +103,31 @@ class TestChunks: await test_client.delete_document(doc_id) @pytest.mark.asyncio - async def test_retrieve_chunk( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_retrieve_chunk(self, test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id retrieved = await test_client.retrieve_chunk(chunk_id) assert str(retrieved.id) == str(chunk_id), "Retrieved wrong chunk ID" assert retrieved.text.split("_")[0] == "Test chunk 1", ( - "Chunk text mismatch" - ) + "Chunk text mismatch") @pytest.mark.asyncio - async def test_update_chunk( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_update_chunk(self, test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id # Update chunk - updated = await test_client.update_chunk( - str(chunk_id), "Updated text", {"version": 2} - ) + updated = await test_client.update_chunk(str(chunk_id), "Updated text", + {"version": 2}) assert updated.text == "Updated text", "Chunk text not updated" assert updated.metadata["version"] == 2, "Metadata not updated" @pytest.mark.asyncio - async def test_delete_chunk( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_delete_chunk(self, test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id @@ -151,12 +145,10 @@ class TestChunks: random_1 = uuid.uuid4() random_2 = uuid.uuid4() # Create searchable document - doc_id, _ = await test_client.create_document( - [ - f"Aristotle reference {random_1}", - f"Another piece of text {random_2}", - ] - ) + doc_id, _ = await test_client.create_document([ + f"Aristotle reference {random_1}", + f"Another piece of text {random_2}", + ]) await asyncio.sleep(1) # Wait for indexing # Search @@ -167,9 +159,9 @@ class TestChunks: await test_client.delete_document(doc_id) @pytest.mark.asyncio - async def test_unauthorized_chunk_access( - self, test_client: AsyncR2RTestClient, test_document - ): + async def test_unauthorized_chunk_access(self, + test_client: AsyncR2RTestClient, + test_document): doc_id, chunks = test_document chunk_id = chunks[0].id @@ -185,9 +177,8 @@ class TestChunks: assert exc_info.value.status_code == 403 @pytest.mark.asyncio - async def test_list_chunks_with_filters( - self, test_client: AsyncR2RTestClient - ): + async def test_list_chunks_with_filters(self, + test_client: AsyncR2RTestClient): """Test listing chunks with owner_id filter.""" # Create and login as temporary user temp_email = f"{uuid.uuid4()}@example.com" @@ -197,8 +188,7 @@ class TestChunks: try: # Create a document with chunks doc_id, _ = await test_client.create_document( - ["Test chunk 1", "Test chunk 2"] - ) + ["Test chunk 1", "Test chunk 2"]) await asyncio.sleep(1) # Wait for ingestion # Test listing chunks (filters automatically applied on server) @@ -226,9 +216,8 @@ class TestChunks: await test_client.logout_user() @pytest.mark.asyncio - async def test_list_chunks_pagination( - self, test_client: AsyncR2RTestClient - ): + async def test_list_chunks_pagination(self, + test_client: AsyncR2RTestClient): """Test chunk listing with pagination.""" # Create and login as temporary user temp_email = f"{uuid.uuid4()}@example.com" @@ -245,23 +234,20 @@ class TestChunks: # Test first page response1 = await test_client.client.chunks.list(offset=0, limit=2) - assert len(response1.results) == 2, ( - "Expected 2 results on first page" - ) + assert len( + response1.results) == 2, ("Expected 2 results on first page") # Test second page response2 = await test_client.client.chunks.list(offset=2, limit=2) - assert len(response2.results) == 2, ( - "Expected 2 results on second page" - ) + assert len( + response2.results) == 2, ("Expected 2 results on second page") # Verify no duplicate results ids_page1 = {str(chunk.id) for chunk in response1.results} ids_page2 = {str(chunk.id) for chunk in response2.results} assert not ids_page1.intersection(ids_page2), ( - "Found duplicate chunks across pages" - ) + "Found duplicate chunks across pages") finally: # Cleanup @@ -274,8 +260,7 @@ class TestChunks: @pytest.mark.asyncio async def test_list_chunks_with_multiple_documents( - self, test_client: AsyncR2RTestClient - ): + self, test_client: AsyncR2RTestClient): """Test listing chunks across multiple documents.""" # Create and login as temporary user temp_email = f"{uuid.uuid4()}@example.com" @@ -287,8 +272,7 @@ class TestChunks: # Create multiple documents for i in range(2): doc_id, _ = await test_client.create_document( - [f"Doc {i} chunk 1", f"Doc {i} chunk 2"] - ) + [f"Doc {i} chunk 1", f"Doc {i} chunk 2"]) doc_ids.append(doc_id) await asyncio.sleep(5) # Wait for ingestion @@ -299,11 +283,12 @@ class TestChunks: assert len(response.results) == 4, "Expected 4 total chunks" chunk_doc_ids = { - str(chunk.document_id) for chunk in response.results + str(chunk.document_id) + for chunk in response.results } - assert all(str(doc_id) in chunk_doc_ids for doc_id in doc_ids), ( - "Got chunks from wrong documents" - ) + assert all( + str(doc_id) in chunk_doc_ids + for doc_id in doc_ids), ("Got chunks from wrong documents") finally: # Cleanup diff --git a/py/tests/unit/test_citations.py b/py/tests/unit/test_citations.py index d2250bc43..8d7af30bd 100644 --- a/py/tests/unit/test_citations.py +++ b/py/tests/unit/test_citations.py @@ -22,10 +22,8 @@ from core.base import ( @pytest.fixture def empty_aggregate(): - """ - Returns an AggregateSearchResult with no chunk search results, - no graph results, no web results, etc. - """ + """Returns an AggregateSearchResult with no chunk search results, no graph + results, no web results, etc.""" return AggregateSearchResult( chunk_search_results=[], graph_search_results=[], @@ -36,10 +34,8 @@ def empty_aggregate(): @pytest.fixture def small_aggregate(): - """ - Returns an AggregateSearchResult with, say, 3 chunk search results - so we can test out-of-range bracket references. - """ + """Returns an AggregateSearchResult with, say, 3 chunk search results so we + can test out-of-range bracket references.""" chunk1 = ChunkSearchResult( id=generate_id("chunk-1"), document_id=generate_id("doc-1"), @@ -76,10 +72,9 @@ def small_aggregate(): def test_no_citations_found(empty_aggregate): - """ - If the LLM text has no bracket references, we should return an empty list from the extraction, - and no changes when we attempt to reassign or map them to sources. - """ + """If the LLM text has no bracket references, we should return an empty + list from the extraction, and no changes when we attempt to reassign or map + them to sources.""" text = "This is some text without any bracket references." raw_citations = extract_citations(text) assert len(raw_citations) == 0 @@ -96,8 +91,8 @@ def test_no_citations_found(empty_aggregate): def test_single_citation_basic(empty_aggregate): - """ - A single bracket reference [1]. + """A single bracket reference [1]. + Should remain as [1] after reassigning, with snippet expanded. """ text = "This is a short sentence [1]. Another sentence." @@ -120,19 +115,19 @@ def test_single_citation_basic(empty_aggregate): def test_multiple_citations_in_order(small_aggregate): - """ - Suppose LLM text has 3 bracket references, e.g. [1], [2], [3]. - We'll confirm they remain in ascending order after reassign_citations_in_order - and confirm they map 1->chunk1, 2->chunk2, 3->chunk3 + """Suppose LLM text has 3 bracket references, e.g. [1], [2], [3]. + + We'll confirm they remain in ascending order after + reassign_citations_in_order and confirm they map 1->chunk1, 2->chunk2, + 3->chunk3 """ text = "Chunk #1 is [1]. Then chunk #2 is [2]. Finally chunk #3 is [3]." raw_citations = extract_citations(text) assert len(raw_citations) == 3 new_text, new_citations = reassign_citations_in_order(text, raw_citations) - assert ( - new_text == text - ) # They remain [1], [2], [3], no re-labelling needed + assert (new_text == text + ) # They remain [1], [2], [3], no re-labelling needed assert [c.index for c in new_citations] == [1, 2, 3] collector = SearchResultsCollector() @@ -146,10 +141,8 @@ def test_multiple_citations_in_order(small_aggregate): def test_descending_citations(small_aggregate): - """ - If the text references [3], [2], [1] in that order, we want them - re-labeled as [1], [2], [3] in ascending order in the final text. - """ + """If the text references [3], [2], [1] in that order, we want them re- + labeled as [1], [2], [3] in ascending order in the final text.""" text = "First mention is [3], then second mention is [2], last mention is [1]." # Extract raw_citations = extract_citations(text) @@ -161,9 +154,8 @@ def test_descending_citations(small_aggregate): assert "[1]" in new_text assert "[2]" in new_text assert "[3]" in new_text - assert ( - "[3]" not in new_text[: new_text.find("[1]")] - ) # ensure the order is correct + assert ("[3]" not in new_text[:new_text.find("[1]")] + ) # ensure the order is correct collector = SearchResultsCollector() collector.add_aggregate_result(small_aggregate) @@ -173,9 +165,8 @@ def test_descending_citations(small_aggregate): # So let's confirm that mapped[0] is indeed chunk #3: assert mapped[0].sourceType == "chunk" - assert ( - mapped[0].text == "Sample chunk text #3" - ) # or check .metadata["title"] == "Doc3.pdf" + assert (mapped[0].text == "Sample chunk text #3" + ) # or check .metadata["title"] == "Doc3.pdf" # The second bracket => aggregator #2 => chunk #2 assert mapped[1].sourceType == "chunk" @@ -187,10 +178,8 @@ def test_descending_citations(small_aggregate): def test_out_of_range_brackets(small_aggregate): - """ - If we have bracket references [1], [2], [5], but only 3 chunk results total, - bracket #5 should map to placeholders. - """ + """If we have bracket references [1], [2], [5], but only 3 chunk results + total, bracket #5 should map to placeholders.""" text = "We talk about chunk #1 [1], chunk #2 [2], and chunk #5 [5]." raw_citations = extract_citations(text) assert len(raw_citations) == 3 @@ -227,9 +216,9 @@ def test_out_of_range_brackets(small_aggregate): def test_zero_brackets_still_converts(small_aggregate): - """ - If the text references [0] or negative, - normally bracket references won't parse that. + """If the text references [0] or negative, normally bracket references + won't parse that. + We can confirm we skip them or treat them as is. """ text = "This is some unusual text with a bracket [0]." @@ -255,28 +244,24 @@ def test_zero_brackets_still_converts(small_aggregate): def test_snippet_extraction_basic(): - """ - If the text has a short sentence, confirm the snippet is that sentence only. - """ + """If the text has a short sentence, confirm the snippet is that sentence + only.""" text = "Hello world. This is a test [1]. Next sentence!" raw_citations = extract_citations(text) # We expect 1 bracket reference assert len(raw_citations) == 1 cit = raw_citations[0] # snippet should ideally be 'This is a test [1].' - snippet = text[cit.snippetStartIndex : cit.snippetEndIndex] + snippet = text[cit.snippetStartIndex:cit.snippetEndIndex] assert "[1]" in snippet - assert ( - "Next sentence!" not in snippet - ) # Because the code should stop at exclamation + assert ("Next sentence!" + not in snippet) # Because the code should stop at exclamation def test_all_upper_bound(small_aggregate): - """ - If the text references [1], [2], [3], [4], [5] but we only have 3 chunk results, - brackets 4 and 5 should still re-label to [4], [5] or whatever the logic does, - but they won't map to a real chunk => placeholders - """ + """If the text references [1], [2], [3], [4], [5] but we only have 3 chunk + results, brackets 4 and 5 should still re-label to [4], [5] or whatever the + logic does, but they won't map to a real chunk => placeholders.""" text = "[1], [2], [3], [4], and [5] are references in ascending order." raw_citations = extract_citations(text) assert len(raw_citations) == 5 @@ -299,16 +284,15 @@ def test_all_upper_bound(small_aggregate): def test_repeated_bracket_ref_basic(empty_aggregate): - """ - If the text uses the same bracket [2] multiple times, we want them all to remain - the same bracket index after relabeling. - For instance, if rawIndex=2 is repeated 3 times, the final text might become - [1],[1],[1] (if 2 is the first unique bracket encountered). + """If the text uses the same bracket [2] multiple times, we want them all + to remain the same bracket index after relabeling. + + For instance, if rawIndex=2 is repeated 3 times, the final text might + become [1],[1],[1] (if 2 is the first unique bracket encountered). """ # The LLM text has repeated "[2]" references text = ( - "This sentence has [2]. Another mention of [2]. And yet another [2]." - ) + "This sentence has [2]. Another mention of [2]. And yet another [2].") raw_citations = extract_citations(text) # We expect 3 bracket occurrences, all rawIndex=2 @@ -322,14 +306,12 @@ def test_repeated_bracket_ref_basic(empty_aggregate): # The final text should have the *same* bracket each time (e.g. `[1],[1],[1]`) # because there's only one unique oldRef = 2, which gets mapped to newIndex=1 assert new_text.count("[1]") == 3, ( - f"Expected all references to become [1], got: {new_text}" - ) + f"Expected all references to become [1], got: {new_text}") # And the new_citations should all share index=1 for cit in new_citations: assert cit.index == 1, ( - f"Expected repeated bracket index=1, got: {cit.index}" - ) + f"Expected repeated bracket index=1, got: {cit.index}") # Also confirm rawIndex=2 for all occurrences assert cit.rawIndex == 2 @@ -344,10 +326,9 @@ def test_repeated_bracket_ref_basic(empty_aggregate): def test_repeated_bracket_ref_with_two_values(small_aggregate): - """ - Tests a scenario where the text references [3], [3], [1], [3]. - We want all the [3] occurrences to remain the same final bracket, - and the [1] to remain or become [1] in ascending order. + """Tests a scenario where the text references [3], [3], [1], [3]. We want + all the [3] occurrences to remain the same final bracket, and the [1] to + remain or become [1] in ascending order. The order we see them: oldRef=3, oldRef=3, oldRef=1, oldRef=3 The unique old refs are {1,3}, so final brackets might map: @@ -357,10 +338,8 @@ def test_repeated_bracket_ref_with_two_values(small_aggregate): or if the code sorts them ascending by the numeric oldRef, then oldRef=1 => newRef=1, oldRef=3 => newRef=2 """ - text = ( - "First mention is [3], second mention also [3], " - "then we have [1], and again [3]." - ) + text = ("First mention is [3], second mention also [3], " + "then we have [1], and again [3].") raw_citations = extract_citations(text) # The LLM has bracket #3 in positions 1,2,4; bracket #1 in position 3 assert len(raw_citations) == 4 @@ -373,9 +352,8 @@ def test_repeated_bracket_ref_with_two_values(small_aggregate): bracket_matches = [m.group() for m in re.finditer(r"\[\d+\]", new_text)] unique_brackets = set(bracket_matches) assert len(unique_brackets) == 2, ( - "Expected exactly 2 bracket values in final text. Got: " - + str(unique_brackets) - ) + "Expected exactly 2 bracket values in final text. Got: " + + str(unique_brackets)) # Check that oldRef=3 => newIndex=2 for all occurrences, oldRef=1 => newIndex=1 found_1 = [c for c in new_cits if c.rawIndex == 1] @@ -384,14 +362,13 @@ def test_repeated_bracket_ref_with_two_values(small_aggregate): f"Expected exactly one bracket occurrence for oldRef=1, got {len(found_1)}" ) assert len(found_3) == 3, ( - f"Expected 3 bracket occurrences for oldRef=3, got {len(found_3)}" - ) - assert all(c.index == found_1[0].index for c in found_1), ( - "All oldRef=1 must share the same final index" - ) - assert all(c.index == found_3[0].index for c in found_3), ( - "All oldRef=3 must share the same final index" - ) + f"Expected 3 bracket occurrences for oldRef=3, got {len(found_3)}") + assert all( + c.index == found_1[0].index + for c in found_1), ("All oldRef=1 must share the same final index") + assert all( + c.index == found_3[0].index + for c in found_3), ("All oldRef=3 must share the same final index") collector = SearchResultsCollector() collector.add_aggregate_result(small_aggregate) @@ -424,8 +401,7 @@ def test_same_bracket_in_non_sequential_text(): # The 3 mentions of oldRef=8 => same bracket, and the single mention of oldRef=2 => the other bracket unique_brackets = set(brackets_list) assert len(unique_brackets) == 2, ( - f"Expected 2 unique brackets, got: {unique_brackets}" - ) + f"Expected 2 unique brackets, got: {unique_brackets}") # Inside new_cits, oldRef=8 should have the same newIndex across all occurrences old8_cits = [c for c in new_cits if c.rawIndex == 8] @@ -433,29 +409,25 @@ def test_same_bracket_in_non_sequential_text(): first_new_index = old8_cits[0].index for c in old8_cits: assert c.index == first_new_index, ( - "All oldRef=8 must share the same final bracket index" - ) + "All oldRef=8 must share the same final bracket index") old2_cits = [c for c in new_cits if c.rawIndex == 2] assert len(old2_cits) == 1, ( - "Expected exactly one mention referencing oldRef=2" - ) + "Expected exactly one mention referencing oldRef=2") # That one mention has a different bracket index from the oldRef=8 group assert old2_cits[0].index != first_new_index def test_three_unique_brackets_with_duplicates(): + """A final stress test with multiple bracket references repeated, e.g.: + + [3], [10], [3], [10], [3], [4], [10] We want 3 unique oldRefs => 3,4,10 => + 3 final bracket numbers, each reused consistently. """ - A final stress test with multiple bracket references repeated, e.g.: - [3], [10], [3], [10], [3], [4], [10] - We want 3 unique oldRefs => 3,4,10 => 3 final bracket numbers, each reused consistently. - """ - text = ( - "We see bracket [3], then bracket [10], then again bracket [3], " - "yet again [10], plus a third time [3], now a new bracket [4], " - "and finally bracket [10]." - ) + text = ("We see bracket [3], then bracket [10], then again bracket [3], " + "yet again [10], plus a third time [3], now a new bracket [4], " + "and finally bracket [10].") raw = extract_citations(text) # oldRef=3 repeated 3 times, oldRef=10 repeated 3 times, oldRef=4 repeated once assert len(raw) == 7 @@ -469,33 +441,30 @@ def test_three_unique_brackets_with_duplicates(): unique_brackets = sorted(set(bracket_list)) # We expect exactly 3 unique bracket labels in final text: assert len(unique_brackets) == 3, ( - f"Expected 3 distinct bracket values, got {unique_brackets}" - ) + f"Expected 3 distinct bracket values, got {unique_brackets}") # Confirm each oldRef is consistently re-labeled: old3 = [c.index for c in new_cits if c.rawIndex == 3] assert len(set(old3)) == 1, ( - "All references to rawIndex=3 must share the same newIndex" - ) + "All references to rawIndex=3 must share the same newIndex") old4 = [c.index for c in new_cits if c.rawIndex == 4] assert len(set(old4)) == 1, ( - "All references to rawIndex=4 must share the same newIndex" - ) + "All references to rawIndex=4 must share the same newIndex") old10 = [c.index for c in new_cits if c.rawIndex == 10] assert len(set(old10)) == 1, ( - "All references to rawIndex=10 must share the same newIndex" - ) + "All references to rawIndex=10 must share the same newIndex") # That’s the main correctness check. We can map them to actual aggregator results if we had them, but this suffices. @pytest.fixture def mock_aggregator_results(): - """ - Return a small AggregateSearchResult with multiple items - in a known order. We'll pretend the aggregator indexes them + """Return a small AggregateSearchResult with multiple items in a known + order. + + We'll pretend the aggregator indexes them as 1..N in this same order. """ chunk1 = ChunkSearchResult( @@ -545,9 +514,9 @@ def mock_aggregator_results(): def test_end_to_end_citation_remapping(mock_aggregator_results): - """ - 1) We define the text that the LLM hypothetically produced. - It references aggregator item #3, #1, #5, #1, #4 in random order. + """1) We define the text that the LLM hypothetically produced. + + It references aggregator item #3, #1, #5, #1, #4 in random order. But let's pretend the aggregator only has 5 total items: 1 -> chunk1 2 -> chunk2 @@ -568,8 +537,7 @@ def test_end_to_end_citation_remapping(mock_aggregator_results): raw_llm_text = ( "Major updates: The Graph item is mentioned [3]. Then we mention chunk1 [1]. " - "Oh, we also have web2 [5]. Wait, chunk1 again [1]. Finally web1 [4]." - ) + "Oh, we also have web2 [5]. Wait, chunk1 again [1]. Finally web1 [4].") # 1) Extract bracket references raw_citations = extract_citations(raw_llm_text) @@ -578,8 +546,7 @@ def test_end_to_end_citation_remapping(mock_aggregator_results): # 2) Re-label them in ascending bracket order for display, # but store the original aggregator index in rawIndex new_text, reassigned_citations = reassign_citations_in_order( - raw_llm_text, raw_citations - ) + raw_llm_text, raw_citations) # 3) Map citations by using the aggregator's oldRef => aggregator #. # i.e. we look up rawIndex in the collector @@ -594,9 +561,8 @@ def test_end_to_end_citation_remapping(mock_aggregator_results): def citation_summary(c: Citation): return { "finalIndex": c.index, - "rawIndex": getattr( - c, "rawIndex", c.rawIndex - ), # Some code calls it oldIndex + "rawIndex": getattr(c, "rawIndex", + c.rawIndex), # Some code calls it oldIndex "sourceType": c.sourceType, "docId": c.document_id or "", "title": (c.metadata.get("title") if c.metadata else ""), @@ -628,8 +594,7 @@ def test_end_to_end_citation_remapping(mock_aggregator_results): # They should share same final 'index' if your code unifies repeated references final_idx_set = {c.index for c in old1_cits} assert len(final_idx_set) == 1, ( - "All references to oldRef=1 must share the same final bracket index" - ) + "All references to oldRef=1 must share the same final bracket index") # They should map to chunk1 => doc-1 for c in old1_cits: assert c.sourceType == "chunk" @@ -756,11 +721,9 @@ def test_end_to_end_mocked_aggregator(ordered_aggregate): # [("chunk", chunk1, 1), ("chunk", chunk2, 2), ("graph", graph1, 3), ("web", web1, 4), ("web", web2, 5)] # Step 2: LLM text references them out of order: aggregator #3, #1, #5, #1, #4 - raw_text = ( - "We mention aggregator #3 first [3], then aggregator #1 [1], " - "then aggregator #5 [5], then aggregator #1 again [1], " - "finally aggregator #4 [4]." - ) + raw_text = ("We mention aggregator #3 first [3], then aggregator #1 [1], " + "then aggregator #5 [5], then aggregator #1 again [1], " + "finally aggregator #4 [4].") # 2a) extract brackets raw_cits = extract_citations(raw_text) @@ -839,12 +802,10 @@ def test_end_to_end_mocked_aggregator(ordered_aggregate): def test_hybrid_aggregation_format(): - """ - Demonstrates combining chunk, graph, and web results in a single - AggregateSearchResult, adding them to the collector, and verifying - the LLM-format output includes each in the correct 'Vector Search Results:', - 'Graph Search Results:', and 'Web Search Results:' sections. - """ + """Demonstrates combining chunk, graph, and web results in a single + AggregateSearchResult, adding them to the collector, and verifying the LLM- + format output includes each in the correct 'Vector Search Results:', 'Graph + Search Results:', and 'Web Search Results:' sections.""" collector = SearchResultsCollector() # Build an all-in-one AggregateSearchResult @@ -858,9 +819,8 @@ def test_hybrid_aggregation_format(): metadata={"title": "DocA"}, ) graphA = GraphSearchResult( - content=GraphEntityResult( - name="GraphEntity", description="Some entity in the KG" - ), + content=GraphEntityResult(name="GraphEntity", + description="Some entity in the KG"), result_type="entity", metadata={"graphKey": "graphVal"}, score=0.90, @@ -910,10 +870,8 @@ def test_hybrid_aggregation_format(): def test_collector_multiple_calls_with_small_aggregate(small_aggregate): - """ - Demonstrate calling collector.add_aggregate_result() multiple times - and confirm aggregator indexes continue incrementing. - """ + """Demonstrate calling collector.add_aggregate_result() multiple times and + confirm aggregator indexes continue incrementing.""" collector = SearchResultsCollector() # First call: aggregator #1..#3 => chunk1, chunk2, chunk3 @@ -946,10 +904,8 @@ def test_collector_multiple_calls_with_small_aggregate(small_aggregate): def test_collector_out_of_range_with_small_aggregate(small_aggregate): - """ - If the LLM references [5] but we only have aggregator items #1..#3, - map_citations_to_collector() should produce sourceType="unknown". - """ + """If the LLM references [5] but we only have aggregator items #1..#3, + map_citations_to_collector() should produce sourceType="unknown".""" collector = SearchResultsCollector() collector.add_aggregate_result(small_aggregate) # aggregator #1 => doc-1, #2 => doc-2, #3 => doc-3 @@ -968,13 +924,10 @@ def test_collector_out_of_range_with_small_aggregate(small_aggregate): def test_collector_repeated_same_aggregator_with_small_aggregate( - small_aggregate, -): - """ - If the LLM text references aggregator #2 multiple times, e.g. [2], [2], [2], - then after reassign_citations_in_order, all should unify to the same final bracket, - and map to chunk2 => doc-2 in aggregator. - """ + small_aggregate, ): + """If the LLM text references aggregator #2 multiple times, e.g. [2], [2], + [2], then after reassign_citations_in_order, all should unify to the same + final bracket, and map to chunk2 => doc-2 in aggregator.""" collector = SearchResultsCollector() collector.add_aggregate_result(small_aggregate) # aggregator #1 => doc-1, #2 => doc-2, #3 => doc-3 @@ -995,11 +948,11 @@ def test_collector_repeated_same_aggregator_with_small_aggregate( def test_collector_mixed_references_small_aggregate(small_aggregate): - """ - Suppose the LLM references aggregator #3, #1, #3, #2 in random order. - That means bracket [3] => chunk #3 doc-3, bracket [1] => chunk #1 doc-1, bracket [2] => doc-2, etc. - Then after reassign, we confirm final text has them in ascending bracket order - while the mapped citations remain correct. + """Suppose the LLM references aggregator #3, #1, #3, #2 in random order. + + That means bracket [3] => chunk #3 doc-3, bracket [1] => chunk #1 doc-1, + bracket [2] => doc-2, etc. Then after reassign, we confirm final text has + them in ascending bracket order while the mapped citations remain correct. """ collector = SearchResultsCollector() collector.add_aggregate_result(small_aggregate) diff --git a/py/tests/unit/test_collections.py b/py/tests/unit/test_collections.py index 3525a453a..eb9c07991 100644 --- a/py/tests/unit/test_collections.py +++ b/py/tests/unit/test_collections.py @@ -34,8 +34,7 @@ async def test_create_collection_default_name(collections_handler): async def test_update_collection(collections_handler): owner_id = uuid.uuid4() coll = await collections_handler.create_collection( - owner_id=owner_id, name="Original Name", description="Original Desc" - ) + owner_id=owner_id, name="Original Name", description="Original Desc") updated = await collections_handler.update_collection( collection_id=coll.id, @@ -52,9 +51,9 @@ async def test_update_collection(collections_handler): @pytest.mark.asyncio async def test_update_collection_no_fields(collections_handler): owner_id = uuid.uuid4() - coll = await collections_handler.create_collection( - owner_id=owner_id, name="NoUpdate", description="No Update" - ) + coll = await collections_handler.create_collection(owner_id=owner_id, + name="NoUpdate", + description="No Update") with pytest.raises(R2RException) as exc: await collections_handler.update_collection(collection_id=coll.id) @@ -64,9 +63,8 @@ async def test_update_collection_no_fields(collections_handler): @pytest.mark.asyncio async def test_delete_collection_relational(collections_handler): owner_id = uuid.uuid4() - coll = await collections_handler.create_collection( - owner_id=owner_id, name="ToDelete" - ) + coll = await collections_handler.create_collection(owner_id=owner_id, + name="ToDelete") # Confirm existence exists = await collections_handler.collection_exists(coll.id) @@ -89,9 +87,8 @@ async def test_collection_exists(collections_handler): async def test_documents_in_collection(collections_handler, db_provider): # Create a collection owner_id = uuid.uuid4() - coll = await collections_handler.create_collection( - owner_id=owner_id, name="DocCollection" - ) + coll = await collections_handler.create_collection(owner_id=owner_id, + name="DocCollection") # Insert some documents related to this collection # We'll directly insert into the documents table for simplicity @@ -101,13 +98,12 @@ async def test_documents_in_collection(collections_handler, db_provider): VALUES ($1, $2, $3, 'txt', '{{}}', 'Test Doc', 'v1', 1234, 'pending', 'pending') """ await db_provider.connection_manager.execute_query( - insert_doc_query, [doc_id, [coll.id], owner_id] - ) + insert_doc_query, [doc_id, [coll.id], owner_id]) # Now fetch documents in collection - res = await collections_handler.documents_in_collection( - coll.id, offset=0, limit=10 - ) + res = await collections_handler.documents_in_collection(coll.id, + offset=0, + limit=10) assert len(res["results"]) == 1 assert res["total_entries"] == 1 assert res["results"][0].id == doc_id @@ -117,16 +113,13 @@ async def test_documents_in_collection(collections_handler, db_provider): @pytest.mark.asyncio async def test_get_collections_overview(collections_handler, db_provider): owner_id = uuid.uuid4() - coll1 = await collections_handler.create_collection( - owner_id=owner_id, name="Overview1" - ) - coll2 = await collections_handler.create_collection( - owner_id=owner_id, name="Overview2" - ) + coll1 = await collections_handler.create_collection(owner_id=owner_id, + name="Overview1") + coll2 = await collections_handler.create_collection(owner_id=owner_id, + name="Overview2") - overview = await collections_handler.get_collections_overview( - offset=0, limit=10 - ) + overview = await collections_handler.get_collections_overview(offset=0, + limit=10) # There should be at least these two ids = [c.id for c in overview["results"]] assert coll1.id in ids @@ -135,12 +128,10 @@ async def test_get_collections_overview(collections_handler, db_provider): @pytest.mark.asyncio async def test_assign_document_to_collection_relational( - collections_handler, db_provider -): + collections_handler, db_provider): owner_id = uuid.uuid4() - coll = await collections_handler.create_collection( - owner_id=owner_id, name="Assign" - ) + coll = await collections_handler.create_collection(owner_id=owner_id, + name="Assign") # Insert a doc doc_id = uuid.uuid4() @@ -148,31 +139,27 @@ async def test_assign_document_to_collection_relational( INSERT INTO {db_provider.project_name}.documents (id, owner_id, type, metadata, title, version, size_in_bytes, ingestion_status, extraction_status, collection_ids) VALUES ($1, $2, 'txt', '{{}}', 'Standalone Doc', 'v1', 10, 'pending', 'pending', ARRAY[]::uuid[]) """ - await db_provider.connection_manager.execute_query( - insert_doc_query, [doc_id, owner_id] - ) + await db_provider.connection_manager.execute_query(insert_doc_query, + [doc_id, owner_id]) # Assign this doc to the collection await collections_handler.assign_document_to_collection_relational( - doc_id, coll.id - ) + doc_id, coll.id) # Verify doc is now in collection - docs = await collections_handler.documents_in_collection( - coll.id, offset=0, limit=10 - ) + docs = await collections_handler.documents_in_collection(coll.id, + offset=0, + limit=10) assert len(docs["results"]) == 1 assert docs["results"][0].id == doc_id @pytest.mark.asyncio async def test_remove_document_from_collection_relational( - collections_handler, db_provider -): + collections_handler, db_provider): owner_id = uuid.uuid4() - coll = await collections_handler.create_collection( - owner_id=owner_id, name="RemoveDoc" - ) + coll = await collections_handler.create_collection(owner_id=owner_id, + name="RemoveDoc") # Insert a doc already in collection doc_id = uuid.uuid4() @@ -182,17 +169,15 @@ async def test_remove_document_from_collection_relational( VALUES ($1, $2, 'txt', '{{}}'::jsonb, 'Another Doc', 'v1', 10, 'pending', 'pending', $3) """ await db_provider.connection_manager.execute_query( - insert_doc_query, [doc_id, owner_id, [coll.id]] - ) + insert_doc_query, [doc_id, owner_id, [coll.id]]) # Remove it await collections_handler.remove_document_from_collection_relational( - doc_id, coll.id - ) + doc_id, coll.id) - docs = await collections_handler.documents_in_collection( - coll.id, offset=0, limit=10 - ) + docs = await collections_handler.documents_in_collection(coll.id, + offset=0, + limit=10) assert len(docs["results"]) == 0 @@ -202,5 +187,4 @@ async def test_delete_nonexistent_collection(collections_handler): with pytest.raises(R2RException) as exc: await collections_handler.delete_collection_relational(non_existent_id) assert exc.value.status_code == 404, ( - "Should raise 404 for non-existing collection" - ) + "Should raise 404 for non-existing collection") diff --git a/py/tests/unit/test_config.py b/py/tests/unit/test_config.py index a444d7936..8ab72285b 100644 --- a/py/tests/unit/test_config.py +++ b/py/tests/unit/test_config.py @@ -22,19 +22,19 @@ def base_config(): @pytest.fixture def config_dir(): - """Get the path to the configs directory""" + """Get the path to the configs directory.""" return Path(__file__).parent.parent.parent / "core" / "configs" @pytest.fixture def all_config_files(config_dir): - """Get list of all TOML files in the configs directory""" + """Get list of all TOML files in the configs directory.""" return list(config_dir.glob("*.toml")) @pytest.fixture def all_configs(all_config_files): - """Load all config files""" + """Load all config files.""" configs = {} for config_file in all_config_files: with open(config_file) as f: @@ -50,7 +50,7 @@ def full_config(all_configs): @pytest.fixture def all_merged_configs(base_config, all_configs): - """Merge every override config into the base config""" + """Merge every override config into the base config.""" merged = {} for config_name, config_data in all_configs.items(): merged[config_name] = deep_update(deepcopy(base_config), config_data) @@ -59,7 +59,7 @@ def all_merged_configs(base_config, all_configs): @pytest.fixture def merged_config(base_config, full_config): - """Merge the full override config into the base config""" + """Merge the full override config into the base config.""" return deep_update(deepcopy(base_config), full_config) @@ -69,25 +69,20 @@ def merged_config(base_config, full_config): def test_base_config_loading(base_config): - """ - Test that the base config loads correctly with the new expected values. - (For example, the old 'clustering_mode' key is gone, and now we check - that keys like 'graph_entity_description_prompt' are present.) + """Test that the base config loads correctly with the new expected values. + + (For example, the old 'clustering_mode' key is gone, and now we check that + keys like 'graph_entity_description_prompt' are present.) """ config = R2RConfig(base_config) # Verify that the database graph creation settings are present and set - assert ( - config.database.graph_creation_settings.graph_entity_description_prompt - == "graph_entity_description" - ) - assert ( - config.database.graph_creation_settings.graph_extraction_prompt - == "graph_extraction" - ) - assert ( - config.database.graph_creation_settings.automatic_deduplication is True - ) + assert (config.database.graph_creation_settings. + graph_entity_description_prompt == "graph_entity_description") + assert (config.database.graph_creation_settings.graph_extraction_prompt == + "graph_extraction") + assert (config.database.graph_creation_settings.automatic_deduplication + is True) # Verify other key sections assert config.ingestion.provider == "r2r" @@ -96,8 +91,8 @@ def test_base_config_loading(base_config): def test_full_config_override(full_config): - """ - Test that full.toml properly overrides the base values. + """Test that full.toml properly overrides the base values. + For example, assume the full override changes: - ingestion.provider from "r2r" to "unstructured_local" - orchestration.provider from "simple" to "hatchet" @@ -108,29 +103,22 @@ def test_full_config_override(full_config): assert config.ingestion.provider == "unstructured_local" assert config.orchestration.provider == "hatchet" # Check that a new nested key has been added - assert ( - config.database.graph_creation_settings.max_knowledge_relationships - == 100 - ) + assert (config.database.graph_creation_settings.max_knowledge_relationships + == 100) def test_nested_config_preservation(merged_config): - """ - Test that nested configuration values are preserved after merging. - """ + """Test that nested configuration values are preserved after merging.""" config = R2RConfig(merged_config) - assert ( - config.database.graph_creation_settings.max_knowledge_relationships - == 100 - ) + assert (config.database.graph_creation_settings.max_knowledge_relationships + == 100) def test_new_values_in_override(merged_config): - """ - Test that new keys in the override config are added. + """Test that new keys in the override config are added. - In the old tests we asserted values for orchestration concurrency keys. - In the new config structure these keys have been removed (or renamed). + In the old tests we asserted values for orchestration concurrency keys. In + the new config structure these keys have been removed (or renamed). Therefore, we now check for them only if they exist. """ config = R2RConfig(merged_config) @@ -140,35 +128,30 @@ def test_new_values_in_override(merged_config): assert config.orchestration.ingestion_concurrency_limit == 16 # Optionally, if new keys like graph_search_results_creation_concurrency_limit are defined, check them: - if hasattr( - config.orchestration, "graph_search_results_creation_concurrency_limit" - ): - assert ( - config.orchestration.graph_search_results_creation_concurrency_limit - == 32 - ) + if hasattr(config.orchestration, + "graph_search_results_creation_concurrency_limit"): + assert (config.orchestration. + graph_search_results_creation_concurrency_limit == 32) if hasattr(config.orchestration, "graph_search_results_concurrency_limit"): assert config.orchestration.graph_search_results_concurrency_limit == 8 def test_config_type_consistency(merged_config): - """ - Test that configuration values maintain their expected types. - """ + """Test that configuration values maintain their expected types.""" config = R2RConfig(merged_config) assert isinstance( - config.database.graph_creation_settings.graph_entity_description_prompt, + config.database.graph_creation_settings. + graph_entity_description_prompt, str, ) assert isinstance( - config.database.graph_creation_settings.automatic_deduplication, bool - ) + config.database.graph_creation_settings.automatic_deduplication, bool) assert isinstance(config.ingestion.chunking_strategy, str) - if hasattr( - config.database.graph_creation_settings, "max_knowledge_relationships" - ): + if hasattr(config.database.graph_creation_settings, + "max_knowledge_relationships"): assert isinstance( - config.database.graph_creation_settings.max_knowledge_relationships, + config.database.graph_creation_settings. + max_knowledge_relationships, int, ) @@ -181,8 +164,8 @@ def get_config_files(): @pytest.mark.parametrize("config_file", get_config_files()) def test_config_required_keys(config_file): - """ - Test that all required sections and keys (per R2RConfig.REQUIRED_KEYS) exist. + """Test that all required sections and keys (per R2RConfig.REQUIRED_KEYS) + exist. In the new structure the 'agent' section no longer includes the key 'generation_config', so we filter that out. @@ -190,12 +173,8 @@ def test_config_required_keys(config_file): if config_file == "r2r.toml": file_path = Path(__file__).parent.parent.parent / "r2r/r2r.toml" else: - file_path = ( - Path(__file__).parent.parent.parent - / "core" - / "configs" - / config_file - ) + file_path = (Path(__file__).parent.parent.parent / "core" / "configs" / + config_file) with open(file_path) as f: config_data = toml.load(f) @@ -219,18 +198,15 @@ def test_config_required_keys(config_file): for key in keys_to_check: if isinstance(section_config, dict): assert key in section_config, ( - f"Missing required key {key} in section {section}" - ) + f"Missing required key {key} in section {section}") else: assert hasattr(section_config, key), ( - f"Missing required key {key} in section {section}" - ) + f"Missing required key {key} in section {section}") def test_serialization_roundtrip(merged_config): - """ - Test that serializing and then deserializing the config does not lose data. - """ + """Test that serializing and then deserializing the config does not lose + data.""" config = R2RConfig(merged_config) serialized = config.to_toml() @@ -238,20 +214,15 @@ def test_serialization_roundtrip(merged_config): roundtrip_config = R2RConfig(toml.loads(serialized)) # Compare a couple of key values after roundtrip. - assert ( - roundtrip_config.database.graph_creation_settings.graph_entity_description_prompt - == config.database.graph_creation_settings.graph_entity_description_prompt - ) - assert ( - roundtrip_config.orchestration.provider - == config.orchestration.provider - ) + assert (roundtrip_config.database.graph_creation_settings. + graph_entity_description_prompt == config.database. + graph_creation_settings.graph_entity_description_prompt) + assert (roundtrip_config.orchestration.provider == + config.orchestration.provider) def test_all_merged_configs(base_config, all_merged_configs): - """ - Test that every override file properly merges with the base config. - """ + """Test that every override file properly merges with the base config.""" for config_name, merged_data in all_merged_configs.items(): config = R2RConfig(merged_data) assert config is not None @@ -263,9 +234,7 @@ def test_all_merged_configs(base_config, all_merged_configs): def test_all_config_overrides(all_configs): - """ - Test that all configuration files can be loaded independently. - """ + """Test that all configuration files can be loaded independently.""" for config_name, config_data in all_configs.items(): config = R2RConfig(config_data) assert config is not None diff --git a/py/tests/unit/test_conversations.py b/py/tests/unit/test_conversations.py index 8bbe0c09f..ce9ead080 100644 --- a/py/tests/unit/test_conversations.py +++ b/py/tests/unit/test_conversations.py @@ -20,9 +20,8 @@ async def test_create_conversation(conversations_handler): @pytest.mark.asyncio async def test_create_conversation_with_user_and_name(conversations_handler): user_id = uuid.uuid4() - resp = await conversations_handler.create_conversation( - user_id=user_id, name="Test Conv" - ) + resp = await conversations_handler.create_conversation(user_id=user_id, + name="Test Conv") assert resp.id is not None assert resp.created_at is not None # There's no direct field for user_id in ConversationResponse, @@ -52,9 +51,9 @@ async def test_add_message_with_parent(conversations_handler): parent_id = parent_resp.id child_msg = Message(role="assistant", content="Child reply") - child_resp = await conversations_handler.add_message( - conv_id, child_msg, parent_id=parent_id - ) + child_resp = await conversations_handler.add_message(conv_id, + child_msg, + parent_id=parent_id) assert child_resp.id is not None assert child_resp.message.content == "Child reply" @@ -68,9 +67,8 @@ async def test_edit_message(conversations_handler): resp = await conversations_handler.add_message(conv_id, original_msg) msg_id = resp.id - updated = await conversations_handler.edit_message( - msg_id, "Edited content" - ) + updated = await conversations_handler.edit_message(msg_id, + "Edited content") assert updated["message"].content == "Edited content" assert updated["metadata"]["edited"] is True @@ -85,8 +83,7 @@ async def test_update_message_metadata(conversations_handler): msg_id = resp.id await conversations_handler.update_message_metadata( - msg_id, {"test_key": "test_value"} - ) + msg_id, {"test_key": "test_value"}) # Verify metadata updated full_conversation = await conversations_handler.get_conversation(conv_id) @@ -126,5 +123,4 @@ async def test_delete_conversation(conversations_handler): with pytest.raises(R2RException) as exc: await conversations_handler.get_conversation(conv_id) assert exc.value.status_code == 404, ( - "Conversation should be deleted and not found" - ) + "Conversation should be deleted and not found") diff --git a/py/tests/unit/test_documents.py b/py/tests/unit/test_documents.py index eedf13b2e..887a15c19 100644 --- a/py/tests/unit/test_documents.py +++ b/py/tests/unit/test_documents.py @@ -14,26 +14,37 @@ from core.base import ( def make_db_entry(doc: DocumentResponse): # This simulates what your real code should do: return { - "id": doc.id, - "collection_ids": doc.collection_ids, - "owner_id": doc.owner_id, - "document_type": doc.document_type.value, - "metadata": json.dumps(doc.metadata), - "title": doc.title, - "version": doc.version, - "size_in_bytes": doc.size_in_bytes, - "ingestion_status": doc.ingestion_status.value, - "extraction_status": doc.extraction_status.value, - "created_at": doc.created_at, - "updated_at": doc.updated_at, - "ingestion_attempt_number": 0, - "summary": doc.summary, + "id": + doc.id, + "collection_ids": + doc.collection_ids, + "owner_id": + doc.owner_id, + "document_type": + doc.document_type.value, + "metadata": + json.dumps(doc.metadata), + "title": + doc.title, + "version": + doc.version, + "size_in_bytes": + doc.size_in_bytes, + "ingestion_status": + doc.ingestion_status.value, + "extraction_status": + doc.extraction_status.value, + "created_at": + doc.created_at, + "updated_at": + doc.updated_at, + "ingestion_attempt_number": + 0, + "summary": + doc.summary, # If summary_embedding is a list, we can store it as a string here if needed - "summary_embedding": ( - str(doc.summary_embedding) - if doc.summary_embedding is not None - else None - ), + "summary_embedding": (str(doc.summary_embedding) + if doc.summary_embedding is not None else None), } @@ -59,14 +70,12 @@ async def test_upsert_documents_overview_insert(documents_handler): # Simulate the handler call await documents_handler.upsert_documents_overview( - [doc] - ) # adjust your handler to accept list or doc + [doc]) # adjust your handler to accept list or doc # If your handler expects a db entry dict, you may need to patch handler or adapt your code # Verify res = await documents_handler.get_documents_overview( - offset=0, limit=10, filter_document_ids=[doc_id] - ) + offset=0, limit=10, filter_document_ids=[doc_id]) assert res["total_entries"] == 1 fetched_doc = res["results"][0] assert fetched_doc.id == doc_id @@ -105,8 +114,7 @@ async def test_upsert_documents_overview_update(documents_handler): # Verify update res = await documents_handler.get_documents_overview( - offset=0, limit=10, filter_document_ids=[doc_id] - ) + offset=0, limit=10, filter_document_ids=[doc_id]) fetched_doc = res["results"][0] assert fetched_doc.title == "Updated Title" assert fetched_doc.metadata["note"] == "updated" @@ -135,6 +143,5 @@ async def test_delete_document(documents_handler): await documents_handler.upsert_documents_overview([doc]) await documents_handler.delete(doc_id) res = await documents_handler.get_documents_overview( - offset=0, limit=10, filter_document_ids=[doc_id] - ) + offset=0, limit=10, filter_document_ids=[doc_id]) assert res["total_entries"] == 0 diff --git a/py/tests/unit/test_graphs.py b/py/tests/unit/test_graphs.py index 01ea94de1..dc15a7978 100644 --- a/py/tests/unit/test_graphs.py +++ b/py/tests/unit/test_graphs.py @@ -14,9 +14,9 @@ class StoreType(str, Enum): @pytest.mark.asyncio async def test_create_graph(graphs_handler): coll_id = uuid.uuid4() - resp = await graphs_handler.create( - collection_id=coll_id, name="My Graph", description="Test Graph" - ) + resp = await graphs_handler.create(collection_id=coll_id, + name="My Graph", + description="Test Graph") assert isinstance(resp, GraphResponse) assert resp.name == "My Graph" assert resp.collection_id == coll_id @@ -26,9 +26,8 @@ async def test_create_graph(graphs_handler): async def test_add_entities_and_relationships(graphs_handler): # Create a graph coll_id = uuid.uuid4() - graph_resp = await graphs_handler.create( - collection_id=coll_id, name="TestGraph" - ) + graph_resp = await graphs_handler.create(collection_id=coll_id, + name="TestGraph") graph_id = graph_resp.id # Add an entity @@ -64,17 +63,16 @@ async def test_add_entities_and_relationships(graphs_handler): assert rel.predicate == "lives_in" # Verify entities retrieval - ents, total_ents = await graphs_handler.get_entities( - parent_id=graph_id, offset=0, limit=10 - ) + ents, total_ents = await graphs_handler.get_entities(parent_id=graph_id, + offset=0, + limit=10) assert total_ents == 2 names = [e.name for e in ents] assert "TestEntity" in names and "AnotherEntity" in names # Verify relationships retrieval rels, total_rels = await graphs_handler.get_relationships( - parent_id=graph_id, offset=0, limit=10 - ) + parent_id=graph_id, offset=0, limit=10) assert total_rels == 1 assert rels[0].predicate == "lives_in" @@ -83,9 +81,8 @@ async def test_add_entities_and_relationships(graphs_handler): async def test_delete_entities_and_relationships(graphs_handler): # Create another graph coll_id = uuid.uuid4() - graph_resp = await graphs_handler.create( - collection_id=coll_id, name="DeletableGraph" - ) + graph_resp = await graphs_handler.create(collection_id=coll_id, + name="DeletableGraph") graph_id = graph_resp.id # Add entities @@ -117,9 +114,9 @@ async def test_delete_entities_and_relationships(graphs_handler): entity_ids=[e1.id], store_type=StoreType.GRAPHS, ) - ents, count = await graphs_handler.get_entities( - parent_id=graph_id, offset=0, limit=10 - ) + ents, count = await graphs_handler.get_entities(parent_id=graph_id, + offset=0, + limit=10) assert count == 1 assert ents[0].id == e2.id @@ -130,8 +127,7 @@ async def test_delete_entities_and_relationships(graphs_handler): store_type=StoreType.GRAPHS, ) rels, rel_count = await graphs_handler.get_relationships( - parent_id=graph_id, offset=0, limit=10 - ) + parent_id=graph_id, offset=0, limit=10) assert rel_count == 0 @@ -182,7 +178,6 @@ async def test_communities(graphs_handler): # # overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id]) # # assert overview["total_entries"] == 0, "Graph should be deleted" - # @pytest.mark.asyncio # async def test_delete_graph(graphs_handler): # # Create a graph and then delete it @@ -244,22 +239,22 @@ async def test_create_graph_defaults(graphs_handler): @pytest.mark.asyncio async def test_update_graph(graphs_handler): coll_id = uuid.uuid4() - graph_resp = await graphs_handler.create( - collection_id=coll_id, name="OldName", description="OldDescription" - ) + graph_resp = await graphs_handler.create(collection_id=coll_id, + name="OldName", + description="OldDescription") graph_id = graph_resp.id # Update name and description - updated_resp = await graphs_handler.update( - collection_id=graph_id, name="NewName", description="NewDescription" - ) + updated_resp = await graphs_handler.update(collection_id=graph_id, + name="NewName", + description="NewDescription") assert updated_resp.name == "NewName" assert updated_resp.description == "NewDescription" # Retrieve and verify - overview = await graphs_handler.list_graphs( - offset=0, limit=10, filter_graph_ids=[graph_id] - ) + overview = await graphs_handler.list_graphs(offset=0, + limit=10, + filter_graph_ids=[graph_id]) assert overview["total_entries"] == 1 fetched_graph = overview["results"][0] assert fetched_graph.name == "NewName" @@ -269,16 +264,27 @@ async def test_update_graph(graphs_handler): @pytest.mark.asyncio async def test_bulk_entities(graphs_handler): coll_id = uuid.uuid4() - graph_resp = await graphs_handler.create( - collection_id=coll_id, name="BulkEntities" - ) + graph_resp = await graphs_handler.create(collection_id=coll_id, + name="BulkEntities") graph_id = graph_resp.id # Add multiple entities entities_to_add = [ - {"name": "EntityA", "category": "CategoryA", "description": "DescA"}, - {"name": "EntityB", "category": "CategoryB", "description": "DescB"}, - {"name": "EntityC", "category": "CategoryC", "description": "DescC"}, + { + "name": "EntityA", + "category": "CategoryA", + "description": "DescA" + }, + { + "name": "EntityB", + "category": "CategoryB", + "description": "DescB" + }, + { + "name": "EntityC", + "category": "CategoryC", + "description": "DescC" + }, ] for ent in entities_to_add: await graphs_handler.entities.create( @@ -289,9 +295,9 @@ async def test_bulk_entities(graphs_handler): description=ent["description"], ) - ents, total = await graphs_handler.get_entities( - parent_id=graph_id, offset=0, limit=10 - ) + ents, total = await graphs_handler.get_entities(parent_id=graph_id, + offset=0, + limit=10) assert total == 3 fetched_names = [e.name for e in ents] for ent in entities_to_add: @@ -301,21 +307,20 @@ async def test_bulk_entities(graphs_handler): @pytest.mark.asyncio async def test_relationship_filtering(graphs_handler): coll_id = uuid.uuid4() - graph_resp = await graphs_handler.create( - collection_id=coll_id, name="RelFilteringGraph" - ) + graph_resp = await graphs_handler.create(collection_id=coll_id, + name="RelFilteringGraph") graph_id = graph_resp.id # Add entities - e1 = await graphs_handler.entities.create( - parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node1" - ) - e2 = await graphs_handler.entities.create( - parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node2" - ) - e3 = await graphs_handler.entities.create( - parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node3" - ) + e1 = await graphs_handler.entities.create(parent_id=graph_id, + store_type=StoreType.GRAPHS, + name="Node1") + e2 = await graphs_handler.entities.create(parent_id=graph_id, + store_type=StoreType.GRAPHS, + name="Node2") + e3 = await graphs_handler.entities.create(parent_id=graph_id, + store_type=StoreType.GRAPHS, + name="Node3") # Add different relationships await graphs_handler.relationships.create( @@ -340,8 +345,7 @@ async def test_relationship_filtering(graphs_handler): # Get all relationships all_rels, all_count = await graphs_handler.get_relationships( - parent_id=graph_id, offset=0, limit=10 - ) + parent_id=graph_id, offset=0, limit=10) assert all_count == 2 # Filter by relationship_type = ["connected_to"] @@ -358,44 +362,41 @@ async def test_relationship_filtering(graphs_handler): @pytest.mark.asyncio async def test_delete_all_entities(graphs_handler): coll_id = uuid.uuid4() - graph_resp = await graphs_handler.create( - collection_id=coll_id, name="DeleteAllEntities" - ) + graph_resp = await graphs_handler.create(collection_id=coll_id, + name="DeleteAllEntities") graph_id = graph_resp.id # Add some entities - await graphs_handler.entities.create( - parent_id=graph_id, store_type=StoreType.GRAPHS, name="E1" - ) - await graphs_handler.entities.create( - parent_id=graph_id, store_type=StoreType.GRAPHS, name="E2" - ) + await graphs_handler.entities.create(parent_id=graph_id, + store_type=StoreType.GRAPHS, + name="E1") + await graphs_handler.entities.create(parent_id=graph_id, + store_type=StoreType.GRAPHS, + name="E2") # Delete all entities without specifying IDs - await graphs_handler.entities.delete( - parent_id=graph_id, store_type=StoreType.GRAPHS - ) - ents, count = await graphs_handler.get_entities( - parent_id=graph_id, offset=0, limit=10 - ) + await graphs_handler.entities.delete(parent_id=graph_id, + store_type=StoreType.GRAPHS) + ents, count = await graphs_handler.get_entities(parent_id=graph_id, + offset=0, + limit=10) assert count == 0 @pytest.mark.asyncio async def test_delete_all_relationships(graphs_handler): coll_id = uuid.uuid4() - graph_resp = await graphs_handler.create( - collection_id=coll_id, name="DeleteAllRels" - ) + graph_resp = await graphs_handler.create(collection_id=coll_id, + name="DeleteAllRels") graph_id = graph_resp.id # Add two entities and a relationship - e1 = await graphs_handler.entities.create( - parent_id=graph_id, store_type=StoreType.GRAPHS, name="E1" - ) - e2 = await graphs_handler.entities.create( - parent_id=graph_id, store_type=StoreType.GRAPHS, name="E2" - ) + e1 = await graphs_handler.entities.create(parent_id=graph_id, + store_type=StoreType.GRAPHS, + name="E1") + e2 = await graphs_handler.entities.create(parent_id=graph_id, + store_type=StoreType.GRAPHS, + name="E2") await graphs_handler.relationships.create( subject="E1", subject_id=e1.id, @@ -407,12 +408,10 @@ async def test_delete_all_relationships(graphs_handler): ) # Delete all relationships - await graphs_handler.relationships.delete( - parent_id=graph_id, store_type=StoreType.GRAPHS - ) + await graphs_handler.relationships.delete(parent_id=graph_id, + store_type=StoreType.GRAPHS) rels, rel_count = await graphs_handler.get_relationships( - parent_id=graph_id, offset=0, limit=10 - ) + parent_id=graph_id, offset=0, limit=10) assert rel_count == 0 @@ -421,8 +420,7 @@ async def test_error_handling_invalid_graph_id(graphs_handler): # Attempt to get a non-existent graph non_existent_id = uuid.uuid4() overview = await graphs_handler.list_graphs( - offset=0, limit=10, filter_graph_ids=[non_existent_id] - ) + offset=0, limit=10, filter_graph_ids=[non_existent_id]) assert overview["total_entries"] == 0 # Attempt to delete a non-existent graph @@ -462,20 +460,19 @@ async def test_filter_by_collection_ids_in_entities(graphs_handler): VALUES ($1, $2, $3, $4) """ await graphs_handler.connection_manager.execute_query( - insert_entity_sql, [row_id, "TestEntity", some_parent_id, None] - ) + insert_entity_sql, [row_id, "TestEntity", some_parent_id, None]) # 3) Now run your actual test search filter_dict = {"collection_ids": {"$in": [str(some_parent_id)]}} results = [] async for row in graphs_handler.graph_search( - query="anything", - search_type="entities", - filters=filter_dict, - limit=10, - use_fulltext_search=False, - use_hybrid_search=False, - query_embedding=[0, 0, 0, 0], + query="anything", + search_type="entities", + filters=filter_dict, + limit=10, + use_fulltext_search=False, + use_hybrid_search=False, + query_embedding=[0, 0, 0, 0], ): results.append(row) @@ -487,15 +484,13 @@ async def test_filter_by_collection_ids_in_entities(graphs_handler): DELETE FROM "{graphs_handler.project_name}"."graphs_entities" WHERE id = $1 """ await graphs_handler.connection_manager.execute_query( - delete_entity_sql, [row_id] - ) + delete_entity_sql, [row_id]) delete_graph_sql = f""" DELETE FROM "{graphs_handler.project_name}"."graphs" WHERE id = $1 """ await graphs_handler.connection_manager.execute_query( - delete_graph_sql, [some_parent_id] - ) + delete_graph_sql, [some_parent_id]) # # TODO - Fix code to pass this test. diff --git a/py/tests/unit/test_limits.py b/py/tests/unit/test_limits.py index f281aa212..43e1b01b5 100644 --- a/py/tests/unit/test_limits.py +++ b/py/tests/unit/test_limits.py @@ -11,8 +11,9 @@ from shared.abstractions import User @pytest.mark.asyncio async def test_log_request_and_count(limits_handler): - """ - Test that when we log requests, the count increments, and rate-limits are enforced. + """Test that when we log requests, the count increments, and rate-limits + are enforced. + Route-specific test using the /v3/retrieval/search endpoint limits. """ # Clear existing logs first @@ -51,8 +52,7 @@ async def test_log_request_and_count(limits_handler): now = datetime.now(timezone.utc) one_min_ago = now - timedelta(minutes=1) route_count = await limits_handler._count_requests( - user_id, route, one_min_ago - ) + user_id, route, one_min_ago) print(f"Route count after request {i + 1}: {route_count}") # This should pass for all 5 requests @@ -62,14 +62,12 @@ async def test_log_request_and_count(limits_handler): # Log the 6th request (over limit) await limits_handler.log_request(user_id, route) route_count = await limits_handler._count_requests( - user_id, route, one_min_ago - ) + user_id, route, one_min_ago) print(f"Route count after request 6: {route_count}") # This check should fail as we've exceeded route_per_min=5 - with pytest.raises( - ValueError, match="Per-route per-minute rate limit exceeded" - ): + with pytest.raises(ValueError, + match="Per-route per-minute rate limit exceeded"): await limits_handler.check_limits(test_user, route) finally: @@ -78,9 +76,8 @@ async def test_log_request_and_count(limits_handler): @pytest.mark.asyncio async def test_global_limit(limits_handler): - """ - Test global limit using the configured limit of 10 requests per minute - """ + """Test global limit using the configured limit of 10 requests per + minute.""" # Clear existing logs clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}" await limits_handler.connection_manager.execute_query(clear_query) @@ -98,9 +95,8 @@ async def test_global_limit(limits_handler): # Set global limit to match config: 10 requests per minute old_limits = limits_handler.config.limits - limits_handler.config.limits = LimitSettings( - global_per_min=10, monthly_limit=20 - ) + limits_handler.config.limits = LimitSettings(global_per_min=10, + monthly_limit=20) try: # Initial check should pass (no requests) @@ -115,14 +111,12 @@ async def test_global_limit(limits_handler): now = datetime.now(timezone.utc) one_min_ago = now - timedelta(minutes=1) global_count = await limits_handler._count_requests( - user_id, None, one_min_ago - ) + user_id, None, one_min_ago) print(f"Global count after 10 requests: {global_count}") # This should fail as we've hit global_per_min=10 - with pytest.raises( - ValueError, match="Global per-minute rate limit exceeded" - ): + with pytest.raises(ValueError, + match="Global per-minute rate limit exceeded"): await limits_handler.check_limits(test_user, route) finally: @@ -131,9 +125,8 @@ async def test_global_limit(limits_handler): @pytest.mark.asyncio async def test_monthly_limit(limits_handler): - """ - Test monthly limit using the configured limit of 20 requests per month - """ + """Test monthly limit using the configured limit of 20 requests per + month.""" # Clear existing logs clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}" await limits_handler.connection_manager.execute_query(clear_query) @@ -163,12 +156,13 @@ async def test_monthly_limit(limits_handler): # Get current month's count now = datetime.now(timezone.utc) - first_of_month = now.replace( - day=1, hour=0, minute=0, second=0, microsecond=0 - ) + first_of_month = now.replace(day=1, + hour=0, + minute=0, + second=0, + microsecond=0) monthly_count = await limits_handler._count_requests( - user_id, None, first_of_month - ) + user_id, None, first_of_month) print(f"Monthly count after 20 requests: {monthly_count}") # This should fail as we've hit monthly_limit=20 @@ -181,9 +175,7 @@ async def test_monthly_limit(limits_handler): @pytest.mark.asyncio async def test_user_level_override(limits_handler): - """ - Test user-specific override limits with debug logging - """ + """Test user-specific override limits with debug logging.""" user_id = UUID("47e53676-b478-5b3f-a409-234ca2164de5") route = "/test-route" @@ -200,15 +192,18 @@ async def test_user_level_override(limits_handler): limits_overrides={ "global_per_min": 2, "route_per_min": 1, - "route_overrides": {"/test-route": {"route_per_min": 1}}, + "route_overrides": { + "/test-route": { + "route_per_min": 1 + } + }, }, ) # Set default limits that should be overridden old_limits = limits_handler.config.limits - limits_handler.config.limits = LimitSettings( - global_per_min=10, monthly_limit=20 - ) + limits_handler.config.limits = LimitSettings(global_per_min=10, + monthly_limit=20) # Debug: Print current limits print(f"\nDefault limits: {limits_handler.config.limits}") @@ -226,11 +221,9 @@ async def test_user_level_override(limits_handler): now = datetime.now(timezone.utc) one_min_ago = now - timedelta(minutes=1) global_count = await limits_handler._count_requests( - user_id, None, one_min_ago - ) + user_id, None, one_min_ago) route_count = await limits_handler._count_requests( - user_id, route, one_min_ago - ) + user_id, route, one_min_ago) print("\nAfter first request:") print(f"Global count: {global_count}") print(f"Route count: {route_count}") @@ -239,9 +232,8 @@ async def test_user_level_override(limits_handler): await limits_handler.log_request(user_id, route) # This check should fail as we've hit route_per_min=1 - with pytest.raises( - ValueError, match="Per-route per-minute rate limit exceeded" - ): + with pytest.raises(ValueError, + match="Per-route per-minute rate limit exceeded"): await limits_handler.check_limits(test_user, route) finally: @@ -251,23 +243,24 @@ async def test_user_level_override(limits_handler): @pytest.mark.asyncio async def test_determine_effective_limits(limits_handler): - """ - Test that user-level overrides > route-level overrides > global defaults. + """Test that user-level overrides > route-level overrides > global + defaults. + This is a pure logic test of the 'determine_effective_limits' method. """ # Setup global/base defaults old_limits = limits_handler.config.limits - limits_handler.config.limits = LimitSettings( - global_per_min=10, route_per_min=5, monthly_limit=50 - ) + limits_handler.config.limits = LimitSettings(global_per_min=10, + route_per_min=5, + monthly_limit=50) # Setup route-level override route = "/some-route" old_route_limits = limits_handler.config.route_limits limits_handler.config.route_limits = { - route: LimitSettings( - global_per_min=8, route_per_min=3, monthly_limit=30 - ) + route: LimitSettings(global_per_min=8, + route_per_min=3, + monthly_limit=30) } # Setup user-level override @@ -280,7 +273,9 @@ async def test_determine_effective_limits(limits_handler): limits_overrides={ "global_per_min": 6, # should override "route_overrides": { - route: {"route_per_min": 2} # should override + route: { + "route_per_min": 2 + } # should override }, }, ) @@ -291,18 +286,15 @@ async def test_determine_effective_limits(limits_handler): # Check final / effective limits # Global limit overridden to 6 assert effective.global_per_min == 6, ( - "User-level global override not applied" - ) + "User-level global override not applied") # route_per_min should be overridden to 2 (not the route-level 3) assert effective.route_per_min == 2, ( - "User-level route override not applied" - ) + "User-level route override not applied") # monthly_limit from route-level override is 30, user didn't override it, so it should stay 30 assert effective.monthly_limit == 30, ( - "Route-level monthly override not applied" - ) + "Route-level monthly override not applied") finally: # revert changes limits_handler.config.limits = old_limits @@ -311,10 +303,8 @@ async def test_determine_effective_limits(limits_handler): @pytest.mark.asyncio async def test_separate_route_usage_is_isolated(limits_handler): - """ - Confirm that calls to /routeA do NOT increment the per-route usage for /routeB, - and vice-versa. - """ + """Confirm that calls to /routeA do NOT increment the per-route usage for + /routeB, and vice-versa.""" # 1) Clear existing logs clear_query = f"DELETE FROM {limits_handler._get_table_name(limits_handler.TABLE_NAME)}" await limits_handler.connection_manager.execute_query(clear_query) @@ -345,14 +335,12 @@ async def test_separate_route_usage_is_isolated(limits_handler): now = datetime.now(timezone.utc) one_min_ago = now - timedelta(minutes=1) routeA_count = await limits_handler._count_requests( - user_id, routeA, one_min_ago - ) + user_id, routeA, one_min_ago) assert routeA_count == 3, f"Expected 3 for routeA, got {routeA_count}" # 5) Check usage for routeB → Should be 0 routeB_count = await limits_handler._count_requests( - user_id, routeB, one_min_ago - ) + user_id, routeB, one_min_ago) assert routeB_count == 0, f"Expected 0 for routeB, got {routeB_count}" # 6) Insert some logs for routeB only @@ -361,17 +349,13 @@ async def test_separate_route_usage_is_isolated(limits_handler): # 7) Recheck usage routeA_count_after = await limits_handler._count_requests( - user_id, routeA, one_min_ago - ) + user_id, routeA, one_min_ago) routeB_count_after = await limits_handler._count_requests( - user_id, routeB, one_min_ago - ) + user_id, routeB, one_min_ago) assert routeA_count_after == 3, ( - f"RouteA usage changed unexpectedly: {routeA_count_after}" - ) + f"RouteA usage changed unexpectedly: {routeA_count_after}") assert routeB_count_after == 2, ( - f"RouteB usage is wrong: {routeB_count_after}" - ) + f"RouteB usage is wrong: {routeB_count_after}") # @pytest.mark.asyncio @@ -427,10 +411,8 @@ async def test_separate_route_usage_is_isolated(limits_handler): @pytest.mark.asyncio async def test_route_specific_monthly_usage(limits_handler): - """ - Confirm that monthly usage is tracked per-route - and doesn't get incremented by calls to other routes. - """ + """Confirm that monthly usage is tracked per-route and doesn't get + incremented by calls to other routes.""" # 1) Clear existing logs clear_query = f"DELETE FROM {limits_handler._get_table_name(limits_handler.TABLE_NAME)}" await limits_handler.connection_manager.execute_query(clear_query) @@ -454,14 +436,12 @@ async def test_route_specific_monthly_usage(limits_handler): # 4) Check monthly usage for routeA => should be 5 routeA_monthly = await limits_handler._count_monthly_requests( - user_id, routeA - ) + user_id, routeA) assert routeA_monthly == 5, f"Expected 5 for routeA, got {routeA_monthly}" # routeB => should still be 0 routeB_monthly = await limits_handler._count_monthly_requests( - user_id, routeB - ) + user_id, routeB) assert routeB_monthly == 0, f"Expected 0 for routeB, got {routeB_monthly}" # 5) Now log 3 requests for routeB @@ -470,22 +450,16 @@ async def test_route_specific_monthly_usage(limits_handler): # Re-check usage routeA_monthly_after = await limits_handler._count_monthly_requests( - user_id, routeA - ) + user_id, routeA) routeB_monthly_after = await limits_handler._count_monthly_requests( - user_id, routeB - ) + user_id, routeB) assert routeA_monthly_after == 5, ( - f"RouteA usage changed unexpectedly: {routeA_monthly_after}" - ) + f"RouteA usage changed unexpectedly: {routeA_monthly_after}") assert routeB_monthly_after == 3, ( - f"RouteB usage is wrong: {routeB_monthly_after}" - ) + f"RouteB usage is wrong: {routeB_monthly_after}") # Additionally confirm total usage across all routes - global_monthly = await limits_handler._count_monthly_requests( - user_id, route=None - ) + global_monthly = await limits_handler._count_monthly_requests(user_id, + route=None) assert global_monthly == 8, ( - f"Expected total of 8 monthly requests, got {global_monthly}" - ) + f"Expected total of 8 monthly requests, got {global_monthly}") diff --git a/py/tests/unit/test_prompts.py b/py/tests/unit/test_prompts.py index 6f61f5511..88490ffa7 100644 --- a/py/tests/unit/test_prompts.py +++ b/py/tests/unit/test_prompts.py @@ -7,9 +7,7 @@ import pytest @pytest.mark.asyncio async def test_add_prompt_basic(prompt_handler): - """ - Test basic addition of a new prompt. - """ + """Test basic addition of a new prompt.""" prompt_name = f"test_prompt_{uuid.uuid4()}" template = "Hello, {name}!" input_types = {"name": "str"} @@ -35,9 +33,8 @@ async def test_add_prompt_basic(prompt_handler): @pytest.mark.asyncio async def test_add_prompt_preserve_existing(prompt_handler): - """ - If preserve_existing is True, we skip overwriting even if we supply new data. - """ + """If preserve_existing is True, we skip overwriting even if we supply new + data.""" prompt_name = f"test_preserve_{uuid.uuid4()}" template_original = "Original template" input_types_original = {"param": "str"} @@ -67,10 +64,8 @@ async def test_add_prompt_preserve_existing(prompt_handler): @pytest.mark.asyncio async def test_add_prompt_overwrite_on_diff_false(prompt_handler, caplog): - """ - If overwrite_on_diff=False but there is a diff, skip updating - and log an info message. - """ + """If overwrite_on_diff=False but there is a diff, skip updating and log an + info message.""" prompt_name = f"test_diff_false_{uuid.uuid4()}" template_original = "Original template: {key}" input_types_original = {"key": "str"} @@ -99,16 +94,14 @@ async def test_add_prompt_overwrite_on_diff_false(prompt_handler, caplog): # Check logs for the skipping message assert any( - "Skipping update" in record.message for record in caplog.records - ), "Expected a skip update log message." + "Skipping update" in record.message + for record in caplog.records), "Expected a skip update log message." @pytest.mark.asyncio async def test_add_prompt_overwrite_on_diff_true(prompt_handler, caplog): - """ - If overwrite_on_diff=True and there's a diff, we overwrite existing prompt - and log a warning. - """ + """If overwrite_on_diff=True and there's a diff, we overwrite existing + prompt and log a warning.""" prompt_name = f"test_diff_true_{uuid.uuid4()}" template_original = "Original template: {key}" input_types_original = {"key": "str"} @@ -138,15 +131,12 @@ async def test_add_prompt_overwrite_on_diff_true(prompt_handler, caplog): # Check logs for the overwriting warning assert any( "Overwriting existing prompt" in record.message - for record in caplog.records - ), "Expected an overwrite warning message." + for record in caplog.records), "Expected an overwrite warning message." @pytest.mark.asyncio async def test_get_cached_prompt(prompt_handler): - """ - Test that get_cached_prompt uses caching properly. - """ + """Test that get_cached_prompt uses caching properly.""" prompt_name = f"test_cached_{uuid.uuid4()}" template = "Cached template: {key}" input_types = {"key": "str"} @@ -158,9 +148,8 @@ async def test_get_cached_prompt(prompt_handler): ) # First retrieval should set the cache - content_1 = await prompt_handler.get_cached_prompt( - prompt_name, {"key": "Bob"} - ) + content_1 = await prompt_handler.get_cached_prompt(prompt_name, + {"key": "Bob"}) assert "Bob" in content_1 # Modify in DB behind the scenes (simulate a change not going through add_prompt) @@ -172,23 +161,19 @@ async def test_get_cached_prompt(prompt_handler): WHERE name=$2 """ await prompt_handler.connection_manager.execute_query( - query, [new_template, prompt_name] - ) + query, [new_template, prompt_name]) # Second retrieval should still reflect the old template if the cache is not bypassed - content_2 = await prompt_handler.get_cached_prompt( - prompt_name, {"key": "Alice"} - ) + content_2 = await prompt_handler.get_cached_prompt(prompt_name, + {"key": "Alice"}) assert "Bob" not in content_2 # Just to ensure we see the difference assert "Updated in DB" not in content_2, ( - "Should not see updated text if cache is used." - ) + "Should not see updated text if cache is used.") assert "Cached template" in content_2 # Bypass cache - content_3 = await prompt_handler.get_cached_prompt( - prompt_name, {"key": "Alice"}, bypass_cache=True - ) + content_3 = await prompt_handler.get_cached_prompt(prompt_name, + {"key": "Alice"}, + bypass_cache=True) assert "Updated in DB" in content_3, ( - "Now we should see the new DB changes after bypassing cache." - ) + "Now we should see the new DB changes after bypassing cache.") diff --git a/py/tests/unit/test_routes.py b/py/tests/unit/test_routes.py index 4a69c3d4d..d47586274 100644 --- a/py/tests/unit/test_routes.py +++ b/py/tests/unit/test_routes.py @@ -94,10 +94,18 @@ def mock_services(): def mock_config(): config_data = { "app": {}, # AppConfig needs minimal data - "auth": {"provider": "mock"}, - "completion": {"provider": "mock"}, - "crypto": {"provider": "mock"}, - "database": {"provider": "mock"}, + "auth": { + "provider": "mock" + }, + "completion": { + "provider": "mock" + }, + "crypto": { + "provider": "mock" + }, + "database": { + "provider": "mock" + }, "embedding": { "provider": "mock", "base_model": "test", @@ -112,11 +120,22 @@ def mock_config(): "batch_size": 10, "add_title_as_prefix": True, }, - "email": {"provider": "mock"}, - "ingestion": {"provider": "mock"}, - "logging": {"provider": "mock", "log_table": "logs"}, - "agent": {"generation_config": {}}, - "orchestration": {"provider": "mock"}, + "email": { + "provider": "mock" + }, + "ingestion": { + "provider": "mock" + }, + "logging": { + "provider": "mock", + "log_table": "logs" + }, + "agent": { + "generation_config": {} + }, + "orchestration": { + "provider": "mock" + }, } return R2RConfig(config_data) @@ -129,38 +148,31 @@ def router(request, mock_providers, mock_services, mock_config): def test_all_routes_have_base_endpoint_decorator(router): for route in router.router.routes: - if ( - route.path.endswith("/stream") - or route.path.endswith("/viewer") - or "websocket" in str(type(route)).lower() - ): + if (route.path.endswith("/stream") or route.path.endswith("/viewer") + or "websocket" in str(type(route)).lower()): continue endpoint = route.endpoint assert hasattr(endpoint, "_is_base_endpoint"), ( - f"Route {route.path} missing @base_endpoint decorator" - ) + f"Route {route.path} missing @base_endpoint decorator") def test_all_routes_have_proper_return_type_hints(router): for route in router.router.routes: - if ( - route.path.endswith("/stream") - or "websocket" in str(type(route)).lower() - ): + if (route.path.endswith("/stream") + or "websocket" in str(type(route)).lower()): continue endpoint = route.endpoint return_type = inspect.signature(endpoint).return_annotation # Check if the type is an R2RResults by name - is_valid = isinstance(return_type, type) and ( - "R2RResults" in str(return_type) - or "PaginatedR2RResult" in str(return_type) - or return_type == FileResponse - or return_type == StreamingResponse - or return_type == _TemplateResponse - ) + is_valid = isinstance( + return_type, type) and ("R2RResults" in str(return_type) + or "PaginatedR2RResult" in str(return_type) + or return_type == FileResponse + or return_type == StreamingResponse + or return_type == _TemplateResponse) assert is_valid, ( f"Route {route.path} has invalid return type: {return_type}, expected R2RResults[...]" @@ -173,10 +185,8 @@ def test_all_routes_have_rate_limiting(router): for route in router.router.routes: print(f"Checking route: {route.path}") print(f"Dependencies: {route.dependencies}") - has_rate_limit = any( - dep.dependency == router.rate_limit_dependency - for dep in route.dependencies - ) + has_rate_limit = any(dep.dependency == router.rate_limit_dependency + for dep in route.dependencies) if not has_rate_limit: # We should require this in the future, but for now just warn warnings.warn(