f6df67f2d2
* Add jax/numpy and tensorflow/numpy * refactor code * more * even better
213 lines
5.6 KiB
Python
213 lines
5.6 KiB
Python
import json
|
|
import os
|
|
|
|
from keras_core.utils.io_utils import print_msg
|
|
|
|
# The type of float to use throughout a session.
|
|
_FLOATX = "float32"
|
|
|
|
# Epsilon fuzz factor used throughout the codebase.
|
|
_EPSILON = 1e-7
|
|
|
|
# Default image data format, one of "channels_last", "channels_first".
|
|
_IMAGE_DATA_FORMAT = "channels_last"
|
|
|
|
# Default backend: TensorFlow.
|
|
_BACKEND = "tensorflow"
|
|
|
|
|
|
def floatx():
|
|
"""Return the default float type, as a string.
|
|
|
|
E.g. `'float16'`, `'float32'`, `'float64'`.
|
|
|
|
Returns:
|
|
String, the current default float type.
|
|
|
|
Example:
|
|
>>> keras_core.backend.floatx()
|
|
'float32'
|
|
"""
|
|
return _FLOATX
|
|
|
|
|
|
def set_floatx(value):
|
|
"""Set the default float type.
|
|
|
|
Note: It is not recommended to set this to float16 for training, as this
|
|
will likely cause numeric stability issues. Instead, mixed precision, which
|
|
is using a mix of float16 and float32, can be used by calling
|
|
`keras_core.mixed_precision.set_global_policy('mixed_float16')`. See the
|
|
[mixed precision guide](
|
|
https://www.tensorflow.org/guide/keras/mixed_precision) for details.
|
|
|
|
Args:
|
|
value: String; `'float16'`, `'float32'`, or `'float64'`.
|
|
|
|
Example:
|
|
>>> keras_core.backend.floatx()
|
|
'float32'
|
|
>>> keras_core.backend.set_floatx('float64')
|
|
>>> keras_core.backend.floatx()
|
|
'float64'
|
|
>>> keras_core.backend.set_floatx('float32')
|
|
|
|
Raises:
|
|
ValueError: In case of invalid value.
|
|
"""
|
|
global _FLOATX
|
|
accepted_dtypes = {"float16", "float32", "float64"}
|
|
if value not in accepted_dtypes:
|
|
raise ValueError(
|
|
f"Unknown `floatx` value: {value}. "
|
|
f"Expected one of {accepted_dtypes}"
|
|
)
|
|
_FLOATX = str(value)
|
|
|
|
|
|
def epsilon():
|
|
"""Return the value of the fuzz factor used in numeric expressions.
|
|
|
|
Returns:
|
|
A float.
|
|
|
|
Example:
|
|
>>> keras_core.backend.epsilon()
|
|
1e-07
|
|
"""
|
|
return _EPSILON
|
|
|
|
|
|
def set_epsilon(value):
|
|
"""Set the value of the fuzz factor used in numeric expressions.
|
|
|
|
Args:
|
|
value: float. New value of epsilon.
|
|
|
|
Example:
|
|
>>> keras_core.backend.epsilon()
|
|
1e-07
|
|
>>> keras_core.backend.set_epsilon(1e-5)
|
|
>>> keras_core.backend.epsilon()
|
|
1e-05
|
|
>>> keras_core.backend.set_epsilon(1e-7)
|
|
"""
|
|
global _EPSILON
|
|
_EPSILON = value
|
|
|
|
|
|
def image_data_format():
|
|
"""Return the default image data format convention.
|
|
|
|
Returns:
|
|
A string, either `'channels_first'` or `'channels_last'`
|
|
|
|
Example:
|
|
>>> keras_core.backend.image_data_format()
|
|
'channels_last'
|
|
"""
|
|
return _IMAGE_DATA_FORMAT
|
|
|
|
|
|
def set_image_data_format(data_format):
|
|
"""Set the value of the image data format convention.
|
|
|
|
Args:
|
|
data_format: string. `'channels_first'` or `'channels_last'`.
|
|
|
|
Example:
|
|
>>> keras_core.backend.image_data_format()
|
|
'channels_last'
|
|
>>> keras_core.backend.set_image_data_format('channels_first')
|
|
>>> keras_core.backend.image_data_format()
|
|
'channels_first'
|
|
>>> keras_core.backend.set_image_data_format('channels_last')
|
|
|
|
Raises:
|
|
ValueError: In case of invalid `data_format` value.
|
|
"""
|
|
global _IMAGE_DATA_FORMAT
|
|
accepted_formats = {"channels_last", "channels_first"}
|
|
if data_format not in accepted_formats:
|
|
raise ValueError(
|
|
f"Unknown `data_format`: {data_format}. "
|
|
f"Expected one of {accepted_formats}"
|
|
)
|
|
_IMAGE_DATA_FORMAT = str(data_format)
|
|
|
|
|
|
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
|
|
# Otherwise either ~/.keras or /tmp.
|
|
if "KERAS_HOME" in os.environ:
|
|
_keras_dir = os.environ.get("KERAS_HOME")
|
|
else:
|
|
_keras_base_dir = os.path.expanduser("~")
|
|
if not os.access(_keras_base_dir, os.W_OK):
|
|
_keras_base_dir = "/tmp"
|
|
_keras_dir = os.path.join(_keras_base_dir, ".keras")
|
|
|
|
|
|
# Attempt to read Keras config file.
|
|
_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras.json"))
|
|
if os.path.exists(_config_path):
|
|
try:
|
|
with open(_config_path) as f:
|
|
_config = json.load(f)
|
|
except ValueError:
|
|
_config = {}
|
|
_floatx = _config.get("floatx", floatx())
|
|
assert _floatx in {"float16", "float32", "float64"}
|
|
_epsilon = _config.get("epsilon", epsilon())
|
|
assert isinstance(_epsilon, float)
|
|
_backend = _config.get("backend", _BACKEND)
|
|
_image_data_format = _config.get("image_data_format", image_data_format())
|
|
assert _image_data_format in {"channels_last", "channels_first"}
|
|
|
|
set_floatx(_floatx)
|
|
set_epsilon(_epsilon)
|
|
set_image_data_format(_image_data_format)
|
|
_BACKEND = _backend
|
|
|
|
# Save config file, if possible.
|
|
if not os.path.exists(_keras_dir):
|
|
try:
|
|
os.makedirs(_keras_dir)
|
|
except OSError:
|
|
# Except permission denied and potential race conditions
|
|
# in multi-threaded environments.
|
|
pass
|
|
|
|
if not os.path.exists(_config_path):
|
|
_config = {
|
|
"floatx": floatx(),
|
|
"epsilon": epsilon(),
|
|
"backend": _BACKEND,
|
|
"image_data_format": image_data_format(),
|
|
}
|
|
try:
|
|
with open(_config_path, "w") as f:
|
|
f.write(json.dumps(_config, indent=4))
|
|
except IOError:
|
|
# Except permission denied.
|
|
pass
|
|
|
|
# Set backend based on KERAS_BACKEND flag, if applicable.
|
|
if "KERAS_BACKEND" in os.environ:
|
|
_backend = os.environ["KERAS_BACKEND"]
|
|
if _backend:
|
|
_BACKEND = _backend
|
|
|
|
|
|
def backend():
|
|
"""Publicly accessible method for determining the current backend.
|
|
|
|
Returns:
|
|
String, the name of the backend Keras is currently using.
|
|
|
|
Example:
|
|
|
|
>>> keras.backend.backend()
|
|
'tensorflow'
|
|
"""
|
|
return _BACKEND
|