Further reduce TF footprint. Now only attempt to import TF if saving to GCS or if using KPL.
This commit is contained in:
parent
e395f16bfd
commit
7b97f71622
@ -83,7 +83,7 @@ def clear_session():
|
||||
GLOBAL_SETTINGS_TRACKER = threading.local()
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
if tf.executing_eagerly():
|
||||
|
@ -2,6 +2,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.CategoryEncoding")
|
||||
@ -83,9 +84,7 @@ class CategoryEncoding(Layer):
|
||||
"""
|
||||
|
||||
def __init__(self, num_tokens=None, output_mode="multi_hot", **kwargs):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer CategoryEncoding requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.Discretization")
|
||||
@ -102,9 +103,7 @@ class Discretization(Layer):
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer Discretization requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -193,8 +192,6 @@ class Discretization(Layer):
|
||||
return backend.KerasTensor(shape=inputs.shape, dtype="int32")
|
||||
|
||||
def __call__(self, inputs):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, backend.KerasTensor)):
|
||||
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
|
||||
if not self.built:
|
||||
|
@ -2,6 +2,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.HashedCrossing")
|
||||
@ -74,9 +75,7 @@ class HashedCrossing(Layer):
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer HashedCrossing requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.Hashing")
|
||||
@ -143,9 +144,7 @@ class Hashing(Layer):
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer Hashing requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -183,8 +182,6 @@ class Hashing(Layer):
|
||||
self.supports_jit = False
|
||||
|
||||
def call(self, inputs):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(np.array(inputs))
|
||||
outputs = self.layer.call(inputs)
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.IntegerLookup")
|
||||
@ -309,9 +310,7 @@ class IntegerLookup(Layer):
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer IntegerLookup requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -449,8 +448,6 @@ class IntegerLookup(Layer):
|
||||
self.layer.set_vocabulary(vocabulary, idf_weights=idf_weights)
|
||||
|
||||
def call(self, inputs):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(np.array(inputs))
|
||||
outputs = self.layer.call(inputs)
|
||||
|
@ -6,6 +6,7 @@ from keras_core import backend
|
||||
from keras_core import ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.Normalization")
|
||||
@ -216,8 +217,6 @@ class Normalization(Layer):
|
||||
data is batched, and if that assumption doesn't hold, the mean
|
||||
and variance may be incorrectly computed.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
if isinstance(data, np.ndarray) or backend.is_tensor(data):
|
||||
input_shape = data.shape
|
||||
elif isinstance(data, tf.data.Dataset):
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.RandomCrop")
|
||||
@ -52,9 +53,7 @@ class RandomCrop(Layer):
|
||||
"""
|
||||
|
||||
def __init__(self, height, width, seed=None, name=None, **kwargs):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer RandomCrop requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -74,8 +73,6 @@ class RandomCrop(Layer):
|
||||
self._allow_non_tensor_positional_args = True
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
|
||||
outputs = self.layer.call(inputs, training=training)
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
HORIZONTAL = "horizontal"
|
||||
VERTICAL = "vertical"
|
||||
@ -45,9 +46,7 @@ class RandomFlip(Layer):
|
||||
def __init__(
|
||||
self, mode=HORIZONTAL_AND_VERTICAL, seed=None, name=None, **kwargs
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer RandomFlip requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -66,8 +65,6 @@ class RandomFlip(Layer):
|
||||
self._allow_non_tensor_positional_args = True
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
|
||||
outputs = self.layer.call(inputs, training=training)
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.RandomRotation")
|
||||
@ -85,9 +86,7 @@ class RandomRotation(Layer):
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer RandomRotation requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -109,8 +108,6 @@ class RandomRotation(Layer):
|
||||
self._allow_non_tensor_positional_args = True
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
|
||||
outputs = self.layer.call(inputs, training=training)
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.RandomTranslation")
|
||||
@ -75,9 +76,7 @@ class RandomTranslation(Layer):
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer RandomTranslation requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -99,8 +98,6 @@ class RandomTranslation(Layer):
|
||||
self._allow_non_tensor_positional_args = True
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
|
||||
outputs = self.layer.call(inputs, training=training)
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.RandomZoom")
|
||||
@ -97,9 +98,7 @@ class RandomZoom(Layer):
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer RandomZoom requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -122,8 +121,6 @@ class RandomZoom(Layer):
|
||||
self.supports_jit = False
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
|
||||
outputs = self.layer.call(inputs, training=training)
|
||||
|
@ -4,6 +4,7 @@ from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.StringLookup")
|
||||
@ -307,9 +308,7 @@ class StringLookup(Layer):
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer StringLookup requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -443,8 +442,6 @@ class StringLookup(Layer):
|
||||
self.layer.set_vocabulary(vocabulary, idf_weights=idf_weights)
|
||||
|
||||
def call(self, inputs):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(np.array(inputs))
|
||||
outputs = self.layer.call(inputs)
|
||||
|
@ -5,6 +5,7 @@ from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.saving import serialization_lib
|
||||
from keras_core.utils import backend_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.TextVectorization")
|
||||
@ -212,9 +213,7 @@ class TextVectorization(Layer):
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
if not tf.available:
|
||||
raise ImportError(
|
||||
"Layer TextVectorization requires TensorFlow. "
|
||||
"Install it via `pip install tensorflow`."
|
||||
@ -366,8 +365,6 @@ class TextVectorization(Layer):
|
||||
self.layer.set_vocabulary(vocabulary, idf_weights=idf_weights)
|
||||
|
||||
def call(self, inputs):
|
||||
import tensorflow as tf
|
||||
|
||||
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||
inputs = tf.convert_to_tensor(np.array(inputs))
|
||||
outputs = self.layer.call(inputs)
|
||||
|
@ -109,7 +109,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
||||
|
||||
# Support for remote zip files
|
||||
if (
|
||||
saving_lib.is_remote_path(filepath)
|
||||
file_utils.is_remote_path(filepath)
|
||||
and not file_utils.isdir(filepath)
|
||||
and not is_keras_zip
|
||||
):
|
||||
|
@ -4,7 +4,6 @@ import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import warnings
|
||||
import zipfile
|
||||
@ -87,7 +86,7 @@ def save_model(model, filepath, weights_format="h5"):
|
||||
}
|
||||
)
|
||||
# TODO(rameshsampath): Need a better logic for local vs remote path
|
||||
if is_remote_path(filepath):
|
||||
if file_utils.is_remote_path(filepath):
|
||||
# Remote path. Zip to local drive and copy to remote
|
||||
zip_filepath = os.path.join(get_temp_dir(), "tmp_model.keras")
|
||||
else:
|
||||
@ -124,7 +123,7 @@ def save_model(model, filepath, weights_format="h5"):
|
||||
weights_store.close()
|
||||
asset_store.close()
|
||||
|
||||
if is_remote_path(filepath):
|
||||
if file_utils.is_remote_path(filepath):
|
||||
# Using gfile context manager doesn't close zip file when
|
||||
# writing to GCS. Hence writing to local and copying to filepath.
|
||||
file_utils.copy(zip_filepath, filepath, overwrite=True)
|
||||
@ -245,12 +244,6 @@ def load_weights_only(model, filepath, skip_mismatch=False):
|
||||
archive.close()
|
||||
|
||||
|
||||
def is_remote_path(filepath):
|
||||
if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path):
|
||||
if not file_utils.isdir(system_path):
|
||||
zipfile_to_save.write(system_path, zip_path)
|
||||
|
@ -107,7 +107,7 @@ class ArrayDataAdapter(DataAdapter):
|
||||
yield tree.map_structure(lambda x: x[start:stop], inputs)
|
||||
|
||||
def get_tf_dataset(self):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
inputs = self._inputs
|
||||
shuffle = self._shuffle
|
||||
@ -302,7 +302,7 @@ def convert_to_arrays(arrays, dtype=None):
|
||||
elif isinstance(x, pandas.DataFrame):
|
||||
x = x.to_numpy(dtype=dtype)
|
||||
if is_tf_ragged_tensor(x):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
return tf.cast(x, dtype=dtype)
|
||||
if not isinstance(x, np.ndarray):
|
||||
|
@ -17,7 +17,7 @@ except ImportError:
|
||||
# backend framework we are not currently using just to do type-checking.
|
||||
ARRAY_TYPES = (np.ndarray,)
|
||||
if backend.backend() == "tensorflow":
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
ARRAY_TYPES = ARRAY_TYPES + (np.ndarray, tf.RaggedTensor)
|
||||
if pandas:
|
||||
|
@ -22,7 +22,7 @@ class GeneratorDataAdapter(DataAdapter):
|
||||
)
|
||||
|
||||
def _set_tf_output_signature(self):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
data, generator = peek_and_restore(self.generator)
|
||||
self.generator = generator
|
||||
@ -47,7 +47,7 @@ class GeneratorDataAdapter(DataAdapter):
|
||||
yield batch
|
||||
|
||||
def get_tf_dataset(self):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
if self._output_signature is None:
|
||||
self._set_tf_output_signature()
|
||||
|
@ -188,7 +188,7 @@ class PyDatasetAdapter(DataAdapter):
|
||||
self._output_signature = None
|
||||
|
||||
def _set_tf_output_signature(self):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
def get_tensor_spec(x):
|
||||
shape = x.shape
|
||||
@ -273,7 +273,7 @@ class PyDatasetAdapter(DataAdapter):
|
||||
self.enqueuer.stop()
|
||||
|
||||
def get_tf_dataset(self):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
if self._output_signature is None:
|
||||
self._set_tf_output_signature()
|
||||
|
@ -8,7 +8,7 @@ class TFDatasetAdapter(DataAdapter):
|
||||
"""Adapter that handles `tf.data.Dataset`."""
|
||||
|
||||
def __init__(self, dataset, class_weight=None):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
if not isinstance(dataset, tf.data.Dataset):
|
||||
raise ValueError(
|
||||
@ -64,7 +64,7 @@ def make_class_weight_map_fn(class_weight):
|
||||
A function that can be used with `tf.data.Dataset.map` to apply class
|
||||
weighting.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
class_weight_tensor = tf.convert_to_tensor(
|
||||
[
|
||||
|
@ -28,7 +28,7 @@ class TorchDataLoaderAdapter(DataAdapter):
|
||||
return self._dataloader
|
||||
|
||||
def get_tf_dataset(self):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
output_signature = self.peek_and_get_tensor_spec()
|
||||
return tf.data.Dataset.from_generator(
|
||||
@ -37,7 +37,7 @@ class TorchDataLoaderAdapter(DataAdapter):
|
||||
)
|
||||
|
||||
def peek_and_get_tensor_spec(self):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
batch_data = next(iter(self._dataloader))
|
||||
|
||||
|
@ -513,7 +513,7 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
|
||||
reason="Only tensorflow supports raggeds",
|
||||
)
|
||||
def test_trainer_with_raggeds(self, model_class):
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
def loss_fn(y, y_pred, sample_weight=None):
|
||||
return 0
|
||||
|
@ -5,7 +5,7 @@ from keras_core import backend as backend_module
|
||||
|
||||
def in_tf_graph():
|
||||
if "tensorflow" in sys.modules:
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
return not tf.executing_eagerly()
|
||||
return False
|
||||
|
@ -1,4 +1,5 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@keras_core_export("keras_core.utils.split_dataset")
|
||||
@ -35,8 +36,6 @@ def split_dataset(
|
||||
>>> int(right_ds.cardinality())
|
||||
200
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO: long-term, port implementation.
|
||||
return tf.keras.utils.split_dataset(
|
||||
dataset,
|
||||
@ -181,8 +180,6 @@ def image_dataset_from_directory(
|
||||
- if `color_mode` is `"rgba"`,
|
||||
there are 4 channels in the image tensors.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO: long-term, port implementation.
|
||||
return tf.keras.utils.image_dataset_from_directory(
|
||||
directory,
|
||||
@ -331,8 +328,6 @@ def timeseries_dataset_from_array(
|
||||
break
|
||||
```
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO: long-term, port implementation.
|
||||
return tf.keras.utils.timeseries_dataset_from_array(
|
||||
data,
|
||||
@ -452,8 +447,6 @@ def text_dataset_from_directory(
|
||||
of shape `(batch_size, num_classes)`, representing a one-hot
|
||||
encoding of the class index.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO: long-term, port implementation.
|
||||
return tf.keras.utils.text_dataset_from_directory(
|
||||
directory,
|
||||
@ -574,8 +567,6 @@ def audio_dataset_from_directory(
|
||||
of shape `(batch_size, num_classes)`, representing a one-hot
|
||||
encoding of the class index.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO: long-term, port implementation.
|
||||
return tf.keras.utils.audio_dataset_from_directory(
|
||||
directory,
|
||||
|
@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import tarfile
|
||||
import urllib
|
||||
@ -382,52 +383,91 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535):
|
||||
return False
|
||||
|
||||
|
||||
# Below are gfile utils
|
||||
def is_remote_path(filepath):
|
||||
"""Returns `True` for paths that represent a remote GCS location."""
|
||||
# TODO: improve generality.
|
||||
if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Below are gfile-replacement utils.
|
||||
|
||||
|
||||
def _raise_if_no_gfile(path):
|
||||
raise ValueError(
|
||||
"Handling remote paths requires installing TensorFlow "
|
||||
f"(in order to use gfile). Received path: {path}"
|
||||
)
|
||||
|
||||
|
||||
def exists(path):
|
||||
if gfile.available:
|
||||
return gfile.exists(path)
|
||||
if is_remote_path(path):
|
||||
if gfile.available:
|
||||
return gfile.exists(path)
|
||||
else:
|
||||
_raise_if_no_gfile(path)
|
||||
return os.path.exists(path)
|
||||
|
||||
|
||||
def File(fname, mode="r"):
|
||||
if gfile.available:
|
||||
return gfile.GFile(fname, mode=mode)
|
||||
return open(fname, mode=mode)
|
||||
def File(path, mode="r"):
|
||||
if is_remote_path(path):
|
||||
if gfile.available:
|
||||
return gfile.GFile(path, mode=mode)
|
||||
else:
|
||||
_raise_if_no_gfile(path)
|
||||
return open(path, mode=mode)
|
||||
|
||||
|
||||
def join(path, *paths):
|
||||
if gfile.available:
|
||||
return gfile.join(path, *paths)
|
||||
if is_remote_path(path):
|
||||
if gfile.available:
|
||||
return gfile.join(path, *paths)
|
||||
else:
|
||||
_raise_if_no_gfile(path)
|
||||
return os.path.join(path, *paths)
|
||||
|
||||
|
||||
def isdir(path):
|
||||
if gfile.available:
|
||||
return gfile.isdir(path)
|
||||
if is_remote_path(path):
|
||||
if gfile.available:
|
||||
return gfile.isdir(path)
|
||||
else:
|
||||
_raise_if_no_gfile(path)
|
||||
return os.path.isdir(path)
|
||||
|
||||
|
||||
def rmtree(path):
|
||||
if gfile.available:
|
||||
return gfile.rmtree(path)
|
||||
if is_remote_path(path):
|
||||
if gfile.available:
|
||||
return gfile.rmtree(path)
|
||||
else:
|
||||
_raise_if_no_gfile(path)
|
||||
return shutil.rmtree
|
||||
|
||||
|
||||
def listdir(path):
|
||||
if gfile.available:
|
||||
return gfile.listdir(path)
|
||||
if is_remote_path(path):
|
||||
if gfile.available:
|
||||
return gfile.listdir(path)
|
||||
else:
|
||||
_raise_if_no_gfile(path)
|
||||
return os.listdir(path)
|
||||
|
||||
|
||||
def copy(src, dst):
|
||||
if gfile.available:
|
||||
return gfile.copy(src, dst)
|
||||
if is_remote_path(src) or is_remote_path(dst):
|
||||
if gfile.available:
|
||||
return gfile.copy(src, dst)
|
||||
else:
|
||||
_raise_if_no_gfile(f"src={src} dst={dst}")
|
||||
return shutil.copy(src, dst)
|
||||
|
||||
|
||||
def makedirs(path):
|
||||
if gfile.available:
|
||||
return gfile.makedirs(path)
|
||||
if is_remote_path(path):
|
||||
if gfile.available:
|
||||
return gfile.makedirs(path)
|
||||
else:
|
||||
_raise_if_no_gfile(path)
|
||||
return os.makedirs(path)
|
||||
|
@ -2,6 +2,7 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
@ -21,7 +22,7 @@ def set_random_seed(seed):
|
||||
```python
|
||||
import random
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
@ -42,5 +43,9 @@ def set_random_seed(seed):
|
||||
)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
# TODO: also seed other backends.
|
||||
if tf.available:
|
||||
tf.random.set_seed(seed)
|
||||
if backend.backend() == "torch":
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
Loading…
Reference in New Issue
Block a user