Minor fixes
This commit is contained in:
parent
7988a18f27
commit
bea61227f1
@ -13,6 +13,7 @@ from keras_core.api_export import keras_core_export
|
||||
from keras_core.backend.common import global_state
|
||||
from keras_core.saving import object_registration
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
PLAIN_TYPES = (str, int, float, bool)
|
||||
|
||||
@ -127,11 +128,6 @@ def serialize_keras_object(obj):
|
||||
A python dict that represents the object. The python dict can be
|
||||
deserialized via `deserialize_keras_object()`.
|
||||
"""
|
||||
if backend.backend() == "tensorflow":
|
||||
import tensorflow as tf
|
||||
else:
|
||||
tf = None
|
||||
|
||||
if obj is None:
|
||||
return obj
|
||||
|
||||
@ -163,7 +159,7 @@ def serialize_keras_object(obj):
|
||||
"keras_history": history,
|
||||
},
|
||||
}
|
||||
if tf is not None and isinstance(obj, tf.TensorShape):
|
||||
if tf.available and isinstance(obj, tf.TensorShape):
|
||||
return obj.as_list() if obj._dims is not None else None
|
||||
if backend.is_tensor(obj):
|
||||
return {
|
||||
@ -185,7 +181,7 @@ def serialize_keras_object(obj):
|
||||
else:
|
||||
# Treat numpy floats / etc as plain types.
|
||||
return obj.item()
|
||||
if tf is not None and isinstance(obj, tf.DType):
|
||||
if tf.available and isinstance(obj, tf.DType):
|
||||
return obj.name
|
||||
if isinstance(obj, types.FunctionType) and obj.__name__ == "<lambda>":
|
||||
warnings.warn(
|
||||
@ -203,7 +199,7 @@ def serialize_keras_object(obj):
|
||||
"value": python_utils.func_dump(obj),
|
||||
},
|
||||
}
|
||||
if tf is not None and isinstance(obj, tf.TypeSpec):
|
||||
if tf.available and isinstance(obj, tf.TypeSpec):
|
||||
ts_config = obj._serialize()
|
||||
# TensorShape and tf.DType conversion
|
||||
ts_config = list(
|
||||
@ -472,11 +468,6 @@ def deserialize_keras_object(
|
||||
Returns:
|
||||
The object described by the `config` dictionary.
|
||||
"""
|
||||
if backend.backend() == "tensorflow":
|
||||
import tensorflow as tf
|
||||
else:
|
||||
tf = None
|
||||
|
||||
safe_scope_arg = in_safe_mode() # Enforces SafeModeScope
|
||||
safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode
|
||||
|
||||
|
@ -46,6 +46,7 @@ from keras_core.trainers.data_adapters import generator_data_adapter
|
||||
from keras_core.trainers.data_adapters import py_dataset_adapter
|
||||
from keras_core.trainers.data_adapters import tf_dataset_adapter
|
||||
from keras_core.trainers.data_adapters import torch_data_adapter
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
class EpochIterator:
|
||||
@ -60,11 +61,6 @@ class EpochIterator:
|
||||
class_weight=None,
|
||||
steps_per_execution=1,
|
||||
):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
tf = None
|
||||
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.steps_per_execution = steps_per_execution
|
||||
if steps_per_epoch:
|
||||
@ -80,7 +76,7 @@ class EpochIterator:
|
||||
batch_size=batch_size,
|
||||
steps=steps_per_epoch,
|
||||
)
|
||||
elif tf is not None and isinstance(x, tf.data.Dataset):
|
||||
elif tf.available and isinstance(x, tf.data.Dataset):
|
||||
self.data_adapter = tf_dataset_adapter.TFDatasetAdapter(
|
||||
x, class_weight=class_weight
|
||||
)
|
||||
@ -164,11 +160,6 @@ class EpochIterator:
|
||||
return iterator
|
||||
|
||||
def enumerate_epoch(self, return_type="np"):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
tf = None
|
||||
|
||||
buffer = []
|
||||
if self.steps_per_epoch:
|
||||
if not self._current_iterator:
|
||||
@ -179,10 +170,10 @@ class EpochIterator:
|
||||
if self._insufficient_data:
|
||||
break
|
||||
|
||||
if tf is None:
|
||||
errors = (StopIteration,)
|
||||
else:
|
||||
if tf.available:
|
||||
errors = (StopIteration, tf.errors.OutOfRangeError)
|
||||
else:
|
||||
errors = (StopIteration,)
|
||||
|
||||
try:
|
||||
data = next(self._current_iterator)
|
||||
|
@ -771,7 +771,7 @@ def resolve_auto_jit_compile(model):
|
||||
def model_supports_jit(model):
|
||||
if platform.system() == "Darwin" and "arm" in platform.processor().lower():
|
||||
if backend.backend() == "tensorflow":
|
||||
import tensorflow as tf
|
||||
from keras_core.utils.module_utils import tensorflow as tf
|
||||
|
||||
if tf.config.list_physical_devices("GPU"):
|
||||
return False
|
||||
|
Loading…
Reference in New Issue
Block a user