Allow dynamic shapes for JAX.
This commit is contained in:
parent
524bcfcd57
commit
2bab1b1923
@ -11,7 +11,6 @@ class KerasTensor:
|
||||
|
||||
shape = backend.standardize_shape(
|
||||
shape,
|
||||
allow_dynamic_batch_size=backend.DYNAMIC_BATCH_SIZE_OK,
|
||||
allow_all_dynamic=backend.DYNAMIC_SHAPES_OK,
|
||||
)
|
||||
self.shape = shape
|
||||
|
@ -4,7 +4,6 @@ from keras_core.backend.jax import math
|
||||
from keras_core.backend.jax import nn
|
||||
from keras_core.backend.jax import numpy
|
||||
from keras_core.backend.jax import random
|
||||
from keras_core.backend.jax.core import DYNAMIC_BATCH_SIZE_OK
|
||||
from keras_core.backend.jax.core import DYNAMIC_SHAPES_OK
|
||||
from keras_core.backend.jax.core import Variable
|
||||
from keras_core.backend.jax.core import cast
|
||||
|
@ -7,8 +7,7 @@ from keras_core.backend.common import standardize_dtype
|
||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.backend.common.stateless_scope import StatelessScope
|
||||
|
||||
DYNAMIC_SHAPES_OK = False # Dynamic shapes NG
|
||||
DYNAMIC_BATCH_SIZE_OK = True
|
||||
DYNAMIC_SHAPES_OK = True
|
||||
|
||||
|
||||
class Variable(KerasVariable):
|
||||
@ -59,46 +58,105 @@ def name_scope(name):
|
||||
# Shape / dtype inference util
|
||||
def compute_output_spec(fn, *args, **kwargs):
|
||||
with StatelessScope():
|
||||
dynamic_batch_map = {}
|
||||
magic_number = 3
|
||||
|
||||
def convert_keras_tensor_to_jax(x):
|
||||
if isinstance(x, KerasTensor):
|
||||
shape = x.shape
|
||||
if shape and x.shape[0] is None:
|
||||
shape = list(shape)
|
||||
shape[0] = magic_number
|
||||
dynamic_batch = True
|
||||
else:
|
||||
dynamic_batch = False
|
||||
|
||||
jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
|
||||
dynamic_batch_map[jax_tensor] = dynamic_batch
|
||||
return jax_tensor
|
||||
return x
|
||||
|
||||
all_input_ktensors = []
|
||||
built_in_types = (type(None), int, float, str, bool, complex, bytes)
|
||||
|
||||
# First, separate symbolic args from other args
|
||||
static_args = []
|
||||
maybe_symbolic_args = []
|
||||
static_kwargs = {}
|
||||
maybe_symbolic_kwargs = {}
|
||||
for arg in args:
|
||||
if isinstance(arg, built_in_types):
|
||||
static_args.append(arg)
|
||||
else:
|
||||
maybe_symbolic_args.append(arg)
|
||||
static_kwargs = {}
|
||||
maybe_symbolic_kwargs = {}
|
||||
for (
|
||||
k,
|
||||
arg,
|
||||
) in kwargs.items():
|
||||
if isinstance(arg, built_in_types):
|
||||
static_kwargs[k] = arg
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, built_in_types):
|
||||
static_kwargs[k] = v
|
||||
else:
|
||||
maybe_symbolic_kwargs[k] = arg
|
||||
maybe_symbolic_kwargs[k] = v
|
||||
|
||||
# Second, identify all ktensors
|
||||
def index_all_ktensors(x):
|
||||
if isinstance(x, KerasTensor):
|
||||
all_input_ktensors.append(x)
|
||||
return x
|
||||
|
||||
# Third, find out if there are dynamic shapes
|
||||
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
|
||||
index_all_ktensors, (maybe_symbolic_args, maybe_symbolic_kwargs)
|
||||
)
|
||||
none_count = 0
|
||||
for x in all_input_ktensors:
|
||||
for d in x.shape:
|
||||
if d is None:
|
||||
none_count += 1
|
||||
|
||||
def convert_keras_tensor_to_jax(x, fill_value=None):
|
||||
if isinstance(x, KerasTensor):
|
||||
shape = list(x.shape)
|
||||
if fill_value:
|
||||
for i, e in enumerate(shape):
|
||||
if e is None:
|
||||
shape[i] = fill_value
|
||||
jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
|
||||
return jax_tensor
|
||||
return x
|
||||
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
return fn(*args, *static_args, **kwargs, **static_kwargs)
|
||||
|
||||
jax_out = None
|
||||
if none_count:
|
||||
try:
|
||||
ms_args_1, ms_kwargs_1 = nest.map_structure(
|
||||
lambda x: convert_keras_tensor_to_jax(x, fill_value=83),
|
||||
(maybe_symbolic_args, maybe_symbolic_kwargs),
|
||||
)
|
||||
_, jax_out_1 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
|
||||
*ms_args_1, **ms_kwargs_1
|
||||
)
|
||||
|
||||
ms_args_2, ms_kwargs_2 = nest.map_structure(
|
||||
lambda x: convert_keras_tensor_to_jax(x, fill_value=89),
|
||||
(maybe_symbolic_args, maybe_symbolic_kwargs),
|
||||
)
|
||||
_, jax_out_2 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
|
||||
*ms_args_2, **ms_kwargs_2
|
||||
)
|
||||
|
||||
flat_out_1 = nest.flatten(jax_out_1)
|
||||
flat_out_2 = nest.flatten(jax_out_2)
|
||||
|
||||
flat_out = []
|
||||
for x1, x2 in zip(flat_out_1, flat_out_2):
|
||||
if isinstance(x1, jax.ShapeDtypeStruct):
|
||||
if not isinstance(x2, jax.ShapeDtypeStruct):
|
||||
raise ValueError("Indeterministic output ordering.")
|
||||
shape = list(x1.shape)
|
||||
for i, e in enumerate(x2.shape):
|
||||
if e != shape[i]:
|
||||
shape[i] = None
|
||||
flat_out.append(
|
||||
jax.ShapeDtypeStruct(shape, dtype=x1.dtype)
|
||||
)
|
||||
else:
|
||||
flat_out.append(x1)
|
||||
jax_out = nest.pack_sequence_as(jax_out_1, flat_out)
|
||||
except:
|
||||
# Errors can happen when the filled dimensions
|
||||
# are not compatible with the function
|
||||
# (or when the function contains a bug).
|
||||
# In such cases we don't want to confuse users
|
||||
# with random filled dimensions and the like,
|
||||
# so we rerun a pass on the dynamic shapes,
|
||||
# which will likely error out when JAX tries to
|
||||
# validate shapes as fully static.
|
||||
# The error message will be much easier to understand.
|
||||
pass
|
||||
|
||||
if jax_out is None:
|
||||
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
|
||||
convert_keras_tensor_to_jax,
|
||||
(maybe_symbolic_args, maybe_symbolic_kwargs),
|
||||
@ -109,17 +167,6 @@ def compute_output_spec(fn, *args, **kwargs):
|
||||
|
||||
def convert_jax_spec_to_keras_tensor(x):
|
||||
if isinstance(x, jax.ShapeDtypeStruct):
|
||||
if dynamic_batch_map.get(x, False):
|
||||
shape = list(x.shape)
|
||||
if shape[0] != magic_number:
|
||||
raise ValueError(
|
||||
f"Function {fn} appears to change the "
|
||||
"batch size of its input. This is not "
|
||||
"allowed when used in conjunction with "
|
||||
"dynamic batch sizes. Consider using "
|
||||
"a static batch size here."
|
||||
)
|
||||
shape[0] = None
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
return x
|
||||
|
||||
|
@ -4,7 +4,6 @@ from keras_core.backend.tensorflow import math
|
||||
from keras_core.backend.tensorflow import nn
|
||||
from keras_core.backend.tensorflow import numpy
|
||||
from keras_core.backend.tensorflow import random
|
||||
from keras_core.backend.tensorflow.core import DYNAMIC_BATCH_SIZE_OK
|
||||
from keras_core.backend.tensorflow.core import DYNAMIC_SHAPES_OK
|
||||
from keras_core.backend.tensorflow.core import Variable
|
||||
from keras_core.backend.tensorflow.core import cast
|
||||
|
@ -7,7 +7,6 @@ from keras_core.backend.common.stateless_scope import StatelessScope
|
||||
from keras_core.utils.naming import auto_name
|
||||
|
||||
DYNAMIC_SHAPES_OK = True
|
||||
DYNAMIC_BATCH_SIZE_OK = True
|
||||
|
||||
|
||||
class Variable(KerasVariable, tf.__internal__.types.Tensor):
|
||||
|
@ -4,7 +4,6 @@ from keras_core.backend.torch import math
|
||||
from keras_core.backend.torch import nn
|
||||
from keras_core.backend.torch import numpy
|
||||
from keras_core.backend.torch import random
|
||||
from keras_core.backend.torch.core import DYNAMIC_BATCH_SIZE_OK
|
||||
from keras_core.backend.torch.core import DYNAMIC_SHAPES_OK
|
||||
from keras_core.backend.torch.core import Variable
|
||||
from keras_core.backend.torch.core import cast
|
||||
|
@ -4,7 +4,6 @@ from keras_core.backend.common import KerasVariable
|
||||
from keras_core.backend.common import standardize_dtype
|
||||
|
||||
DYNAMIC_SHAPES_OK = True
|
||||
DYNAMIC_BATCH_SIZE_OK = True
|
||||
|
||||
|
||||
TORCH_DTYPES = {
|
||||
|
@ -30,8 +30,8 @@ class DropoutTest(testing.TestCase):
|
||||
self.assertAllClose(np.max(outputs), 2.0)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
backend.backend() == "jax",
|
||||
reason="JAX does not support dynamic shapes",
|
||||
)
|
||||
def test_dropout_partial_noise_shape_dynamic(self):
|
||||
inputs = np.ones((20, 5, 10))
|
||||
|
@ -1,7 +1,5 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core.testing import test_case
|
||||
|
||||
@ -52,10 +50,6 @@ class SpatialDropoutTest(test_case.TestCase):
|
||||
input_shape=(2, 3, 4, 4, 5),
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
def test_spatial_dropout_1D_dynamic(self):
|
||||
inputs = layers.Input((3, 2))
|
||||
layer = layers.SpatialDropout1D(0.5)
|
||||
@ -67,10 +61,6 @@ class SpatialDropoutTest(test_case.TestCase):
|
||||
outputs = layer(inputs, training=True)
|
||||
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
def test_spatial_dropout_2D_dynamic(self):
|
||||
inputs = layers.Input((3, 2, 4))
|
||||
layer = layers.SpatialDropout2D(0.5)
|
||||
@ -82,10 +72,6 @@ class SpatialDropoutTest(test_case.TestCase):
|
||||
outputs = layer(inputs, training=True)
|
||||
self.assertAllClose(outputs[:, 0, 0, :], outputs[:, 1, 1, :])
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
def test_spatial_dropout_3D_dynamic(self):
|
||||
inputs = layers.Input((3, 2, 4, 2))
|
||||
layer = layers.SpatialDropout3D(0.5)
|
||||
|
@ -48,10 +48,6 @@ class UpSamplingTest(testing.TestCase):
|
||||
layers.UpSampling1D(size=3)(np.ones((2, 1, 5))), np.ones((2, 3, 5))
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_BATCH_SIZE_OK,
|
||||
reason="Backend does not support dynamic batch sizes",
|
||||
)
|
||||
def test_upsampling_1d_with_dynamic_batch_size(self):
|
||||
x = KerasTensor([None, 2, 3])
|
||||
self.assertEqual(layers.UpSampling1D(size=2)(x).shape, (None, 4, 3))
|
||||
|
@ -1,19 +1,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
from keras_core.operations import core
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
class CoreOpsDynamicShapeTest(testing.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class CoreOpsStaticShapeTest(testing.TestCase):
|
||||
def test_scatter(self):
|
||||
# Requires dtype
|
||||
|
@ -1,7 +1,5 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
from keras_core.backend.common import keras_tensor
|
||||
from keras_core.operations import function
|
||||
@ -42,10 +40,6 @@ class FunctionTest(testing.TestCase):
|
||||
self.assertAllClose(y_val[0], np.ones((2, 3)) * 6)
|
||||
self.assertAllClose(y_val[1], np.ones((2, 3)) * 4)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_BATCH_SIZE_OK,
|
||||
reason="Test only valid if dynamic batch sizes are supported",
|
||||
)
|
||||
def test_dynamic_shape_inference(self):
|
||||
x = keras_tensor.KerasTensor((None, 3))
|
||||
y = x**2
|
||||
|
@ -1,18 +1,12 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.operations import image as kimage
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
class ImageOpsDynamicShapeTest(testing.TestCase):
|
||||
def test_resize(self):
|
||||
x = KerasTensor([None, 20, 20, 3])
|
||||
|
@ -8,10 +8,6 @@ from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.operations import math as kmath
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
class MathOpsDynamicShapeTest(testing.TestCase):
|
||||
def test_segment_sum(self):
|
||||
data = KerasTensor((None, 4), dtype="float32")
|
||||
|
@ -1,17 +1,11 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
from keras_core.operations import nn as knn
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
class NNOpsDynamicShapeTest(testing.TestCase):
|
||||
def test_relu(self):
|
||||
x = KerasTensor([None, 2, 3])
|
||||
|
@ -11,10 +11,6 @@ from keras_core.operations import numpy as knp
|
||||
np_config.enable_numpy_behavior()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
class NumpyTwoInputOpsDynamicShapeTest(testing.TestCase):
|
||||
def test_add(self):
|
||||
x = KerasTensor((None, 3))
|
||||
@ -623,10 +619,6 @@ class NumpyTwoInputOpsStaticShapeTest(testing.TestCase):
|
||||
self.assertEqual(knp.where(condition, x, y).shape, (2, 3))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not backend.DYNAMIC_SHAPES_OK,
|
||||
reason="Backend does not support dynamic shapes",
|
||||
)
|
||||
class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
|
||||
def test_mean(self):
|
||||
x = KerasTensor([None, 3])
|
||||
|
Loading…
Reference in New Issue
Block a user