keras/keras_core/operations/image_test.py
Neel Kovelamudi b14eaaf5da Adds torch image resizing and torchvision dependency. (#220)
* Add PyTorch numpy functionality

* Add dtype conversion

* Partial fix for PyTorch numpy tests

* small logic fix

* Revert numpy_test

* Add tensor conversion to numpy

* Fix some arithmetic tests

* Fix some torch functions for numpy compatibility

* Fix pytorch ops for numpy compatibility, add TODOs

* Fix formatting

* Implement nits and fix dtype standardization

* Add pytest skipif decorator and fix nits

* Fix formatting and rename dtypes map

* Split tests by backend

* Merge space

* Fix dtype issues from new type checking

* Implement torch.full and torch.full_like numpy compatible

* Implements logspace and linspace with tensor support for start and stop

* Replace len of shape with ndim

* Fix formatting

* Implement torch.trace

* Implement eye k diagonal arg

* Implement torch.tri

* Fix formatting issues

* Fix torch.take dimensionality

* Add split functionality

* Revert torch.eye implementation to prevent conflict

* Implement all padding modes

* Adds torch image resizing and torchvision dependency.

* Fix conditional syntax

* Make torchvision import optional
2023-05-30 16:19:01 +00:00

104 lines
3.7 KiB
Python

import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import backend
from keras_core import testing
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.operations import image as kimage
class ImageOpsDynamicShapeTest(testing.TestCase):
def test_resize(self):
x = KerasTensor([None, 20, 20, 3])
out = kimage.resize(x, size=(15, 15))
self.assertEqual(out.shape, (None, 15, 15, 3))
x = KerasTensor([None, None, 3])
out = kimage.resize(x, size=(15, 15))
self.assertEqual(out.shape, (15, 15, 3))
class ImageOpsStaticShapeTest(testing.TestCase):
def test_resize(self):
x = KerasTensor([20, 20, 3])
out = kimage.resize(x, size=(15, 15))
self.assertEqual(out.shape, (15, 15, 3))
class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
[
("bilinear", True, "channels_last"),
("nearest", True, "channels_last"),
("lanczos3", True, "channels_last"),
("lanczos5", True, "channels_last"),
("bicubic", True, "channels_last"),
("bilinear", False, "channels_last"),
("nearest", False, "channels_last"),
("lanczos3", False, "channels_last"),
("lanczos5", False, "channels_last"),
("bicubic", False, "channels_last"),
("bilinear", True, "channels_first"),
]
)
def test_resize(self, method, antialias, data_format):
if backend.backend() == "torch":
if "lanczos" in method:
self.skipTest(
"Resizing with Lanczos interpolation is "
"not supported by the PyTorch backend. "
f"Received: method={method}."
)
if method == "bicubic" and antialias is False:
self.skipTest(
"Resizing with Bicubic interpolation in "
"PyTorch backend produces noise. Please "
"turn on anti-aliasing. "
f"Received: method={method}, "
f"antialias={antialias}."
)
# Unbatched case
if data_format == "channels_first":
x = np.random.random((3, 50, 50)) * 255
else:
x = np.random.random((50, 50, 3)) * 255
out = kimage.resize(
x,
size=(25, 25),
method=method,
antialias=antialias,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (1, 2, 0))
ref_out = tf.image.resize(
x, size=(25, 25), method=method, antialias=antialias
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (2, 0, 1))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)
# Batched case
if data_format == "channels_first":
x = np.random.random((2, 3, 50, 50)) * 255
else:
x = np.random.random((2, 50, 50, 3)) * 255
out = kimage.resize(
x,
size=(25, 25),
method=method,
antialias=antialias,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (0, 2, 3, 1))
ref_out = tf.image.resize(
x, size=(25, 25), method=method, antialias=antialias
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)