chore(ml): use strict mypy (#5001)

* improved typing

* improved export typing

* strict mypy & check export folder

* formatting

* add formatting checks for export folder

* re-added init call
This commit is contained in:
Mert
2023-11-13 11:18:46 -05:00
committed by GitHub
parent 9fa9ad05b1
commit 935f471ccb
10 changed files with 70 additions and 55 deletions

View File

@ -51,7 +51,7 @@ class BaseCLIPEncoder(InferenceModel):
provider_options=self.provider_options,
)
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
if isinstance(image_or_text, bytes):
image_or_text = Image.open(BytesIO(image_or_text))
@ -60,16 +60,16 @@ class BaseCLIPEncoder(InferenceModel):
if self.mode == "text":
raise TypeError("Cannot encode image as text-only model")
outputs = self.vision_model.run(None, self.transform(image_or_text))
outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
case str():
if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model")
outputs = self.text_model.run(None, self.tokenize(image_or_text))
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
return outputs[0][0].tolist()
return outputs
@abstractmethod
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
@ -151,11 +151,13 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
@cached_property
def model_cfg(self) -> dict[str, Any]:
return json.load(self.model_cfg_path.open())
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
return model_cfg
@cached_property
def preprocess_cfg(self) -> dict[str, Any]:
return json.load(self.preprocess_cfg_path.open())
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
return preprocess_cfg
class MCLIPEncoder(OpenCLIPEncoder):