keras/keras_core/backend/jax/nn.py
Chen Qian 7b38c23893 Add conv ops to torch backend (#206)
* Add pooling in torch backend

* Add torch pooling

* Add conv ops in torch

* fix tests

* fix error msg
2023-05-23 15:46:26 -07:00

468 lines
12 KiB
Python

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax import nn as jnn
from keras_core.backend import standardize_data_format
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding,
)
from keras_core.backend.config import epsilon
from keras_core.backend.jax.core import convert_to_tensor
def relu(x):
return jnn.relu(x)
def relu6(x):
return jnn.relu6(x)
def sigmoid(x):
return jnn.sigmoid(x)
def tanh(x):
return jnn.tanh(x)
def softplus(x):
return jnn.softplus(x)
def softsign(x):
return jnn.soft_sign(x)
def silu(x):
return jnn.silu(x)
def swish(x):
return jnn.swish(x)
def log_sigmoid(x):
return jnn.log_sigmoid(x)
def leaky_relu(x, negative_slope=0.2):
return jnn.leaky_relu(x, negative_slope=negative_slope)
def hard_sigmoid(x):
return jnn.hard_sigmoid(x)
def elu(x, alpha=1.0):
return jnn.elu(x, alpha=alpha)
def selu(x):
return jnn.selu(x)
def gelu(x, approximate=True):
return jnn.gelu(x, approximate)
def softmax(x, axis=None):
return jnn.softmax(x, axis=axis)
def log_softmax(x, axis=-1):
return jnn.log_softmax(x, axis=axis)
def _convert_to_spatial_operand(
x,
num_spatial_dims,
data_format="channels_last",
include_batch_and_channels=True,
):
# Helper function that converts an operand to a spatial operand.
x = (x,) * num_spatial_dims if isinstance(x, int) else x
if not include_batch_and_channels:
return x
if data_format == "channels_last":
x = (1,) + x + (1,)
else:
x = (1,) + (1,) + x
return x
def _pool(
inputs,
initial_value,
reduce_fn,
pool_size,
strides=None,
padding="valid",
):
"""Helper function to define pooling functions.
Args:
inputs: input data of shape `N+2`.
initial_value: the initial value for the reduction.
reduce_fn: a reduce function of the form `(T, T) -> T`.
pool_size: a sequence of `N` integers, representing the window size to
reduce over.
strides: a sequence of `N` integers, representing the inter-window
strides (default: `(1, ..., 1)`).
padding: either the string `same` or `valid`.
Returns:
The output of the reduction for each window slice.
"""
if padding not in ("same", "valid"):
raise ValueError(
f"Invalid padding '{padding}', must be 'same' or 'valid'."
)
padding = padding.upper()
return lax.reduce_window(
inputs,
initial_value,
reduce_fn,
pool_size,
strides,
padding,
)
def max_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format=None,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding)
def average_pool(
inputs,
pool_size,
strides,
padding,
data_format=None,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)
if padding == "valid":
# Avoid the extra reduce_window.
return pooled / np.prod(pool_size)
else:
# Count the number of valid entries at each input point, then use that
# for computing average. Assumes that any two arrays of same shape will
# be padded the same. Avoid broadcasting on axis where pooling is
# skipped.
shape = [
(a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)
]
window_counts = _pool(
jnp.ones(shape, inputs.dtype),
0.0,
lax.add,
pool_size,
strides,
padding,
)
return pooled / window_counts
def _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format="channels_last",
transpose=False,
):
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if data_format == "channels_last":
spatial_dims = tuple(range(1, num_dims - 1))
inputs_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
inputs_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(
lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn
)
def conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
if data_format == "channels_last":
channels = inputs.shape[-1]
else:
channels = inputs.shape[1]
kernel_in_channels = kernel.shape[-2]
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel's in_channels. Received input channels {channels} and "
f"kernel in_channels {kernel_in_channels}. "
)
feature_group_count = channels // kernel_in_channels
return jax.lax.conv_general_dilated(
convert_to_tensor(inputs),
convert_to_tensor(kernel),
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
def depthwise_conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
feature_group_count = (
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
)
kernel = jnp.reshape(
kernel,
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
)
return jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
def separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
depthwise_conv_output = depthwise_conv(
inputs,
depthwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
return conv(
depthwise_conv_output,
pointwise_kernel,
strides=1,
padding="valid",
data_format=data_format,
dilation_rate=dilation_rate,
)
def conv_transpose(
inputs,
kernel,
strides=1,
padding="valid",
output_padding=None,
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
padding_values = compute_conv_transpose_padding(
inputs.shape,
kernel.shape,
strides,
padding,
output_padding,
data_format,
dilation_rate,
)
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
return jax.lax.conv_transpose(
inputs,
kernel,
strides,
padding=padding_values,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
transpose_kernel=True,
)
def one_hot(x, num_classes, axis=-1):
return jnn.one_hot(x, num_classes, axis=axis)
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = jnp.array(target)
output = jnp.array(output)
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if len(target.shape) < 1:
raise ValueError(
"Arguments `target` and `output` must be at least rank 1. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_prob = jax.nn.log_softmax(output, axis=axis)
else:
output = output / jnp.sum(output, axis, keepdims=True)
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
log_prob = jnp.log(output)
return -jnp.sum(target * log_prob, axis=axis)
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = jnp.array(target, dtype="int32")
output = jnp.array(output)
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
target = jnp.squeeze(target, axis=-1)
if len(output.shape) < 1:
raise ValueError(
"Argument `output` must be at least rank 1. "
"Received: "
f"output.shape={output.shape}"
)
if target.shape != output.shape[:-1]:
raise ValueError(
"Arguments `target` and `output` must have the same shape "
"up until the last dimension: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_prob = jax.nn.log_softmax(output, axis=axis)
else:
output = output / jnp.sum(output, axis, keepdims=True)
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
log_prob = jnp.log(output)
target = jnn.one_hot(target, output.shape[axis], axis=axis)
return -jnp.sum(target * log_prob, axis=axis)
def binary_crossentropy(target, output, from_logits=False):
target = jnp.array(target)
output = jnp.array(output)
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_logits = jax.nn.log_sigmoid(output)
log_neg_logits = jax.nn.log_sigmoid(-output)
return -1.0 * target * log_logits - (1.0 - target) * log_neg_logits
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
bce = target * jnp.log(output)
bce += (1.0 - target) * jnp.log(1.0 - output)
return -bce