97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
from keras_core import backend
|
|
from keras_core import initializers
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.utils.naming import auto_name
|
|
from keras_core.utils.tracking import Tracker
|
|
|
|
|
|
@keras_core_export(["keras_core.Metric", "keras_core.metrics.Metric"])
|
|
class Metric:
|
|
def __init__(self, dtype=None, name=None):
|
|
self.name = name or auto_name(self.__class__.__name__)
|
|
self._dtype = dtype
|
|
self._metrics = []
|
|
self._variables = []
|
|
self._tracker = Tracker(
|
|
{
|
|
"variables": (
|
|
lambda x: isinstance(x, backend.Variable),
|
|
self._variables,
|
|
),
|
|
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
|
|
}
|
|
)
|
|
|
|
def reset_state(self):
|
|
"""Reset all of the metric state variables.
|
|
|
|
This function is called between epochs/steps,
|
|
when a metric is evaluated during training.
|
|
"""
|
|
for v in self.variables:
|
|
v.assign(0)
|
|
|
|
def update_state(self, *args, **kwargs):
|
|
"""Accumulate statistics for the metric."""
|
|
raise NotImplementedError
|
|
|
|
def result(self):
|
|
"""Compute the current metric value.
|
|
|
|
Returns:
|
|
A scalar tensor, or a dictionary of scalar tensors.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._dtype
|
|
|
|
def add_variable(self, shape, initializer, dtype=None, name=None):
|
|
self._check_super_called()
|
|
initializer = initializers.get(initializer)
|
|
variable = backend.Variable(
|
|
initializer=initializer,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
trainable=False,
|
|
name=name,
|
|
)
|
|
self._variables.append(variable)
|
|
# Prevent double-tracking
|
|
self._tracker.stored_ids["variables"].add(id(variable))
|
|
return variable
|
|
|
|
@property
|
|
def variables(self):
|
|
variables = self._variables[:]
|
|
for metric in self._metrics:
|
|
variables.extend(metric._variables)
|
|
return variables
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
self._check_super_called()
|
|
self.update_state(*args, **kwargs)
|
|
return self.result()
|
|
|
|
def get_config(self):
|
|
"""Return the serializable config of the metric."""
|
|
return {"name": self.name, "dtype": self.dtype}
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
return cls(**config)
|
|
|
|
def __setattr__(self, name, value):
|
|
# Track Variables, Layers, Metrics
|
|
if hasattr(self, "_tracker"):
|
|
value = self._tracker.track(value)
|
|
return super().__setattr__(name, value)
|
|
|
|
def _check_super_called(self):
|
|
if not hasattr(self, "_tracker"):
|
|
raise RuntimeError(
|
|
"You forgot to call `super().__init__()` "
|
|
"in the `__init__()` method. Go add it!"
|
|
)
|