e7db62e6bd
* Add scheduler and vacuum * Lint * Refactor test workflow, add mock scheduler to tests * 0 3 * * * * Missing quote in toml * Add maintenance service mock
217 lines
8.2 KiB
Python
217 lines
8.2 KiB
Python
# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
|
|
import logging
|
|
import os
|
|
from enum import Enum
|
|
from typing import Any, Optional
|
|
|
|
import toml
|
|
from pydantic import BaseModel
|
|
|
|
from ..base.abstractions import GenerationConfig
|
|
from ..base.agent.agent import RAGAgentConfig # type: ignore
|
|
from ..base.providers import AppConfig
|
|
from ..base.providers.auth import AuthConfig
|
|
from ..base.providers.crypto import CryptoConfig
|
|
from ..base.providers.database import DatabaseConfig
|
|
from ..base.providers.email import EmailConfig
|
|
from ..base.providers.embedding import EmbeddingConfig
|
|
from ..base.providers.ingestion import IngestionConfig
|
|
from ..base.providers.llm import CompletionConfig
|
|
from ..base.providers.orchestration import OrchestrationConfig
|
|
from ..base.providers.scheduler import SchedulerConfig
|
|
from ..base.utils import deep_update
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class R2RConfig:
|
|
current_file_path = os.path.dirname(__file__)
|
|
config_dir_root = os.path.join(current_file_path, "..", "configs")
|
|
default_config_path = os.path.join(
|
|
current_file_path, "..", "..", "r2r", "r2r.toml"
|
|
)
|
|
|
|
CONFIG_OPTIONS: dict[str, Optional[str]] = {}
|
|
for file_ in os.listdir(config_dir_root):
|
|
if file_.endswith(".toml"):
|
|
CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join(
|
|
config_dir_root, file_
|
|
)
|
|
CONFIG_OPTIONS["default"] = None
|
|
|
|
REQUIRED_KEYS: dict[str, list] = {
|
|
"app": [],
|
|
"completion": ["provider"],
|
|
"crypto": ["provider"],
|
|
"email": ["provider"],
|
|
"auth": ["provider"],
|
|
"embedding": [
|
|
"provider",
|
|
"base_model",
|
|
"base_dimension",
|
|
"batch_size",
|
|
"add_title_as_prefix",
|
|
],
|
|
"completion_embedding": [
|
|
"provider",
|
|
"base_model",
|
|
"base_dimension",
|
|
"batch_size",
|
|
"add_title_as_prefix",
|
|
],
|
|
# TODO - deprecated, remove
|
|
"ingestion": ["provider"],
|
|
"logging": ["provider", "log_table"],
|
|
"database": ["provider"],
|
|
"agent": ["generation_config"],
|
|
"orchestration": ["provider"],
|
|
"scheduler": ["provider"],
|
|
}
|
|
|
|
app: AppConfig
|
|
auth: AuthConfig
|
|
completion: CompletionConfig
|
|
crypto: CryptoConfig
|
|
database: DatabaseConfig
|
|
embedding: EmbeddingConfig
|
|
completion_embedding: EmbeddingConfig
|
|
email: EmailConfig
|
|
ingestion: IngestionConfig
|
|
agent: RAGAgentConfig
|
|
orchestration: OrchestrationConfig
|
|
scheduler: SchedulerConfig
|
|
|
|
def __init__(self, config_data: dict[str, Any]):
|
|
"""
|
|
:param config_data: dictionary of configuration parameters
|
|
"""
|
|
# Load the default configuration
|
|
default_config = self.load_default_config()
|
|
|
|
# Override the default configuration with the passed configuration
|
|
default_config = deep_update(default_config, config_data)
|
|
|
|
# Validate and set the configuration
|
|
for section, keys in R2RConfig.REQUIRED_KEYS.items():
|
|
# Check the keys when provider is set
|
|
# TODO - remove after deprecation
|
|
if section in ["graph", "file"] and section not in default_config:
|
|
continue
|
|
if "provider" in default_config[section] and (
|
|
default_config[section]["provider"] is not None
|
|
and default_config[section]["provider"] != "None"
|
|
and default_config[section]["provider"] != "null"
|
|
):
|
|
self._validate_config_section(default_config, section, keys)
|
|
setattr(self, section, default_config[section])
|
|
|
|
self.app = AppConfig.create(**self.app) # type: ignore
|
|
self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore
|
|
self.completion = CompletionConfig.create(
|
|
**self.completion, app=self.app
|
|
) # type: ignore
|
|
self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore
|
|
self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore
|
|
self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore
|
|
self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore
|
|
self.completion_embedding = EmbeddingConfig.create(
|
|
**self.completion_embedding, app=self.app
|
|
) # type: ignore
|
|
self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore
|
|
self.agent = RAGAgentConfig.create(**self.agent, app=self.app) # type: ignore
|
|
self.orchestration = OrchestrationConfig.create(
|
|
**self.orchestration, app=self.app
|
|
) # type: ignore
|
|
self.scheduler = SchedulerConfig.create(**self.scheduler, app=self.app) # type: ignore
|
|
|
|
IngestionConfig.set_default(**self.ingestion.model_dump())
|
|
|
|
# override GenerationConfig defaults
|
|
if self.completion.generation_config:
|
|
GenerationConfig.set_default(
|
|
**self.completion.generation_config.model_dump()
|
|
)
|
|
|
|
def _validate_config_section(
|
|
self, config_data: dict[str, Any], section: str, keys: list
|
|
):
|
|
if section not in config_data:
|
|
raise ValueError(f"Missing '{section}' section in config")
|
|
if missing_keys := [
|
|
key for key in keys if key not in config_data[section]
|
|
]:
|
|
raise ValueError(
|
|
f"Missing required keys in '{section}' config: {', '.join(missing_keys)}"
|
|
)
|
|
|
|
@classmethod
|
|
def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig":
|
|
if config_path is None:
|
|
config_path = R2RConfig.default_config_path
|
|
|
|
# Load configuration from TOML file
|
|
with open(config_path, encoding="utf-8") as f:
|
|
config_data = toml.load(f)
|
|
|
|
return cls(config_data)
|
|
|
|
def to_toml(self):
|
|
config_data = {}
|
|
for section in R2RConfig.REQUIRED_KEYS.keys():
|
|
section_data = self._serialize_config(getattr(self, section))
|
|
if isinstance(section_data, dict):
|
|
# Remove app from nested configs before serializing
|
|
section_data.pop("app", None)
|
|
config_data[section] = section_data
|
|
return toml.dumps(config_data)
|
|
|
|
@classmethod
|
|
def load_default_config(cls) -> dict:
|
|
with open(R2RConfig.default_config_path, encoding="utf-8") as f:
|
|
return toml.load(f)
|
|
|
|
@staticmethod
|
|
def _serialize_config(config_section: Any):
|
|
"""Serialize config section while excluding internal state."""
|
|
if isinstance(config_section, dict):
|
|
return {
|
|
R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
|
|
for k, v in config_section.items()
|
|
if k != "app" # Exclude app from serialization
|
|
}
|
|
elif isinstance(config_section, (list, tuple)):
|
|
return [
|
|
R2RConfig._serialize_config(item) for item in config_section
|
|
]
|
|
elif isinstance(config_section, Enum):
|
|
return config_section.value
|
|
elif isinstance(config_section, BaseModel):
|
|
data = config_section.model_dump(exclude_none=True)
|
|
data.pop("app", None) # Remove app from the serialized data
|
|
return R2RConfig._serialize_config(data)
|
|
else:
|
|
return config_section
|
|
|
|
@staticmethod
|
|
def _serialize_key(key: Any) -> str:
|
|
return key.value if isinstance(key, Enum) else str(key)
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
config_name: Optional[str] = None,
|
|
config_path: Optional[str] = None,
|
|
) -> "R2RConfig":
|
|
if config_path and config_name:
|
|
raise ValueError(
|
|
f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}"
|
|
)
|
|
|
|
if config_path := os.getenv("R2R_CONFIG_PATH") or config_path:
|
|
return cls.from_toml(config_path)
|
|
|
|
config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default"
|
|
if config_name not in R2RConfig.CONFIG_OPTIONS:
|
|
raise ValueError(f"Invalid config name: {config_name}")
|
|
return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name])
|