* add add-hoc rerank implementation to embedding, add async rerank (#1572)

* add HF defaults

* Feature/add document summary to ingestion (#1573)

* adds document summary to ingestion pipeline

* cleanup impl

* new hybrid document search

* implement hybrid document search

* Feature/add document summary to ingestion (#1575)

* adds document summary to ingestion pipeline

* cleanup impl

* new hybrid document search

* implement hybrid document search

* add migration script

* make the summary change non-breaking (#1576)

* make the summary change non-breaking

* rollbk

* up

* Feature/tweak downgrade logic (#1577)

* tweak downgrade

* fix js sdk

* fix js sdk

* fix upgrade logic

* up
This commit is contained in:
emrgnt-cmplxty
2024-11-12 18:16:02 -08:00
committed by GitHub
parent 22c0e26eb0
commit 2378f58242
60 changed files with 1778 additions and 809 deletions
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.3.15",
"version": "0.3.16",
"lockfileVersion": 3,
"requires": true,
"packages": {
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.3.15",
"version": "0.3.16",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
+7 -24
View File
@@ -1921,41 +1921,24 @@ export class r2rClient {
/**
* Search over documents.
* @param query The query to search for.
* @param settings Settings for the document search.
* @param vector_search_settings Settings for the document search.
* @returns A promise that resolves to the response from the server.
*/
@feature("searchDocuments")
async searchDocuments(
query: string,
settings?: {
searchOverMetadata?: boolean;
metadataKeys?: string[];
searchOverBody?: boolean;
filters?: Record<string, any>;
searchFilters?: Record<string, any>;
offset?: number;
limit?: number;
titleWeight?: number;
metadataWeight?: number;
},
vector_search_settings?: VectorSearchSettings | Record<string, any>,
): Promise<any> {
this._ensureAuthenticated();
const json_data: Record<string, any> = {
query,
settings: {
search_over_metadata: settings?.searchOverMetadata ?? true,
metadata_keys: settings?.metadataKeys ?? ["title"],
search_over_body: settings?.searchOverBody ?? false,
filters: settings?.filters ?? {},
search_filters: settings?.searchFilters ?? {},
offset: settings?.offset ?? 0,
limit: settings?.limit ?? 10,
title_weight: settings?.titleWeight ?? 0.5,
metadata_weight: settings?.metadataWeight ?? 0.5,
},
vector_search_settings,
};
Object.keys(json_data).forEach(
(key) => json_data[key] === undefined && delete json_data[key],
);
return await this._makeRequest("POST", "search_documents", {
data: json_data,
});
+1
View File
@@ -84,6 +84,7 @@ async def upgrade(schema, revision):
click.echo(
f"Running database upgrade for schema {schema or 'default'}..."
)
print(f"Upgrading revision = {revision}")
command = f"upgrade {revision}" if revision else "upgrade"
result = await run_alembic_command(command, schema_name=schema)
+1 -1
View File
@@ -10,7 +10,7 @@ from cli.utils.timer import timer
@click.option(
"--query", prompt="Enter your search query", help="The search query"
)
# VectorSearchSettings
# SearchSettings
@click.option(
"--use-vector-search",
is_flag=True,
+4
View File
@@ -333,6 +333,10 @@ services:
# Ollama
- OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
# Huggingface
- HUGGINGFACE_API_BASE=${HUGGINGFACE_API_BASE:-http://host.docker.internal:8080}
- HUGGINGFACE_API_KEY=${HUGGINGFACE_API_KEY}
# Unstructured
- UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
- UNSTRUCTURED_API_URL=${UNSTRUCTURED_API_URL:-https://api.unstructured.io/general/v0/general}
+1 -2
View File
@@ -78,8 +78,7 @@ __all__ = [
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"SearchSettings",
"HybridSearchSettings",
# User abstractions
"Token",
+2 -2
View File
@@ -8,7 +8,7 @@ from core.base import (
from core.base.abstractions import (
AggregateSearchResult,
KGSearchSettings,
VectorSearchSettings,
SearchSettings,
)
from core.base.agent import AgentConfig, Tool
from core.base.providers import CompletionProvider
@@ -57,7 +57,7 @@ class RAGAgentMixin:
async def search(
self,
query: str,
vector_search_settings: VectorSearchSettings,
vector_search_settings: SearchSettings,
kg_search_settings: KGSearchSettings,
*args,
**kwargs,
+1 -2
View File
@@ -46,8 +46,7 @@ __all__ = [
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"SearchSettings",
"HybridSearchSettings",
# KG abstractions
"KGCreationSettings",
+2 -4
View File
@@ -51,7 +51,6 @@ from shared.abstractions.llm import (
from shared.abstractions.prompt import Prompt
from shared.abstractions.search import (
AggregateSearchResult,
DocumentSearchSettings,
HybridSearchSettings,
KGCommunityResult,
KGEntityResult,
@@ -61,8 +60,8 @@ from shared.abstractions.search import (
KGSearchResult,
KGSearchResultType,
KGSearchSettings,
SearchSettings,
VectorSearchResult,
VectorSearchSettings,
)
from shared.abstractions.user import Token, TokenData, UserStats
from shared.abstractions.vector import (
@@ -130,8 +129,7 @@ __all__ = [
"KGGlobalResult",
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"SearchSettings",
"HybridSearchSettings",
# KG abstractions
"KGCreationSettings",
+24 -17
View File
@@ -27,7 +27,6 @@ from core.base import (
)
from core.base.abstractions import (
DocumentInfo,
DocumentSearchSettings,
IndexArgsHNSW,
IndexArgsIVFFlat,
IndexMeasure,
@@ -35,11 +34,11 @@ from core.base.abstractions import (
KGCreationSettings,
KGEnrichmentSettings,
KGEntityDeduplicationSettings,
SearchSettings,
UserStats,
VectorEntry,
VectorQuantizationType,
VectorSearchResult,
VectorSearchSettings,
VectorTableName,
)
from core.base.api.models import (
@@ -256,6 +255,15 @@ class DocumentHandler(Handler):
):
pass
@abstractmethod
async def search_documents(
self,
query_text: str,
query_embedding: Optional[list[float]] = None,
search_settings: Optional[SearchSettings] = None,
) -> list[DocumentInfo]:
pass
class CollectionHandler(Handler):
@abstractmethod
@@ -511,28 +519,22 @@ class VectorHandler(Handler):
@abstractmethod
async def semantic_search(
self, query_vector: list[float], search_settings: VectorSearchSettings
self, query_vector: list[float], search_settings: SearchSettings
) -> list[VectorSearchResult]:
pass
@abstractmethod
async def full_text_search(
self, query_text: str, search_settings: VectorSearchSettings
self, query_text: str, search_settings: SearchSettings
) -> list[VectorSearchResult]:
pass
@abstractmethod
async def search_documents(
self, query_text: str, settings: DocumentSearchSettings
) -> list[dict]:
pass
@abstractmethod
async def hybrid_search(
self,
query_text: str,
query_vector: list[float],
search_settings: VectorSearchSettings,
search_settings: SearchSettings,
*args,
**kwargs,
) -> list[VectorSearchResult]:
@@ -1404,14 +1406,14 @@ class DatabaseProvider(Provider):
return await self.vector_handler.upsert_entries(entries)
async def semantic_search(
self, query_vector: list[float], search_settings: VectorSearchSettings
self, query_vector: list[float], search_settings: SearchSettings
) -> list[VectorSearchResult]:
return await self.vector_handler.semantic_search(
query_vector, search_settings
)
async def full_text_search(
self, query_text: str, search_settings: VectorSearchSettings
self, query_text: str, search_settings: SearchSettings
) -> list[VectorSearchResult]:
return await self.vector_handler.full_text_search(
query_text, search_settings
@@ -1421,7 +1423,7 @@ class DatabaseProvider(Provider):
self,
query_text: str,
query_vector: list[float],
search_settings: VectorSearchSettings,
search_settings: SearchSettings,
*args,
**kwargs,
) -> list[VectorSearchResult]:
@@ -1430,9 +1432,14 @@ class DatabaseProvider(Provider):
)
async def search_documents(
self, query_text: str, settings: DocumentSearchSettings
) -> list[dict]:
return await self.vector_handler.search_documents(query_text, settings)
self,
query_text: str,
settings: SearchSettings,
query_embedding: Optional[list[float]] = None,
) -> list[DocumentInfo]:
return await self.document_handler.search_documents(
query_text, query_embedding, settings
)
async def delete(
self, filters: dict[str, Any]
+15 -2
View File
@@ -25,8 +25,7 @@ class EmbeddingConfig(ProviderConfig):
base_model: str
base_dimension: int
rerank_model: Optional[str] = None
rerank_dimension: Optional[int] = None
rerank_transformer_type: Optional[str] = None
rerank_url: Optional[str] = None
batch_size: int = 1
prefixes: Optional[dict[str, str]] = None
add_title_as_prefix: bool = True
@@ -38,6 +37,10 @@ class EmbeddingConfig(ProviderConfig):
VectorQuantizationSettings()
)
## deprecated
rerank_dimension: Optional[int] = None
rerank_transformer_type: Optional[str] = None
def validate_config(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Provider '{self.provider}' is not supported.")
@@ -171,6 +174,16 @@ class EmbeddingProvider(Provider):
):
pass
@abstractmethod
async def arerank(
self,
query: str,
results: list[VectorSearchResult],
stage: PipeStage = PipeStage.RERANK,
limit: int = 10,
):
pass
def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
self.prefixes = {}
+6 -1
View File
@@ -18,7 +18,6 @@ class IngestionConfig(ProviderConfig):
chunk_enrichment_settings: ChunkEnrichmentSettings = (
ChunkEnrichmentSettings()
)
extra_parsers: dict[str, str] = {}
audio_transcription_model: str = "openai/whisper-1"
@@ -29,6 +28,12 @@ class IngestionConfig(ProviderConfig):
vision_pdf_prompt_name: str = "vision_pdf"
vision_pdf_model: str = "openai/gpt-4-mini"
skip_document_summary: bool = False
document_summary_system_prompt: str = "default_system"
document_summary_task_prompt: str = "default_summary"
chunks_for_document_summary: int = 128
document_summary_model: str = "openai/gpt-4o-mini"
@property
def supported_providers(self) -> list[str]:
return ["r2r", "unstructured_local", "unstructured_api"]
+2
View File
@@ -66,6 +66,8 @@ new_after_n_chars = 512
max_characters = 1_024
combine_under_n_chars = 128
overlap = 20
chunks_for_document_summary = 16
document_summary_model = "ollama/llama3.1"
[orchestration]
provider = "hatchet"
+3
View File
@@ -67,3 +67,6 @@ vision_pdf_model = "ollama/llama3.2-vision"
[ingestion.extra_parsers]
pdf = "zerox"
chunks_for_document_summary = 16
document_summary_model = "ollama/llama3.1"
+1
View File
@@ -40,5 +40,6 @@ audio_transcription_model="azure/whisper-1"
vision_img_model = "azure/gpt-4o-mini"
vision_pdf_model = "azure/gpt-4o-mini"
document_summary_model = "azure/gpt-4o-mini"
[ingestion.chunk_enrichment_settings]
generation_config = { model = "azure/gpt-4o-mini" }
+1 -1
View File
@@ -11,10 +11,10 @@ from fastapi import (
Depends,
File,
Form,
HTTPException,
Path,
Query,
UploadFile,
HTTPException,
)
from pydantic import Json
+1 -1
View File
@@ -337,7 +337,7 @@ class ManagementRouter(BaseRouter):
document_ids: list[str] = Query([]),
offset: int = Query(0, ge=0),
limit: int = Query(
100,
1_000,
ge=-1,
description="Number of items to return. Use -1 to return all items.",
),
+37 -12
View File
@@ -8,12 +8,11 @@ from fastapi import Body, Depends
from fastapi.responses import StreamingResponse
from core.base import (
DocumentSearchSettings,
GenerationConfig,
KGSearchSettings,
Message,
R2RException,
VectorSearchSettings,
SearchSettings,
)
from core.base.api.models import (
WrappedCompletionResponse,
@@ -58,7 +57,7 @@ class RetrievalRouter(BaseRouter):
def _select_filters(
self,
auth_user: Any,
search_settings: Union[VectorSearchSettings, KGSearchSettings],
search_settings: Union[SearchSettings, KGSearchSettings],
) -> dict[str, Any]:
selected_collections = {
str(cid) for cid in set(search_settings.selected_collection_ids)
@@ -111,8 +110,8 @@ class RetrievalRouter(BaseRouter):
query: str = Body(
..., description=search_descriptions.get("query")
),
settings: DocumentSearchSettings = Body(
default_factory=DocumentSearchSettings,
settings: SearchSettings = Body(
default_factory=SearchSettings,
description="Settings for document search",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
@@ -127,8 +126,14 @@ class RetrievalRouter(BaseRouter):
Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
"""
query_embedding = (
await self.service.providers.embedding.async_get_embedding(
query
)
)
results = await self.service.search_documents(
query=query,
query_embedding=query_embedding,
settings=settings,
)
return results
@@ -142,8 +147,8 @@ class RetrievalRouter(BaseRouter):
query: str = Body(
..., description=search_descriptions.get("query")
),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description=search_descriptions.get("vector_search_settings"),
),
kg_search_settings: KGSearchSettings = Body(
@@ -187,8 +192,8 @@ class RetrievalRouter(BaseRouter):
@self.base_endpoint
async def rag_app(
query: str = Body(..., description=rag_descriptions.get("query")),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description=rag_descriptions.get("vector_search_settings"),
),
kg_search_settings: KGSearchSettings = Body(
@@ -261,8 +266,8 @@ class RetrievalRouter(BaseRouter):
description=agent_descriptions.get("messages"),
deprecated=True,
),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description=agent_descriptions.get("vector_search_settings"),
),
kg_search_settings: KGSearchSettings = Body(
@@ -358,7 +363,27 @@ class RetrievalRouter(BaseRouter):
This endpoint uses the language model to generate completions for the provided messages.
The generation process can be customized using the generation_config parameter.
"""
print("messages = ", messages)
return await self.service.completion(
messages=messages,
messages=[message.to_dict() for message in messages],
generation_config=generation_config,
)
@self.router.post("/embedding")
@self.base_endpoint
async def embedding(
content: str = Body(..., description="The content to embed"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
response_model=WrappedCompletionResponse,
):
"""
Generate completions for a list of messages.
This endpoint uses the language model to generate completions for the provided messages.
The generation process can be customized using the generation_config parameter.
"""
return await self.service.providers.embedding.async_get_embedding(
text=content
)
+2 -2
View File
@@ -1,11 +1,11 @@
from typing import Union
from core.base import R2RException
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from core.base import R2RException
from core.providers import (
HatchetOrchestrationProvider,
SimpleOrchestrationProvider,
+3 -2
View File
@@ -5,10 +5,11 @@ from contextlib import asynccontextmanager
from typing import Optional
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from core.base import R2RException
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from core.base import R2RException
from .assembly import R2RBuilder, R2RConfig
@@ -3,8 +3,8 @@ import logging
import uuid
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import HTTPException
from fastapi import HTTPException
from hatchet_sdk import ConcurrencyLimitStrategy, Context
from litellm import AuthenticationError
@@ -103,6 +103,14 @@ def hatchet_ingestion_factory(
# document_info_dict = context.step_output("parse")["document_info"]
# document_info = DocumentInfo(**document_info_dict)
await service.update_document_status(
document_info, status=IngestionStatus.AUGMENTING
)
await service.augment_document_info(
document_info,
[extraction.to_dict() for extraction in extractions],
)
await self.ingestion_service.update_document_status(
document_info,
status=IngestionStatus.EMBEDDING,
@@ -2,9 +2,9 @@ import asyncio
import logging
from uuid import UUID
from fastapi import HTTPException
from litellm import AuthenticationError
from fastapi import HTTPException
from core.base import DocumentExtraction, R2RException, increment_version
from core.utils import (
generate_default_user_collection_id,
@@ -44,6 +44,11 @@ def simple_ingestion_factory(service: IngestionService):
async for extraction in extractions_generator
]
await service.update_document_status(
document_info, status=IngestionStatus.AUGMENTING
)
await service.augment_document_info(document_info, extractions)
await service.update_document_status(
document_info, status=IngestionStatus.EMBEDDING
)
@@ -5,6 +5,7 @@ import uuid
from datetime import datetime
from typing import Any, AsyncGenerator, Optional, Sequence, Union
from uuid import UUID
from fastapi import HTTPException
from core.base import (
@@ -12,6 +13,7 @@ from core.base import (
DocumentExtraction,
DocumentInfo,
DocumentType,
GenerationConfig,
IngestionStatus,
R2RException,
RawChunk,
@@ -221,6 +223,43 @@ class IngestionService(Service):
ingestion_config=ingestion_config,
)
async def augment_document_info(
self,
document_info: DocumentInfo,
chunked_documents: list[dict],
) -> None:
if not self.config.ingestion.skip_document_summary:
document = f"Document Title: {document_info.title}\n"
if document_info.metadata != {}:
document += f"Document Metadata: {json.dumps(document_info.metadata)}\n"
document += "Document Text:\n"
for chunk in chunked_documents[
0 : self.config.ingestion.chunks_for_document_summary
]:
document += chunk["data"]
messages = await self.providers.database.prompt_handler.get_message_payload(
system_prompt_name=self.config.ingestion.document_summary_system_prompt,
task_prompt_name=self.config.ingestion.document_summary_task_prompt,
task_inputs={"document": document},
)
response = await self.providers.llm.aget_completion(
messages=messages,
generation_config=GenerationConfig(model="openai/gpt-4o-mini"),
)
document_info.summary = response.choices[0].message.content # type: ignore
if not document_info.summary:
raise ValueError("Expected a generated response.")
embedding = await self.providers.embedding.async_get_embedding(
text=document_info.summary,
)
document_info.summary_embedding = embedding
return
async def embed_document(
self,
chunked_documents: list[dict],
+1
View File
@@ -3,6 +3,7 @@ import math
import time
from typing import AsyncGenerator, Optional
from uuid import UUID
from fastapi import HTTPException
from core.base import KGExtractionStatus, RunManager
+16 -13
View File
@@ -3,18 +3,19 @@ import logging
import time
from typing import Optional
from uuid import UUID
from fastapi import HTTPException
from core import R2RStreamingRAGAgent
from core.base import (
DocumentSearchSettings,
DocumentInfo,
EmbeddingPurpose,
GenerationConfig,
KGSearchSettings,
Message,
R2RException,
RunManager,
VectorSearchSettings,
SearchSettings,
manage_run,
to_async_generator,
)
@@ -55,7 +56,7 @@ class RetrievalService(Service):
async def search(
self,
query: str,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
vector_search_settings: SearchSettings = SearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
*args,
**kwargs,
@@ -121,12 +122,14 @@ class RetrievalService(Service):
async def search_documents(
self,
query: str,
settings: DocumentSearchSettings,
) -> list[dict]:
settings: SearchSettings,
query_embedding: Optional[list[float]] = None,
) -> list[DocumentInfo]:
return await self.providers.database.search_documents(
query_text=query,
settings=settings,
query_embedding=query_embedding,
)
@telemetry_event("Completion")
@@ -149,7 +152,7 @@ class RetrievalService(Service):
self,
query: str,
rag_generation_config: GenerationConfig,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
vector_search_settings: SearchSettings = SearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
*args,
**kwargs,
@@ -247,7 +250,7 @@ class RetrievalService(Service):
async def agent(
self,
rag_generation_config: GenerationConfig,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
vector_search_settings: SearchSettings = SearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
task_prompt_override: Optional[str] = None,
include_title_if_available: Optional[bool] = False,
@@ -422,7 +425,7 @@ class RetrievalServiceAdapter:
@staticmethod
def prepare_search_input(
query: str,
vector_search_settings: VectorSearchSettings,
vector_search_settings: SearchSettings,
kg_search_settings: KGSearchSettings,
user: UserResponse,
) -> dict:
@@ -437,7 +440,7 @@ class RetrievalServiceAdapter:
def parse_search_input(data: dict):
return {
"query": data["query"],
"vector_search_settings": VectorSearchSettings.from_dict(
"vector_search_settings": SearchSettings.from_dict(
data["vector_search_settings"]
),
"kg_search_settings": KGSearchSettings.from_dict(
@@ -449,7 +452,7 @@ class RetrievalServiceAdapter:
@staticmethod
def prepare_rag_input(
query: str,
vector_search_settings: VectorSearchSettings,
vector_search_settings: SearchSettings,
kg_search_settings: KGSearchSettings,
rag_generation_config: GenerationConfig,
task_prompt_override: Optional[str],
@@ -468,7 +471,7 @@ class RetrievalServiceAdapter:
def parse_rag_input(data: dict):
return {
"query": data["query"],
"vector_search_settings": VectorSearchSettings.from_dict(
"vector_search_settings": SearchSettings.from_dict(
data["vector_search_settings"]
),
"kg_search_settings": KGSearchSettings.from_dict(
@@ -484,7 +487,7 @@ class RetrievalServiceAdapter:
@staticmethod
def prepare_agent_input(
message: Message,
vector_search_settings: VectorSearchSettings,
vector_search_settings: SearchSettings,
kg_search_settings: KGSearchSettings,
rag_generation_config: GenerationConfig,
task_prompt_override: Optional[str],
@@ -509,7 +512,7 @@ class RetrievalServiceAdapter:
def parse_agent_input(data: dict):
return {
"message": Message.from_dict(data["message"]),
"vector_search_settings": VectorSearchSettings.from_dict(
"vector_search_settings": SearchSettings.from_dict(
data["vector_search_settings"]
),
"kg_search_settings": KGSearchSettings.from_dict(
+2 -2
View File
@@ -5,7 +5,7 @@ from typing import Any, Optional
from ..base.abstractions import (
GenerationConfig,
KGSearchSettings,
VectorSearchSettings,
SearchSettings,
)
from ..base.logger.base import RunType
from ..base.logger.run_manager import RunManager, manage_run
@@ -34,7 +34,7 @@ class RAGPipeline(AsyncPipeline):
input: Any,
state: Optional[AsyncState],
run_manager: Optional[RunManager] = None,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
vector_search_settings: SearchSettings = SearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
rag_generation_config: GenerationConfig = GenerationConfig(),
*args: Any,
+2 -2
View File
@@ -6,7 +6,7 @@ from typing import Any, Optional
from ..base.abstractions import (
AggregateSearchResult,
KGSearchSettings,
VectorSearchSettings,
SearchSettings,
)
from ..base.logger.run_manager import RunManager, manage_run
from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests
@@ -35,7 +35,7 @@ class SearchPipeline(AsyncPipeline):
state: Optional[AsyncState],
stream: bool = False,
run_manager: Optional[RunManager] = None,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
vector_search_settings: SearchSettings = SearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
*args: Any,
**kwargs: Any,
+1
View File
@@ -2,6 +2,7 @@ import json
import logging
from typing import Any, Union
from uuid import UUID
from fastapi import HTTPException
from core.base import AsyncState
+1
View File
@@ -5,6 +5,7 @@ Pipe to tune the prompt for the KG model.
import logging
from typing import Any
from uuid import UUID
from fastapi import HTTPException
from core.base import (

Some files were not shown because too many files have changed in this diff Show More