90baabee5d
* Add stop_gradient to keras_core.backend * Use `.requires_grad_(False)` instead of `.detach()` * Add test for stop_gradient and add a StopGradient Op * Return a spec instead of list in StopGradient.compute_output_spec * Use models.Sequential and self.assertEqual in tests * Use keras_core.backend.convert_to_tensor instead of tf in doc example * Don't use tf and remove the noqa comment
66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
from keras_core import backend
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.backend import KerasTensor
|
|
from keras_core.backend import any_symbolic_tensors
|
|
from keras_core.operations.operation import Operation
|
|
|
|
|
|
class Resize(Operation):
|
|
def __init__(
|
|
self,
|
|
size,
|
|
method="bilinear",
|
|
antialias=False,
|
|
data_format="channels_last",
|
|
):
|
|
super().__init__()
|
|
self.size = tuple(size)
|
|
self.method = method
|
|
self.antialias = antialias
|
|
self.data_format = data_format
|
|
|
|
def call(self, image):
|
|
return backend.image.resize(
|
|
image,
|
|
self.size,
|
|
method=self.method,
|
|
antialias=self.antialias,
|
|
data_format=self.data_format,
|
|
)
|
|
|
|
def compute_output_spec(self, image):
|
|
if len(image.shape) == 3:
|
|
return KerasTensor(
|
|
self.size + (image.shape[-1],), dtype=image.dtype
|
|
)
|
|
elif len(image.shape) == 4:
|
|
if self.data_format == "channels_last":
|
|
return KerasTensor(
|
|
(image.shape[0],) + self.size + (image.shape[-1],),
|
|
dtype=image.dtype,
|
|
)
|
|
else:
|
|
return KerasTensor(
|
|
(image.shape[0], image.shape[1]) + self.size,
|
|
dtype=image.dtype,
|
|
)
|
|
raise ValueError(
|
|
"Invalid input rank: expected rank 3 (single image) "
|
|
"or rank 4 (batch of images). Received input with shape: "
|
|
f"image.shape={image.shape}"
|
|
)
|
|
|
|
|
|
@keras_core_export("keras_core.operations.image.resize")
|
|
def resize(
|
|
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
|
):
|
|
# TODO: add docstring
|
|
if any_symbolic_tensors((image,)):
|
|
return Resize(
|
|
size, method=method, antialias=antialias, data_format=data_format
|
|
).symbolic_call(image)
|
|
return backend.image.resize(
|
|
image, size, method=method, antialias=antialias, data_format=data_format
|
|
)
|