Add sharding constraint for the output tensor in the model (#18536)

* WIP for shard the intermediate tensor.

* Add unit test for sharding constraint

* Remove unused variable.

* Address review comments.

* Address review comments.

* Address review comments.
This commit is contained in:
Qianli Scott Zhu 2023-10-06 14:42:38 -07:00 committed by GitHub
parent b811a37498
commit c57e454f20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 221 additions and 8 deletions

@ -35,7 +35,7 @@ class Variable(KerasVariable):
def _direct_assign(self, value):
if getattr(self, "_layout", None) is not None:
value = distribution_lib.distribute_value(value, self._layout)
value = distribution_lib.distribute_variable(value, self._layout)
self._value = value
def _convert_to_tensor(self, value, dtype=None):

@ -8,6 +8,8 @@ with other backends in the future.
import jax
import numpy as np
from keras.utils import jax_utils
def list_devices(device_type=None):
"""Return all the available devices based on the device type.
@ -27,20 +29,50 @@ def list_devices(device_type=None):
return [f"{device.device_kind}:{device.id}" for device in jax_devices]
def distribute_value(value, tensor_layout):
"""Distribute the value based on the layout.
def distribute_variable(value, layout):
"""Create a distributed variable for JAX.
Since JAX doesn't have a variable class, this will just return a `jax.Array`
with the corresponding layout/sharding specified.
Note that this function should be used in eager context, not in jitted
function.
Args:
value: the initial value of the variable.
layout: `TensorLayout` for the created variable, or a
`jax.sharding.Sharding` instance.
Returns:
jax.Array which is the distributed variable.
"""
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)
return jax.device_put(value, layout)
def distribute_tensor(tensor, layout):
"""Distribute the tensor based on the layout.
Note that this function can be used both in eager context, or within a
jitted function.
Args:
value: `jax.Array` that need to be distributed.
tensor_layout: `TensorLayout` for the distribution information, or a
layout: `TensorLayout` for the distribution information, or a
`jax.sharding.Sharding` instance.
Returns:
Distributed value.
"""
if not isinstance(tensor_layout, jax.sharding.Sharding):
tensor_layout = _to_jax_layout(tensor_layout)
return jax.device_put(value, tensor_layout)
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)
# TODO(scottzhu): This might not be a cheap check, we should consider
# have some proper JAX API for doing this check.
if jax_utils.is_in_jax_tracing_scope():
return jax.lax.with_sharding_constraint(tensor, layout)
return jax.device_put(tensor, layout)
def _to_jax_device(device_id):

@ -846,7 +846,7 @@ class JAXTrainer(base_trainer.Trainer):
def distribute_single_value(d):
layout = distribution.get_data_layout(d.shape)
return jax_distribution_lib.distribute_value(d, layout)
return jax_distribution_lib.distribute_tensor(d, layout)
return jax.tree_util.tree_map(distribute_single_value, data)
else:

