keras/keras_core/metrics/metric.py

217 lines
7.2 KiB
Python
Raw Normal View History

2023-04-09 19:21:45 +00:00
from keras_core import backend
2023-04-26 17:29:40 +00:00
from keras_core import initializers
2023-06-28 22:36:45 +00:00
from keras_core import ops
2023-04-09 19:53:37 +00:00
from keras_core.api_export import keras_core_export
from keras_core.utils.naming import auto_name
from keras_core.utils.tracking import Tracker
2023-04-09 19:21:45 +00:00
2023-04-09 19:53:37 +00:00
@keras_core_export(["keras_core.Metric", "keras_core.metrics.Metric"])
2023-04-09 19:21:45 +00:00
class Metric:
2023-05-26 16:13:08 +00:00
"""Encapsulates metric logic and state.
Args:
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Standalone usage:
```python
m = SomeMetric(...)
for input in ...:
m.update_state(input)
print('Final result: ', m.result())
```
Usage with `compile()` API:
```python
model = keras_core.Sequential()
model.add(keras_core.layers.Dense(64, activation='relu'))
model.add(keras_core.layers.Dense(64, activation='relu'))
model.add(keras_core.layers.Dense(10, activation='softmax'))
model.compile(optimizer=keras_core.optimizers.RMSprop(0.01),
loss=keras_core.losses.CategoricalCrossentropy(),
metrics=[keras_core.metrics.CategoricalAccuracy()])
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
model.fit(data, labels, epochs=10)
```
To be implemented by subclasses:
* `__init__()`: All state variables should be created in this method by
calling `self.add_variable()` like: `self.var = self.add_variable(...)`
* `update_state()`: Has all updates to the state variables like:
`self.var.assign(...)`.
* `result()`: Computes and returns a scalar value or a dict of scalar values
for the metric from the state variables.
Example subclass implementation:
```python
class BinaryTruePositives(Metric):
def __init__(self, name='binary_true_positives', **kwargs):
super().__init__(name=name, **kwargs)
self.true_positives = self.add_variable(
shape=(),
initializer='zeros',
2023-05-26 21:11:03 +00:00
name='true_positives'
2023-05-26 16:13:08 +00:00
)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = ops.cast(y_true, "bool")
y_pred = ops.cast(y_pred, "bool")
values = ops.logical_and(
ops.equal(y_true, True), ops.equal(y_pred, True))
values = ops.cast(values, self.dtype)
if sample_weight is not None:
sample_weight = ops.cast(sample_weight, self.dtype)
sample_weight = ops.broadcast_to(sample_weight, values.shape)
values = ops.multiply(values, sample_weight)
self.true_positives.assign(self.true_positives + ops.sum(values))
def result(self):
return self.true_positives
```
"""
2023-04-09 19:21:45 +00:00
def __init__(self, dtype=None, name=None):
self.name = name or auto_name(self.__class__.__name__)
self._dtype = dtype
self._metrics = []
self._variables = []
self._tracker = Tracker(
{
"variables": (
lambda x: isinstance(x, backend.Variable),
self._variables,
),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
}
)
def reset_state(self):
"""Reset all of the metric state variables.
This function is called between epochs/steps,
when a metric is evaluated during training.
"""
for v in self.variables:
2023-05-29 01:28:19 +00:00
v.assign(ops.zeros(v.shape, dtype=v.dtype))
2023-04-09 19:21:45 +00:00
def update_state(self, *args, **kwargs):
"""Accumulate statistics for the metric."""
raise NotImplementedError
2023-06-27 18:57:24 +00:00
def stateless_update_state(self, metric_variables, *args, **kwargs):
if len(metric_variables) != len(self.variables):
raise ValueError(
"Argument `metric_variables` must be a list of tensors "
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
f"Received list with length {len(metric_variables)}, but "
f"expected {len(self.variables)} variables."
)
# Gather variable mapping
mapping = list(zip(self.variables, metric_variables))
# Call in stateless scope
with backend.StatelessScope(state_mapping=mapping) as scope:
self.update_state(*args, **kwargs)
# Gather updated variables
metric_variables = []
for v in self.variables:
new_v = scope.get_current_value(v)
if new_v is not None:
metric_variables.append(new_v)
else:
metric_variables.append(v)
return metric_variables
2023-04-09 19:21:45 +00:00
def result(self):
"""Compute the current metric value.
Returns:
A scalar tensor, or a dictionary of scalar tensors.
"""
raise NotImplementedError
2023-06-28 18:58:53 +00:00
def stateless_result(self, metric_variables):
if len(metric_variables) != len(self.variables):
raise ValueError(
"Argument `metric_variables` must be a list of tensors "
f"corresponding 1:1 to {self.__class__.__name__}().variables. "
f"Received list with length {len(metric_variables)}, but "
f"expected {len(self.variables)} variables."
)
# Gather variable mapping
mapping = list(zip(self.variables, metric_variables))
# Call in stateless scope
2023-06-28 23:04:13 +00:00
with backend.StatelessScope(state_mapping=mapping):
2023-06-28 18:58:53 +00:00
res = self.result()
return res
2023-04-09 19:21:45 +00:00
@property
def dtype(self):
return self._dtype
def add_variable(self, shape, initializer, dtype=None, name=None):
self._check_super_called()
2023-04-26 17:29:40 +00:00
initializer = initializers.get(initializer)
2023-04-09 19:21:45 +00:00
variable = backend.Variable(
2023-04-26 17:29:40 +00:00
initializer=initializer,
shape=shape,
2023-04-09 19:21:45 +00:00
dtype=dtype,
trainable=False,
name=name,
)
# Prevent double-tracking
2023-06-06 01:39:26 +00:00
self._tracker.add_to_store("variables", variable)
2023-04-09 19:21:45 +00:00
return variable
2023-06-14 22:36:31 +00:00
def add_weight(self, shape=(), initializer=None, dtype=None, name=None):
2023-05-26 16:13:08 +00:00
# Backwards compatibility alias
2023-06-14 22:36:31 +00:00
return self.add_variable(
shape=shape, initializer=initializer, dtype=dtype, name=name
)
2023-05-26 16:13:08 +00:00
2023-04-09 19:21:45 +00:00
@property
def variables(self):
variables = self._variables[:]
for metric in self._metrics:
variables.extend(metric._variables)
return variables
def __call__(self, *args, **kwargs):
self._check_super_called()
self.update_state(*args, **kwargs)
return self.result()
def get_config(self):
"""Return the serializable config of the metric."""
return {"name": self.name, "dtype": self.dtype}
@classmethod
def from_config(cls, config):
return cls(**config)
def __setattr__(self, name, value):
# Track Variables, Layers, Metrics
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
return super().__setattr__(name, value)
def _check_super_called(self):
if not hasattr(self, "_tracker"):
raise RuntimeError(
"You forgot to call `super().__init__()` "
"in the `__init__()` method. Go add it!"
)