555 lines
19 KiB
Python
555 lines
19 KiB
Python
import json
|
|
import uuid
|
|
from typing import Any, Generator, Optional
|
|
|
|
from shared.api.models import (
|
|
WrappedAgentResponse,
|
|
WrappedEmbeddingResponse,
|
|
WrappedLLMChatCompletion,
|
|
WrappedRAGResponse,
|
|
WrappedSearchResponse,
|
|
)
|
|
|
|
from ..models import (
|
|
AgentEvent,
|
|
CitationData,
|
|
CitationEvent,
|
|
Delta,
|
|
DeltaPayload,
|
|
FinalAnswerData,
|
|
FinalAnswerEvent,
|
|
GenerationConfig,
|
|
Message,
|
|
MessageData,
|
|
MessageDelta,
|
|
MessageEvent,
|
|
SearchMode,
|
|
SearchResultsData,
|
|
SearchResultsEvent,
|
|
SearchSettings,
|
|
ThinkingData,
|
|
ThinkingEvent,
|
|
ToolCallData,
|
|
ToolCallEvent,
|
|
ToolResultData,
|
|
ToolResultEvent,
|
|
UnknownEvent,
|
|
)
|
|
|
|
|
|
def parse_retrieval_event(raw: dict) -> Optional[AgentEvent]:
|
|
"""
|
|
Convert a raw SSE event dict into a typed Pydantic model.
|
|
|
|
Example raw dict:
|
|
{
|
|
"event": "message",
|
|
"data": "{\"id\": \"msg_partial\", \"object\": \"agent.message.delta\", \"delta\": {...}}"
|
|
}
|
|
"""
|
|
event_type = raw.get("event", "unknown")
|
|
|
|
# If event_type == "done", we usually return None to signal the SSE stream is finished.
|
|
if event_type == "done":
|
|
return None
|
|
|
|
# The SSE "data" is JSON-encoded, so parse it
|
|
data_str = raw.get("data", "")
|
|
try:
|
|
data_obj = json.loads(data_str)
|
|
except json.JSONDecodeError as e:
|
|
# You can decide whether to raise or return UnknownEvent
|
|
raise ValueError(f"Could not parse JSON in SSE event data: {e}") from e
|
|
|
|
# Now branch on event_type to build the right Pydantic model
|
|
if event_type == "search_results":
|
|
return SearchResultsEvent(
|
|
event=event_type,
|
|
data=SearchResultsData(**data_obj),
|
|
)
|
|
elif event_type == "message":
|
|
# Parse nested delta structure manually before creating MessageData
|
|
if "delta" in data_obj and isinstance(data_obj["delta"], dict):
|
|
delta_dict = data_obj["delta"]
|
|
|
|
# Convert content items to MessageDelta objects
|
|
if "content" in delta_dict and isinstance(
|
|
delta_dict["content"], list
|
|
):
|
|
parsed_content = []
|
|
for item in delta_dict["content"]:
|
|
if isinstance(item, dict):
|
|
# Parse payload to DeltaPayload
|
|
if "payload" in item and isinstance(
|
|
item["payload"], dict
|
|
):
|
|
payload_dict = item["payload"]
|
|
item["payload"] = DeltaPayload(**payload_dict)
|
|
parsed_content.append(MessageDelta(**item))
|
|
|
|
# Replace with parsed content
|
|
delta_dict["content"] = parsed_content
|
|
|
|
# Create properly typed Delta object
|
|
data_obj["delta"] = Delta(**delta_dict)
|
|
|
|
return MessageEvent(
|
|
event=event_type,
|
|
data=MessageData(**data_obj),
|
|
)
|
|
elif event_type == "citation":
|
|
return CitationEvent(event=event_type, data=CitationData(**data_obj))
|
|
elif event_type == "tool_call":
|
|
return ToolCallEvent(event=event_type, data=ToolCallData(**data_obj))
|
|
elif event_type == "tool_result":
|
|
return ToolResultEvent(
|
|
event=event_type, data=ToolResultData(**data_obj)
|
|
)
|
|
elif event_type == "thinking":
|
|
# Parse nested delta structure manually before creating ThinkingData
|
|
if "delta" in data_obj and isinstance(data_obj["delta"], dict):
|
|
delta_dict = data_obj["delta"]
|
|
|
|
# Convert content items to MessageDelta objects
|
|
if "content" in delta_dict and isinstance(
|
|
delta_dict["content"], list
|
|
):
|
|
parsed_content = []
|
|
for item in delta_dict["content"]:
|
|
if isinstance(item, dict):
|
|
# Parse payload to DeltaPayload
|
|
if "payload" in item and isinstance(
|
|
item["payload"], dict
|
|
):
|
|
payload_dict = item["payload"]
|
|
item["payload"] = DeltaPayload(**payload_dict)
|
|
parsed_content.append(MessageDelta(**item))
|
|
|
|
# Replace with parsed content
|
|
delta_dict["content"] = parsed_content
|
|
|
|
# Create properly typed Delta object
|
|
data_obj["delta"] = Delta(**delta_dict)
|
|
|
|
return ThinkingEvent(
|
|
event=event_type,
|
|
data=ThinkingData(**data_obj),
|
|
)
|
|
elif event_type == "final_answer":
|
|
return FinalAnswerEvent(
|
|
event=event_type, data=FinalAnswerData(**data_obj)
|
|
)
|
|
else:
|
|
# Fallback if it doesn't match any known event
|
|
return UnknownEvent(
|
|
event=event_type,
|
|
data=data_obj,
|
|
)
|
|
|
|
|
|
def search_arg_parser(
|
|
query: str,
|
|
search_mode: Optional[str | SearchMode] = "custom",
|
|
search_settings: Optional[dict | SearchSettings] = None,
|
|
) -> dict:
|
|
if search_mode and not isinstance(search_mode, str):
|
|
search_mode = search_mode.value
|
|
|
|
if search_settings and not isinstance(search_settings, dict):
|
|
search_settings = search_settings.model_dump()
|
|
|
|
data: dict[str, Any] = {
|
|
"query": query,
|
|
"search_settings": search_settings,
|
|
}
|
|
if search_mode:
|
|
data["search_mode"] = search_mode
|
|
|
|
return data
|
|
|
|
|
|
def completion_arg_parser(
|
|
messages: list[dict | Message],
|
|
generation_config: Optional[dict | GenerationConfig] = None,
|
|
) -> dict:
|
|
# FIXME: Needs a proper return type
|
|
cast_messages: list[Message] = [
|
|
Message(**msg) if isinstance(msg, dict) else msg for msg in messages
|
|
]
|
|
|
|
if generation_config and not isinstance(generation_config, dict):
|
|
generation_config = generation_config.model_dump()
|
|
|
|
data: dict[str, Any] = {
|
|
"messages": [msg.model_dump() for msg in cast_messages],
|
|
"generation_config": generation_config,
|
|
}
|
|
return data
|
|
|
|
|
|
def embedding_arg_parser(
|
|
text: str,
|
|
) -> dict:
|
|
data: dict[str, Any] = {
|
|
"text": text,
|
|
}
|
|
return data
|
|
|
|
|
|
def rag_arg_parser(
|
|
query: str,
|
|
rag_generation_config: Optional[dict | GenerationConfig] = None,
|
|
search_mode: Optional[str | SearchMode] = "custom",
|
|
search_settings: Optional[dict | SearchSettings] = None,
|
|
task_prompt: Optional[str] = None,
|
|
include_title_if_available: Optional[bool] = False,
|
|
include_web_search: Optional[bool] = False,
|
|
) -> dict:
|
|
if rag_generation_config and not isinstance(rag_generation_config, dict):
|
|
rag_generation_config = rag_generation_config.model_dump()
|
|
if search_settings and not isinstance(search_settings, dict):
|
|
search_settings = search_settings.model_dump()
|
|
|
|
data: dict[str, Any] = {
|
|
"query": query,
|
|
"rag_generation_config": rag_generation_config,
|
|
"search_settings": search_settings,
|
|
"task_prompt": task_prompt,
|
|
"include_title_if_available": include_title_if_available,
|
|
"include_web_search": include_web_search,
|
|
}
|
|
if search_mode:
|
|
data["search_mode"] = search_mode
|
|
return data
|
|
|
|
|
|
def agent_arg_parser(
|
|
message: Optional[dict | Message] = None,
|
|
rag_generation_config: Optional[dict | GenerationConfig] = None,
|
|
research_generation_config: Optional[dict | GenerationConfig] = None,
|
|
search_mode: Optional[str | SearchMode] = "custom",
|
|
search_settings: Optional[dict | SearchSettings] = None,
|
|
task_prompt: Optional[str] = None,
|
|
include_title_if_available: Optional[bool] = True,
|
|
conversation_id: Optional[str | uuid.UUID] = None,
|
|
max_tool_context_length: Optional[int] = None,
|
|
use_system_context: Optional[bool] = True,
|
|
rag_tools: Optional[list[str]] = None,
|
|
research_tools: Optional[list[str]] = None,
|
|
tools: Optional[list[str]] = None, # For backward compatibility
|
|
mode: Optional[str] = "rag",
|
|
needs_initial_conversation_name: Optional[bool] = None,
|
|
) -> dict:
|
|
if rag_generation_config and not isinstance(rag_generation_config, dict):
|
|
rag_generation_config = rag_generation_config.model_dump()
|
|
if research_generation_config and not isinstance(
|
|
research_generation_config, dict
|
|
):
|
|
research_generation_config = research_generation_config.model_dump()
|
|
if search_settings and not isinstance(search_settings, dict):
|
|
search_settings = search_settings.model_dump()
|
|
|
|
data: dict[str, Any] = {
|
|
"rag_generation_config": rag_generation_config or {},
|
|
"search_settings": search_settings,
|
|
"task_prompt": task_prompt,
|
|
"include_title_if_available": include_title_if_available,
|
|
"conversation_id": (str(conversation_id) if conversation_id else None),
|
|
"max_tool_context_length": max_tool_context_length,
|
|
"use_system_context": use_system_context,
|
|
"mode": mode,
|
|
}
|
|
|
|
# Handle generation configs based on mode
|
|
if research_generation_config and mode == "research":
|
|
data["research_generation_config"] = research_generation_config
|
|
|
|
# Handle tool configurations
|
|
if rag_tools:
|
|
data["rag_tools"] = rag_tools
|
|
if research_tools:
|
|
data["research_tools"] = research_tools
|
|
if tools: # Backward compatibility
|
|
data["tools"] = tools
|
|
|
|
if search_mode:
|
|
data["search_mode"] = search_mode
|
|
|
|
if needs_initial_conversation_name:
|
|
data["needs_initial_conversation_name"] = (
|
|
needs_initial_conversation_name
|
|
)
|
|
|
|
if message:
|
|
cast_message: Message = (
|
|
Message(**message) if isinstance(message, dict) else message
|
|
)
|
|
data["message"] = cast_message.model_dump()
|
|
return data
|
|
|
|
|
|
class RetrievalSDK:
|
|
"""SDK for interacting with documents in the v3 API."""
|
|
|
|
def __init__(self, client):
|
|
self.client = client
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
search_mode: Optional[str | SearchMode] = "custom",
|
|
search_settings: Optional[dict | SearchSettings] = None,
|
|
) -> WrappedSearchResponse:
|
|
"""Conduct a vector and/or graph search.
|
|
|
|
Args:
|
|
query (str): The query to search for.
|
|
search_settings (Optional[dict, SearchSettings]]): Vector search settings.
|
|
|
|
Returns:
|
|
WrappedSearchResponse
|
|
"""
|
|
|
|
response_dict = self.client._make_request(
|
|
"POST",
|
|
"retrieval/search",
|
|
json=search_arg_parser(
|
|
query=query,
|
|
search_mode=search_mode,
|
|
search_settings=search_settings,
|
|
),
|
|
version="v3",
|
|
)
|
|
|
|
return WrappedSearchResponse(**response_dict)
|
|
|
|
def completion(
|
|
self,
|
|
messages: list[dict | Message],
|
|
generation_config: Optional[dict | GenerationConfig] = None,
|
|
) -> WrappedLLMChatCompletion:
|
|
cast_messages: list[Message] = [
|
|
Message(**msg) if isinstance(msg, dict) else msg
|
|
for msg in messages
|
|
]
|
|
|
|
if generation_config and not isinstance(generation_config, dict):
|
|
generation_config = generation_config.model_dump()
|
|
|
|
data: dict[str, Any] = {
|
|
"messages": [msg.model_dump() for msg in cast_messages],
|
|
"generation_config": generation_config,
|
|
}
|
|
response_dict = self.client._make_request(
|
|
"POST",
|
|
"retrieval/completion",
|
|
json=completion_arg_parser(messages, generation_config),
|
|
version="v3",
|
|
)
|
|
|
|
return WrappedLLMChatCompletion(**response_dict)
|
|
|
|
def embedding(
|
|
self,
|
|
text: str,
|
|
) -> WrappedEmbeddingResponse:
|
|
response_dict = self.client._make_request(
|
|
"POST",
|
|
"retrieval/embedding",
|
|
data=embedding_arg_parser(text),
|
|
version="v3",
|
|
)
|
|
|
|
return WrappedEmbeddingResponse(**response_dict)
|
|
|
|
def rag(
|
|
self,
|
|
query: str,
|
|
rag_generation_config: Optional[dict | GenerationConfig] = None,
|
|
search_mode: Optional[str | SearchMode] = "custom",
|
|
search_settings: Optional[dict | SearchSettings] = None,
|
|
task_prompt: Optional[str] = None,
|
|
include_title_if_available: Optional[bool] = False,
|
|
include_web_search: Optional[bool] = False,
|
|
) -> (
|
|
WrappedRAGResponse
|
|
| Generator[
|
|
ThinkingEvent
|
|
| SearchResultsEvent
|
|
| MessageEvent
|
|
| CitationEvent
|
|
| FinalAnswerEvent
|
|
| ToolCallEvent
|
|
| ToolResultEvent
|
|
| UnknownEvent
|
|
| None,
|
|
None,
|
|
None,
|
|
]
|
|
):
|
|
"""Conducts a Retrieval Augmented Generation (RAG) search with the
|
|
given query.
|
|
|
|
Args:
|
|
query (str): The query to search for.
|
|
rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
|
|
search_settings (Optional[dict | SearchSettings]): Vector search settings.
|
|
task_prompt (Optional[str]): Task prompt override.
|
|
include_title_if_available (Optional[bool]): Include the title if available.
|
|
|
|
Returns:
|
|
WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
|
|
"""
|
|
data = rag_arg_parser(
|
|
query=query,
|
|
rag_generation_config=rag_generation_config,
|
|
search_mode=search_mode,
|
|
search_settings=search_settings,
|
|
task_prompt=task_prompt,
|
|
include_title_if_available=include_title_if_available,
|
|
include_web_search=include_web_search,
|
|
)
|
|
rag_generation_config = data.get("rag_generation_config")
|
|
if rag_generation_config and rag_generation_config.get( # type: ignore
|
|
"stream", False
|
|
):
|
|
raw_stream = self.client._make_streaming_request(
|
|
"POST",
|
|
"retrieval/rag",
|
|
json=data,
|
|
version="v3",
|
|
)
|
|
# Wrap the raw stream to parse each event
|
|
return (parse_retrieval_event(event) for event in raw_stream)
|
|
|
|
response_dict = self.client._make_request(
|
|
"POST",
|
|
"retrieval/rag",
|
|
json=data,
|
|
version="v3",
|
|
)
|
|
|
|
return WrappedRAGResponse(**response_dict)
|
|
|
|
def agent(
|
|
self,
|
|
message: Optional[dict | Message] = None,
|
|
rag_generation_config: Optional[dict | GenerationConfig] = None,
|
|
research_generation_config: Optional[dict | GenerationConfig] = None,
|
|
search_mode: Optional[str | SearchMode] = "custom",
|
|
search_settings: Optional[dict | SearchSettings] = None,
|
|
task_prompt: Optional[str] = None,
|
|
include_title_if_available: Optional[bool] = True,
|
|
conversation_id: Optional[str | uuid.UUID] = None,
|
|
max_tool_context_length: Optional[int] = None,
|
|
use_system_context: Optional[bool] = True,
|
|
# Tool configurations
|
|
rag_tools: Optional[list[str]] = None,
|
|
research_tools: Optional[list[str]] = None,
|
|
tools: Optional[list[str]] = None, # For backward compatibility
|
|
mode: Optional[str] = "rag",
|
|
needs_initial_conversation_name: Optional[bool] = None,
|
|
) -> (
|
|
WrappedAgentResponse
|
|
| Generator[
|
|
ThinkingEvent
|
|
| SearchResultsEvent
|
|
| MessageEvent
|
|
| CitationEvent
|
|
| FinalAnswerEvent
|
|
| ToolCallEvent
|
|
| ToolResultEvent
|
|
| UnknownEvent
|
|
| None,
|
|
None,
|
|
None,
|
|
]
|
|
):
|
|
"""Performs a single turn in a conversation with a RAG agent.
|
|
|
|
Args:
|
|
message (Optional[dict | Message]): The message to send to the agent.
|
|
rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
|
|
research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
|
|
search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
|
|
search_settings (Optional[dict | SearchSettings]): Vector search settings.
|
|
task_prompt (Optional[str]): Task prompt override.
|
|
include_title_if_available (Optional[bool]): Include the title if available.
|
|
conversation_id (Optional[str | uuid.UUID]): ID of the conversation for maintaining context.
|
|
max_tool_context_length (Optional[int]): Maximum context length for tool replies.
|
|
use_system_context (Optional[bool]): Whether to use system context in the prompt.
|
|
rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
|
|
Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions".
|
|
research_tools (Optional[list[str]]): List of tools to enable for Research mode.
|
|
Available tools: "rag", "reasoning", "critique", "python_executor".
|
|
tools (Optional[list[str]]): Deprecated. List of tools to execute.
|
|
mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis.
|
|
Defaults to "rag".
|
|
|
|
Returns:
|
|
WrappedAgentResponse | AsyncGenerator[AgentEvent, None]: The agent response.
|
|
"""
|
|
data = agent_arg_parser(
|
|
message=message,
|
|
rag_generation_config=rag_generation_config,
|
|
research_generation_config=research_generation_config,
|
|
search_mode=search_mode,
|
|
search_settings=search_settings,
|
|
task_prompt=task_prompt,
|
|
include_title_if_available=include_title_if_available,
|
|
conversation_id=conversation_id,
|
|
max_tool_context_length=max_tool_context_length,
|
|
use_system_context=use_system_context,
|
|
rag_tools=rag_tools,
|
|
research_tools=research_tools,
|
|
tools=tools,
|
|
mode=mode,
|
|
needs_initial_conversation_name=needs_initial_conversation_name,
|
|
)
|
|
|
|
# Determine if streaming is enabled
|
|
if search_mode:
|
|
data["search_mode"] = search_mode
|
|
|
|
if message:
|
|
cast_message: Message = (
|
|
Message(**message) if isinstance(message, dict) else message
|
|
)
|
|
data["message"] = cast_message.model_dump()
|
|
|
|
is_stream = False
|
|
if mode != "research":
|
|
if rag_generation_config:
|
|
if isinstance(rag_generation_config, dict):
|
|
is_stream = rag_generation_config.get( # type: ignore
|
|
"stream", False
|
|
)
|
|
else:
|
|
is_stream = rag_generation_config.stream
|
|
else:
|
|
if research_generation_config:
|
|
if isinstance(research_generation_config, dict):
|
|
is_stream = research_generation_config.get( # type: ignore
|
|
"stream", False
|
|
)
|
|
else:
|
|
is_stream = research_generation_config.stream
|
|
|
|
if is_stream:
|
|
raw_stream = self.client._make_streaming_request(
|
|
"POST",
|
|
"retrieval/agent",
|
|
json=data,
|
|
version="v3",
|
|
)
|
|
return (parse_retrieval_event(event) for event in raw_stream)
|
|
|
|
response_dict = self.client._make_request(
|
|
"POST",
|
|
"retrieval/agent",
|
|
json=data,
|
|
version="v3",
|
|
)
|
|
|
|
return WrappedAgentResponse(**response_dict)
|