@ -14,6 +14,7 @@ import warnings
import numpy as np
from keras.api_export import keras_export
from keras.backend import KerasTensor
from keras.backend import distribution_lib
from keras.backend.common import global_state
@ -170,6 +171,7 @@ class Distribution:
1. Distribute the model variables to a `DeviceMesh`.
2. Distribute the input data to a `DeviceMesh`.
3. Distribute an intermediate state tensor in the model.
It can create a context scope so that the framework to properly detect the
`Distribution` and distribute the variable/data accordingly.
@ -205,6 +207,19 @@ class Distribution:
"""
raise NotImplementedError()
def get_tensor_layout(self, path):
"""Retrieve the `TensorLayout` for the intermediate tensor.
Args:
path: a string path for the correspoding tensor.
return:
The `TensorLayout` for the intermediate tensor, which can be used
by `backend.relayout()` to reshard the tensor. Could also return
None.
"""
raise NotImplementedError()
@contextlib.contextmanager
def scope(self):
"""Context manager to make the `Distribution` current."""
@ -296,6 +311,10 @@ class DataParallel(Distribution):
variable_shard_spec = [None] * len(variable.shape)
return TensorLayout(variable_shard_spec, self.device_mesh)
def get_tensor_layout(self, path):
# For data parallel training, the intermediate state is not changed.
return None
@keras_export("keras.distribution.ModelParallel")
class ModelParallel(Distribution):
@ -393,6 +412,9 @@ class ModelParallel(Distribution):
variable_shard_spec = [None] * len(variable.shape)
return TensorLayout(variable_shard_spec, self.device_mesh)
def get_tensor_layout(self, path):
return self._layout_map[path]
@keras_export("keras.distribution.LayoutMap")
class LayoutMap(collections.abc.MutableMapping):
@ -507,6 +529,28 @@ class LayoutMap(collections.abc.MutableMapping):
LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__
@keras_export("keras.distribution.distribute_tensor")
def distribute_tensor(tensor, layout):
"""Change the layout of a Tensor value in the jit function execution.
Note that this might not work outside of the jitted function for certain
backend. To change the layout of a value eagerly, please use
`backend.distribution_lib.distribute_value`.
Args:
tensor: a Tensor to change the layout.
layout: `TensorLayout` to be applied on the value.
Returns:
a new value with the specified tensor layout.
"""
if isinstance(tensor, KerasTensor):
# keras tensor is only used for building functional model, and can't be
# used to alter layout/sharding.
return tensor
return distribution_lib.distribute_tensor(tensor, layout)
@keras_export("keras.distribution.distribution")
def distribution():
"""Retrieve the current distribution from global context."""

@ -1,5 +1,6 @@
"""Test for distribution_lib.py."""
import functools
import os
from unittest import mock
@ -192,6 +193,15 @@ class DataParallelDistributionTest(testing.TestCase):
self.assertIs(variable_layout.device_mesh, self.device_mesh)
self.assertEqual(variable_layout.axes, (None,))
def test_get_tensor_layout(self):
distribution = distribution_lib.DataParallel(
device_mesh=self.device_mesh
)
path = "path/to/tensor"
tensor_layout = distribution.get_tensor_layout(path)
self.assertIsNone(tensor_layout)
class ModelParallelDistributionTest(testing.TestCase):
def setUp(self):
@ -239,6 +249,22 @@ class ModelParallelDistributionTest(testing.TestCase):
self.assertIs(data_layout.device_mesh, self.device_mesh)
self.assertEqual(data_layout.axes, ("data", None, None))
def test_get_tensor_layout(self):
layout_map = distribution_lib.LayoutMap(self.device_mesh)
layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"])
layout_map[".*bias"] = distribution_lib.TensorLayout(["model"])
layout_map["/model/layer/tensor"] = ("data", None)
distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
)
layout = distribution.get_tensor_layout("/model/layer/tensor")
self.assertIs(layout.device_mesh, self.device_mesh)
self.assertEqual(layout.axes, ("data", None))
layout = distribution.get_tensor_layout("/model/layer/other_tensor")
self.assertIsNone(layout)
class LayoutMapTest(testing.TestCase):
def setUp(self):
@ -362,6 +388,30 @@ class JaxDistributionLibTest(testing.TestCase):
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
def test_distribute_tensor(self):
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
inputs = jax.numpy.array(np.random.normal(size=(16, 8)))
target_layout = jax.sharding.NamedSharding(
jax_mesh, jax.sharding.PartitionSpec("batch", None)
)
@functools.partial(jax.jit, static_argnames="target_layout")
def test_function(inputs, target_layout):
return distribution_lib.distribute_tensor(inputs, target_layout)
result = test_function(inputs, target_layout)
# Note that the returned tensor has a different sharding implementation
# which is GSPMDSharding, but it should be equivalent as the target
# layout specified.
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
# Test without jit
result = distribution_lib.distribute_tensor(inputs, target_layout)
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
def test_to_jax_mesh(self):
devices = [f"cpu:{i}" for i in range(8)]
shape = (4, 2)
@ -499,6 +549,78 @@ class JaxDistributionLibTest(testing.TestCase):
model.compile(loss="mse")
model.fit(inputs, labels)
def test_e2e_model_parallel_with_output_sharding(self):
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
shape, axis_names, backend_dlib.list_devices()
)
layout_map = distribution_lib.LayoutMap(device_mesh)
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
[None, "model"]
)
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
# Force the dense layer output to be batch parallel only, and not
# sharded on model dimension.
layout_map[".*dense.*output"] = ("batch", None)
distribution = distribution_lib.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch"
)
sharding_capture = ShardingCaptureLayer()
with distribution.scope():
inputs = layers.Input(shape=[28, 28, 1])
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = sharding_capture(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = models.Model(inputs=inputs, outputs=y)
for weight in model.weights:
if "kernel" in weight.name:
self.assertEqual(weight._value.sharding.spec, (None, "model"))
elif "bias" in weight.name:
self.assertEqual(weight._value.sharding.spec, ("model",))
else:
self.assertTrue(weight._value.sharding.is_fully_replicated)
inputs = np.random.normal(size=(32, 28, 28, 1))
labels = np.random.normal(size=(32, 10))
with distribution.scope():
model.compile(loss="mse")
model.fit(inputs, labels)
# Note that the intermediate_tensor_layout is only captured during the
# actual training, and not at the model building time.
intermediate_tensor_layout = jax.sharding.NamedSharding(
backend_dlib._to_jax_mesh(distribution.device_mesh),
jax.sharding.PartitionSpec("batch", None),
)
self.assertTrue(
sharding_capture.captured_input_sharding.is_equivalent_to(
intermediate_tensor_layout, ndim=2
)
)
class ShardingCaptureLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.captured_input_sharding = None
self.supports_masking = True
def call(self, inputs):
jax.debug.inspect_array_sharding(
inputs, callback=lambda x: self.capture_input_sharding(x)
)
return inputs
def capture_input_sharding(self, sharding):
self.captured_input_sharding = sharding
# @pytest.mark.skipif(
# backend.backend() != "tensorflow",

@ -30,6 +30,8 @@ from keras import utils
from keras.api_export import keras_export
from keras.backend import KerasTensor
from keras.backend.common import global_state
from keras.backend.common.name_scope import current_path
from keras.distribution import distribution_lib
from keras.layers import input_spec
from keras.metrics.metric import Metric
from keras.ops.operation import Operation
@ -808,6 +810,19 @@ class Layer(BackendLayer, Operation):
outputs = super().__call__(*args, **kwargs)
else:
outputs = super().__call__(*args, **kwargs)
# Change the layout for the layer output if needed.
# This is useful for relayout intermediate tensor in the model
# to achieve the optimal performance.
distribution = distribution_lib.distribution()
if distribution is not None:
current_layer_path = current_path()
current_layer_path += "/output"
layout = distribution.get_tensor_layout(current_layer_path)
if layout:
outputs = distribution_lib.distribute_tensor(
outputs, layout
)
if not self.built:
self.built = True
# Record activity regularizer loss.