Allow dynamic shapes for JAX.

This commit is contained in:
Francois Chollet 2023-05-20 19:01:01 -07:00
parent 524bcfcd57
commit 2bab1b1923
17 changed files with 97 additions and 114 deletions

@ -11,7 +11,6 @@ class KerasTensor:
shape = backend.standardize_shape(
shape,
allow_dynamic_batch_size=backend.DYNAMIC_BATCH_SIZE_OK,
allow_all_dynamic=backend.DYNAMIC_SHAPES_OK,
)
self.shape = shape

@ -4,7 +4,6 @@ from keras_core.backend.jax import math
from keras_core.backend.jax import nn
from keras_core.backend.jax import numpy
from keras_core.backend.jax import random
from keras_core.backend.jax.core import DYNAMIC_BATCH_SIZE_OK
from keras_core.backend.jax.core import DYNAMIC_SHAPES_OK
from keras_core.backend.jax.core import Variable
from keras_core.backend.jax.core import cast

@ -7,8 +7,7 @@ from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope
DYNAMIC_SHAPES_OK = False # Dynamic shapes NG
DYNAMIC_BATCH_SIZE_OK = True
DYNAMIC_SHAPES_OK = True
class Variable(KerasVariable):
@ -59,67 +58,115 @@ def name_scope(name):
# Shape / dtype inference util
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():
dynamic_batch_map = {}
magic_number = 3
def convert_keras_tensor_to_jax(x):
if isinstance(x, KerasTensor):
shape = x.shape
if shape and x.shape[0] is None:
shape = list(shape)
shape[0] = magic_number
dynamic_batch = True
else:
dynamic_batch = False
jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
dynamic_batch_map[jax_tensor] = dynamic_batch
return jax_tensor
return x
all_input_ktensors = []
built_in_types = (type(None), int, float, str, bool, complex, bytes)
# First, separate symbolic args from other args
static_args = []
maybe_symbolic_args = []
static_kwargs = {}
maybe_symbolic_kwargs = {}
for arg in args:
if isinstance(arg, built_in_types):
static_args.append(arg)
else:
maybe_symbolic_args.append(arg)
static_kwargs = {}
maybe_symbolic_kwargs = {}
for (
k,
arg,
) in kwargs.items():
if isinstance(arg, built_in_types):
static_kwargs[k] = arg
for k, v in kwargs.items():
if isinstance(v, built_in_types):
static_kwargs[k] = v
else:
maybe_symbolic_kwargs[k] = arg
maybe_symbolic_kwargs[k] = v
# Second, identify all ktensors
def index_all_ktensors(x):
if isinstance(x, KerasTensor):
all_input_ktensors.append(x)
return x
# Third, find out if there are dynamic shapes
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
index_all_ktensors, (maybe_symbolic_args, maybe_symbolic_kwargs)
)
none_count = 0
for x in all_input_ktensors:
for d in x.shape:
if d is None:
none_count += 1
def convert_keras_tensor_to_jax(x, fill_value=None):
if isinstance(x, KerasTensor):
shape = list(x.shape)
if fill_value:
for i, e in enumerate(shape):
if e is None:
shape[i] = fill_value
jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
return jax_tensor
return x
def wrapped_fn(*args, **kwargs):
return fn(*args, *static_args, **kwargs, **static_kwargs)
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
convert_keras_tensor_to_jax,
(maybe_symbolic_args, maybe_symbolic_kwargs),
)
_, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
*maybe_symbolic_args, **maybe_symbolic_kwargs
)
jax_out = None
if none_count:
try:
ms_args_1, ms_kwargs_1 = nest.map_structure(
lambda x: convert_keras_tensor_to_jax(x, fill_value=83),
(maybe_symbolic_args, maybe_symbolic_kwargs),
)
_, jax_out_1 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
*ms_args_1, **ms_kwargs_1
)
ms_args_2, ms_kwargs_2 = nest.map_structure(
lambda x: convert_keras_tensor_to_jax(x, fill_value=89),
(maybe_symbolic_args, maybe_symbolic_kwargs),
)
_, jax_out_2 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
*ms_args_2, **ms_kwargs_2
)
flat_out_1 = nest.flatten(jax_out_1)
flat_out_2 = nest.flatten(jax_out_2)
flat_out = []
for x1, x2 in zip(flat_out_1, flat_out_2):
if isinstance(x1, jax.ShapeDtypeStruct):
if not isinstance(x2, jax.ShapeDtypeStruct):
raise ValueError("Indeterministic output ordering.")
shape = list(x1.shape)
for i, e in enumerate(x2.shape):
if e != shape[i]:
shape[i] = None
flat_out.append(
jax.ShapeDtypeStruct(shape, dtype=x1.dtype)
)
else:
flat_out.append(x1)
jax_out = nest.pack_sequence_as(jax_out_1, flat_out)
except:
# Errors can happen when the filled dimensions
# are not compatible with the function
# (or when the function contains a bug).
# In such cases we don't want to confuse users
# with random filled dimensions and the like,
# so we rerun a pass on the dynamic shapes,
# which will likely error out when JAX tries to
# validate shapes as fully static.
# The error message will be much easier to understand.
pass
if jax_out is None:
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
convert_keras_tensor_to_jax,
(maybe_symbolic_args, maybe_symbolic_kwargs),
)
_, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
*maybe_symbolic_args, **maybe_symbolic_kwargs
)
def convert_jax_spec_to_keras_tensor(x):
if isinstance(x, jax.ShapeDtypeStruct):
if dynamic_batch_map.get(x, False):
shape = list(x.shape)
if shape[0] != magic_number:
raise ValueError(
f"Function {fn} appears to change the "
"batch size of its input. This is not "
"allowed when used in conjunction with "
"dynamic batch sizes. Consider using "
"a static batch size here."
)
shape[0] = None
return KerasTensor(x.shape, x.dtype)
return x

