keras/keras_core/losses/losses.py
François Chollet 4a8d8c6117 Merge pull request #2 from keras-team/add_initializers
Add RandomNormal and RandomUniform Initializers
2023-04-12 11:00:14 -07:00

38 lines
1.1 KiB
Python

from keras_core.losses.loss import Loss
from keras_core import operations as ops
from keras_core.losses.loss import squeeze_to_same_rank
class LossFunctionWrapper(Loss):
def __init__(
self, fn, reduction="sum_over_batch_size", name=None, **kwargs
):
super().__init__(reduction=reduction, name=name)
self.fn = fn
self._fn_kwargs = kwargs
def call(self, y_true, y_pred):
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
return self.fn(y_true, y_pred, **self._fn_kwargs)
def get_config(self):
raise NotImplementedError
@classmethod
def from_config(clf, config):
raise NotImplementedError
def mean_squared_error(y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
return (y_true - y_pred) ** 2
class MeanSquaredError(LossFunctionWrapper):
def __init__(
self, reduction="sum_over_batch_size", name="mean_squared_error"
):
super().__init__(mean_squared_error, reduction=reduction, name=name)