2023-04-09 19:21:45 +00:00
|
|
|
from keras_core import backend
|
2023-04-26 17:29:40 +00:00
|
|
|
from keras_core import initializers
|
2023-05-29 01:28:19 +00:00
|
|
|
from keras_core import operations as ops
|
2023-04-09 19:53:37 +00:00
|
|
|
from keras_core.api_export import keras_core_export
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.utils.naming import auto_name
|
|
|
|
from keras_core.utils.tracking import Tracker
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
2023-04-09 19:53:37 +00:00
|
|
|
@keras_core_export(["keras_core.Metric", "keras_core.metrics.Metric"])
|
2023-04-09 19:21:45 +00:00
|
|
|
class Metric:
|
2023-05-26 16:13:08 +00:00
|
|
|
"""Encapsulates metric logic and state.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: (Optional) string name of the metric instance.
|
|
|
|
dtype: (Optional) data type of the metric result.
|
|
|
|
|
|
|
|
Standalone usage:
|
|
|
|
|
|
|
|
```python
|
|
|
|
m = SomeMetric(...)
|
|
|
|
for input in ...:
|
|
|
|
m.update_state(input)
|
|
|
|
print('Final result: ', m.result())
|
|
|
|
```
|
|
|
|
|
|
|
|
Usage with `compile()` API:
|
|
|
|
|
|
|
|
```python
|
|
|
|
model = keras_core.Sequential()
|
|
|
|
model.add(keras_core.layers.Dense(64, activation='relu'))
|
|
|
|
model.add(keras_core.layers.Dense(64, activation='relu'))
|
|
|
|
model.add(keras_core.layers.Dense(10, activation='softmax'))
|
|
|
|
|
|
|
|
model.compile(optimizer=keras_core.optimizers.RMSprop(0.01),
|
|
|
|
loss=keras_core.losses.CategoricalCrossentropy(),
|
|
|
|
metrics=[keras_core.metrics.CategoricalAccuracy()])
|
|
|
|
|
|
|
|
data = np.random.random((1000, 32))
|
|
|
|
labels = np.random.random((1000, 10))
|
|
|
|
|
|
|
|
model.fit(data, labels, epochs=10)
|
|
|
|
```
|
|
|
|
|
|
|
|
To be implemented by subclasses:
|
|
|
|
|
|
|
|
* `__init__()`: All state variables should be created in this method by
|
|
|
|
calling `self.add_variable()` like: `self.var = self.add_variable(...)`
|
|
|
|
* `update_state()`: Has all updates to the state variables like:
|
|
|
|
`self.var.assign(...)`.
|
|
|
|
* `result()`: Computes and returns a scalar value or a dict of scalar values
|
|
|
|
for the metric from the state variables.
|
|
|
|
|
|
|
|
Example subclass implementation:
|
|
|
|
|
|
|
|
```python
|
|
|
|
class BinaryTruePositives(Metric):
|
|
|
|
|
|
|
|
def __init__(self, name='binary_true_positives', **kwargs):
|
|
|
|
super().__init__(name=name, **kwargs)
|
|
|
|
self.true_positives = self.add_variable(
|
|
|
|
shape=(),
|
|
|
|
initializer='zeros',
|
2023-05-26 21:11:03 +00:00
|
|
|
name='true_positives'
|
2023-05-26 16:13:08 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
|
|
|
y_true = ops.cast(y_true, "bool")
|
|
|
|
y_pred = ops.cast(y_pred, "bool")
|
|
|
|
|
|
|
|
values = ops.logical_and(
|
|
|
|
ops.equal(y_true, True), ops.equal(y_pred, True))
|
|
|
|
values = ops.cast(values, self.dtype)
|
|
|
|
if sample_weight is not None:
|
|
|
|
sample_weight = ops.cast(sample_weight, self.dtype)
|
|
|
|
sample_weight = ops.broadcast_to(sample_weight, values.shape)
|
|
|
|
values = ops.multiply(values, sample_weight)
|
|
|
|
self.true_positives.assign(self.true_positives + ops.sum(values))
|
|
|
|
|
|
|
|
def result(self):
|
|
|
|
return self.true_positives
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
def __init__(self, dtype=None, name=None):
|
|
|
|
self.name = name or auto_name(self.__class__.__name__)
|
|
|
|
self._dtype = dtype
|
|
|
|
self._metrics = []
|
|
|
|
self._variables = []
|
|
|
|
self._tracker = Tracker(
|
|
|
|
{
|
|
|
|
"variables": (
|
|
|
|
lambda x: isinstance(x, backend.Variable),
|
|
|
|
self._variables,
|
|
|
|
),
|
|
|
|
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
def reset_state(self):
|
|
|
|
"""Reset all of the metric state variables.
|
|
|
|
|
|
|
|
This function is called between epochs/steps,
|
|
|
|
when a metric is evaluated during training.
|
|
|
|
"""
|
|
|
|
for v in self.variables:
|
2023-05-29 01:28:19 +00:00
|
|
|
v.assign(ops.zeros(v.shape, dtype=v.dtype))
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def update_state(self, *args, **kwargs):
|
|
|
|
"""Accumulate statistics for the metric."""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def result(self):
|
|
|
|
"""Compute the current metric value.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A scalar tensor, or a dictionary of scalar tensors.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@property
|
|
|
|
def dtype(self):
|
|
|
|
return self._dtype
|
|
|
|
|
|
|
|
def add_variable(self, shape, initializer, dtype=None, name=None):
|
|
|
|
self._check_super_called()
|
2023-04-26 17:29:40 +00:00
|
|
|
initializer = initializers.get(initializer)
|
2023-04-09 19:21:45 +00:00
|
|
|
variable = backend.Variable(
|
2023-04-26 17:29:40 +00:00
|
|
|
initializer=initializer,
|
|
|
|
shape=shape,
|
2023-04-09 19:21:45 +00:00
|
|
|
dtype=dtype,
|
|
|
|
trainable=False,
|
|
|
|
name=name,
|
|
|
|
)
|
|
|
|
# Prevent double-tracking
|
2023-06-06 01:39:26 +00:00
|
|
|
self._tracker.add_to_store("variables", variable)
|
2023-04-09 19:21:45 +00:00
|
|
|
return variable
|
|
|
|
|
2023-05-26 16:13:08 +00:00
|
|
|
def add_weight(self, *args, **kwargs):
|
|
|
|
# Backwards compatibility alias
|
|
|
|
return self.add_variable(*args, **kwargs)
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
@property
|
|
|
|
def variables(self):
|
|
|
|
variables = self._variables[:]
|
|
|
|
for metric in self._metrics:
|
|
|
|
variables.extend(metric._variables)
|
|
|
|
return variables
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
self._check_super_called()
|
|
|
|
self.update_state(*args, **kwargs)
|
|
|
|
return self.result()
|
|
|
|
|
|
|
|
def get_config(self):
|
|
|
|
"""Return the serializable config of the metric."""
|
|
|
|
return {"name": self.name, "dtype": self.dtype}
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_config(cls, config):
|
|
|
|
return cls(**config)
|
|
|
|
|
|
|
|
def __setattr__(self, name, value):
|
|
|
|
# Track Variables, Layers, Metrics
|
|
|
|
if hasattr(self, "_tracker"):
|
|
|
|
value = self._tracker.track(value)
|
|
|
|
return super().__setattr__(name, value)
|
|
|
|
|
|
|
|
def _check_super_called(self):
|
|
|
|
if not hasattr(self, "_tracker"):
|
|
|
|
raise RuntimeError(
|
|
|
|
"You forgot to call `super().__init__()` "
|
|
|
|
"in the `__init__()` method. Go add it!"
|
|
|
|
)
|