keras/keras_core/backend/common.py

138 lines
3.7 KiB
Python
Raw Normal View History

2023-04-09 19:21:45 +00:00
from keras_core.backend.config import floatx
from tensorflow import nest
import threading
class KerasVariable:
def __init__(self, value, dtype, trainable=True, name=None):
raise NotImplementedError
@property
def value(self):
raise NotImplementedError
@property
def dtype(self):
raise NotImplementedError
@property
def shape(self):
raise NotImplementedError
@property
def ndim(self):
raise NotImplementedError
def numpy(self):
raise NotImplementedError
def assign(self, value):
raise NotImplementedError
def __repr__(self):
return f"<KerasVariable shape={self.shape}, dtype={self.dtype}, name={self.name}>"
2023-04-09 19:21:45 +00:00
ALLOWED_DTYPES = {
"float16",
"float32",
"float64",
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"bfloat16",
}
def standardize_dtype(dtype):
if dtype is None:
return floatx()
if hasattr(dtype, "name"):
dtype = dtype.name
if dtype not in ALLOWED_DTYPES:
raise ValueError(f"Invalid dtype: {dtype}")
return dtype
def standardize_shape(shape, fully_defined=False):
if not isinstance(shape, tuple):
if shape is None:
raise ValueError("Undefined shapes are not supported.")
if not hasattr(shape, "__iter__"):
raise ValueError(f"Cannot convert '{shape}' to a shape.")
shape = tuple(shape)
for e in shape:
if not fully_defined and e is None:
continue
if not isinstance(e, int):
raise ValueError(
f"Cannot convert '{shape}' to a shape. Found invalid entry '{e}'"
)
if e < 0:
raise ValueError(
f"Cannot convert '{shape}' to a shape. Negative dimensions are not allowed."
)
return shape
### Stateless context manager
GLOBAL_SCOPE_TRACKER = threading.local()
class StatelessScope:
def __init__(self, state_mapping=None, collect_losses=False):
from keras_core import backend
self.collect_losses = collect_losses
self.losses = []
self.state_mapping = {}
state_mapping = state_mapping or {}
for k, v in state_mapping:
if not isinstance(k, KerasVariable):
raise ValueError(
"Invalid reference variable in VariableSwapScope: "
"all keys in argument `mapping` must be KerasVariable "
f"instances. Received instead: {k}"
)
v = backend.convert_to_tensor(v, dtype=k.dtype)
if k.shape != v.shape:
raise ValueError(
"Invalid variable value in VariableSwapScope: "
"all values in argument `mapping` must be tensors with "
"a shape that matches the corresponding variable shape. "
f"For variable {k}, received invalid value {v} with shape {v.shape}."
)
self.state_mapping[id(k)] = v
def __enter__(self):
self.original_scope = get_stateless_scope()
GLOBAL_SCOPE_TRACKER.stateless_scope = self
return self
def add_loss(self, loss):
self.losses.append(loss)
def add_update(self, update):
variable, value = update
self.state_mapping[id(variable)] = value
def get_current_value(self, variable):
return self.state_mapping.get(id(variable), None)
def __exit__(self, *args, **kwargs):
GLOBAL_SCOPE_TRACKER.stateless_scope = self.original_scope
def in_stateless_scope():
return getattr(GLOBAL_SCOPE_TRACKER, "stateless_scope", None) is not None
def get_stateless_scope():
return getattr(GLOBAL_SCOPE_TRACKER, "stateless_scope", None)