Dev minor (#1574)
* add add-hoc rerank implementation to embedding, add async rerank (#1572) * add HF defaults * Feature/add document summary to ingestion (#1573) * adds document summary to ingestion pipeline * cleanup impl * new hybrid document search * implement hybrid document search * Feature/add document summary to ingestion (#1575) * adds document summary to ingestion pipeline * cleanup impl * new hybrid document search * implement hybrid document search * add migration script * make the summary change non-breaking (#1576) * make the summary change non-breaking * rollbk * up * Feature/tweak downgrade logic (#1577) * tweak downgrade * fix js sdk * fix js sdk * fix upgrade logic * up
This commit is contained in:
Generated
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "r2r-js",
|
||||
"version": "0.3.15",
|
||||
"version": "0.3.16",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "r2r-js",
|
||||
"version": "0.3.15",
|
||||
"version": "0.3.16",
|
||||
"description": "",
|
||||
"main": "dist/index.js",
|
||||
"browser": "dist/index.browser.js",
|
||||
|
||||
+7
-24
@@ -1921,41 +1921,24 @@ export class r2rClient {
|
||||
/**
|
||||
* Search over documents.
|
||||
* @param query The query to search for.
|
||||
* @param settings Settings for the document search.
|
||||
* @param vector_search_settings Settings for the document search.
|
||||
* @returns A promise that resolves to the response from the server.
|
||||
*/
|
||||
@feature("searchDocuments")
|
||||
async searchDocuments(
|
||||
query: string,
|
||||
settings?: {
|
||||
searchOverMetadata?: boolean;
|
||||
metadataKeys?: string[];
|
||||
searchOverBody?: boolean;
|
||||
filters?: Record<string, any>;
|
||||
searchFilters?: Record<string, any>;
|
||||
offset?: number;
|
||||
limit?: number;
|
||||
titleWeight?: number;
|
||||
metadataWeight?: number;
|
||||
},
|
||||
vector_search_settings?: VectorSearchSettings | Record<string, any>,
|
||||
): Promise<any> {
|
||||
this._ensureAuthenticated();
|
||||
|
||||
const json_data: Record<string, any> = {
|
||||
query,
|
||||
settings: {
|
||||
search_over_metadata: settings?.searchOverMetadata ?? true,
|
||||
metadata_keys: settings?.metadataKeys ?? ["title"],
|
||||
search_over_body: settings?.searchOverBody ?? false,
|
||||
filters: settings?.filters ?? {},
|
||||
search_filters: settings?.searchFilters ?? {},
|
||||
offset: settings?.offset ?? 0,
|
||||
limit: settings?.limit ?? 10,
|
||||
title_weight: settings?.titleWeight ?? 0.5,
|
||||
metadata_weight: settings?.metadataWeight ?? 0.5,
|
||||
},
|
||||
vector_search_settings,
|
||||
};
|
||||
|
||||
Object.keys(json_data).forEach(
|
||||
(key) => json_data[key] === undefined && delete json_data[key],
|
||||
);
|
||||
|
||||
return await this._makeRequest("POST", "search_documents", {
|
||||
data: json_data,
|
||||
});
|
||||
|
||||
@@ -84,6 +84,7 @@ async def upgrade(schema, revision):
|
||||
click.echo(
|
||||
f"Running database upgrade for schema {schema or 'default'}..."
|
||||
)
|
||||
print(f"Upgrading revision = {revision}")
|
||||
command = f"upgrade {revision}" if revision else "upgrade"
|
||||
result = await run_alembic_command(command, schema_name=schema)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from cli.utils.timer import timer
|
||||
@click.option(
|
||||
"--query", prompt="Enter your search query", help="The search query"
|
||||
)
|
||||
# VectorSearchSettings
|
||||
# SearchSettings
|
||||
@click.option(
|
||||
"--use-vector-search",
|
||||
is_flag=True,
|
||||
|
||||
@@ -333,6 +333,10 @@ services:
|
||||
# Ollama
|
||||
- OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
|
||||
|
||||
# Huggingface
|
||||
- HUGGINGFACE_API_BASE=${HUGGINGFACE_API_BASE:-http://host.docker.internal:8080}
|
||||
- HUGGINGFACE_API_KEY=${HUGGINGFACE_API_KEY}
|
||||
|
||||
# Unstructured
|
||||
- UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
|
||||
- UNSTRUCTURED_API_URL=${UNSTRUCTURED_API_URL:-https://api.unstructured.io/general/v0/general}
|
||||
|
||||
+1
-2
@@ -78,8 +78,7 @@ __all__ = [
|
||||
"KGSearchResult",
|
||||
"KGSearchSettings",
|
||||
"VectorSearchResult",
|
||||
"VectorSearchSettings",
|
||||
"DocumentSearchSettings",
|
||||
"SearchSettings",
|
||||
"HybridSearchSettings",
|
||||
# User abstractions
|
||||
"Token",
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.base import (
|
||||
from core.base.abstractions import (
|
||||
AggregateSearchResult,
|
||||
KGSearchSettings,
|
||||
VectorSearchSettings,
|
||||
SearchSettings,
|
||||
)
|
||||
from core.base.agent import AgentConfig, Tool
|
||||
from core.base.providers import CompletionProvider
|
||||
@@ -57,7 +57,7 @@ class RAGAgentMixin:
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
vector_search_settings: VectorSearchSettings,
|
||||
vector_search_settings: SearchSettings,
|
||||
kg_search_settings: KGSearchSettings,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
@@ -46,8 +46,7 @@ __all__ = [
|
||||
"KGSearchResult",
|
||||
"KGSearchSettings",
|
||||
"VectorSearchResult",
|
||||
"VectorSearchSettings",
|
||||
"DocumentSearchSettings",
|
||||
"SearchSettings",
|
||||
"HybridSearchSettings",
|
||||
# KG abstractions
|
||||
"KGCreationSettings",
|
||||
|
||||
@@ -51,7 +51,6 @@ from shared.abstractions.llm import (
|
||||
from shared.abstractions.prompt import Prompt
|
||||
from shared.abstractions.search import (
|
||||
AggregateSearchResult,
|
||||
DocumentSearchSettings,
|
||||
HybridSearchSettings,
|
||||
KGCommunityResult,
|
||||
KGEntityResult,
|
||||
@@ -61,8 +60,8 @@ from shared.abstractions.search import (
|
||||
KGSearchResult,
|
||||
KGSearchResultType,
|
||||
KGSearchSettings,
|
||||
SearchSettings,
|
||||
VectorSearchResult,
|
||||
VectorSearchSettings,
|
||||
)
|
||||
from shared.abstractions.user import Token, TokenData, UserStats
|
||||
from shared.abstractions.vector import (
|
||||
@@ -130,8 +129,7 @@ __all__ = [
|
||||
"KGGlobalResult",
|
||||
"KGSearchSettings",
|
||||
"VectorSearchResult",
|
||||
"VectorSearchSettings",
|
||||
"DocumentSearchSettings",
|
||||
"SearchSettings",
|
||||
"HybridSearchSettings",
|
||||
# KG abstractions
|
||||
"KGCreationSettings",
|
||||
|
||||
@@ -27,7 +27,6 @@ from core.base import (
|
||||
)
|
||||
from core.base.abstractions import (
|
||||
DocumentInfo,
|
||||
DocumentSearchSettings,
|
||||
IndexArgsHNSW,
|
||||
IndexArgsIVFFlat,
|
||||
IndexMeasure,
|
||||
@@ -35,11 +34,11 @@ from core.base.abstractions import (
|
||||
KGCreationSettings,
|
||||
KGEnrichmentSettings,
|
||||
KGEntityDeduplicationSettings,
|
||||
SearchSettings,
|
||||
UserStats,
|
||||
VectorEntry,
|
||||
VectorQuantizationType,
|
||||
VectorSearchResult,
|
||||
VectorSearchSettings,
|
||||
VectorTableName,
|
||||
)
|
||||
from core.base.api.models import (
|
||||
@@ -256,6 +255,15 @@ class DocumentHandler(Handler):
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_documents(
|
||||
self,
|
||||
query_text: str,
|
||||
query_embedding: Optional[list[float]] = None,
|
||||
search_settings: Optional[SearchSettings] = None,
|
||||
) -> list[DocumentInfo]:
|
||||
pass
|
||||
|
||||
|
||||
class CollectionHandler(Handler):
|
||||
@abstractmethod
|
||||
@@ -511,28 +519,22 @@ class VectorHandler(Handler):
|
||||
|
||||
@abstractmethod
|
||||
async def semantic_search(
|
||||
self, query_vector: list[float], search_settings: VectorSearchSettings
|
||||
self, query_vector: list[float], search_settings: SearchSettings
|
||||
) -> list[VectorSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def full_text_search(
|
||||
self, query_text: str, search_settings: VectorSearchSettings
|
||||
self, query_text: str, search_settings: SearchSettings
|
||||
) -> list[VectorSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search_documents(
|
||||
self, query_text: str, settings: DocumentSearchSettings
|
||||
) -> list[dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query_text: str,
|
||||
query_vector: list[float],
|
||||
search_settings: VectorSearchSettings,
|
||||
search_settings: SearchSettings,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[VectorSearchResult]:
|
||||
@@ -1404,14 +1406,14 @@ class DatabaseProvider(Provider):
|
||||
return await self.vector_handler.upsert_entries(entries)
|
||||
|
||||
async def semantic_search(
|
||||
self, query_vector: list[float], search_settings: VectorSearchSettings
|
||||
self, query_vector: list[float], search_settings: SearchSettings
|
||||
) -> list[VectorSearchResult]:
|
||||
return await self.vector_handler.semantic_search(
|
||||
query_vector, search_settings
|
||||
)
|
||||
|
||||
async def full_text_search(
|
||||
self, query_text: str, search_settings: VectorSearchSettings
|
||||
self, query_text: str, search_settings: SearchSettings
|
||||
) -> list[VectorSearchResult]:
|
||||
return await self.vector_handler.full_text_search(
|
||||
query_text, search_settings
|
||||
@@ -1421,7 +1423,7 @@ class DatabaseProvider(Provider):
|
||||
self,
|
||||
query_text: str,
|
||||
query_vector: list[float],
|
||||
search_settings: VectorSearchSettings,
|
||||
search_settings: SearchSettings,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[VectorSearchResult]:
|
||||
@@ -1430,9 +1432,14 @@ class DatabaseProvider(Provider):
|
||||
)
|
||||
|
||||
async def search_documents(
|
||||
self, query_text: str, settings: DocumentSearchSettings
|
||||
) -> list[dict]:
|
||||
return await self.vector_handler.search_documents(query_text, settings)
|
||||
self,
|
||||
query_text: str,
|
||||
settings: SearchSettings,
|
||||
query_embedding: Optional[list[float]] = None,
|
||||
) -> list[DocumentInfo]:
|
||||
return await self.document_handler.search_documents(
|
||||
query_text, query_embedding, settings
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self, filters: dict[str, Any]
|
||||
|
||||
@@ -25,8 +25,7 @@ class EmbeddingConfig(ProviderConfig):
|
||||
base_model: str
|
||||
base_dimension: int
|
||||
rerank_model: Optional[str] = None
|
||||
rerank_dimension: Optional[int] = None
|
||||
rerank_transformer_type: Optional[str] = None
|
||||
rerank_url: Optional[str] = None
|
||||
batch_size: int = 1
|
||||
prefixes: Optional[dict[str, str]] = None
|
||||
add_title_as_prefix: bool = True
|
||||
@@ -38,6 +37,10 @@ class EmbeddingConfig(ProviderConfig):
|
||||
VectorQuantizationSettings()
|
||||
)
|
||||
|
||||
## deprecated
|
||||
rerank_dimension: Optional[int] = None
|
||||
rerank_transformer_type: Optional[str] = None
|
||||
|
||||
def validate_config(self) -> None:
|
||||
if self.provider not in self.supported_providers:
|
||||
raise ValueError(f"Provider '{self.provider}' is not supported.")
|
||||
@@ -171,6 +174,16 @@ class EmbeddingProvider(Provider):
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def arerank(
|
||||
self,
|
||||
query: str,
|
||||
results: list[VectorSearchResult],
|
||||
stage: PipeStage = PipeStage.RERANK,
|
||||
limit: int = 10,
|
||||
):
|
||||
pass
|
||||
|
||||
def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
|
||||
self.prefixes = {}
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ class IngestionConfig(ProviderConfig):
|
||||
chunk_enrichment_settings: ChunkEnrichmentSettings = (
|
||||
ChunkEnrichmentSettings()
|
||||
)
|
||||
|
||||
extra_parsers: dict[str, str] = {}
|
||||
|
||||
audio_transcription_model: str = "openai/whisper-1"
|
||||
@@ -29,6 +28,12 @@ class IngestionConfig(ProviderConfig):
|
||||
vision_pdf_prompt_name: str = "vision_pdf"
|
||||
vision_pdf_model: str = "openai/gpt-4-mini"
|
||||
|
||||
skip_document_summary: bool = False
|
||||
document_summary_system_prompt: str = "default_system"
|
||||
document_summary_task_prompt: str = "default_summary"
|
||||
chunks_for_document_summary: int = 128
|
||||
document_summary_model: str = "openai/gpt-4o-mini"
|
||||
|
||||
@property
|
||||
def supported_providers(self) -> list[str]:
|
||||
return ["r2r", "unstructured_local", "unstructured_api"]
|
||||
|
||||
@@ -66,6 +66,8 @@ new_after_n_chars = 512
|
||||
max_characters = 1_024
|
||||
combine_under_n_chars = 128
|
||||
overlap = 20
|
||||
chunks_for_document_summary = 16
|
||||
document_summary_model = "ollama/llama3.1"
|
||||
|
||||
[orchestration]
|
||||
provider = "hatchet"
|
||||
|
||||
@@ -67,3 +67,6 @@ vision_pdf_model = "ollama/llama3.2-vision"
|
||||
|
||||
[ingestion.extra_parsers]
|
||||
pdf = "zerox"
|
||||
|
||||
chunks_for_document_summary = 16
|
||||
document_summary_model = "ollama/llama3.1"
|
||||
|
||||
@@ -40,5 +40,6 @@ audio_transcription_model="azure/whisper-1"
|
||||
vision_img_model = "azure/gpt-4o-mini"
|
||||
vision_pdf_model = "azure/gpt-4o-mini"
|
||||
|
||||
document_summary_model = "azure/gpt-4o-mini"
|
||||
[ingestion.chunk_enrichment_settings]
|
||||
generation_config = { model = "azure/gpt-4o-mini" }
|
||||
|
||||
@@ -11,10 +11,10 @@ from fastapi import (
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
UploadFile,
|
||||
HTTPException,
|
||||
)
|
||||
from pydantic import Json
|
||||
|
||||
|
||||
@@ -337,7 +337,7 @@ class ManagementRouter(BaseRouter):
|
||||
document_ids: list[str] = Query([]),
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(
|
||||
100,
|
||||
1_000,
|
||||
ge=-1,
|
||||
description="Number of items to return. Use -1 to return all items.",
|
||||
),
|
||||
|
||||
@@ -8,12 +8,11 @@ from fastapi import Body, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from core.base import (
|
||||
DocumentSearchSettings,
|
||||
GenerationConfig,
|
||||
KGSearchSettings,
|
||||
Message,
|
||||
R2RException,
|
||||
VectorSearchSettings,
|
||||
SearchSettings,
|
||||
)
|
||||
from core.base.api.models import (
|
||||
WrappedCompletionResponse,
|
||||
@@ -58,7 +57,7 @@ class RetrievalRouter(BaseRouter):
|
||||
def _select_filters(
|
||||
self,
|
||||
auth_user: Any,
|
||||
search_settings: Union[VectorSearchSettings, KGSearchSettings],
|
||||
search_settings: Union[SearchSettings, KGSearchSettings],
|
||||
) -> dict[str, Any]:
|
||||
selected_collections = {
|
||||
str(cid) for cid in set(search_settings.selected_collection_ids)
|
||||
@@ -111,8 +110,8 @@ class RetrievalRouter(BaseRouter):
|
||||
query: str = Body(
|
||||
..., description=search_descriptions.get("query")
|
||||
),
|
||||
settings: DocumentSearchSettings = Body(
|
||||
default_factory=DocumentSearchSettings,
|
||||
settings: SearchSettings = Body(
|
||||
default_factory=SearchSettings,
|
||||
description="Settings for document search",
|
||||
),
|
||||
auth_user=Depends(self.service.providers.auth.auth_wrapper),
|
||||
@@ -127,8 +126,14 @@ class RetrievalRouter(BaseRouter):
|
||||
Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
|
||||
"""
|
||||
|
||||
query_embedding = (
|
||||
await self.service.providers.embedding.async_get_embedding(
|
||||
query
|
||||
)
|
||||
)
|
||||
results = await self.service.search_documents(
|
||||
query=query,
|
||||
query_embedding=query_embedding,
|
||||
settings=settings,
|
||||
)
|
||||
return results
|
||||
@@ -142,8 +147,8 @@ class RetrievalRouter(BaseRouter):
|
||||
query: str = Body(
|
||||
..., description=search_descriptions.get("query")
|
||||
),
|
||||
vector_search_settings: VectorSearchSettings = Body(
|
||||
default_factory=VectorSearchSettings,
|
||||
vector_search_settings: SearchSettings = Body(
|
||||
default_factory=SearchSettings,
|
||||
description=search_descriptions.get("vector_search_settings"),
|
||||
),
|
||||
kg_search_settings: KGSearchSettings = Body(
|
||||
@@ -187,8 +192,8 @@ class RetrievalRouter(BaseRouter):
|
||||
@self.base_endpoint
|
||||
async def rag_app(
|
||||
query: str = Body(..., description=rag_descriptions.get("query")),
|
||||
vector_search_settings: VectorSearchSettings = Body(
|
||||
default_factory=VectorSearchSettings,
|
||||
vector_search_settings: SearchSettings = Body(
|
||||
default_factory=SearchSettings,
|
||||
description=rag_descriptions.get("vector_search_settings"),
|
||||
),
|
||||
kg_search_settings: KGSearchSettings = Body(
|
||||
@@ -261,8 +266,8 @@ class RetrievalRouter(BaseRouter):
|
||||
description=agent_descriptions.get("messages"),
|
||||
deprecated=True,
|
||||
),
|
||||
vector_search_settings: VectorSearchSettings = Body(
|
||||
default_factory=VectorSearchSettings,
|
||||
vector_search_settings: SearchSettings = Body(
|
||||
default_factory=SearchSettings,
|
||||
description=agent_descriptions.get("vector_search_settings"),
|
||||
),
|
||||
kg_search_settings: KGSearchSettings = Body(
|
||||
@@ -358,7 +363,27 @@ class RetrievalRouter(BaseRouter):
|
||||
This endpoint uses the language model to generate completions for the provided messages.
|
||||
The generation process can be customized using the generation_config parameter.
|
||||
"""
|
||||
print("messages = ", messages)
|
||||
|
||||
return await self.service.completion(
|
||||
messages=messages,
|
||||
messages=[message.to_dict() for message in messages],
|
||||
generation_config=generation_config,
|
||||
)
|
||||
|
||||
@self.router.post("/embedding")
|
||||
@self.base_endpoint
|
||||
async def embedding(
|
||||
content: str = Body(..., description="The content to embed"),
|
||||
auth_user=Depends(self.service.providers.auth.auth_wrapper),
|
||||
response_model=WrappedCompletionResponse,
|
||||
):
|
||||
"""
|
||||
Generate completions for a list of messages.
|
||||
|
||||
This endpoint uses the language model to generate completions for the provided messages.
|
||||
The generation process can be customized using the generation_config parameter.
|
||||
"""
|
||||
|
||||
return await self.service.providers.embedding.async_get_embedding(
|
||||
text=content
|
||||
)
|
||||
|
||||
+2
-2
@@ -1,11 +1,11 @@
|
||||
from typing import Union
|
||||
|
||||
from core.base import R2RException
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from core.base import R2RException
|
||||
from core.providers import (
|
||||
HatchetOrchestrationProvider,
|
||||
SimpleOrchestrationProvider,
|
||||
|
||||
@@ -5,10 +5,11 @@ from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from core.base import R2RException
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from core.base import R2RException
|
||||
|
||||
from .assembly import R2RBuilder, R2RConfig
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
|
||||
from fastapi import HTTPException
|
||||
from hatchet_sdk import ConcurrencyLimitStrategy, Context
|
||||
from litellm import AuthenticationError
|
||||
|
||||
@@ -103,6 +103,14 @@ def hatchet_ingestion_factory(
|
||||
# document_info_dict = context.step_output("parse")["document_info"]
|
||||
# document_info = DocumentInfo(**document_info_dict)
|
||||
|
||||
await service.update_document_status(
|
||||
document_info, status=IngestionStatus.AUGMENTING
|
||||
)
|
||||
await service.augment_document_info(
|
||||
document_info,
|
||||
[extraction.to_dict() for extraction in extractions],
|
||||
)
|
||||
|
||||
await self.ingestion_service.update_document_status(
|
||||
document_info,
|
||||
status=IngestionStatus.EMBEDDING,
|
||||
|
||||
@@ -2,9 +2,9 @@ import asyncio
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from litellm import AuthenticationError
|
||||
|
||||
from fastapi import HTTPException
|
||||
from core.base import DocumentExtraction, R2RException, increment_version
|
||||
from core.utils import (
|
||||
generate_default_user_collection_id,
|
||||
@@ -44,6 +44,11 @@ def simple_ingestion_factory(service: IngestionService):
|
||||
async for extraction in extractions_generator
|
||||
]
|
||||
|
||||
await service.update_document_status(
|
||||
document_info, status=IngestionStatus.AUGMENTING
|
||||
)
|
||||
await service.augment_document_info(document_info, extractions)
|
||||
|
||||
await service.update_document_status(
|
||||
document_info, status=IngestionStatus.EMBEDDING
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.base import (
|
||||
@@ -12,6 +13,7 @@ from core.base import (
|
||||
DocumentExtraction,
|
||||
DocumentInfo,
|
||||
DocumentType,
|
||||
GenerationConfig,
|
||||
IngestionStatus,
|
||||
R2RException,
|
||||
RawChunk,
|
||||
@@ -221,6 +223,43 @@ class IngestionService(Service):
|
||||
ingestion_config=ingestion_config,
|
||||
)
|
||||
|
||||
async def augment_document_info(
|
||||
self,
|
||||
document_info: DocumentInfo,
|
||||
chunked_documents: list[dict],
|
||||
) -> None:
|
||||
if not self.config.ingestion.skip_document_summary:
|
||||
document = f"Document Title: {document_info.title}\n"
|
||||
if document_info.metadata != {}:
|
||||
document += f"Document Metadata: {json.dumps(document_info.metadata)}\n"
|
||||
|
||||
document += "Document Text:\n"
|
||||
for chunk in chunked_documents[
|
||||
0 : self.config.ingestion.chunks_for_document_summary
|
||||
]:
|
||||
document += chunk["data"]
|
||||
|
||||
messages = await self.providers.database.prompt_handler.get_message_payload(
|
||||
system_prompt_name=self.config.ingestion.document_summary_system_prompt,
|
||||
task_prompt_name=self.config.ingestion.document_summary_task_prompt,
|
||||
task_inputs={"document": document},
|
||||
)
|
||||
response = await self.providers.llm.aget_completion(
|
||||
messages=messages,
|
||||
generation_config=GenerationConfig(model="openai/gpt-4o-mini"),
|
||||
)
|
||||
|
||||
document_info.summary = response.choices[0].message.content # type: ignore
|
||||
|
||||
if not document_info.summary:
|
||||
raise ValueError("Expected a generated response.")
|
||||
|
||||
embedding = await self.providers.embedding.async_get_embedding(
|
||||
text=document_info.summary,
|
||||
)
|
||||
document_info.summary_embedding = embedding
|
||||
return
|
||||
|
||||
async def embed_document(
|
||||
self,
|
||||
chunked_documents: list[dict],
|
||||
|
||||
@@ -3,6 +3,7 @@ import math
|
||||
import time
|
||||
from typing import AsyncGenerator, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.base import KGExtractionStatus, RunManager
|
||||
|
||||
@@ -3,18 +3,19 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core import R2RStreamingRAGAgent
|
||||
from core.base import (
|
||||
DocumentSearchSettings,
|
||||
DocumentInfo,
|
||||
EmbeddingPurpose,
|
||||
GenerationConfig,
|
||||
KGSearchSettings,
|
||||
Message,
|
||||
R2RException,
|
||||
RunManager,
|
||||
VectorSearchSettings,
|
||||
SearchSettings,
|
||||
manage_run,
|
||||
to_async_generator,
|
||||
)
|
||||
@@ -55,7 +56,7 @@ class RetrievalService(Service):
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
|
||||
vector_search_settings: SearchSettings = SearchSettings(),
|
||||
kg_search_settings: KGSearchSettings = KGSearchSettings(),
|
||||
*args,
|
||||
**kwargs,
|
||||
@@ -121,12 +122,14 @@ class RetrievalService(Service):
|
||||
async def search_documents(
|
||||
self,
|
||||
query: str,
|
||||
settings: DocumentSearchSettings,
|
||||
) -> list[dict]:
|
||||
settings: SearchSettings,
|
||||
query_embedding: Optional[list[float]] = None,
|
||||
) -> list[DocumentInfo]:
|
||||
|
||||
return await self.providers.database.search_documents(
|
||||
query_text=query,
|
||||
settings=settings,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
|
||||
@telemetry_event("Completion")
|
||||
@@ -149,7 +152,7 @@ class RetrievalService(Service):
|
||||
self,
|
||||
query: str,
|
||||
rag_generation_config: GenerationConfig,
|
||||
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
|
||||
vector_search_settings: SearchSettings = SearchSettings(),
|
||||
kg_search_settings: KGSearchSettings = KGSearchSettings(),
|
||||
*args,
|
||||
**kwargs,
|
||||
@@ -247,7 +250,7 @@ class RetrievalService(Service):
|
||||
async def agent(
|
||||
self,
|
||||
rag_generation_config: GenerationConfig,
|
||||
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
|
||||
vector_search_settings: SearchSettings = SearchSettings(),
|
||||
kg_search_settings: KGSearchSettings = KGSearchSettings(),
|
||||
task_prompt_override: Optional[str] = None,
|
||||
include_title_if_available: Optional[bool] = False,
|
||||
@@ -422,7 +425,7 @@ class RetrievalServiceAdapter:
|
||||
@staticmethod
|
||||
def prepare_search_input(
|
||||
query: str,
|
||||
vector_search_settings: VectorSearchSettings,
|
||||
vector_search_settings: SearchSettings,
|
||||
kg_search_settings: KGSearchSettings,
|
||||
user: UserResponse,
|
||||
) -> dict:
|
||||
@@ -437,7 +440,7 @@ class RetrievalServiceAdapter:
|
||||
def parse_search_input(data: dict):
|
||||
return {
|
||||
"query": data["query"],
|
||||
"vector_search_settings": VectorSearchSettings.from_dict(
|
||||
"vector_search_settings": SearchSettings.from_dict(
|
||||
data["vector_search_settings"]
|
||||
),
|
||||
"kg_search_settings": KGSearchSettings.from_dict(
|
||||
@@ -449,7 +452,7 @@ class RetrievalServiceAdapter:
|
||||
@staticmethod
|
||||
def prepare_rag_input(
|
||||
query: str,
|
||||
vector_search_settings: VectorSearchSettings,
|
||||
vector_search_settings: SearchSettings,
|
||||
kg_search_settings: KGSearchSettings,
|
||||
rag_generation_config: GenerationConfig,
|
||||
task_prompt_override: Optional[str],
|
||||
@@ -468,7 +471,7 @@ class RetrievalServiceAdapter:
|
||||
def parse_rag_input(data: dict):
|
||||
return {
|
||||
"query": data["query"],
|
||||
"vector_search_settings": VectorSearchSettings.from_dict(
|
||||
"vector_search_settings": SearchSettings.from_dict(
|
||||
data["vector_search_settings"]
|
||||
),
|
||||
"kg_search_settings": KGSearchSettings.from_dict(
|
||||
@@ -484,7 +487,7 @@ class RetrievalServiceAdapter:
|
||||
@staticmethod
|
||||
def prepare_agent_input(
|
||||
message: Message,
|
||||
vector_search_settings: VectorSearchSettings,
|
||||
vector_search_settings: SearchSettings,
|
||||
kg_search_settings: KGSearchSettings,
|
||||
rag_generation_config: GenerationConfig,
|
||||
task_prompt_override: Optional[str],
|
||||
@@ -509,7 +512,7 @@ class RetrievalServiceAdapter:
|
||||
def parse_agent_input(data: dict):
|
||||
return {
|
||||
"message": Message.from_dict(data["message"]),
|
||||
"vector_search_settings": VectorSearchSettings.from_dict(
|
||||
"vector_search_settings": SearchSettings.from_dict(
|
||||
data["vector_search_settings"]
|
||||
),
|
||||
"kg_search_settings": KGSearchSettings.from_dict(
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Optional
|
||||
from ..base.abstractions import (
|
||||
GenerationConfig,
|
||||
KGSearchSettings,
|
||||
VectorSearchSettings,
|
||||
SearchSettings,
|
||||
)
|
||||
from ..base.logger.base import RunType
|
||||
from ..base.logger.run_manager import RunManager, manage_run
|
||||
@@ -34,7 +34,7 @@ class RAGPipeline(AsyncPipeline):
|
||||
input: Any,
|
||||
state: Optional[AsyncState],
|
||||
run_manager: Optional[RunManager] = None,
|
||||
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
|
||||
vector_search_settings: SearchSettings = SearchSettings(),
|
||||
kg_search_settings: KGSearchSettings = KGSearchSettings(),
|
||||
rag_generation_config: GenerationConfig = GenerationConfig(),
|
||||
*args: Any,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Optional
|
||||
from ..base.abstractions import (
|
||||
AggregateSearchResult,
|
||||
KGSearchSettings,
|
||||
VectorSearchSettings,
|
||||
SearchSettings,
|
||||
)
|
||||
from ..base.logger.run_manager import RunManager, manage_run
|
||||
from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests
|
||||
@@ -35,7 +35,7 @@ class SearchPipeline(AsyncPipeline):
|
||||
state: Optional[AsyncState],
|
||||
stream: bool = False,
|
||||
run_manager: Optional[RunManager] = None,
|
||||
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
|
||||
vector_search_settings: SearchSettings = SearchSettings(),
|
||||
kg_search_settings: KGSearchSettings = KGSearchSettings(),
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.base import AsyncState
|
||||
|
||||
@@ -5,6 +5,7 @@ Pipe to tune the prompt for the KG model.
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.base import (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user