keras/keras_core/utils/tracking.py

159 lines
4.7 KiB
Python
Raw Normal View History

2023-05-03 22:33:40 +00:00
from keras_core.backend.common.global_state import get_global_attribute
from keras_core.backend.common.global_state import set_global_attribute
2023-04-12 22:20:56 +00:00
class DotNotTrackScope:
def __enter__(self):
self.original_value = is_tracking_enabled()
set_global_attribute("tracking_on", False)
2023-04-12 22:43:56 +00:00
2023-04-12 22:20:56 +00:00
def __exit__(self, *args, **kwargs):
set_global_attribute("tracking_on", self.original_value)
2023-04-12 22:20:56 +00:00
def is_tracking_enabled():
return get_global_attribute("tracking_on", True)
2023-04-12 22:20:56 +00:00
def no_automatic_dependency_tracking(fn):
def wrapper(*args, **kwargs):
with DotNotTrackScope():
return fn(*args, **kwargs)
2023-04-12 22:43:56 +00:00
2023-04-12 22:20:56 +00:00
return wrapper
2023-04-09 19:21:45 +00:00
class Tracker:
"""Attribute tracker, used for e.g. Variable tracking.
Monitors certain attribute types
and put them in appropriate lists in case of a match.
Also passively tracks certain mutable collections
(dict, list) so that items added to them later
still get tracked. This is done by wrapping these
collections into an equivalent, tracking-aware object.
Usage:
```python
def __init__(self):
self.tracker = Tracker(
# Format: `name: (test_fn, store)`
{
"variables":
(lambda x: isinstance(x, Variable), self._variables),
2023-04-09 19:21:45 +00:00
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
"layers": (lambda x: isinstance(x, Layer), self._layers),
}
)
def __setattr__(self, name, value):
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
return super().__setattr__(name, value)
```
"""
def __init__(self, config):
self.config = config
self.stored_ids = {name: set() for name in self.config.keys()}
2023-06-06 01:39:26 +00:00
self.locked = False
self._lock_violation_msg = None
2023-04-09 19:21:45 +00:00
def track(self, attr):
2023-04-12 22:20:56 +00:00
if not is_tracking_enabled():
return attr
2023-06-06 01:39:26 +00:00
for name, (is_attr_type, _) in self.config.items():
2023-04-09 19:21:45 +00:00
if is_attr_type(attr):
if id(attr) not in self.stored_ids[name]:
2023-06-06 01:39:26 +00:00
self.add_to_store(name, attr)
2023-04-09 19:21:45 +00:00
return attr
if isinstance(attr, tuple):
wrapped_attr = []
for e in attr:
wrapped_attr.append(self.track(e))
# This should cover tuples and nametuples
return attr.__class__(wrapped_attr)
elif isinstance(attr, list):
return TrackedList(attr, self)
elif isinstance(attr, dict):
2023-06-06 01:39:26 +00:00
# TODO: OrderedDict?
2023-04-09 19:21:45 +00:00
return TrackedDict(attr, self)
elif isinstance(attr, set):
return TrackedSet(attr, self)
return attr
2023-06-06 01:39:26 +00:00
def lock(self, msg):
self.locked = True
self._lock_violation_msg = msg
def add_to_store(self, store_name, value):
if self.locked:
raise ValueError(self._lock_violation_msg)
self.config[store_name][1].append(value)
self.stored_ids[store_name].add(id(value))
2023-04-09 19:21:45 +00:00
class TrackedList(list):
2023-06-06 01:39:26 +00:00
# TODO: override item removal methods?
def __init__(self, values=None, tracker=None):
2023-04-09 19:21:45 +00:00
self.tracker = tracker
if tracker and values:
values = [tracker.track(v) for v in values]
super().__init__(values or [])
2023-04-09 19:21:45 +00:00
def append(self, value):
if self.tracker:
self.tracker.track(value)
2023-04-09 19:21:45 +00:00
super().append(value)
def insert(self, value):
if self.tracker:
self.tracker.track(value)
2023-04-09 19:21:45 +00:00
super().insert(value)
def extend(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
2023-04-09 19:21:45 +00:00
super().extend(values)
class TrackedDict(dict):
2023-06-06 01:39:26 +00:00
# TODO: override item removal methods?
def __init__(self, values=None, tracker=None):
2023-04-09 19:21:45 +00:00
self.tracker = tracker
if tracker and values:
values = {k: tracker.track(v) for k, v in values.items()}
super().__init__(values or [])
2023-04-09 19:21:45 +00:00
def __setitem__(self, key, value):
if self.tracker:
self.tracker.track(value)
2023-04-09 19:21:45 +00:00
super().__setitem__(key, value)
def update(self, mapping):
if self.tracker:
mapping = {k: self.tracker.track(v) for k, v in mapping.items()}
2023-04-09 19:21:45 +00:00
super().update(mapping)
class TrackedSet(set):
2023-06-06 01:39:26 +00:00
# TODO: override item removal methods?
def __init__(self, values=None, tracker=None):
2023-04-09 19:21:45 +00:00
self.tracker = tracker
if tracker and values:
values = {tracker.track(v) for v in values}
super().__init__(values or [])
2023-04-09 19:21:45 +00:00
def add(self, value):
if self.tracker:
self.tracker.track(value)
2023-04-09 19:21:45 +00:00
super().add(value)
def update(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
2023-04-09 19:21:45 +00:00
super().update(values)