Minor fixes

This commit is contained in:
Francois Chollet 2023-07-13 20:09:58 -07:00
parent 7988a18f27
commit bea61227f1
3 changed files with 10 additions and 28 deletions

@ -13,6 +13,7 @@ from keras_core.api_export import keras_core_export
from keras_core.backend.common import global_state
from keras_core.saving import object_registration
from keras_core.utils import python_utils
from keras_core.utils.module_utils import tensorflow as tf
PLAIN_TYPES = (str, int, float, bool)
@ -127,11 +128,6 @@ def serialize_keras_object(obj):
A python dict that represents the object. The python dict can be
deserialized via `deserialize_keras_object()`.
"""
if backend.backend() == "tensorflow":
import tensorflow as tf
else:
tf = None
if obj is None:
return obj
@ -163,7 +159,7 @@ def serialize_keras_object(obj):
"keras_history": history,
},
}
if tf is not None and isinstance(obj, tf.TensorShape):
if tf.available and isinstance(obj, tf.TensorShape):
return obj.as_list() if obj._dims is not None else None
if backend.is_tensor(obj):
return {
@ -185,7 +181,7 @@ def serialize_keras_object(obj):
else:
# Treat numpy floats / etc as plain types.
return obj.item()
if tf is not None and isinstance(obj, tf.DType):
if tf.available and isinstance(obj, tf.DType):
return obj.name
if isinstance(obj, types.FunctionType) and obj.__name__ == "<lambda>":
warnings.warn(
@ -203,7 +199,7 @@ def serialize_keras_object(obj):
"value": python_utils.func_dump(obj),
},
}
if tf is not None and isinstance(obj, tf.TypeSpec):
if tf.available and isinstance(obj, tf.TypeSpec):
ts_config = obj._serialize()
# TensorShape and tf.DType conversion
ts_config = list(
@ -472,11 +468,6 @@ def deserialize_keras_object(
Returns:
The object described by the `config` dictionary.
"""
if backend.backend() == "tensorflow":
import tensorflow as tf
else:
tf = None
safe_scope_arg = in_safe_mode() # Enforces SafeModeScope
safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode

@ -46,6 +46,7 @@ from keras_core.trainers.data_adapters import generator_data_adapter
from keras_core.trainers.data_adapters import py_dataset_adapter
from keras_core.trainers.data_adapters import tf_dataset_adapter
from keras_core.trainers.data_adapters import torch_data_adapter
from keras_core.utils.module_utils import tensorflow as tf
class EpochIterator:
@ -60,11 +61,6 @@ class EpochIterator:
class_weight=None,
steps_per_execution=1,
):
try:
import tensorflow as tf
except ImportError:
tf = None
self.steps_per_epoch = steps_per_epoch
self.steps_per_execution = steps_per_execution
if steps_per_epoch:
@ -80,7 +76,7 @@ class EpochIterator:
batch_size=batch_size,
steps=steps_per_epoch,
)
elif tf is not None and isinstance(x, tf.data.Dataset):
elif tf.available and isinstance(x, tf.data.Dataset):
self.data_adapter = tf_dataset_adapter.TFDatasetAdapter(
x, class_weight=class_weight
)
@ -164,11 +160,6 @@ class EpochIterator:
return iterator
def enumerate_epoch(self, return_type="np"):
try:
import tensorflow as tf
except ImportError:
tf = None
buffer = []
if self.steps_per_epoch:
if not self._current_iterator:
@ -179,10 +170,10 @@ class EpochIterator:
if self._insufficient_data:
break
if tf is None:
errors = (StopIteration,)
else:
if tf.available:
errors = (StopIteration, tf.errors.OutOfRangeError)
else:
errors = (StopIteration,)
try:
data = next(self._current_iterator)

@ -771,7 +771,7 @@ def resolve_auto_jit_compile(model):
def model_supports_jit(model):
if platform.system() == "Darwin" and "arm" in platform.processor().lower():
if backend.backend() == "tensorflow":
import tensorflow as tf
from keras_core.utils.module_utils import tensorflow as tf
if tf.config.list_physical_devices("GPU"):
return False