From 42bdabf76ad0f1b03f2f662b573bdee627d1561f Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:04:11 -0500 Subject: [PATCH] Resolve shape via ops in broadcast_to operation (#547) --- keras_core/metrics/iou_metrics.py | 2 +- keras_core/metrics/metric.py | 4 +++- keras_core/metrics/metrics_utils.py | 6 +++--- keras_core/metrics/regression_metrics.py | 10 ++++++---- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/keras_core/metrics/iou_metrics.py b/keras_core/metrics/iou_metrics.py index 66357b466..fbec25aad 100644 --- a/keras_core/metrics/iou_metrics.py +++ b/keras_core/metrics/iou_metrics.py @@ -106,7 +106,7 @@ class _IoUBase(Metric): if len(sample_weight.shape) > 1: sample_weight = ops.reshape(sample_weight, [-1]) - sample_weight = ops.broadcast_to(sample_weight, y_true.shape) + sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true)) if self.ignore_class is not None: ignore_class = ops.convert_to_tensor( diff --git a/keras_core/metrics/metric.py b/keras_core/metrics/metric.py index 530587198..a7d324bea 100644 --- a/keras_core/metrics/metric.py +++ b/keras_core/metrics/metric.py @@ -72,7 +72,9 @@ class Metric: values = ops.cast(values, self.dtype) if sample_weight is not None: sample_weight = ops.cast(sample_weight, self.dtype) - sample_weight = ops.broadcast_to(sample_weight, values.shape) + sample_weight = ops.broadcast_to( + sample_weight, ops.shape(values) + ) values = ops.multiply(values, sample_weight) self.true_positives.assign(self.true_positives + ops.sum(values)) diff --git a/keras_core/metrics/metrics_utils.py b/keras_core/metrics/metrics_utils.py index 6b5a3a61a..8a450a075 100644 --- a/keras_core/metrics/metrics_utils.py +++ b/keras_core/metrics/metrics_utils.py @@ -195,7 +195,7 @@ def _update_confusion_matrix_variables_optimized( sample_weights = 1.0 else: sample_weights = ops.broadcast_to( - ops.cast(sample_weights, dtype=y_pred.dtype), y_pred.shape + ops.cast(sample_weights, dtype=y_pred.dtype), ops.shape(y_pred) ) if not multi_label: sample_weights = ops.reshape(sample_weights, [-1]) @@ -203,7 +203,7 @@ def _update_confusion_matrix_variables_optimized( label_weights = 1.0 else: label_weights = ops.expand_dims(label_weights, 0) - label_weights = ops.broadcast_to(label_weights, y_pred.shape) + label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred)) if not multi_label: label_weights = ops.reshape(label_weights, [-1]) weights = ops.cast( @@ -533,7 +533,7 @@ def update_confusion_matrix_variables( if label_weights is not None and not multi_label: label_weights = ops.expand_dims(label_weights, 0) - label_weights = ops.broadcast_to(label_weights, y_pred.shape) + label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred)) label_weights_tiled = ops.tile( ops.reshape(label_weights, thresh_tiles), data_tiles ) diff --git a/keras_core/metrics/regression_metrics.py b/keras_core/metrics/regression_metrics.py index 74793804b..abba6a359 100644 --- a/keras_core/metrics/regression_metrics.py +++ b/keras_core/metrics/regression_metrics.py @@ -428,7 +428,6 @@ class R2Score(reduction_metrics.Metric): shape=(), initializer=initializers.Zeros(), name="num_samples", - dtype="int32", ) self._built = False @@ -500,16 +499,19 @@ class R2Score(reduction_metrics.Metric): # Make sure there's a features dimension sample_weight = ops.expand_dims(sample_weight, axis=1) - sample_weight = ops.broadcast_to(sample_weight, y_true.shape) + sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true)) - weighted_y_true = y_true * sample_weight + weighted_y_true = y_true * ops.cast(sample_weight, y_true.dtype) self.sum.assign(self.sum + ops.sum(weighted_y_true, axis=0)) self.squared_sum.assign( self.squared_sum + ops.sum(y_true * weighted_y_true, axis=0) ) self.total_mse.assign( self.total_mse - + ops.sum((y_true - y_pred) ** 2 * sample_weight, axis=0) + + ops.sum( + (y_true - y_pred) ** 2 * ops.cast(sample_weight, y_true.dtype), + axis=0, + ) ) self.count.assign(self.count + ops.sum(sample_weight, axis=0)) self.num_samples.assign(self.num_samples + ops.size(y_true))