186 lines
5.9 KiB
Python
186 lines
5.9 KiB
Python
from keras_core import operations as ops
|
|
from keras_core import backend
|
|
from keras_core.engine.naming import auto_name
|
|
|
|
|
|
class Loss:
|
|
def __init__(self, name=None, reduction="sum_over_batch_size"):
|
|
self.name = name or auto_name(self.__class__.__name__)
|
|
self.reduction = standardize_reduction(reduction)
|
|
|
|
def __call__(self, y_true, y_pred, sample_weight=None):
|
|
in_mask = getattr(y_pred, "_keras_mask", None)
|
|
|
|
with ops.name_scope(self.name):
|
|
y_pred = ops.convert_to_tensor(y_pred)
|
|
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
|
|
|
losses = self.call(y_true, y_pred)
|
|
out_mask = getattr(losses, "_keras_mask", None)
|
|
|
|
if in_mask is not None and out_mask is not None:
|
|
mask = in_mask & out_mask
|
|
elif in_mask is not None:
|
|
mask = in_mask
|
|
elif out_mask is not None:
|
|
mask = out_mask
|
|
else:
|
|
mask = None
|
|
|
|
return reduce_weighted_loss(
|
|
losses, sample_weight=sample_weight, mask=mask, reduction=self.reduction
|
|
)
|
|
|
|
def call(self, y_true, y_pred):
|
|
raise NotImplementedError
|
|
|
|
def get_config(self):
|
|
return {"name": self.name}
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
return cls(**config)
|
|
|
|
|
|
def standardize_reduction(reduction):
|
|
allowed = {"sum_over_batch_size", "sum", None}
|
|
if not reduction in allowed:
|
|
raise ValueError(
|
|
"Invalid value for argument `reduction`. "
|
|
f"Expected on of {allowed}. Received: "
|
|
f"reduction={reduction}"
|
|
)
|
|
return reduction
|
|
|
|
|
|
def squeeze_to_same_rank(x1, x2):
|
|
"""Squeeze last dim if ranks differ from expected by exactly 1."""
|
|
x1_rank = len(x1.shape)
|
|
x2_rank = len(x2.shape)
|
|
if x1_rank == x2_rank:
|
|
return x1, x2
|
|
if x1_rank == x2_rank + 1:
|
|
if x1.shape[-1] == 1:
|
|
x1 = ops.squeeze(x1, axis=-1)
|
|
if x2_rank == x1_rank + 1:
|
|
if x2.shape[-1] == 1:
|
|
x2 = ops.squeeze(x2, axis=-1)
|
|
return x1, x2
|
|
|
|
|
|
def reduce_loss(losses, reduction="sum_over_batch_size"):
|
|
if reduction is None or tuple(losses.shape) == () or tuple(losses.shape) == (0,):
|
|
return losses
|
|
loss = ops.sum(losses)
|
|
if reduction == "sum_over_batch_size":
|
|
loss /= ops.cast(ops.shape(losses)[0], loss.dtype)
|
|
return loss
|
|
|
|
|
|
def reduce_weighted_loss(
|
|
losses,
|
|
sample_weight=None,
|
|
mask=None,
|
|
reduction="sum_over_batch_size",
|
|
):
|
|
reduction = standardize_reduction(reduction)
|
|
|
|
losses = ops.convert_to_tensor(losses)
|
|
if sample_weight is not None:
|
|
sample_weight = ops.convert_to_tensor(sample_weight, dtype=losses.dtype)
|
|
if mask is not None:
|
|
mask = ops.convert_to_tensor(mask, dtype=losses.dtype)
|
|
|
|
# Merge mask and sample weight into sample weight.
|
|
sample_weight = apply_mask(
|
|
sample_weight, mask, dtype=losses.dtype, reduction=reduction
|
|
)
|
|
|
|
# Convert any non float dtypes to floats, to avoid loss of precision
|
|
# for dtype like int or bool.
|
|
dtype = backend.standardize_dtype(losses.dtype)
|
|
if not is_float(dtype):
|
|
input_dtype = losses.dtype
|
|
losses = ops.cast(losses, "float32")
|
|
input_casted = True
|
|
else:
|
|
input_casted = False
|
|
|
|
if sample_weight is not None:
|
|
sample_weight = ops.cast(sample_weight, losses.dtype)
|
|
# Update dimensions of `sample_weight` to match `losses`.
|
|
losses, sample_weight = squeeze_to_same_rank(losses, sample_weight)
|
|
losses *= sample_weight
|
|
|
|
# Apply reduction function to the individual weighted losses.
|
|
loss = reduce_loss(losses, reduction)
|
|
|
|
if input_casted:
|
|
# Convert the result back to the input type.
|
|
loss = ops.cast(loss, input_dtype)
|
|
return loss
|
|
|
|
|
|
def float_dtype_size(dtype):
|
|
if dtype in ("bfloat16", "float16"):
|
|
return 16
|
|
if dtype == "float32":
|
|
return 32
|
|
if dtype == "float64":
|
|
return 64
|
|
raise ValueError(f"Invalid dtype: {dtype}")
|
|
|
|
|
|
def is_float(dtype):
|
|
return "float" in dtype
|
|
|
|
|
|
def cast_to_common_dtype(tensors):
|
|
"""Cast a list of tensors to a common dtype.
|
|
|
|
If any tensor is floating-point, they will all be casted to the most-precise
|
|
floating-point dtype. Otherwise the tensors are not casted.
|
|
|
|
Args:
|
|
tensors: A list of tensors.
|
|
|
|
Returns:
|
|
Same list, casted to a common dtype.
|
|
"""
|
|
highest_float = None
|
|
for x in tensors:
|
|
dtype = backend.standardize_dtype(x.dtype)
|
|
if is_float(dtype):
|
|
if highest_float is None or float_dtype_size(dtype) > highest_float:
|
|
highest_float = dtype
|
|
elif dtype == "float16" and highest_float == "bfloat16":
|
|
highest_float = "float32"
|
|
if highest_float:
|
|
tensors = [ops.cast(x, highest_float) for x in tensors]
|
|
return tensors
|
|
|
|
|
|
def apply_mask(sample_weight, mask, dtype, reduction):
|
|
"""Applies any mask on predictions to sample weights."""
|
|
if mask is not None:
|
|
mask = ops.cast(mask, dtype=dtype)
|
|
if reduction == "sum_over_batch_size":
|
|
# Valid entries have weight `total/valid`, while invalid ones
|
|
# have 0. When summed over batch, they will be reduced to:
|
|
#
|
|
# mean(loss * sample_weight * total / valid)
|
|
# = sum(loss * sample_weight * total / valid) / total
|
|
# = sum(loss * sample_weight) / total * total / valid
|
|
# = sum(loss * sample_weight) / valid
|
|
total = ops.cast(ops.shape(mask)[0], dtype=dtype)
|
|
valid = ops.sum(mask) # May be 0!
|
|
mask *= total / (valid + backend.epsilon())
|
|
|
|
if sample_weight is not None:
|
|
sample_weight = ops.cast(sample_weight, dtype=dtype)
|
|
mask, sample_weight = squeeze_to_same_rank(mask, sample_weight)
|
|
sample_weight *= mask
|
|
else:
|
|
sample_weight = mask
|
|
return sample_weight
|