Use Numpy arrays in data adapters for JAX backend. (#20007)

- Slicing and splitting is more efficient with Numpy arrays on CPU than JAX arrays on CPU.
- Converting to `jax.Array` sends them to the first device by default, which is incorrect in the case of distribution and slows down distribution.

Also
- Keep JAX BCOO tensors on CPU within data adapters.
- Use `np.split` instead of `jax.numpy.split` in distribution_lib per the JAX distribution guide.
This commit is contained in:
hertschuh 2024-07-17 21:12:03 -07:00 committed by GitHub
parent 902f9da309
commit 808b78618e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 49 additions and 53 deletions

@ -142,7 +142,7 @@ def distribute_data_input(inputs, layout):
f"{num_split}" f"{num_split}"
) )
global_batch_size = per_process_batch_size * jax.process_count() global_batch_size = per_process_batch_size * jax.process_count()
per_replica_batches = jax.numpy.split(inputs, num_split, axis=0) per_replica_batches = np.split(inputs, num_split, axis=0)
elif mesh_rank == 2: elif mesh_rank == 2:
# Data+Model parallel # Data+Model parallel
# In this case, we need to check if the mesh batch dim shape is large # In this case, we need to check if the mesh batch dim shape is large

@ -243,8 +243,6 @@ class ArrayDataAdapter(DataAdapter):
return dataset.prefetch(tf.data.AUTOTUNE) return dataset.prefetch(tf.data.AUTOTUNE)
def get_jax_iterator(self): def get_jax_iterator(self):
from keras.src.backend.jax.core import convert_to_tensor
inputs = array_slicing.convert_to_sliceable( inputs = array_slicing.convert_to_sliceable(
self._inputs, target_backend="jax" self._inputs, target_backend="jax"
) )
@ -252,7 +250,6 @@ class ArrayDataAdapter(DataAdapter):
def slice_and_convert_to_jax(sliceable, indices=None): def slice_and_convert_to_jax(sliceable, indices=None):
x = sliceable[indices] x = sliceable[indices]
x = sliceable.convert_to_jax_compatible(x) x = sliceable.convert_to_jax_compatible(x)
x = convert_to_tensor(x)
return x return x
return self._get_iterator(slice_and_convert_to_jax, inputs) return self._get_iterator(slice_and_convert_to_jax, inputs)

@ -92,7 +92,7 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"): if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"):
expected_class = jax_sparse.JAXSparse expected_class = jax_sparse.JAXSparse
else: else:
expected_class = jax.Array expected_class = np.ndarray
elif backend.backend() == "torch": elif backend.backend() == "torch":
it = adapter.get_torch_dataloader() it = adapter.get_torch_dataloader()
expected_class = torch.Tensor expected_class = torch.Tensor

