fix(ml): load models in separate threads (#4034)
* load models in thread * set clip mode logs to debug level * updated tests * made fixtures slightly less ugly * moved responses to json file * formatting
This commit is contained in:
@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import orjson
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
|
||||
from starlette.formparsers import MultiPartParser
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
@ -31,6 +34,7 @@ def init_state() -> None:
|
||||
)
|
||||
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
|
||||
app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
|
||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||
|
||||
|
||||
@ -63,14 +67,49 @@ async def predict(
|
||||
inputs = text
|
||||
else:
|
||||
raise HTTPException(400, "Either image or text must be provided")
|
||||
try:
|
||||
kwargs = orjson.loads(options)
|
||||
except orjson.JSONDecodeError:
|
||||
raise HTTPException(400, f"Invalid options JSON: {options}")
|
||||
|
||||
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
|
||||
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
|
||||
model.configure(**kwargs)
|
||||
outputs = await run(model, inputs)
|
||||
return ORJSONResponse(outputs)
|
||||
|
||||
|
||||
async def run(model: InferenceModel, inputs: Any) -> Any:
|
||||
if app.state.thread_pool is not None:
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
else:
|
||||
if app.state.thread_pool is None:
|
||||
return model.predict(inputs)
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
|
||||
|
||||
async def load(model: InferenceModel) -> InferenceModel:
|
||||
if model.loaded:
|
||||
return model
|
||||
|
||||
def _load() -> None:
|
||||
with app.state.locks[model.model_type]:
|
||||
model.load()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if app.state.thread_pool is None:
|
||||
model.load()
|
||||
else:
|
||||
await loop.run_in_executor(app.state.thread_pool, _load)
|
||||
return model
|
||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||
log.warn(
|
||||
(
|
||||
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
|
||||
"Clearing cache and retrying."
|
||||
)
|
||||
)
|
||||
model.clear_cache()
|
||||
if app.state.thread_pool is None:
|
||||
model.load()
|
||||
else:
|
||||
await loop.run_in_executor(app.state.thread_pool, _load)
|
||||
return model
|
||||
|
Reference in New Issue
Block a user