Use Numpy arrays in data adapters for JAX backend. (#20007)
- Slicing and splitting is more efficient with Numpy arrays on CPU than JAX arrays on CPU. - Converting to `jax.Array` sends them to the first device by default, which is incorrect in the case of distribution and slows down distribution. Also - Keep JAX BCOO tensors on CPU within data adapters. - Use `np.split` instead of `jax.numpy.split` in distribution_lib per the JAX distribution guide.
This commit is contained in:
parent
902f9da309
commit
808b78618e
@ -142,7 +142,7 @@ def distribute_data_input(inputs, layout):
|
||||
f"{num_split}"
|
||||
)
|
||||
global_batch_size = per_process_batch_size * jax.process_count()
|
||||
per_replica_batches = jax.numpy.split(inputs, num_split, axis=0)
|
||||
per_replica_batches = np.split(inputs, num_split, axis=0)
|
||||
elif mesh_rank == 2:
|
||||
# Data+Model parallel
|
||||
# In this case, we need to check if the mesh batch dim shape is large
|
||||
|
@ -243,8 +243,6 @@ class ArrayDataAdapter(DataAdapter):
|
||||
return dataset.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
def get_jax_iterator(self):
|
||||
from keras.src.backend.jax.core import convert_to_tensor
|
||||
|
||||
inputs = array_slicing.convert_to_sliceable(
|
||||
self._inputs, target_backend="jax"
|
||||
)
|
||||
@ -252,7 +250,6 @@ class ArrayDataAdapter(DataAdapter):
|
||||
def slice_and_convert_to_jax(sliceable, indices=None):
|
||||
x = sliceable[indices]
|
||||
x = sliceable.convert_to_jax_compatible(x)
|
||||
x = convert_to_tensor(x)
|
||||
return x
|
||||
|
||||
return self._get_iterator(slice_and_convert_to_jax, inputs)
|
||||
|
@ -92,7 +92,7 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
|
||||
if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"):
|
||||
expected_class = jax_sparse.JAXSparse
|
||||
else:
|
||||
expected_class = jax.Array
|
||||
expected_class = np.ndarray
|
||||
elif backend.backend() == "torch":
|
||||
it = adapter.get_torch_dataloader()
|
||||
expected_class = torch.Tensor
|
||||
|
@ -91,8 +91,7 @@ class Sliceable:
|
||||
def convert_to_jax_compatible(cls, x):
|
||||
"""Convert a tensor to something that the JAX backend can consume.
|
||||
|
||||
This can be a `JAX` array, NumPy array or any other type of tensor that
|
||||
`keras.backend.jax.core.convert_to_tensor()` can consume.
|
||||
This can be a `JAX` array, `JAXSparse` or a NumPy array.
|
||||
Only called after slicing using `__getitem__`.
|
||||
Used to convert sparse tensors and densify ragged tensors.
|
||||
|
||||
@ -147,7 +146,7 @@ class TensorflowSliceable(Sliceable):
|
||||
class TensorflowRaggedSliceable(TensorflowSliceable):
|
||||
@classmethod
|
||||
def convert_to_jax_compatible(cls, x):
|
||||
return x.to_tensor()
|
||||
return cls.convert_to_numpy(x)
|
||||
|
||||
@classmethod
|
||||
def convert_to_torch_compatible(cls, x):
|
||||
@ -180,7 +179,7 @@ class TensorflowSparseSliceable(TensorflowSliceable):
|
||||
return tf_sparse.sparse_to_dense(x)
|
||||
|
||||
|
||||
class JaxSliceable(Sliceable):
|
||||
class JaxSparseSliceable(Sliceable):
|
||||
def __getitem__(self, indices):
|
||||
return self.array[indices, ...]
|
||||
|
||||
@ -190,8 +189,6 @@ class JaxSliceable(Sliceable):
|
||||
|
||||
return convert_to_numpy(x)
|
||||
|
||||
|
||||
class JaxSparseSliceable(JaxSliceable):
|
||||
@classmethod
|
||||
def convert_to_tf_dataset_compatible(cls, array):
|
||||
return to_tensorflow_sparse_wrapper(
|
||||
@ -386,7 +383,8 @@ def convert_to_sliceable(arrays, target_backend=None):
|
||||
if data_adapter_utils.is_jax_sparse(x):
|
||||
sliceable_class = JaxSparseSliceable
|
||||
else:
|
||||
sliceable_class = JaxSliceable
|
||||
x = np.asarray(x)
|
||||
sliceable_class = NumpySliceable
|
||||
elif data_adapter_utils.is_torch_tensor(x):
|
||||
sliceable_class = TorchSliceable
|
||||
elif pandas is not None and isinstance(x, pandas.DataFrame):
|
||||
@ -433,14 +431,14 @@ def convert_to_sliceable(arrays, target_backend=None):
|
||||
if target_backend == "tensorflow":
|
||||
return sliceable_class.convert_to_tf_dataset_compatible(x)
|
||||
|
||||
# With dense arrays, with JAX as either input or output, it is faster to
|
||||
# use NumPy as an intermediary representation, so wrap input array in a
|
||||
# NumPy array, which should not use extra memory. For the input case,
|
||||
# see https://github.com/google/jax/issues/1276 for an explanation of
|
||||
# With dense arrays and JAX as output, it is faster to use NumPy as an
|
||||
# intermediary representation, so wrap input array in a NumPy array,
|
||||
# which should not use extra memory.
|
||||
# See https://github.com/google/jax/issues/1276 for an explanation of
|
||||
# why slicing a NumPy array is faster than slicing a JAX array.
|
||||
if sliceable_class == JaxSliceable or (
|
||||
target_backend == "jax"
|
||||
and sliceable_class in (TensorflowSliceable, TorchSliceable)
|
||||
if target_backend == "jax" and sliceable_class in (
|
||||
TensorflowSliceable,
|
||||
TorchSliceable,
|
||||
):
|
||||
x = np.asarray(x)
|
||||
sliceable_class = NumpySliceable
|
||||
|
@ -30,7 +30,8 @@ class DataAdapter:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_jax_iterator(self):
|
||||
"""Get a Python iterable for the `DataAdapter`, that yields JAX arrays.
|
||||
"""Get a Python iterable for the `DataAdapter`, that yields arrays that
|
||||
that can be fed to JAX. NumPy arrays are preferred for performance.
|
||||
|
||||
Returns:
|
||||
A Python iterator.
|
||||
|
@ -176,10 +176,21 @@ def get_tensor_spec(batches):
|
||||
|
||||
|
||||
def get_jax_iterator(iterable):
|
||||
from keras.src.backend.jax.core import convert_to_tensor
|
||||
import jax
|
||||
import jax.experimental.sparse as jax_sparse
|
||||
|
||||
def convert_to_jax_compatible(x):
|
||||
if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)):
|
||||
return x
|
||||
elif is_scipy_sparse(x):
|
||||
return scipy_sparse_to_jax_sparse(x)
|
||||
elif is_tensorflow_sparse(x):
|
||||
return tf_sparse_to_jax_sparse(x)
|
||||
else:
|
||||
return np.asarray(x)
|
||||
|
||||
for batch in iterable:
|
||||
yield tree.map_structure(convert_to_tensor, batch)
|
||||
yield tree.map_structure(convert_to_jax_compatible, batch)
|
||||
|
||||
|
||||
def get_numpy_iterator(iterable):
|
||||
@ -289,17 +300,21 @@ def scipy_sparse_to_tf_sparse(x):
|
||||
|
||||
|
||||
def scipy_sparse_to_jax_sparse(x):
|
||||
import jax
|
||||
import jax.experimental.sparse as jax_sparse
|
||||
|
||||
return jax_sparse.BCOO.from_scipy_sparse(x)
|
||||
with jax.default_device(jax.local_devices(backend="cpu")[0]):
|
||||
return jax_sparse.BCOO.from_scipy_sparse(x)
|
||||
|
||||
|
||||
def tf_sparse_to_jax_sparse(x):
|
||||
import jax
|
||||
import jax.experimental.sparse as jax_sparse
|
||||
|
||||
values = np.asarray(x.values)
|
||||
indices = np.asarray(x.indices)
|
||||
return jax_sparse.BCOO((values, indices), shape=x.shape)
|
||||
with jax.default_device(jax.local_devices(backend="cpu")[0]):
|
||||
return jax_sparse.BCOO((values, indices), shape=x.shape)
|
||||
|
||||
|
||||
def jax_sparse_to_tf_sparse(x):
|
||||
|
@ -26,17 +26,7 @@ class GeneratorDataAdapter(DataAdapter):
|
||||
return data_adapter_utils.get_numpy_iterator(self.generator)
|
||||
|
||||
def get_jax_iterator(self):
|
||||
from keras.src.backend.jax.core import convert_to_tensor
|
||||
|
||||
def convert_to_jax(x):
|
||||
if data_adapter_utils.is_scipy_sparse(x):
|
||||
return data_adapter_utils.scipy_sparse_to_jax_sparse(x)
|
||||
elif data_adapter_utils.is_tensorflow_sparse(x):
|
||||
return data_adapter_utils.tf_sparse_to_jax_sparse(x)
|
||||
return convert_to_tensor(x)
|
||||
|
||||
for batch in self.generator:
|
||||
yield tree.map_structure(convert_to_jax, batch)
|
||||
return data_adapter_utils.get_jax_iterator(self.generator)
|
||||
|
||||
def get_tf_dataset(self):
|
||||
from keras.src.utils.module_utils import tensorflow as tf
|
||||
|
@ -73,7 +73,9 @@ class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase):
|
||||
expected_class = tf.Tensor
|
||||
elif backend.backend() == "jax":
|
||||
it = adapter.get_jax_iterator()
|
||||
expected_class = jax.Array
|
||||
expected_class = (
|
||||
jax.Array if generator_type == "jax" else np.ndarray
|
||||
)
|
||||
elif backend.backend() == "torch":
|
||||
it = adapter.get_torch_dataloader()
|
||||
expected_class = torch.Tensor
|
||||
|
@ -177,7 +177,7 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase):
|
||||
expected_class = tf.Tensor
|
||||
elif backend.backend() == "jax":
|
||||
it = adapter.get_jax_iterator()
|
||||
expected_class = jax.Array
|
||||
expected_class = jax.Array if dataset_type == "jax" else np.ndarray
|
||||
elif backend.backend() == "torch":
|
||||
it = adapter.get_torch_dataloader()
|
||||
expected_class = torch.Tensor
|
||||
|
@ -41,20 +41,15 @@ class TFDatasetAdapter(DataAdapter):
|
||||
yield tree.map_structure(convert_to_numpy, batch)
|
||||
|
||||
def get_jax_iterator(self):
|
||||
import jax.experimental.sparse as jax_sparse
|
||||
|
||||
from keras.src.backend.jax.core import convert_to_tensor
|
||||
from keras.src.backend.tensorflow.core import convert_to_numpy
|
||||
from keras.src.utils.module_utils import tensorflow as tf
|
||||
|
||||
def convert_to_jax(x):
|
||||
# We use numpy as an intermediary because the conversion
|
||||
# tf -> numpy -> jax is more than 2x faster than tf -> jax.
|
||||
if isinstance(x, tf.SparseTensor):
|
||||
values = convert_to_numpy(x.values)
|
||||
indices = convert_to_numpy(x.indices)
|
||||
return jax_sparse.BCOO((values, indices), shape=x.shape)
|
||||
return convert_to_tensor(convert_to_numpy(x))
|
||||
return data_adapter_utils.tf_sparse_to_jax_sparse(x)
|
||||
else:
|
||||
# We use numpy as an intermediary because it is faster.
|
||||
return convert_to_numpy(x)
|
||||
|
||||
for batch in self._dataset:
|
||||
yield tree.map_structure(convert_to_jax, batch)
|
||||
|
@ -32,7 +32,7 @@ class TestTFDatasetAdapter(testing.TestCase, parameterized.TestCase):
|
||||
expected_class = tf.Tensor
|
||||
elif backend.backend() == "jax":
|
||||
it = adapter.get_jax_iterator()
|
||||
expected_class = jax.Array
|
||||
expected_class = np.ndarray
|
||||
elif backend.backend() == "torch":
|
||||
it = adapter.get_torch_dataloader()
|
||||
expected_class = torch.Tensor
|
||||
|
@ -39,9 +39,8 @@ class TorchDataLoaderAdapter(DataAdapter):
|
||||
)
|
||||
|
||||
def get_jax_iterator(self):
|
||||
# We use numpy as an intermediary because the conversion
|
||||
# torch -> numpy -> jax is faster than torch -> jax.
|
||||
return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator())
|
||||
# We use numpy as an intermediary because it is faster.
|
||||
return self.get_numpy_iterator()
|
||||
|
||||
def get_tf_dataset(self):
|
||||
from keras.src.utils.module_utils import tensorflow as tf
|
||||
|
@ -1,6 +1,5 @@
|
||||
import math
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
@ -35,7 +34,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase):
|
||||
expected_class = tf.Tensor
|
||||
elif backend.backend() == "jax":
|
||||
it = adapter.get_jax_iterator()
|
||||
expected_class = jax.Array
|
||||
expected_class = np.ndarray
|
||||
elif backend.backend() == "torch":
|
||||
it = adapter.get_torch_dataloader()
|
||||
expected_class = torch.Tensor
|
||||
@ -104,7 +103,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase):
|
||||
expected_class = tf.Tensor
|
||||
elif backend.backend() == "jax":
|
||||
it = adapter.get_jax_iterator()
|
||||
expected_class = jax.Array
|
||||
expected_class = np.ndarray
|
||||
elif backend.backend() == "torch":
|
||||
it = adapter.get_torch_dataloader()
|
||||
expected_class = torch.Tensor
|
||||
|
Loading…
Reference in New Issue
Block a user