Split the jax distribution backend test into separate file. (#18587)
* Split the jax distribution backend test into separate file. * Update test script. * Remove unused imports.
This commit is contained in:
parent
6bafb3b099
commit
3d92f71e0d
270
keras/backend/jax/distribution_lib_test.py
Normal file
270
keras/backend/jax/distribution_lib_test.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""Test for distribution_lib.py."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras import backend
|
||||
from keras import layers
|
||||
from keras import models
|
||||
from keras import testing
|
||||
from keras.backend import distribution_lib as backend_dlib
|
||||
from keras.distribution import distribution_lib
|
||||
|
||||
if backend.backend() == "jax":
|
||||
# Due to https://github.com/google/jax/issues/17188, we can't
|
||||
# override the XLA flag after the JAX back init. We have to
|
||||
# run this at top level to let JAX pick the flag value.
|
||||
xla_flags = os.getenv("XLA_FLAGS") or ""
|
||||
# Don't override user-specified device count, or other XLA flags.
|
||||
if "xla_force_host_platform_device_count" not in xla_flags:
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
xla_flags + " --xla_force_host_platform_device_count=8"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() != "jax",
|
||||
reason="Backend specific test",
|
||||
)
|
||||
class JaxDistributionLibTest(testing.TestCase):
|
||||
def test_list_devices(self):
|
||||
self.assertEqual(len(distribution_lib.list_devices()), 8)
|
||||
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
|
||||
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
|
||||
|
||||
def test_distribute_tensor(self):
|
||||
jax_mesh = jax.sharding.Mesh(
|
||||
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
|
||||
)
|
||||
|
||||
inputs = jax.numpy.array(np.random.normal(size=(16, 8)))
|
||||
target_layout = jax.sharding.NamedSharding(
|
||||
jax_mesh, jax.sharding.PartitionSpec("batch", None)
|
||||
)
|
||||
|
||||
@functools.partial(jax.jit, static_argnames="target_layout")
|
||||
def test_function(inputs, target_layout):
|
||||
return distribution_lib.distribute_tensor(inputs, target_layout)
|
||||
|
||||
result = test_function(inputs, target_layout)
|
||||
# Note that the returned tensor has a different sharding implementation
|
||||
# which is GSPMDSharding, but it should be equivalent as the target
|
||||
# layout specified.
|
||||
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
|
||||
|
||||
# Test without jit
|
||||
result = distribution_lib.distribute_tensor(inputs, target_layout)
|
||||
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
|
||||
|
||||
def test_to_jax_mesh(self):
|
||||
devices = [f"cpu:{i}" for i in range(8)]
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
|
||||
mesh = distribution_lib.DeviceMesh(shape, axis_names, devices)
|
||||
jax_mesh = backend_dlib._to_jax_mesh(mesh)
|
||||
|
||||
self.assertIsInstance(jax_mesh, jax.sharding.Mesh)
|
||||
self.assertEqual(jax_mesh.devices.shape, shape)
|
||||
self.assertEqual(jax_mesh.axis_names, ("batch", "model"))
|
||||
|
||||
def test_to_jax_layout(self):
|
||||
axes = ["data", None]
|
||||
mesh = distribution_lib.DeviceMesh(
|
||||
(4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)]
|
||||
)
|
||||
layout = distribution_lib.TensorLayout(axes, mesh)
|
||||
jax_sharding = backend_dlib._to_jax_layout(layout)
|
||||
jax_mesh = backend_dlib._to_jax_mesh(mesh)
|
||||
self.assertEqual(
|
||||
jax_sharding,
|
||||
jax.sharding.NamedSharding(
|
||||
jax_mesh, jax.sharding.PartitionSpec("data", None)
|
||||
),
|
||||
)
|
||||
|
||||
def test_validation_for_device_mesh(self):
|
||||
axes = ["data", None]
|
||||
layout = distribution_lib.TensorLayout(axes, device_mesh=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Cannot create sharding when device mesh is not set"
|
||||
):
|
||||
backend_dlib._to_jax_layout(layout)
|
||||
|
||||
def test_variable_assignment_reuse_layout(self):
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
device_mesh = distribution_lib.DeviceMesh(
|
||||
shape, axis_names, backend_dlib.list_devices()
|
||||
)
|
||||
layout_map = distribution_lib.LayoutMap(device_mesh)
|
||||
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
|
||||
[None, "model"]
|
||||
)
|
||||
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
device_mesh, layout_map, batch_dim_name="batch"
|
||||
)
|
||||
|
||||
with distribution.scope():
|
||||
dense_layer = layers.Dense(8)
|
||||
dense_layer.build((16, 16))
|
||||
|
||||
self.assertEqual(
|
||||
dense_layer.kernel._value.sharding.spec, (None, "model")
|
||||
)
|
||||
self.assertEqual(dense_layer.bias._value.sharding.spec, ("model",))
|
||||
|
||||
# Assign a numpy value to dense layer to mimic the model weight loading
|
||||
new_kernel = np.random.normal(size=(16, 8))
|
||||
new_bias = np.random.normal(size=(8))
|
||||
dense_layer.kernel.assign(new_kernel)
|
||||
dense_layer.bias.assign(new_bias)
|
||||
|
||||
# Make sure the loaded value still use the layout when it is
|
||||
# initialized, even outside of the distribution scope.
|
||||
self.assertEqual(
|
||||
dense_layer.kernel._value.sharding.spec, (None, "model")
|
||||
)
|
||||
self.assertEqual(dense_layer.bias._value.sharding.spec, ("model",))
|
||||
|
||||
def test_e2e_data_parallel_model(self):
|
||||
distribution = distribution_lib.DataParallel(
|
||||
devices=backend_dlib.list_devices()
|
||||
)
|
||||
|
||||
with distribution.scope():
|
||||
inputs = layers.Input(shape=[28, 28, 1])
|
||||
y = layers.Flatten()(inputs)
|
||||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
|
||||
y = layers.Dropout(0.4)(y)
|
||||
y = layers.Dense(units=10, activation="softmax")(y)
|
||||
model = models.Model(inputs=inputs, outputs=y)
|
||||
|
||||
# Make sure all the weights are properly sharded.
|
||||
for weight in model.weights:
|
||||
self.assertTrue(weight._value.sharding.is_fully_replicated)
|
||||
|
||||
inputs = np.random.normal(size=(32, 28, 28, 1))
|
||||
labels = np.random.normal(size=(32, 10))
|
||||
|
||||
with distribution.scope():
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
def test_e2e_model_parallel_model(self):
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
device_mesh = distribution_lib.DeviceMesh(
|
||||
shape, axis_names, backend_dlib.list_devices()
|
||||
)
|
||||
|
||||
layout_map = distribution_lib.LayoutMap(device_mesh)
|
||||
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
|
||||
[None, "model"]
|
||||
)
|
||||
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
device_mesh, layout_map, batch_dim_name="batch"
|
||||
)
|
||||
with distribution.scope():
|
||||
inputs = layers.Input(shape=[28, 28, 1])
|
||||
y = layers.Flatten()(inputs)
|
||||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
|
||||
y = layers.Dropout(0.4)(y)
|
||||
y = layers.Dense(units=10, activation="softmax")(y)
|
||||
model = models.Model(inputs=inputs, outputs=y)
|
||||
|
||||
for weight in model.weights:
|
||||
if "kernel" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, (None, "model"))
|
||||
elif "bias" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, ("model",))
|
||||
else:
|
||||
self.assertTrue(weight._value.sharding.is_fully_replicated)
|
||||
|
||||
inputs = np.random.normal(size=(32, 28, 28, 1))
|
||||
labels = np.random.normal(size=(32, 10))
|
||||
|
||||
with distribution.scope():
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
def test_e2e_model_parallel_with_output_sharding(self):
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
device_mesh = distribution_lib.DeviceMesh(
|
||||
shape, axis_names, backend_dlib.list_devices()
|
||||
)
|
||||
|
||||
layout_map = distribution_lib.LayoutMap(device_mesh)
|
||||
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
|
||||
[None, "model"]
|
||||
)
|
||||
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
# Force the dense layer output to be batch parallel only, and not
|
||||
# sharded on model dimension.
|
||||
layout_map[".*dense.*output"] = ("batch", None)
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
device_mesh, layout_map, batch_dim_name="batch"
|
||||
)
|
||||
sharding_capture = ShardingCaptureLayer()
|
||||
with distribution.scope():
|
||||
inputs = layers.Input(shape=[28, 28, 1])
|
||||
y = layers.Flatten()(inputs)
|
||||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
|
||||
y = sharding_capture(y)
|
||||
y = layers.Dropout(0.4)(y)
|
||||
y = layers.Dense(units=10, activation="softmax")(y)
|
||||
model = models.Model(inputs=inputs, outputs=y)
|
||||
|
||||
for weight in model.weights:
|
||||
if "kernel" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, (None, "model"))
|
||||
elif "bias" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, ("model",))
|
||||
else:
|
||||
self.assertTrue(weight._value.sharding.is_fully_replicated)
|
||||
|
||||
inputs = np.random.normal(size=(32, 28, 28, 1))
|
||||
labels = np.random.normal(size=(32, 10))
|
||||
|
||||
with distribution.scope():
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
# Note that the intermediate_tensor_layout is only captured during the
|
||||
# actual training, and not at the model building time.
|
||||
intermediate_tensor_layout = jax.sharding.NamedSharding(
|
||||
backend_dlib._to_jax_mesh(distribution.device_mesh),
|
||||
jax.sharding.PartitionSpec("batch", None),
|
||||
)
|
||||
self.assertTrue(
|
||||
sharding_capture.captured_input_sharding.is_equivalent_to(
|
||||
intermediate_tensor_layout, ndim=2
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ShardingCaptureLayer(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.captured_input_sharding = None
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, inputs):
|
||||
jax.debug.inspect_array_sharding(
|
||||
inputs, callback=lambda x: self.capture_input_sharding(x)
|
||||
)
|
||||
return inputs
|
||||
|
||||
def capture_input_sharding(self, sharding):
|
||||
self.captured_input_sharding = sharding
|
@ -1,31 +1,13 @@
|
||||
"""Test for distribution_lib.py."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras import backend
|
||||
from keras import layers
|
||||
from keras import models
|
||||
from keras import testing
|
||||
from keras.backend import distribution_lib as backend_dlib
|
||||
from keras.distribution import distribution_lib
|
||||
|
||||
if backend.backend() == "jax":
|
||||
# Due to https://github.com/google/jax/issues/17188, we can't
|
||||
# override the XLA flag after the JAX back init. We have to
|
||||
# run this at top level to let JAX pick the flag value.
|
||||
xla_flags = os.getenv("XLA_FLAGS") or ""
|
||||
# Don't override user-specified device count, or other XLA flags.
|
||||
if "xla_force_host_platform_device_count" not in xla_flags:
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
xla_flags + " --xla_force_host_platform_device_count=8"
|
||||
)
|
||||
|
||||
|
||||
class DeviceMeshTest(testing.TestCase):
|
||||
def test_mesh_creation(self):
|
||||
@ -381,250 +363,6 @@ class LayoutMapTest(testing.TestCase):
|
||||
self.assertEqual(values, [self.sharded_2d, self.sharded_1d])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() != "jax",
|
||||
reason="Backend specific test",
|
||||
)
|
||||
class JaxDistributionLibTest(testing.TestCase):
|
||||
def test_list_devices(self):
|
||||
self.assertEqual(len(distribution_lib.list_devices()), 8)
|
||||
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
|
||||
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
|
||||
|
||||
def test_distribute_tensor(self):
|
||||
jax_mesh = jax.sharding.Mesh(
|
||||
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
|
||||
)
|
||||
|
||||
inputs = jax.numpy.array(np.random.normal(size=(16, 8)))
|
||||
target_layout = jax.sharding.NamedSharding(
|
||||
jax_mesh, jax.sharding.PartitionSpec("batch", None)
|
||||
)
|
||||
|
||||
@functools.partial(jax.jit, static_argnames="target_layout")
|
||||
def test_function(inputs, target_layout):
|
||||
return distribution_lib.distribute_tensor(inputs, target_layout)
|
||||
|
||||
result = test_function(inputs, target_layout)
|
||||
# Note that the returned tensor has a different sharding implementation
|
||||
# which is GSPMDSharding, but it should be equivalent as the target
|
||||
# layout specified.
|
||||
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
|
||||
|
||||
# Test without jit
|
||||
result = distribution_lib.distribute_tensor(inputs, target_layout)
|
||||
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
|
||||
|
||||
def test_to_jax_mesh(self):
|
||||
devices = [f"cpu:{i}" for i in range(8)]
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
|
||||
mesh = distribution_lib.DeviceMesh(shape, axis_names, devices)
|
||||
jax_mesh = backend_dlib._to_jax_mesh(mesh)
|
||||
|
||||
self.assertIsInstance(jax_mesh, jax.sharding.Mesh)
|
||||
self.assertEqual(jax_mesh.devices.shape, shape)
|
||||
self.assertEqual(jax_mesh.axis_names, ("batch", "model"))
|
||||
|
||||
def test_to_jax_layout(self):
|
||||
axes = ["data", None]
|
||||
mesh = distribution_lib.DeviceMesh(
|
||||
(4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)]
|
||||
)
|
||||
layout = distribution_lib.TensorLayout(axes, mesh)
|
||||
jax_sharding = backend_dlib._to_jax_layout(layout)
|
||||
jax_mesh = backend_dlib._to_jax_mesh(mesh)
|
||||
self.assertEqual(
|
||||
jax_sharding,
|
||||
jax.sharding.NamedSharding(
|
||||
jax_mesh, jax.sharding.PartitionSpec("data", None)
|
||||
),
|
||||
)
|
||||
|
||||
def test_validation_for_device_mesh(self):
|
||||
axes = ["data", None]
|
||||
layout = distribution_lib.TensorLayout(axes, device_mesh=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Cannot create sharding when device mesh is not set"
|
||||
):
|
||||
backend_dlib._to_jax_layout(layout)
|
||||
|
||||
def test_variable_assignment_reuse_layout(self):
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
device_mesh = distribution_lib.DeviceMesh(
|
||||
shape, axis_names, backend_dlib.list_devices()
|
||||
)
|
||||
layout_map = distribution_lib.LayoutMap(device_mesh)
|
||||
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
|
||||
[None, "model"]
|
||||
)
|
||||
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
device_mesh, layout_map, batch_dim_name="batch"
|
||||
)
|
||||
|
||||
with distribution.scope():
|
||||
dense_layer = layers.Dense(8)
|
||||
dense_layer.build((16, 16))
|
||||
|
||||
self.assertEqual(
|
||||
dense_layer.kernel._value.sharding.spec, (None, "model")
|
||||
)
|
||||
self.assertEqual(dense_layer.bias._value.sharding.spec, ("model",))
|
||||
|
||||
# Assign a numpy value to dense layer to mimic the model weight loading
|
||||
new_kernel = np.random.normal(size=(16, 8))
|
||||
new_bias = np.random.normal(size=(8))
|
||||
dense_layer.kernel.assign(new_kernel)
|
||||
dense_layer.bias.assign(new_bias)
|
||||
|
||||
# Make sure the loaded value still use the layout when it is
|
||||
# initialized, even outside of the distribution scope.
|
||||
self.assertEqual(
|
||||
dense_layer.kernel._value.sharding.spec, (None, "model")
|
||||
)
|
||||
self.assertEqual(dense_layer.bias._value.sharding.spec, ("model",))
|
||||
|
||||
def test_e2e_data_parallel_model(self):
|
||||
distribution = distribution_lib.DataParallel(
|
||||
devices=backend_dlib.list_devices()
|
||||
)
|
||||
|
||||
with distribution.scope():
|
||||
inputs = layers.Input(shape=[28, 28, 1])
|
||||
y = layers.Flatten()(inputs)
|
||||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
|
||||
y = layers.Dropout(0.4)(y)
|
||||
y = layers.Dense(units=10, activation="softmax")(y)
|
||||
model = models.Model(inputs=inputs, outputs=y)
|
||||
|
||||
# Make sure all the weights are properly sharded.
|
||||
for weight in model.weights:
|
||||
self.assertTrue(weight._value.sharding.is_fully_replicated)
|
||||
|
||||
inputs = np.random.normal(size=(32, 28, 28, 1))
|
||||
labels = np.random.normal(size=(32, 10))
|
||||
|
||||
with distribution.scope():
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
def test_e2e_model_parallel_model(self):
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
device_mesh = distribution_lib.DeviceMesh(
|
||||
shape, axis_names, backend_dlib.list_devices()
|
||||
)
|
||||
|
||||
layout_map = distribution_lib.LayoutMap(device_mesh)
|
||||
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
|
||||
[None, "model"]
|
||||
)
|
||||
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
device_mesh, layout_map, batch_dim_name="batch"
|
||||
)
|
||||
with distribution.scope():
|
||||
inputs = layers.Input(shape=[28, 28, 1])
|
||||
y = layers.Flatten()(inputs)
|
||||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
|
||||
y = layers.Dropout(0.4)(y)
|
||||
y = layers.Dense(units=10, activation="softmax")(y)
|
||||
model = models.Model(inputs=inputs, outputs=y)
|
||||
|
||||
for weight in model.weights:
|
||||
if "kernel" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, (None, "model"))
|
||||
elif "bias" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, ("model",))
|
||||
else:
|
||||
self.assertTrue(weight._value.sharding.is_fully_replicated)
|
||||
|
||||
inputs = np.random.normal(size=(32, 28, 28, 1))
|
||||
labels = np.random.normal(size=(32, 10))
|
||||
|
||||
with distribution.scope():
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
def test_e2e_model_parallel_with_output_sharding(self):
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
device_mesh = distribution_lib.DeviceMesh(
|
||||
shape, axis_names, backend_dlib.list_devices()
|
||||
)
|
||||
|
||||
layout_map = distribution_lib.LayoutMap(device_mesh)
|
||||
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
|
||||
[None, "model"]
|
||||
)
|
||||
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
# Force the dense layer output to be batch parallel only, and not
|
||||
# sharded on model dimension.
|
||||
layout_map[".*dense.*output"] = ("batch", None)
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
device_mesh, layout_map, batch_dim_name="batch"
|
||||
)
|
||||
sharding_capture = ShardingCaptureLayer()
|
||||
with distribution.scope():
|
||||
inputs = layers.Input(shape=[28, 28, 1])
|
||||
y = layers.Flatten()(inputs)
|
||||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
|
||||
y = sharding_capture(y)
|
||||
y = layers.Dropout(0.4)(y)
|
||||
y = layers.Dense(units=10, activation="softmax")(y)
|
||||
model = models.Model(inputs=inputs, outputs=y)
|
||||
|
||||
for weight in model.weights:
|
||||
if "kernel" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, (None, "model"))
|
||||
elif "bias" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, ("model",))
|
||||
else:
|
||||
self.assertTrue(weight._value.sharding.is_fully_replicated)
|
||||
|
||||
inputs = np.random.normal(size=(32, 28, 28, 1))
|
||||
labels = np.random.normal(size=(32, 10))
|
||||
|
||||
with distribution.scope():
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
# Note that the intermediate_tensor_layout is only captured during the
|
||||
# actual training, and not at the model building time.
|
||||
intermediate_tensor_layout = jax.sharding.NamedSharding(
|
||||
backend_dlib._to_jax_mesh(distribution.device_mesh),
|
||||
jax.sharding.PartitionSpec("batch", None),
|
||||
)
|
||||
self.assertTrue(
|
||||
sharding_capture.captured_input_sharding.is_equivalent_to(
|
||||
intermediate_tensor_layout, ndim=2
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ShardingCaptureLayer(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.captured_input_sharding = None
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, inputs):
|
||||
jax.debug.inspect_array_sharding(
|
||||
inputs, callback=lambda x: self.capture_input_sharding(x)
|
||||
)
|
||||
return inputs
|
||||
|
||||
def capture_input_sharding(self, sharding):
|
||||
self.captured_input_sharding = sharding
|
||||
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# backend.backend() != "tensorflow",
|
||||
# reason="Backend specific test",
|
||||
|
@ -49,12 +49,12 @@ then
|
||||
# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
|
||||
# TODO: FAILED keras/layers/preprocessing/feature_space_test.py::FeatureSpaceTest::test_saving
|
||||
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
|
||||
# TODO: FAILED keras/distribution/distribution_lib_test.py
|
||||
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.
|
||||
pytest keras --ignore keras/applications \
|
||||
--ignore keras/layers/merging/merging_test.py \
|
||||
--ignore keras/layers/preprocessing/feature_space_test.py \
|
||||
--ignore keras/trainers/data_adapters/py_dataset_adapter_test.py \
|
||||
--ignore keras/distribution/distribution_lib_test.py \
|
||||
--ignore keras/backend/jax/distribution_lib_test.py \
|
||||
--cov=keras
|
||||
fi
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user