201 lines
5.9 KiB
Python
201 lines
5.9 KiB
Python
"""Base classes for database providers."""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Optional, Sequence
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from core.base.abstractions import (
|
|
GraphSearchSettings,
|
|
KGCreationSettings,
|
|
KGEnrichmentSettings,
|
|
)
|
|
|
|
from .base import Provider, ProviderConfig
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class DatabaseConnectionManager(ABC):
|
|
@abstractmethod
|
|
def execute_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
isolation_level: Optional[str] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def execute_many(self, query, params=None, batch_size=1000):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetch_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetchrow_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[dict[str, Any] | Sequence[Any]] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def initialize(self, pool: Any):
|
|
pass
|
|
|
|
|
|
class Handler(ABC):
|
|
def __init__(
|
|
self,
|
|
project_name: str,
|
|
connection_manager: DatabaseConnectionManager,
|
|
):
|
|
self.project_name = project_name
|
|
self.connection_manager = connection_manager
|
|
|
|
def _get_table_name(self, base_name: str) -> str:
|
|
return f"{self.project_name}.{base_name}"
|
|
|
|
@abstractmethod
|
|
def create_tables(self):
|
|
pass
|
|
|
|
|
|
class PostgresConfigurationSettings(BaseModel):
|
|
"""
|
|
Configuration settings with defaults defined by the PGVector docker image.
|
|
|
|
These settings are helpful in managing the connections to the database.
|
|
To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
|
|
"""
|
|
|
|
checkpoint_completion_target: Optional[float] = 0.9
|
|
default_statistics_target: Optional[int] = 100
|
|
effective_io_concurrency: Optional[int] = 1
|
|
effective_cache_size: Optional[int] = 524288
|
|
huge_pages: Optional[str] = "try"
|
|
maintenance_work_mem: Optional[int] = 65536
|
|
max_connections: Optional[int] = 256
|
|
max_parallel_workers_per_gather: Optional[int] = 2
|
|
max_parallel_workers: Optional[int] = 8
|
|
max_parallel_maintenance_workers: Optional[int] = 2
|
|
max_wal_size: Optional[int] = 1024
|
|
max_worker_processes: Optional[int] = 8
|
|
min_wal_size: Optional[int] = 80
|
|
shared_buffers: Optional[int] = 16384
|
|
statement_cache_size: Optional[int] = 100
|
|
random_page_cost: Optional[float] = 4
|
|
wal_buffers: Optional[int] = 512
|
|
work_mem: Optional[int] = 4096
|
|
|
|
|
|
class LimitSettings(BaseModel):
|
|
global_per_min: Optional[int] = None
|
|
route_per_min: Optional[int] = None
|
|
monthly_limit: Optional[int] = None
|
|
|
|
def merge_with_defaults(
|
|
self, defaults: "LimitSettings"
|
|
) -> "LimitSettings":
|
|
return LimitSettings(
|
|
global_per_min=self.global_per_min or defaults.global_per_min,
|
|
route_per_min=self.route_per_min or defaults.route_per_min,
|
|
monthly_limit=self.monthly_limit or defaults.monthly_limit,
|
|
)
|
|
|
|
|
|
class DatabaseConfig(ProviderConfig):
|
|
"""A base database configuration class"""
|
|
|
|
provider: str = "postgres"
|
|
user: Optional[str] = None
|
|
password: Optional[str] = None
|
|
host: Optional[str] = None
|
|
port: Optional[int] = None
|
|
db_name: Optional[str] = None
|
|
project_name: Optional[str] = None
|
|
postgres_configuration_settings: Optional[
|
|
PostgresConfigurationSettings
|
|
] = None
|
|
default_collection_name: str = "Default"
|
|
default_collection_description: str = "Your default collection."
|
|
collection_summary_system_prompt: str = "default_system"
|
|
collection_summary_task_prompt: str = "default_collection_summary"
|
|
enable_fts: bool = False
|
|
|
|
# Graph settings
|
|
batch_size: Optional[int] = 1
|
|
kg_store_path: Optional[str] = None
|
|
graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
|
|
graph_creation_settings: KGCreationSettings = KGCreationSettings()
|
|
graph_search_settings: GraphSearchSettings = GraphSearchSettings()
|
|
|
|
# Rate limits
|
|
limits: LimitSettings = LimitSettings(
|
|
global_per_min=60, route_per_min=20, monthly_limit=10000
|
|
)
|
|
route_limits: dict[str, LimitSettings] = {}
|
|
user_limits: dict[UUID, LimitSettings] = {}
|
|
|
|
def __post_init__(self):
|
|
self.validate_config()
|
|
# Capture additional fields
|
|
for key, value in self.extra_fields.items():
|
|
setattr(self, key, value)
|
|
|
|
def validate_config(self) -> None:
|
|
if self.provider not in self.supported_providers:
|
|
raise ValueError(f"Provider '{self.provider}' is not supported.")
|
|
|
|
@property
|
|
def supported_providers(self) -> list[str]:
|
|
return ["postgres"]
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
|
|
instance = super().from_dict(
|
|
data
|
|
) # or some logic to create the base instance
|
|
|
|
limits_data = data.get("limits", {})
|
|
default_limits = LimitSettings(
|
|
global_per_min=limits_data.get("global_per_min", 60),
|
|
route_per_min=limits_data.get("route_per_min", 20),
|
|
monthly_limit=limits_data.get("monthly_limit", 10000),
|
|
)
|
|
|
|
instance.limits = default_limits
|
|
|
|
route_limits_data = limits_data.get("routes", {})
|
|
for route_str, route_cfg in route_limits_data.items():
|
|
instance.route_limits[route_str] = LimitSettings(**route_cfg)
|
|
|
|
return instance
|
|
|
|
|
|
class DatabaseProvider(Provider):
|
|
connection_manager: DatabaseConnectionManager
|
|
config: DatabaseConfig
|
|
project_name: str
|
|
|
|
def __init__(self, config: DatabaseConfig):
|
|
logger.info(f"Initializing DatabaseProvider with config {config}.")
|
|
super().__init__(config)
|
|
|
|
@abstractmethod
|
|
async def __aenter__(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
pass
|