16 lines
513 B
Python
16 lines
513 B
Python
from keras_core import operations as ops
|
|
from keras_core.metrics import reduction_metrics
|
|
|
|
|
|
def mean_squared_error(y_true, y_pred):
|
|
ndim = len(y_pred.shape)
|
|
return ops.mean((y_true - y_pred) ** 2, axis=list(range(1, ndim)))
|
|
|
|
|
|
class MeanSquaredError(reduction_metrics.MeanMetricWrapper):
|
|
def __init__(self, name="mean_squared_error", dtype=None):
|
|
super().__init__(fn=mean_squared_error, name=name, dtype=dtype)
|
|
|
|
def get_config(self):
|
|
return {"name": self.name, "dtype": self.dtype}
|