207 lines
5.7 KiB
Python
207 lines
5.7 KiB
Python
import logging
|
|
from abc import abstractmethod
|
|
from datetime import datetime
|
|
from io import BytesIO
|
|
from typing import BinaryIO, Optional, Tuple
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from core.base.abstractions import (
|
|
ChunkSearchResult,
|
|
Community,
|
|
DocumentResponse,
|
|
Entity,
|
|
IndexArgsHNSW,
|
|
IndexArgsIVFFlat,
|
|
IndexMeasure,
|
|
IndexMethod,
|
|
KGCreationSettings,
|
|
KGEnrichmentSettings,
|
|
KGEntityDeduplicationSettings,
|
|
Message,
|
|
Relationship,
|
|
SearchSettings,
|
|
User,
|
|
VectorEntry,
|
|
VectorTableName,
|
|
)
|
|
from core.base.api.models import CollectionResponse, GraphResponse
|
|
|
|
from .base import Provider, ProviderConfig
|
|
|
|
"""Base classes for knowledge graph providers."""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Optional, Sequence, Tuple, Type
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from ..abstractions import (
|
|
Community,
|
|
Entity,
|
|
GraphSearchSettings,
|
|
KGCreationSettings,
|
|
KGEnrichmentSettings,
|
|
KGEntityDeduplicationSettings,
|
|
KGExtraction,
|
|
R2RSerializable,
|
|
Relationship,
|
|
)
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class DatabaseConnectionManager(ABC):
|
|
@abstractmethod
|
|
def execute_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
isolation_level: Optional[str] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def execute_many(self, query, params=None, batch_size=1000):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetch_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetchrow_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def initialize(self, pool: Any):
|
|
pass
|
|
|
|
|
|
class Handler(ABC):
|
|
def __init__(
|
|
self,
|
|
project_name: str,
|
|
connection_manager: DatabaseConnectionManager,
|
|
):
|
|
self.project_name = project_name
|
|
self.connection_manager = connection_manager
|
|
|
|
def _get_table_name(self, base_name: str) -> str:
|
|
return f"{self.project_name}.{base_name}"
|
|
|
|
@abstractmethod
|
|
def create_tables(self):
|
|
pass
|
|
|
|
|
|
class PostgresConfigurationSettings(BaseModel):
|
|
"""
|
|
Configuration settings with defaults defined by the PGVector docker image.
|
|
|
|
These settings are helpful in managing the connections to the database.
|
|
To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
|
|
"""
|
|
|
|
checkpoint_completion_target: Optional[float] = 0.9
|
|
default_statistics_target: Optional[int] = 100
|
|
effective_io_concurrency: Optional[int] = 1
|
|
effective_cache_size: Optional[int] = 524288
|
|
huge_pages: Optional[str] = "try"
|
|
maintenance_work_mem: Optional[int] = 65536
|
|
max_connections: Optional[int] = 256
|
|
max_parallel_workers_per_gather: Optional[int] = 2
|
|
max_parallel_workers: Optional[int] = 8
|
|
max_parallel_maintenance_workers: Optional[int] = 2
|
|
max_wal_size: Optional[int] = 1024
|
|
max_worker_processes: Optional[int] = 8
|
|
min_wal_size: Optional[int] = 80
|
|
shared_buffers: Optional[int] = 16384
|
|
statement_cache_size: Optional[int] = 100
|
|
random_page_cost: Optional[float] = 4
|
|
wal_buffers: Optional[int] = 512
|
|
work_mem: Optional[int] = 4096
|
|
|
|
|
|
class DatabaseConfig(ProviderConfig):
|
|
"""A base database configuration class"""
|
|
|
|
provider: str = "postgres"
|
|
user: Optional[str] = None
|
|
password: Optional[str] = None
|
|
host: Optional[str] = None
|
|
port: Optional[int] = None
|
|
db_name: Optional[str] = None
|
|
project_name: Optional[str] = None
|
|
postgres_configuration_settings: Optional[
|
|
PostgresConfigurationSettings
|
|
] = None
|
|
default_collection_name: str = "Default"
|
|
default_collection_description: str = "Your default collection."
|
|
collection_summary_system_prompt: str = "default_system"
|
|
collection_summary_task_prompt: str = "default_collection_summary"
|
|
enable_fts: bool = False
|
|
|
|
# KG settings
|
|
batch_size: Optional[int] = 1
|
|
kg_store_path: Optional[str] = None
|
|
graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
|
|
graph_creation_settings: KGCreationSettings = KGCreationSettings()
|
|
graph_entity_deduplication_settings: KGEntityDeduplicationSettings = (
|
|
KGEntityDeduplicationSettings()
|
|
)
|
|
graph_search_settings: GraphSearchSettings = GraphSearchSettings()
|
|
|
|
def __post_init__(self):
|
|
self.validate_config()
|
|
# Capture additional fields
|
|
for key, value in self.extra_fields.items():
|
|
setattr(self, key, value)
|
|
|
|
def validate_config(self) -> None:
|
|
if self.provider not in self.supported_providers:
|
|
raise ValueError(f"Provider '{self.provider}' is not supported.")
|
|
|
|
@property
|
|
def supported_providers(self) -> list[str]:
|
|
return ["postgres"]
|
|
|
|
|
|
class DatabaseProvider(Provider):
|
|
connection_manager: DatabaseConnectionManager
|
|
# documents_handler: DocumentHandler
|
|
# collections_handler: CollectionsHandler
|
|
# token_handler: TokenHandler
|
|
# users_handler: UserHandler
|
|
# chunks_handler: ChunkHandler
|
|
# entity_handler: EntityHandler
|
|
# relationship_handler: RelationshipHandler
|
|
# graphs_handler: GraphHandler
|
|
# prompts_handler: PromptHandler
|
|
# files_handler: FileHandler
|
|
config: DatabaseConfig
|
|
project_name: str
|
|
|
|
def __init__(self, config: DatabaseConfig):
|
|
logger.info(f"Initializing DatabaseProvider with config {config}.")
|
|
super().__init__(config)
|
|
|
|
@abstractmethod
|
|
async def __aenter__(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
pass
|