Add sharding constraint for the output tensor in the model (#18536)
* WIP for shard the intermediate tensor. * Add unit test for sharding constraint * Remove unused variable. * Address review comments. * Address review comments. * Address review comments.
This commit is contained in:
parent
b811a37498
commit
c57e454f20
@ -35,7 +35,7 @@ class Variable(KerasVariable):
|
||||
|
||||
def _direct_assign(self, value):
|
||||
if getattr(self, "_layout", None) is not None:
|
||||
value = distribution_lib.distribute_value(value, self._layout)
|
||||
value = distribution_lib.distribute_variable(value, self._layout)
|
||||
self._value = value
|
||||
|
||||
def _convert_to_tensor(self, value, dtype=None):
|
||||
|
@ -8,6 +8,8 @@ with other backends in the future.
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from keras.utils import jax_utils
|
||||
|
||||
|
||||
def list_devices(device_type=None):
|
||||
"""Return all the available devices based on the device type.
|
||||
@ -27,20 +29,50 @@ def list_devices(device_type=None):
|
||||
return [f"{device.device_kind}:{device.id}" for device in jax_devices]
|
||||
|
||||
|
||||
def distribute_value(value, tensor_layout):
|
||||
"""Distribute the value based on the layout.
|
||||
def distribute_variable(value, layout):
|
||||
"""Create a distributed variable for JAX.
|
||||
|
||||
Since JAX doesn't have a variable class, this will just return a `jax.Array`
|
||||
with the corresponding layout/sharding specified.
|
||||
|
||||
Note that this function should be used in eager context, not in jitted
|
||||
function.
|
||||
|
||||
Args:
|
||||
value: the initial value of the variable.
|
||||
layout: `TensorLayout` for the created variable, or a
|
||||
`jax.sharding.Sharding` instance.
|
||||
|
||||
Returns:
|
||||
jax.Array which is the distributed variable.
|
||||
"""
|
||||
if not isinstance(layout, jax.sharding.Sharding):
|
||||
layout = _to_jax_layout(layout)
|
||||
return jax.device_put(value, layout)
|
||||
|
||||
|
||||
def distribute_tensor(tensor, layout):
|
||||
"""Distribute the tensor based on the layout.
|
||||
|
||||
Note that this function can be used both in eager context, or within a
|
||||
jitted function.
|
||||
|
||||
Args:
|
||||
value: `jax.Array` that need to be distributed.
|
||||
tensor_layout: `TensorLayout` for the distribution information, or a
|
||||
layout: `TensorLayout` for the distribution information, or a
|
||||
`jax.sharding.Sharding` instance.
|
||||
|
||||
Returns:
|
||||
Distributed value.
|
||||
"""
|
||||
if not isinstance(tensor_layout, jax.sharding.Sharding):
|
||||
tensor_layout = _to_jax_layout(tensor_layout)
|
||||
return jax.device_put(value, tensor_layout)
|
||||
if not isinstance(layout, jax.sharding.Sharding):
|
||||
layout = _to_jax_layout(layout)
|
||||
|
||||
# TODO(scottzhu): This might not be a cheap check, we should consider
|
||||
# have some proper JAX API for doing this check.
|
||||
if jax_utils.is_in_jax_tracing_scope():
|
||||
return jax.lax.with_sharding_constraint(tensor, layout)
|
||||
return jax.device_put(tensor, layout)
|
||||
|
||||
|
||||
def _to_jax_device(device_id):
|
||||
|
@ -846,7 +846,7 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
|
||||
def distribute_single_value(d):
|
||||
layout = distribution.get_data_layout(d.shape)
|
||||
return jax_distribution_lib.distribute_value(d, layout)
|
||||
return jax_distribution_lib.distribute_tensor(d, layout)
|
||||
|
||||
return jax.tree_util.tree_map(distribute_single_value, data)
|
||||
else:
|
||||
|
@ -14,6 +14,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from keras.api_export import keras_export
|
||||
from keras.backend import KerasTensor
|
||||
from keras.backend import distribution_lib
|
||||
from keras.backend.common import global_state
|
||||
|
||||
@ -170,6 +171,7 @@ class Distribution:
|
||||
|
||||
1. Distribute the model variables to a `DeviceMesh`.
|
||||
2. Distribute the input data to a `DeviceMesh`.
|
||||
3. Distribute an intermediate state tensor in the model.
|
||||
|
||||
It can create a context scope so that the framework to properly detect the
|
||||
`Distribution` and distribute the variable/data accordingly.
|
||||
@ -205,6 +207,19 @@ class Distribution:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_tensor_layout(self, path):
|
||||
"""Retrieve the `TensorLayout` for the intermediate tensor.
|
||||
|
||||
Args:
|
||||
path: a string path for the correspoding tensor.
|
||||
|
||||
return:
|
||||
The `TensorLayout` for the intermediate tensor, which can be used
|
||||
by `backend.relayout()` to reshard the tensor. Could also return
|
||||
None.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def scope(self):
|
||||
"""Context manager to make the `Distribution` current."""
|
||||
@ -296,6 +311,10 @@ class DataParallel(Distribution):
|
||||
variable_shard_spec = [None] * len(variable.shape)
|
||||
return TensorLayout(variable_shard_spec, self.device_mesh)
|
||||
|
||||
def get_tensor_layout(self, path):
|
||||
# For data parallel training, the intermediate state is not changed.
|
||||
return None
|
||||
|
||||
|
||||
@keras_export("keras.distribution.ModelParallel")
|
||||
class ModelParallel(Distribution):
|
||||
@ -393,6 +412,9 @@ class ModelParallel(Distribution):
|
||||
variable_shard_spec = [None] * len(variable.shape)
|
||||
return TensorLayout(variable_shard_spec, self.device_mesh)
|
||||
|
||||
def get_tensor_layout(self, path):
|
||||
return self._layout_map[path]
|
||||
|
||||
|
||||
@keras_export("keras.distribution.LayoutMap")
|
||||
class LayoutMap(collections.abc.MutableMapping):
|
||||
@ -507,6 +529,28 @@ class LayoutMap(collections.abc.MutableMapping):
|
||||
LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__
|
||||
|
||||
|
||||
@keras_export("keras.distribution.distribute_tensor")
|
||||
def distribute_tensor(tensor, layout):
|
||||
"""Change the layout of a Tensor value in the jit function execution.
|
||||
|
||||
Note that this might not work outside of the jitted function for certain
|
||||
backend. To change the layout of a value eagerly, please use
|
||||
`backend.distribution_lib.distribute_value`.
|
||||
|
||||
Args:
|
||||
tensor: a Tensor to change the layout.
|
||||
layout: `TensorLayout` to be applied on the value.
|
||||
|
||||
Returns:
|
||||
a new value with the specified tensor layout.
|
||||
"""
|
||||
if isinstance(tensor, KerasTensor):
|
||||
# keras tensor is only used for building functional model, and can't be
|
||||
# used to alter layout/sharding.
|
||||
return tensor
|
||||
return distribution_lib.distribute_tensor(tensor, layout)
|
||||
|
||||
|
||||
@keras_export("keras.distribution.distribution")
|
||||
def distribution():
|
||||
"""Retrieve the current distribution from global context."""
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test for distribution_lib.py."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
@ -192,6 +193,15 @@ class DataParallelDistributionTest(testing.TestCase):
|
||||
self.assertIs(variable_layout.device_mesh, self.device_mesh)
|
||||
self.assertEqual(variable_layout.axes, (None,))
|
||||
|
||||
def test_get_tensor_layout(self):
|
||||
distribution = distribution_lib.DataParallel(
|
||||
device_mesh=self.device_mesh
|
||||
)
|
||||
|
||||
path = "path/to/tensor"
|
||||
tensor_layout = distribution.get_tensor_layout(path)
|
||||
self.assertIsNone(tensor_layout)
|
||||
|
||||
|
||||
class ModelParallelDistributionTest(testing.TestCase):
|
||||
def setUp(self):
|
||||
@ -239,6 +249,22 @@ class ModelParallelDistributionTest(testing.TestCase):
|
||||
self.assertIs(data_layout.device_mesh, self.device_mesh)
|
||||
self.assertEqual(data_layout.axes, ("data", None, None))
|
||||
|
||||
def test_get_tensor_layout(self):
|
||||
layout_map = distribution_lib.LayoutMap(self.device_mesh)
|
||||
layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"])
|
||||
layout_map[".*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
layout_map["/model/layer/tensor"] = ("data", None)
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
self.device_mesh, layout_map, batch_dim_name="data"
|
||||
)
|
||||
layout = distribution.get_tensor_layout("/model/layer/tensor")
|
||||
self.assertIs(layout.device_mesh, self.device_mesh)
|
||||
self.assertEqual(layout.axes, ("data", None))
|
||||
|
||||
layout = distribution.get_tensor_layout("/model/layer/other_tensor")
|
||||
self.assertIsNone(layout)
|
||||
|
||||
|
||||
class LayoutMapTest(testing.TestCase):
|
||||
def setUp(self):
|
||||
@ -362,6 +388,30 @@ class JaxDistributionLibTest(testing.TestCase):
|
||||
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
|
||||
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
|
||||
|
||||
def test_distribute_tensor(self):
|
||||
jax_mesh = jax.sharding.Mesh(
|
||||
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
|
||||
)
|
||||
|
||||
inputs = jax.numpy.array(np.random.normal(size=(16, 8)))
|
||||
target_layout = jax.sharding.NamedSharding(
|
||||
jax_mesh, jax.sharding.PartitionSpec("batch", None)
|
||||
)
|
||||
|
||||
@functools.partial(jax.jit, static_argnames="target_layout")
|
||||
def test_function(inputs, target_layout):
|
||||
return distribution_lib.distribute_tensor(inputs, target_layout)
|
||||
|
||||
result = test_function(inputs, target_layout)
|
||||
# Note that the returned tensor has a different sharding implementation
|
||||
# which is GSPMDSharding, but it should be equivalent as the target
|
||||
# layout specified.
|
||||
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
|
||||
|
||||
# Test without jit
|
||||
result = distribution_lib.distribute_tensor(inputs, target_layout)
|
||||
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
|
||||
|
||||
def test_to_jax_mesh(self):
|
||||
devices = [f"cpu:{i}" for i in range(8)]
|
||||
shape = (4, 2)
|
||||
@ -499,6 +549,78 @@ class JaxDistributionLibTest(testing.TestCase):
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
def test_e2e_model_parallel_with_output_sharding(self):
|
||||
shape = (4, 2)
|
||||
axis_names = ["batch", "model"]
|
||||
device_mesh = distribution_lib.DeviceMesh(
|
||||
shape, axis_names, backend_dlib.list_devices()
|
||||
)
|
||||
|
||||
layout_map = distribution_lib.LayoutMap(device_mesh)
|
||||
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
|
||||
[None, "model"]
|
||||
)
|
||||
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
|
||||
# Force the dense layer output to be batch parallel only, and not
|
||||
# sharded on model dimension.
|
||||
layout_map[".*dense.*output"] = ("batch", None)
|
||||
|
||||
distribution = distribution_lib.ModelParallel(
|
||||
device_mesh, layout_map, batch_dim_name="batch"
|
||||
)
|
||||
sharding_capture = ShardingCaptureLayer()
|
||||
with distribution.scope():
|
||||
inputs = layers.Input(shape=[28, 28, 1])
|
||||
y = layers.Flatten()(inputs)
|
||||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
|
||||
y = sharding_capture(y)
|
||||
y = layers.Dropout(0.4)(y)
|
||||
y = layers.Dense(units=10, activation="softmax")(y)
|
||||
model = models.Model(inputs=inputs, outputs=y)
|
||||
|
||||
for weight in model.weights:
|
||||
if "kernel" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, (None, "model"))
|
||||
elif "bias" in weight.name:
|
||||
self.assertEqual(weight._value.sharding.spec, ("model",))
|
||||
else:
|
||||
self.assertTrue(weight._value.sharding.is_fully_replicated)
|
||||
|
||||
inputs = np.random.normal(size=(32, 28, 28, 1))
|
||||
labels = np.random.normal(size=(32, 10))
|
||||
|
||||
with distribution.scope():
|
||||
model.compile(loss="mse")
|
||||
model.fit(inputs, labels)
|
||||
|
||||
# Note that the intermediate_tensor_layout is only captured during the
|
||||
# actual training, and not at the model building time.
|
||||
intermediate_tensor_layout = jax.sharding.NamedSharding(
|
||||
backend_dlib._to_jax_mesh(distribution.device_mesh),
|
||||
jax.sharding.PartitionSpec("batch", None),
|
||||
)
|
||||
self.assertTrue(
|
||||
sharding_capture.captured_input_sharding.is_equivalent_to(
|
||||
intermediate_tensor_layout, ndim=2
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ShardingCaptureLayer(layers.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.captured_input_sharding = None
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, inputs):
|
||||
jax.debug.inspect_array_sharding(
|
||||
inputs, callback=lambda x: self.capture_input_sharding(x)
|
||||
)
|
||||
return inputs
|
||||
|
||||
def capture_input_sharding(self, sharding):
|
||||
self.captured_input_sharding = sharding
|
||||
|
||||
|
||||
# @pytest.mark.skipif(
|
||||
# backend.backend() != "tensorflow",
|
||||
|
@ -30,6 +30,8 @@ from keras import utils
|
||||
from keras.api_export import keras_export
|
||||
from keras.backend import KerasTensor
|
||||
from keras.backend.common import global_state
|
||||
from keras.backend.common.name_scope import current_path
|
||||
from keras.distribution import distribution_lib
|
||||
from keras.layers import input_spec
|
||||
from keras.metrics.metric import Metric
|
||||
from keras.ops.operation import Operation
|
||||
@ -808,6 +810,19 @@ class Layer(BackendLayer, Operation):
|
||||
outputs = super().__call__(*args, **kwargs)
|
||||
else:
|
||||
outputs = super().__call__(*args, **kwargs)
|
||||
# Change the layout for the layer output if needed.
|
||||
# This is useful for relayout intermediate tensor in the model
|
||||
# to achieve the optimal performance.
|
||||
distribution = distribution_lib.distribution()
|
||||
if distribution is not None:
|
||||
current_layer_path = current_path()
|
||||
current_layer_path += "/output"
|
||||
layout = distribution.get_tensor_layout(current_layer_path)
|
||||
if layout:
|
||||
outputs = distribution_lib.distribute_tensor(
|
||||
outputs, layout
|
||||
)
|
||||
|
||||
if not self.built:
|
||||
self.built = True
|
||||
# Record activity regularizer loss.
|
||||
|
Loading…
Reference in New Issue
Block a user