refactor(ml): model sessions (#10559)

This commit is contained in:
Mert
2024-06-25 12:00:24 -04:00
committed by GitHub
parent 6538ad8de7
commit 6356c28f64
11 changed files with 529 additions and 375 deletions

View File

@ -52,8 +52,6 @@ class Ann(metaclass=_Singleton):
def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None:
if not is_available:
raise RuntimeError("libann is not available!")
if tuning_file and not exists(tuning_file):
raise ValueError("tuning_file must point to an existing (possibly empty) file!")
if tuning_level == 0 and tuning_file is None:
raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file")
if tuning_level < 0 or tuning_level > 3:
@ -67,6 +65,12 @@ class Ann(metaclass=_Singleton):
self.input_shapes: dict[int, tuple[tuple[int], ...]] = {}
self.ann: int | None = None
self.new()
if self.tuning_file is not None:
# make sure tuning file exists (without clearing contents)
# once filled, the tuning file reduces the cost/time of the first
# inference after model load by 10s of seconds
open(self.tuning_file, "a").close()
def new(self) -> None:
if self.ann is None:
@ -95,17 +99,19 @@ class Ann(metaclass=_Singleton):
model_path: str,
fast_math: bool = True,
fp16: bool = False,
save_cached_network: bool = False,
cached_network_path: str | None = None,
) -> int:
if not model_path.endswith((".armnn", ".tflite", ".onnx")):
raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx")
if not exists(model_path):
raise ValueError("model_path must point to an existing file!")
save_cached_network = False
if cached_network_path is not None and not exists(cached_network_path):
raise ValueError("cached_network_path must point to an existing (possibly empty) file!")
if save_cached_network and cached_network_path is None:
raise ValueError("save_cached_network is True, cached_network_path must be specified!")
save_cached_network = True
# create empty model cache file
open(cached_network_path, "a").close()
net_id: int = libann.load(
self.ann,
model_path.encode(),

View File

@ -8,6 +8,8 @@ from fastapi.testclient import TestClient
from numpy.typing import NDArray
from PIL import Image
from app.config import log
from .main import app
@ -96,12 +98,77 @@ def clip_tokenizer_cfg() -> dict[str, Any]:
@pytest.fixture(scope="function")
def providers(request: pytest.FixtureRequest) -> Iterator[dict[str, Any]]:
def providers(request: pytest.FixtureRequest) -> Iterator[mock.Mock]:
marker = request.node.get_closest_marker("providers")
if marker is None:
raise ValueError("Missing marker 'providers'")
providers = marker.args[0]
with mock.patch("app.models.base.ort.get_available_providers") as mocked:
with mock.patch("app.sessions.ort.ort.get_available_providers") as mocked:
mocked.return_value = providers
yield providers
@pytest.fixture(scope="function")
def ort_pybind() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ort.ort.capi._pybind_state") as mocked:
yield mocked
@pytest.fixture(scope="function")
def ov_device_ids(request: pytest.FixtureRequest, ort_pybind: mock.Mock) -> Iterator[mock.Mock]:
marker = request.node.get_closest_marker("ov_device_ids")
if marker is None:
raise ValueError("Missing marker 'ov_device_ids'")
ort_pybind.get_available_openvino_device_ids.return_value = marker.args[0]
return ort_pybind
@pytest.fixture(scope="function")
def ort_session() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ort.ort.InferenceSession") as mocked:
yield mocked
@pytest.fixture(scope="function")
def ann_session() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.ann.Ann") as mocked:
yield mocked
@pytest.fixture(scope="function")
def rmtree() -> Iterator[mock.Mock]:
with mock.patch("app.models.base.rmtree", autospec=True) as mocked:
mocked.avoids_symlink_attacks = True
yield mocked
@pytest.fixture(scope="function")
def path() -> Iterator[mock.Mock]:
path = mock.MagicMock()
path.exists.return_value = True
path.is_dir.return_value = True
path.is_file.return_value = True
path.with_suffix.return_value = path
path.return_value = path
with mock.patch("app.models.base.Path", return_value=path) as mocked:
yield mocked
@pytest.fixture(scope="function")
def info() -> Iterator[mock.Mock]:
with mock.patch.object(log, "info") as mocked:
yield mocked
@pytest.fixture(scope="function")
def warning() -> Iterator[mock.Mock]:
with mock.patch.object(log, "warning") as mocked:
yield mocked
@pytest.fixture(scope="function")
def snapshot_download() -> Iterator[mock.Mock]:
with mock.patch("app.models.base.snapshot_download") as mocked:
yield mocked

View File

@ -5,15 +5,14 @@ from pathlib import Path
from shutil import rmtree
from typing import Any, ClassVar
import onnxruntime as ort
from huggingface_hub import snapshot_download
import ann.ann
from app.models.constants import SUPPORTED_PROVIDERS
from app.sessions.ort import OrtSession
from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
from .ann import AnnSession
from ..sessions.ann import AnnSession
class InferenceModel(ABC):
@ -24,20 +23,17 @@ class InferenceModel(ABC):
self,
model_name: str,
cache_dir: Path | str | None = None,
providers: list[str] | None = None,
provider_options: list[dict[str, Any]] | None = None,
sess_options: ort.SessionOptions | None = None,
preferred_format: ModelFormat | None = None,
session: ModelSession | None = None,
**model_kwargs: Any,
) -> None:
self.loaded = False
self.loaded = session is not None
self.load_attempts = 0
self.model_name = clean_name(model_name)
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
self.providers = providers if providers is not None else self.providers_default
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
self.preferred_format = preferred_format if preferred_format is not None else self.preferred_format_default
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
self.model_format = preferred_format if preferred_format is not None else self._model_format_default
if session is not None:
self.session = session
def download(self) -> None:
if not self.cached:
@ -70,7 +66,7 @@ class InferenceModel(ABC):
pass
def _download(self) -> None:
ignore_patterns = [] if self.preferred_format == ModelFormat.ARMNN else ["*.armnn"]
ignore_patterns = [] if self.model_format == ModelFormat.ARMNN else ["*.armnn"]
snapshot_download(
f"immich-app/{clean_name(self.model_name)}",
cache_dir=self.cache_dir,
@ -105,26 +101,11 @@ class InferenceModel(ABC):
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _make_session(self, model_path: Path) -> ModelSession:
if not model_path.is_file():
onnx_path = model_path.with_suffix(".onnx")
if not onnx_path.is_file():
raise ValueError(f"Model path '{model_path}' does not exist")
log.warning(
f"Could not find model path '{model_path}'. " f"Falling back to ONNX model path '{onnx_path}' instead.",
)
model_path = onnx_path
match model_path.suffix:
case ".armnn":
session = AnnSession(model_path)
session: ModelSession = AnnSession(model_path)
case ".onnx":
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
session = OrtSession(model_path)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
@ -135,7 +116,7 @@ class InferenceModel(ABC):
@property
def model_path(self) -> Path:
return self.model_dir / f"model.{self.preferred_format}"
return self.model_dir / f"model.{self.model_format}"
@property
def model_task(self) -> ModelTask:
@ -154,7 +135,7 @@ class InferenceModel(ABC):
self._cache_dir = cache_dir
@property
def cache_dir_default(self) -> Path:
def _cache_dir_default(self) -> Path:
return settings.cache_folder / self.model_task.value / self.model_name
@property
@ -162,95 +143,18 @@ class InferenceModel(ABC):
return self.model_path.is_file()
@property
def providers(self) -> list[str]:
return self._providers
@providers.setter
def providers(self, providers: list[str]) -> None:
log.info(
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
)
self._providers = providers
@property
def providers_default(self) -> list[str]:
available_providers = set(ort.get_available_providers())
log.debug(f"Available ORT providers: {available_providers}")
if (openvino := "OpenVINOExecutionProvider") in available_providers:
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
log.debug(f"Available OpenVINO devices: {device_ids}")
gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
if not gpu_devices:
log.warning("No GPU device found in OpenVINO. Falling back to CPU.")
available_providers.remove(openvino)
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]
@property
def provider_options(self) -> list[dict[str, Any]]:
return self._provider_options
@provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options
@property
def provider_options_default(self) -> list[dict[str, Any]]:
options = []
for provider in self.providers:
match provider:
case "CPUExecutionProvider" | "CUDAExecutionProvider":
option = {"arena_extend_strategy": "kSameAsRequested"}
case "OpenVINOExecutionProvider":
option = {"device_type": "GPU_FP32", "cache_dir": (self.cache_dir / "openvino").as_posix()}
case _:
option = {}
options.append(option)
return options
@property
def sess_options(self) -> ort.SessionOptions:
return self._sess_options
@sess_options.setter
def sess_options(self, sess_options: ort.SessionOptions) -> None:
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
self._sess_options = sess_options
@property
def sess_options_default(self) -> ort.SessionOptions:
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False
# avoid thread contention between models
if settings.model_inter_op_threads > 0:
sess_options.inter_op_num_threads = settings.model_inter_op_threads
# these defaults work well for CPU, but bottleneck GPU
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.inter_op_num_threads = 1
if settings.model_intra_op_threads > 0:
sess_options.intra_op_num_threads = settings.model_intra_op_threads
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.intra_op_num_threads = 2
if sess_options.inter_op_num_threads > 1:
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
return sess_options
@property
def preferred_format(self) -> ModelFormat:
def model_format(self) -> ModelFormat:
return self._preferred_format
@preferred_format.setter
def preferred_format(self, preferred_format: ModelFormat) -> None:
@model_format.setter
def model_format(self, preferred_format: ModelFormat) -> None:
log.debug(f"Setting preferred format to {preferred_format}")
self._preferred_format = preferred_format
@property
def preferred_format_default(self) -> ModelFormat:
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
def _model_format_default(self) -> ModelFormat:
prefer_ann = ann.ann.is_available and settings.ann
ann_exists = (self.model_dir / "model.armnn").is_file()
if prefer_ann and not ann_exists:
log.warning(f"ARM NN is available, but '{self.model_name}' does not support ARM NN. Falling back to ONNX.")
return ModelFormat.ARMNN if prefer_ann and ann_exists else ModelFormat.ONNX

