2023-04-16 01:51:10 +00:00
|
|
|
from tensorflow import nest
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core import backend
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core import operations as ops
|
2023-04-09 19:53:37 +00:00
|
|
|
from keras_core.api_export import keras_core_export
|
2023-04-12 18:58:53 +00:00
|
|
|
from keras_core.utils import dtype_utils
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.utils.naming import auto_name
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
2023-04-09 19:53:37 +00:00
|
|
|
@keras_core_export(["keras_core.Loss", "keras_core.losses.Loss"])
|
2023-04-09 19:21:45 +00:00
|
|
|
class Loss:
|
|
|
|
def __init__(self, name=None, reduction="sum_over_batch_size"):
|
|
|
|
self.name = name or auto_name(self.__class__.__name__)
|
|
|
|
self.reduction = standardize_reduction(reduction)
|
|
|
|
|
|
|
|
def __call__(self, y_true, y_pred, sample_weight=None):
|
|
|
|
in_mask = getattr(y_pred, "_keras_mask", None)
|
|
|
|
|
|
|
|
with ops.name_scope(self.name):
|
2023-04-27 23:02:31 +00:00
|
|
|
dtype = backend.floatx()
|
|
|
|
y_pred = nest.map_structure(
|
|
|
|
lambda x: ops.convert_to_tensor(x, dtype=dtype), y_pred
|
|
|
|
)
|
2023-04-16 01:51:10 +00:00
|
|
|
y_true = nest.map_structure(
|
|
|
|
lambda x: ops.convert_to_tensor(x, dtype=y_pred.dtype), y_true
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
losses = self.call(y_true, y_pred)
|
|
|
|
out_mask = getattr(losses, "_keras_mask", None)
|
|
|
|
|
|
|
|
if in_mask is not None and out_mask is not None:
|
|
|
|
mask = in_mask & out_mask
|
|
|
|
elif in_mask is not None:
|
|
|
|
mask = in_mask
|
|
|
|
elif out_mask is not None:
|
|
|
|
mask = out_mask
|
|
|
|
else:
|
|
|
|
mask = None
|
|
|
|
|
2023-04-16 19:55:04 +00:00
|
|
|
return reduce_weighted_values(
|
2023-04-12 18:00:14 +00:00
|
|
|
losses,
|
|
|
|
sample_weight=sample_weight,
|
|
|
|
mask=mask,
|
|
|
|
reduction=self.reduction,
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def call(self, y_true, y_pred):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def get_config(self):
|
2023-04-25 04:21:57 +00:00
|
|
|
return {"name": self.name, "reduction": self.reduction}
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_config(cls, config):
|
|
|
|
return cls(**config)
|
|
|
|
|
|
|
|
|
|
|
|
def standardize_reduction(reduction):
|
|
|
|
allowed = {"sum_over_batch_size", "sum", None}
|
2023-04-22 06:16:51 +00:00
|
|
|
if reduction not in allowed:
|
2023-04-09 19:21:45 +00:00
|
|
|
raise ValueError(
|
|
|
|
"Invalid value for argument `reduction`. "
|
|
|
|
f"Expected on of {allowed}. Received: "
|
|
|
|
f"reduction={reduction}"
|
|
|
|
)
|
|
|
|
return reduction
|
|
|
|
|
|
|
|
|
|
|
|
def squeeze_to_same_rank(x1, x2):
|
|
|
|
"""Squeeze last dim if ranks differ from expected by exactly 1."""
|
|
|
|
x1_rank = len(x1.shape)
|
|
|
|
x2_rank = len(x2.shape)
|
|
|
|
if x1_rank == x2_rank:
|
|
|
|
return x1, x2
|
|
|
|
if x1_rank == x2_rank + 1:
|
|
|
|
if x1.shape[-1] == 1:
|
|
|
|
x1 = ops.squeeze(x1, axis=-1)
|
|
|
|
if x2_rank == x1_rank + 1:
|
|
|
|
if x2.shape[-1] == 1:
|
|
|
|
x2 = ops.squeeze(x2, axis=-1)
|
|
|
|
return x1, x2
|
|
|
|
|
|
|
|
|
2023-04-16 19:55:04 +00:00
|
|
|
def reduce_values(values, reduction="sum_over_batch_size"):
|
2023-04-12 18:00:14 +00:00
|
|
|
if (
|
|
|
|
reduction is None
|
2023-04-16 19:55:04 +00:00
|
|
|
or tuple(values.shape) == ()
|
|
|
|
or tuple(values.shape) == (0,)
|
2023-04-12 18:00:14 +00:00
|
|
|
):
|
2023-04-16 19:55:04 +00:00
|
|
|
return values
|
|
|
|
loss = ops.sum(values)
|
2023-04-09 19:21:45 +00:00
|
|
|
if reduction == "sum_over_batch_size":
|
2023-04-23 21:02:44 +00:00
|
|
|
loss /= ops.cast(
|
|
|
|
ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")),
|
|
|
|
loss.dtype,
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
return loss
|
|
|
|
|
|
|
|
|
2023-04-16 19:55:04 +00:00
|
|
|
def reduce_weighted_values(
|
|
|
|
values,
|
2023-04-09 19:21:45 +00:00
|
|
|
sample_weight=None,
|
|
|
|
mask=None,
|
|
|
|
reduction="sum_over_batch_size",
|
|
|
|
):
|
|
|
|
reduction = standardize_reduction(reduction)
|
|
|
|
|
2023-04-16 19:55:04 +00:00
|
|
|
values = ops.convert_to_tensor(values)
|
2023-04-09 19:21:45 +00:00
|
|
|
if sample_weight is not None:
|
2023-04-16 19:55:04 +00:00
|
|
|
sample_weight = ops.convert_to_tensor(sample_weight, dtype=values.dtype)
|
2023-04-09 19:21:45 +00:00
|
|
|
if mask is not None:
|
2023-04-16 19:55:04 +00:00
|
|
|
mask = ops.convert_to_tensor(mask, dtype=values.dtype)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
# Merge mask and sample weight into sample weight.
|
|
|
|
sample_weight = apply_mask(
|
2023-04-16 19:55:04 +00:00
|
|
|
sample_weight, mask, dtype=values.dtype, reduction=reduction
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# Convert any non float dtypes to floats, to avoid loss of precision
|
|
|
|
# for dtype like int or bool.
|
2023-04-16 19:55:04 +00:00
|
|
|
dtype = backend.standardize_dtype(values.dtype)
|
2023-04-12 18:00:14 +00:00
|
|
|
if not dtype_utils.is_float(dtype):
|
2023-04-16 19:55:04 +00:00
|
|
|
input_dtype = values.dtype
|
|
|
|
values = ops.cast(values, "float32")
|
2023-04-09 19:21:45 +00:00
|
|
|
input_casted = True
|
|
|
|
else:
|
|
|
|
input_casted = False
|
|
|
|
|
|
|
|
if sample_weight is not None:
|
2023-04-16 19:55:04 +00:00
|
|
|
sample_weight = ops.cast(sample_weight, values.dtype)
|
2023-04-09 19:21:45 +00:00
|
|
|
# Update dimensions of `sample_weight` to match `losses`.
|
2023-04-16 19:55:04 +00:00
|
|
|
values, sample_weight = squeeze_to_same_rank(values, sample_weight)
|
|
|
|
values *= sample_weight
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
# Apply reduction function to the individual weighted losses.
|
2023-04-16 19:55:04 +00:00
|
|
|
loss = reduce_values(values, reduction)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
if input_casted:
|
|
|
|
# Convert the result back to the input type.
|
|
|
|
loss = ops.cast(loss, input_dtype)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def apply_mask(sample_weight, mask, dtype, reduction):
|
|
|
|
"""Applies any mask on predictions to sample weights."""
|
|
|
|
if mask is not None:
|
|
|
|
mask = ops.cast(mask, dtype=dtype)
|
|
|
|
if reduction == "sum_over_batch_size":
|
|
|
|
# Valid entries have weight `total/valid`, while invalid ones
|
|
|
|
# have 0. When summed over batch, they will be reduced to:
|
|
|
|
#
|
|
|
|
# mean(loss * sample_weight * total / valid)
|
|
|
|
# = sum(loss * sample_weight * total / valid) / total
|
|
|
|
# = sum(loss * sample_weight) / total * total / valid
|
|
|
|
# = sum(loss * sample_weight) / valid
|
|
|
|
total = ops.cast(ops.shape(mask)[0], dtype=dtype)
|
|
|
|
valid = ops.sum(mask) # May be 0!
|
|
|
|
mask *= total / (valid + backend.epsilon())
|
|
|
|
|
|
|
|
if sample_weight is not None:
|
|
|
|
sample_weight = ops.cast(sample_weight, dtype=dtype)
|
|
|
|
mask, sample_weight = squeeze_to_same_rank(mask, sample_weight)
|
|
|
|
sample_weight *= mask
|
|
|
|
else:
|
|
|
|
sample_weight = mask
|
|
|
|
return sample_weight
|