Add support for sample_weights in CompileLoss (#370)
* Add support for sample_weights in CompileLoss * is not None
This commit is contained in:
parent
6b0cb5598a
commit
c6d71e6a68
@ -548,19 +548,35 @@ class CompileLoss(losses_module.Loss):
|
||||
self.flat_loss_weights = flat_loss_weights
|
||||
self.built = True
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
def __call__(self, y_true, y_pred, sample_weight=None):
|
||||
with ops.name_scope(self.name):
|
||||
return self.call(y_true, y_pred, sample_weight)
|
||||
|
||||
def call(self, y_true, y_pred, sample_weight=None):
|
||||
if not self.built:
|
||||
self.build(y_true, y_pred)
|
||||
|
||||
y_true = nest.flatten(y_true)
|
||||
y_pred = nest.flatten(y_pred)
|
||||
|
||||
if sample_weight is not None:
|
||||
sample_weight = nest.flatten(sample_weight)
|
||||
else:
|
||||
sample_weight = [None for _ in y_true]
|
||||
|
||||
loss_values = []
|
||||
|
||||
for loss, y_t, y_p, w in zip(
|
||||
self.flat_losses, y_true, y_pred, self.flat_loss_weights
|
||||
for loss, y_t, y_p, loss_weight, sample_weight in zip(
|
||||
self.flat_losses,
|
||||
y_true,
|
||||
y_pred,
|
||||
self.flat_loss_weights,
|
||||
sample_weight,
|
||||
):
|
||||
if loss:
|
||||
value = w * ops.cast(loss(y_t, y_p), dtype=backend.floatx())
|
||||
value = loss_weight * ops.cast(
|
||||
loss(y_t, y_p, sample_weight), dtype=backend.floatx()
|
||||
)
|
||||
loss_values.append(value)
|
||||
if loss_values:
|
||||
total_loss = sum(loss_values)
|
||||
|
@ -270,6 +270,10 @@ class TestCompileLoss(testing.TestCase, parameterized.TestCase):
|
||||
"a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),
|
||||
"b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),
|
||||
}
|
||||
sample_weight = {
|
||||
"a": np.array([1.0, 2.0, 3.0]),
|
||||
"b": np.array([3.0, 2.0, 1.0]),
|
||||
}
|
||||
compile_loss.build(y_true, y_pred)
|
||||
value = compile_loss(y_true, y_pred)
|
||||
self.assertAllClose(value, 0.953333, atol=1e-5)
|
||||
value = compile_loss(y_true, y_pred, sample_weight)
|
||||
self.assertAllClose(value, 1.266666, atol=1e-5)
|
||||
|
Loading…
Reference in New Issue
Block a user