@ -91,8 +91,7 @@ class Sliceable:
def convert_to_jax_compatible(cls, x): def convert_to_jax_compatible(cls, x):
"""Convert a tensor to something that the JAX backend can consume. """Convert a tensor to something that the JAX backend can consume.
This can be a `JAX` array, NumPy array or any other type of tensor that This can be a `JAX` array, `JAXSparse` or a NumPy array.
`keras.backend.jax.core.convert_to_tensor()` can consume.
Only called after slicing using `__getitem__`. Only called after slicing using `__getitem__`.
Used to convert sparse tensors and densify ragged tensors. Used to convert sparse tensors and densify ragged tensors.
@ -147,7 +146,7 @@ class TensorflowSliceable(Sliceable):
class TensorflowRaggedSliceable(TensorflowSliceable): class TensorflowRaggedSliceable(TensorflowSliceable):
@classmethod @classmethod
def convert_to_jax_compatible(cls, x): def convert_to_jax_compatible(cls, x):
return x.to_tensor() return cls.convert_to_numpy(x)
@classmethod @classmethod
def convert_to_torch_compatible(cls, x): def convert_to_torch_compatible(cls, x):
@ -180,7 +179,7 @@ class TensorflowSparseSliceable(TensorflowSliceable):
return tf_sparse.sparse_to_dense(x) return tf_sparse.sparse_to_dense(x)
class JaxSliceable(Sliceable): class JaxSparseSliceable(Sliceable):
def __getitem__(self, indices): def __getitem__(self, indices):
return self.array[indices, ...] return self.array[indices, ...]
@ -190,8 +189,6 @@ class JaxSliceable(Sliceable):
return convert_to_numpy(x) return convert_to_numpy(x)
class JaxSparseSliceable(JaxSliceable):
@classmethod @classmethod
def convert_to_tf_dataset_compatible(cls, array): def convert_to_tf_dataset_compatible(cls, array):
return to_tensorflow_sparse_wrapper( return to_tensorflow_sparse_wrapper(
@ -386,7 +383,8 @@ def convert_to_sliceable(arrays, target_backend=None):
if data_adapter_utils.is_jax_sparse(x): if data_adapter_utils.is_jax_sparse(x):
sliceable_class = JaxSparseSliceable sliceable_class = JaxSparseSliceable
else: else:
sliceable_class = JaxSliceable x = np.asarray(x)
sliceable_class = NumpySliceable
elif data_adapter_utils.is_torch_tensor(x): elif data_adapter_utils.is_torch_tensor(x):
sliceable_class = TorchSliceable sliceable_class = TorchSliceable
elif pandas is not None and isinstance(x, pandas.DataFrame): elif pandas is not None and isinstance(x, pandas.DataFrame):
@ -433,14 +431,14 @@ def convert_to_sliceable(arrays, target_backend=None):
if target_backend == "tensorflow": if target_backend == "tensorflow":
return sliceable_class.convert_to_tf_dataset_compatible(x) return sliceable_class.convert_to_tf_dataset_compatible(x)
# With dense arrays, with JAX as either input or output, it is faster to # With dense arrays and JAX as output, it is faster to use NumPy as an
# use NumPy as an intermediary representation, so wrap input array in a # intermediary representation, so wrap input array in a NumPy array,
# NumPy array, which should not use extra memory. For the input case, # which should not use extra memory.
# see https://github.com/google/jax/issues/1276 for an explanation of # See https://github.com/google/jax/issues/1276 for an explanation of
# why slicing a NumPy array is faster than slicing a JAX array. # why slicing a NumPy array is faster than slicing a JAX array.
if sliceable_class == JaxSliceable or ( if target_backend == "jax" and sliceable_class in (
target_backend == "jax" TensorflowSliceable,
and sliceable_class in (TensorflowSliceable, TorchSliceable) TorchSliceable,
): ):
x = np.asarray(x) x = np.asarray(x)
sliceable_class = NumpySliceable sliceable_class = NumpySliceable

@ -30,7 +30,8 @@ class DataAdapter:
raise NotImplementedError raise NotImplementedError
def get_jax_iterator(self): def get_jax_iterator(self):
"""Get a Python iterable for the `DataAdapter`, that yields JAX arrays. """Get a Python iterable for the `DataAdapter`, that yields arrays that
that can be fed to JAX. NumPy arrays are preferred for performance.
Returns: Returns:
A Python iterator. A Python iterator.

@ -176,10 +176,21 @@ def get_tensor_spec(batches):
def get_jax_iterator(iterable): def get_jax_iterator(iterable):
from keras.src.backend.jax.core import convert_to_tensor import jax
import jax.experimental.sparse as jax_sparse
def convert_to_jax_compatible(x):
if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)):
return x
elif is_scipy_sparse(x):
return scipy_sparse_to_jax_sparse(x)
elif is_tensorflow_sparse(x):
return tf_sparse_to_jax_sparse(x)
else:
return np.asarray(x)
for batch in iterable: for batch in iterable:
yield tree.map_structure(convert_to_tensor, batch) yield tree.map_structure(convert_to_jax_compatible, batch)
def get_numpy_iterator(iterable): def get_numpy_iterator(iterable):
@ -289,17 +300,21 @@ def scipy_sparse_to_tf_sparse(x):
def scipy_sparse_to_jax_sparse(x): def scipy_sparse_to_jax_sparse(x):
import jax
import jax.experimental.sparse as jax_sparse import jax.experimental.sparse as jax_sparse
return jax_sparse.BCOO.from_scipy_sparse(x) with jax.default_device(jax.local_devices(backend="cpu")[0]):
return jax_sparse.BCOO.from_scipy_sparse(x)
def tf_sparse_to_jax_sparse(x): def tf_sparse_to_jax_sparse(x):
import jax
import jax.experimental.sparse as jax_sparse import jax.experimental.sparse as jax_sparse
values = np.asarray(x.values) values = np.asarray(x.values)
indices = np.asarray(x.indices) indices = np.asarray(x.indices)
return jax_sparse.BCOO((values, indices), shape=x.shape) with jax.default_device(jax.local_devices(backend="cpu")[0]):
return jax_sparse.BCOO((values, indices), shape=x.shape)
def jax_sparse_to_tf_sparse(x): def jax_sparse_to_tf_sparse(x):

@ -26,17 +26,7 @@ class GeneratorDataAdapter(DataAdapter):
return data_adapter_utils.get_numpy_iterator(self.generator) return data_adapter_utils.get_numpy_iterator(self.generator)
def get_jax_iterator(self): def get_jax_iterator(self):
from keras.src.backend.jax.core import convert_to_tensor return data_adapter_utils.get_jax_iterator(self.generator)
def convert_to_jax(x):
if data_adapter_utils.is_scipy_sparse(x):
return data_adapter_utils.scipy_sparse_to_jax_sparse(x)
elif data_adapter_utils.is_tensorflow_sparse(x):
return data_adapter_utils.tf_sparse_to_jax_sparse(x)
return convert_to_tensor(x)
for batch in self.generator:
yield tree.map_structure(convert_to_jax, batch)
def get_tf_dataset(self): def get_tf_dataset(self):
from keras.src.utils.module_utils import tensorflow as tf from keras.src.utils.module_utils import tensorflow as tf

