Use Numpy arrays in data adapters for JAX backend. (#20007)
- Slicing and splitting is more efficient with Numpy arrays on CPU than JAX arrays on CPU. - Converting to `jax.Array` sends them to the first device by default, which is incorrect in the case of distribution and slows down distribution. Also - Keep JAX BCOO tensors on CPU within data adapters. - Use `np.split` instead of `jax.numpy.split` in distribution_lib per the JAX distribution guide.
This commit is contained in:
parent
902f9da309
commit
808b78618e
@ -142,7 +142,7 @@ def distribute_data_input(inputs, layout):
|
|||||||
f"{num_split}"
|
f"{num_split}"
|
||||||
)
|
)
|
||||||
global_batch_size = per_process_batch_size * jax.process_count()
|
global_batch_size = per_process_batch_size * jax.process_count()
|
||||||
per_replica_batches = jax.numpy.split(inputs, num_split, axis=0)
|
per_replica_batches = np.split(inputs, num_split, axis=0)
|
||||||
elif mesh_rank == 2:
|
elif mesh_rank == 2:
|
||||||
# Data+Model parallel
|
# Data+Model parallel
|
||||||
# In this case, we need to check if the mesh batch dim shape is large
|
# In this case, we need to check if the mesh batch dim shape is large
|
||||||
|
@ -243,8 +243,6 @@ class ArrayDataAdapter(DataAdapter):
|
|||||||
return dataset.prefetch(tf.data.AUTOTUNE)
|
return dataset.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
def get_jax_iterator(self):
|
def get_jax_iterator(self):
|
||||||
from keras.src.backend.jax.core import convert_to_tensor
|
|
||||||
|
|
||||||
inputs = array_slicing.convert_to_sliceable(
|
inputs = array_slicing.convert_to_sliceable(
|
||||||
self._inputs, target_backend="jax"
|
self._inputs, target_backend="jax"
|
||||||
)
|
)
|
||||||
@ -252,7 +250,6 @@ class ArrayDataAdapter(DataAdapter):
|
|||||||
def slice_and_convert_to_jax(sliceable, indices=None):
|
def slice_and_convert_to_jax(sliceable, indices=None):
|
||||||
x = sliceable[indices]
|
x = sliceable[indices]
|
||||||
x = sliceable.convert_to_jax_compatible(x)
|
x = sliceable.convert_to_jax_compatible(x)
|
||||||
x = convert_to_tensor(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
return self._get_iterator(slice_and_convert_to_jax, inputs)
|
return self._get_iterator(slice_and_convert_to_jax, inputs)
|
||||||
|
@ -92,7 +92,7 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
|
|||||||
if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"):
|
if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"):
|
||||||
expected_class = jax_sparse.JAXSparse
|
expected_class = jax_sparse.JAXSparse
|
||||||
else:
|
else:
|
||||||
expected_class = jax.Array
|
expected_class = np.ndarray
|
||||||
elif backend.backend() == "torch":
|
elif backend.backend() == "torch":
|
||||||
it = adapter.get_torch_dataloader()
|
it = adapter.get_torch_dataloader()
|
||||||
expected_class = torch.Tensor
|
expected_class = torch.Tensor
|
||||||
|
@ -91,8 +91,7 @@ class Sliceable:
|
|||||||
def convert_to_jax_compatible(cls, x):
|
def convert_to_jax_compatible(cls, x):
|
||||||
"""Convert a tensor to something that the JAX backend can consume.
|
"""Convert a tensor to something that the JAX backend can consume.
|
||||||
|
|
||||||
This can be a `JAX` array, NumPy array or any other type of tensor that
|
This can be a `JAX` array, `JAXSparse` or a NumPy array.
|
||||||
`keras.backend.jax.core.convert_to_tensor()` can consume.
|
|
||||||
Only called after slicing using `__getitem__`.
|
Only called after slicing using `__getitem__`.
|
||||||
Used to convert sparse tensors and densify ragged tensors.
|
Used to convert sparse tensors and densify ragged tensors.
|
||||||
|
|
||||||
@ -147,7 +146,7 @@ class TensorflowSliceable(Sliceable):
|
|||||||
class TensorflowRaggedSliceable(TensorflowSliceable):
|
class TensorflowRaggedSliceable(TensorflowSliceable):
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_to_jax_compatible(cls, x):
|
def convert_to_jax_compatible(cls, x):
|
||||||
return x.to_tensor()
|
return cls.convert_to_numpy(x)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_to_torch_compatible(cls, x):
|
def convert_to_torch_compatible(cls, x):
|
||||||
@ -180,7 +179,7 @@ class TensorflowSparseSliceable(TensorflowSliceable):
|
|||||||
return tf_sparse.sparse_to_dense(x)
|
return tf_sparse.sparse_to_dense(x)
|
||||||
|
|
||||||
|
|
||||||
class JaxSliceable(Sliceable):
|
class JaxSparseSliceable(Sliceable):
|
||||||
def __getitem__(self, indices):
|
def __getitem__(self, indices):
|
||||||
return self.array[indices, ...]
|
return self.array[indices, ...]
|
||||||
|
|
||||||
@ -190,8 +189,6 @@ class JaxSliceable(Sliceable):
|
|||||||
|
|
||||||
return convert_to_numpy(x)
|
return convert_to_numpy(x)
|
||||||
|
|
||||||
|
|
||||||
class JaxSparseSliceable(JaxSliceable):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_to_tf_dataset_compatible(cls, array):
|
def convert_to_tf_dataset_compatible(cls, array):
|
||||||
return to_tensorflow_sparse_wrapper(
|
return to_tensorflow_sparse_wrapper(
|
||||||
@ -386,7 +383,8 @@ def convert_to_sliceable(arrays, target_backend=None):
|
|||||||
if data_adapter_utils.is_jax_sparse(x):
|
if data_adapter_utils.is_jax_sparse(x):
|
||||||
sliceable_class = JaxSparseSliceable
|
sliceable_class = JaxSparseSliceable
|
||||||
else:
|
else:
|
||||||
sliceable_class = JaxSliceable
|
x = np.asarray(x)
|
||||||
|
sliceable_class = NumpySliceable
|
||||||
elif data_adapter_utils.is_torch_tensor(x):
|
elif data_adapter_utils.is_torch_tensor(x):
|
||||||
sliceable_class = TorchSliceable
|
sliceable_class = TorchSliceable
|
||||||
elif pandas is not None and isinstance(x, pandas.DataFrame):
|
elif pandas is not None and isinstance(x, pandas.DataFrame):
|
||||||
@ -433,14 +431,14 @@ def convert_to_sliceable(arrays, target_backend=None):
|
|||||||
if target_backend == "tensorflow":
|
if target_backend == "tensorflow":
|
||||||
return sliceable_class.convert_to_tf_dataset_compatible(x)
|
return sliceable_class.convert_to_tf_dataset_compatible(x)
|
||||||
|
|
||||||
# With dense arrays, with JAX as either input or output, it is faster to
|
# With dense arrays and JAX as output, it is faster to use NumPy as an
|
||||||
# use NumPy as an intermediary representation, so wrap input array in a
|
# intermediary representation, so wrap input array in a NumPy array,
|
||||||
# NumPy array, which should not use extra memory. For the input case,
|
# which should not use extra memory.
|
||||||
# see https://github.com/google/jax/issues/1276 for an explanation of
|
# See https://github.com/google/jax/issues/1276 for an explanation of
|
||||||
# why slicing a NumPy array is faster than slicing a JAX array.
|
# why slicing a NumPy array is faster than slicing a JAX array.
|
||||||
if sliceable_class == JaxSliceable or (
|
if target_backend == "jax" and sliceable_class in (
|
||||||
target_backend == "jax"
|
TensorflowSliceable,
|
||||||
and sliceable_class in (TensorflowSliceable, TorchSliceable)
|
TorchSliceable,
|
||||||
):
|
):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
sliceable_class = NumpySliceable
|
sliceable_class = NumpySliceable
|
||||||
|
@ -30,7 +30,8 @@ class DataAdapter:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_jax_iterator(self):
|
def get_jax_iterator(self):
|
||||||
"""Get a Python iterable for the `DataAdapter`, that yields JAX arrays.
|
"""Get a Python iterable for the `DataAdapter`, that yields arrays that
|
||||||
|
that can be fed to JAX. NumPy arrays are preferred for performance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python iterator.
|
A Python iterator.
|
||||||
|
@ -176,10 +176,21 @@ def get_tensor_spec(batches):
|
|||||||
|
|
||||||
|
|
||||||
def get_jax_iterator(iterable):
|
def get_jax_iterator(iterable):
|
||||||
from keras.src.backend.jax.core import convert_to_tensor
|
import jax
|
||||||
|
import jax.experimental.sparse as jax_sparse
|
||||||
|
|
||||||
|
def convert_to_jax_compatible(x):
|
||||||
|
if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)):
|
||||||
|
return x
|
||||||
|
elif is_scipy_sparse(x):
|
||||||
|
return scipy_sparse_to_jax_sparse(x)
|
||||||
|
elif is_tensorflow_sparse(x):
|
||||||
|
return tf_sparse_to_jax_sparse(x)
|
||||||
|
else:
|
||||||
|
return np.asarray(x)
|
||||||
|
|
||||||
for batch in iterable:
|
for batch in iterable:
|
||||||
yield tree.map_structure(convert_to_tensor, batch)
|
yield tree.map_structure(convert_to_jax_compatible, batch)
|
||||||
|
|
||||||
|
|
||||||
def get_numpy_iterator(iterable):
|
def get_numpy_iterator(iterable):
|
||||||
@ -289,17 +300,21 @@ def scipy_sparse_to_tf_sparse(x):
|
|||||||
|
|
||||||
|
|
||||||
def scipy_sparse_to_jax_sparse(x):
|
def scipy_sparse_to_jax_sparse(x):
|
||||||
|
import jax
|
||||||
import jax.experimental.sparse as jax_sparse
|
import jax.experimental.sparse as jax_sparse
|
||||||
|
|
||||||
return jax_sparse.BCOO.from_scipy_sparse(x)
|
with jax.default_device(jax.local_devices(backend="cpu")[0]):
|
||||||
|
return jax_sparse.BCOO.from_scipy_sparse(x)
|
||||||
|
|
||||||
|
|
||||||
def tf_sparse_to_jax_sparse(x):
|
def tf_sparse_to_jax_sparse(x):
|
||||||
|
import jax
|
||||||
import jax.experimental.sparse as jax_sparse
|
import jax.experimental.sparse as jax_sparse
|
||||||
|
|
||||||
values = np.asarray(x.values)
|
values = np.asarray(x.values)
|
||||||
indices = np.asarray(x.indices)
|
indices = np.asarray(x.indices)
|
||||||
return jax_sparse.BCOO((values, indices), shape=x.shape)
|
with jax.default_device(jax.local_devices(backend="cpu")[0]):
|
||||||
|
return jax_sparse.BCOO((values, indices), shape=x.shape)
|
||||||
|
|
||||||
|
|
||||||
def jax_sparse_to_tf_sparse(x):
|
def jax_sparse_to_tf_sparse(x):
|
||||||
|
@ -26,17 +26,7 @@ class GeneratorDataAdapter(DataAdapter):
|
|||||||
return data_adapter_utils.get_numpy_iterator(self.generator)
|
return data_adapter_utils.get_numpy_iterator(self.generator)
|
||||||
|
|
||||||
def get_jax_iterator(self):
|
def get_jax_iterator(self):
|
||||||
from keras.src.backend.jax.core import convert_to_tensor
|
return data_adapter_utils.get_jax_iterator(self.generator)
|
||||||
|
|
||||||
def convert_to_jax(x):
|
|
||||||
if data_adapter_utils.is_scipy_sparse(x):
|
|
||||||
return data_adapter_utils.scipy_sparse_to_jax_sparse(x)
|
|
||||||
elif data_adapter_utils.is_tensorflow_sparse(x):
|
|
||||||
return data_adapter_utils.tf_sparse_to_jax_sparse(x)
|
|
||||||
return convert_to_tensor(x)
|
|
||||||
|
|
||||||
for batch in self.generator:
|
|
||||||
yield tree.map_structure(convert_to_jax, batch)
|
|
||||||
|
|
||||||
def get_tf_dataset(self):
|
def get_tf_dataset(self):
|
||||||
from keras.src.utils.module_utils import tensorflow as tf
|
from keras.src.utils.module_utils import tensorflow as tf
|
||||||
|
@ -73,7 +73,9 @@ class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase):
|
|||||||
expected_class = tf.Tensor
|
expected_class = tf.Tensor
|
||||||
elif backend.backend() == "jax":
|
elif backend.backend() == "jax":
|
||||||
it = adapter.get_jax_iterator()
|
it = adapter.get_jax_iterator()
|
||||||
expected_class = jax.Array
|
expected_class = (
|
||||||
|
jax.Array if generator_type == "jax" else np.ndarray
|
||||||
|
)
|
||||||
elif backend.backend() == "torch":
|
elif backend.backend() == "torch":
|
||||||
it = adapter.get_torch_dataloader()
|
it = adapter.get_torch_dataloader()
|
||||||
expected_class = torch.Tensor
|
expected_class = torch.Tensor
|
||||||
|
@ -177,7 +177,7 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase):
|
|||||||
expected_class = tf.Tensor
|
expected_class = tf.Tensor
|
||||||
elif backend.backend() == "jax":
|
elif backend.backend() == "jax":
|
||||||
it = adapter.get_jax_iterator()
|
it = adapter.get_jax_iterator()
|
||||||
expected_class = jax.Array
|
expected_class = jax.Array if dataset_type == "jax" else np.ndarray
|
||||||
elif backend.backend() == "torch":
|
elif backend.backend() == "torch":
|
||||||
it = adapter.get_torch_dataloader()
|
it = adapter.get_torch_dataloader()
|
||||||
expected_class = torch.Tensor
|
expected_class = torch.Tensor
|
||||||
|
@ -41,20 +41,15 @@ class TFDatasetAdapter(DataAdapter):
|
|||||||
yield tree.map_structure(convert_to_numpy, batch)
|
yield tree.map_structure(convert_to_numpy, batch)
|
||||||
|
|
||||||
def get_jax_iterator(self):
|
def get_jax_iterator(self):
|
||||||
import jax.experimental.sparse as jax_sparse
|
|
||||||
|
|
||||||
from keras.src.backend.jax.core import convert_to_tensor
|
|
||||||
from keras.src.backend.tensorflow.core import convert_to_numpy
|
from keras.src.backend.tensorflow.core import convert_to_numpy
|
||||||
from keras.src.utils.module_utils import tensorflow as tf
|
from keras.src.utils.module_utils import tensorflow as tf
|
||||||
|
|
||||||
def convert_to_jax(x):
|
def convert_to_jax(x):
|
||||||
# We use numpy as an intermediary because the conversion
|
|
||||||
# tf -> numpy -> jax is more than 2x faster than tf -> jax.
|
|
||||||
if isinstance(x, tf.SparseTensor):
|
if isinstance(x, tf.SparseTensor):
|
||||||
values = convert_to_numpy(x.values)
|
return data_adapter_utils.tf_sparse_to_jax_sparse(x)
|
||||||
indices = convert_to_numpy(x.indices)
|
else:
|
||||||
return jax_sparse.BCOO((values, indices), shape=x.shape)
|
# We use numpy as an intermediary because it is faster.
|
||||||
return convert_to_tensor(convert_to_numpy(x))
|
return convert_to_numpy(x)
|
||||||
|
|
||||||
for batch in self._dataset:
|
for batch in self._dataset:
|
||||||
yield tree.map_structure(convert_to_jax, batch)
|
yield tree.map_structure(convert_to_jax, batch)
|
||||||
|
@ -32,7 +32,7 @@ class TestTFDatasetAdapter(testing.TestCase, parameterized.TestCase):
|
|||||||
expected_class = tf.Tensor
|
expected_class = tf.Tensor
|
||||||
elif backend.backend() == "jax":
|
elif backend.backend() == "jax":
|
||||||
it = adapter.get_jax_iterator()
|
it = adapter.get_jax_iterator()
|
||||||
expected_class = jax.Array
|
expected_class = np.ndarray
|
||||||
elif backend.backend() == "torch":
|
elif backend.backend() == "torch":
|
||||||
it = adapter.get_torch_dataloader()
|
it = adapter.get_torch_dataloader()
|
||||||
expected_class = torch.Tensor
|
expected_class = torch.Tensor
|
||||||
|
@ -39,9 +39,8 @@ class TorchDataLoaderAdapter(DataAdapter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_jax_iterator(self):
|
def get_jax_iterator(self):
|
||||||
# We use numpy as an intermediary because the conversion
|
# We use numpy as an intermediary because it is faster.
|
||||||
# torch -> numpy -> jax is faster than torch -> jax.
|
return self.get_numpy_iterator()
|
||||||
return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator())
|
|
||||||
|
|
||||||
def get_tf_dataset(self):
|
def get_tf_dataset(self):
|
||||||
from keras.src.utils.module_utils import tensorflow as tf
|
from keras.src.utils.module_utils import tensorflow as tf
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import jax
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
@ -35,7 +34,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase):
|
|||||||
expected_class = tf.Tensor
|
expected_class = tf.Tensor
|
||||||
elif backend.backend() == "jax":
|
elif backend.backend() == "jax":
|
||||||
it = adapter.get_jax_iterator()
|
it = adapter.get_jax_iterator()
|
||||||
expected_class = jax.Array
|
expected_class = np.ndarray
|
||||||
elif backend.backend() == "torch":
|
elif backend.backend() == "torch":
|
||||||
it = adapter.get_torch_dataloader()
|
it = adapter.get_torch_dataloader()
|
||||||
expected_class = torch.Tensor
|
expected_class = torch.Tensor
|
||||||
@ -104,7 +103,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase):
|
|||||||
expected_class = tf.Tensor
|
expected_class = tf.Tensor
|
||||||
elif backend.backend() == "jax":
|
elif backend.backend() == "jax":
|
||||||
it = adapter.get_jax_iterator()
|
it = adapter.get_jax_iterator()
|
||||||
expected_class = jax.Array
|
expected_class = np.ndarray
|
||||||
elif backend.backend() == "torch":
|
elif backend.backend() == "torch":
|
||||||
it = adapter.get_torch_dataloader()
|
it = adapter.get_torch_dataloader()
|
||||||
expected_class = torch.Tensor
|
expected_class = torch.Tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user