keras/keras_core/metrics/regression_metrics.py
Chen Qian eabdb87f9f Add some numpy ops (#1)
* Add numpy ops (initial batch) and some config

* Add unit test

* fix call

* Revert "fix call"

This reverts commit 6748ad183029ff4b97317b77ceed8661916bb9a0.

* full unit test coverage

* fix setup.py
2023-04-12 11:31:58 -07:00

29 lines
884 B
Python

from keras_core import backend
from keras_core import initializers
from keras_core.metrics.metric import Metric
class MeanSquareError(Metric):
def __init__(self, name="mean_square_error", dtype=None):
super().__init__(name=name, dtype=dtype)
self.sum = self.add_variable(
name="sum", initializer=initializers.Zeros()
)
self.total = self.add_variable(
name="total", initializer=initializers.Zeros()
)
def update_state(self, y_true, y_pred):
# TODO: add support for sample_weight
sum = (y_true - y_pred) ** 2
self.sum.assign(self.sum + sum)
batch_size = backend.shape(y_true)[0]
self.total.assign(self.total + batch_size)
def result(self):
return self.sum / self.total
def reset_states(self):
self.sum.assign(0.0)
self.total.assign(0.0)