keras/keras_core/backend/jax/nn.py

468 lines
12 KiB
Python
Raw Normal View History

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax import nn as jnn
from keras_core.backend import standardize_data_format
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding,
)
2023-05-01 23:37:03 +00:00
from keras_core.backend.config import epsilon
2023-05-14 05:54:25 +00:00
from keras_core.backend.jax.core import convert_to_tensor
2023-05-01 23:37:03 +00:00
def relu(x):
return jnn.relu(x)
def relu6(x):
return jnn.relu6(x)
def sigmoid(x):
return jnn.sigmoid(x)
def tanh(x):
return jnn.tanh(x)
def softplus(x):
return jnn.softplus(x)
def softsign(x):
return jnn.soft_sign(x)
def silu(x):
return jnn.silu(x)
def swish(x):
return jnn.swish(x)
def log_sigmoid(x):
return jnn.log_sigmoid(x)
def leaky_relu(x, negative_slope=0.2):
return jnn.leaky_relu(x, negative_slope=negative_slope)
def hard_sigmoid(x):
return jnn.hard_sigmoid(x)
def elu(x, alpha=1.0):
return jnn.elu(x, alpha=alpha)
def selu(x):
return jnn.selu(x)
def gelu(x, approximate=True):
return jnn.gelu(x, approximate)
def softmax(x, axis=None):
return jnn.softmax(x, axis=axis)
def log_softmax(x, axis=-1):
return jnn.log_softmax(x, axis=axis)
def _convert_to_spatial_operand(
x,
num_spatial_dims,
data_format="channels_last",
include_batch_and_channels=True,
):
# Helper function that converts an operand to a spatial operand.
x = (x,) * num_spatial_dims if isinstance(x, int) else x
if not include_batch_and_channels:
return x
if data_format == "channels_last":
x = (1,) + x + (1,)
else:
x = (1,) + (1,) + x
return x
def _pool(
inputs,
initial_value,
reduce_fn,
pool_size,
strides=None,
padding="valid",
):
"""Helper function to define pooling functions.
Args:
inputs: input data of shape `N+2`.
initial_value: the initial value for the reduction.
reduce_fn: a reduce function of the form `(T, T) -> T`.
pool_size: a sequence of `N` integers, representing the window size to
reduce over.
strides: a sequence of `N` integers, representing the inter-window
strides (default: `(1, ..., 1)`).
padding: either the string `same` or `valid`.
Returns:
The output of the reduction for each window slice.
"""
if padding not in ("same", "valid"):
raise ValueError(
f"Invalid padding '{padding}', must be 'same' or 'valid'."
)
padding = padding.upper()
return lax.reduce_window(
inputs,
initial_value,
reduce_fn,
pool_size,
strides,
padding,
)
def max_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format=None,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding)
def average_pool(
inputs,
pool_size,
strides,
padding,
data_format=None,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)
if padding == "valid":
# Avoid the extra reduce_window.
return pooled / np.prod(pool_size)
else:
# Count the number of valid entries at each input point, then use that
# for computing average. Assumes that any two arrays of same shape will
# be padded the same. Avoid broadcasting on axis where pooling is
# skipped.
shape = [
(a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)
]
window_counts = _pool(
jnp.ones(shape, inputs.dtype),
0.0,
lax.add,
pool_size,
strides,
padding,
)
return pooled / window_counts
def _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format="channels_last",
transpose=False,
):
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if data_format == "channels_last":
spatial_dims = tuple(range(1, num_dims - 1))
inputs_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
inputs_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(
lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn
)
def conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
if data_format == "channels_last":
channels = inputs.shape[-1]
else:
channels = inputs.shape[1]
kernel_in_channels = kernel.shape[-2]
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel's in_channels. Received input channels {channels} and "
f"kernel in_channels {kernel_in_channels}. "
)
feature_group_count = channels // kernel_in_channels
return jax.lax.conv_general_dilated(
2023-05-14 05:54:25 +00:00
convert_to_tensor(inputs),
convert_to_tensor(kernel),
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
def depthwise_conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
feature_group_count = (
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
)
kernel = jnp.reshape(
kernel,
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
)
return jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
def separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
depthwise_conv_output = depthwise_conv(
inputs,
depthwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
return conv(
depthwise_conv_output,
pointwise_kernel,
strides=1,
padding="valid",
data_format=data_format,
dilation_rate=dilation_rate,
)
def conv_transpose(
inputs,
kernel,
strides=1,
padding="valid",
output_padding=None,
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
padding_values = compute_conv_transpose_padding(
inputs.shape,
kernel.shape,
strides,
padding,
output_padding,
data_format,
dilation_rate,
)
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
return jax.lax.conv_transpose(
inputs,
kernel,
strides,
padding=padding_values,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
transpose_kernel=True,
)
def one_hot(x, num_classes, axis=-1):
return jnn.one_hot(x, num_classes, axis=axis)
2023-05-01 23:37:03 +00:00
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = jnp.array(target)
output = jnp.array(output)
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if len(target.shape) < 1:
raise ValueError(
"Arguments `target` and `output` must be at least rank 1. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
2023-05-04 21:59:15 +00:00
2023-05-01 23:37:03 +00:00
if from_logits:
log_prob = jax.nn.log_softmax(output, axis=axis)
else:
output = output / jnp.sum(output, axis, keepdims=True)
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
log_prob = jnp.log(output)
return -jnp.sum(target * log_prob, axis=axis)
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = jnp.array(target, dtype="int32")
2023-05-01 23:37:03 +00:00
output = jnp.array(output)
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
target = jnp.squeeze(target, axis=-1)
if len(output.shape) < 1:
raise ValueError(
"Argument `output` must be at least rank 1. "
"Received: "
f"output.shape={output.shape}"
)
if target.shape != output.shape[:-1]:
raise ValueError(
"Arguments `target` and `output` must have the same shape "
"up until the last dimension: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_prob = jax.nn.log_softmax(output, axis=axis)
else:
output = output / jnp.sum(output, axis, keepdims=True)
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
log_prob = jnp.log(output)
target = jnn.one_hot(target, output.shape[axis], axis=axis)
return -jnp.sum(target * log_prob, axis=axis)
def binary_crossentropy(target, output, from_logits=False):
target = jnp.array(target)
output = jnp.array(output)
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_logits = jax.nn.log_sigmoid(output)
log_neg_logits = jax.nn.log_sigmoid(-output)
return -1.0 * target * log_logits - (1.0 - target) * log_neg_logits
2023-05-01 23:37:03 +00:00
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
bce = target * jnp.log(output)
bce += (1.0 - target) * jnp.log(1.0 - output)
return -bce