refactor(ml): model downloading (#3545)

* download facial recognition models

* download hf models

* simplified logic

* updated `predict` for facial recognition

* ensure download method is called

* fixed repo_id for clip

* fixed download destination

* use st's own `snapshot_download`

* conditional download

* fixed predict method

* check if loaded

* minor fixes

* updated mypy overrides

* added pytest-mock

* updated tests

* updated lock
This commit is contained in:
Mert
2023-08-05 22:45:13 -04:00
committed by GitHub
parent 2f26a7edae
commit c73832bd9c
10 changed files with 350 additions and 274 deletions

View File

@ -1,6 +1,7 @@
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from PIL.Image import Image
from transformers.pipelines import pipeline
@ -22,14 +23,19 @@ class ImageClassifier(InferenceModel):
self.min_score = min_score
super().__init__(model_name, cache_dir, **model_kwargs)
def load(self, **model_kwargs: Any) -> None:
def _download(self, **model_kwargs: Any) -> None:
snapshot_download(
cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
)
def _load(self, **model_kwargs: Any) -> None:
self.model = pipeline(
self.model_type.value,
self.model_name,
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
)
def predict(self, image: Image) -> list[str]:
def _predict(self, image: Image) -> list[str]:
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]