@ -4,7 +4,6 @@ from keras_core.backend.tensorflow import math
from keras_core.backend.tensorflow import nn
from keras_core.backend.tensorflow import numpy
from keras_core.backend.tensorflow import random
from keras_core.backend.tensorflow.core import DYNAMIC_BATCH_SIZE_OK
from keras_core.backend.tensorflow.core import DYNAMIC_SHAPES_OK
from keras_core.backend.tensorflow.core import Variable
from keras_core.backend.tensorflow.core import cast

@ -7,7 +7,6 @@ from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.naming import auto_name
DYNAMIC_SHAPES_OK = True
DYNAMIC_BATCH_SIZE_OK = True
class Variable(KerasVariable, tf.__internal__.types.Tensor):

@ -4,7 +4,6 @@ from keras_core.backend.torch import math
from keras_core.backend.torch import nn
from keras_core.backend.torch import numpy
from keras_core.backend.torch import random
from keras_core.backend.torch.core import DYNAMIC_BATCH_SIZE_OK
from keras_core.backend.torch.core import DYNAMIC_SHAPES_OK
from keras_core.backend.torch.core import Variable
from keras_core.backend.torch.core import cast

@ -4,7 +4,6 @@ from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype
DYNAMIC_SHAPES_OK = True
DYNAMIC_BATCH_SIZE_OK = True
TORCH_DTYPES = {

@ -197,7 +197,7 @@ def one_hot(x, num_classes, axis=-1):
new_axes_order = list(range(dims))
new_axes_order[axis] = -1 # Shifts output to axis positon
# Shift remaining axes with offset by 1 since output moved to `axis`.
for ax in range(axis+1, dims):
for ax in range(axis + 1, dims):
new_axes_order[ax] -= 1
output = output.permute(new_axes_order)
return output

@ -30,8 +30,8 @@ class DropoutTest(testing.TestCase):
self.assertAllClose(np.max(outputs), 2.0)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
backend.backend() == "jax",
reason="JAX does not support dynamic shapes",
)
def test_dropout_partial_noise_shape_dynamic(self):
inputs = np.ones((20, 5, 10))

