eabdb87f9f
* Add numpy ops (initial batch) and some config * Add unit test * fix call * Revert "fix call" This reverts commit 6748ad183029ff4b97317b77ceed8661916bb9a0. * full unit test coverage * fix setup.py
29 lines
884 B
Python
29 lines
884 B
Python
from keras_core import backend
|
|
from keras_core import initializers
|
|
from keras_core.metrics.metric import Metric
|
|
|
|
|
|
class MeanSquareError(Metric):
|
|
def __init__(self, name="mean_square_error", dtype=None):
|
|
super().__init__(name=name, dtype=dtype)
|
|
self.sum = self.add_variable(
|
|
name="sum", initializer=initializers.Zeros()
|
|
)
|
|
self.total = self.add_variable(
|
|
name="total", initializer=initializers.Zeros()
|
|
)
|
|
|
|
def update_state(self, y_true, y_pred):
|
|
# TODO: add support for sample_weight
|
|
sum = (y_true - y_pred) ** 2
|
|
self.sum.assign(self.sum + sum)
|
|
batch_size = backend.shape(y_true)[0]
|
|
self.total.assign(self.total + batch_size)
|
|
|
|
def result(self):
|
|
return self.sum / self.total
|
|
|
|
def reset_states(self):
|
|
self.sum.assign(0.0)
|
|
self.total.assign(0.0)
|