Rename ops.reshape
argument from new_shape
to newshape
. (#19097)
This is for consistency with the NumPy API.
This commit is contained in:
parent
1d0008bc70
commit
bf1f463e0e
@ -48,10 +48,10 @@ class KerasTensor:
|
||||
def ndim(self):
|
||||
return len(self.shape)
|
||||
|
||||
def reshape(self, new_shape):
|
||||
def reshape(self, newshape):
|
||||
from keras import ops
|
||||
|
||||
return ops.Reshape(new_shape)(self)
|
||||
return ops.Reshape(newshape)(self)
|
||||
|
||||
def squeeze(self, axis=None):
|
||||
from keras import ops
|
||||
|
@ -647,8 +647,8 @@ def repeat(x, repeats, axis=None):
|
||||
return jnp.repeat(x, repeats, axis=axis)
|
||||
|
||||
|
||||
def reshape(x, new_shape):
|
||||
return jnp.reshape(x, new_shape)
|
||||
def reshape(x, newshape):
|
||||
return jnp.reshape(x, newshape)
|
||||
|
||||
|
||||
def roll(x, shift, axis=None):
|
||||
|
@ -770,8 +770,8 @@ def repeat(x, repeats, axis=None):
|
||||
return np.repeat(x, repeats, axis=axis)
|
||||
|
||||
|
||||
def reshape(x, new_shape):
|
||||
return np.reshape(x, new_shape)
|
||||
def reshape(x, newshape):
|
||||
return np.reshape(x, newshape)
|
||||
|
||||
|
||||
def roll(x, shift, axis=None):
|
||||
|
@ -1339,18 +1339,18 @@ def repeat(x, repeats, axis=None):
|
||||
return tf.repeat(x, repeats, axis=axis)
|
||||
|
||||
|
||||
def reshape(x, new_shape):
|
||||
def reshape(x, newshape):
|
||||
x = convert_to_tensor(x)
|
||||
if isinstance(x, tf.SparseTensor):
|
||||
from keras.ops.operation_utils import compute_reshape_output_shape
|
||||
|
||||
output_shape = compute_reshape_output_shape(
|
||||
x.shape, new_shape, "new_shape"
|
||||
x.shape, newshape, "newshape"
|
||||
)
|
||||
output = tf.sparse.reshape(x, new_shape)
|
||||
output = tf.sparse.reshape(x, newshape)
|
||||
output.set_shape(output_shape)
|
||||
return output
|
||||
return tf.reshape(x, new_shape)
|
||||
return tf.reshape(x, newshape)
|
||||
|
||||
|
||||
def roll(x, shift, axis=None):
|
||||
|
@ -1122,11 +1122,11 @@ def repeat(x, repeats, axis=None):
|
||||
return torch.repeat_interleave(x, repeats, dim=axis)
|
||||
|
||||
|
||||
def reshape(x, new_shape):
|
||||
if not isinstance(new_shape, (list, tuple)):
|
||||
new_shape = (new_shape,)
|
||||
def reshape(x, newshape):
|
||||
if not isinstance(newshape, (list, tuple)):
|
||||
newshape = (newshape,)
|
||||
x = convert_to_tensor(x)
|
||||
return torch.reshape(x, new_shape)
|
||||
return torch.reshape(x, newshape)
|
||||
|
||||
|
||||
def roll(x, shift, axis=None):
|
||||
|
@ -165,7 +165,7 @@ def categorical_accuracy(y_true, y_pred):
|
||||
y_pred = ops.cast(y_pred, dtype=y_true.dtype)
|
||||
matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
|
||||
if reshape_matches:
|
||||
matches = ops.reshape(matches, new_shape=y_true_org_shape)
|
||||
matches = ops.reshape(matches, y_true_org_shape)
|
||||
return matches
|
||||
|
||||
|
||||
@ -251,7 +251,7 @@ def sparse_categorical_accuracy(y_true, y_pred):
|
||||
y_pred = ops.cast(y_pred, y_true.dtype)
|
||||
matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())
|
||||
if reshape_matches:
|
||||
matches = ops.reshape(matches, new_shape=y_true_org_shape)
|
||||
matches = ops.reshape(matches, y_true_org_shape)
|
||||
# if shape is (num_samples, 1) squeeze
|
||||
if len(matches.shape) > 1 and matches.shape[-1] == 1:
|
||||
matches = ops.squeeze(matches, -1)
|
||||
@ -337,7 +337,7 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||
|
||||
# returned matches is expected to have same shape as y_true input
|
||||
if reshape_matches:
|
||||
matches = ops.reshape(matches, new_shape=y_true_org_shape)
|
||||
matches = ops.reshape(matches, y_true_org_shape)
|
||||
|
||||
return matches
|
||||
|
||||
@ -415,7 +415,7 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||
|
||||
# returned matches is expected to have same shape as y_true input
|
||||
if reshape_matches:
|
||||
matches = ops.reshape(matches, new_shape=y_true_org_shape)
|
||||
matches = ops.reshape(matches, y_true_org_shape)
|
||||
|
||||
return matches
|
||||
|
||||
|
@ -4463,28 +4463,28 @@ def repeat(x, repeats, axis=None):
|
||||
|
||||
|
||||
class Reshape(Operation):
|
||||
def __init__(self, new_shape):
|
||||
def __init__(self, newshape):
|
||||
super().__init__()
|
||||
self.new_shape = new_shape
|
||||
self.newshape = newshape
|
||||
|
||||
def call(self, x):
|
||||
return backend.numpy.reshape(x, self.new_shape)
|
||||
return backend.numpy.reshape(x, self.newshape)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
output_shape = operation_utils.compute_reshape_output_shape(
|
||||
x.shape, self.new_shape, "new_shape"
|
||||
x.shape, self.newshape, "newshape"
|
||||
)
|
||||
sparse = getattr(x, "sparse", False)
|
||||
return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse)
|
||||
|
||||
|
||||
@keras_export(["keras.ops.reshape", "keras.ops.numpy.reshape"])
|
||||
def reshape(x, new_shape):
|
||||
def reshape(x, newshape):
|
||||
"""Gives a new shape to a tensor without changing its data.
|
||||
|
||||
Args:
|
||||
x: Input tensor.
|
||||
new_shape: The new shape should be compatible with the original shape.
|
||||
newshape: The new shape should be compatible with the original shape.
|
||||
One shape dimension can be -1 in which case the value is
|
||||
inferred from the length of the array and remaining dimensions.
|
||||
|
||||
@ -4492,8 +4492,8 @@ def reshape(x, new_shape):
|
||||
The reshaped tensor.
|
||||
"""
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Reshape(new_shape).symbolic_call(x)
|
||||
return backend.numpy.reshape(x, new_shape)
|
||||
return Reshape(newshape).symbolic_call(x)
|
||||
return backend.numpy.reshape(x, newshape)
|
||||
|
||||
|
||||
class Roll(Operation):
|
||||
|
@ -272,36 +272,36 @@ def compute_matmul_output_shape(shape1, shape2):
|
||||
return tuple(output_shape)
|
||||
|
||||
|
||||
def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
|
||||
"""Converts `-1` in `new_shape` to either an actual dimension or `None`.
|
||||
def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name):
|
||||
"""Converts `-1` in `newshape` to either an actual dimension or `None`.
|
||||
|
||||
This utility does not special case the 0th dimension (batch size).
|
||||
"""
|
||||
unknown_dim_count = new_shape.count(-1)
|
||||
unknown_dim_count = newshape.count(-1)
|
||||
if unknown_dim_count > 1:
|
||||
raise ValueError(
|
||||
"There must be at most one unknown dimension (-1) in "
|
||||
f"{new_shape_arg_name}. Received: {new_shape_arg_name}={new_shape}."
|
||||
f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}."
|
||||
)
|
||||
|
||||
# If there is a None in input_shape, we can't infer what the -1 is
|
||||
if None in input_shape:
|
||||
return tuple(dim if dim != -1 else None for dim in new_shape)
|
||||
return tuple(dim if dim != -1 else None for dim in newshape)
|
||||
|
||||
input_size = math.prod(input_shape)
|
||||
# If the new_shape fully defined, return it
|
||||
# If the `newshape` is fully defined, return it
|
||||
if unknown_dim_count == 0:
|
||||
if input_size != math.prod(new_shape):
|
||||
if input_size != math.prod(newshape):
|
||||
raise ValueError(
|
||||
"The total size of the tensor must be unchanged. Received: "
|
||||
f"input_shape={input_shape}, {new_shape_arg_name}={new_shape}"
|
||||
f"input_shape={input_shape}, {newshape_arg_name}={newshape}"
|
||||
)
|
||||
return new_shape
|
||||
return newshape
|
||||
|
||||
# We have one -1 in new_shape, compute the actual value
|
||||
# We have one -1 in `newshape`, compute the actual value
|
||||
known_output_size = 1
|
||||
unknown_dim_index = None
|
||||
for index, dim in enumerate(new_shape):
|
||||
for index, dim in enumerate(newshape):
|
||||
if dim == -1:
|
||||
unknown_dim_index = index
|
||||
else:
|
||||
@ -311,11 +311,11 @@ def compute_reshape_output_shape(input_shape, new_shape, new_shape_arg_name):
|
||||
raise ValueError(
|
||||
"The total size of the tensor must be unchanged, however, the "
|
||||
"input size cannot by divided by the specified dimensions in "
|
||||
f"{new_shape_arg_name}. Received: input_shape={input_shape}, "
|
||||
f"{new_shape_arg_name}={new_shape}"
|
||||
f"{newshape_arg_name}. Received: input_shape={input_shape}, "
|
||||
f"{newshape_arg_name}={newshape}"
|
||||
)
|
||||
|
||||
output_shape = list(new_shape)
|
||||
output_shape = list(newshape)
|
||||
output_shape[unknown_dim_index] = input_size // known_output_size
|
||||
return tuple(output_shape)
|
||||
|
||||
|
@ -139,7 +139,7 @@ class OperationUtilsTest(testing.TestCase):
|
||||
input_shape = (1, 4, 4, 1)
|
||||
target_shape = (16, 1)
|
||||
output_shape = operation_utils.compute_reshape_output_shape(
|
||||
input_shape, new_shape=target_shape, new_shape_arg_name="New shape"
|
||||
input_shape, newshape=target_shape, newshape_arg_name="New shape"
|
||||
)
|
||||
self.assertEqual(output_shape, target_shape)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user