diff --git a/keras/backend/common/keras_tensor.py b/keras/backend/common/keras_tensor.py index db2156726..00f07239a 100644 --- a/keras/backend/common/keras_tensor.py +++ b/keras/backend/common/keras_tensor.py @@ -48,10 +48,10 @@ class KerasTensor: def ndim(self): return len(self.shape) - def reshape(self, new_shape): + def reshape(self, newshape): from keras import ops - return ops.Reshape(new_shape)(self) + return ops.Reshape(newshape)(self) def squeeze(self, axis=None): from keras import ops diff --git a/keras/backend/jax/numpy.py b/keras/backend/jax/numpy.py index 50c81b64d..8ce52d752 100644 --- a/keras/backend/jax/numpy.py +++ b/keras/backend/jax/numpy.py @@ -647,8 +647,8 @@ def repeat(x, repeats, axis=None): return jnp.repeat(x, repeats, axis=axis) -def reshape(x, new_shape): - return jnp.reshape(x, new_shape) +def reshape(x, newshape): + return jnp.reshape(x, newshape) def roll(x, shift, axis=None): diff --git a/keras/backend/numpy/numpy.py b/keras/backend/numpy/numpy.py index bc7a1e312..667ef603b 100644 --- a/keras/backend/numpy/numpy.py +++ b/keras/backend/numpy/numpy.py @@ -770,8 +770,8 @@ def repeat(x, repeats, axis=None): return np.repeat(x, repeats, axis=axis) -def reshape(x, new_shape): - return np.reshape(x, new_shape) +def reshape(x, newshape): + return np.reshape(x, newshape) def roll(x, shift, axis=None): diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index 2aca10e2b..6bf567da1 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -1339,18 +1339,18 @@ def repeat(x, repeats, axis=None): return tf.repeat(x, repeats, axis=axis) -def reshape(x, new_shape): +def reshape(x, newshape): x = convert_to_tensor(x) if isinstance(x, tf.SparseTensor): from keras.ops.operation_utils import compute_reshape_output_shape output_shape = compute_reshape_output_shape( - x.shape, new_shape, "new_shape" + x.shape, newshape, "newshape" ) - output = tf.sparse.reshape(x, new_shape) + output = tf.sparse.reshape(x, newshape) output.set_shape(output_shape) return output - return tf.reshape(x, new_shape) + return tf.reshape(x, newshape) def roll(x, shift, axis=None): diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index 3ed5122e2..4da9059b1 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -1122,11 +1122,11 @@ def repeat(x, repeats, axis=None): return torch.repeat_interleave(x, repeats, dim=axis) -def reshape(x, new_shape): - if not isinstance(new_shape, (list, tuple)): - new_shape = (new_shape,) +def reshape(x, newshape): + if not isinstance(newshape, (list, tuple)): + newshape = (newshape,) x = convert_to_tensor(x) - return torch.reshape(x, new_shape) + return torch.reshape(x, newshape) def roll(x, shift, axis=None): diff --git a/keras/metrics/accuracy_metrics.py b/keras/metrics/accuracy_metrics.py index b32b5ec01..a75eb7998 100644 --- a/keras/metrics/accuracy_metrics.py +++ b/keras/metrics/accuracy_metrics.py @@ -165,7 +165,7 @@ def categorical_accuracy(y_true, y_pred): y_pred = ops.cast(y_pred, dtype=y_true.dtype) matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx()) if reshape_matches: - matches = ops.reshape(matches, new_shape=y_true_org_shape) + matches = ops.reshape(matches, y_true_org_shape) return matches @@ -251,7 +251,7 @@ def sparse_categorical_accuracy(y_true, y_pred): y_pred = ops.cast(y_pred, y_true.dtype) matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx()) if reshape_matches: - matches = ops.reshape(matches, new_shape=y_true_org_shape) + matches = ops.reshape(matches, y_true_org_shape) # if shape is (num_samples, 1) squeeze if len(matches.shape) > 1 and matches.shape[-1] == 1: matches = ops.squeeze(matches, -1) @@ -337,7 +337,7 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5): # returned matches is expected to have same shape as y_true input if reshape_matches: - matches = ops.reshape(matches, new_shape=y_true_org_shape) + matches = ops.reshape(matches, y_true_org_shape) return matches @@ -415,7 +415,7 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): # returned matches is expected to have same shape as y_true input if reshape_matches: - matches = ops.reshape(matches, new_shape=y_true_org_shape) + matches = ops.reshape(matches, y_true_org_shape) return matches diff --git a/keras/ops/numpy.py b/keras/ops/numpy.py index 3461d0dc5..332afb833 100644 --- a/keras/ops/numpy.py +++ b/keras/ops/numpy.py @@ -4463,28 +4463,28 @@ def repeat(x, repeats, axis=None): class Reshape(Operation): - def __init__(self, new_shape): + def __init__(self, newshape): super().__init__() - self.new_shape = new_shape + self.newshape = newshape def call(self, x): - return backend.numpy.reshape(x, self.new_shape) + return backend.numpy.reshape(x, self.newshape) def compute_output_spec(self, x): output_shape = operation_utils.compute_reshape_output_shape( - x.shape, self.new_shape, "new_shape" + x.shape, self.newshape, "newshape" ) sparse = getattr(x, "sparse", False) return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse) @keras_export(["keras.ops.reshape", "keras.ops.numpy.reshape"]) -def reshape(x, new_shape): +def reshape(x, newshape): """Gives a new shape to a tensor without changing its data. Args: x: Input tensor. - new_shape: The new shape should be compatible with the original shape. + newshape: The new shape should be compatible with the original shape. One shape dimension can be -1 in which case the value is inferred from the length of the array and remaining dimensions. @@ -4492,8 +4492,8 @@ def reshape(x, new_shape): The reshaped tensor. """ if any_symbolic_tensors((x,)): - return Reshape(new_shape).symbolic_call(x) - return backend.numpy.reshape(x, new_shape) + return Reshape(newshape).symbolic_call(x) + return backend.numpy.reshape(x, newshape) class Roll(Operation): diff --git a/keras/ops/operation_utils.py b/keras/ops/operation_utils.py index 72d112377..7ef1e60c0 100644 --- a/keras/ops/operation_utils.py +++ b/keras/ops/operation_utils.py @@ -272,36 +272,36 @@ def compute_matmul_output_shape(shape1, shape2): return tuple(output_shape) -def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name): - """Converts `-1` in `new_shape` to either an actual dimension or `None`. +def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name): + """Converts `-1` in `newshape` to either an actual dimension or `None`. This utility does not special case the 0th dimension (batch size). """ - unknown_dim_count = new_shape.count(-1) + unknown_dim_count = newshape.count(-1) if unknown_dim_count > 1: raise ValueError( "There must be at most one unknown dimension (-1) in " - f"{new_shape_arg_name}. Received: {new_shape_arg_name}={new_shape}." + f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}." ) # If there is a None in input_shape, we can't infer what the -1 is if None in input_shape: - return tuple(dim if dim != -1 else None for dim in new_shape) + return tuple(dim if dim != -1 else None for dim in newshape) input_size = math.prod(input_shape) - # If the new_shape fully defined, return it + # If the `newshape` is fully defined, return it if unknown_dim_count == 0: - if input_size != math.prod(new_shape): + if input_size != math.prod(newshape): raise ValueError( "The total size of the tensor must be unchanged. Received: " - f"input_shape={input_shape}, {new_shape_arg_name}={new_shape}" + f"input_shape={input_shape}, {newshape_arg_name}={newshape}" ) - return new_shape + return newshape - # We have one -1 in new_shape, compute the actual value + # We have one -1 in `newshape`, compute the actual value known_output_size = 1 unknown_dim_index = None - for index, dim in enumerate(new_shape): + for index, dim in enumerate(newshape): if dim == -1: unknown_dim_index = index else: @@ -311,11 +311,11 @@ def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name): raise ValueError( "The total size of the tensor must be unchanged, however, the " "input size cannot by divided by the specified dimensions in " - f"{new_shape_arg_name}. Received: input_shape={input_shape}, " - f"{new_shape_arg_name}={new_shape}" + f"{newshape_arg_name}. Received: input_shape={input_shape}, " + f"{newshape_arg_name}={newshape}" ) - output_shape = list(new_shape) + output_shape = list(newshape) output_shape[unknown_dim_index] = input_size // known_output_size return tuple(output_shape) diff --git a/keras/ops/operation_utils_test.py b/keras/ops/operation_utils_test.py index 8d8b066bf..5cf304f9a 100644 --- a/keras/ops/operation_utils_test.py +++ b/keras/ops/operation_utils_test.py @@ -139,7 +139,7 @@ class OperationUtilsTest(testing.TestCase): input_shape = (1, 4, 4, 1) target_shape = (16, 1) output_shape = operation_utils.compute_reshape_output_shape( - input_shape, new_shape=target_shape, new_shape_arg_name="New shape" + input_shape, newshape=target_shape, newshape_arg_name="New shape" ) self.assertEqual(output_shape, target_shape)