Docstring nits

This commit is contained in:
Francois Chollet 2023-05-27 09:44:52 -07:00
parent b4a7f36e8e
commit b10843ded6

@ -34,35 +34,36 @@ class ScatterUpdate(Operation):
@keras_core_export("keras_core.operations.scatter_update") @keras_core_export("keras_core.operations.scatter_update")
def scatter_update(inputs, indices, updates): def scatter_update(inputs, indices, updates):
"""Update inputs by scattering updates at indices. """Update inputs via updates at scattered (sparse) indices.
At a high level, this operation does `inputs[indices]=updates`. In details, At a high level, this operation does `inputs[indices] = updates`.
assume `inputs` is a tensor of shape `[D0, D1, ..., Dn]`, there are 2 main Assume `inputs` is a tensor of shape `(D0, D1, ..., Dn)`, there are 2 main
usages of `scatter_update`. usages of `scatter_update`.
- `indices` is a 2D tensor of shape `[num_updates, n]`, where `num_updates` 1. `indices` is a 2D tensor of shape `(num_updates, n)`, where `num_updates`
is the number of updates to perform, and `updates` is a 1D tensor of is the number of updates to perform, and `updates` is a 1D tensor of
shape `[num_updates]`. For example, if `inputs = np.zeros([4, 4, 4])`, shape `(num_updates,)`. For example, if `inputs` is `zeros((4, 4, 4))`,
and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then
we can use we can use:
``` ```python
inputs = np.zeros([4, 4, 4]) inputs = np.zeros((4, 4, 4))
indices = [[1, 2, 3], [0, 1, 3]] indices = [[1, 2, 3], [0, 1, 3]]
updates = np.array([1., 1.]) updates = np.array([1., 1.])
inputs = keras_core.operations.scatter_update(inputs, indices, updates) inputs = keras_core.operations.scatter_update(inputs, indices, updates)
``` ```
- `indices` is a 2D tensor of shape `[num_updates, k]`, where `num_updates`
2 `indices` is a 2D tensor of shape `(num_updates, k)`, where `num_updates`
is the number of updates to perform, and `k` (`k < n`) is the size of is the number of updates to perform, and `k` (`k < n`) is the size of
each index in `indices`. `updates` is a `n - k`-D tensor of shape each index in `indices`. `updates` is a `n - k`-D tensor of shape
`[num_updates, inputs.shape[k:]]`. For example, if `(num_updates, inputs.shape[k:])`. For example, if
`inputs = np.zeros([4, 4, 4])`, and we want to update `inputs[1, 2, :]` `inputs = np.zeros((4, 4, 4))`, and we want to update `inputs[1, 2, :]`
and `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape and `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape
`[num_updates, 2]` (`k=2`), and `updates` would have shape `(num_updates, 2)` (`k = 2`), and `updates` would have shape
`[num_updates, 4]` (`inputs.shape[2:]=4`). See the code below: `(num_updates, 4)` (`inputs.shape[2:] = 4`). See the code below:
``` ```python
inputs = np.zeros([4, 4, 4]) inputs = np.zeros((4, 4, 4))
indices = [[1, 2], [2, 3]] indices = [[1, 2], [2, 3]]
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,]) updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
inputs = keras_core.operations.scatter_update(inputs, indices, updates) inputs = keras_core.operations.scatter_update(inputs, indices, updates)
@ -70,7 +71,7 @@ def scatter_update(inputs, indices, updates):
Args: Args:
inputs: A tensor, the tensor to be updated. inputs: A tensor, the tensor to be updated.
indices: A tensor or list/tuple of shape `[N, inputs.ndims]`, specifying indices: A tensor or list/tuple of shape `(N, inputs.ndim)`, specifying
indices to update. `N` is the number of indices to update, must be indices to update. `N` is the number of indices to update, must be
equal to the first dimension of `updates`. equal to the first dimension of `updates`.
updates: A tensor, the new values to be put to `inputs` at `indices`. updates: A tensor, the new values to be put to `inputs` at `indices`.
@ -96,25 +97,25 @@ def block_update(inputs, start_indices, updates):
"""Update inputs block. """Update inputs block.
At a high level, this operation does At a high level, this operation does
`inputs[start_indices: start_indices + updates.shape] = updates`. In `inputs[start_indices: start_indices + updates.shape] = updates`.
details, assume inputs is a tensor of shape `[D0, D1, ..., Dn]`, Assume inputs is a tensor of shape `(D0, D1, ..., Dn)`,
`start_indices` must be a list/tuple of n integers, specifying the starting `start_indices` must be a list/tuple of n integers, specifying the starting
indices. `updates` must have the same rank as `inputs`, and the size of each indices. `updates` must have the same rank as `inputs`, and the size of each
dim must not exceed `Di - start_indices[i]`. For example, if we have 2D dim must not exceed `Di - start_indices[i]`. For example, if we have 2D
inputs `inputs=np.zeros([5, 5])`, and we want to update the intersection of inputs `inputs = np.zeros((5, 5))`, and we want to update the intersection
last 2 rows and last 2 columns as 1, i.e., of last 2 rows and last 2 columns as 1, i.e.,
`inputs[3:, 3:] = np.ones([2, 2])`, then we can use the code below: `inputs[3:, 3:] = np.ones((2, 2))`, then we can use the code below:
``` ```python
inputs = np.zeros([5, 5]) inputs = np.zeros((5, 5))
start_indices = [3, 3] start_indices = [3, 3]
updates = np.ones([2, 2]) updates = np.ones((2, 2))
inputs = keras_core.operations.block_update(inputs, start_indices, updates) inputs = keras_core.operations.block_update(inputs, start_indices, updates)
``` ```
Args: Args:
inputs: A tensor, the tensor to be updated. inputs: A tensor, the tensor to be updated.
start_indices: A list/tuple of shape `[inputs.ndims]`, specifying start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying
the starting indices for updating. the starting indices for updating.
updates: A tensor, the new values to be put to `inputs` at `indices`. updates: A tensor, the new values to be put to `inputs` at `indices`.
`updates` must have the same rank as `inputs`. `updates` must have the same rank as `inputs`.