Rename ops.reshape argument from new_shape to newshape. (#19097)

This is for consistency with the NumPy API.
This commit is contained in:
hertschuh 2024-01-24 13:51:19 -08:00 committed by GitHub
parent 1d0008bc70
commit bf1f463e0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 41 additions and 41 deletions

@ -48,10 +48,10 @@ class KerasTensor:
def ndim(self):
return len(self.shape)
def reshape(self, new_shape):
def reshape(self, newshape):
from keras import ops
return ops.Reshape(new_shape)(self)
return ops.Reshape(newshape)(self)
def squeeze(self, axis=None):
from keras import ops

@ -647,8 +647,8 @@ def repeat(x, repeats, axis=None):
return jnp.repeat(x, repeats, axis=axis)
def reshape(x, new_shape):
return jnp.reshape(x, new_shape)
def reshape(x, newshape):
return jnp.reshape(x, newshape)
def roll(x, shift, axis=None):

@ -770,8 +770,8 @@ def repeat(x, repeats, axis=None):
return np.repeat(x, repeats, axis=axis)
def reshape(x, new_shape):
return np.reshape(x, new_shape)
def reshape(x, newshape):
return np.reshape(x, newshape)
def roll(x, shift, axis=None):

@ -1339,18 +1339,18 @@ def repeat(x, repeats, axis=None):
return tf.repeat(x, repeats, axis=axis)
def reshape(x, new_shape):
def reshape(x, newshape):
x = convert_to_tensor(x)
if isinstance(x, tf.SparseTensor):
from keras.ops.operation_utils import compute_reshape_output_shape
output_shape = compute_reshape_output_shape(
x.shape, new_shape, "new_shape"
x.shape, newshape, "newshape"
)
output = tf.sparse.reshape(x, new_shape)
output = tf.sparse.reshape(x, newshape)
output.set_shape(output_shape)
return output
return tf.reshape(x, new_shape)
return tf.reshape(x, newshape)
def roll(x, shift, axis=None):

@ -1122,11 +1122,11 @@ def repeat(x, repeats, axis=None):
return torch.repeat_interleave(x, repeats, dim=axis)
def reshape(x, new_shape):
if not isinstance(new_shape, (list, tuple)):
new_shape = (new_shape,)
def reshape(x, newshape):
if not isinstance(newshape, (list, tuple)):
newshape = (newshape,)
x = convert_to_tensor(x)
return torch.reshape(x, new_shape)
return torch.reshape(x, newshape)
def roll(x, shift, axis=None):

@ -165,7 +165,7 @@ def categorical_accuracy(y_true, y_pred):
y_pred = ops.cast(y_pred, dtype=y_true.dtype)
matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
if reshape_matches:
matches = ops.reshape(matches, new_shape=y_true_org_shape)
matches = ops.reshape(matches, y_true_org_shape)
return matches
@ -251,7 +251,7 @@ def sparse_categorical_accuracy(y_true, y_pred):
y_pred = ops.cast(y_pred, y_true.dtype)
matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
if reshape_matches:
matches = ops.reshape(matches, new_shape=y_true_org_shape)
matches = ops.reshape(matches, y_true_org_shape)
# if shape is (num_samples, 1) squeeze
if len(matches.shape) > 1 and matches.shape[-1] == 1:
matches = ops.squeeze(matches, -1)
@ -337,7 +337,7 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
# returned matches is expected to have same shape as y_true input
if reshape_matches:
matches = ops.reshape(matches, new_shape=y_true_org_shape)
matches = ops.reshape(matches, y_true_org_shape)
return matches
@ -415,7 +415,7 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
# returned matches is expected to have same shape as y_true input
if reshape_matches:
matches = ops.reshape(matches, new_shape=y_true_org_shape)
matches = ops.reshape(matches, y_true_org_shape)
return matches

@ -4463,28 +4463,28 @@ def repeat(x, repeats, axis=None):
class Reshape(Operation):
def __init__(self, new_shape):
def __init__(self, newshape):
super().__init__()
self.new_shape = new_shape
self.newshape = newshape
def call(self, x):
return backend.numpy.reshape(x, self.new_shape)
return backend.numpy.reshape(x, self.newshape)
def compute_output_spec(self, x):
output_shape = operation_utils.compute_reshape_output_shape(
x.shape, self.new_shape, "new_shape"
x.shape, self.newshape, "newshape"
)
sparse = getattr(x, "sparse", False)
return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse)
@keras_export(["keras.ops.reshape", "keras.ops.numpy.reshape"])
def reshape(x, new_shape):
def reshape(x, newshape):
"""Gives a new shape to a tensor without changing its data.
Args:
x: Input tensor.
new_shape: The new shape should be compatible with the original shape.
newshape: The new shape should be compatible with the original shape.
One shape dimension can be -1 in which case the value is
inferred from the length of the array and remaining dimensions.
@ -4492,8 +4492,8 @@ def reshape(x, new_shape):
The reshaped tensor.
"""
if any_symbolic_tensors((x,)):
return Reshape(new_shape).symbolic_call(x)
return backend.numpy.reshape(x, new_shape)
return Reshape(newshape).symbolic_call(x)
return backend.numpy.reshape(x, newshape)
class Roll(Operation):

