diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 988b66a79..572676fcb 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -142,7 +142,7 @@ def distribute_data_input(inputs, layout): f"{num_split}" ) global_batch_size = per_process_batch_size * jax.process_count() - per_replica_batches = jax.numpy.split(inputs, num_split, axis=0) + per_replica_batches = np.split(inputs, num_split, axis=0) elif mesh_rank == 2: # Data+Model parallel # In this case, we need to check if the mesh batch dim shape is large diff --git a/keras/src/trainers/data_adapters/array_data_adapter.py b/keras/src/trainers/data_adapters/array_data_adapter.py index 3832d8e55..10b4dc37a 100644 --- a/keras/src/trainers/data_adapters/array_data_adapter.py +++ b/keras/src/trainers/data_adapters/array_data_adapter.py @@ -243,8 +243,6 @@ class ArrayDataAdapter(DataAdapter): return dataset.prefetch(tf.data.AUTOTUNE) def get_jax_iterator(self): - from keras.src.backend.jax.core import convert_to_tensor - inputs = array_slicing.convert_to_sliceable( self._inputs, target_backend="jax" ) @@ -252,7 +250,6 @@ class ArrayDataAdapter(DataAdapter): def slice_and_convert_to_jax(sliceable, indices=None): x = sliceable[indices] x = sliceable.convert_to_jax_compatible(x) - x = convert_to_tensor(x) return x return self._get_iterator(slice_and_convert_to_jax, inputs) diff --git a/keras/src/trainers/data_adapters/array_data_adapter_test.py b/keras/src/trainers/data_adapters/array_data_adapter_test.py index a61a90424..6b042f62c 100644 --- a/keras/src/trainers/data_adapters/array_data_adapter_test.py +++ b/keras/src/trainers/data_adapters/array_data_adapter_test.py @@ -92,7 +92,7 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase): if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"): expected_class = jax_sparse.JAXSparse else: - expected_class = jax.Array + expected_class = np.ndarray elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor diff --git a/keras/src/trainers/data_adapters/array_slicing.py b/keras/src/trainers/data_adapters/array_slicing.py index 50279aa11..a0a75c3a3 100644 --- a/keras/src/trainers/data_adapters/array_slicing.py +++ b/keras/src/trainers/data_adapters/array_slicing.py @@ -91,8 +91,7 @@ class Sliceable: def convert_to_jax_compatible(cls, x): """Convert a tensor to something that the JAX backend can consume. - This can be a `JAX` array, NumPy array or any other type of tensor that - `keras.backend.jax.core.convert_to_tensor()` can consume. + This can be a `JAX` array, `JAXSparse` or a NumPy array. Only called after slicing using `__getitem__`. Used to convert sparse tensors and densify ragged tensors. @@ -147,7 +146,7 @@ class TensorflowSliceable(Sliceable): class TensorflowRaggedSliceable(TensorflowSliceable): @classmethod def convert_to_jax_compatible(cls, x): - return x.to_tensor() + return cls.convert_to_numpy(x) @classmethod def convert_to_torch_compatible(cls, x): @@ -180,7 +179,7 @@ class TensorflowSparseSliceable(TensorflowSliceable): return tf_sparse.sparse_to_dense(x) -class JaxSliceable(Sliceable): +class JaxSparseSliceable(Sliceable): def __getitem__(self, indices): return self.array[indices, ...] @@ -190,8 +189,6 @@ class JaxSliceable(Sliceable): return convert_to_numpy(x) - -class JaxSparseSliceable(JaxSliceable): @classmethod def convert_to_tf_dataset_compatible(cls, array): return to_tensorflow_sparse_wrapper( @@ -386,7 +383,8 @@ def convert_to_sliceable(arrays, target_backend=None): if data_adapter_utils.is_jax_sparse(x): sliceable_class = JaxSparseSliceable else: - sliceable_class = JaxSliceable + x = np.asarray(x) + sliceable_class = NumpySliceable elif data_adapter_utils.is_torch_tensor(x): sliceable_class = TorchSliceable elif pandas is not None and isinstance(x, pandas.DataFrame): @@ -433,14 +431,14 @@ def convert_to_sliceable(arrays, target_backend=None): if target_backend == "tensorflow": return sliceable_class.convert_to_tf_dataset_compatible(x) - # With dense arrays, with JAX as either input or output, it is faster to - # use NumPy as an intermediary representation, so wrap input array in a - # NumPy array, which should not use extra memory. For the input case, - # see https://github.com/google/jax/issues/1276 for an explanation of + # With dense arrays and JAX as output, it is faster to use NumPy as an + # intermediary representation, so wrap input array in a NumPy array, + # which should not use extra memory. + # See https://github.com/google/jax/issues/1276 for an explanation of # why slicing a NumPy array is faster than slicing a JAX array. - if sliceable_class == JaxSliceable or ( - target_backend == "jax" - and sliceable_class in (TensorflowSliceable, TorchSliceable) + if target_backend == "jax" and sliceable_class in ( + TensorflowSliceable, + TorchSliceable, ): x = np.asarray(x) sliceable_class = NumpySliceable diff --git a/keras/src/trainers/data_adapters/data_adapter.py b/keras/src/trainers/data_adapters/data_adapter.py index 1be272853..9f0cd315f 100644 --- a/keras/src/trainers/data_adapters/data_adapter.py +++ b/keras/src/trainers/data_adapters/data_adapter.py @@ -30,7 +30,8 @@ class DataAdapter: raise NotImplementedError def get_jax_iterator(self): - """Get a Python iterable for the `DataAdapter`, that yields JAX arrays. + """Get a Python iterable for the `DataAdapter`, that yields arrays that + that can be fed to JAX. NumPy arrays are preferred for performance. Returns: A Python iterator. diff --git a/keras/src/trainers/data_adapters/data_adapter_utils.py b/keras/src/trainers/data_adapters/data_adapter_utils.py index 83dae01e1..2ac98f142 100644 --- a/keras/src/trainers/data_adapters/data_adapter_utils.py +++ b/keras/src/trainers/data_adapters/data_adapter_utils.py @@ -176,10 +176,21 @@ def get_tensor_spec(batches): def get_jax_iterator(iterable): - from keras.src.backend.jax.core import convert_to_tensor + import jax + import jax.experimental.sparse as jax_sparse + + def convert_to_jax_compatible(x): + if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)): + return x + elif is_scipy_sparse(x): + return scipy_sparse_to_jax_sparse(x) + elif is_tensorflow_sparse(x): + return tf_sparse_to_jax_sparse(x) + else: + return np.asarray(x) for batch in iterable: - yield tree.map_structure(convert_to_tensor, batch) + yield tree.map_structure(convert_to_jax_compatible, batch) def get_numpy_iterator(iterable): @@ -289,17 +300,21 @@ def scipy_sparse_to_tf_sparse(x): def scipy_sparse_to_jax_sparse(x): + import jax import jax.experimental.sparse as jax_sparse - return jax_sparse.BCOO.from_scipy_sparse(x) + with jax.default_device(jax.local_devices(backend="cpu")[0]): + return jax_sparse.BCOO.from_scipy_sparse(x) def tf_sparse_to_jax_sparse(x): + import jax import jax.experimental.sparse as jax_sparse values = np.asarray(x.values) indices = np.asarray(x.indices) - return jax_sparse.BCOO((values, indices), shape=x.shape) + with jax.default_device(jax.local_devices(backend="cpu")[0]): + return jax_sparse.BCOO((values, indices), shape=x.shape) def jax_sparse_to_tf_sparse(x): diff --git a/keras/src/trainers/data_adapters/generator_data_adapter.py b/keras/src/trainers/data_adapters/generator_data_adapter.py index 30ebfac27..7f2418388 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -26,17 +26,7 @@ class GeneratorDataAdapter(DataAdapter): return data_adapter_utils.get_numpy_iterator(self.generator) def get_jax_iterator(self): - from keras.src.backend.jax.core import convert_to_tensor - - def convert_to_jax(x): - if data_adapter_utils.is_scipy_sparse(x): - return data_adapter_utils.scipy_sparse_to_jax_sparse(x) - elif data_adapter_utils.is_tensorflow_sparse(x): - return data_adapter_utils.tf_sparse_to_jax_sparse(x) - return convert_to_tensor(x) - - for batch in self.generator: - yield tree.map_structure(convert_to_jax, batch) + return data_adapter_utils.get_jax_iterator(self.generator) def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf diff --git a/keras/src/trainers/data_adapters/generator_data_adapter_test.py b/keras/src/trainers/data_adapters/generator_data_adapter_test.py index c40f8822a..8a9ef09b2 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter_test.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter_test.py @@ -73,7 +73,9 @@ class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase): expected_class = tf.Tensor elif backend.backend() == "jax": it = adapter.get_jax_iterator() - expected_class = jax.Array + expected_class = ( + jax.Array if generator_type == "jax" else np.ndarray + ) elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py index 4a7b869d0..715f729fa 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py @@ -177,7 +177,7 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase): expected_class = tf.Tensor elif backend.backend() == "jax": it = adapter.get_jax_iterator() - expected_class = jax.Array + expected_class = jax.Array if dataset_type == "jax" else np.ndarray elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py index fcd4c9893..f3fa19b4f 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -41,20 +41,15 @@ class TFDatasetAdapter(DataAdapter): yield tree.map_structure(convert_to_numpy, batch) def get_jax_iterator(self): - import jax.experimental.sparse as jax_sparse - - from keras.src.backend.jax.core import convert_to_tensor from keras.src.backend.tensorflow.core import convert_to_numpy from keras.src.utils.module_utils import tensorflow as tf def convert_to_jax(x): - # We use numpy as an intermediary because the conversion - # tf -> numpy -> jax is more than 2x faster than tf -> jax. if isinstance(x, tf.SparseTensor): - values = convert_to_numpy(x.values) - indices = convert_to_numpy(x.indices) - return jax_sparse.BCOO((values, indices), shape=x.shape) - return convert_to_tensor(convert_to_numpy(x)) + return data_adapter_utils.tf_sparse_to_jax_sparse(x) + else: + # We use numpy as an intermediary because it is faster. + return convert_to_numpy(x) for batch in self._dataset: yield tree.map_structure(convert_to_jax, batch) diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py index 2535e505d..a69d859f1 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py @@ -32,7 +32,7 @@ class TestTFDatasetAdapter(testing.TestCase, parameterized.TestCase): expected_class = tf.Tensor elif backend.backend() == "jax": it = adapter.get_jax_iterator() - expected_class = jax.Array + expected_class = np.ndarray elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 59a89050c..8aeb45110 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -39,9 +39,8 @@ class TorchDataLoaderAdapter(DataAdapter): ) def get_jax_iterator(self): - # We use numpy as an intermediary because the conversion - # torch -> numpy -> jax is faster than torch -> jax. - return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator()) + # We use numpy as an intermediary because it is faster. + return self.get_numpy_iterator() def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py index 4d02f5592..2c87b7509 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py @@ -1,6 +1,5 @@ import math -import jax import numpy as np import tensorflow as tf import torch @@ -35,7 +34,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase): expected_class = tf.Tensor elif backend.backend() == "jax": it = adapter.get_jax_iterator() - expected_class = jax.Array + expected_class = np.ndarray elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -104,7 +103,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase): expected_class = tf.Tensor elif backend.backend() == "jax": it = adapter.get_jax_iterator() - expected_class = jax.Array + expected_class = np.ndarray elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor