keras/keras_core/utils/tracking.py
2023-07-07 16:33:53 -07:00

162 lines
4.8 KiB
Python

from functools import wraps
from keras_core.backend.common.global_state import get_global_attribute
from keras_core.backend.common.global_state import set_global_attribute
class DotNotTrackScope:
def __enter__(self):
self.original_value = is_tracking_enabled()
set_global_attribute("tracking_on", False)
def __exit__(self, *args, **kwargs):
set_global_attribute("tracking_on", self.original_value)
def is_tracking_enabled():
return get_global_attribute("tracking_on", True)
def no_automatic_dependency_tracking(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
with DotNotTrackScope():
return fn(*args, **kwargs)
return wrapper
class Tracker:
"""Attribute tracker, used for e.g. Variable tracking.
Monitors certain attribute types
and put them in appropriate lists in case of a match.
Also passively tracks certain mutable collections
(dict, list) so that items added to them later
still get tracked. This is done by wrapping these
collections into an equivalent, tracking-aware object.
Usage:
```python
def __init__(self):
self.tracker = Tracker(
# Format: `name: (test_fn, store)`
{
"variables":
(lambda x: isinstance(x, Variable), self._variables),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
"layers": (lambda x: isinstance(x, Layer), self._layers),
}
)
def __setattr__(self, name, value):
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
return super().__setattr__(name, value)
```
"""
def __init__(self, config):
self.config = config
self.stored_ids = {name: set() for name in self.config.keys()}
self.locked = False
self._lock_violation_msg = None
def track(self, attr):
if not is_tracking_enabled():
return attr
for name, (is_attr_type, _) in self.config.items():
if is_attr_type(attr):
if id(attr) not in self.stored_ids[name]:
self.add_to_store(name, attr)
return attr
if isinstance(attr, tuple):
wrapped_attr = []
for e in attr:
wrapped_attr.append(self.track(e))
# This should cover tuples and nametuples
return attr.__class__(wrapped_attr)
elif isinstance(attr, list):
return TrackedList(attr, self)
elif isinstance(attr, dict):
# TODO: OrderedDict?
return TrackedDict(attr, self)
elif isinstance(attr, set):
return TrackedSet(attr, self)
return attr
def lock(self, msg):
self.locked = True
self._lock_violation_msg = msg
def add_to_store(self, store_name, value):
if self.locked:
raise ValueError(self._lock_violation_msg)
self.config[store_name][1].append(value)
self.stored_ids[store_name].add(id(value))
class TrackedList(list):
# TODO: override item removal methods?
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = [tracker.track(v) for v in values]
super().__init__(values or [])
def append(self, value):
if self.tracker:
self.tracker.track(value)
super().append(value)
def insert(self, value):
if self.tracker:
self.tracker.track(value)
super().insert(value)
def extend(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
super().extend(values)
class TrackedDict(dict):
# TODO: override item removal methods?
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = {k: tracker.track(v) for k, v in values.items()}
super().__init__(values or [])
def __setitem__(self, key, value):
if self.tracker:
self.tracker.track(value)
super().__setitem__(key, value)
def update(self, mapping):
if self.tracker:
mapping = {k: self.tracker.track(v) for k, v in mapping.items()}
super().update(mapping)
class TrackedSet(set):
# TODO: override item removal methods?
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = {tracker.track(v) for v in values}
super().__init__(values or [])
def add(self, value):
if self.tracker:
self.tracker.track(value)
super().add(value)
def update(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
super().update(values)