@ -272,36 +272,36 @@ def compute_matmul_output_shape(shape1, shape2):
return tuple(output_shape)
def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
"""Converts `-1` in `new_shape` to either an actual dimension or `None`.
def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name):
"""Converts `-1` in `newshape` to either an actual dimension or `None`.
This utility does not special case the 0th dimension (batch size).
"""
unknown_dim_count = new_shape.count(-1)
unknown_dim_count = newshape.count(-1)
if unknown_dim_count > 1:
raise ValueError(
"There must be at most one unknown dimension (-1) in "
f"{new_shape_arg_name}. Received: {new_shape_arg_name}={new_shape}."
f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}."
)
# If there is a None in input_shape, we can't infer what the -1 is
if None in input_shape:
return tuple(dim if dim != -1 else None for dim in new_shape)
return tuple(dim if dim != -1 else None for dim in newshape)
input_size = math.prod(input_shape)
# If the new_shape fully defined, return it
# If the `newshape` is fully defined, return it
if unknown_dim_count == 0:
if input_size != math.prod(new_shape):
if input_size != math.prod(newshape):
raise ValueError(
"The total size of the tensor must be unchanged. Received: "
f"input_shape={input_shape}, {new_shape_arg_name}={new_shape}"
f"input_shape={input_shape}, {newshape_arg_name}={newshape}"
)
return new_shape
return newshape
# We have one -1 in new_shape, compute the actual value
# We have one -1 in `newshape`, compute the actual value
known_output_size = 1
unknown_dim_index = None
for index, dim in enumerate(new_shape):
for index, dim in enumerate(newshape):
if dim == -1:
unknown_dim_index = index
else:
@ -311,11 +311,11 @@ def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
raise ValueError(
"The total size of the tensor must be unchanged, however, the "
"input size cannot by divided by the specified dimensions in "
f"{new_shape_arg_name}. Received: input_shape={input_shape}, "
f"{new_shape_arg_name}={new_shape}"
f"{newshape_arg_name}. Received: input_shape={input_shape}, "
f"{newshape_arg_name}={newshape}"
)
output_shape = list(new_shape)
output_shape = list(newshape)
output_shape[unknown_dim_index] = input_size // known_output_size
return tuple(output_shape)

@ -139,7 +139,7 @@ class OperationUtilsTest(testing.TestCase):
input_shape = (1, 4, 4, 1)
target_shape = (16, 1)
output_shape = operation_utils.compute_reshape_output_shape(
input_shape, new_shape=target_shape, new_shape_arg_name="New shape"
input_shape, newshape=target_shape, newshape_arg_name="New shape"
)
self.assertEqual(output_shape, target_shape)