fix(ml): load models in separate threads (#4034)

* load models in thread

* set clip mode logs to debug level

* updated tests

* made fixtures slightly less ugly

* moved responses to json file

* formatting
This commit is contained in:
Mert
2023-09-09 05:02:44 -04:00
committed by GitHub
parent f1db257628
commit 258b98c262
9 changed files with 1683 additions and 114 deletions

View File

@ -1,10 +1,13 @@
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from zipfile import BadZipFile
import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel
@ -31,6 +34,7 @@ def init_state() -> None:
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@ -63,14 +67,49 @@ async def predict(
inputs = text
else:
raise HTTPException(400, "Either image or text must be provided")
try:
kwargs = orjson.loads(options)
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs)
outputs = await run(model, inputs)
return ORJSONResponse(outputs)
async def run(model: InferenceModel, inputs: Any) -> Any:
if app.state.thread_pool is not None:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
else:
if app.state.thread_pool is None:
return model.predict(inputs)
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
async def load(model: InferenceModel) -> InferenceModel:
if model.loaded:
return model
def _load() -> None:
with app.state.locks[model.model_type]:
model.load()
loop = asyncio.get_running_loop()
try:
if app.state.thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warn(
(
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
"Clearing cache and retrying."
)
)
model.clear_cache()
if app.state.thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
return model