forked from dark_thunder/immich
feat(ml): ARMNN acceleration (#5667)
* feat(ml): ARMNN acceleration for CLIP * wrap ANN as ONNX-Session * strict typing * normalize ARMNN CLIP embedding * mutex to handle concurrent execution * make inputs contiguous * fine-grained locking; concurrent network execution --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:

committed by
GitHub

parent
29747437f6
commit
753292956e
@ -10,8 +10,11 @@ import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
from typing_extensions import Buffer
|
||||
|
||||
import ann.ann
|
||||
|
||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||
from ..schemas import ModelType
|
||||
from .ann import AnnSession
|
||||
|
||||
|
||||
class InferenceModel(ABC):
|
||||
@ -138,6 +141,21 @@ class InferenceModel(ABC):
|
||||
self.cache_dir.unlink()
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
|
||||
armnn_path = model_path.with_suffix(".armnn")
|
||||
if settings.ann and ann.ann.is_available and armnn_path.is_file():
|
||||
session = AnnSession(armnn_path)
|
||||
elif model_path.is_file():
|
||||
session = ort.InferenceSession(
|
||||
model_path.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"the file model_path='{model_path}' does not exist")
|
||||
return session
|
||||
|
||||
|
||||
# HF deep copies configs, so we need to make session options picklable
|
||||
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
||||
|
Reference in New Issue
Block a user