bug fix for torch sample_weights on *_on_batch methods (#332)

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
Haifeng Jin 2023-06-12 13:17:09 -07:00 committed by Francois Chollet
parent b2fe70618f
commit 49c6ddbc42

@ -146,7 +146,7 @@ def reduce_weighted_values(
sample_weight = ops.cast(sample_weight, values.dtype)
# Update dimensions of `sample_weight` to match `losses`.
values, sample_weight = squeeze_to_same_rank(values, sample_weight)
values *= sample_weight
values = values * sample_weight
# Apply reduction function to the individual weighted losses.
loss = reduce_values(values, reduction)