feat(ml) backend takes image over HTTP (#2783)

* using pydantic BaseSetting

* ML API takes image file as input

* keeping image in memory

* reducing duplicate code

* using bytes instead of UploadFile & other small code improvements

* removed form-multipart, using HTTP body

* format code

---------

Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
Zeeshan Khan
2023-06-17 22:49:19 -05:00
committed by GitHub
parent 3e804f16df
commit 34201be74c
8 changed files with 116 additions and 80 deletions

View File

@ -1,4 +1,5 @@
import os
import io
from typing import Any
from cache import ModelCache
@ -9,52 +10,44 @@ from schemas import (
MessageResponse,
TextModelRequest,
TextResponse,
VisionModelRequest,
)
import uvicorn
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Depends, Body
from models import get_model, run_classification, run_facial_recognition
classification_model = os.getenv(
"MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50"
)
clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32")
clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32")
facial_recognition_model = os.getenv(
"MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL", "buffalo_l"
)
min_tag_score = float(os.getenv("MACHINE_LEARNING_MIN_TAG_SCORE", 0.9))
eager_startup = (
os.getenv("MACHINE_LEARNING_EAGER_STARTUP", "true") == "true"
) # loads all models at startup
model_ttl = int(os.getenv("MACHINE_LEARNING_MODEL_TTL", 300))
from config import settings
_model_cache = None
app = FastAPI()
@app.on_event("startup")
async def startup_event() -> None:
global _model_cache
_model_cache = ModelCache(ttl=model_ttl, revalidate=True)
_model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
models = [
(classification_model, "image-classification"),
(clip_image_model, "clip"),
(clip_text_model, "clip"),
(facial_recognition_model, "facial-recognition"),
(settings.classification_model, "image-classification"),
(settings.clip_image_model, "clip"),
(settings.clip_text_model, "clip"),
(settings.facial_recognition_model, "facial-recognition"),
]
# Get all models
for model_name, model_type in models:
if eager_startup:
if settings.eager_startup:
await _model_cache.get_cached_model(model_name, model_type)
else:
get_model(model_name, model_type)
def dep_model_cache():
if _model_cache is None:
raise HTTPException(status_code=500, detail="Unable to load model.")
def dep_input_image(image: bytes = Body(...)) -> Image:
return Image.open(io.BytesIO(image))
@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
return {"message": "Immich ML"}
@ -65,29 +58,36 @@ def ping() -> str:
return "pong"
@app.post("/image-classifier/tag-image", response_model=TagResponse, status_code=200)
async def image_classification(payload: VisionModelRequest) -> list[str]:
if _model_cache is None:
raise HTTPException(status_code=500, detail="Unable to load model.")
model = await _model_cache.get_cached_model(
classification_model, "image-classification"
)
labels = run_classification(model, payload.image_path, min_tag_score)
return labels
@app.post(
"/image-classifier/tag-image",
response_model=TagResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def image_classification(
image: Image = Depends(dep_input_image)
) -> list[str]:
try:
model = await _model_cache.get_cached_model(
settings.classification_model, "image-classification"
)
labels = run_classification(model, image, settings.min_tag_score)
except Exception as ex:
raise HTTPException(status_code=500, detail=str(ex))
else:
return labels
@app.post(
"/sentence-transformer/encode-image",
response_model=EmbeddingResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def clip_encode_image(payload: VisionModelRequest) -> list[float]:
if _model_cache is None:
raise HTTPException(status_code=500, detail="Unable to load model.")
model = await _model_cache.get_cached_model(clip_image_model, "clip")
image = Image.open(payload.image_path)
async def clip_encode_image(
image: Image = Depends(dep_input_image)
) -> list[float]:
model = await _model_cache.get_cached_model(settings.clip_image_model, "clip")
embedding = model.encode(image).tolist()
return embedding
@ -96,33 +96,38 @@ async def clip_encode_image(payload: VisionModelRequest) -> list[float]:
"/sentence-transformer/encode-text",
response_model=EmbeddingResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
if _model_cache is None:
raise HTTPException(status_code=500, detail="Unable to load model.")
model = await _model_cache.get_cached_model(clip_text_model, "clip")
async def clip_encode_text(
payload: TextModelRequest
) -> list[float]:
model = await _model_cache.get_cached_model(settings.clip_text_model, "clip")
embedding = model.encode(payload.text).tolist()
return embedding
@app.post(
"/facial-recognition/detect-faces", response_model=FaceResponse, status_code=200
"/facial-recognition/detect-faces",
response_model=FaceResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def facial_recognition(payload: VisionModelRequest) -> list[dict[str, Any]]:
if _model_cache is None:
raise HTTPException(status_code=500, detail="Unable to load model.")
async def facial_recognition(
image: bytes = Body(...),
) -> list[dict[str, Any]]:
model = await _model_cache.get_cached_model(
facial_recognition_model, "facial-recognition"
settings.facial_recognition_model, "facial-recognition"
)
faces = run_facial_recognition(model, payload.image_path)
faces = run_facial_recognition(model, image)
return faces
if __name__ == "__main__":
host = os.getenv("MACHINE_LEARNING_HOST", "0.0.0.0")
port = int(os.getenv("MACHINE_LEARNING_PORT", 3003))
is_dev = os.getenv("NODE_ENV") == "development"
uvicorn.run("main:app", host=host, port=port, reload=is_dev, workers=1)
uvicorn.run(
"main:app",
host=settings.host,
port=settings.port,
reload=is_dev,
workers=settings.workers,
)