2023-04-09 12:21:45 -07:00

186 lines
5.9 KiB

from keras_core import operations as ops
from keras_core import backend
from keras_core.engine.naming import auto_name
class Loss:
def __init__(self, name=None, reduction="sum_over_batch_size"): = name or auto_name(self.__class__.__name__)
self.reduction = standardize_reduction(reduction)
def __call__(self, y_true, y_pred, sample_weight=None):
in_mask = getattr(y_pred, "_keras_mask", None)
with ops.name_scope(
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
losses =, y_pred)
out_mask = getattr(losses, "_keras_mask", None)
if in_mask is not None and out_mask is not None:
mask = in_mask & out_mask
elif in_mask is not None:
mask = in_mask
elif out_mask is not None:
mask = out_mask
mask = None
return reduce_weighted_loss(
losses, sample_weight=sample_weight, mask=mask, reduction=self.reduction
def call(self, y_true, y_pred):
raise NotImplementedError
def get_config(self):
return {"name":}
def from_config(cls, config):
return cls(**config)
def standardize_reduction(reduction):
allowed = {"sum_over_batch_size", "sum", None}
if not reduction in allowed:
raise ValueError(
"Invalid value for argument `reduction`. "
f"Expected on of {allowed}. Received: "
return reduction
def squeeze_to_same_rank(x1, x2):
"""Squeeze last dim if ranks differ from expected by exactly 1."""
x1_rank = len(x1.shape)
x2_rank = len(x2.shape)
if x1_rank == x2_rank:
return x1, x2
if x1_rank == x2_rank + 1:
if x1.shape[-1] == 1:
x1 = ops.squeeze(x1, axis=-1)
if x2_rank == x1_rank + 1:
if x2.shape[-1] == 1:
x2 = ops.squeeze(x2, axis=-1)
return x1, x2
def reduce_loss(losses, reduction="sum_over_batch_size"):
if reduction is None or tuple(losses.shape) == () or tuple(losses.shape) == (0,):
return losses
loss = ops.sum(losses)
if reduction == "sum_over_batch_size":
loss /= ops.cast(ops.shape(losses)[0], loss.dtype)
return loss
def reduce_weighted_loss(
reduction = standardize_reduction(reduction)
losses = ops.convert_to_tensor(losses)
if sample_weight is not None:
sample_weight = ops.convert_to_tensor(sample_weight, dtype=losses.dtype)
if mask is not None:
mask = ops.convert_to_tensor(mask, dtype=losses.dtype)
# Merge mask and sample weight into sample weight.
sample_weight = apply_mask(
sample_weight, mask, dtype=losses.dtype, reduction=reduction
# Convert any non float dtypes to floats, to avoid loss of precision
# for dtype like int or bool.
dtype = backend.standardize_dtype(losses.dtype)
if not is_float(dtype):
input_dtype = losses.dtype
losses = ops.cast(losses, "float32")
input_casted = True
input_casted = False
if sample_weight is not None:
sample_weight = ops.cast(sample_weight, losses.dtype)
# Update dimensions of `sample_weight` to match `losses`.
losses, sample_weight = squeeze_to_same_rank(losses, sample_weight)
losses *= sample_weight
# Apply reduction function to the individual weighted losses.
loss = reduce_loss(losses, reduction)
if input_casted:
# Convert the result back to the input type.
loss = ops.cast(loss, input_dtype)
return loss
def float_dtype_size(dtype):
if dtype in ("bfloat16", "float16"):
return 16
if dtype == "float32":
return 32
if dtype == "float64":
return 64
raise ValueError(f"Invalid dtype: {dtype}")
def is_float(dtype):
return "float" in dtype
def cast_to_common_dtype(tensors):
"""Cast a list of tensors to a common dtype.
If any tensor is floating-point, they will all be casted to the most-precise
floating-point dtype. Otherwise the tensors are not casted.
tensors: A list of tensors.
Same list, casted to a common dtype.
highest_float = None
for x in tensors:
dtype = backend.standardize_dtype(x.dtype)
if is_float(dtype):
if highest_float is None or float_dtype_size(dtype) > highest_float:
highest_float = dtype
elif dtype == "float16" and highest_float == "bfloat16":
highest_float = "float32"
if highest_float:
tensors = [ops.cast(x, highest_float) for x in tensors]
return tensors
def apply_mask(sample_weight, mask, dtype, reduction):
"""Applies any mask on predictions to sample weights."""
if mask is not None:
mask = ops.cast(mask, dtype=dtype)
if reduction == "sum_over_batch_size":
# Valid entries have weight `total/valid`, while invalid ones
# have 0. When summed over batch, they will be reduced to:
# mean(loss * sample_weight * total / valid)
# = sum(loss * sample_weight * total / valid) / total
# = sum(loss * sample_weight) / total * total / valid
# = sum(loss * sample_weight) / valid
total = ops.cast(ops.shape(mask)[0], dtype=dtype)
valid = ops.sum(mask) # May be 0!
mask *= total / (valid + backend.epsilon())
if sample_weight is not None:
sample_weight = ops.cast(sample_weight, dtype=dtype)
mask, sample_weight = squeeze_to_same_rank(mask, sample_weight)
sample_weight *= mask
sample_weight = mask
return sample_weight