from keras_core import operations as ops from keras_core import backend from keras_core.engine.naming import auto_name class Loss: def __init__(self, name=None, reduction="sum_over_batch_size"): self.name = name or auto_name(self.__class__.__name__) self.reduction = standardize_reduction(reduction) def __call__(self, y_true, y_pred, sample_weight=None): in_mask = getattr(y_pred, "_keras_mask", None) with ops.name_scope(self.name): y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) losses = self.call(y_true, y_pred) out_mask = getattr(losses, "_keras_mask", None) if in_mask is not None and out_mask is not None: mask = in_mask & out_mask elif in_mask is not None: mask = in_mask elif out_mask is not None: mask = out_mask else: mask = None return reduce_weighted_loss( losses, sample_weight=sample_weight, mask=mask, reduction=self.reduction ) def call(self, y_true, y_pred): raise NotImplementedError def get_config(self): return {"name": self.name} @classmethod def from_config(cls, config): return cls(**config) def standardize_reduction(reduction): allowed = {"sum_over_batch_size", "sum", None} if not reduction in allowed: raise ValueError( "Invalid value for argument `reduction`. " f"Expected on of {allowed}. Received: " f"reduction={reduction}" ) return reduction def squeeze_to_same_rank(x1, x2): """Squeeze last dim if ranks differ from expected by exactly 1.""" x1_rank = len(x1.shape) x2_rank = len(x2.shape) if x1_rank == x2_rank: return x1, x2 if x1_rank == x2_rank + 1: if x1.shape[-1] == 1: x1 = ops.squeeze(x1, axis=-1) if x2_rank == x1_rank + 1: if x2.shape[-1] == 1: x2 = ops.squeeze(x2, axis=-1) return x1, x2 def reduce_loss(losses, reduction="sum_over_batch_size"): if reduction is None or tuple(losses.shape) == () or tuple(losses.shape) == (0,): return losses loss = ops.sum(losses) if reduction == "sum_over_batch_size": loss /= ops.cast(ops.shape(losses)[0], loss.dtype) return loss def reduce_weighted_loss( losses, sample_weight=None, mask=None, reduction="sum_over_batch_size", ): reduction = standardize_reduction(reduction) losses = ops.convert_to_tensor(losses) if sample_weight is not None: sample_weight = ops.convert_to_tensor(sample_weight, dtype=losses.dtype) if mask is not None: mask = ops.convert_to_tensor(mask, dtype=losses.dtype) # Merge mask and sample weight into sample weight. sample_weight = apply_mask( sample_weight, mask, dtype=losses.dtype, reduction=reduction ) # Convert any non float dtypes to floats, to avoid loss of precision # for dtype like int or bool. dtype = backend.standardize_dtype(losses.dtype) if not is_float(dtype): input_dtype = losses.dtype losses = ops.cast(losses, "float32") input_casted = True else: input_casted = False if sample_weight is not None: sample_weight = ops.cast(sample_weight, losses.dtype) # Update dimensions of `sample_weight` to match `losses`. losses, sample_weight = squeeze_to_same_rank(losses, sample_weight) losses *= sample_weight # Apply reduction function to the individual weighted losses. loss = reduce_loss(losses, reduction) if input_casted: # Convert the result back to the input type. loss = ops.cast(loss, input_dtype) return loss def float_dtype_size(dtype): if dtype in ("bfloat16", "float16"): return 16 if dtype == "float32": return 32 if dtype == "float64": return 64 raise ValueError(f"Invalid dtype: {dtype}") def is_float(dtype): return "float" in dtype def cast_to_common_dtype(tensors): """Cast a list of tensors to a common dtype. If any tensor is floating-point, they will all be casted to the most-precise floating-point dtype. Otherwise the tensors are not casted. Args: tensors: A list of tensors. Returns: Same list, casted to a common dtype. """ highest_float = None for x in tensors: dtype = backend.standardize_dtype(x.dtype) if is_float(dtype): if highest_float is None or float_dtype_size(dtype) > highest_float: highest_float = dtype elif dtype == "float16" and highest_float == "bfloat16": highest_float = "float32" if highest_float: tensors = [ops.cast(x, highest_float) for x in tensors] return tensors def apply_mask(sample_weight, mask, dtype, reduction): """Applies any mask on predictions to sample weights.""" if mask is not None: mask = ops.cast(mask, dtype=dtype) if reduction == "sum_over_batch_size": # Valid entries have weight `total/valid`, while invalid ones # have 0. When summed over batch, they will be reduced to: # # mean(loss * sample_weight * total / valid) # = sum(loss * sample_weight * total / valid) / total # = sum(loss * sample_weight) / total * total / valid # = sum(loss * sample_weight) / valid total = ops.cast(ops.shape(mask)[0], dtype=dtype) valid = ops.sum(mask) # May be 0! mask *= total / (valid + backend.epsilon()) if sample_weight is not None: sample_weight = ops.cast(sample_weight, dtype=dtype) mask, sample_weight = squeeze_to_same_rank(mask, sample_weight) sample_weight *= mask else: sample_weight = mask return sample_weight