View File

@ -3,7 +3,6 @@ from typing import Any
import numpy as np
import onnx
import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX
from insightface.utils.face_align import norm_crop
from numpy.typing import NDArray
@ -13,7 +12,8 @@ from PIL import Image
from app.config import clean_name, log
from app.models.base import InferenceModel
from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelSession, ModelTask, ModelType
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
from app.sessions import has_batch_axis
class FaceRecognizer(InferenceModel):
@ -27,13 +27,14 @@ class FaceRecognizer(InferenceModel):
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
self.min_score = model_kwargs.pop("minScore", min_score)
self.batch = self.model_format == ModelFormat.ONNX
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if not self._has_batch_dim(session):
self._add_batch_dim(self.model_path)
if self.model_format == ModelFormat.ONNX and not has_batch_axis(session):
self._add_batch_axis(self.model_path)
session = self._make_session(self.model_path)
self.model = ArcFaceONNX(
self.model_path.with_suffix(".onnx").as_posix(),
@ -47,9 +48,20 @@ class FaceRecognizer(InferenceModel):
if faces["boxes"].shape[0] == 0:
return []
inputs = decode_cv2(inputs)
embeddings: NDArray[np.float32] = self.model.get_feat(self._crop(inputs, faces))
cropped_faces = self._crop(inputs, faces)
embeddings = self._predict_batch(cropped_faces) if self.batch else self._predict_single(cropped_faces)
return self.postprocess(faces, embeddings)
def _predict_batch(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
embeddings: NDArray[np.float32] = self.model.get_feat(cropped_faces)
return embeddings
def _predict_single(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
embeddings: list[NDArray[np.float32]] = []
for face in cropped_faces:
embeddings.append(self.model.get_feat(face))
return np.concatenate(embeddings, axis=0)
def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32]) -> FacialRecognitionOutput:
return [
{
@ -63,11 +75,8 @@ class FaceRecognizer(InferenceModel):
def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]:
return [norm_crop(image, landmark) for landmark in faces["landmarks"]]
def _has_batch_dim(self, session: ort.InferenceSession) -> bool:
return not isinstance(session, ort.InferenceSession) or session.get_inputs()[0].shape[0] == "batch"
def _add_batch_dim(self, model_path: Path) -> None:
log.debug(f"Adding batch dimension to model {model_path}")
def _add_batch_axis(self, model_path: Path) -> None:
log.debug(f"Adding batch axis to model {model_path}")
proto = onnx.load(model_path)
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]

View File

@ -54,6 +54,14 @@ class ModelSource(StrEnum):
ModelIdentity = tuple[ModelType, ModelTask]
class SessionNode(Protocol):
@property
def name(self) -> str | None: ...
@property
def shape(self) -> tuple[int, ...]: ...
class ModelSession(Protocol):
def run(
self,
@ -62,6 +70,10 @@ class ModelSession(Protocol):
run_options: Any = None,
) -> list[npt.NDArray[np.float32]]: ...
def get_inputs(self) -> list[SessionNode]: ...
def get_outputs(self) -> list[SessionNode]: ...
class HasProfiling(Protocol):
profiling: dict[str, float]

View File

@ -0,0 +1,5 @@
from app.schemas import ModelSession
def has_batch_axis(session: ModelSession) -> bool:
return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0

View File

@ -7,6 +7,7 @@ import numpy as np
from numpy.typing import NDArray
from ann.ann import Ann
from app.schemas import SessionNode
from ..config import log, settings
@ -16,27 +17,15 @@ class AnnSession:
Wrapper for ANN to be drop-in replacement for ONNX session.
"""
def __init__(self, model_path: Path):
tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann"
with tuning_file.open(mode="a"):
# make sure tuning file exists (without clearing contents)
# once filled, the tuning file reduces the cost/time of the first
# inference after model load by 10s of seconds
pass
self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix())
log.info("Loading ANN model %s ...", model_path)
cache_file = model_path.with_suffix(".anncache")
save = False
if not cache_file.is_file():
save = True
with cache_file.open(mode="a"):
# create empty model cache file
pass
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
self.model_path = model_path
self.cache_dir = cache_dir
self.ann = Ann(tuning_level=3, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
log.info("Loading ANN model %s ...", model_path)
self.model = self.ann.load(
model_path.as_posix(),
save_cached_network=save,
cached_network_path=cache_file.as_posix(),
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
)
log.info("Loaded ANN model with ID %d", self.model)
@ -45,11 +34,11 @@ class AnnSession:
log.info("Unloaded ANN model %d", self.model)
self.ann.destroy()
def get_inputs(self) -> list[AnnNode]:
def get_inputs(self) -> list[SessionNode]:
shapes = self.ann.input_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def get_outputs(self) -> list[AnnNode]:
def get_outputs(self) -> list[SessionNode]:
shapes = self.ann.output_shapes[self.model]
return [AnnNode(None, s) for s in shapes]

View File

@ -0,0 +1,129 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray
from app.models.constants import SUPPORTED_PROVIDERS
from app.schemas import SessionNode
from ..config import log, settings
class OrtSession:
def __init__(
self,
model_path: Path | str,
providers: list[str] | None = None,
provider_options: list[dict[str, Any]] | None = None,
sess_options: ort.SessionOptions | None = None,
):
self.model_path = Path(model_path)
self.providers = providers if providers is not None else self._providers_default
self.provider_options = provider_options if provider_options is not None else self._provider_options_default
self.sess_options = sess_options if sess_options is not None else self._sess_options_default
self.session = ort.InferenceSession(
self.model_path.as_posix(),
providers=self.providers,
provider_options=self.provider_options,
sess_options=self.sess_options,
)
def get_inputs(self) -> list[SessionNode]:
inputs: list[SessionNode] = self.session.get_inputs()
return inputs
def get_outputs(self) -> list[SessionNode]:
outputs: list[SessionNode] = self.session.get_outputs()
return outputs
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
outputs: list[NDArray[np.float32]] = self.session.run(output_names, input_feed, run_options)
return outputs
@property
def providers(self) -> list[str]:
return self._providers
@providers.setter
def providers(self, providers: list[str]) -> None:
log.info(f"Setting execution providers to {providers}, in descending order of preference")
self._providers = providers
@property
def _providers_default(self) -> list[str]:
available_providers = set(ort.get_available_providers())
log.debug(f"Available ORT providers: {available_providers}")
if (openvino := "OpenVINOExecutionProvider") in available_providers:
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
log.debug(f"Available OpenVINO devices: {device_ids}")
gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
if not gpu_devices:
log.warning("No GPU device found in OpenVINO. Falling back to CPU.")
available_providers.remove(openvino)
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]
@property
def provider_options(self) -> list[dict[str, Any]]:
return self._provider_options
@provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options
@property
def _provider_options_default(self) -> list[dict[str, Any]]:
options = []
for provider in self.providers:
match provider:
case "CPUExecutionProvider" | "CUDAExecutionProvider":
option = {"arena_extend_strategy": "kSameAsRequested"}
case "OpenVINOExecutionProvider":
option = {"device_type": "GPU_FP32", "cache_dir": (self.model_path.parent / "openvino").as_posix()}
case _:
option = {}
options.append(option)
return options
@property
def sess_options(self) -> ort.SessionOptions:
return self._sess_options
@sess_options.setter
def sess_options(self, sess_options: ort.SessionOptions) -> None:
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
self._sess_options = sess_options
@property
def _sess_options_default(self) -> ort.SessionOptions:
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False
# avoid thread contention between models
if settings.model_inter_op_threads > 0:
sess_options.inter_op_num_threads = settings.model_inter_op_threads
# these defaults work well for CPU, but bottleneck GPU
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.inter_op_num_threads = 1
if settings.model_intra_op_threads > 0:
sess_options.intra_op_num_threads = settings.model_intra_op_threads
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.intra_op_num_threads = 2
if sess_options.inter_op_num_threads > 1:
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
return sess_options

File diff suppressed because it is too large Load Diff

View File

@ -97,4 +97,4 @@ line-length = 120
target-version = ['py311']
[tool.pytest.ini_options]
markers = ["providers"]
markers = ["providers", "ov_device_ids"]