first pass commit
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
from ..config import R2RConfig
|
||||
from .builder import R2RBuilder
|
||||
from .factory import (
|
||||
R2RAgentFactory,
|
||||
R2RPipeFactory,
|
||||
R2RPipelineFactory,
|
||||
R2RProviderFactory,
|
||||
)
|
||||
from .factory import R2RProviderFactory
|
||||
|
||||
__all__ = [
|
||||
# Builder
|
||||
@@ -14,7 +9,4 @@ __all__ = [
|
||||
"R2RConfig",
|
||||
# Factory
|
||||
"R2RProviderFactory",
|
||||
"R2RPipeFactory",
|
||||
"R2RPipelineFactory",
|
||||
"R2RAgentFactory",
|
||||
]
|
||||
|
||||
@@ -1,26 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
|
||||
from core.agent import R2RRAGAgent
|
||||
from core.base import (
|
||||
AsyncPipe,
|
||||
AuthProvider,
|
||||
CompletionProvider,
|
||||
CryptoProvider,
|
||||
DatabaseProvider,
|
||||
EmbeddingProvider,
|
||||
OrchestrationProvider,
|
||||
RunManager,
|
||||
)
|
||||
from core.main.abstractions import R2RServices
|
||||
from core.main.services.auth_service import AuthService
|
||||
from core.main.services.graph_service import GraphService
|
||||
from core.main.services.ingestion_service import IngestionService
|
||||
from core.main.services.management_service import ManagementService
|
||||
from core.main.services.retrieval_service import RetrievalService
|
||||
from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
|
||||
|
||||
from ..abstractions import R2RProviders
|
||||
from ..abstractions import R2RProviders, R2RServices
|
||||
from ..api.v3.chunks_router import ChunksRouter
|
||||
from ..api.v3.collections_router import CollectionsRouter
|
||||
from ..api.v3.conversations_router import ConversationsRouter
|
||||
@@ -34,53 +15,22 @@ from ..api.v3.system_router import SystemRouter
|
||||
from ..api.v3.users_router import UsersRouter
|
||||
from ..app import R2RApp
|
||||
from ..config import R2RConfig
|
||||
from .factory import (
|
||||
R2RAgentFactory,
|
||||
R2RPipeFactory,
|
||||
R2RPipelineFactory,
|
||||
R2RProviderFactory,
|
||||
)
|
||||
from ..services.auth_service import AuthService
|
||||
from ..services.graph_service import GraphService
|
||||
from ..services.ingestion_service import IngestionService
|
||||
from ..services.management_service import ManagementService
|
||||
from ..services.retrieval_service import RetrievalService
|
||||
from .factory import R2RProviderFactory
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class R2RBuilder:
|
||||
_SERVICES = ["auth", "ingestion", "management", "retrieval", "graph"]
|
||||
|
||||
def __init__(self, config: R2RConfig):
|
||||
self.config = config
|
||||
|
||||
def _create_pipes(
|
||||
self,
|
||||
pipe_factory: type[R2RPipeFactory],
|
||||
providers: R2RProviders,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
return pipe_factory(self.config, providers).create_pipes(
|
||||
overrides={}, *args, **kwargs
|
||||
)
|
||||
|
||||
def _create_pipelines(
|
||||
self,
|
||||
pipeline_factory: type[R2RPipelineFactory],
|
||||
providers: R2RProviders,
|
||||
pipes: Any,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
return pipeline_factory(
|
||||
self.config, providers, pipes
|
||||
).create_pipelines(*args, **kwargs)
|
||||
|
||||
def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
|
||||
services = ["auth", "ingestion", "management", "retrieval", "graph"]
|
||||
service_instances = {}
|
||||
|
||||
for service_type in services:
|
||||
service_class = globals()[f"{service_type.capitalize()}Service"]
|
||||
service_instances[service_type] = service_class(**service_params)
|
||||
|
||||
return R2RServices(**service_instances)
|
||||
|
||||
async def _create_providers(
|
||||
self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
|
||||
) -> Any:
|
||||
@@ -89,35 +39,18 @@ class R2RBuilder:
|
||||
|
||||
async def build(self, *args, **kwargs) -> R2RApp:
|
||||
provider_factory = R2RProviderFactory
|
||||
pipe_factory = R2RPipeFactory
|
||||
pipeline_factory = R2RPipelineFactory
|
||||
|
||||
try:
|
||||
providers = await self._create_providers(
|
||||
provider_factory, *args, **kwargs
|
||||
)
|
||||
pipes = self._create_pipes(
|
||||
pipe_factory, providers, *args, **kwargs
|
||||
)
|
||||
pipelines = self._create_pipelines(
|
||||
pipeline_factory, providers, pipes, *args, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating providers, pipes, or pipelines: {e}")
|
||||
logger.error(f"Error {e} while creating R2RProviders.")
|
||||
raise
|
||||
|
||||
assistant_factory = R2RAgentFactory(self.config, providers, pipelines)
|
||||
agents = assistant_factory.create_agents(*args, **kwargs)
|
||||
|
||||
run_manager = RunManager()
|
||||
|
||||
service_params = {
|
||||
"config": self.config,
|
||||
"providers": providers,
|
||||
"pipes": pipes,
|
||||
"pipelines": pipelines,
|
||||
"agents": agents,
|
||||
"run_manager": run_manager,
|
||||
}
|
||||
|
||||
services = self._create_services(service_params)
|
||||
@@ -172,5 +105,22 @@ class R2RBuilder:
|
||||
return R2RApp(
|
||||
config=self.config,
|
||||
orchestration_provider=providers.orchestration,
|
||||
services=services,
|
||||
**routers,
|
||||
)
|
||||
|
||||
async def _create_providers(
|
||||
self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
|
||||
) -> R2RProviders:
|
||||
factory = provider_factory(self.config)
|
||||
return await factory.create_providers(*args, **kwargs)
|
||||
|
||||
def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
|
||||
services = R2RBuilder._SERVICES
|
||||
service_instances = {}
|
||||
|
||||
for service_type in services:
|
||||
service_class = globals()[f"{service_type.capitalize()}Service"]
|
||||
service_instances[service_type] = service_class(**service_params)
|
||||
|
||||
return R2RServices(**service_instances)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user