bug fix for torch sample_weights on *_on_batch methods (#332)
Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
parent
b2fe70618f
commit
49c6ddbc42
@ -146,7 +146,7 @@ def reduce_weighted_values(
|
||||
sample_weight = ops.cast(sample_weight, values.dtype)
|
||||
# Update dimensions of `sample_weight` to match `losses`.
|
||||
values, sample_weight = squeeze_to_same_rank(values, sample_weight)
|
||||
values *= sample_weight
|
||||
values = values * sample_weight
|
||||
|
||||
# Apply reduction function to the individual weighted losses.
|
||||
loss = reduce_values(values, reduction)
|
||||
|
Loading…
Reference in New Issue
Block a user