534 lines
14 KiB
Python
534 lines
14 KiB
Python
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax import lax
|
|
from jax import nn as jnn
|
|
|
|
from keras_core.backend.common.backend_utils import (
|
|
compute_conv_transpose_output_length,
|
|
)
|
|
from keras_core.backend.config import epsilon
|
|
from keras_core.backend.jax.core import convert_to_tensor
|
|
|
|
|
|
def relu(x):
|
|
return jnn.relu(x)
|
|
|
|
|
|
def relu6(x):
|
|
return jnn.relu6(x)
|
|
|
|
|
|
def sigmoid(x):
|
|
return jnn.sigmoid(x)
|
|
|
|
|
|
def tanh(x):
|
|
return jnn.tanh(x)
|
|
|
|
|
|
def softplus(x):
|
|
return jnn.softplus(x)
|
|
|
|
|
|
def softsign(x):
|
|
return jnn.soft_sign(x)
|
|
|
|
|
|
def silu(x):
|
|
return jnn.silu(x)
|
|
|
|
|
|
def swish(x):
|
|
return jnn.swish(x)
|
|
|
|
|
|
def log_sigmoid(x):
|
|
return jnn.log_sigmoid(x)
|
|
|
|
|
|
def leaky_relu(x, negative_slope=0.2):
|
|
return jnn.leaky_relu(x, negative_slope=negative_slope)
|
|
|
|
|
|
def hard_sigmoid(x):
|
|
return jnn.hard_sigmoid(x)
|
|
|
|
|
|
def elu(x):
|
|
return jnn.elu(x)
|
|
|
|
|
|
def selu(x):
|
|
return jnn.selu(x)
|
|
|
|
|
|
def gelu(x, approximate=True):
|
|
return jnn.gelu(x, approximate)
|
|
|
|
|
|
def softmax(x, axis=None):
|
|
return jnn.softmax(x, axis=axis)
|
|
|
|
|
|
def log_softmax(x, axis=-1):
|
|
return jnn.log_softmax(x, axis=axis)
|
|
|
|
|
|
def _convert_to_spatial_operand(
|
|
x,
|
|
num_spatial_dims,
|
|
data_format="channels_last",
|
|
include_batch_and_channels=True,
|
|
):
|
|
# Helper function that converts an operand to a spatial operand.
|
|
x = (x,) * num_spatial_dims if isinstance(x, int) else x
|
|
if not include_batch_and_channels:
|
|
return x
|
|
if data_format == "channels_last":
|
|
x = (1,) + x + (1,)
|
|
else:
|
|
x = (1,) + (1,) + x
|
|
return x
|
|
|
|
|
|
def _pool(
|
|
inputs,
|
|
initial_value,
|
|
reduce_fn,
|
|
pool_size,
|
|
strides=None,
|
|
padding="valid",
|
|
):
|
|
"""Helper function to define pooling functions.
|
|
|
|
Args:
|
|
inputs: input data of shape `N+2`.
|
|
initial_value: the initial value for the reduction.
|
|
reduce_fn: a reduce function of the form `(T, T) -> T`.
|
|
pool_size: a sequence of `N` integers, representing the window size to
|
|
reduce over.
|
|
strides: a sequence of `N` integers, representing the inter-window
|
|
strides (default: `(1, ..., 1)`).
|
|
padding: either the string `same` or `valid`.
|
|
|
|
Returns:
|
|
The output of the reduction for each window slice.
|
|
"""
|
|
if padding not in ("same", "valid"):
|
|
raise ValueError(
|
|
f"Invalid padding '{padding}', must be 'same' or 'valid'."
|
|
)
|
|
padding = padding.upper()
|
|
return lax.reduce_window(
|
|
inputs,
|
|
initial_value,
|
|
reduce_fn,
|
|
pool_size,
|
|
strides,
|
|
padding,
|
|
)
|
|
|
|
|
|
def max_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides=None,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
):
|
|
num_spatial_dims = inputs.ndim - 2
|
|
pool_size = _convert_to_spatial_operand(
|
|
pool_size, num_spatial_dims, data_format
|
|
)
|
|
strides = pool_size if strides is None else strides
|
|
strides = _convert_to_spatial_operand(
|
|
strides, num_spatial_dims, data_format
|
|
)
|
|
return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding)
|
|
|
|
|
|
def average_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides,
|
|
padding,
|
|
data_format="channels_last",
|
|
):
|
|
num_spatial_dims = inputs.ndim - 2
|
|
pool_size = _convert_to_spatial_operand(
|
|
pool_size, num_spatial_dims, data_format
|
|
)
|
|
strides = pool_size if strides is None else strides
|
|
strides = _convert_to_spatial_operand(
|
|
strides, num_spatial_dims, data_format
|
|
)
|
|
|
|
pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)
|
|
if padding == "valid":
|
|
# Avoid the extra reduce_window.
|
|
return pooled / np.prod(pool_size)
|
|
else:
|
|
# Count the number of valid entries at each input point, then use that
|
|
# for computing average. Assumes that any two arrays of same shape will
|
|
# be padded the same. Avoid broadcasting on axis where pooling is
|
|
# skipped.
|
|
shape = [
|
|
(a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)
|
|
]
|
|
window_counts = _pool(
|
|
jnp.ones(shape, inputs.dtype),
|
|
0.0,
|
|
lax.add,
|
|
pool_size,
|
|
strides,
|
|
padding,
|
|
)
|
|
return pooled / window_counts
|
|
|
|
|
|
def _convert_to_lax_conv_dimension_numbers(
|
|
num_spatial_dims,
|
|
data_format="channels_last",
|
|
transpose=False,
|
|
):
|
|
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
|
|
num_dims = num_spatial_dims + 2
|
|
|
|
if data_format == "channels_last":
|
|
spatial_dims = tuple(range(1, num_dims - 1))
|
|
inputs_dn = (0, num_dims - 1) + spatial_dims
|
|
else:
|
|
spatial_dims = tuple(range(2, num_dims))
|
|
inputs_dn = (0, 1) + spatial_dims
|
|
|
|
if transpose:
|
|
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
|
|
else:
|
|
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
|
|
|
|
return lax.ConvDimensionNumbers(
|
|
lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn
|
|
)
|
|
|
|
|
|
def conv(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
num_spatial_dims = inputs.ndim - 2
|
|
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
|
|
num_spatial_dims,
|
|
data_format,
|
|
transpose=False,
|
|
)
|
|
strides = _convert_to_spatial_operand(
|
|
strides,
|
|
num_spatial_dims,
|
|
data_format,
|
|
include_batch_and_channels=False,
|
|
)
|
|
dilation_rate = _convert_to_spatial_operand(
|
|
dilation_rate,
|
|
num_spatial_dims,
|
|
data_format,
|
|
include_batch_and_channels=False,
|
|
)
|
|
if data_format == "channels_last":
|
|
channels = inputs.shape[-1]
|
|
else:
|
|
channels = inputs.shape[1]
|
|
kernel_in_channels = kernel.shape[-2]
|
|
if channels % kernel_in_channels > 0:
|
|
raise ValueError(
|
|
"The number of input channels must be evenly divisible by "
|
|
f"kernel's in_channels. Received input channels {channels} and "
|
|
f"kernel in_channels {kernel_in_channels}. "
|
|
)
|
|
feature_group_count = channels // kernel_in_channels
|
|
return jax.lax.conv_general_dilated(
|
|
convert_to_tensor(inputs),
|
|
convert_to_tensor(kernel),
|
|
strides,
|
|
padding,
|
|
rhs_dilation=dilation_rate,
|
|
dimension_numbers=dimension_numbers,
|
|
feature_group_count=feature_group_count,
|
|
)
|
|
|
|
|
|
def depthwise_conv(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
num_spatial_dims = inputs.ndim - 2
|
|
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
|
|
num_spatial_dims,
|
|
data_format,
|
|
transpose=False,
|
|
)
|
|
strides = _convert_to_spatial_operand(
|
|
strides,
|
|
num_spatial_dims,
|
|
data_format,
|
|
include_batch_and_channels=False,
|
|
)
|
|
dilation_rate = _convert_to_spatial_operand(
|
|
dilation_rate,
|
|
num_spatial_dims,
|
|
data_format,
|
|
include_batch_and_channels=False,
|
|
)
|
|
feature_group_count = (
|
|
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
|
|
)
|
|
kernel = jnp.reshape(
|
|
kernel,
|
|
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
|
|
)
|
|
return jax.lax.conv_general_dilated(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
rhs_dilation=dilation_rate,
|
|
dimension_numbers=dimension_numbers,
|
|
feature_group_count=feature_group_count,
|
|
)
|
|
|
|
|
|
def separable_conv(
|
|
inputs,
|
|
depthwise_kernel,
|
|
pointwise_kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
depthwise_conv_output = depthwise_conv(
|
|
inputs,
|
|
depthwise_kernel,
|
|
strides,
|
|
padding,
|
|
data_format,
|
|
dilation_rate,
|
|
)
|
|
return conv(
|
|
depthwise_conv_output,
|
|
pointwise_kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format=data_format,
|
|
dilation_rate=dilation_rate,
|
|
)
|
|
|
|
|
|
def _compute_padding_value_one_dim(
|
|
input_length,
|
|
output_length,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation_rate,
|
|
):
|
|
"""Computes adjusted padding for `conv_transpose` in one dim."""
|
|
kernel_size = (kernel_size - 1) * dilation_rate + 1
|
|
if padding == "valid":
|
|
padding_before = 0
|
|
else:
|
|
# padding == "same".
|
|
padding_needed = max(
|
|
0, (input_length - 1) * stride + kernel_size - output_length
|
|
)
|
|
padding_before = padding_needed // 2
|
|
|
|
expanded_input_length = (input_length - 1) * stride + 1
|
|
padded_out_length = output_length + kernel_size - 1
|
|
pad_before = kernel_size - 1 - padding_before
|
|
pad_after = padded_out_length - expanded_input_length - pad_before
|
|
return (pad_before, pad_after)
|
|
|
|
|
|
def _compute_padding_values(
|
|
input_shape,
|
|
kernel_shape,
|
|
strides=1,
|
|
padding="valid",
|
|
output_padding=None,
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
"""Computes adjusted padding for `conv_transpose`."""
|
|
num_spatial_dims = len(input_shape) - 2
|
|
if isinstance(output_padding, int):
|
|
output_padding = (output_padding,) * num_spatial_dims
|
|
if isinstance(strides, int):
|
|
strides = (strides,) * num_spatial_dims
|
|
if isinstance(dilation_rate, int):
|
|
dilation_rate = (dilation_rate,) * num_spatial_dims
|
|
|
|
kernel_spatial_shape = kernel_shape[:-2]
|
|
if data_format == "channels_last":
|
|
input_spatial_shape = input_shape[1:-1]
|
|
else:
|
|
input_spatial_shape = input_shape[2:]
|
|
padding_values = []
|
|
for i in range(num_spatial_dims):
|
|
input_length = input_spatial_shape[i]
|
|
current_output_padding = (
|
|
None if output_padding is None else output_padding[i]
|
|
)
|
|
output_length = compute_conv_transpose_output_length(
|
|
input_spatial_shape[i],
|
|
kernel_spatial_shape[i],
|
|
padding=padding,
|
|
output_padding=current_output_padding,
|
|
stride=strides[i],
|
|
dilation=dilation_rate[i],
|
|
)
|
|
padding_value = _compute_padding_value_one_dim(
|
|
input_length,
|
|
output_length,
|
|
kernel_spatial_shape[i],
|
|
strides[i],
|
|
padding=padding,
|
|
dilation_rate=dilation_rate[i],
|
|
)
|
|
padding_values.append(padding_value)
|
|
return padding_values
|
|
|
|
|
|
def conv_transpose(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
output_padding=None,
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
num_spatial_dims = inputs.ndim - 2
|
|
padding_values = _compute_padding_values(
|
|
inputs.shape,
|
|
kernel.shape,
|
|
strides,
|
|
padding,
|
|
output_padding,
|
|
data_format,
|
|
dilation_rate,
|
|
)
|
|
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
|
|
num_spatial_dims,
|
|
data_format,
|
|
transpose=False,
|
|
)
|
|
strides = _convert_to_spatial_operand(
|
|
strides,
|
|
num_spatial_dims,
|
|
data_format,
|
|
include_batch_and_channels=False,
|
|
)
|
|
dilation_rate = _convert_to_spatial_operand(
|
|
dilation_rate,
|
|
num_spatial_dims,
|
|
data_format,
|
|
include_batch_and_channels=False,
|
|
)
|
|
|
|
return jax.lax.conv_transpose(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding=padding_values,
|
|
rhs_dilation=dilation_rate,
|
|
dimension_numbers=dimension_numbers,
|
|
transpose_kernel=True,
|
|
)
|
|
|
|
|
|
def one_hot(x, num_classes, axis=-1):
|
|
return jnn.one_hot(x, num_classes, axis=axis)
|
|
|
|
|
|
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|
target = jnp.array(target)
|
|
output = jnp.array(output)
|
|
|
|
if target.shape != output.shape:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must have the same shape. "
|
|
"Received: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
if len(target.shape) < 1:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must be at least rank 1. "
|
|
"Received: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
|
|
if from_logits:
|
|
log_prob = jax.nn.log_softmax(output, axis=axis)
|
|
else:
|
|
output = output / jnp.sum(output, axis, keepdims=True)
|
|
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
|
|
log_prob = jnp.log(output)
|
|
return -jnp.sum(target * log_prob, axis=axis)
|
|
|
|
|
|
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|
target = jnp.array(target, dtype="int64")
|
|
output = jnp.array(output)
|
|
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
|
|
target = jnp.squeeze(target, axis=-1)
|
|
|
|
if len(output.shape) < 1:
|
|
raise ValueError(
|
|
"Argument `output` must be at least rank 1. "
|
|
"Received: "
|
|
f"output.shape={output.shape}"
|
|
)
|
|
if target.shape != output.shape[:-1]:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must have the same shape "
|
|
"up until the last dimension: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
if from_logits:
|
|
log_prob = jax.nn.log_softmax(output, axis=axis)
|
|
else:
|
|
output = output / jnp.sum(output, axis, keepdims=True)
|
|
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
|
|
log_prob = jnp.log(output)
|
|
target = jnn.one_hot(target, output.shape[axis], axis=axis)
|
|
return -jnp.sum(target * log_prob, axis=axis)
|
|
|
|
|
|
def binary_crossentropy(target, output, from_logits=False):
|
|
target = jnp.array(target)
|
|
output = jnp.array(output)
|
|
|
|
if target.shape != output.shape:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must have the same shape. "
|
|
"Received: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
|
|
if from_logits:
|
|
output = jnn.sigmoid(output)
|
|
|
|
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
|
|
bce = target * jnp.log(output)
|
|
bce += (1.0 - target) * jnp.log(1.0 - output)
|
|
return -bce
|