198 lines
6.5 KiB
Python
198 lines
6.5 KiB
Python
from keras_core import backend
|
|
from keras_core import initializers
|
|
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.losses import loss
|
|
from keras_core.metrics.metric import Metric
|
|
from keras_core.saving import serialization_lib
|
|
|
|
|
|
def reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype):
|
|
mask = getattr(values, "_keras_mask", None)
|
|
values = ops.cast(values, dtype=dtype)
|
|
if sample_weight is not None:
|
|
sample_weight = ops.cast(sample_weight, dtype=dtype)
|
|
if mask is not None:
|
|
sample_weight = loss.apply_mask(
|
|
sample_weight, mask, dtype=dtype, reduction="sum"
|
|
)
|
|
# Update dimensions of weights to match with values if possible.
|
|
values, sample_weight = loss.squeeze_to_same_rank(values, sample_weight)
|
|
# Reduce values to same ndim as weight array
|
|
weight_ndim = len(sample_weight.shape)
|
|
values_ndim = len(values.shape)
|
|
if values_ndim > weight_ndim:
|
|
values = reduce_fn(
|
|
values, axis=list(range(weight_ndim, values_ndim))
|
|
)
|
|
values = values * sample_weight
|
|
|
|
values_ndim = len(values.shape)
|
|
if values_ndim > 1:
|
|
values = reduce_fn(values, axis=list(range(1, values_ndim)))
|
|
return values, sample_weight
|
|
return values, sample_weight
|
|
|
|
|
|
@keras_core_export("keras_core.metrics.Sum")
|
|
class Sum(Metric):
|
|
"""Compute the (weighted) sum of the given values.
|
|
|
|
For example, if `values` is `[1, 3, 5, 7]` then their sum is 16.
|
|
If `sample_weight` was specified as `[1, 1, 0, 0]` then the sum would be 4.
|
|
|
|
This metric creates one variable, `total`.
|
|
This is ultimately returned as the sum value.
|
|
|
|
Args:
|
|
name: (Optional) string name of the metric instance.
|
|
dtype: (Optional) data type of the metric result.
|
|
|
|
Example:
|
|
|
|
>>> m = metrics.Sum()
|
|
>>> m.update_state([1, 3, 5, 7])
|
|
>>> m.result()
|
|
16.0
|
|
|
|
>>> m = metrics.Sum()
|
|
>>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])
|
|
>>> m.result()
|
|
4.0
|
|
"""
|
|
|
|
def __init__(self, name="sum", dtype=None):
|
|
super().__init__(name=name, dtype=dtype)
|
|
self.total = self.add_variable(
|
|
shape=(), initializer=initializers.Zeros(), dtype=self.dtype
|
|
)
|
|
|
|
def update_state(self, values, sample_weight=None):
|
|
values, _ = reduce_to_samplewise_values(
|
|
values, sample_weight, reduce_fn=ops.sum, dtype=self.dtype
|
|
)
|
|
self.total.assign(self.total + ops.sum(values))
|
|
|
|
def reset_state(self):
|
|
self.total.assign(0.0)
|
|
|
|
def result(self):
|
|
return ops.cast(self.total, self.dtype)
|
|
|
|
|
|
@keras_core_export("keras_core.metrics.Mean")
|
|
class Mean(Metric):
|
|
"""Compute the (weighted) mean of the given values.
|
|
|
|
For example, if values is `[1, 3, 5, 7]` then the mean is 4.
|
|
If `sample_weight` was specified as `[1, 1, 0, 0]` then the mean would be 2.
|
|
|
|
This metric creates two variables, `total` and `count`.
|
|
The mean value returned is simply `total` divided by `count`.
|
|
|
|
Args:
|
|
name: (Optional) string name of the metric instance.
|
|
dtype: (Optional) data type of the metric result.
|
|
|
|
Example:
|
|
|
|
>>> m = Mean()
|
|
>>> m.update_state([1, 3, 5, 7])
|
|
>>> m.result()
|
|
4.0
|
|
|
|
>>> m.reset_state()
|
|
>>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])
|
|
>>> m.result()
|
|
2.0
|
|
```
|
|
"""
|
|
|
|
def __init__(self, name="mean", dtype=None):
|
|
super().__init__(name=name, dtype=dtype)
|
|
self.total = self.add_variable(
|
|
shape=(), initializer=initializers.Zeros(), dtype=self.dtype
|
|
)
|
|
self.count = self.add_variable(
|
|
shape=(), initializer=initializers.Zeros(), dtype=self.dtype
|
|
)
|
|
|
|
def update_state(self, values, sample_weight=None):
|
|
values, sample_weight = reduce_to_samplewise_values(
|
|
values, sample_weight, reduce_fn=ops.mean, dtype=self.dtype
|
|
)
|
|
self.total.assign(self.total + ops.sum(values))
|
|
if len(values.shape) >= 1:
|
|
num_samples = ops.shape(values)[0]
|
|
else:
|
|
num_samples = 1
|
|
if sample_weight is not None:
|
|
num_samples = ops.sum(
|
|
ops.ones(shape=(num_samples,)) * sample_weight
|
|
)
|
|
self.count.assign(self.count + ops.cast(num_samples, dtype=self.dtype))
|
|
|
|
def reset_state(self):
|
|
self.total.assign(0.0)
|
|
self.count.assign(0)
|
|
|
|
def result(self):
|
|
return self.total / (
|
|
ops.maximum(
|
|
ops.cast(self.count, dtype=self.dtype), backend.epsilon()
|
|
)
|
|
)
|
|
|
|
|
|
@keras_core_export("keras_core.metrics.MeanMetricWrapper")
|
|
class MeanMetricWrapper(Mean):
|
|
"""Wrap a stateless metric function with the Mean metric.
|
|
|
|
You could use this class to quickly build a mean metric from a function. The
|
|
function needs to have the signature `fn(y_true, y_pred)` and return a
|
|
per-sample loss array. `MeanMetricWrapper.result()` will return
|
|
the average metric value across all samples seen so far.
|
|
|
|
For example:
|
|
|
|
```python
|
|
def mse(y_true, y_pred):
|
|
return (y_true - y_pred) ** 2
|
|
|
|
mse_metric = MeanMetricWrapper(fn=mse)
|
|
```
|
|
|
|
Args:
|
|
fn: The metric function to wrap, with signature
|
|
`fn(y_true, y_pred, **kwargs)`.
|
|
name: (Optional) string name of the metric instance.
|
|
dtype: (Optional) data type of the metric result.
|
|
**kwargs: Keyword arguments to pass on to `fn`.
|
|
"""
|
|
|
|
def __init__(self, fn, name=None, dtype=None, **kwargs):
|
|
super().__init__(name=name, dtype=dtype)
|
|
self._fn = fn
|
|
self._fn_kwargs = kwargs
|
|
|
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
|
mask = getattr(y_pred, "_keras_mask", None)
|
|
values = self._fn(y_true, y_pred, **self._fn_kwargs)
|
|
if sample_weight is not None and mask is not None:
|
|
sample_weight = loss.apply_mask(
|
|
sample_weight, mask, dtype=self.dtype, reduction="sum"
|
|
)
|
|
return super().update_state(values, sample_weight=sample_weight)
|
|
|
|
def get_config(self):
|
|
base_config = super().get_config()
|
|
config = {"fn": serialization_lib.serialize_keras_object(self.fn)}
|
|
config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))
|
|
return {**base_config, **config}
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
if "fn" in config:
|
|
config = serialization_lib.deserialize_keras_object(config)
|
|
return cls(**config)
|