keras/keras_core/operations/nn.py
Chen Qian 43e33ab9ab Add Conv layers (#89)
* initials

* more

* something

* add docstrings

* fix some docstrings

* fix comments
2023-05-05 22:13:13 -07:00

1048 lines
33 KiB
Python

"""
relu
relu6
sigmoid
softplus
softsign
silu
swish
log_sigmoid
leaky_relu
hard_sigmoid
elu
selu
gelu
softmax
log_softmax
max_pooling
average_pooling
conv
depthwise_conv
separable_conv
conv_transpose
one_hot
top_k
in_top_k
ctc ??
"""
import numpy as np
from keras_core import backend
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
from keras_core.operations import operation_utils
from keras_core.operations.operation import Operation
class Relu(Operation):
def call(self, x):
return backend.nn.relu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def relu(x):
if any_symbolic_tensors((x,)):
return Relu().symbolic_call(x)
return backend.nn.relu(x)
class Relu6(Operation):
def call(self, x):
return backend.nn.relu6(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def relu6(x):
if any_symbolic_tensors((x,)):
return Relu6().symbolic_call(x)
return backend.nn.relu6(x)
class Sigmoid(Operation):
def call(self, x):
return backend.nn.sigmoid(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def sigmoid(x):
if any_symbolic_tensors((x,)):
return Sigmoid().symbolic_call(x)
return backend.nn.sigmoid(x)
class Tanh(Operation):
def call(self, x):
return backend.nn.tanh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def tanh(x):
if any_symbolic_tensors((x,)):
return Tanh().symbolic_call(x)
return backend.nn.tanh(x)
class Softplus(Operation):
def call(self, x):
return backend.nn.softplus(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def softplus(x):
if any_symbolic_tensors((x,)):
return Softplus().symbolic_call(x)
return backend.nn.softplus(x)
class Softsign(Operation):
def call(self, x):
return backend.nn.softsign(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def softsign(x):
if any_symbolic_tensors((x,)):
return Softsign().symbolic_call(x)
return backend.nn.softsign(x)
class Silu(Operation):
def call(self, x):
return backend.nn.silu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def silu(x):
if any_symbolic_tensors((x,)):
return Silu().symbolic_call(x)
return backend.nn.silu(x)
class Swish(Operation):
def call(self, x):
return backend.nn.swish(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def swish(x):
if any_symbolic_tensors((x,)):
return Swish().symbolic_call(x)
return backend.nn.swish(x)
class LogSigmoid(Operation):
def call(self, x):
return backend.nn.log_sigmoid(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def log_sigmoid(x):
if any_symbolic_tensors((x,)):
return LogSigmoid().symbolic_call(x)
return backend.nn.log_sigmoid(x)
class LeakyRelu(Operation):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def call(self, x):
return backend.nn.leaky_relu(x, self.negative_slope)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def leaky_relu(x, negative_slope=0.2):
if any_symbolic_tensors((x,)):
return LeakyRelu(negative_slope).symbolic_call(x)
return backend.nn.leaky_relu(x, negative_slope=negative_slope)
class HardSigmoid(Operation):
def call(self, x):
return backend.nn.hard_sigmoid(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def hard_sigmoid(x):
if any_symbolic_tensors((x,)):
return HardSigmoid().symbolic_call(x)
return backend.nn.hard_sigmoid(x)
class Elu(Operation):
def call(self, x):
return backend.nn.elu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def elu(x):
if any_symbolic_tensors((x,)):
return Elu().symbolic_call(x)
return backend.nn.elu(x)
class Selu(Operation):
def call(self, x):
return backend.nn.selu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def selu(x):
if any_symbolic_tensors((x,)):
return Selu().symbolic_call(x)
return backend.nn.selu(x)
class Gelu(Operation):
def __init__(self, approximate=True):
super().__init__()
self.approximate = approximate
def call(self, x):
return backend.nn.gelu(x, self.approximate)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def gelu(x, approximate=True):
if any_symbolic_tensors((x,)):
return Gelu(approximate).symbolic_call(x)
return backend.nn.gelu(x, approximate)
class Softmax(Operation):
def __init__(self, axis=None):
super().__init__()
self.axis = axis if axis is not None else -1
def call(self, x):
return backend.nn.softmax(x, axis=self.axis)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def softmax(x, axis=None):
if any_symbolic_tensors((x,)):
return Softmax(axis).symbolic_call(x)
return backend.nn.softmax(x, axis=axis)
class LogSoftmax(Operation):
def __init__(self, axis=None):
super().__init__()
self.axis = axis if axis is not None else -1
def call(self, x):
return backend.nn.log_softmax(x, axis=self.axis)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def log_softmax(x, axis=None):
if any_symbolic_tensors((x,)):
return LogSoftmax(axis).symbolic_call(x)
return backend.nn.log_softmax(x, axis=axis)
class MaxPool(Operation):
def __init__(
self,
pool_size,
strides=None,
padding="valid",
data_format="channels_last",
):
super().__init__()
self.pool_size = pool_size
self.strides = strides
self.padding = padding
self.data_format = data_format
def call(self, inputs):
return backend.nn.max_pool(
inputs,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
def compute_output_spec(self, inputs):
output_shape = operation_utils.compute_pooling_output_shape(
inputs.shape,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
return KerasTensor(output_shape, dtype=inputs.dtype)
def max_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format="channels_last",
):
"""Max pooling operation.
Args:
inputs: Tensor of rank N+2. `inputs` has shape
[batch_size] + inputs_spatial_shape + [num_channels] if
`data_format="channels_last"`, or
[batch_size, num_channels] + inputs_spatial_shape if
`data_format="channels_first"`. Pooling happens over the spatial
dimensions only.
pool_size: int or tuple/list of integers of size
`len(inputs_spatial_shape)`, specifying the size of the pooling
window for each spatial dimension of the input tensor. If
`pool_size` is int, then every spatial dimension shares the same
`pool_size`.
strides: int or tuple/list of integers of size
`len(inputs_spatial_shape)`. The stride of the sliding window for
each spatial dimension of the input tensor. If `strides` is int,
then every spatial dimension shares the same `strides`.
padding: string, either `"valid"` or `"same"`. `"valid"` means no
padding is applied, and "same" results in padding evenly to the
left/right or up/down of the input such that output has the
same height/width dimension as the input when `strides=1`.
data_format: A string, either "channels_last" or `channels_first`.
`data_format` determines the ordering of the dimensions in the
inputs. If `data_format="channels_last"`, inputs is of shape
(batch_size, spatial_shape, channels) while if
`data_format="channels_first"`, inputs is of shape
(batch_size, channels, spatial_shape).
Returns:
A tensor of rank N+2, the result of the max pooling operation.
"""
if any_symbolic_tensors((inputs,)):
return MaxPool(
pool_size,
strides,
padding,
data_format,
).symbolic_call(inputs)
return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format)
class AveragePool(Operation):
def __init__(
self,
pool_size,
strides=None,
padding="valid",
data_format="channels_last",
):
super().__init__()
self.pool_size = pool_size
self.strides = strides
self.padding = padding
self.data_format = data_format
def call(self, inputs):
return backend.nn.average_pool(
inputs,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
def compute_output_spec(self, inputs):
output_shape = operation_utils.compute_pooling_output_shape(
inputs.shape,
self.pool_size,
self.strides,
self.padding,
self.data_format,
)
return KerasTensor(output_shape, dtype=inputs.dtype)
def average_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format="channels_last",
):
"""Average pooling operation.
Args:
inputs: Tensor of rank N+2. `inputs` has shape
[batch_size] + inputs_spatial_shape + [num_channels] if
`data_format="channels_last"`, or
[batch_size, num_channels] + inputs_spatial_shape if
`data_format="channels_first"`. Pooling happens over the spatial
dimensions only.
pool_size: int or tuple/list of integers of size
`len(inputs_spatial_shape)`, specifying the size of the pooling
window for each spatial dimension of the input tensor. If
`pool_size` is int, then every spatial dimension shares the same
`pool_size`.
strides: int or tuple/list of integers of size
`len(inputs_spatial_shape)`. The stride of the sliding window for
each spatial dimension of the input tensor. If `strides` is int,
then every spatial dimension shares the same `strides`.
padding: string, either `"valid"` or `"same"`. `"valid"` means no
padding is applied, and "same" results in padding evenly to the
left/right or up/down of the input such that output has the
same height/width dimension as the input when `strides=1`.
data_format: A string, either "channels_last" or `channels_first`.
`data_format` determines the ordering of the dimensions in the
inputs. If `data_format="channels_last"`, inputs is of shape
(batch_size, spatial_shape, channels) while if
`data_format="channels_first"`, inputs is of shape
(batch_size, channels, spatial_shape).
Returns:
A tensor of rank N+2, the result of the average pooling operation.
"""
if any_symbolic_tensors((inputs,)):
return AveragePool(
pool_size,
strides,
padding,
data_format,
).symbolic_call(inputs)
return backend.nn.average_pool(
inputs, pool_size, strides, padding, data_format
)
class Conv(Operation):
def __init__(
self,
strides=1,
padding="valid",
data_format="channel_last",
dilation_rate=1,
):
super().__init__()
self.strides = strides
self.padding = padding
self.data_format = data_format
self.dilation_rate = dilation_rate
def call(self, inputs, kernel):
return backend.nn.conv(
inputs,
kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
def compute_output_spec(self, inputs, kernel):
output_shape = operation_utils.compute_conv_output_shape(
inputs.shape,
kernel.shape[-1],
kernel.shape[:-2],
self.strides,
self.padding,
self.data_format,
self.dilation_rate,
)
return KerasTensor(output_shape, dtype=inputs.dtype)
def conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
"""General N-D convolution.
This ops supports 1D, 2D and 3D convolution.
Args:
inputs: Tensor of rank N+2. `inputs` has shape
[batch_size] + inputs_spatial_shape + [num_channels] if
`data_format="channels_last"`, or
[batch_size, num_channels] + inputs_spatial_shape if
`data_format="channels_first"`. Pooling happens over the spatial
dimensions only.
kernel: Tensor of rank N+2. `kernel` has shape
[kernel_spatial_shape, num_input_channels, num_output_channels],
`num_input_channels` should match the number of channels in
`inputs`.
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the strides of the convolution along each spatial
dimension. If `strides` is int, then every spatial dimension shares
the same `strides`.
padding: string, either `"valid"` or `"same"`. `"valid"` means no
padding is applied, and "same" results in padding evenly to the
left/right or up/down of the input such that output has the
same height/width dimension as the input when `strides=1`.
data_format: A string, either "channels_last" or `channels_first`.
`data_format` determines the ordering of the dimensions in the
inputs. If `data_format="channels_last"`, inputs is of shape
(batch_size, spatial_shape, channels) while if
`data_format="channels_first"`, inputs is of shape
(batch_size, channels, spatial_shape).
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the dilation rate to use for dilated convolution. If
`dilation_rate` is int, then every spatial dimension shares
the same `dilation_rate`.
Returns:
A tensor of rank N+2, the result of the conv operation.
"""
if any_symbolic_tensors((inputs,)):
return Conv(strides, padding, data_format, dilation_rate).symbolic_call(
inputs, kernel
)
return backend.nn.conv(
inputs, kernel, strides, padding, data_format, dilation_rate
)
class DepthwiseConv(Operation):
def __init__(
self,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
super().__init__()
self.strides = strides
self.padding = padding
self.data_format = data_format
self.dilation_rate = dilation_rate
def call(self, inputs, kernel):
return backend.nn.depthwise_conv(
inputs,
kernel,
self.strides,
self.padding,
self.data_format,
self.dilation_rate,
)
def compute_output_spec(self, inputs, kernel):
input_shape = inputs.shape
if self.data_format == "channels_last":
spatial_shape = input_shape[1:-1]
else:
spatial_shape = input_shape[2:]
if len(kernel.shape) != len(inputs.shape):
raise ValueError(
"Kernel shape must have the same length as input, but received "
f"kernel of shape {kernel.shape} and "
f"input of shape {input_shape}."
)
if isinstance(self.dilation_rate, int):
dilation_rate = (self.dilation_rate,) * len(spatial_shape)
else:
dilation_rate = self.dilation_rate
if len(dilation_rate) != len(spatial_shape):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={self.dilation_rate}` and input of "
f"shape {input_shape}."
)
spatial_shape = np.array(spatial_shape)
kernel_spatial_shape = np.array(kernel.shape[:-2])
dilation_rate = np.array(self.dilation_rate)
if self.padding == "valid":
output_spatial_shape = (
np.floor(
(
spatial_shape
- dilation_rate * (kernel_spatial_shape - 1)
- 1
)
/ self.strides
)
+ 1
)
negative_in_shape = np.all(output_spatial_shape < 0)
if negative_in_shape:
raise ValueError(
"Computed output size would be negative. Received "
f"`inputs shape={inputs.shape}`, "
f"`kernel spatial size={kernel.size}`, "
f"`dilation_rate={self.dilation_rate}`."
)
elif self.padding == "same":
output_spatial_shape = (
np.floor((spatial_shape - 1) / self.strides) + 1
)
output_spatial_shape = [int(i) for i in output_spatial_shape]
output_channels = kernel.shape[-1] * kernel.shape[-2]
if self.data_format == "channels_last":
output_shape = (
[input_shape[0]] + output_spatial_shape + [output_channels]
)
else:
output_shape = [
input_shape[0],
output_channels,
] + output_spatial_shape
return KerasTensor(output_shape, dtype=inputs.dtype)
def depthwise_conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
"""General N-D depthwise convolution.
This ops supports 1D and 2D depthwise convolution.
Args:
inputs: Tensor of rank N+2. `inputs` has shape
[batch_size] + inputs_spatial_shape + [num_channels] if
`data_format="channels_last"`, or
[batch_size, num_channels] + inputs_spatial_shape if
`data_format="channels_first"`. Pooling happens over the spatial
dimensions only.
kernel: Tensor of rank N+2. `kernel` has shape
[kernel_spatial_shape, num_input_channels, num_channels_multiplier],
`num_input_channels` should match the number of channels in
`inputs`.
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the strides of the convolution along each spatial
dimension. If `strides` is int, then every spatial dimension shares
the same `strides`.
padding: string, either `"valid"` or `"same"`. `"valid"` means no
padding is applied, and "same" results in padding evenly to the
left/right or up/down of the input such that output has the
same height/width dimension as the input when `strides=1`.
data_format: A string, either "channels_last" or `channels_first`.
`data_format` determines the ordering of the dimensions in the
inputs. If `data_format="channels_last"`, inputs is of shape
(batch_size, spatial_shape, channels) while if
`data_format="channels_first"`, inputs is of shape
(batch_size, channels, spatial_shape).
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the dilation rate to use for dilated convolution. If
`dilation_rate` is int, then every spatial dimension shares
the same `dilation_rate`.
Returns:
A tensor of rank N+2, the result of the depthwise conv operation.
"""
if any_symbolic_tensors((inputs,)):
return DepthwiseConv(
strides, padding, data_format, dilation_rate
).symbolic_call(inputs, kernel)
return backend.nn.depthwise_conv(
inputs,
kernel,
strides,
padding,
data_format,
dilation_rate,
)
class SeparableConv(Operation):
def __init__(
self,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
super().__init__()
self.strides = strides
self.padding = padding
self.data_format = data_format
self.dilation_rate = dilation_rate
def call(self, inputs, depthwise_kernel, pointwise_kernel):
return backend.nn.separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
self.strides,
self.padding,
self.data_format,
self.dilation_rate,
)
def compute_output_spec(self, inputs, depthwise_kernel, pointwise_kernel):
output_shape = list(
depthwise_conv(
inputs,
depthwise_kernel,
self.strides,
self.padding,
self.data_format,
self.dilation_rate,
).shape
)
if self.data_format == "channels_last":
output_shape[-1] = pointwise_kernel.shape[-1]
else:
output_shape[1] = pointwise_kernel.shape[-1]
return KerasTensor(output_shape, dtype=inputs.dtype)
def separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
"""General N-D separable convolution.
This ops supports 1D and 2D separable convolution. `separable_conv` is
a depthwise conv followed by a pointwise conv.
Args:
inputs: Tensor of rank N+2. `inputs` has shape
[batch_size] + inputs_spatial_shape + [num_channels] if
`data_format="channels_last"`, or
[batch_size, num_channels] + inputs_spatial_shape if
`data_format="channels_first"`. Pooling happens over the spatial
dimensions only.
depthwise_kernel: Tensor of rank N+2. `depthwise_kernel` has shape
[kernel_spatial_shape, num_input_channels, num_channels_multiplier],
`num_input_channels` should match the number of channels in
`inputs`.
pointwise_kernel: Tensor of rank N+2. `pointwise_kernel` has shape
[ones_like(kernel_spatial_shape),
num_input_channels * num_channels_multiplier, num_output_channels].
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the strides of the convolution along each spatial
dimension. If `strides` is int, then every spatial dimension shares
the same `strides`.
padding: string, either `"valid"` or `"same"`. `"valid"` means no
padding is applied, and "same" results in padding evenly to the
left/right or up/down of the input such that output has the
same height/width dimension as the input when `strides=1`.
data_format: A string, either "channels_last" or `channels_first`.
`data_format` determines the ordering of the dimensions in the
inputs. If `data_format="channels_last"`, inputs is of shape
(batch_size, spatial_shape, channels) while if
`data_format="channels_first"`, inputs is of shape
(batch_size, channels, spatial_shape).
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the dilation rate to use for dilated convolution. If
`dilation_rate` is int, then every spatial dimension shares
the same `dilation_rate`.
Returns:
A tensor of rank N+2, the result of the depthwise conv operation.
"""
if any_symbolic_tensors((inputs,)):
return SeparableConv(
strides,
padding,
data_format,
dilation_rate,
).symbolic_call(inputs, depthwise_kernel, pointwise_kernel)
return backend.nn.separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
class ConvTranspose(Operation):
def __init__(
self,
strides,
padding="valid",
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
super().__init__()
self.strides = strides
self.output_padding = output_padding
self.padding = padding
self.data_format = data_format
self.dilation_rate = dilation_rate
def call(
self,
inputs,
kernel,
):
return backend.nn.conv_transpose(
inputs,
kernel,
self.strides,
self.output_padding,
self.padding,
self.data_format,
self.dilation_rate,
)
def compute_output_spec(self, inputs, kernel):
output_shape = compute_conv_transpose_output_shape(
inputs,
kernel,
self.strides,
self.padding,
self.output_padding,
self.data_format,
self.dilation_rate,
)
return KerasTensor(output_shape, dtype=inputs.dtype)
def conv_transpose(
inputs,
kernel,
strides,
padding="valid",
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
"""General N-D convolution transpose.
Also known as de-convolution. This ops supports 1D, 2D and 3D convolution.
Args:
inputs: Tensor of rank N+2. `inputs` has shape
[batch_size] + inputs_spatial_shape + [num_channels] if
`data_format="channels_last"`, or
[batch_size, num_channels] + inputs_spatial_shape if
`data_format="channels_first"`. Pooling happens over the spatial
dimensions only.
kernel: Tensor of rank N+2. `kernel` has shape
[kernel_spatial_shape, num_output_channels, num_input_channels],
`num_input_channels` should match the number of channels in
`inputs`.
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the strides of the convolution along each spatial
dimension. If `strides` is int, then every spatial dimension shares
the same `strides`.
padding: string, either `"valid"` or `"same"`. `"valid"` means no
padding is applied, and "same" results in padding evenly to the
left/right or up/down of the input such that output has the
same height/width dimension as the input when `strides=1`.
output_padding: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the amount of padding along the height and width of
the output tensor. Can be a single integer to specify the same
value for all spatial dimensions. The amount of output padding
along a given dimension must be lower than the stride along that
same dimension. If set to None (default), the output shape is
inferred.
data_format: A string, either "channels_last" or `channels_first`.
`data_format` determines the ordering of the dimensions in the
inputs. If `data_format="channels_last"`, inputs is of shape
(batch_size, spatial_shape, channels) while if
`data_format="channels_first"`, inputs is of shape
(batch_size, channels, spatial_shape).
dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,
specifying the dilation rate to use for dilated convolution. If
`dilation_rate` is int, then every spatial dimension shares
the same `dilation_rate`.
Returns:
A tensor of rank N+2, the result of the conv operation.
"""
if any_symbolic_tensors((inputs,)):
return ConvTranspose(
strides, padding, output_padding, data_format, dilation_rate
).symbolic_call(inputs, kernel)
return backend.nn.conv_transpose(
inputs,
kernel,
strides,
padding,
output_padding,
data_format,
dilation_rate,
)
class OneHot(Operation):
def __init__(self, num_classes, axis=-1):
super().__init__()
self.num_classes = num_classes
self.axis = axis
def call(self, x):
return backend.nn.one_hot(x, self.num_classes, axis=self.axis)
def compute_output_spec(self, x):
x_shape = list(getattr(x, "shape", []))
if self.axis == -1:
x_shape.append(self.num_classes)
elif self.axis >= 0 and self.axis < len(x_shape):
x_shape.insert(self.axis, self.num_classes)
else:
raise ValueError(
f"axis must be -1 or between [0, {len(x.shape)}), but "
f"received {self.axis}."
)
return KerasTensor(x_shape)
def one_hot(x, num_classes, axis=-1):
if any_symbolic_tensors((x,)):
return OneHot(num_classes, axis=axis).symbolic_call(x)
return backend.nn.one_hot(x, num_classes, axis=axis)
class BinaryCrossentropy(Operation):
def __init__(self, from_logits=False):
super().__init__()
self.from_logits = from_logits
def call(self, target, output):
return backend.nn.binary_crossentropy(
target, output, from_logits=self.from_logits
)
def compute_output_spec(self, target, output):
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
return KerasTensor(output.shape, dtype=output.dtype)
def binary_crossentropy(target, output, from_logits=False):
if any_symbolic_tensors((target, output)):
return BinaryCrossentropy(from_logits=from_logits).symbolic_call(
target, output
)
return backend.nn.binary_crossentropy(
target, output, from_logits=from_logits
)
class CategoricalCrossentropy(Operation):
def __init__(self, from_logits=False, axis=-1):
super().__init__()
self.from_logits = from_logits
self.axis = axis
def call(self, target, output):
return backend.nn.categorical_crossentropy(
target, output, from_logits=self.from_logits, axis=self.axis
)
def compute_output_spec(self, target, output):
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if len(target.shape) < 1:
raise ValueError(
"Arguments `target` and `output` must be at least rank 1. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
return KerasTensor(output.shape[:-1], dtype=output.dtype)
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
if any_symbolic_tensors((target, output)):
return CategoricalCrossentropy(
from_logits=from_logits, axis=axis
).symbolic_call(target, output)
return backend.nn.categorical_crossentropy(
target, output, from_logits=from_logits, axis=axis
)
class SparseCategoricalCrossentropy(Operation):
def __init__(self, from_logits=False, axis=-1):
super().__init__()
self.from_logits = from_logits
self.axis = axis
def call(self, target, output):
return backend.nn.sparse_categorical_crossentropy(
target, output, from_logits=self.from_logits, axis=self.axis
)
def compute_output_spec(self, target, output):
if len(output.shape) < 1:
raise ValueError(
"Argument `output` must be at least rank 1. "
"Received: "
f"output.shape={output.shape}"
)
target_shape = target.shape
if len(target_shape) == len(output.shape) and target_shape[-1] == 1:
target_shape = target_shape[:-1]
if target_shape != output.shape[:-1]:
raise ValueError(
"Arguments `target` and `output` must have the same shape "
"up until the last dimension: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
return KerasTensor(output.shape[:-1], dtype=output.dtype)
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
if any_symbolic_tensors((target, output)):
return SparseCategoricalCrossentropy(
from_logits=from_logits, axis=axis
).symbolic_call(target, output)
return backend.nn.sparse_categorical_crossentropy(
target, output, from_logits=from_logits, axis=axis
)