Add support for sample_weights in CompileLoss (#370)

* Add support for sample_weights in CompileLoss

* is not None
This commit is contained in:
Ian Stenbit 2023-06-17 11:51:55 -06:00 committed by Francois Chollet
parent 6b0cb5598a
commit c6d71e6a68
2 changed files with 26 additions and 6 deletions

@ -548,19 +548,35 @@ class CompileLoss(losses_module.Loss):
self.flat_loss_weights = flat_loss_weights
self.built = True
def call(self, y_true, y_pred):
def __call__(self, y_true, y_pred, sample_weight=None):
with ops.name_scope(self.name):
return self.call(y_true, y_pred, sample_weight)
def call(self, y_true, y_pred, sample_weight=None):
if not self.built:
self.build(y_true, y_pred)
y_true = nest.flatten(y_true)
y_pred = nest.flatten(y_pred)
if sample_weight is not None:
sample_weight = nest.flatten(sample_weight)
else:
sample_weight = [None for _ in y_true]
loss_values = []
for loss, y_t, y_p, w in zip(
self.flat_losses, y_true, y_pred, self.flat_loss_weights
for loss, y_t, y_p, loss_weight, sample_weight in zip(
self.flat_losses,
y_true,
y_pred,
self.flat_loss_weights,
sample_weight,
):
if loss:
value = w * ops.cast(loss(y_t, y_p), dtype=backend.floatx())
value = loss_weight * ops.cast(
loss(y_t, y_p, sample_weight), dtype=backend.floatx()
)
loss_values.append(value)
if loss_values:
total_loss = sum(loss_values)

@ -270,6 +270,10 @@ class TestCompileLoss(testing.TestCase, parameterized.TestCase):
"a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),
"b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),
}
sample_weight = {
"a": np.array([1.0, 2.0, 3.0]),
"b": np.array([3.0, 2.0, 1.0]),
}
compile_loss.build(y_true, y_pred)
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 0.953333, atol=1e-5)
value = compile_loss(y_true, y_pred, sample_weight)
self.assertAllClose(value, 1.266666, atol=1e-5)