@ -73,7 +73,9 @@ class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase):
expected_class = tf.Tensor expected_class = tf.Tensor
elif backend.backend() == "jax": elif backend.backend() == "jax":
it = adapter.get_jax_iterator() it = adapter.get_jax_iterator()
expected_class = jax.Array expected_class = (
jax.Array if generator_type == "jax" else np.ndarray
)
elif backend.backend() == "torch": elif backend.backend() == "torch":
it = adapter.get_torch_dataloader() it = adapter.get_torch_dataloader()
expected_class = torch.Tensor expected_class = torch.Tensor

@ -177,7 +177,7 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase):
expected_class = tf.Tensor expected_class = tf.Tensor
elif backend.backend() == "jax": elif backend.backend() == "jax":
it = adapter.get_jax_iterator() it = adapter.get_jax_iterator()
expected_class = jax.Array expected_class = jax.Array if dataset_type == "jax" else np.ndarray
elif backend.backend() == "torch": elif backend.backend() == "torch":
it = adapter.get_torch_dataloader() it = adapter.get_torch_dataloader()
expected_class = torch.Tensor expected_class = torch.Tensor

@ -41,20 +41,15 @@ class TFDatasetAdapter(DataAdapter):
yield tree.map_structure(convert_to_numpy, batch) yield tree.map_structure(convert_to_numpy, batch)
def get_jax_iterator(self): def get_jax_iterator(self):
import jax.experimental.sparse as jax_sparse
from keras.src.backend.jax.core import convert_to_tensor
from keras.src.backend.tensorflow.core import convert_to_numpy from keras.src.backend.tensorflow.core import convert_to_numpy
from keras.src.utils.module_utils import tensorflow as tf from keras.src.utils.module_utils import tensorflow as tf
def convert_to_jax(x): def convert_to_jax(x):
# We use numpy as an intermediary because the conversion
# tf -> numpy -> jax is more than 2x faster than tf -> jax.
if isinstance(x, tf.SparseTensor): if isinstance(x, tf.SparseTensor):
values = convert_to_numpy(x.values) return data_adapter_utils.tf_sparse_to_jax_sparse(x)
indices = convert_to_numpy(x.indices) else:
return jax_sparse.BCOO((values, indices), shape=x.shape) # We use numpy as an intermediary because it is faster.
return convert_to_tensor(convert_to_numpy(x)) return convert_to_numpy(x)
for batch in self._dataset: for batch in self._dataset:
yield tree.map_structure(convert_to_jax, batch) yield tree.map_structure(convert_to_jax, batch)

@ -32,7 +32,7 @@ class TestTFDatasetAdapter(testing.TestCase, parameterized.TestCase):
expected_class = tf.Tensor expected_class = tf.Tensor
elif backend.backend() == "jax": elif backend.backend() == "jax":
it = adapter.get_jax_iterator() it = adapter.get_jax_iterator()
expected_class = jax.Array expected_class = np.ndarray
elif backend.backend() == "torch": elif backend.backend() == "torch":
it = adapter.get_torch_dataloader() it = adapter.get_torch_dataloader()
expected_class = torch.Tensor expected_class = torch.Tensor

@ -39,9 +39,8 @@ class TorchDataLoaderAdapter(DataAdapter):
) )
def get_jax_iterator(self): def get_jax_iterator(self):
# We use numpy as an intermediary because the conversion # We use numpy as an intermediary because it is faster.
# torch -> numpy -> jax is faster than torch -> jax. return self.get_numpy_iterator()
return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator())
def get_tf_dataset(self): def get_tf_dataset(self):
from keras.src.utils.module_utils import tensorflow as tf from keras.src.utils.module_utils import tensorflow as tf

@ -1,6 +1,5 @@
import math import math
import jax
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
@ -35,7 +34,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase):
expected_class = tf.Tensor expected_class = tf.Tensor
elif backend.backend() == "jax": elif backend.backend() == "jax":
it = adapter.get_jax_iterator() it = adapter.get_jax_iterator()
expected_class = jax.Array expected_class = np.ndarray
elif backend.backend() == "torch": elif backend.backend() == "torch":
it = adapter.get_torch_dataloader() it = adapter.get_torch_dataloader()
expected_class = torch.Tensor expected_class = torch.Tensor
@ -104,7 +103,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase):
expected_class = tf.Tensor expected_class = tf.Tensor
elif backend.backend() == "jax": elif backend.backend() == "jax":
it = adapter.get_jax_iterator() it = adapter.get_jax_iterator()
expected_class = jax.Array expected_class = np.ndarray
elif backend.backend() == "torch": elif backend.backend() == "torch":
it = adapter.get_torch_dataloader() it = adapter.get_torch_dataloader()
expected_class = torch.Tensor expected_class = torch.Tensor