From bea61227f1c2dffe6d475e401b0e18037c0f7601 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 13 Jul 2023 20:09:58 -0700 Subject: [PATCH] Minor fixes --- keras_core/saving/serialization_lib.py | 17 ++++------------- keras_core/trainers/epoch_iterator.py | 19 +++++-------------- keras_core/trainers/trainer.py | 2 +- 3 files changed, 10 insertions(+), 28 deletions(-) diff --git a/keras_core/saving/serialization_lib.py b/keras_core/saving/serialization_lib.py index 7ad5caaff..86d8dc635 100644 --- a/keras_core/saving/serialization_lib.py +++ b/keras_core/saving/serialization_lib.py @@ -13,6 +13,7 @@ from keras_core.api_export import keras_core_export from keras_core.backend.common import global_state from keras_core.saving import object_registration from keras_core.utils import python_utils +from keras_core.utils.module_utils import tensorflow as tf PLAIN_TYPES = (str, int, float, bool) @@ -127,11 +128,6 @@ def serialize_keras_object(obj): A python dict that represents the object. The python dict can be deserialized via `deserialize_keras_object()`. """ - if backend.backend() == "tensorflow": - import tensorflow as tf - else: - tf = None - if obj is None: return obj @@ -163,7 +159,7 @@ def serialize_keras_object(obj): "keras_history": history, }, } - if tf is not None and isinstance(obj, tf.TensorShape): + if tf.available and isinstance(obj, tf.TensorShape): return obj.as_list() if obj._dims is not None else None if backend.is_tensor(obj): return { @@ -185,7 +181,7 @@ def serialize_keras_object(obj): else: # Treat numpy floats / etc as plain types. return obj.item() - if tf is not None and isinstance(obj, tf.DType): + if tf.available and isinstance(obj, tf.DType): return obj.name if isinstance(obj, types.FunctionType) and obj.__name__ == "": warnings.warn( @@ -203,7 +199,7 @@ def serialize_keras_object(obj): "value": python_utils.func_dump(obj), }, } - if tf is not None and isinstance(obj, tf.TypeSpec): + if tf.available and isinstance(obj, tf.TypeSpec): ts_config = obj._serialize() # TensorShape and tf.DType conversion ts_config = list( @@ -472,11 +468,6 @@ def deserialize_keras_object( Returns: The object described by the `config` dictionary. """ - if backend.backend() == "tensorflow": - import tensorflow as tf - else: - tf = None - safe_scope_arg = in_safe_mode() # Enforces SafeModeScope safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode diff --git a/keras_core/trainers/epoch_iterator.py b/keras_core/trainers/epoch_iterator.py index 212335229..5feb0da2e 100644 --- a/keras_core/trainers/epoch_iterator.py +++ b/keras_core/trainers/epoch_iterator.py @@ -46,6 +46,7 @@ from keras_core.trainers.data_adapters import generator_data_adapter from keras_core.trainers.data_adapters import py_dataset_adapter from keras_core.trainers.data_adapters import tf_dataset_adapter from keras_core.trainers.data_adapters import torch_data_adapter +from keras_core.utils.module_utils import tensorflow as tf class EpochIterator: @@ -60,11 +61,6 @@ class EpochIterator: class_weight=None, steps_per_execution=1, ): - try: - import tensorflow as tf - except ImportError: - tf = None - self.steps_per_epoch = steps_per_epoch self.steps_per_execution = steps_per_execution if steps_per_epoch: @@ -80,7 +76,7 @@ class EpochIterator: batch_size=batch_size, steps=steps_per_epoch, ) - elif tf is not None and isinstance(x, tf.data.Dataset): + elif tf.available and isinstance(x, tf.data.Dataset): self.data_adapter = tf_dataset_adapter.TFDatasetAdapter( x, class_weight=class_weight ) @@ -164,11 +160,6 @@ class EpochIterator: return iterator def enumerate_epoch(self, return_type="np"): - try: - import tensorflow as tf - except ImportError: - tf = None - buffer = [] if self.steps_per_epoch: if not self._current_iterator: @@ -179,10 +170,10 @@ class EpochIterator: if self._insufficient_data: break - if tf is None: - errors = (StopIteration,) - else: + if tf.available: errors = (StopIteration, tf.errors.OutOfRangeError) + else: + errors = (StopIteration,) try: data = next(self._current_iterator) diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index ec56d3bba..6eeb85ec2 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -771,7 +771,7 @@ def resolve_auto_jit_compile(model): def model_supports_jit(model): if platform.system() == "Darwin" and "arm" in platform.processor().lower(): if backend.backend() == "tensorflow": - import tensorflow as tf + from keras_core.utils.module_utils import tensorflow as tf if tf.config.list_physical_devices("GPU"): return False