This commit is contained in:
emrgnt-cmplxty
2024-12-17 18:13:38 -08:00
parent fbe8337268
commit 563771862c
29 changed files with 242 additions and 233 deletions
+5 -5
View File
@@ -173,7 +173,7 @@ __all__ = [
"IngestionService",
"ManagementService",
"RetrievalService",
"KgService",
"GraphService",
## PARSERS
# Media parsers
"AudioParser",
@@ -213,14 +213,14 @@ __all__ = [
## PIPES
"SearchPipe",
"EmbeddingPipe",
"KGExtractionPipe",
"GraphExtractionPipe",
"ParsingPipe",
"QueryTransformPipe",
"SearchRAGPipe",
"StreamingSearchRAGPipe",
"RAGPipe",
"StreamingRAGPipe",
"VectorSearchPipe",
"VectorStoragePipe",
"KGStoragePipe",
"GraphStoragePipe",
"MultiSearchPipe",
## PROVIDERS
# Auth
+9 -9
View File
@@ -10,15 +10,7 @@ from shared.api.models.base import (
WrappedBooleanResponse,
WrappedGenericMessageResponse,
)
from shared.api.models.ingestion.responses import (
IngestionResponse,
UpdateResponse,
WrappedIngestionResponse,
WrappedListVectorIndicesResponse,
WrappedMetadataUpdateResponse,
WrappedUpdateResponse,
)
from shared.api.models.kg.responses import ( # TODO: Need to review anything above this
from shared.api.models.graph.responses import ( # TODO: Need to review anything above this
Community,
Entity,
GraphResponse,
@@ -32,6 +24,14 @@ from shared.api.models.kg.responses import ( # TODO: Need to review anything ab
WrappedRelationshipResponse,
WrappedRelationshipsResponse,
)
from shared.api.models.ingestion.responses import (
IngestionResponse,
UpdateResponse,
WrappedIngestionResponse,
WrappedListVectorIndicesResponse,
WrappedMetadataUpdateResponse,
WrappedUpdateResponse,
)
from shared.api.models.management.responses import ( # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed
AnalyticsResponse,
ChunkResponse,
+1 -1
View File
@@ -30,5 +30,5 @@ __all__ = [
"IngestionService",
"ManagementService",
"RetrievalService",
"KgService",
"GraphService",
]
+32 -17
View File
@@ -4,9 +4,24 @@ from typing import TYPE_CHECKING, Optional
from pydantic import BaseModel
from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
from core.base.pipes import AsyncPipe
from core.database import PostgresDatabaseProvider
from core.pipelines import RAGPipeline, SearchPipeline
from core.pipes import (
EmbeddingPipe,
GraphClusteringPipe,
GraphCommunitySummaryPipe,
GraphDeduplicationPipe,
GraphDeduplicationSummaryPipe,
GraphDescriptionPipe,
GraphExtractionPipe,
GraphSearchSearchPipe,
GraphStoragePipe,
ParsingPipe,
RAGPipe,
StreamingRAGPipe,
VectorSearchPipe,
VectorStoragePipe,
)
from core.providers import (
AsyncSMTPEmailProvider,
ConsoleMockEmailProvider,
@@ -26,8 +41,8 @@ from core.providers import (
if TYPE_CHECKING:
from core.main.services.auth_service import AuthService
from core.main.services.graph_service import GraphService
from core.main.services.ingestion_service import IngestionService
from core.main.services.kg_service import KgService
from core.main.services.management_service import ManagementService
from core.main.services.retrieval_service import RetrievalService
@@ -54,20 +69,20 @@ class R2RProviders(BaseModel):
class R2RPipes(BaseModel):
parsing_pipe: AsyncPipe
embedding_pipe: AsyncPipe
kg_search_pipe: AsyncPipe
kg_relationships_extraction_pipe: AsyncPipe
kg_storage_pipe: AsyncPipe
kg_entity_description_pipe: AsyncPipe
kg_clustering_pipe: AsyncPipe
kg_entity_deduplication_pipe: AsyncPipe
kg_entity_deduplication_summary_pipe: AsyncPipe
kg_community_summary_pipe: AsyncPipe
rag_pipe: AsyncPipe
streaming_rag_pipe: AsyncPipe
vector_storage_pipe: AsyncPipe
vector_search_pipe: AsyncPipe
parsing_pipe: ParsingPipe
embedding_pipe: EmbeddingPipe
graph_search_pipe: GraphSearchSearchPipe
graph_extraction_pipe: GraphExtractionPipe
graph_storage_pipe: GraphStoragePipe
graph_description_pipe: GraphDescriptionPipe
graph_clustering_pipe: GraphClusteringPipe
graph_deduplication_pipe: GraphDeduplicationPipe
graph_deduplication_summary_pipe: GraphDeduplicationSummaryPipe
graph_community_summary_pipe: GraphCommunitySummaryPipe
rag_pipe: RAGPipe
streaming_rag_pipe: StreamingRAGPipe
vector_storage_pipe: VectorStoragePipe
vector_search_pipe: VectorSearchPipe
class Config:
arbitrary_types_allowed = True
@@ -96,4 +111,4 @@ class R2RServices:
ingestion: Optional["IngestionService"] = None
management: Optional["ManagementService"] = None
retrieval: Optional["RetrievalService"] = None
kg: Optional["KgService"] = None
graph: Optional["GraphService"] = None
+1 -1
View File
@@ -1144,7 +1144,7 @@ class CollectionsRouter(BaseRouterV3):
from core.main.orchestration import simple_kg_factory
logger.info("Running extract-triples without orchestration.")
simple_kg = simple_kg_factory(self.services.kg)
simple_kg = simple_kg_factory(self.services.graph)
await simple_kg["extract-triples"](workflow_input) # type: ignore
return { # type: ignore
"message": "Graph created successfully.",
+2 -2
View File
@@ -1359,7 +1359,7 @@ class DocumentsRouter(BaseRouterV3):
"message": "Estimate retrieved successfully",
"task_id": None,
"id": id,
"estimate": await self.services.kg.get_creation_estimate(
"estimate": await self.services.graph.get_creation_estimate(
document_id=id,
graph_creation_settings=server_graph_creation_settings,
),
@@ -1379,7 +1379,7 @@ class DocumentsRouter(BaseRouterV3):
from core.main.orchestration import simple_kg_factory
logger.info("Running extract-triples without orchestration.")
simple_kg = simple_kg_factory(self.services.kg)
simple_kg = simple_kg_factory(self.services.graph)
await simple_kg["extract-triples"](workflow_input)
return { # type: ignore
"message": "Graph created successfully.",
+21 -21
View File
@@ -65,7 +65,7 @@ class GraphRouter(BaseRouterV3):
self.providers.orchestration.register_workflows(
Workflow.KG,
self.services.kg,
self.services.graph,
workflow_messages,
)
@@ -119,7 +119,7 @@ class GraphRouter(BaseRouterV3):
# Return cost estimate if requested
if run_type == KGRunType.ESTIMATE:
return await self.services.kg.get_deduplication_estimate(
return await self.services.graph.get_deduplication_estimate(
collection_id, server_settings
)
@@ -136,7 +136,7 @@ class GraphRouter(BaseRouterV3):
else:
from core.main.orchestration import simple_kg_factory
simple_kg = simple_kg_factory(self.services.kg)
simple_kg = simple_kg_factory(self.services.graph)
await simple_kg["entity-deduplication"](workflow_input)
return { # type: ignore
"message": "Entity deduplication completed successfully.",
@@ -223,7 +223,7 @@ class GraphRouter(BaseRouterV3):
graph_uuids = [UUID(graph_id) for graph_id in collection_ids]
list_graphs_response = await self.services.kg.list_graphs(
list_graphs_response = await self.services.graph.list_graphs(
# user_ids=requesting_user_id,
graph_ids=graph_uuids,
offset=offset,
@@ -302,7 +302,7 @@ class GraphRouter(BaseRouterV3):
403,
)
list_graphs_response = await self.services.kg.list_graphs(
list_graphs_response = await self.services.graph.list_graphs(
# user_ids=None,
graph_ids=[collection_id],
offset=0,
@@ -394,7 +394,7 @@ class GraphRouter(BaseRouterV3):
from core.main.orchestration import simple_kg_factory
logger.info("Running build-communities without orchestration.")
simple_kg = simple_kg_factory(self.services.kg)
simple_kg = simple_kg_factory(self.services.graph)
await simple_kg["build-communities"](workflow_input)
return {
"message": "Graph communities created successfully.",
@@ -476,7 +476,7 @@ class GraphRouter(BaseRouterV3):
403,
)
await self.services.kg.reset_graph_v3(id=collection_id)
await self.services.graph.reset_graph_v3(id=collection_id)
# await _pull(collection_id, auth_user)
return GenericBooleanResponse(success=True) # type: ignore
@@ -562,7 +562,7 @@ class GraphRouter(BaseRouterV3):
403,
)
return await self.services.kg.update_graph( # type: ignore
return await self.services.graph.update_graph( # type: ignore
collection_id,
name=name,
description=description,
@@ -637,7 +637,7 @@ class GraphRouter(BaseRouterV3):
403,
)
entities, count = await self.services.kg.get_entities(
entities, count = await self.services.graph.get_entities(
parent_id=collection_id,
offset=offset,
limit=limit,
@@ -682,7 +682,7 @@ class GraphRouter(BaseRouterV3):
403,
)
return await self.services.kg.create_entity(
return await self.services.graph.create_entity(
name=name,
description=description,
parent_id=collection_id,
@@ -744,7 +744,7 @@ class GraphRouter(BaseRouterV3):
"The currently authenticated user does not have access to the collection associated with the given graph.",
403,
)
return await self.services.kg.create_relationship(
return await self.services.graph.create_relationship(
subject=subject,
subject_id=subject_id,
predicate=predicate,
@@ -874,7 +874,7 @@ class GraphRouter(BaseRouterV3):
403,
)
return await self.services.kg.update_entity(
return await self.services.graph.update_entity(
entity_id=entity_id,
name=name,
category=category,
@@ -954,7 +954,7 @@ class GraphRouter(BaseRouterV3):
403,
)
await self.services.kg.delete_entity(
await self.services.graph.delete_entity(
parent_id=collection_id,
entity_id=entity_id,
)
@@ -1033,7 +1033,7 @@ class GraphRouter(BaseRouterV3):
403,
)
relationships, count = await self.services.kg.get_relationships(
relationships, count = await self.services.graph.get_relationships(
parent_id=collection_id,
offset=offset,
limit=limit,
@@ -1178,7 +1178,7 @@ class GraphRouter(BaseRouterV3):
403,
)
return await self.services.kg.update_relationship(
return await self.services.graph.update_relationship(
relationship_id=relationship_id,
subject=subject,
subject_id=subject_id,
@@ -1261,7 +1261,7 @@ class GraphRouter(BaseRouterV3):
403,
)
await self.services.kg.delete_relationship(
await self.services.graph.delete_relationship(
parent_id=collection_id,
relationship_id=relationship_id,
)
@@ -1368,7 +1368,7 @@ class GraphRouter(BaseRouterV3):
403,
)
return await self.services.kg.create_community(
return await self.services.graph.create_community(
parent_id=collection_id,
name=name,
summary=summary,
@@ -1449,7 +1449,7 @@ class GraphRouter(BaseRouterV3):
403,
)
communities, count = await self.services.kg.get_communities(
communities, count = await self.services.graph.get_communities(
parent_id=collection_id,
offset=offset,
limit=limit,
@@ -1611,7 +1611,7 @@ class GraphRouter(BaseRouterV3):
403,
)
await self.services.kg.delete_community(
await self.services.graph.delete_community(
parent_id=collection_id,
community_id=community_id,
)
@@ -1703,7 +1703,7 @@ class GraphRouter(BaseRouterV3):
403,
)
return await self.services.kg.update_community(
return await self.services.graph.update_community(
community_id=community_id,
name=name,
summary=summary,
@@ -1801,7 +1801,7 @@ class GraphRouter(BaseRouterV3):
403,
)
list_graphs_response = await self.services.kg.list_graphs(
list_graphs_response = await self.services.graph.list_graphs(
# user_ids=None,
graph_ids=[collection_id],
offset=0,
+2 -2
View File
@@ -14,8 +14,8 @@ from core.base import (
)
from core.main.abstractions import R2RServices
from core.main.services.auth_service import AuthService
from core.main.services.graph_service import GraphService
from core.main.services.ingestion_service import IngestionService
from core.main.services.kg_service import KgService
from core.main.services.management_service import ManagementService
from core.main.services.retrieval_service import RetrievalService
from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
@@ -85,7 +85,7 @@ class R2RBuilder:
ingestion=service_instances["ingestion"],
management=service_instances["management"],
retrieval=service_instances["retrieval"],
kg=service_instances["kg"],
graph=service_instances["graph"],
)
async def _create_providers(
+62 -68
View File
@@ -364,18 +364,18 @@ class R2RPipeFactory:
self,
parsing_pipe_override: Optional[AsyncPipe] = None,
embedding_pipe_override: Optional[AsyncPipe] = None,
kg_relationships_extraction_pipe_override: Optional[AsyncPipe] = None,
kg_storage_pipe_override: Optional[AsyncPipe] = None,
kg_search_pipe_override: Optional[AsyncPipe] = None,
graph_extraction_pipe_override: Optional[AsyncPipe] = None,
graph_storage_pipe_override: Optional[AsyncPipe] = None,
graph_search_pipe_override: Optional[AsyncPipe] = None,
vector_storage_pipe_override: Optional[AsyncPipe] = None,
vector_search_pipe_override: Optional[AsyncPipe] = None,
rag_pipe_override: Optional[AsyncPipe] = None,
streaming_rag_pipe_override: Optional[AsyncPipe] = None,
kg_entity_description_pipe: Optional[AsyncPipe] = None,
kg_clustering_pipe: Optional[AsyncPipe] = None,
kg_entity_deduplication_pipe: Optional[AsyncPipe] = None,
kg_entity_deduplication_summary_pipe: Optional[AsyncPipe] = None,
kg_community_summary_pipe: Optional[AsyncPipe] = None,
graph_description_pipe: Optional[AsyncPipe] = None,
graph_clustering_pipe: Optional[AsyncPipe] = None,
graph_deduplication_pipe: Optional[AsyncPipe] = None,
graph_deduplication_summary_pipe: Optional[AsyncPipe] = None,
graph_community_summary_pipe: Optional[AsyncPipe] = None,
*args,
**kwargs,
) -> R2RPipes:
@@ -388,32 +388,30 @@ class R2RPipeFactory:
),
embedding_pipe=embedding_pipe_override
or self.create_embedding_pipe(*args, **kwargs),
kg_relationships_extraction_pipe=kg_relationships_extraction_pipe_override
or self.create_kg_relationships_extraction_pipe(*args, **kwargs),
kg_storage_pipe=kg_storage_pipe_override
or self.create_kg_storage_pipe(*args, **kwargs),
graph_extraction_pipe=graph_extraction_pipe_override
or self.create_graph_extraction_pipe(*args, **kwargs),
graph_storage_pipe=graph_storage_pipe_override
or self.create_graph_storage_pipe(*args, **kwargs),
vector_storage_pipe=vector_storage_pipe_override
or self.create_vector_storage_pipe(*args, **kwargs),
vector_search_pipe=vector_search_pipe_override
or self.create_vector_search_pipe(*args, **kwargs),
kg_search_pipe=kg_search_pipe_override
or self.create_kg_search_pipe(*args, **kwargs),
graph_search_pipe=graph_search_pipe_override
or self.create_graph_search_pipe(*args, **kwargs),
rag_pipe=rag_pipe_override
or self.create_rag_pipe(*args, **kwargs),
streaming_rag_pipe=streaming_rag_pipe_override
or self.create_rag_pipe(True, *args, **kwargs),
kg_entity_description_pipe=kg_entity_description_pipe
or self.create_kg_entity_description_pipe(*args, **kwargs),
kg_clustering_pipe=kg_clustering_pipe
or self.create_kg_clustering_pipe(*args, **kwargs),
kg_entity_deduplication_pipe=kg_entity_deduplication_pipe
or self.create_kg_entity_deduplication_pipe(*args, **kwargs),
kg_entity_deduplication_summary_pipe=kg_entity_deduplication_summary_pipe
or self.create_kg_entity_deduplication_summary_pipe(
*args, **kwargs
),
kg_community_summary_pipe=kg_community_summary_pipe
or self.create_kg_community_summary_pipe(*args, **kwargs),
graph_description_pipe=graph_description_pipe
or self.create_graph_description_pipe(*args, **kwargs),
graph_clustering_pipe=graph_clustering_pipe
or self.create_graph_clustering_pipe(*args, **kwargs),
graph_deduplication_pipe=graph_deduplication_pipe
or self.create_graph_deduplication_pipe(*args, **kwargs),
graph_deduplication_summary_pipe=graph_deduplication_summary_pipe
or self.create_graph_deduplication_summary_pipe(*args, **kwargs),
graph_community_summary_pipe=graph_community_summary_pipe
or self.create_graph_community_summary_pipe(*args, **kwargs),
)
def create_parsing_pipe(self, *args, **kwargs) -> Any:
@@ -525,29 +523,27 @@ class R2RPipeFactory:
config=AsyncPipe.PipeConfig(name="routing_search_pipe"),
)
def create_kg_relationships_extraction_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGExtractionPipe
def create_graph_extraction_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphExtractionPipe
return KGExtractionPipe(
return GraphExtractionPipe(
llm_provider=self.providers.llm,
database_provider=self.providers.database,
config=AsyncPipe.PipeConfig(
name="kg_relationships_extraction_pipe"
),
config=AsyncPipe.PipeConfig(name="graph_extraction_pipe"),
)
def create_kg_storage_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGStoragePipe
def create_graph_storage_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphStoragePipe
return KGStoragePipe(
return GraphStoragePipe(
database_provider=self.providers.database,
config=AsyncPipe.PipeConfig(name="kg_storage_pipe"),
config=AsyncPipe.PipeConfig(name="graph_storage_pipe"),
)
def create_kg_search_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGSearchSearchPipe
def create_graph_search_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphSearchSearchPipe
return KGSearchSearchPipe(
return GraphSearchSearchPipe(
database_provider=self.providers.database,
llm_provider=self.providers.llm,
embedding_provider=self.providers.embedding,
@@ -558,9 +554,9 @@ class R2RPipeFactory:
def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any:
if stream:
from core.pipes import StreamingSearchRAGPipe
from core.pipes import StreamingRAGPipe
return StreamingSearchRAGPipe(
return StreamingRAGPipe(
llm_provider=self.providers.llm,
database_provider=self.providers.database,
config=GeneratorPipe.PipeConfig(
@@ -568,9 +564,9 @@ class R2RPipeFactory:
),
)
else:
from core.pipes import SearchRAGPipe
from core.pipes import RAGPipe
return SearchRAGPipe(
return RAGPipe(
llm_provider=self.providers.llm,
database_provider=self.providers.database,
config=GeneratorPipe.PipeConfig(
@@ -578,67 +574,65 @@ class R2RPipeFactory:
),
)
def create_kg_entity_description_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGEntityDescriptionPipe
def create_graph_description_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphDescriptionPipe
return KGEntityDescriptionPipe(
return GraphDescriptionPipe(
database_provider=self.providers.database,
llm_provider=self.providers.llm,
embedding_provider=self.providers.embedding,
config=AsyncPipe.PipeConfig(name="kg_entity_description_pipe"),
config=AsyncPipe.PipeConfig(name="graph_description_pipe"),
)
def create_kg_clustering_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGClusteringPipe
def create_graph_clustering_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphClusteringPipe
return KGClusteringPipe(
return GraphClusteringPipe(
database_provider=self.providers.database,
llm_provider=self.providers.llm,
embedding_provider=self.providers.embedding,
config=AsyncPipe.PipeConfig(name="kg_clustering_pipe"),
config=AsyncPipe.PipeConfig(name="graph_clustering_pipe"),
)
def create_kg_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGEntityDeduplicationSummaryPipe
from core.pipes import GraphDeduplicationSummaryPipe
return KGEntityDeduplicationSummaryPipe(
return GraphDeduplicationSummaryPipe(
database_provider=self.providers.database,
llm_provider=self.providers.llm,
embedding_provider=self.providers.embedding,
config=AsyncPipe.PipeConfig(name="kg_deduplication_summary_pipe"),
)
def create_kg_community_summary_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGCommunitySummaryPipe
def create_graph_community_summary_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphCommunitySummaryPipe
return KGCommunitySummaryPipe(
return GraphCommunitySummaryPipe(
database_provider=self.providers.database,
llm_provider=self.providers.llm,
embedding_provider=self.providers.embedding,
config=AsyncPipe.PipeConfig(name="kg_community_summary_pipe"),
config=AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
)
def create_kg_entity_deduplication_pipe(self, *args, **kwargs) -> Any:
from core.pipes import KGEntityDeduplicationPipe
def create_graph_deduplication_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphDeduplicationPipe
return KGEntityDeduplicationPipe(
return GraphDeduplicationPipe(
database_provider=self.providers.database,
llm_provider=self.providers.llm,
embedding_provider=self.providers.embedding,
config=AsyncPipe.PipeConfig(name="kg_entity_deduplication_pipe"),
config=AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
)
def create_kg_entity_deduplication_summary_pipe(
self, *args, **kwargs
) -> Any:
from core.pipes import KGEntityDeduplicationSummaryPipe
def create_graph_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
from core.pipes import GraphDeduplicationSummaryPipe
return KGEntityDeduplicationSummaryPipe(
return GraphDeduplicationSummaryPipe(
database_provider=self.providers.database,
llm_provider=self.providers.llm,
embedding_provider=self.providers.embedding,
config=AsyncPipe.PipeConfig(
name="kg_entity_deduplication_summary_pipe"
name="graph_deduplication_summary_pipe"
),
)
@@ -664,7 +658,7 @@ class R2RPipelineFactory:
self.pipes.vector_search_pipe, vector_search_pipe=True
)
search_pipeline.add_pipe(
self.pipes.kg_search_pipe, kg_search_pipe=True
self.pipes.graph_search_pipe, graph_search_pipe=True
)
return search_pipeline
@@ -11,7 +11,7 @@ from core import GenerationConfig
from core.base import OrchestrationProvider, R2RException
from core.base.abstractions import KGEnrichmentStatus, KGExtractionStatus
from ...services import KgService
from ...services import GraphService
logger = logging.getLogger()
from typing import TYPE_CHECKING
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
def hatchet_kg_factory(
orchestration_provider: OrchestrationProvider, service: KgService
orchestration_provider: OrchestrationProvider, service: GraphService
) -> dict[str, "Hatchet.Workflow"]:
def convert_to_dict(input_data):
@@ -124,7 +124,7 @@ def hatchet_kg_factory(
@orchestration_provider.workflow(name="kg-extract", timeout="360m")
class KGExtractDescribeEmbedWorkflow:
def __init__(self, kg_service: KgService):
def __init__(self, kg_service: GraphService):
self.kg_service = kg_service
@orchestration_provider.concurrency( # type: ignore
@@ -273,7 +273,7 @@ def hatchet_kg_factory(
except Exception as e:
pass
def __init__(self, kg_service: KgService):
def __init__(self, kg_service: GraphService):
self.kg_service = kg_service
@orchestration_provider.step(retries=1)
@@ -392,7 +392,7 @@ def hatchet_kg_factory(
name="entity-deduplication", timeout="360m"
)
class EntityDeduplicationWorkflow:
def __init__(self, kg_service: KgService):
def __init__(self, kg_service: GraphService):
self.kg_service = kg_service
@orchestration_provider.step(retries=0, timeout="360m")
@@ -460,7 +460,7 @@ def hatchet_kg_factory(
name="kg-entity-deduplication-summary", timeout="360m"
)
class EntityDeduplicationSummaryWorkflow:
def __init__(self, kg_service: KgService):
def __init__(self, kg_service: GraphService):
self.kg_service = kg_service
@orchestration_provider.step(retries=0, timeout="360m")
@@ -490,7 +490,7 @@ def hatchet_kg_factory(
@orchestration_provider.workflow(name="build-communities", timeout="360m")
class EnrichGraphWorkflow:
def __init__(self, kg_service: KgService):
def __init__(self, kg_service: GraphService):
self.kg_service = kg_service
@orchestration_provider.step(retries=1, parents=[], timeout="360m")
@@ -642,7 +642,7 @@ def hatchet_kg_factory(
name="kg-community-summary", timeout="360m"
)
class KGCommunitySummaryWorkflow:
def __init__(self, kg_service: KgService):
def __init__(self, kg_service: GraphService):
self.kg_service = kg_service
@orchestration_provider.concurrency( # type: ignore
@@ -6,12 +6,12 @@ import uuid
from core import GenerationConfig, R2RException
from core.base.abstractions import KGEnrichmentStatus
from ...services import KgService
from ...services import GraphService
logger = logging.getLogger()
def simple_kg_factory(service: KgService):
def simple_kg_factory(service: GraphService):
def get_input_data_dict(input_data):
for key, value in input_data.items():
+2 -2
View File
@@ -1,6 +1,6 @@
from .auth_service import AuthService
from .graph_service import GraphService
from .ingestion_service import IngestionService, IngestionServiceAdapter
from .kg_service import KgService
from .management_service import ManagementService
from .retrieval_service import RetrievalService
@@ -9,6 +9,6 @@ __all__ = [
"IngestionService",
"IngestionServiceAdapter",
"ManagementService",
"KgService",
"GraphService",
"RetrievalService",
]
@@ -46,9 +46,9 @@ async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
return results
# TODO - Fix naming convention to read `KGService` instead of `KgService`
# TODO - Fix naming convention to read `KGService` instead of `GraphService`
# this will require a minor change in how services are registered.
class KgService(Service):
class GraphService(Service):
def __init__(
self,
config: R2RConfig,
@@ -90,8 +90,8 @@ class KgService(Service):
status=KGExtractionStatus.PROCESSING,
)
relationships = await self.pipes.kg_relationships_extraction_pipe.run(
input=self.pipes.kg_relationships_extraction_pipe.Input(
relationships = await self.pipes.graph_extraction_pipe.run(
input=self.pipes.graph_extraction_pipe.Input(
message={
"document_id": document_id,
"generation_config": generation_config,
@@ -110,8 +110,10 @@ class KgService(Service):
f"KGService: Finished processing document {document_id} for KG extraction"
)
result_gen = await self.pipes.kg_storage_pipe.run(
input=self.pipes.kg_storage_pipe.Input(message=relationships),
result_gen = await self.pipes.graph_storage_pipe.run(
input=self.pipes.graph_storage_pipe.Input(
message=relationships
),
state=None,
run_manager=self.run_manager,
)
@@ -525,8 +527,8 @@ class KgService(Service):
f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
)
node_descriptions = await self.pipes.kg_entity_description_pipe.run(
input=self.pipes.kg_entity_description_pipe.Input(
node_descriptions = await self.pipes.graph_description_pipe.run(
input=self.pipes.graph_description_pipe.Input(
message={
"offset": i * 256,
"limit": 256,
@@ -571,8 +573,8 @@ class KgService(Service):
f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
)
clustering_result = await self.pipes.kg_clustering_pipe.run(
input=self.pipes.kg_clustering_pipe.Input(
clustering_result = await self.pipes.graph_clustering_pipe.run(
input=self.pipes.graph_clustering_pipe.Input(
message={
"collection_id": collection_id,
"generation_config": generation_config,
@@ -597,8 +599,8 @@ class KgService(Service):
# graph_id: UUID | None,
**kwargs,
):
summary_results = await self.pipes.kg_community_summary_pipe.run(
input=self.pipes.kg_community_summary_pipe.Input(
summary_results = await self.pipes.graph_community_summary_pipe.run(
input=self.pipes.graph_community_summary_pipe.Input(
message={
"offset": offset,
"limit": limit,
@@ -716,8 +718,8 @@ class KgService(Service):
generation_config: GenerationConfig,
**kwargs,
):
deduplication_results = await self.pipes.kg_entity_deduplication_pipe.run(
input=self.pipes.kg_entity_deduplication_pipe.Input(
deduplication_results = await self.pipes.graph_deduplication_pipe.run(
input=self.pipes.graph_deduplication_pipe.Input(
message={
"collection_id": collection_id,
"graph_id": graph_id,
@@ -747,8 +749,8 @@ class KgService(Service):
logger.info(
f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
)
deduplication_summary_results = await self.pipes.kg_entity_deduplication_summary_pipe.run(
input=self.pipes.kg_entity_deduplication_summary_pipe.Input(
deduplication_summary_results = await self.pipes.graph_deduplication_summary_pipe.run(
input=self.pipes.graph_deduplication_summary_pipe.Input(
message={
"collection_id": collection_id,
"offset": offset,
@@ -780,7 +782,7 @@ class KgService(Service):
start_time = time.time()
logger.info(
f"KGExtractionPipe: Processing document {document_id} for KG extraction",
f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
)
# Then create the extractions from the results
@@ -835,7 +837,7 @@ class KgService(Service):
return
logger.info(
f"KGExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
)
# sort the extractions accroding to chunk_order field in metadata in ascending order
@@ -851,7 +853,7 @@ class KgService(Service):
]
logger.info(
f"KGExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
)
tasks = [
@@ -873,7 +875,7 @@ class KgService(Service):
total_tasks = len(tasks)
logger.info(
f"KGExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
)
for completed_task in asyncio.as_completed(tasks):
@@ -882,7 +884,7 @@ class KgService(Service):
completed_tasks += 1
if completed_tasks % 100 == 0:
logger.info(
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
)
except Exception as e:
logger.error(f"Error in Extracting KG Relationships: {e}")
@@ -892,7 +894,7 @@ class KgService(Service):
)
logger.info(
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
)
async def _extract_kg(
@@ -1044,7 +1046,7 @@ class KgService(Service):
)
logger.info(
f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
)
return KGExtraction(
+8 -8
View File
@@ -25,7 +25,7 @@ class SearchPipeline(AsyncPipeline):
super().__init__(run_manager)
self._parsing_pipe: Optional[AsyncPipe] = None
self._vector_search_pipeline: Optional[AsyncPipeline] = None
self._kg_search_pipeline: Optional[AsyncPipeline] = None
self._graph_search_pipeline: Optional[AsyncPipeline] = None
async def run( # type: ignore
self,
@@ -68,7 +68,7 @@ class SearchPipeline(AsyncPipeline):
)
)
kg_task = asyncio.create_task(
self._kg_search_pipeline.run(
self._graph_search_pipeline.run(
dequeue_requests(kg_queue),
request_state,
stream,
@@ -93,22 +93,22 @@ class SearchPipeline(AsyncPipeline):
self,
pipe: AsyncPipe,
add_upstream_outputs: Optional[list[dict[str, str]]] = None,
kg_search_pipe: bool = False,
graph_search_pipe: bool = False,
vector_search_pipe: bool = False,
*args,
**kwargs,
) -> None:
logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline")
if kg_search_pipe:
if not self._kg_search_pipeline:
self._kg_search_pipeline = AsyncPipeline()
if not self._kg_search_pipeline:
if graph_search_pipe:
if not self._graph_search_pipeline:
self._graph_search_pipeline = AsyncPipeline()
if not self._graph_search_pipeline:
raise ValueError(
"KG search pipeline not found"
) # for type hinting
self._kg_search_pipeline.add_pipe(
self._graph_search_pipeline.add_pipe(
pipe, add_upstream_outputs, *args, **kwargs
)
elif vector_search_pipe:
+20 -20
View File
@@ -3,39 +3,39 @@ from .abstractions.search_pipe import SearchPipe
from .ingestion.embedding_pipe import EmbeddingPipe
from .ingestion.parsing_pipe import ParsingPipe
from .ingestion.vector_storage_pipe import VectorStoragePipe
from .kg.clustering import KGClusteringPipe
from .kg.community_summary import KGCommunitySummaryPipe
from .kg.deduplication import KGEntityDeduplicationPipe
from .kg.deduplication_summary import KGEntityDeduplicationSummaryPipe
from .kg.description import KGEntityDescriptionPipe
from .kg.extraction import KGExtractionPipe
from .kg.storage import KGStoragePipe
from .kg.clustering import GraphClusteringPipe
from .kg.community_summary import GraphCommunitySummaryPipe
from .kg.deduplication import GraphDeduplicationPipe
from .kg.deduplication_summary import GraphDeduplicationSummaryPipe
from .kg.description import GraphDescriptionPipe
from .kg.extraction import GraphExtractionPipe
from .kg.storage import GraphStoragePipe
from .retrieval.chunk_search_pipe import VectorSearchPipe
from .retrieval.kg_search_pipe import KGSearchSearchPipe
from .retrieval.graph_search_pipe import GraphSearchSearchPipe
from .retrieval.multi_search import MultiSearchPipe
from .retrieval.query_transform_pipe import QueryTransformPipe
from .retrieval.routing_search_pipe import RoutingSearchPipe
from .retrieval.search_rag_pipe import SearchRAGPipe
from .retrieval.streaming_rag_pipe import StreamingSearchRAGPipe
from .retrieval.search_rag_pipe import RAGPipe
from .retrieval.streaming_rag_pipe import StreamingRAGPipe
__all__ = [
"SearchPipe",
"GeneratorPipe",
"EmbeddingPipe",
"KGExtractionPipe",
"KGSearchSearchPipe",
"KGEntityDescriptionPipe",
"GraphExtractionPipe",
"GraphSearchSearchPipe",
"GraphDescriptionPipe",
"ParsingPipe",
"QueryTransformPipe",
"SearchRAGPipe",
"StreamingSearchRAGPipe",
"RAGPipe",
"StreamingRAGPipe",
"VectorSearchPipe",
"VectorStoragePipe",
"KGStoragePipe",
"KGClusteringPipe",
"GraphStoragePipe",
"GraphClusteringPipe",
"MultiSearchPipe",
"KGCommunitySummaryPipe",
"GraphCommunitySummaryPipe",
"RoutingSearchPipe",
"KGEntityDeduplicationPipe",
"KGEntityDeduplicationSummaryPipe",
"GraphDeduplicationPipe",
"GraphDeduplicationSummaryPipe",
]
+1 -1
View File
@@ -15,7 +15,7 @@ from core.database import PostgresDatabaseProvider
logger = logging.getLogger()
class KGClusteringPipe(AsyncPipe):
class GraphClusteringPipe(AsyncPipe):
"""
Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
"""
+7 -7
View File
@@ -21,7 +21,7 @@ from ...database.postgres import PostgresDatabaseProvider
logger = logging.getLogger()
class KGCommunitySummaryPipe(AsyncPipe):
class GraphCommunitySummaryPipe(AsyncPipe):
"""
Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
"""
@@ -40,7 +40,7 @@ class KGCommunitySummaryPipe(AsyncPipe):
"""
super().__init__(
config=config
or AsyncPipe.PipeConfig(name="kg_community_summary_pipe"),
or AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
)
self.database_provider = database_provider
self.llm_provider = llm_provider
@@ -210,7 +210,7 @@ class KGCommunitySummaryPipe(AsyncPipe):
except Exception as e:
if attempt == 2:
logger.error(
f"KGCommunitySummaryPipe: Error generating community summary for community {community_id}: {e}"
f"GraphCommunitySummaryPipe: Error generating community summary for community {community_id}: {e}"
)
return {
"community_id": community_id,
@@ -265,7 +265,7 @@ class KGCommunitySummaryPipe(AsyncPipe):
# check which community summaries exist and don't run them again
logger.info(
f"KGCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
f"GraphCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
)
all_entities, _ = (
@@ -335,12 +335,12 @@ class KGCommunitySummaryPipe(AsyncPipe):
completed_community_summary_jobs += 1
if completed_community_summary_jobs % 50 == 0:
logger.info(
f"KGCommunitySummaryPipe: {completed_community_summary_jobs}/{total_jobs} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds"
f"GraphCommunitySummaryPipe: {completed_community_summary_jobs}/{total_jobs} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds"
)
if "error" in summary:
logger.error(
f"KGCommunitySummaryPipe: Error generating community summary for community {summary['community_id']}: {summary['error']}"
f"GraphCommunitySummaryPipe: Error generating community summary for community {summary['community_id']}: {summary['error']}"
)
total_errors += 1
continue
@@ -349,5 +349,5 @@ class KGCommunitySummaryPipe(AsyncPipe):
if total_errors > 0:
raise ValueError(
f"KGCommunitySummaryPipe: Failed to generate community summaries for {total_errors} out of {total_jobs} communities. Please rerun the job if there are too many failures."
f"GraphCommunitySummaryPipe: Failed to generate community summaries for {total_errors} out of {total_jobs} communities. Please rerun the job if there are too many failures."
)
+7 -7
View File
@@ -18,7 +18,7 @@ from core.providers import (
logger = logging.getLogger()
class KGEntityDeduplicationPipe(AsyncPipe):
class GraphDeduplicationPipe(AsyncPipe):
def __init__(
self,
config: AsyncPipe.PipeConfig,
@@ -33,7 +33,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
):
super().__init__(
config=config
or AsyncPipe.PipeConfig(name="kg_entity_deduplication_pipe"),
or AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
)
self.database_provider = database_provider
self.llm_provider = llm_provider
@@ -69,7 +69,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
entities = await self._get_entities(graph_id, collection_id)
logger.info(
f"KGEntityDeduplicationPipe: Got {len(entities)} entities for {graph_id or collection_id}"
f"GraphDeduplicationPipe: Got {len(entities)} entities for {graph_id or collection_id}"
)
# deduplicate entities by name
@@ -129,7 +129,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
]
logger.info(
f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
f"GraphDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
)
await self.database_provider.graphs_handler.add_entities(
@@ -171,7 +171,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
embeddings = [entity.description_embedding for entity in entities]
logger.info(
f"KGEntityDeduplicationPipe: Running DBSCAN clustering on {len(embeddings)} embeddings"
f"GraphDeduplicationPipe: Running DBSCAN clustering on {len(embeddings)} embeddings"
)
# TODO: make eps a config, make it very strict for now
clustering = DBSCAN(eps=0.1, min_samples=2, metric="cosine").fit(
@@ -183,7 +183,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
n_noise = list(labels).count(-1)
logger.info(
f"KGEntityDeduplicationPipe: Found {n_clusters} clusters and {n_noise} noise points"
f"GraphDeduplicationPipe: Found {n_clusters} clusters and {n_noise} noise points"
)
# for all labels in the same cluster, we can deduplicate them by name
@@ -236,7 +236,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
)
logger.info(
f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
f"GraphDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
)
await self.database_provider.graphs_handler.add_entities(
deduplicated_entities_list,
+1 -1
View File
@@ -18,7 +18,7 @@ from core.providers import ( # PostgresDatabaseProvider,
logger = logging.getLogger()
class KGEntityDeduplicationSummaryPipe(AsyncPipe[Any]):
class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
class Input(AsyncPipe.Input):
message: dict
+5 -5
View File
@@ -15,7 +15,7 @@ from ...database.postgres import PostgresDatabaseProvider
logger = logging.getLogger()
class KGEntityDescriptionPipe(AsyncPipe):
class GraphDescriptionPipe(AsyncPipe):
"""
The pipe takes input a list of nodes and extracts description from them.
"""
@@ -147,7 +147,7 @@ class KGEntityDescriptionPipe(AsyncPipe):
logger = input.message["logger"]
logger.info(
f"KGEntityDescriptionPipe: Getting entity map for document {document_id}",
f"GraphDescriptionPipe: Getting entity map for document {document_id}",
)
entity_map = (
@@ -158,7 +158,7 @@ class KGEntityDescriptionPipe(AsyncPipe):
total_entities = len(entity_map)
logger.info(
f"KGEntityDescriptionPipe: Got entity map for document {document_id}, total entities: {total_entities}, time from start: {time.time() - start_time:.2f} seconds",
f"GraphDescriptionPipe: Got entity map for document {document_id}, total entities: {total_entities}, time from start: {time.time() - start_time:.2f} seconds",
)
workflows = []
@@ -182,11 +182,11 @@ class KGEntityDescriptionPipe(AsyncPipe):
for result in asyncio.as_completed(workflows):
if completed_entities % 100 == 0:
logger.info(
f"KGEntityDescriptionPipe: Completed {completed_entities+1} of {total_entities} entities for document {document_id}",
f"GraphDescriptionPipe: Completed {completed_entities+1} of {total_entities} entities for document {document_id}",
)
yield await result
completed_entities += 1
logger.info(
f"KGEntityDescriptionPipe: Processed {total_entities} entities for document {document_id}, time from start: {time.time() - start_time:.2f} seconds",
f"GraphDescriptionPipe: Processed {total_entities} entities for document {document_id}, time from start: {time.time() - start_time:.2f} seconds",
)
+9 -11
View File
@@ -32,7 +32,7 @@ class ClientError(Exception):
pass
class KGExtractionPipe(AsyncPipe[dict]):
class GraphExtractionPipe(AsyncPipe[dict]):
"""
Extracts knowledge graph information from document extractions.
"""
@@ -54,9 +54,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
):
super().__init__(
config=config
or AsyncPipe.PipeConfig(
name="default_kg_relationships_extraction_pipe"
),
or AsyncPipe.PipeConfig(name="default_graph_extraction_pipe"),
)
self.database_provider = database_provider
self.llm_provider = llm_provider
@@ -198,7 +196,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
# add metadata to entities and relationships
logger.info(
f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {extractions[0].document_id}",
f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {extractions[0].document_id}",
)
return KGExtraction(
@@ -233,7 +231,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
logger = input.message.get("logger", logging.getLogger())
logger.info(
f"KGExtractionPipe: Processing document {document_id} for KG extraction",
f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
)
# Then create the extractions from the results
@@ -277,7 +275,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
return
logger.info(
f"KGExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds",
)
# sort the extractions accroding to chunk_order field in metadata in ascending order
@@ -293,7 +291,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
]
logger.info(
f"KGExtractionPipe: Extracting KG Relationships for document and created {len(extractions_groups)} tasks, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(extractions_groups)} tasks, time from start: {time.time() - start_time:.2f} seconds",
)
tasks = [
@@ -315,7 +313,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
total_tasks = len(tasks)
logger.info(
f"KGExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
)
for completed_task in asyncio.as_completed(tasks):
@@ -324,7 +322,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
completed_tasks += 1
if completed_tasks % 100 == 0:
logger.info(
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
)
except Exception as e:
logger.error(f"Error in Extracting KG Relationships: {e}")
@@ -334,5 +332,5 @@ class KGExtractionPipe(AsyncPipe[dict]):
)
logger.info(
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
)
+2 -2
View File
@@ -10,7 +10,7 @@ from core.database import PostgresDatabaseProvider
logger = logging.getLogger()
class KGStoragePipe(AsyncPipe):
class GraphStoragePipe(AsyncPipe):
# TODO - Apply correct type hints to storage messages
class Input(AsyncPipe.Input):
message: AsyncGenerator[list[Any], None]
@@ -27,7 +27,7 @@ class KGStoragePipe(AsyncPipe):
Initializes the async knowledge graph storage pipe with necessary components and configurations.
"""
logger.info(
f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database."
f"Initializing an `GraphStoragePipe` to store knowledge graph extractions in a graph database."
)
super().__init__(
@@ -24,7 +24,7 @@ from ..abstractions.generator_pipe import GeneratorPipe
logger = logging.getLogger()
class KGSearchSearchPipe(GeneratorPipe):
class GraphSearchSearchPipe(GeneratorPipe):
"""
Embeds and stores documents using a specified embedding model and database.
"""
+1 -1
View File
@@ -14,7 +14,7 @@ from core.base.abstractions import GenerationConfig, RAGCompletion
from ..abstractions.generator_pipe import GeneratorPipe
class SearchRAGPipe(GeneratorPipe):
class RAGPipe(GeneratorPipe):
class Input(AsyncPipe.Input):
message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
@@ -17,7 +17,7 @@ from ..abstractions.generator_pipe import GeneratorPipe
logger = logging.getLogger()
class StreamingSearchRAGPipe(GeneratorPipe):
class StreamingRAGPipe(GeneratorPipe):
CHUNK_SEARCH_STREAM_MARKER = (
"search" # TODO - change this to vector_search in next major release
)
@@ -72,7 +72,7 @@ class StreamingSearchRAGPipe(GeneratorPipe):
for chunk in self.llm_provider.get_completion_stream(
messages=messages, generation_config=rag_generation_config
):
chunk_txt = StreamingSearchRAGPipe._process_chunk(chunk)
chunk_txt = StreamingRAGPipe._process_chunk(chunk)
response += chunk_txt
yield chunk_txt
+1 -1
View File
@@ -2,7 +2,7 @@ from typing import Any, Optional
from uuid import UUID
from shared.api.models.base import WrappedBooleanResponse
from shared.api.models.kg.responses import (
from shared.api.models.graph.responses import (
WrappedCommunitiesResponse,
WrappedCommunityResponse,
WrappedEntitiesResponse,
+5 -5
View File
@@ -10,17 +10,17 @@ from shared.api.models.base import (
WrappedBooleanResponse,
WrappedGenericMessageResponse,
)
from shared.api.models.graph.responses import (
GraphResponse,
WrappedGraphResponse,
WrappedGraphsResponse,
)
from shared.api.models.ingestion.responses import (
IngestionResponse,
WrappedIngestionResponse,
WrappedMetadataUpdateResponse,
WrappedUpdateResponse,
)
from shared.api.models.kg.responses import (
GraphResponse,
WrappedGraphResponse,
WrappedGraphsResponse,
)
from shared.api.models.management.responses import (
AnalyticsResponse,
ChunkResponse,