a
This commit is contained in:
+5
-5
@@ -173,7 +173,7 @@ __all__ = [
|
||||
"IngestionService",
|
||||
"ManagementService",
|
||||
"RetrievalService",
|
||||
"KgService",
|
||||
"GraphService",
|
||||
## PARSERS
|
||||
# Media parsers
|
||||
"AudioParser",
|
||||
@@ -213,14 +213,14 @@ __all__ = [
|
||||
## PIPES
|
||||
"SearchPipe",
|
||||
"EmbeddingPipe",
|
||||
"KGExtractionPipe",
|
||||
"GraphExtractionPipe",
|
||||
"ParsingPipe",
|
||||
"QueryTransformPipe",
|
||||
"SearchRAGPipe",
|
||||
"StreamingSearchRAGPipe",
|
||||
"RAGPipe",
|
||||
"StreamingRAGPipe",
|
||||
"VectorSearchPipe",
|
||||
"VectorStoragePipe",
|
||||
"KGStoragePipe",
|
||||
"GraphStoragePipe",
|
||||
"MultiSearchPipe",
|
||||
## PROVIDERS
|
||||
# Auth
|
||||
|
||||
@@ -10,15 +10,7 @@ from shared.api.models.base import (
|
||||
WrappedBooleanResponse,
|
||||
WrappedGenericMessageResponse,
|
||||
)
|
||||
from shared.api.models.ingestion.responses import (
|
||||
IngestionResponse,
|
||||
UpdateResponse,
|
||||
WrappedIngestionResponse,
|
||||
WrappedListVectorIndicesResponse,
|
||||
WrappedMetadataUpdateResponse,
|
||||
WrappedUpdateResponse,
|
||||
)
|
||||
from shared.api.models.kg.responses import ( # TODO: Need to review anything above this
|
||||
from shared.api.models.graph.responses import ( # TODO: Need to review anything above this
|
||||
Community,
|
||||
Entity,
|
||||
GraphResponse,
|
||||
@@ -32,6 +24,14 @@ from shared.api.models.kg.responses import ( # TODO: Need to review anything ab
|
||||
WrappedRelationshipResponse,
|
||||
WrappedRelationshipsResponse,
|
||||
)
|
||||
from shared.api.models.ingestion.responses import (
|
||||
IngestionResponse,
|
||||
UpdateResponse,
|
||||
WrappedIngestionResponse,
|
||||
WrappedListVectorIndicesResponse,
|
||||
WrappedMetadataUpdateResponse,
|
||||
WrappedUpdateResponse,
|
||||
)
|
||||
from shared.api.models.management.responses import ( # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed
|
||||
AnalyticsResponse,
|
||||
ChunkResponse,
|
||||
|
||||
@@ -30,5 +30,5 @@ __all__ = [
|
||||
"IngestionService",
|
||||
"ManagementService",
|
||||
"RetrievalService",
|
||||
"KgService",
|
||||
"GraphService",
|
||||
]
|
||||
|
||||
@@ -4,9 +4,24 @@ from typing import TYPE_CHECKING, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
|
||||
from core.base.pipes import AsyncPipe
|
||||
from core.database import PostgresDatabaseProvider
|
||||
from core.pipelines import RAGPipeline, SearchPipeline
|
||||
from core.pipes import (
|
||||
EmbeddingPipe,
|
||||
GraphClusteringPipe,
|
||||
GraphCommunitySummaryPipe,
|
||||
GraphDeduplicationPipe,
|
||||
GraphDeduplicationSummaryPipe,
|
||||
GraphDescriptionPipe,
|
||||
GraphExtractionPipe,
|
||||
GraphSearchSearchPipe,
|
||||
GraphStoragePipe,
|
||||
ParsingPipe,
|
||||
RAGPipe,
|
||||
StreamingRAGPipe,
|
||||
VectorSearchPipe,
|
||||
VectorStoragePipe,
|
||||
)
|
||||
from core.providers import (
|
||||
AsyncSMTPEmailProvider,
|
||||
ConsoleMockEmailProvider,
|
||||
@@ -26,8 +41,8 @@ from core.providers import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.main.services.auth_service import AuthService
|
||||
from core.main.services.graph_service import GraphService
|
||||
from core.main.services.ingestion_service import IngestionService
|
||||
from core.main.services.kg_service import KgService
|
||||
from core.main.services.management_service import ManagementService
|
||||
from core.main.services.retrieval_service import RetrievalService
|
||||
|
||||
@@ -54,20 +69,20 @@ class R2RProviders(BaseModel):
|
||||
|
||||
|
||||
class R2RPipes(BaseModel):
|
||||
parsing_pipe: AsyncPipe
|
||||
embedding_pipe: AsyncPipe
|
||||
kg_search_pipe: AsyncPipe
|
||||
kg_relationships_extraction_pipe: AsyncPipe
|
||||
kg_storage_pipe: AsyncPipe
|
||||
kg_entity_description_pipe: AsyncPipe
|
||||
kg_clustering_pipe: AsyncPipe
|
||||
kg_entity_deduplication_pipe: AsyncPipe
|
||||
kg_entity_deduplication_summary_pipe: AsyncPipe
|
||||
kg_community_summary_pipe: AsyncPipe
|
||||
rag_pipe: AsyncPipe
|
||||
streaming_rag_pipe: AsyncPipe
|
||||
vector_storage_pipe: AsyncPipe
|
||||
vector_search_pipe: AsyncPipe
|
||||
parsing_pipe: ParsingPipe
|
||||
embedding_pipe: EmbeddingPipe
|
||||
graph_search_pipe: GraphSearchSearchPipe
|
||||
graph_extraction_pipe: GraphExtractionPipe
|
||||
graph_storage_pipe: GraphStoragePipe
|
||||
graph_description_pipe: GraphDescriptionPipe
|
||||
graph_clustering_pipe: GraphClusteringPipe
|
||||
graph_deduplication_pipe: GraphDeduplicationPipe
|
||||
graph_deduplication_summary_pipe: GraphDeduplicationSummaryPipe
|
||||
graph_community_summary_pipe: GraphCommunitySummaryPipe
|
||||
rag_pipe: RAGPipe
|
||||
streaming_rag_pipe: StreamingRAGPipe
|
||||
vector_storage_pipe: VectorStoragePipe
|
||||
vector_search_pipe: VectorSearchPipe
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -96,4 +111,4 @@ class R2RServices:
|
||||
ingestion: Optional["IngestionService"] = None
|
||||
management: Optional["ManagementService"] = None
|
||||
retrieval: Optional["RetrievalService"] = None
|
||||
kg: Optional["KgService"] = None
|
||||
graph: Optional["GraphService"] = None
|
||||
|
||||
@@ -1144,7 +1144,7 @@ class CollectionsRouter(BaseRouterV3):
|
||||
from core.main.orchestration import simple_kg_factory
|
||||
|
||||
logger.info("Running extract-triples without orchestration.")
|
||||
simple_kg = simple_kg_factory(self.services.kg)
|
||||
simple_kg = simple_kg_factory(self.services.graph)
|
||||
await simple_kg["extract-triples"](workflow_input) # type: ignore
|
||||
return { # type: ignore
|
||||
"message": "Graph created successfully.",
|
||||
|
||||
@@ -1359,7 +1359,7 @@ class DocumentsRouter(BaseRouterV3):
|
||||
"message": "Estimate retrieved successfully",
|
||||
"task_id": None,
|
||||
"id": id,
|
||||
"estimate": await self.services.kg.get_creation_estimate(
|
||||
"estimate": await self.services.graph.get_creation_estimate(
|
||||
document_id=id,
|
||||
graph_creation_settings=server_graph_creation_settings,
|
||||
),
|
||||
@@ -1379,7 +1379,7 @@ class DocumentsRouter(BaseRouterV3):
|
||||
from core.main.orchestration import simple_kg_factory
|
||||
|
||||
logger.info("Running extract-triples without orchestration.")
|
||||
simple_kg = simple_kg_factory(self.services.kg)
|
||||
simple_kg = simple_kg_factory(self.services.graph)
|
||||
await simple_kg["extract-triples"](workflow_input)
|
||||
return { # type: ignore
|
||||
"message": "Graph created successfully.",
|
||||
|
||||
@@ -65,7 +65,7 @@ class GraphRouter(BaseRouterV3):
|
||||
|
||||
self.providers.orchestration.register_workflows(
|
||||
Workflow.KG,
|
||||
self.services.kg,
|
||||
self.services.graph,
|
||||
workflow_messages,
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class GraphRouter(BaseRouterV3):
|
||||
|
||||
# Return cost estimate if requested
|
||||
if run_type == KGRunType.ESTIMATE:
|
||||
return await self.services.kg.get_deduplication_estimate(
|
||||
return await self.services.graph.get_deduplication_estimate(
|
||||
collection_id, server_settings
|
||||
)
|
||||
|
||||
@@ -136,7 +136,7 @@ class GraphRouter(BaseRouterV3):
|
||||
else:
|
||||
from core.main.orchestration import simple_kg_factory
|
||||
|
||||
simple_kg = simple_kg_factory(self.services.kg)
|
||||
simple_kg = simple_kg_factory(self.services.graph)
|
||||
await simple_kg["entity-deduplication"](workflow_input)
|
||||
return { # type: ignore
|
||||
"message": "Entity deduplication completed successfully.",
|
||||
@@ -223,7 +223,7 @@ class GraphRouter(BaseRouterV3):
|
||||
|
||||
graph_uuids = [UUID(graph_id) for graph_id in collection_ids]
|
||||
|
||||
list_graphs_response = await self.services.kg.list_graphs(
|
||||
list_graphs_response = await self.services.graph.list_graphs(
|
||||
# user_ids=requesting_user_id,
|
||||
graph_ids=graph_uuids,
|
||||
offset=offset,
|
||||
@@ -302,7 +302,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
list_graphs_response = await self.services.kg.list_graphs(
|
||||
list_graphs_response = await self.services.graph.list_graphs(
|
||||
# user_ids=None,
|
||||
graph_ids=[collection_id],
|
||||
offset=0,
|
||||
@@ -394,7 +394,7 @@ class GraphRouter(BaseRouterV3):
|
||||
from core.main.orchestration import simple_kg_factory
|
||||
|
||||
logger.info("Running build-communities without orchestration.")
|
||||
simple_kg = simple_kg_factory(self.services.kg)
|
||||
simple_kg = simple_kg_factory(self.services.graph)
|
||||
await simple_kg["build-communities"](workflow_input)
|
||||
return {
|
||||
"message": "Graph communities created successfully.",
|
||||
@@ -476,7 +476,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
await self.services.kg.reset_graph_v3(id=collection_id)
|
||||
await self.services.graph.reset_graph_v3(id=collection_id)
|
||||
# await _pull(collection_id, auth_user)
|
||||
return GenericBooleanResponse(success=True) # type: ignore
|
||||
|
||||
@@ -562,7 +562,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
return await self.services.kg.update_graph( # type: ignore
|
||||
return await self.services.graph.update_graph( # type: ignore
|
||||
collection_id,
|
||||
name=name,
|
||||
description=description,
|
||||
@@ -637,7 +637,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
entities, count = await self.services.kg.get_entities(
|
||||
entities, count = await self.services.graph.get_entities(
|
||||
parent_id=collection_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
@@ -682,7 +682,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
return await self.services.kg.create_entity(
|
||||
return await self.services.graph.create_entity(
|
||||
name=name,
|
||||
description=description,
|
||||
parent_id=collection_id,
|
||||
@@ -744,7 +744,7 @@ class GraphRouter(BaseRouterV3):
|
||||
"The currently authenticated user does not have access to the collection associated with the given graph.",
|
||||
403,
|
||||
)
|
||||
return await self.services.kg.create_relationship(
|
||||
return await self.services.graph.create_relationship(
|
||||
subject=subject,
|
||||
subject_id=subject_id,
|
||||
predicate=predicate,
|
||||
@@ -874,7 +874,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
return await self.services.kg.update_entity(
|
||||
return await self.services.graph.update_entity(
|
||||
entity_id=entity_id,
|
||||
name=name,
|
||||
category=category,
|
||||
@@ -954,7 +954,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
await self.services.kg.delete_entity(
|
||||
await self.services.graph.delete_entity(
|
||||
parent_id=collection_id,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
@@ -1033,7 +1033,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
relationships, count = await self.services.kg.get_relationships(
|
||||
relationships, count = await self.services.graph.get_relationships(
|
||||
parent_id=collection_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
@@ -1178,7 +1178,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
return await self.services.kg.update_relationship(
|
||||
return await self.services.graph.update_relationship(
|
||||
relationship_id=relationship_id,
|
||||
subject=subject,
|
||||
subject_id=subject_id,
|
||||
@@ -1261,7 +1261,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
await self.services.kg.delete_relationship(
|
||||
await self.services.graph.delete_relationship(
|
||||
parent_id=collection_id,
|
||||
relationship_id=relationship_id,
|
||||
)
|
||||
@@ -1368,7 +1368,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
return await self.services.kg.create_community(
|
||||
return await self.services.graph.create_community(
|
||||
parent_id=collection_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
@@ -1449,7 +1449,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
communities, count = await self.services.kg.get_communities(
|
||||
communities, count = await self.services.graph.get_communities(
|
||||
parent_id=collection_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
@@ -1611,7 +1611,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
await self.services.kg.delete_community(
|
||||
await self.services.graph.delete_community(
|
||||
parent_id=collection_id,
|
||||
community_id=community_id,
|
||||
)
|
||||
@@ -1703,7 +1703,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
return await self.services.kg.update_community(
|
||||
return await self.services.graph.update_community(
|
||||
community_id=community_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
@@ -1801,7 +1801,7 @@ class GraphRouter(BaseRouterV3):
|
||||
403,
|
||||
)
|
||||
|
||||
list_graphs_response = await self.services.kg.list_graphs(
|
||||
list_graphs_response = await self.services.graph.list_graphs(
|
||||
# user_ids=None,
|
||||
graph_ids=[collection_id],
|
||||
offset=0,
|
||||
|
||||
@@ -14,8 +14,8 @@ from core.base import (
|
||||
)
|
||||
from core.main.abstractions import R2RServices
|
||||
from core.main.services.auth_service import AuthService
|
||||
from core.main.services.graph_service import GraphService
|
||||
from core.main.services.ingestion_service import IngestionService
|
||||
from core.main.services.kg_service import KgService
|
||||
from core.main.services.management_service import ManagementService
|
||||
from core.main.services.retrieval_service import RetrievalService
|
||||
from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
|
||||
@@ -85,7 +85,7 @@ class R2RBuilder:
|
||||
ingestion=service_instances["ingestion"],
|
||||
management=service_instances["management"],
|
||||
retrieval=service_instances["retrieval"],
|
||||
kg=service_instances["kg"],
|
||||
graph=service_instances["graph"],
|
||||
)
|
||||
|
||||
async def _create_providers(
|
||||
|
||||
@@ -364,18 +364,18 @@ class R2RPipeFactory:
|
||||
self,
|
||||
parsing_pipe_override: Optional[AsyncPipe] = None,
|
||||
embedding_pipe_override: Optional[AsyncPipe] = None,
|
||||
kg_relationships_extraction_pipe_override: Optional[AsyncPipe] = None,
|
||||
kg_storage_pipe_override: Optional[AsyncPipe] = None,
|
||||
kg_search_pipe_override: Optional[AsyncPipe] = None,
|
||||
graph_extraction_pipe_override: Optional[AsyncPipe] = None,
|
||||
graph_storage_pipe_override: Optional[AsyncPipe] = None,
|
||||
graph_search_pipe_override: Optional[AsyncPipe] = None,
|
||||
vector_storage_pipe_override: Optional[AsyncPipe] = None,
|
||||
vector_search_pipe_override: Optional[AsyncPipe] = None,
|
||||
rag_pipe_override: Optional[AsyncPipe] = None,
|
||||
streaming_rag_pipe_override: Optional[AsyncPipe] = None,
|
||||
kg_entity_description_pipe: Optional[AsyncPipe] = None,
|
||||
kg_clustering_pipe: Optional[AsyncPipe] = None,
|
||||
kg_entity_deduplication_pipe: Optional[AsyncPipe] = None,
|
||||
kg_entity_deduplication_summary_pipe: Optional[AsyncPipe] = None,
|
||||
kg_community_summary_pipe: Optional[AsyncPipe] = None,
|
||||
graph_description_pipe: Optional[AsyncPipe] = None,
|
||||
graph_clustering_pipe: Optional[AsyncPipe] = None,
|
||||
graph_deduplication_pipe: Optional[AsyncPipe] = None,
|
||||
graph_deduplication_summary_pipe: Optional[AsyncPipe] = None,
|
||||
graph_community_summary_pipe: Optional[AsyncPipe] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> R2RPipes:
|
||||
@@ -388,32 +388,30 @@ class R2RPipeFactory:
|
||||
),
|
||||
embedding_pipe=embedding_pipe_override
|
||||
or self.create_embedding_pipe(*args, **kwargs),
|
||||
kg_relationships_extraction_pipe=kg_relationships_extraction_pipe_override
|
||||
or self.create_kg_relationships_extraction_pipe(*args, **kwargs),
|
||||
kg_storage_pipe=kg_storage_pipe_override
|
||||
or self.create_kg_storage_pipe(*args, **kwargs),
|
||||
graph_extraction_pipe=graph_extraction_pipe_override
|
||||
or self.create_graph_extraction_pipe(*args, **kwargs),
|
||||
graph_storage_pipe=graph_storage_pipe_override
|
||||
or self.create_graph_storage_pipe(*args, **kwargs),
|
||||
vector_storage_pipe=vector_storage_pipe_override
|
||||
or self.create_vector_storage_pipe(*args, **kwargs),
|
||||
vector_search_pipe=vector_search_pipe_override
|
||||
or self.create_vector_search_pipe(*args, **kwargs),
|
||||
kg_search_pipe=kg_search_pipe_override
|
||||
or self.create_kg_search_pipe(*args, **kwargs),
|
||||
graph_search_pipe=graph_search_pipe_override
|
||||
or self.create_graph_search_pipe(*args, **kwargs),
|
||||
rag_pipe=rag_pipe_override
|
||||
or self.create_rag_pipe(*args, **kwargs),
|
||||
streaming_rag_pipe=streaming_rag_pipe_override
|
||||
or self.create_rag_pipe(True, *args, **kwargs),
|
||||
kg_entity_description_pipe=kg_entity_description_pipe
|
||||
or self.create_kg_entity_description_pipe(*args, **kwargs),
|
||||
kg_clustering_pipe=kg_clustering_pipe
|
||||
or self.create_kg_clustering_pipe(*args, **kwargs),
|
||||
kg_entity_deduplication_pipe=kg_entity_deduplication_pipe
|
||||
or self.create_kg_entity_deduplication_pipe(*args, **kwargs),
|
||||
kg_entity_deduplication_summary_pipe=kg_entity_deduplication_summary_pipe
|
||||
or self.create_kg_entity_deduplication_summary_pipe(
|
||||
*args, **kwargs
|
||||
),
|
||||
kg_community_summary_pipe=kg_community_summary_pipe
|
||||
or self.create_kg_community_summary_pipe(*args, **kwargs),
|
||||
graph_description_pipe=graph_description_pipe
|
||||
or self.create_graph_description_pipe(*args, **kwargs),
|
||||
graph_clustering_pipe=graph_clustering_pipe
|
||||
or self.create_graph_clustering_pipe(*args, **kwargs),
|
||||
graph_deduplication_pipe=graph_deduplication_pipe
|
||||
or self.create_graph_deduplication_pipe(*args, **kwargs),
|
||||
graph_deduplication_summary_pipe=graph_deduplication_summary_pipe
|
||||
or self.create_graph_deduplication_summary_pipe(*args, **kwargs),
|
||||
graph_community_summary_pipe=graph_community_summary_pipe
|
||||
or self.create_graph_community_summary_pipe(*args, **kwargs),
|
||||
)
|
||||
|
||||
def create_parsing_pipe(self, *args, **kwargs) -> Any:
|
||||
@@ -525,29 +523,27 @@ class R2RPipeFactory:
|
||||
config=AsyncPipe.PipeConfig(name="routing_search_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_relationships_extraction_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGExtractionPipe
|
||||
def create_graph_extraction_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphExtractionPipe
|
||||
|
||||
return KGExtractionPipe(
|
||||
return GraphExtractionPipe(
|
||||
llm_provider=self.providers.llm,
|
||||
database_provider=self.providers.database,
|
||||
config=AsyncPipe.PipeConfig(
|
||||
name="kg_relationships_extraction_pipe"
|
||||
),
|
||||
config=AsyncPipe.PipeConfig(name="graph_extraction_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_storage_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGStoragePipe
|
||||
def create_graph_storage_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphStoragePipe
|
||||
|
||||
return KGStoragePipe(
|
||||
return GraphStoragePipe(
|
||||
database_provider=self.providers.database,
|
||||
config=AsyncPipe.PipeConfig(name="kg_storage_pipe"),
|
||||
config=AsyncPipe.PipeConfig(name="graph_storage_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_search_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGSearchSearchPipe
|
||||
def create_graph_search_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphSearchSearchPipe
|
||||
|
||||
return KGSearchSearchPipe(
|
||||
return GraphSearchSearchPipe(
|
||||
database_provider=self.providers.database,
|
||||
llm_provider=self.providers.llm,
|
||||
embedding_provider=self.providers.embedding,
|
||||
@@ -558,9 +554,9 @@ class R2RPipeFactory:
|
||||
|
||||
def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any:
|
||||
if stream:
|
||||
from core.pipes import StreamingSearchRAGPipe
|
||||
from core.pipes import StreamingRAGPipe
|
||||
|
||||
return StreamingSearchRAGPipe(
|
||||
return StreamingRAGPipe(
|
||||
llm_provider=self.providers.llm,
|
||||
database_provider=self.providers.database,
|
||||
config=GeneratorPipe.PipeConfig(
|
||||
@@ -568,9 +564,9 @@ class R2RPipeFactory:
|
||||
),
|
||||
)
|
||||
else:
|
||||
from core.pipes import SearchRAGPipe
|
||||
from core.pipes import RAGPipe
|
||||
|
||||
return SearchRAGPipe(
|
||||
return RAGPipe(
|
||||
llm_provider=self.providers.llm,
|
||||
database_provider=self.providers.database,
|
||||
config=GeneratorPipe.PipeConfig(
|
||||
@@ -578,67 +574,65 @@ class R2RPipeFactory:
|
||||
),
|
||||
)
|
||||
|
||||
def create_kg_entity_description_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGEntityDescriptionPipe
|
||||
def create_graph_description_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphDescriptionPipe
|
||||
|
||||
return KGEntityDescriptionPipe(
|
||||
return GraphDescriptionPipe(
|
||||
database_provider=self.providers.database,
|
||||
llm_provider=self.providers.llm,
|
||||
embedding_provider=self.providers.embedding,
|
||||
config=AsyncPipe.PipeConfig(name="kg_entity_description_pipe"),
|
||||
config=AsyncPipe.PipeConfig(name="graph_description_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_clustering_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGClusteringPipe
|
||||
def create_graph_clustering_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphClusteringPipe
|
||||
|
||||
return KGClusteringPipe(
|
||||
return GraphClusteringPipe(
|
||||
database_provider=self.providers.database,
|
||||
llm_provider=self.providers.llm,
|
||||
embedding_provider=self.providers.embedding,
|
||||
config=AsyncPipe.PipeConfig(name="kg_clustering_pipe"),
|
||||
config=AsyncPipe.PipeConfig(name="graph_clustering_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGEntityDeduplicationSummaryPipe
|
||||
from core.pipes import GraphDeduplicationSummaryPipe
|
||||
|
||||
return KGEntityDeduplicationSummaryPipe(
|
||||
return GraphDeduplicationSummaryPipe(
|
||||
database_provider=self.providers.database,
|
||||
llm_provider=self.providers.llm,
|
||||
embedding_provider=self.providers.embedding,
|
||||
config=AsyncPipe.PipeConfig(name="kg_deduplication_summary_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_community_summary_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGCommunitySummaryPipe
|
||||
def create_graph_community_summary_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphCommunitySummaryPipe
|
||||
|
||||
return KGCommunitySummaryPipe(
|
||||
return GraphCommunitySummaryPipe(
|
||||
database_provider=self.providers.database,
|
||||
llm_provider=self.providers.llm,
|
||||
embedding_provider=self.providers.embedding,
|
||||
config=AsyncPipe.PipeConfig(name="kg_community_summary_pipe"),
|
||||
config=AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_entity_deduplication_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import KGEntityDeduplicationPipe
|
||||
def create_graph_deduplication_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphDeduplicationPipe
|
||||
|
||||
return KGEntityDeduplicationPipe(
|
||||
return GraphDeduplicationPipe(
|
||||
database_provider=self.providers.database,
|
||||
llm_provider=self.providers.llm,
|
||||
embedding_provider=self.providers.embedding,
|
||||
config=AsyncPipe.PipeConfig(name="kg_entity_deduplication_pipe"),
|
||||
config=AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
|
||||
)
|
||||
|
||||
def create_kg_entity_deduplication_summary_pipe(
|
||||
self, *args, **kwargs
|
||||
) -> Any:
|
||||
from core.pipes import KGEntityDeduplicationSummaryPipe
|
||||
def create_graph_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
|
||||
from core.pipes import GraphDeduplicationSummaryPipe
|
||||
|
||||
return KGEntityDeduplicationSummaryPipe(
|
||||
return GraphDeduplicationSummaryPipe(
|
||||
database_provider=self.providers.database,
|
||||
llm_provider=self.providers.llm,
|
||||
embedding_provider=self.providers.embedding,
|
||||
config=AsyncPipe.PipeConfig(
|
||||
name="kg_entity_deduplication_summary_pipe"
|
||||
name="graph_deduplication_summary_pipe"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -664,7 +658,7 @@ class R2RPipelineFactory:
|
||||
self.pipes.vector_search_pipe, vector_search_pipe=True
|
||||
)
|
||||
search_pipeline.add_pipe(
|
||||
self.pipes.kg_search_pipe, kg_search_pipe=True
|
||||
self.pipes.graph_search_pipe, graph_search_pipe=True
|
||||
)
|
||||
|
||||
return search_pipeline
|
||||
|
||||
@@ -11,7 +11,7 @@ from core import GenerationConfig
|
||||
from core.base import OrchestrationProvider, R2RException
|
||||
from core.base.abstractions import KGEnrichmentStatus, KGExtractionStatus
|
||||
|
||||
from ...services import KgService
|
||||
from ...services import GraphService
|
||||
|
||||
logger = logging.getLogger()
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def hatchet_kg_factory(
|
||||
orchestration_provider: OrchestrationProvider, service: KgService
|
||||
orchestration_provider: OrchestrationProvider, service: GraphService
|
||||
) -> dict[str, "Hatchet.Workflow"]:
|
||||
|
||||
def convert_to_dict(input_data):
|
||||
@@ -124,7 +124,7 @@ def hatchet_kg_factory(
|
||||
|
||||
@orchestration_provider.workflow(name="kg-extract", timeout="360m")
|
||||
class KGExtractDescribeEmbedWorkflow:
|
||||
def __init__(self, kg_service: KgService):
|
||||
def __init__(self, kg_service: GraphService):
|
||||
self.kg_service = kg_service
|
||||
|
||||
@orchestration_provider.concurrency( # type: ignore
|
||||
@@ -273,7 +273,7 @@ def hatchet_kg_factory(
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def __init__(self, kg_service: KgService):
|
||||
def __init__(self, kg_service: GraphService):
|
||||
self.kg_service = kg_service
|
||||
|
||||
@orchestration_provider.step(retries=1)
|
||||
@@ -392,7 +392,7 @@ def hatchet_kg_factory(
|
||||
name="entity-deduplication", timeout="360m"
|
||||
)
|
||||
class EntityDeduplicationWorkflow:
|
||||
def __init__(self, kg_service: KgService):
|
||||
def __init__(self, kg_service: GraphService):
|
||||
self.kg_service = kg_service
|
||||
|
||||
@orchestration_provider.step(retries=0, timeout="360m")
|
||||
@@ -460,7 +460,7 @@ def hatchet_kg_factory(
|
||||
name="kg-entity-deduplication-summary", timeout="360m"
|
||||
)
|
||||
class EntityDeduplicationSummaryWorkflow:
|
||||
def __init__(self, kg_service: KgService):
|
||||
def __init__(self, kg_service: GraphService):
|
||||
self.kg_service = kg_service
|
||||
|
||||
@orchestration_provider.step(retries=0, timeout="360m")
|
||||
@@ -490,7 +490,7 @@ def hatchet_kg_factory(
|
||||
|
||||
@orchestration_provider.workflow(name="build-communities", timeout="360m")
|
||||
class EnrichGraphWorkflow:
|
||||
def __init__(self, kg_service: KgService):
|
||||
def __init__(self, kg_service: GraphService):
|
||||
self.kg_service = kg_service
|
||||
|
||||
@orchestration_provider.step(retries=1, parents=[], timeout="360m")
|
||||
@@ -642,7 +642,7 @@ def hatchet_kg_factory(
|
||||
name="kg-community-summary", timeout="360m"
|
||||
)
|
||||
class KGCommunitySummaryWorkflow:
|
||||
def __init__(self, kg_service: KgService):
|
||||
def __init__(self, kg_service: GraphService):
|
||||
self.kg_service = kg_service
|
||||
|
||||
@orchestration_provider.concurrency( # type: ignore
|
||||
|
||||
@@ -6,12 +6,12 @@ import uuid
|
||||
from core import GenerationConfig, R2RException
|
||||
from core.base.abstractions import KGEnrichmentStatus
|
||||
|
||||
from ...services import KgService
|
||||
from ...services import GraphService
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def simple_kg_factory(service: KgService):
|
||||
def simple_kg_factory(service: GraphService):
|
||||
|
||||
def get_input_data_dict(input_data):
|
||||
for key, value in input_data.items():
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .auth_service import AuthService
|
||||
from .graph_service import GraphService
|
||||
from .ingestion_service import IngestionService, IngestionServiceAdapter
|
||||
from .kg_service import KgService
|
||||
from .management_service import ManagementService
|
||||
from .retrieval_service import RetrievalService
|
||||
|
||||
@@ -9,6 +9,6 @@ __all__ = [
|
||||
"IngestionService",
|
||||
"IngestionServiceAdapter",
|
||||
"ManagementService",
|
||||
"KgService",
|
||||
"GraphService",
|
||||
"RetrievalService",
|
||||
]
|
||||
|
||||
@@ -46,9 +46,9 @@ async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
|
||||
return results
|
||||
|
||||
|
||||
# TODO - Fix naming convention to read `KGService` instead of `KgService`
|
||||
# TODO - Fix naming convention to read `KGService` instead of `GraphService`
|
||||
# this will require a minor change in how services are registered.
|
||||
class KgService(Service):
|
||||
class GraphService(Service):
|
||||
def __init__(
|
||||
self,
|
||||
config: R2RConfig,
|
||||
@@ -90,8 +90,8 @@ class KgService(Service):
|
||||
status=KGExtractionStatus.PROCESSING,
|
||||
)
|
||||
|
||||
relationships = await self.pipes.kg_relationships_extraction_pipe.run(
|
||||
input=self.pipes.kg_relationships_extraction_pipe.Input(
|
||||
relationships = await self.pipes.graph_extraction_pipe.run(
|
||||
input=self.pipes.graph_extraction_pipe.Input(
|
||||
message={
|
||||
"document_id": document_id,
|
||||
"generation_config": generation_config,
|
||||
@@ -110,8 +110,10 @@ class KgService(Service):
|
||||
f"KGService: Finished processing document {document_id} for KG extraction"
|
||||
)
|
||||
|
||||
result_gen = await self.pipes.kg_storage_pipe.run(
|
||||
input=self.pipes.kg_storage_pipe.Input(message=relationships),
|
||||
result_gen = await self.pipes.graph_storage_pipe.run(
|
||||
input=self.pipes.graph_storage_pipe.Input(
|
||||
message=relationships
|
||||
),
|
||||
state=None,
|
||||
run_manager=self.run_manager,
|
||||
)
|
||||
@@ -525,8 +527,8 @@ class KgService(Service):
|
||||
f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
|
||||
)
|
||||
|
||||
node_descriptions = await self.pipes.kg_entity_description_pipe.run(
|
||||
input=self.pipes.kg_entity_description_pipe.Input(
|
||||
node_descriptions = await self.pipes.graph_description_pipe.run(
|
||||
input=self.pipes.graph_description_pipe.Input(
|
||||
message={
|
||||
"offset": i * 256,
|
||||
"limit": 256,
|
||||
@@ -571,8 +573,8 @@ class KgService(Service):
|
||||
f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
|
||||
)
|
||||
|
||||
clustering_result = await self.pipes.kg_clustering_pipe.run(
|
||||
input=self.pipes.kg_clustering_pipe.Input(
|
||||
clustering_result = await self.pipes.graph_clustering_pipe.run(
|
||||
input=self.pipes.graph_clustering_pipe.Input(
|
||||
message={
|
||||
"collection_id": collection_id,
|
||||
"generation_config": generation_config,
|
||||
@@ -597,8 +599,8 @@ class KgService(Service):
|
||||
# graph_id: UUID | None,
|
||||
**kwargs,
|
||||
):
|
||||
summary_results = await self.pipes.kg_community_summary_pipe.run(
|
||||
input=self.pipes.kg_community_summary_pipe.Input(
|
||||
summary_results = await self.pipes.graph_community_summary_pipe.run(
|
||||
input=self.pipes.graph_community_summary_pipe.Input(
|
||||
message={
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
@@ -716,8 +718,8 @@ class KgService(Service):
|
||||
generation_config: GenerationConfig,
|
||||
**kwargs,
|
||||
):
|
||||
deduplication_results = await self.pipes.kg_entity_deduplication_pipe.run(
|
||||
input=self.pipes.kg_entity_deduplication_pipe.Input(
|
||||
deduplication_results = await self.pipes.graph_deduplication_pipe.run(
|
||||
input=self.pipes.graph_deduplication_pipe.Input(
|
||||
message={
|
||||
"collection_id": collection_id,
|
||||
"graph_id": graph_id,
|
||||
@@ -747,8 +749,8 @@ class KgService(Service):
|
||||
logger.info(
|
||||
f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
|
||||
)
|
||||
deduplication_summary_results = await self.pipes.kg_entity_deduplication_summary_pipe.run(
|
||||
input=self.pipes.kg_entity_deduplication_summary_pipe.Input(
|
||||
deduplication_summary_results = await self.pipes.graph_deduplication_summary_pipe.run(
|
||||
input=self.pipes.graph_deduplication_summary_pipe.Input(
|
||||
message={
|
||||
"collection_id": collection_id,
|
||||
"offset": offset,
|
||||
@@ -780,7 +782,7 @@ class KgService(Service):
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Processing document {document_id} for KG extraction",
|
||||
f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
|
||||
)
|
||||
|
||||
# Then create the extractions from the results
|
||||
@@ -835,7 +837,7 @@ class KgService(Service):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
# sort the extractions accroding to chunk_order field in metadata in ascending order
|
||||
@@ -851,7 +853,7 @@ class KgService(Service):
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
tasks = [
|
||||
@@ -873,7 +875,7 @@ class KgService(Service):
|
||||
total_tasks = len(tasks)
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
|
||||
f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
|
||||
)
|
||||
|
||||
for completed_task in asyncio.as_completed(tasks):
|
||||
@@ -882,7 +884,7 @@ class KgService(Service):
|
||||
completed_tasks += 1
|
||||
if completed_tasks % 100 == 0:
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
|
||||
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Extracting KG Relationships: {e}")
|
||||
@@ -892,7 +894,7 @@ class KgService(Service):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
async def _extract_kg(
|
||||
@@ -1044,7 +1046,7 @@ class KgService(Service):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
|
||||
f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
|
||||
)
|
||||
|
||||
return KGExtraction(
|
||||
@@ -25,7 +25,7 @@ class SearchPipeline(AsyncPipeline):
|
||||
super().__init__(run_manager)
|
||||
self._parsing_pipe: Optional[AsyncPipe] = None
|
||||
self._vector_search_pipeline: Optional[AsyncPipeline] = None
|
||||
self._kg_search_pipeline: Optional[AsyncPipeline] = None
|
||||
self._graph_search_pipeline: Optional[AsyncPipeline] = None
|
||||
|
||||
async def run( # type: ignore
|
||||
self,
|
||||
@@ -68,7 +68,7 @@ class SearchPipeline(AsyncPipeline):
|
||||
)
|
||||
)
|
||||
kg_task = asyncio.create_task(
|
||||
self._kg_search_pipeline.run(
|
||||
self._graph_search_pipeline.run(
|
||||
dequeue_requests(kg_queue),
|
||||
request_state,
|
||||
stream,
|
||||
@@ -93,22 +93,22 @@ class SearchPipeline(AsyncPipeline):
|
||||
self,
|
||||
pipe: AsyncPipe,
|
||||
add_upstream_outputs: Optional[list[dict[str, str]]] = None,
|
||||
kg_search_pipe: bool = False,
|
||||
graph_search_pipe: bool = False,
|
||||
vector_search_pipe: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline")
|
||||
|
||||
if kg_search_pipe:
|
||||
if not self._kg_search_pipeline:
|
||||
self._kg_search_pipeline = AsyncPipeline()
|
||||
if not self._kg_search_pipeline:
|
||||
if graph_search_pipe:
|
||||
if not self._graph_search_pipeline:
|
||||
self._graph_search_pipeline = AsyncPipeline()
|
||||
if not self._graph_search_pipeline:
|
||||
raise ValueError(
|
||||
"KG search pipeline not found"
|
||||
) # for type hinting
|
||||
|
||||
self._kg_search_pipeline.add_pipe(
|
||||
self._graph_search_pipeline.add_pipe(
|
||||
pipe, add_upstream_outputs, *args, **kwargs
|
||||
)
|
||||
elif vector_search_pipe:
|
||||
|
||||
+20
-20
@@ -3,39 +3,39 @@ from .abstractions.search_pipe import SearchPipe
|
||||
from .ingestion.embedding_pipe import EmbeddingPipe
|
||||
from .ingestion.parsing_pipe import ParsingPipe
|
||||
from .ingestion.vector_storage_pipe import VectorStoragePipe
|
||||
from .kg.clustering import KGClusteringPipe
|
||||
from .kg.community_summary import KGCommunitySummaryPipe
|
||||
from .kg.deduplication import KGEntityDeduplicationPipe
|
||||
from .kg.deduplication_summary import KGEntityDeduplicationSummaryPipe
|
||||
from .kg.description import KGEntityDescriptionPipe
|
||||
from .kg.extraction import KGExtractionPipe
|
||||
from .kg.storage import KGStoragePipe
|
||||
from .kg.clustering import GraphClusteringPipe
|
||||
from .kg.community_summary import GraphCommunitySummaryPipe
|
||||
from .kg.deduplication import GraphDeduplicationPipe
|
||||
from .kg.deduplication_summary import GraphDeduplicationSummaryPipe
|
||||
from .kg.description import GraphDescriptionPipe
|
||||
from .kg.extraction import GraphExtractionPipe
|
||||
from .kg.storage import GraphStoragePipe
|
||||
from .retrieval.chunk_search_pipe import VectorSearchPipe
|
||||
from .retrieval.kg_search_pipe import KGSearchSearchPipe
|
||||
from .retrieval.graph_search_pipe import GraphSearchSearchPipe
|
||||
from .retrieval.multi_search import MultiSearchPipe
|
||||
from .retrieval.query_transform_pipe import QueryTransformPipe
|
||||
from .retrieval.routing_search_pipe import RoutingSearchPipe
|
||||
from .retrieval.search_rag_pipe import SearchRAGPipe
|
||||
from .retrieval.streaming_rag_pipe import StreamingSearchRAGPipe
|
||||
from .retrieval.search_rag_pipe import RAGPipe
|
||||
from .retrieval.streaming_rag_pipe import StreamingRAGPipe
|
||||
|
||||
__all__ = [
|
||||
"SearchPipe",
|
||||
"GeneratorPipe",
|
||||
"EmbeddingPipe",
|
||||
"KGExtractionPipe",
|
||||
"KGSearchSearchPipe",
|
||||
"KGEntityDescriptionPipe",
|
||||
"GraphExtractionPipe",
|
||||
"GraphSearchSearchPipe",
|
||||
"GraphDescriptionPipe",
|
||||
"ParsingPipe",
|
||||
"QueryTransformPipe",
|
||||
"SearchRAGPipe",
|
||||
"StreamingSearchRAGPipe",
|
||||
"RAGPipe",
|
||||
"StreamingRAGPipe",
|
||||
"VectorSearchPipe",
|
||||
"VectorStoragePipe",
|
||||
"KGStoragePipe",
|
||||
"KGClusteringPipe",
|
||||
"GraphStoragePipe",
|
||||
"GraphClusteringPipe",
|
||||
"MultiSearchPipe",
|
||||
"KGCommunitySummaryPipe",
|
||||
"GraphCommunitySummaryPipe",
|
||||
"RoutingSearchPipe",
|
||||
"KGEntityDeduplicationPipe",
|
||||
"KGEntityDeduplicationSummaryPipe",
|
||||
"GraphDeduplicationPipe",
|
||||
"GraphDeduplicationSummaryPipe",
|
||||
]
|
||||
|
||||
@@ -15,7 +15,7 @@ from core.database import PostgresDatabaseProvider
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class KGClusteringPipe(AsyncPipe):
|
||||
class GraphClusteringPipe(AsyncPipe):
|
||||
"""
|
||||
Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
|
||||
"""
|
||||
|
||||
@@ -21,7 +21,7 @@ from ...database.postgres import PostgresDatabaseProvider
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class KGCommunitySummaryPipe(AsyncPipe):
|
||||
class GraphCommunitySummaryPipe(AsyncPipe):
|
||||
"""
|
||||
Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
|
||||
"""
|
||||
@@ -40,7 +40,7 @@ class KGCommunitySummaryPipe(AsyncPipe):
|
||||
"""
|
||||
super().__init__(
|
||||
config=config
|
||||
or AsyncPipe.PipeConfig(name="kg_community_summary_pipe"),
|
||||
or AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
|
||||
)
|
||||
self.database_provider = database_provider
|
||||
self.llm_provider = llm_provider
|
||||
@@ -210,7 +210,7 @@ class KGCommunitySummaryPipe(AsyncPipe):
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error(
|
||||
f"KGCommunitySummaryPipe: Error generating community summary for community {community_id}: {e}"
|
||||
f"GraphCommunitySummaryPipe: Error generating community summary for community {community_id}: {e}"
|
||||
)
|
||||
return {
|
||||
"community_id": community_id,
|
||||
@@ -265,7 +265,7 @@ class KGCommunitySummaryPipe(AsyncPipe):
|
||||
|
||||
# check which community summaries exist and don't run them again
|
||||
logger.info(
|
||||
f"KGCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
|
||||
f"GraphCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
|
||||
)
|
||||
|
||||
all_entities, _ = (
|
||||
@@ -335,12 +335,12 @@ class KGCommunitySummaryPipe(AsyncPipe):
|
||||
completed_community_summary_jobs += 1
|
||||
if completed_community_summary_jobs % 50 == 0:
|
||||
logger.info(
|
||||
f"KGCommunitySummaryPipe: {completed_community_summary_jobs}/{total_jobs} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds"
|
||||
f"GraphCommunitySummaryPipe: {completed_community_summary_jobs}/{total_jobs} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds"
|
||||
)
|
||||
|
||||
if "error" in summary:
|
||||
logger.error(
|
||||
f"KGCommunitySummaryPipe: Error generating community summary for community {summary['community_id']}: {summary['error']}"
|
||||
f"GraphCommunitySummaryPipe: Error generating community summary for community {summary['community_id']}: {summary['error']}"
|
||||
)
|
||||
total_errors += 1
|
||||
continue
|
||||
@@ -349,5 +349,5 @@ class KGCommunitySummaryPipe(AsyncPipe):
|
||||
|
||||
if total_errors > 0:
|
||||
raise ValueError(
|
||||
f"KGCommunitySummaryPipe: Failed to generate community summaries for {total_errors} out of {total_jobs} communities. Please rerun the job if there are too many failures."
|
||||
f"GraphCommunitySummaryPipe: Failed to generate community summaries for {total_errors} out of {total_jobs} communities. Please rerun the job if there are too many failures."
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.providers import (
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class KGEntityDeduplicationPipe(AsyncPipe):
|
||||
class GraphDeduplicationPipe(AsyncPipe):
|
||||
def __init__(
|
||||
self,
|
||||
config: AsyncPipe.PipeConfig,
|
||||
@@ -33,7 +33,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
|
||||
):
|
||||
super().__init__(
|
||||
config=config
|
||||
or AsyncPipe.PipeConfig(name="kg_entity_deduplication_pipe"),
|
||||
or AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
|
||||
)
|
||||
self.database_provider = database_provider
|
||||
self.llm_provider = llm_provider
|
||||
@@ -69,7 +69,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
|
||||
entities = await self._get_entities(graph_id, collection_id)
|
||||
|
||||
logger.info(
|
||||
f"KGEntityDeduplicationPipe: Got {len(entities)} entities for {graph_id or collection_id}"
|
||||
f"GraphDeduplicationPipe: Got {len(entities)} entities for {graph_id or collection_id}"
|
||||
)
|
||||
|
||||
# deduplicate entities by name
|
||||
@@ -129,7 +129,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
|
||||
f"GraphDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
|
||||
)
|
||||
|
||||
await self.database_provider.graphs_handler.add_entities(
|
||||
@@ -171,7 +171,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
|
||||
embeddings = [entity.description_embedding for entity in entities]
|
||||
|
||||
logger.info(
|
||||
f"KGEntityDeduplicationPipe: Running DBSCAN clustering on {len(embeddings)} embeddings"
|
||||
f"GraphDeduplicationPipe: Running DBSCAN clustering on {len(embeddings)} embeddings"
|
||||
)
|
||||
# TODO: make eps a config, make it very strict for now
|
||||
clustering = DBSCAN(eps=0.1, min_samples=2, metric="cosine").fit(
|
||||
@@ -183,7 +183,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
|
||||
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
||||
n_noise = list(labels).count(-1)
|
||||
logger.info(
|
||||
f"KGEntityDeduplicationPipe: Found {n_clusters} clusters and {n_noise} noise points"
|
||||
f"GraphDeduplicationPipe: Found {n_clusters} clusters and {n_noise} noise points"
|
||||
)
|
||||
|
||||
# for all labels in the same cluster, we can deduplicate them by name
|
||||
@@ -236,7 +236,7 @@ class KGEntityDeduplicationPipe(AsyncPipe):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
|
||||
f"GraphDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
|
||||
)
|
||||
await self.database_provider.graphs_handler.add_entities(
|
||||
deduplicated_entities_list,
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.providers import ( # PostgresDatabaseProvider,
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class KGEntityDeduplicationSummaryPipe(AsyncPipe[Any]):
|
||||
class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
|
||||
|
||||
class Input(AsyncPipe.Input):
|
||||
message: dict
|
||||
|
||||
@@ -15,7 +15,7 @@ from ...database.postgres import PostgresDatabaseProvider
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class KGEntityDescriptionPipe(AsyncPipe):
|
||||
class GraphDescriptionPipe(AsyncPipe):
|
||||
"""
|
||||
The pipe takes input a list of nodes and extracts description from them.
|
||||
"""
|
||||
@@ -147,7 +147,7 @@ class KGEntityDescriptionPipe(AsyncPipe):
|
||||
logger = input.message["logger"]
|
||||
|
||||
logger.info(
|
||||
f"KGEntityDescriptionPipe: Getting entity map for document {document_id}",
|
||||
f"GraphDescriptionPipe: Getting entity map for document {document_id}",
|
||||
)
|
||||
|
||||
entity_map = (
|
||||
@@ -158,7 +158,7 @@ class KGEntityDescriptionPipe(AsyncPipe):
|
||||
total_entities = len(entity_map)
|
||||
|
||||
logger.info(
|
||||
f"KGEntityDescriptionPipe: Got entity map for document {document_id}, total entities: {total_entities}, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphDescriptionPipe: Got entity map for document {document_id}, total entities: {total_entities}, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
workflows = []
|
||||
@@ -182,11 +182,11 @@ class KGEntityDescriptionPipe(AsyncPipe):
|
||||
for result in asyncio.as_completed(workflows):
|
||||
if completed_entities % 100 == 0:
|
||||
logger.info(
|
||||
f"KGEntityDescriptionPipe: Completed {completed_entities+1} of {total_entities} entities for document {document_id}",
|
||||
f"GraphDescriptionPipe: Completed {completed_entities+1} of {total_entities} entities for document {document_id}",
|
||||
)
|
||||
yield await result
|
||||
completed_entities += 1
|
||||
|
||||
logger.info(
|
||||
f"KGEntityDescriptionPipe: Processed {total_entities} entities for document {document_id}, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphDescriptionPipe: Processed {total_entities} entities for document {document_id}, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ class ClientError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KGExtractionPipe(AsyncPipe[dict]):
|
||||
class GraphExtractionPipe(AsyncPipe[dict]):
|
||||
"""
|
||||
Extracts knowledge graph information from document extractions.
|
||||
"""
|
||||
@@ -54,9 +54,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
):
|
||||
super().__init__(
|
||||
config=config
|
||||
or AsyncPipe.PipeConfig(
|
||||
name="default_kg_relationships_extraction_pipe"
|
||||
),
|
||||
or AsyncPipe.PipeConfig(name="default_graph_extraction_pipe"),
|
||||
)
|
||||
self.database_provider = database_provider
|
||||
self.llm_provider = llm_provider
|
||||
@@ -198,7 +196,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
# add metadata to entities and relationships
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {extractions[0].document_id}",
|
||||
f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {extractions[0].document_id}",
|
||||
)
|
||||
|
||||
return KGExtraction(
|
||||
@@ -233,7 +231,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
logger = input.message.get("logger", logging.getLogger())
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Processing document {document_id} for KG extraction",
|
||||
f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
|
||||
)
|
||||
|
||||
# Then create the extractions from the results
|
||||
@@ -277,7 +275,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
# sort the extractions accroding to chunk_order field in metadata in ascending order
|
||||
@@ -293,7 +291,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Extracting KG Relationships for document and created {len(extractions_groups)} tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(extractions_groups)} tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
tasks = [
|
||||
@@ -315,7 +313,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
total_tasks = len(tasks)
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
|
||||
f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
|
||||
)
|
||||
|
||||
for completed_task in asyncio.as_completed(tasks):
|
||||
@@ -324,7 +322,7 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
completed_tasks += 1
|
||||
if completed_tasks % 100 == 0:
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
|
||||
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Extracting KG Relationships: {e}")
|
||||
@@ -334,5 +332,5 @@ class KGExtractionPipe(AsyncPipe[dict]):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.database import PostgresDatabaseProvider
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class KGStoragePipe(AsyncPipe):
|
||||
class GraphStoragePipe(AsyncPipe):
|
||||
# TODO - Apply correct type hints to storage messages
|
||||
class Input(AsyncPipe.Input):
|
||||
message: AsyncGenerator[list[Any], None]
|
||||
@@ -27,7 +27,7 @@ class KGStoragePipe(AsyncPipe):
|
||||
Initializes the async knowledge graph storage pipe with necessary components and configurations.
|
||||
"""
|
||||
logger.info(
|
||||
f"Initializing an `KGStoragePipe` to store knowledge graph extractions in a graph database."
|
||||
f"Initializing an `GraphStoragePipe` to store knowledge graph extractions in a graph database."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
|
||||
+1
-1
@@ -24,7 +24,7 @@ from ..abstractions.generator_pipe import GeneratorPipe
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class KGSearchSearchPipe(GeneratorPipe):
|
||||
class GraphSearchSearchPipe(GeneratorPipe):
|
||||
"""
|
||||
Embeds and stores documents using a specified embedding model and database.
|
||||
"""
|
||||
@@ -14,7 +14,7 @@ from core.base.abstractions import GenerationConfig, RAGCompletion
|
||||
from ..abstractions.generator_pipe import GeneratorPipe
|
||||
|
||||
|
||||
class SearchRAGPipe(GeneratorPipe):
|
||||
class RAGPipe(GeneratorPipe):
|
||||
class Input(AsyncPipe.Input):
|
||||
message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..abstractions.generator_pipe import GeneratorPipe
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class StreamingSearchRAGPipe(GeneratorPipe):
|
||||
class StreamingRAGPipe(GeneratorPipe):
|
||||
CHUNK_SEARCH_STREAM_MARKER = (
|
||||
"search" # TODO - change this to vector_search in next major release
|
||||
)
|
||||
@@ -72,7 +72,7 @@ class StreamingSearchRAGPipe(GeneratorPipe):
|
||||
for chunk in self.llm_provider.get_completion_stream(
|
||||
messages=messages, generation_config=rag_generation_config
|
||||
):
|
||||
chunk_txt = StreamingSearchRAGPipe._process_chunk(chunk)
|
||||
chunk_txt = StreamingRAGPipe._process_chunk(chunk)
|
||||
response += chunk_txt
|
||||
yield chunk_txt
|
||||
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@ from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from shared.api.models.base import WrappedBooleanResponse
|
||||
from shared.api.models.kg.responses import (
|
||||
from shared.api.models.graph.responses import (
|
||||
WrappedCommunitiesResponse,
|
||||
WrappedCommunityResponse,
|
||||
WrappedEntitiesResponse,
|
||||
|
||||
@@ -10,17 +10,17 @@ from shared.api.models.base import (
|
||||
WrappedBooleanResponse,
|
||||
WrappedGenericMessageResponse,
|
||||
)
|
||||
from shared.api.models.graph.responses import (
|
||||
GraphResponse,
|
||||
WrappedGraphResponse,
|
||||
WrappedGraphsResponse,
|
||||
)
|
||||
from shared.api.models.ingestion.responses import (
|
||||
IngestionResponse,
|
||||
WrappedIngestionResponse,
|
||||
WrappedMetadataUpdateResponse,
|
||||
WrappedUpdateResponse,
|
||||
)
|
||||
from shared.api.models.kg.responses import (
|
||||
GraphResponse,
|
||||
WrappedGraphResponse,
|
||||
WrappedGraphsResponse,
|
||||
)
|
||||
from shared.api.models.management.responses import (
|
||||
AnalyticsResponse,
|
||||
ChunkResponse,
|
||||
|
||||
Reference in New Issue
Block a user