@ -1,7 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
from keras_core.testing import test_case
@ -52,10 +50,6 @@ class SpatialDropoutTest(test_case.TestCase):
input_shape=(2, 3, 4, 4, 5),
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_spatial_dropout_1D_dynamic(self):
inputs = layers.Input((3, 2))
layer = layers.SpatialDropout1D(0.5)
@ -67,10 +61,6 @@ class SpatialDropoutTest(test_case.TestCase):
outputs = layer(inputs, training=True)
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_spatial_dropout_2D_dynamic(self):
inputs = layers.Input((3, 2, 4))
layer = layers.SpatialDropout2D(0.5)
@ -82,10 +72,6 @@ class SpatialDropoutTest(test_case.TestCase):
outputs = layer(inputs, training=True)
self.assertAllClose(outputs[:, 0, 0, :], outputs[:, 1, 1, :])
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_spatial_dropout_3D_dynamic(self):
inputs = layers.Input((3, 2, 4, 2))
layer = layers.SpatialDropout3D(0.5)

@ -48,10 +48,6 @@ class UpSamplingTest(testing.TestCase):
layers.UpSampling1D(size=3)(np.ones((2, 1, 5))), np.ones((2, 3, 5))
)
@pytest.mark.skipif(
not backend.DYNAMIC_BATCH_SIZE_OK,
reason="Backend does not support dynamic batch sizes",
)
def test_upsampling_1d_with_dynamic_batch_size(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(layers.UpSampling1D(size=2)(x).shape, (None, 4, 3))

@ -1,19 +1,9 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import testing
from keras_core.operations import core
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class CoreOpsDynamicShapeTest(testing.TestCase):
pass
class CoreOpsStaticShapeTest(testing.TestCase):
def test_scatter(self):
# Requires dtype

@ -1,7 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import testing
from keras_core.backend.common import keras_tensor
from keras_core.operations import function
@ -42,10 +40,6 @@ class FunctionTest(testing.TestCase):
self.assertAllClose(y_val[0], np.ones((2, 3)) * 6)
self.assertAllClose(y_val[1], np.ones((2, 3)) * 4)
@pytest.mark.skipif(
not backend.DYNAMIC_BATCH_SIZE_OK,
reason="Test only valid if dynamic batch sizes are supported",
)
def test_dynamic_shape_inference(self):
x = keras_tensor.KerasTensor((None, 3))
y = x**2

@ -1,18 +1,12 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
from keras_core import backend
from keras_core import testing
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.operations import image as kimage
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class ImageOpsDynamicShapeTest(testing.TestCase):
def test_resize(self):
x = KerasTensor([None, 20, 20, 3])

@ -8,10 +8,6 @@ from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.operations import math as kmath
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class MathOpsDynamicShapeTest(testing.TestCase):
def test_segment_sum(self):
data = KerasTensor((None, 4), dtype="float32")

@ -1,17 +1,11 @@
import numpy as np
import pytest
import tensorflow as tf
from keras_core import backend
from keras_core import testing
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.operations import nn as knn
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class NNOpsDynamicShapeTest(testing.TestCase):
def test_relu(self):
x = KerasTensor([None, 2, 3])

@ -11,10 +11,6 @@ from keras_core.operations import numpy as knp
np_config.enable_numpy_behavior()
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class NumpyTwoInputOpsDynamicShapeTest(testing.TestCase):
def test_add(self):
x = KerasTensor((None, 3))
@ -623,10 +619,6 @@ class NumpyTwoInputOpsStaticShapeTest(testing.TestCase):
self.assertEqual(knp.where(condition, x, y).shape, (2, 3))
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
def test_mean(self):
x = KerasTensor([None, 3])