feat(ml): conditionally download .armnn models (#6650)

This commit is contained in:
Mert
2024-01-28 10:31:59 -05:00
committed by GitHub
parent fa0913120d
commit a84b6f5fb1
5 changed files with 127 additions and 38 deletions

View File

@ -14,7 +14,7 @@ import ann.ann
from app.models.constants import SUPPORTED_PROVIDERS
from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelType
from ..schemas import ModelRuntime, ModelType
from .ann import AnnSession
@ -28,6 +28,7 @@ class InferenceModel(ABC):
providers: list[str] | None = None,
provider_options: list[dict[str, Any]] | None = None,
sess_options: ort.SessionOptions | None = None,
preferred_runtime: ModelRuntime | None = None,
**model_kwargs: Any,
) -> None:
self.loaded = False
@ -36,6 +37,7 @@ class InferenceModel(ABC):
self.providers = providers if providers is not None else self.providers_default
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
self.preferred_runtime = preferred_runtime if preferred_runtime is not None else self.preferred_runtime_default
def download(self) -> None:
if not self.cached:
@ -66,11 +68,13 @@ class InferenceModel(ABC):
pass
def _download(self) -> None:
ignore_patterns = [] if self.preferred_runtime == ModelRuntime.ARMNN else ["*.armnn"]
snapshot_download(
get_hf_model_name(self.model_name),
cache_dir=self.cache_dir,
local_dir=self.cache_dir,
local_dir_use_symlinks=False,
ignore_patterns=ignore_patterns,
)
@abstractmethod
@ -100,18 +104,28 @@ class InferenceModel(ABC):
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
armnn_path = model_path.with_suffix(".armnn")
if settings.ann and ann.ann.is_available and armnn_path.is_file():
session = AnnSession(armnn_path)
elif model_path.is_file():
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
if not model_path.is_file():
onnx_path = model_path.with_suffix(".onnx")
if not onnx_path.is_file():
raise ValueError(f"Model path '{model_path}' does not exist")
log.warning(
f"Could not find model path '{model_path}'. " f"Falling back to ONNX model path '{onnx_path}' instead.",
)
else:
raise ValueError(f"the file model_path='{model_path}' does not exist")
model_path = onnx_path
match model_path.suffix:
case ".armnn":
session = AnnSession(model_path)
case ".onnx":
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
@property
@ -132,7 +146,7 @@ class InferenceModel(ABC):
@property
def cached(self) -> bool:
return self.cache_dir.exists() and any(self.cache_dir.iterdir())
return self.cache_dir.is_dir() and any(self.cache_dir.iterdir())
@property
def providers(self) -> list[str]:
@ -215,6 +229,19 @@ class InferenceModel(ABC):
return sess_options
@property
def preferred_runtime(self) -> ModelRuntime:
return self._preferred_runtime
@preferred_runtime.setter
def preferred_runtime(self, preferred_runtime: ModelRuntime) -> None:
log.debug(f"Setting preferred runtime to {preferred_runtime}")
self._preferred_runtime = preferred_runtime
@property
def preferred_runtime_default(self) -> ModelRuntime:
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]

View File

@ -81,11 +81,11 @@ class BaseCLIPEncoder(InferenceModel):
@property
def textual_path(self) -> Path:
return self.textual_dir / "model.onnx"
return self.textual_dir / f"model.{self.preferred_runtime}"
@property
def visual_path(self) -> Path:
return self.visual_dir / "model.onnx"
return self.visual_dir / f"model.{self.preferred_runtime}"
@property
def tokenizer_file_path(self) -> Path:

View File

@ -77,11 +77,11 @@ class FaceRecognizer(InferenceModel):
@property
def det_file(self) -> Path:
return self.cache_dir / "detection" / "model.onnx"
return self.cache_dir / "detection" / f"model.{self.preferred_runtime}"
@property
def rec_file(self) -> Path:
return self.cache_dir / "recognition" / "model.onnx"
return self.cache_dir / "recognition" / f"model.{self.preferred_runtime}"
def configure(self, **model_kwargs: Any) -> None:
self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)