1546 lines
44 KiB
Python
1546 lines
44 KiB
Python
import logging
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from io import BytesIO
|
|
from typing import Any, BinaryIO, Optional, Sequence, Tuple
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from core.base.abstractions import (
|
|
ChunkSearchResult,
|
|
Community,
|
|
DocumentResponse,
|
|
Entity,
|
|
IndexArgsHNSW,
|
|
IndexArgsIVFFlat,
|
|
IndexMeasure,
|
|
IndexMethod,
|
|
KGCreationSettings,
|
|
KGEnrichmentSettings,
|
|
KGEntityDeduplicationSettings,
|
|
Message,
|
|
Relationship,
|
|
SearchSettings,
|
|
User,
|
|
VectorEntry,
|
|
VectorTableName,
|
|
)
|
|
from core.base.api.models import CollectionResponse, GraphResponse
|
|
|
|
from ..logger import RunInfoLog
|
|
from ..logger.base import RunType
|
|
from .base import Provider, ProviderConfig
|
|
|
|
"""Base classes for knowledge graph providers."""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Optional, Tuple
|
|
from uuid import UUID
|
|
|
|
from ..abstractions import (
|
|
Community,
|
|
Entity,
|
|
GraphSearchSettings,
|
|
KGCreationSettings,
|
|
KGEnrichmentSettings,
|
|
KGEntityDeduplicationSettings,
|
|
KGExtraction,
|
|
Relationship,
|
|
)
|
|
from .base import ProviderConfig
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class PostgresConfigurationSettings(BaseModel):
|
|
"""
|
|
Configuration settings with defaults defined by the PGVector docker image.
|
|
|
|
These settings are helpful in managing the connections to the database.
|
|
To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
|
|
"""
|
|
|
|
checkpoint_completion_target: Optional[float] = 0.9
|
|
default_statistics_target: Optional[int] = 100
|
|
effective_io_concurrency: Optional[int] = 1
|
|
effective_cache_size: Optional[int] = 524288
|
|
huge_pages: Optional[str] = "try"
|
|
maintenance_work_mem: Optional[int] = 65536
|
|
max_connections: Optional[int] = 256
|
|
max_parallel_workers_per_gather: Optional[int] = 2
|
|
max_parallel_workers: Optional[int] = 8
|
|
max_parallel_maintenance_workers: Optional[int] = 2
|
|
max_wal_size: Optional[int] = 1024
|
|
max_worker_processes: Optional[int] = 8
|
|
min_wal_size: Optional[int] = 80
|
|
shared_buffers: Optional[int] = 16384
|
|
statement_cache_size: Optional[int] = 100
|
|
random_page_cost: Optional[float] = 4
|
|
wal_buffers: Optional[int] = 512
|
|
work_mem: Optional[int] = 4096
|
|
|
|
|
|
class DatabaseConfig(ProviderConfig):
|
|
"""A base database configuration class"""
|
|
|
|
provider: str = "postgres"
|
|
user: Optional[str] = None
|
|
password: Optional[str] = None
|
|
host: Optional[str] = None
|
|
port: Optional[int] = None
|
|
db_name: Optional[str] = None
|
|
project_name: Optional[str] = None
|
|
postgres_configuration_settings: Optional[
|
|
PostgresConfigurationSettings
|
|
] = None
|
|
default_collection_name: str = "Default"
|
|
default_collection_description: str = "Your default collection."
|
|
enable_fts: bool = False
|
|
|
|
# KG settings
|
|
batch_size: Optional[int] = 1
|
|
kg_store_path: Optional[str] = None
|
|
graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
|
|
graph_creation_settings: KGCreationSettings = KGCreationSettings()
|
|
graph_entity_deduplication_settings: KGEntityDeduplicationSettings = (
|
|
KGEntityDeduplicationSettings()
|
|
)
|
|
graph_search_settings: GraphSearchSettings = GraphSearchSettings()
|
|
|
|
def __post_init__(self):
|
|
self.validate_config()
|
|
# Capture additional fields
|
|
for key, value in self.extra_fields.items():
|
|
setattr(self, key, value)
|
|
|
|
def validate_config(self) -> None:
|
|
if self.provider not in self.supported_providers:
|
|
raise ValueError(f"Provider '{self.provider}' is not supported.")
|
|
|
|
@property
|
|
def supported_providers(self) -> list[str]:
|
|
return ["postgres"]
|
|
|
|
|
|
class DatabaseConnectionManager(ABC):
|
|
@abstractmethod
|
|
def execute_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
isolation_level: Optional[str] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def execute_many(self, query, params=None, batch_size=1000):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetch_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetchrow_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def initialize(self, pool: Any):
|
|
pass
|
|
|
|
|
|
class Handler(ABC):
|
|
def __init__(
|
|
self,
|
|
project_name: str,
|
|
connection_manager: DatabaseConnectionManager,
|
|
):
|
|
self.project_name = project_name
|
|
self.connection_manager = connection_manager
|
|
|
|
def _get_table_name(self, base_name: str) -> str:
|
|
return f"{self.project_name}.{base_name}"
|
|
|
|
@abstractmethod
|
|
def create_tables(self):
|
|
pass
|
|
|
|
|
|
class DocumentHandler(Handler):
|
|
|
|
@abstractmethod
|
|
async def upsert_documents_overview(
|
|
self,
|
|
documents_overview: DocumentResponse | list[DocumentResponse],
|
|
) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_from_documents_overview(
|
|
self,
|
|
document_id: UUID,
|
|
version: Optional[str] = None,
|
|
) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_documents_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filter_user_ids: Optional[list[UUID]] = None,
|
|
filter_document_ids: Optional[list[UUID]] = None,
|
|
filter_collection_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_workflow_status(
|
|
self,
|
|
id: UUID | list[UUID],
|
|
status_type: str,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def set_workflow_status(
|
|
self,
|
|
id: UUID | list[UUID],
|
|
status_type: str,
|
|
status: str,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_document_ids_by_status(
|
|
self,
|
|
status_type: str,
|
|
status: str | list[str],
|
|
collection_id: Optional[UUID] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def search_documents(
|
|
self,
|
|
query_text: str,
|
|
query_embedding: Optional[list[float]] = None,
|
|
search_settings: Optional[SearchSettings] = None,
|
|
) -> list[DocumentResponse]:
|
|
pass
|
|
|
|
|
|
class CollectionsHandler(Handler):
|
|
@abstractmethod
|
|
async def collection_exists(self, collection_id: UUID) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def create_collection(
|
|
self,
|
|
owner_id: UUID,
|
|
name: Optional[str] = None,
|
|
description: str = "",
|
|
collection_id: Optional[UUID] = None,
|
|
) -> CollectionResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update_collection(
|
|
self,
|
|
collection_id: UUID,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
) -> CollectionResponse:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_collection_relational(self, collection_id: UUID) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def documents_in_collection(
|
|
self, collection_id: UUID, offset: int, limit: int
|
|
) -> dict[str, list[DocumentResponse] | int]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_collections_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filter_user_ids: Optional[list[UUID]] = None,
|
|
filter_document_ids: Optional[list[UUID]] = None,
|
|
filter_collection_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, list[CollectionResponse] | int]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def assign_document_to_collection_relational(
|
|
self,
|
|
document_id: UUID,
|
|
collection_id: UUID,
|
|
) -> UUID:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def remove_document_from_collection_relational(
|
|
self, document_id: UUID, collection_id: UUID
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
class TokenHandler(Handler):
|
|
|
|
@abstractmethod
|
|
async def create_tables(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def blacklist_token(
|
|
self, token: str, current_time: Optional[datetime] = None
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def is_token_blacklisted(self, token: str) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def clean_expired_blacklisted_tokens(
|
|
self,
|
|
max_age_hours: int = 7 * 24,
|
|
current_time: Optional[datetime] = None,
|
|
):
|
|
pass
|
|
|
|
|
|
class UserHandler(Handler):
|
|
TABLE_NAME = "users"
|
|
|
|
@abstractmethod
|
|
async def get_user_by_id(self, user_id: UUID) -> User:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_user_by_email(self, email: str) -> User:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def create_user(
|
|
self, email: str, password: str, is_superuser: bool
|
|
) -> User:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update_user(self, user: User) -> User:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_user_relational(self, user_id: UUID) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update_user_password(
|
|
self, user_id: UUID, new_hashed_password: str
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_all_users(self) -> list[User]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def store_verification_code(
|
|
self, user_id: UUID, verification_code: str, expiry: datetime
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def verify_user(self, verification_code: str) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def remove_verification_code(self, verification_code: str):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def expire_verification_code(self, user_id: UUID):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def store_reset_token(
|
|
self, user_id: UUID, reset_token: str, expiry: datetime
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_user_id_by_reset_token(
|
|
self, reset_token: str
|
|
) -> Optional[UUID]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def remove_reset_token(self, user_id: UUID):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def remove_user_from_all_collections(self, user_id: UUID):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def add_user_to_collection(
|
|
self, user_id: UUID, collection_id: UUID
|
|
) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def remove_user_from_collection(
|
|
self, user_id: UUID, collection_id: UUID
|
|
) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_users_in_collection(
|
|
self, collection_id: UUID, offset: int, limit: int
|
|
) -> dict[str, list[User] | int]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def mark_user_as_superuser(self, user_id: UUID):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_user_id_by_verification_code(
|
|
self, verification_code: str
|
|
) -> Optional[UUID]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def mark_user_as_verified(self, user_id: UUID):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_users_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
user_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, list[User] | int]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_user_validation_data(
|
|
self,
|
|
user_id: UUID,
|
|
) -> dict:
|
|
"""
|
|
Get verification data for a specific user.
|
|
This method should be called after superuser authorization has been verified.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ChunkHandler(Handler):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@abstractmethod
|
|
async def upsert(self, entry: VectorEntry) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def upsert_entries(self, entries: list[VectorEntry]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def semantic_search(
|
|
self, query_vector: list[float], search_settings: SearchSettings
|
|
) -> list[ChunkSearchResult]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def full_text_search(
|
|
self, query_text: str, search_settings: SearchSettings
|
|
) -> list[ChunkSearchResult]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def hybrid_search(
|
|
self,
|
|
query_text: str,
|
|
query_vector: list[float],
|
|
search_settings: SearchSettings,
|
|
*args,
|
|
**kwargs,
|
|
) -> list[ChunkSearchResult]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete(
|
|
self, filters: dict[str, Any]
|
|
) -> dict[str, dict[str, str]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def assign_document_to_collection_vector(
|
|
self, document_id: UUID, collection_id: UUID
|
|
) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def remove_document_from_collection_vector(
|
|
self, document_id: UUID, collection_id: UUID
|
|
) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_user_vector(self, user_id: UUID) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_collection_vector(self, collection_id: UUID) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def list_document_chunks(
|
|
self,
|
|
document_id: UUID,
|
|
offset: int,
|
|
limit: int,
|
|
include_vectors: bool = False,
|
|
) -> dict[str, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_chunk(self, chunk_id: UUID) -> dict:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def create_index(
|
|
self,
|
|
table_name: Optional[VectorTableName] = None,
|
|
index_measure: IndexMeasure = IndexMeasure.cosine_distance,
|
|
index_method: IndexMethod = IndexMethod.auto,
|
|
index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
|
|
index_name: Optional[str] = None,
|
|
index_column: Optional[str] = None,
|
|
concurrently: bool = True,
|
|
) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def list_indices(
|
|
self, offset: int, limit: int, filters: Optional[dict] = None
|
|
) -> dict:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_index(
|
|
self,
|
|
index_name: str,
|
|
table_name: Optional[VectorTableName] = None,
|
|
concurrently: bool = True,
|
|
) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_semantic_neighbors(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
document_id: UUID,
|
|
chunk_id: UUID,
|
|
similarity_threshold: float = 0.5,
|
|
) -> list[dict[str, Any]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def list_chunks(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filters: Optional[dict[str, Any]] = None,
|
|
include_vectors: bool = False,
|
|
) -> dict[str, Any]:
|
|
pass
|
|
|
|
|
|
class EntityHandler(Handler):
|
|
|
|
@abstractmethod
|
|
async def create(self, *args: Any, **kwargs: Any) -> Entity:
|
|
"""Create entities in storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get(self, *args: Any, **kwargs: Any) -> list[Entity]:
|
|
"""Get entities from storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update(self, *args: Any, **kwargs: Any) -> Entity:
|
|
"""Update entities in storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Delete entities from storage."""
|
|
pass
|
|
|
|
|
|
class RelationshipHandler(Handler):
|
|
@abstractmethod
|
|
async def create(self, *args: Any, **kwargs: Any) -> Relationship:
|
|
"""Add relationships to storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get(self, *args: Any, **kwargs: Any) -> list[Relationship]:
|
|
"""Get relationships from storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update(self, *args: Any, **kwargs: Any) -> Relationship:
|
|
"""Update relationships in storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Delete relationships from storage."""
|
|
pass
|
|
|
|
|
|
class CommunityHandler(Handler):
|
|
@abstractmethod
|
|
async def create(self, *args: Any, **kwargs: Any) -> Community:
|
|
"""Create communities in storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get(self, *args: Any, **kwargs: Any) -> list[Community]:
|
|
"""Get communities from storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update(self, *args: Any, **kwargs: Any) -> Community:
|
|
"""Update communities in storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Delete communities from storage."""
|
|
pass
|
|
|
|
|
|
class GraphHandler(Handler):
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@abstractmethod
|
|
async def create(self, *args: Any, **kwargs: Any) -> GraphResponse:
|
|
"""Create graph"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update(
|
|
self,
|
|
graph_id: UUID,
|
|
name: Optional[str],
|
|
description: Optional[str],
|
|
) -> GraphResponse:
|
|
"""Update graph"""
|
|
pass
|
|
|
|
|
|
class PromptHandler(Handler):
|
|
"""Abstract base class for prompt handling operations."""
|
|
|
|
@abstractmethod
|
|
async def add_prompt(
|
|
self, name: str, template: str, input_types: dict[str, str]
|
|
) -> None:
|
|
"""Add a new prompt template to the database."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_cached_prompt(
|
|
self,
|
|
prompt_name: str,
|
|
inputs: Optional[dict[str, Any]] = None,
|
|
prompt_override: Optional[str] = None,
|
|
) -> str:
|
|
"""Retrieve and format a prompt template."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_prompt(
|
|
self,
|
|
prompt_name: str,
|
|
inputs: Optional[dict[str, Any]] = None,
|
|
prompt_override: Optional[str] = None,
|
|
) -> str:
|
|
"""Retrieve and format a prompt template."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_all_prompts(self) -> dict[str, Any]:
|
|
"""Retrieve all stored prompts."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update_prompt(
|
|
self,
|
|
name: str,
|
|
template: Optional[str] = None,
|
|
input_types: Optional[dict[str, str]] = None,
|
|
) -> None:
|
|
"""Update an existing prompt template."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_prompt(self, name: str) -> None:
|
|
"""Delete a prompt template."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_message_payload(
|
|
self,
|
|
system_prompt_name: Optional[str] = None,
|
|
system_role: str = "system",
|
|
system_inputs: dict = {},
|
|
system_prompt_override: Optional[str] = None,
|
|
task_prompt_name: Optional[str] = None,
|
|
task_role: str = "user",
|
|
task_inputs: dict = {},
|
|
task_prompt_override: Optional[str] = None,
|
|
) -> list[dict]:
|
|
"""Get the payload of a prompt."""
|
|
pass
|
|
|
|
|
|
class FileHandler(Handler):
|
|
"""Abstract base class for file handling operations."""
|
|
|
|
@abstractmethod
|
|
async def upsert_file(
|
|
self,
|
|
document_id: UUID,
|
|
file_name: str,
|
|
file_oid: int,
|
|
file_size: int,
|
|
file_type: Optional[str] = None,
|
|
) -> None:
|
|
"""Add or update a file entry in storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def store_file(
|
|
self,
|
|
document_id: UUID,
|
|
file_name: str,
|
|
file_content: BytesIO,
|
|
file_type: Optional[str] = None,
|
|
) -> None:
|
|
"""Store a new file in the database."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def retrieve_file(
|
|
self, document_id: UUID
|
|
) -> Optional[tuple[str, BinaryIO, int]]:
|
|
"""Retrieve a file from storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_file(self, document_id: UUID) -> bool:
|
|
"""Delete a file from storage."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_files_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filter_document_ids: Optional[list[UUID]] = None,
|
|
filter_file_names: Optional[list[str]] = None,
|
|
) -> list[dict]:
|
|
"""Get an overview of stored files."""
|
|
pass
|
|
|
|
|
|
class LoggingHandler(Handler):
|
|
"""Abstract base class defining the interface for logging handlers."""
|
|
|
|
@abstractmethod
|
|
async def close(self) -> None:
|
|
"""Close any open connections."""
|
|
pass
|
|
|
|
# Basic logging methods
|
|
@abstractmethod
|
|
async def log(self, run_id: UUID, key: str, value: str) -> None:
|
|
"""Log a key-value pair for a specific run."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def info_log(
|
|
self, run_id: UUID, run_type: RunType, user_id: UUID
|
|
) -> None:
|
|
"""Log run information."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_logs(
|
|
self, run_ids: list[UUID], limit_per_run: int = 10
|
|
) -> list[dict]:
|
|
"""Retrieve logs for specified run IDs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_info_logs(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
run_type_filter: Optional[RunType] = None,
|
|
user_ids: Optional[list[UUID]] = None,
|
|
) -> list[RunInfoLog]:
|
|
"""Retrieve run information logs with filtering options."""
|
|
pass
|
|
|
|
# Conversation management methods
|
|
@abstractmethod
|
|
async def create_conversation(self) -> dict:
|
|
"""Create a new conversation and return its ID."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_conversation(self, conversation_id: str) -> None:
|
|
"""Delete a conversation and all associated data."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_conversations(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
conversation_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, list[dict] | int]:
|
|
"""Get an overview of conversations with pagination."""
|
|
pass
|
|
|
|
# Message management methods
|
|
@abstractmethod
|
|
async def add_message(
|
|
self,
|
|
conversation_id: str,
|
|
content: Message,
|
|
parent_id: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
) -> str:
|
|
"""Add a message to a conversation."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def edit_message(
|
|
self, message_id: str, new_content: str
|
|
) -> Tuple[str, str]:
|
|
"""Edit an existing message and return new message ID and branch ID."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_conversation(
|
|
self, conversation_id: str, branch_id: Optional[str] = None
|
|
) -> list[Tuple[str, Message]]:
|
|
"""Retrieve all messages in a conversation branch."""
|
|
pass
|
|
|
|
# Branch management methods
|
|
@abstractmethod
|
|
async def get_branches(self, conversation_id: str) -> list[dict]:
|
|
"""Get an overview of all branches in a conversation."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_next_branch(self, current_branch_id: str) -> Optional[str]:
|
|
"""Get the ID of the next branch in chronological order."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_prev_branch(self, current_branch_id: str) -> Optional[str]:
|
|
"""Get the ID of the previous branch in chronological order."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def branch_at_message(self, message_id: str) -> str:
|
|
"""Create a new branch starting at a specific message."""
|
|
pass
|
|
|
|
|
|
class DatabaseProvider(Provider):
|
|
connection_manager: DatabaseConnectionManager
|
|
document_handler: DocumentHandler
|
|
collections_handler: CollectionsHandler
|
|
token_handler: TokenHandler
|
|
user_handler: UserHandler
|
|
vector_handler: ChunkHandler
|
|
entity_handler: EntityHandler
|
|
relationship_handler: RelationshipHandler
|
|
graph_handler: GraphHandler
|
|
prompt_handler: PromptHandler
|
|
file_handler: FileHandler
|
|
logging_handler: LoggingHandler
|
|
config: DatabaseConfig
|
|
project_name: str
|
|
|
|
def __init__(self, config: DatabaseConfig):
|
|
logger.info(f"Initializing DatabaseProvider with config {config}.")
|
|
super().__init__(config)
|
|
|
|
@abstractmethod
|
|
async def __aenter__(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
pass
|
|
|
|
# Document handler methods
|
|
async def upsert_documents_overview(
|
|
self,
|
|
documents_overview: DocumentResponse | list[DocumentResponse],
|
|
) -> None:
|
|
return await self.document_handler.upsert_documents_overview(
|
|
documents_overview
|
|
)
|
|
|
|
async def delete_from_documents_overview(
|
|
self, document_id: UUID, version: Optional[str] = None
|
|
) -> None:
|
|
return await self.document_handler.delete_from_documents_overview(
|
|
document_id, version
|
|
)
|
|
|
|
async def get_documents_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filter_user_ids: Optional[list[UUID]] = None,
|
|
filter_document_ids: Optional[list[UUID]] = None,
|
|
filter_collection_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, Any]:
|
|
return await self.document_handler.get_documents_overview(
|
|
offset=offset,
|
|
limit=limit,
|
|
filter_user_ids=filter_user_ids,
|
|
filter_document_ids=filter_document_ids,
|
|
filter_collection_ids=filter_collection_ids,
|
|
)
|
|
|
|
async def get_workflow_status(
|
|
self, id: UUID | list[UUID], status_type: str
|
|
):
|
|
return await self.document_handler.get_workflow_status(id, status_type)
|
|
|
|
async def set_workflow_status(
|
|
self,
|
|
id: UUID | list[UUID],
|
|
status_type: str,
|
|
status: str,
|
|
):
|
|
return await self.document_handler.set_workflow_status(
|
|
id, status_type, status
|
|
)
|
|
|
|
async def get_document_ids_by_status(
|
|
self,
|
|
status_type: str,
|
|
status: str | list[str],
|
|
collection_id: Optional[UUID] = None,
|
|
):
|
|
return await self.document_handler.get_document_ids_by_status(
|
|
status_type, status, collection_id
|
|
)
|
|
|
|
# Collection handler methods
|
|
async def collection_exists(self, collection_id: UUID) -> bool:
|
|
return await self.collections_handler.collection_exists(collection_id)
|
|
|
|
async def create_collection(
|
|
self,
|
|
owner_id: UUID,
|
|
name: Optional[str] = None,
|
|
description: str = "",
|
|
collection_id: Optional[UUID] = None,
|
|
) -> CollectionResponse:
|
|
return await self.collections_handler.create_collection(
|
|
owner_id=owner_id,
|
|
name=name,
|
|
description=description,
|
|
collection_id=collection_id,
|
|
)
|
|
|
|
async def update_collection(
|
|
self,
|
|
collection_id: UUID,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
) -> CollectionResponse:
|
|
return await self.collections_handler.update_collection(
|
|
collection_id, name, description
|
|
)
|
|
|
|
async def delete_collection_relational(self, collection_id: UUID) -> None:
|
|
return await self.collections_handler.delete_collection_relational(
|
|
collection_id
|
|
)
|
|
|
|
async def documents_in_collection(
|
|
self, collection_id: UUID, offset: int, limit: int
|
|
) -> dict[str, list[DocumentResponse] | int]:
|
|
return await self.collections_handler.documents_in_collection(
|
|
collection_id, offset, limit
|
|
)
|
|
|
|
async def get_collections_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filter_user_ids: Optional[list[UUID]] = None,
|
|
filter_document_ids: Optional[list[UUID]] = None,
|
|
filter_collection_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, list[CollectionResponse] | int]:
|
|
return await self.collections_handler.get_collections_overview(
|
|
offset=offset,
|
|
limit=limit,
|
|
filter_user_ids=filter_user_ids,
|
|
filter_document_ids=filter_document_ids,
|
|
filter_collection_ids=filter_collection_ids,
|
|
)
|
|
|
|
async def assign_document_to_collection_relational(
|
|
self,
|
|
document_id: UUID,
|
|
collection_id: UUID,
|
|
) -> UUID:
|
|
return await self.collections_handler.assign_document_to_collection_relational(
|
|
document_id=document_id,
|
|
collection_id=collection_id,
|
|
)
|
|
|
|
async def remove_document_from_collection_relational(
|
|
self, document_id: UUID, collection_id: UUID
|
|
) -> None:
|
|
return await self.collections_handler.remove_document_from_collection_relational(
|
|
document_id, collection_id
|
|
)
|
|
|
|
# Token handler methods
|
|
async def blacklist_token(
|
|
self, token: str, current_time: Optional[datetime] = None
|
|
):
|
|
return await self.token_handler.blacklist_token(token, current_time)
|
|
|
|
async def is_token_blacklisted(self, token: str) -> bool:
|
|
return await self.token_handler.is_token_blacklisted(token)
|
|
|
|
async def clean_expired_blacklisted_tokens(
|
|
self,
|
|
max_age_hours: int = 7 * 24,
|
|
current_time: Optional[datetime] = None,
|
|
):
|
|
return await self.token_handler.clean_expired_blacklisted_tokens(
|
|
max_age_hours, current_time
|
|
)
|
|
|
|
# User handler methods
|
|
async def get_user_by_id(self, user_id: UUID) -> User:
|
|
return await self.user_handler.get_user_by_id(user_id)
|
|
|
|
async def get_user_by_email(self, email: str) -> User:
|
|
return await self.user_handler.get_user_by_email(email)
|
|
|
|
async def create_user(
|
|
self, email: str, password: str, is_superuser: bool = False
|
|
) -> User:
|
|
return await self.user_handler.create_user(
|
|
email=email,
|
|
password=password,
|
|
is_superuser=is_superuser,
|
|
)
|
|
|
|
async def update_user(self, user: User) -> User:
|
|
return await self.user_handler.update_user(user)
|
|
|
|
async def delete_user_relational(self, user_id: UUID) -> None:
|
|
return await self.user_handler.delete_user_relational(user_id)
|
|
|
|
async def update_user_password(
|
|
self, user_id: UUID, new_hashed_password: str
|
|
):
|
|
return await self.user_handler.update_user_password(
|
|
user_id, new_hashed_password
|
|
)
|
|
|
|
async def get_all_users(self) -> list[User]:
|
|
return await self.user_handler.get_all_users()
|
|
|
|
async def store_verification_code(
|
|
self, user_id: UUID, verification_code: str, expiry: datetime
|
|
):
|
|
return await self.user_handler.store_verification_code(
|
|
user_id, verification_code, expiry
|
|
)
|
|
|
|
async def verify_user(self, verification_code: str) -> None:
|
|
return await self.user_handler.verify_user(verification_code)
|
|
|
|
async def remove_verification_code(self, verification_code: str):
|
|
return await self.user_handler.remove_verification_code(
|
|
verification_code
|
|
)
|
|
|
|
async def expire_verification_code(self, user_id: UUID):
|
|
return await self.user_handler.expire_verification_code(user_id)
|
|
|
|
async def store_reset_token(
|
|
self, user_id: UUID, reset_token: str, expiry: datetime
|
|
):
|
|
return await self.user_handler.store_reset_token(
|
|
user_id, reset_token, expiry
|
|
)
|
|
|
|
async def get_user_id_by_reset_token(
|
|
self, reset_token: str
|
|
) -> Optional[UUID]:
|
|
return await self.user_handler.get_user_id_by_reset_token(reset_token)
|
|
|
|
async def remove_reset_token(self, user_id: UUID):
|
|
return await self.user_handler.remove_reset_token(user_id)
|
|
|
|
async def remove_user_from_all_collections(self, user_id: UUID):
|
|
return await self.user_handler.remove_user_from_all_collections(
|
|
user_id
|
|
)
|
|
|
|
async def add_user_to_collection(
|
|
self, user_id: UUID, collection_id: UUID
|
|
) -> bool:
|
|
return await self.user_handler.add_user_to_collection(
|
|
user_id, collection_id
|
|
)
|
|
|
|
async def remove_user_from_collection(
|
|
self, user_id: UUID, collection_id: UUID
|
|
) -> bool:
|
|
return await self.user_handler.remove_user_from_collection(
|
|
user_id, collection_id
|
|
)
|
|
|
|
async def get_users_in_collection(
|
|
self, collection_id: UUID, offset: int, limit: int
|
|
) -> dict[str, list[User] | int]:
|
|
return await self.user_handler.get_users_in_collection(
|
|
collection_id, offset, limit
|
|
)
|
|
|
|
async def mark_user_as_superuser(self, user_id: UUID):
|
|
return await self.user_handler.mark_user_as_superuser(user_id)
|
|
|
|
async def get_user_id_by_verification_code(
|
|
self, verification_code: str
|
|
) -> Optional[UUID]:
|
|
return await self.user_handler.get_user_id_by_verification_code(
|
|
verification_code
|
|
)
|
|
|
|
async def mark_user_as_verified(self, user_id: UUID):
|
|
return await self.user_handler.mark_user_as_verified(user_id)
|
|
|
|
async def get_users_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
user_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, list[User] | int]:
|
|
return await self.user_handler.get_users_overview(
|
|
offset=offset,
|
|
limit=limit,
|
|
user_ids=user_ids,
|
|
)
|
|
|
|
async def get_user_validation_data(
|
|
self,
|
|
user_id: UUID,
|
|
) -> dict:
|
|
return await self.user_handler.get_user_validation_data(
|
|
user_id=user_id
|
|
)
|
|
|
|
# Vector handler methods
|
|
async def upsert(self, entry: VectorEntry) -> None:
|
|
return await self.vector_handler.upsert(entry)
|
|
|
|
async def upsert_entries(self, entries: list[VectorEntry]) -> None:
|
|
return await self.vector_handler.upsert_entries(entries)
|
|
|
|
async def semantic_search(
|
|
self, query_vector: list[float], search_settings: SearchSettings
|
|
) -> list[ChunkSearchResult]:
|
|
return await self.vector_handler.semantic_search(
|
|
query_vector, search_settings
|
|
)
|
|
|
|
async def full_text_search(
|
|
self, query_text: str, search_settings: SearchSettings
|
|
) -> list[ChunkSearchResult]:
|
|
return await self.vector_handler.full_text_search(
|
|
query_text, search_settings
|
|
)
|
|
|
|
async def hybrid_search(
|
|
self,
|
|
query_text: str,
|
|
query_vector: list[float],
|
|
search_settings: SearchSettings,
|
|
*args,
|
|
**kwargs,
|
|
) -> list[ChunkSearchResult]:
|
|
return await self.vector_handler.hybrid_search(
|
|
query_text, query_vector, search_settings, *args, **kwargs
|
|
)
|
|
|
|
async def search_documents(
|
|
self,
|
|
query_text: str,
|
|
settings: SearchSettings,
|
|
query_embedding: Optional[list[float]] = None,
|
|
) -> list[DocumentResponse]:
|
|
return await self.document_handler.search_documents(
|
|
query_text, query_embedding, settings
|
|
)
|
|
|
|
async def delete(
|
|
self, filters: dict[str, Any]
|
|
) -> dict[str, dict[str, str]]:
|
|
result = await self.vector_handler.delete(filters)
|
|
try:
|
|
await self.entity_handler.delete(parent_id=filters["id"]["$eq"])
|
|
except Exception as e:
|
|
logger.debug(f"Attempt to delete entity failed: {e}")
|
|
try:
|
|
await self.relationship_handler.delete(
|
|
parent_id=filters["id"]["$eq"]
|
|
)
|
|
except Exception as e:
|
|
logger.debug(f"Attempt to delete relationship failed: {e}")
|
|
return result
|
|
|
|
async def assign_document_to_collection_vector(
|
|
self,
|
|
document_id: UUID,
|
|
collection_id: UUID,
|
|
) -> None:
|
|
return await self.vector_handler.assign_document_to_collection_vector(
|
|
document_id=document_id,
|
|
collection_id=collection_id,
|
|
)
|
|
|
|
async def remove_document_from_collection_vector(
|
|
self,
|
|
document_id: UUID,
|
|
collection_id: UUID,
|
|
) -> None:
|
|
return (
|
|
await self.vector_handler.remove_document_from_collection_vector(
|
|
document_id, collection_id
|
|
)
|
|
)
|
|
|
|
async def delete_user_vector(self, user_id: UUID) -> None:
|
|
return await self.vector_handler.delete_user_vector(user_id)
|
|
|
|
async def delete_collection_vector(self, collection_id: UUID) -> None:
|
|
return await self.vector_handler.delete_collection_vector(
|
|
collection_id
|
|
)
|
|
|
|
async def list_document_chunks(
|
|
self,
|
|
document_id: UUID,
|
|
offset: int,
|
|
limit: int,
|
|
include_vectors: bool = False,
|
|
) -> dict[str, Any]:
|
|
return await self.vector_handler.list_document_chunks(
|
|
document_id=document_id,
|
|
offset=offset,
|
|
limit=limit,
|
|
include_vectors=include_vectors,
|
|
)
|
|
|
|
async def get_chunk(self, chunk_id: UUID) -> dict:
|
|
return await self.vector_handler.get_chunk(chunk_id)
|
|
|
|
async def create_index(
|
|
self,
|
|
table_name: Optional[VectorTableName] = None,
|
|
index_measure: IndexMeasure = IndexMeasure.cosine_distance,
|
|
index_method: IndexMethod = IndexMethod.auto,
|
|
index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
|
|
index_name: Optional[str] = None,
|
|
index_column: Optional[str] = None,
|
|
concurrently: bool = True,
|
|
) -> None:
|
|
return await self.vector_handler.create_index(
|
|
table_name,
|
|
index_measure,
|
|
index_method,
|
|
index_arguments,
|
|
index_name,
|
|
index_column,
|
|
concurrently,
|
|
)
|
|
|
|
async def list_indices(
|
|
self, offset: int, limit: int, filters: Optional[dict] = None
|
|
) -> dict:
|
|
return await self.vector_handler.list_indices(offset, limit, filters)
|
|
|
|
async def delete_index(
|
|
self,
|
|
index_name: str,
|
|
table_name: Optional[VectorTableName] = None,
|
|
concurrently: bool = True,
|
|
) -> None:
|
|
return await self.vector_handler.delete_index(
|
|
index_name, table_name, concurrently
|
|
)
|
|
|
|
async def get_semantic_neighbors(
|
|
self,
|
|
document_id: UUID,
|
|
chunk_id: UUID,
|
|
offset: int,
|
|
limit: int,
|
|
similarity_threshold: float = 0.5,
|
|
) -> list[dict[str, Any]]:
|
|
return await self.vector_handler.get_semantic_neighbors(
|
|
offset=offset,
|
|
limit=limit,
|
|
document_id=document_id,
|
|
chunk_id=chunk_id,
|
|
similarity_threshold=similarity_threshold,
|
|
)
|
|
|
|
async def add_prompt(
|
|
self, name: str, template: str, input_types: dict[str, str]
|
|
) -> None:
|
|
return await self.prompt_handler.add_prompt(
|
|
name, template, input_types
|
|
)
|
|
|
|
async def get_cached_prompt(
|
|
self,
|
|
prompt_name: str,
|
|
inputs: Optional[dict[str, Any]] = None,
|
|
prompt_override: Optional[str] = None,
|
|
) -> str:
|
|
return await self.prompt_handler.get_cached_prompt(
|
|
prompt_name, inputs, prompt_override
|
|
)
|
|
|
|
async def get_prompt(
|
|
self,
|
|
prompt_name: str,
|
|
inputs: Optional[dict[str, Any]] = None,
|
|
prompt_override: Optional[str] = None,
|
|
) -> str:
|
|
return await self.prompt_handler.get_prompt(
|
|
prompt_name, inputs, prompt_override
|
|
)
|
|
|
|
async def get_all_prompts(self) -> dict[str, Any]:
|
|
return await self.prompt_handler.get_all_prompts()
|
|
|
|
async def update_prompt(
|
|
self,
|
|
name: str,
|
|
template: Optional[str] = None,
|
|
input_types: Optional[dict[str, str]] = None,
|
|
) -> None:
|
|
return await self.prompt_handler.update_prompt(
|
|
name, template, input_types
|
|
)
|
|
|
|
async def delete_prompt(self, name: str) -> None:
|
|
return await self.prompt_handler.delete_prompt(name)
|
|
|
|
async def upsert_file(
|
|
self,
|
|
document_id: UUID,
|
|
file_name: str,
|
|
file_oid: int,
|
|
file_size: int,
|
|
file_type: Optional[str] = None,
|
|
) -> None:
|
|
return await self.file_handler.upsert_file(
|
|
document_id, file_name, file_oid, file_size, file_type
|
|
)
|
|
|
|
async def store_file(
|
|
self,
|
|
document_id: UUID,
|
|
file_name: str,
|
|
file_content: BytesIO,
|
|
file_type: Optional[str] = None,
|
|
) -> None:
|
|
return await self.file_handler.store_file(
|
|
document_id, file_name, file_content, file_type
|
|
)
|
|
|
|
async def retrieve_file(
|
|
self, document_id: UUID
|
|
) -> Optional[tuple[str, BinaryIO, int]]:
|
|
return await self.file_handler.retrieve_file(document_id)
|
|
|
|
async def delete_file(self, document_id: UUID) -> bool:
|
|
return await self.file_handler.delete_file(document_id)
|
|
|
|
async def get_files_overview(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filter_document_ids: Optional[list[UUID]] = None,
|
|
filter_file_names: Optional[list[str]] = None,
|
|
) -> list[dict]:
|
|
return await self.file_handler.get_files_overview(
|
|
offset=offset,
|
|
limit=limit,
|
|
filter_document_ids=filter_document_ids,
|
|
filter_file_names=filter_file_names,
|
|
)
|
|
|
|
async def log(
|
|
self,
|
|
run_id: UUID,
|
|
key: str,
|
|
value: str,
|
|
) -> None:
|
|
"""Add a new log entry."""
|
|
return await self.logging_handler.log(run_id, key, value)
|
|
|
|
async def info_log(
|
|
self,
|
|
run_id: UUID,
|
|
run_type: RunType,
|
|
user_id: UUID,
|
|
) -> None:
|
|
"""Add or update a log info entry."""
|
|
return await self.logging_handler.info_log(run_id, run_type, user_id)
|
|
|
|
async def get_info_logs(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
run_type_filter: Optional[RunType] = None,
|
|
user_ids: Optional[list[UUID]] = None,
|
|
) -> list[RunInfoLog]:
|
|
"""Retrieve log info entries with filtering and pagination."""
|
|
return await self.logging_handler.get_info_logs(
|
|
offset, limit, run_type_filter, user_ids
|
|
)
|
|
|
|
async def get_logs(
|
|
self,
|
|
run_ids: list[UUID],
|
|
limit_per_run: int = 10,
|
|
) -> list[dict[str, Any]]:
|
|
"""Retrieve logs for specified run IDs with a per-run limit."""
|
|
return await self.logging_handler.get_logs(run_ids, limit_per_run)
|
|
|
|
async def create_conversation(self) -> dict:
|
|
"""Create a new conversation and return its ID and timestamp."""
|
|
return await self.logging_handler.create_conversation()
|
|
|
|
async def delete_conversation(self, conversation_id: str) -> None:
|
|
"""Delete a conversation and all associated data."""
|
|
return await self.logging_handler.delete_conversation(conversation_id)
|
|
|
|
async def get_conversations(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
conversation_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, list[dict] | int]:
|
|
"""Get an overview of conversations with pagination."""
|
|
return await self.logging_handler.get_conversations(
|
|
offset=offset,
|
|
limit=limit,
|
|
conversation_ids=conversation_ids,
|
|
)
|
|
|
|
async def add_message(
|
|
self,
|
|
conversation_id: str,
|
|
content: Message,
|
|
parent_id: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
) -> str:
|
|
"""Add a message to a conversation."""
|
|
return await self.logging_handler.add_message(
|
|
conversation_id, content, parent_id, metadata
|
|
)
|
|
|
|
async def edit_message(
|
|
self, message_id: str, new_content: str
|
|
) -> Tuple[str, str]:
|
|
"""Edit an existing message and return new message ID and branch ID."""
|
|
return await self.logging_handler.edit_message(message_id, new_content)
|
|
|
|
async def get_conversation(
|
|
self, conversation_id: str, branch_id: Optional[str] = None
|
|
) -> list[Tuple[str, Message]]:
|
|
"""Retrieve all messages in a conversation branch."""
|
|
return await self.logging_handler.get_conversation(
|
|
conversation_id, branch_id
|
|
)
|
|
|
|
async def get_branches(self, conversation_id: str) -> list[dict]:
|
|
"""Get an overview of all branches in a conversation."""
|
|
return await self.logging_handler.get_branches(conversation_id)
|
|
|
|
async def get_next_branch(self, current_branch_id: str) -> Optional[str]:
|
|
"""Get the ID of the next branch in chronological order."""
|
|
return await self.logging_handler.get_next_branch(current_branch_id)
|
|
|
|
async def get_prev_branch(self, current_branch_id: str) -> Optional[str]:
|
|
"""Get the ID of the previous branch in chronological order."""
|
|
return await self.logging_handler.get_prev_branch(current_branch_id)
|
|
|
|
async def branch_at_message(self, message_id: str) -> str:
|
|
"""Create a new branch starting at a specific message."""
|
|
return await self.logging_handler.branch_at_message(message_id)
|
|
|
|
async def list_chunks(
|
|
self,
|
|
offset: int,
|
|
limit: int,
|
|
filters: Optional[dict[str, Any]] = None,
|
|
include_vectors: bool = False,
|
|
) -> dict[str, Any]:
|
|
return await self.vector_handler.list_chunks(
|
|
offset, limit, filters, include_vectors
|
|
)
|