keras/keras_core/metrics/regression_metrics.py

16 lines
513 B
Python
Raw Normal View History

2023-04-16 19:21:29 +00:00
from keras_core import operations as ops
from keras_core.metrics import reduction_metrics
2023-04-17 22:41:48 +00:00
def mean_squared_error(y_true, y_pred):
2023-04-16 19:55:04 +00:00
ndim = len(y_pred.shape)
return ops.mean((y_true - y_pred) ** 2, axis=list(range(1, ndim)))
2023-04-16 19:21:29 +00:00
2023-04-17 22:41:48 +00:00
class MeanSquaredError(reduction_metrics.MeanMetricWrapper):
2023-04-18 00:23:53 +00:00
def __init__(self, name="mean_squared_error", dtype=None):
2023-04-17 22:41:48 +00:00
super().__init__(fn=mean_squared_error, name=name, dtype=dtype)
2023-04-16 19:21:29 +00:00
def get_config(self):
2023-04-16 20:30:21 +00:00
return {"name": self.name, "dtype": self.dtype}