diff --git a/py/tests/unit/conftest.py b/py/tests/unit/conftest.py index a9d1b2e81..3ecae670e 100644 --- a/py/tests/unit/conftest.py +++ b/py/tests/unit/conftest.py @@ -3,6 +3,7 @@ import os import json from datetime import datetime import uuid +import yaml import pytest @@ -749,10 +750,26 @@ class MockPostgresPromptsHandler: return list(self.prompts.values()) async def load_prompts_from_yaml(self, yaml_content): - """Load prompts from a YAML string.""" - # Just a stub for testing - pass - + """Load prompts from YAML content.""" + try: + data = yaml.safe_load(yaml_content) + if not data or "prompts" not in data: + return + + for prompt_data in data["prompts"]: + name = prompt_data.get("name") + template = prompt_data.get("template") + input_types = prompt_data.get("input_types", {}) + + if name and template: + await self.add_prompt( + name=name, + template=template, + input_types=input_types, + ) + except Exception as e: + raise ValueError(f"Failed to load prompts from YAML: {str(e)}") + async def get_cached_prompt(self, prompt_name, inputs=None, bypass_cache=False): """Get a formatted prompt, using cache if available.""" if inputs is None: