2023-04-18 15:52:21 +00:00
|
|
|
import threading
|
|
|
|
|
2023-04-18 21:49:38 +00:00
|
|
|
from keras_core.backend.common import KerasVariable
|
2023-04-18 15:52:21 +00:00
|
|
|
|
2023-04-20 01:38:56 +00:00
|
|
|
### Stateless context manager
|
|
|
|
|
2023-04-18 15:52:21 +00:00
|
|
|
GLOBAL_SCOPE_TRACKER = threading.local()
|
|
|
|
|
|
|
|
|
|
|
|
class StatelessScope:
|
|
|
|
def __init__(self, state_mapping=None, collect_losses=False):
|
|
|
|
from keras_core import backend
|
|
|
|
|
|
|
|
self.collect_losses = collect_losses
|
|
|
|
self.losses = []
|
|
|
|
self.state_mapping = {}
|
|
|
|
state_mapping = state_mapping or {}
|
|
|
|
for k, v in state_mapping:
|
|
|
|
if not isinstance(k, KerasVariable):
|
|
|
|
raise ValueError(
|
|
|
|
"Invalid reference variable in VariableSwapScope: "
|
|
|
|
"all keys in argument `mapping` must be KerasVariable "
|
|
|
|
f"instances. Received instead: {k}"
|
|
|
|
)
|
|
|
|
v = backend.convert_to_tensor(v, dtype=k.dtype)
|
|
|
|
if k.shape != v.shape:
|
|
|
|
raise ValueError(
|
|
|
|
"Invalid variable value in VariableSwapScope: "
|
|
|
|
"all values in argument `mapping` must be tensors with "
|
|
|
|
"a shape that matches the corresponding variable shape. "
|
2023-04-22 06:16:51 +00:00
|
|
|
f"For variable {k}, received invalid value {v} with shape "
|
|
|
|
f"{v.shape}."
|
2023-04-18 15:52:21 +00:00
|
|
|
)
|
|
|
|
self.state_mapping[id(k)] = v
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
self.original_scope = get_stateless_scope()
|
|
|
|
GLOBAL_SCOPE_TRACKER.stateless_scope = self
|
|
|
|
return self
|
|
|
|
|
|
|
|
def add_loss(self, loss):
|
|
|
|
self.losses.append(loss)
|
|
|
|
|
|
|
|
def add_update(self, update):
|
|
|
|
variable, value = update
|
|
|
|
self.state_mapping[id(variable)] = value
|
|
|
|
|
|
|
|
def get_current_value(self, variable):
|
|
|
|
return self.state_mapping.get(id(variable), None)
|
|
|
|
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
|
|
GLOBAL_SCOPE_TRACKER.stateless_scope = self.original_scope
|
|
|
|
|
|
|
|
|
|
|
|
def in_stateless_scope():
|
|
|
|
return getattr(GLOBAL_SCOPE_TRACKER, "stateless_scope", None) is not None
|
|
|
|
|
|
|
|
|
|
|
|
def get_stateless_scope():
|
|
|
|
return getattr(GLOBAL_SCOPE_TRACKER, "stateless_scope", None)
|