Docstring nits
This commit is contained in:
parent
b4a7f36e8e
commit
b10843ded6
@ -34,43 +34,44 @@ class ScatterUpdate(Operation):
|
||||
|
||||
@keras_core_export("keras_core.operations.scatter_update")
|
||||
def scatter_update(inputs, indices, updates):
|
||||
"""Update inputs by scattering updates at indices.
|
||||
"""Update inputs via updates at scattered (sparse) indices.
|
||||
|
||||
At a high level, this operation does `inputs[indices]=updates`. In details,
|
||||
assume `inputs` is a tensor of shape `[D0, D1, ..., Dn]`, there are 2 main
|
||||
At a high level, this operation does `inputs[indices] = updates`.
|
||||
Assume `inputs` is a tensor of shape `(D0, D1, ..., Dn)`, there are 2 main
|
||||
usages of `scatter_update`.
|
||||
|
||||
- `indices` is a 2D tensor of shape `[num_updates, n]`, where `num_updates`
|
||||
1. `indices` is a 2D tensor of shape `(num_updates, n)`, where `num_updates`
|
||||
is the number of updates to perform, and `updates` is a 1D tensor of
|
||||
shape `[num_updates]`. For example, if `inputs = np.zeros([4, 4, 4])`,
|
||||
shape `(num_updates,)`. For example, if `inputs` is `zeros((4, 4, 4))`,
|
||||
and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then
|
||||
we can use
|
||||
we can use:
|
||||
|
||||
```
|
||||
inputs = np.zeros([4, 4, 4])
|
||||
indices = [[1, 2, 3], [0, 1, 3]]
|
||||
updates = np.array([1., 1.])
|
||||
inputs = keras_core.operations.scatter_update(inputs, indices, updates)
|
||||
```
|
||||
- `indices` is a 2D tensor of shape `[num_updates, k]`, where `num_updates`
|
||||
```python
|
||||
inputs = np.zeros((4, 4, 4))
|
||||
indices = [[1, 2, 3], [0, 1, 3]]
|
||||
updates = np.array([1., 1.])
|
||||
inputs = keras_core.operations.scatter_update(inputs, indices, updates)
|
||||
```
|
||||
|
||||
2 `indices` is a 2D tensor of shape `(num_updates, k)`, where `num_updates`
|
||||
is the number of updates to perform, and `k` (`k < n`) is the size of
|
||||
each index in `indices`. `updates` is a `n-k`-D tensor of shape
|
||||
`[num_updates, inputs.shape[k:]]`. For example, if
|
||||
`inputs = np.zeros([4, 4, 4])`, and we want to update `inputs[1, 2, :]`
|
||||
each index in `indices`. `updates` is a `n - k`-D tensor of shape
|
||||
`(num_updates, inputs.shape[k:])`. For example, if
|
||||
`inputs = np.zeros((4, 4, 4))`, and we want to update `inputs[1, 2, :]`
|
||||
and `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape
|
||||
`[num_updates, 2]` (`k=2`), and `updates` would have shape
|
||||
`[num_updates, 4]` (`inputs.shape[2:]=4`). See the code below:
|
||||
`(num_updates, 2)` (`k = 2`), and `updates` would have shape
|
||||
`(num_updates, 4)` (`inputs.shape[2:] = 4`). See the code below:
|
||||
|
||||
```
|
||||
inputs = np.zeros([4, 4, 4])
|
||||
indices = [[1, 2], [2, 3]]
|
||||
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
|
||||
inputs = keras_core.operations.scatter_update(inputs, indices, updates)
|
||||
```
|
||||
```python
|
||||
inputs = np.zeros((4, 4, 4))
|
||||
indices = [[1, 2], [2, 3]]
|
||||
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
|
||||
inputs = keras_core.operations.scatter_update(inputs, indices, updates)
|
||||
```
|
||||
|
||||
Args:
|
||||
inputs: A tensor, the tensor to be updated.
|
||||
indices: A tensor or list/tuple of shape `[N, inputs.ndims]`, specifying
|
||||
indices: A tensor or list/tuple of shape `(N, inputs.ndim)`, specifying
|
||||
indices to update. `N` is the number of indices to update, must be
|
||||
equal to the first dimension of `updates`.
|
||||
updates: A tensor, the new values to be put to `inputs` at `indices`.
|
||||
@ -96,25 +97,25 @@ def block_update(inputs, start_indices, updates):
|
||||
"""Update inputs block.
|
||||
|
||||
At a high level, this operation does
|
||||
`inputs[start_indices: start_indices + updates.shape] = updates`. In
|
||||
details, assume inputs is a tensor of shape `[D0, D1, ..., Dn]`,
|
||||
`inputs[start_indices: start_indices + updates.shape] = updates`.
|
||||
Assume inputs is a tensor of shape `(D0, D1, ..., Dn)`,
|
||||
`start_indices` must be a list/tuple of n integers, specifying the starting
|
||||
indices. `updates` must have the same rank as `inputs`, and the size of each
|
||||
dim must not exceed `Di - start_indices[i]`. For example, if we have 2D
|
||||
inputs `inputs=np.zeros([5, 5])`, and we want to update the intersection of
|
||||
last 2 rows and last 2 columns as 1, i.e.,
|
||||
`inputs[3:, 3:] = np.ones([2, 2])`, then we can use the code below:
|
||||
inputs `inputs = np.zeros((5, 5))`, and we want to update the intersection
|
||||
of last 2 rows and last 2 columns as 1, i.e.,
|
||||
`inputs[3:, 3:] = np.ones((2, 2))`, then we can use the code below:
|
||||
|
||||
```
|
||||
inputs = np.zeros([5, 5])
|
||||
```python
|
||||
inputs = np.zeros((5, 5))
|
||||
start_indices = [3, 3]
|
||||
updates = np.ones([2, 2])
|
||||
updates = np.ones((2, 2))
|
||||
inputs = keras_core.operations.block_update(inputs, start_indices, updates)
|
||||
```
|
||||
|
||||
Args:
|
||||
inputs: A tensor, the tensor to be updated.
|
||||
start_indices: A list/tuple of shape `[inputs.ndims]`, specifying
|
||||
start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying
|
||||
the starting indices for updating.
|
||||
updates: A tensor, the new values to be put to `inputs` at `indices`.
|
||||
`updates` must have the same rank as `inputs`.
|
||||
|
Loading…
Reference in New Issue
Block a user