From 49c6ddbc428829d89c1cbdb413465079f00ddcc3 Mon Sep 17 00:00:00 2001 From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Date: Mon, 12 Jun 2023 13:17:09 -0700 Subject: [PATCH] bug fix for torch sample_weights on *_on_batch methods (#332) Co-authored-by: Haifeng Jin --- keras_core/losses/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/losses/loss.py b/keras_core/losses/loss.py index 660bf146d..6aaaac5ac 100644 --- a/keras_core/losses/loss.py +++ b/keras_core/losses/loss.py @@ -146,7 +146,7 @@ def reduce_weighted_values( sample_weight = ops.cast(sample_weight, values.dtype) # Update dimensions of `sample_weight` to match `losses`. values, sample_weight = squeeze_to_same_rank(values, sample_weight) - values *= sample_weight + values = values * sample_weight # Apply reduction function to the individual weighted losses. loss = reduce_values(values, reduction)