2d40cb20b9
* Adds unit normalization and tests * Adds layer normalization and initial tests * Fixes formatting in docstrings * Fix type issues for JAX * Fix nits * Initial stash for group_normalization and spectral_normalization * Adds spectral normalization and tests * Adds group normalization and tests * Formatting fixes * Fix small nit in docstring * Fix docstring and tests * Adds RandomContrast and associated tests * Remove arithmetic comment * Adds RandomBrightness and tests * Fix docstring and format * Fix nits and add backend generator * Inlines random_contrast helper * Add bincount op * Add CategoryEncoding layer and tests * Fix formatting * Fix JAX issues * Fix JAX bincount * Formatting and small fix * Fix nits and docstrings * Add args to bincount op test
605 lines
18 KiB
Python
605 lines
18 KiB
Python
import warnings
|
|
|
|
import tensorflow as tf
|
|
|
|
from keras_core.backend.common.backend_utils import (
|
|
compute_conv_transpose_output_shape,
|
|
)
|
|
from keras_core.backend.config import epsilon
|
|
|
|
|
|
def relu(x):
|
|
return tf.nn.relu(x)
|
|
|
|
|
|
def relu6(x):
|
|
return tf.nn.relu6(x)
|
|
|
|
|
|
def sigmoid(x):
|
|
logits = x
|
|
output = tf.nn.sigmoid(x)
|
|
output._keras_logits = logits
|
|
return output
|
|
|
|
|
|
def tanh(x):
|
|
return tf.nn.tanh(x)
|
|
|
|
|
|
def softplus(x):
|
|
return tf.math.softplus(x)
|
|
|
|
|
|
def softsign(x):
|
|
return tf.nn.softsign(x)
|
|
|
|
|
|
def silu(x, beta=1.0):
|
|
return tf.nn.silu(x, beta=beta)
|
|
|
|
|
|
def swish(x):
|
|
return x * sigmoid(x)
|
|
|
|
|
|
def log_sigmoid(x):
|
|
return tf.math.log_sigmoid(x)
|
|
|
|
|
|
def leaky_relu(x, negative_slope=0.2):
|
|
return tf.nn.leaky_relu(x, alpha=negative_slope)
|
|
|
|
|
|
def hard_sigmoid(x):
|
|
x = x / 6.0 + 0.5
|
|
return tf.clip_by_value(x, 0.0, 1.0)
|
|
|
|
|
|
def elu(x):
|
|
return tf.nn.elu(x)
|
|
|
|
|
|
def selu(x):
|
|
return tf.nn.selu(x)
|
|
|
|
|
|
def gelu(x, approximate=True):
|
|
return tf.nn.gelu(x, approximate)
|
|
|
|
|
|
def softmax(x, axis=None):
|
|
logits = x
|
|
if axis is None:
|
|
# Unlike numpy, tf will handle axis=None as axis=-1.
|
|
# We need this workaround for the reduction on every dim.
|
|
logits_exp = tf.exp(logits)
|
|
output = logits_exp / tf.reduce_sum(logits_exp, keepdims=True)
|
|
else:
|
|
output = tf.nn.softmax(x, axis=axis)
|
|
output._keras_logits = logits
|
|
return output
|
|
|
|
|
|
def log_softmax(x, axis=None):
|
|
if axis is None:
|
|
# Unlike numpy, tf will handle axis=None as axis=-1.
|
|
# We need this workaround for the reduction on every dim.
|
|
logits = x
|
|
logits_exp = tf.exp(logits)
|
|
return logits - tf.math.log(tf.reduce_sum(logits_exp, keepdims=True))
|
|
return tf.nn.log_softmax(x, axis=axis)
|
|
|
|
|
|
def _transpose_spatial_inputs(inputs):
|
|
num_spatial_dims = len(inputs.shape) - 2
|
|
# Tensorflow pooling does not support `channels_first` format, so
|
|
# we need to transpose to `channels_last` format.
|
|
if num_spatial_dims == 1:
|
|
inputs = tf.transpose(inputs, (0, 2, 1))
|
|
elif num_spatial_dims == 2:
|
|
inputs = tf.transpose(inputs, (0, 2, 3, 1))
|
|
elif num_spatial_dims == 3:
|
|
inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))
|
|
else:
|
|
raise ValueError(
|
|
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
|
f"and 3D inputs. But received shape: {inputs.shape}."
|
|
)
|
|
return inputs
|
|
|
|
|
|
def _transpose_spatial_outputs(outputs):
|
|
# Undo the tranpose in `_transpose_spatial_inputs`.
|
|
num_spatial_dims = len(outputs.shape) - 2
|
|
if num_spatial_dims == 1:
|
|
outputs = tf.transpose(outputs, (0, 2, 1))
|
|
elif num_spatial_dims == 2:
|
|
outputs = tf.transpose(outputs, (0, 3, 1, 2))
|
|
elif num_spatial_dims == 3:
|
|
outputs = tf.transpose(outputs, (0, 4, 1, 2, 3))
|
|
return outputs
|
|
|
|
|
|
def max_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides=None,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
):
|
|
strides = pool_size if strides is None else strides
|
|
padding = padding.upper()
|
|
tf_data_format = _convert_data_format("channels_last", len(inputs.shape))
|
|
if data_format == "channels_first":
|
|
# Tensorflow pooling does not support `channels_first` format, so
|
|
# we need to transpose to `channels_last` format.
|
|
inputs = _transpose_spatial_inputs(inputs)
|
|
|
|
outputs = tf.nn.max_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides,
|
|
padding,
|
|
tf_data_format,
|
|
)
|
|
if data_format == "channels_first":
|
|
outputs = _transpose_spatial_outputs(outputs)
|
|
return outputs
|
|
|
|
|
|
def average_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides=None,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
):
|
|
strides = pool_size if strides is None else strides
|
|
padding = padding.upper()
|
|
tf_data_format = _convert_data_format("channels_last", len(inputs.shape))
|
|
if data_format == "channels_first":
|
|
# Tensorflow pooling does not support `channels_first` format, so
|
|
# we need to transpose to `channels_last` format.
|
|
inputs = _transpose_spatial_inputs(inputs)
|
|
|
|
outputs = tf.nn.avg_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides,
|
|
padding,
|
|
tf_data_format,
|
|
)
|
|
if data_format == "channels_first":
|
|
outputs = _transpose_spatial_outputs(outputs)
|
|
return outputs
|
|
|
|
|
|
def _convert_data_format(data_format, ndim):
|
|
if data_format == "channels_last":
|
|
if ndim == 3:
|
|
return "NWC"
|
|
elif ndim == 4:
|
|
return "NHWC"
|
|
elif ndim == 5:
|
|
return "NDHWC"
|
|
else:
|
|
raise ValueError(
|
|
f"Input rank not supported: {ndim}. "
|
|
"Expected values are [3, 4, 5]"
|
|
)
|
|
elif data_format == "channels_first":
|
|
if ndim == 3:
|
|
return "NCW"
|
|
elif ndim == 4:
|
|
return "NCHW"
|
|
elif ndim == 5:
|
|
return "NCDHW"
|
|
else:
|
|
raise ValueError(
|
|
f"Input rank not supported: {ndim}. "
|
|
"Expected values are [3, 4, 5]"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid data_format: {data_format}. "
|
|
'Expected values are ["channels_first", "channels_last"]'
|
|
)
|
|
|
|
|
|
def conv(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channel_last",
|
|
dilation_rate=1,
|
|
):
|
|
def _conv():
|
|
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
return tf.nn.convolution(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding.upper(),
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
# Reason for making this function is in Tensorflow, `groups > 1` does not
|
|
# work on CPU for `tf.nn.convolution`, but wrapping it by XLA works.
|
|
@tf.function(jit_compile=True)
|
|
def _conv_xla():
|
|
return _conv()
|
|
|
|
if data_format == "channels_last":
|
|
channels = inputs.shape[-1]
|
|
else:
|
|
channels = inputs.shape[1]
|
|
if channels != kernel.shape[-2]:
|
|
# If kernel's in_channel does not match input's channels, it indicates
|
|
# convolution is broken down into groups.
|
|
return _conv_xla()
|
|
return _conv()
|
|
|
|
|
|
def depthwise_conv(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
num_spatial_dims = len(inputs.shape) - 2
|
|
if num_spatial_dims > 2:
|
|
raise ValueError(
|
|
"`inputs` rank must be 3 (1D conv) or 4 (2D conv). Received: "
|
|
"{inputs.ndim}."
|
|
)
|
|
# Because we use `tf.nn.depthwise_conv2d` for both 1D and 2D convs, we set
|
|
# `tf_data_format` using 2D conv format.
|
|
tf_data_format = _convert_data_format(data_format, 4)
|
|
padding = padding.upper()
|
|
if isinstance(strides, int):
|
|
strides = (strides,) * num_spatial_dims
|
|
if isinstance(dilation_rate, int):
|
|
dilation_rate = (dilation_rate,) * num_spatial_dims
|
|
if num_spatial_dims == 1:
|
|
# 1D depthwise conv.
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides * 2 + (1,)
|
|
spatial_start_dim = 1
|
|
else:
|
|
strides = (1, 1) + strides * 2
|
|
spatial_start_dim = 2
|
|
inputs = tf.expand_dims(inputs, spatial_start_dim)
|
|
kernel = tf.expand_dims(kernel, axis=0)
|
|
|
|
dilation_rate = None if dilation_rate is None else (1,) + dilation_rate
|
|
|
|
outputs = tf.nn.depthwise_conv2d(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
return tf.squeeze(outputs, [spatial_start_dim])
|
|
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides + (1,)
|
|
spatial_start_dim = 1
|
|
else:
|
|
strides = (1, 1) + strides
|
|
spatial_start_dim = 2
|
|
return tf.nn.depthwise_conv2d(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
|
|
def separable_conv(
|
|
inputs,
|
|
depthwise_kernel,
|
|
pointwise_kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
num_spatial_dims = len(inputs.shape) - 2
|
|
if num_spatial_dims > 2:
|
|
raise ValueError(
|
|
"`num_spatial_dims` must be 1 or 2. Received: "
|
|
f"num_spatial_dims={num_spatial_dims}."
|
|
)
|
|
# Because we use `tf.nn.separable_conv2d` for both 1D and 2D convs, we set
|
|
# `tf_data_format` using 2D conv format.
|
|
tf_data_format = _convert_data_format(data_format, 4)
|
|
padding = padding.upper()
|
|
if isinstance(strides, int):
|
|
strides = (strides,) * num_spatial_dims
|
|
if isinstance(dilation_rate, int):
|
|
dilation_rate = (dilation_rate,) * num_spatial_dims
|
|
if num_spatial_dims == 1:
|
|
# 1D depthwise conv.
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides * 2 + (1,)
|
|
spatial_start_dim = 1
|
|
else:
|
|
strides = (1, 1) + strides * 2
|
|
spatial_start_dim = 2
|
|
inputs = tf.expand_dims(inputs, spatial_start_dim)
|
|
depthwise_kernel = tf.expand_dims(depthwise_kernel, axis=0)
|
|
pointwise_kernel = tf.expand_dims(pointwise_kernel, axis=0)
|
|
dilation_rate = None if dilation_rate is None else (1,) + dilation_rate
|
|
|
|
outputs = tf.nn.separable_conv2d(
|
|
inputs,
|
|
depthwise_kernel,
|
|
pointwise_kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
return tf.squeeze(outputs, [spatial_start_dim])
|
|
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides + (1,)
|
|
else:
|
|
strides = (1, 1) + strides
|
|
return tf.nn.separable_conv2d(
|
|
inputs,
|
|
depthwise_kernel,
|
|
pointwise_kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
|
|
def conv_transpose(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
output_padding=None,
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
kernel_size = kernel.shape[:-2]
|
|
filters = kernel.shape[-2]
|
|
output_shape = compute_conv_transpose_output_shape(
|
|
inputs.shape,
|
|
kernel_size,
|
|
filters,
|
|
strides,
|
|
padding,
|
|
output_padding,
|
|
data_format,
|
|
dilation_rate,
|
|
)
|
|
|
|
return tf.nn.conv_transpose(
|
|
inputs,
|
|
kernel,
|
|
output_shape,
|
|
strides,
|
|
padding=padding.upper(),
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
|
|
def one_hot(x, num_classes, axis=-1):
|
|
return tf.one_hot(x, num_classes, axis=axis)
|
|
|
|
|
|
def _get_logits(output, from_logits, op_type, fn_name):
|
|
"""Retrieves logits tensor from maybe-softmax or maybe-sigmoid tensor."""
|
|
output_ = output
|
|
from_logits_ = from_logits
|
|
|
|
has_keras_logits = hasattr(output, "_keras_logits")
|
|
if has_keras_logits:
|
|
output_ = output._keras_logits
|
|
from_logits_ = True
|
|
|
|
from_expected_op_type = (
|
|
not isinstance(output, (tf.__internal__.EagerTensor, tf.Variable))
|
|
and output.op.type == op_type
|
|
) and not has_keras_logits
|
|
|
|
if from_expected_op_type:
|
|
# When softmax activation function is used for output operation, we
|
|
# use logits from the softmax function directly to compute loss in order
|
|
# to prevent collapsing zero when training.
|
|
assert len(output.op.inputs) == 1
|
|
output_ = output.op.inputs[0]
|
|
from_logits_ = True
|
|
|
|
if from_logits and (has_keras_logits or from_expected_op_type):
|
|
warnings.warn(
|
|
f'"`{fn_name}` received `from_logits=True`, but '
|
|
f"the `output` argument was produced by a {op_type} "
|
|
"activation and thus does not represent logits. "
|
|
"Was this intended?",
|
|
stacklevel=2,
|
|
)
|
|
return output_, from_logits_
|
|
|
|
|
|
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|
"""Categorical crossentropy between an output tensor and a target tensor.
|
|
|
|
Args:
|
|
target: A tensor of the same shape as `output`.
|
|
output: A tensor resulting from a softmax
|
|
(unless `from_logits` is `True`, in which
|
|
case `output` is expected to be the logits).
|
|
from_logits: Boolean, whether `output` is the
|
|
result of a softmax, or is a tensor of logits.
|
|
axis: Int specifying the channels axis. `axis=-1` corresponds to data
|
|
format `channels_last`, and `axis=1` corresponds to data format
|
|
`channels_first`.
|
|
|
|
Returns:
|
|
Output tensor.
|
|
|
|
Example:
|
|
|
|
>>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
|
|
>>> print(a)
|
|
tf.Tensor(
|
|
[[1. 0. 0.]
|
|
[0. 1. 0.]
|
|
[0. 0. 1.]], shape=(3, 3), dtype=float32)
|
|
>>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94],
|
|
... shape=[3, 3])
|
|
>>> print(b)
|
|
tf.Tensor(
|
|
[[0.9 0.05 0.05]
|
|
[0.05 0.89 0.06]
|
|
[0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
|
|
>>> loss = categorical_crossentropy(a, b)
|
|
>>> print(np.around(loss, 5))
|
|
[0.10536 0.11653 0.06188]
|
|
>>> loss = categorical_crossentropy(a, a)
|
|
>>> print(np.around(loss, 5))
|
|
[0. 0. 0.]
|
|
"""
|
|
target = tf.convert_to_tensor(target)
|
|
output = tf.convert_to_tensor(output)
|
|
|
|
if target.shape != output.shape:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must have the same shape. "
|
|
"Received: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
if len(target.shape) < 1:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must be at least rank 1. "
|
|
"Received: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
|
|
output, from_logits = _get_logits(
|
|
output, from_logits, "Softmax", "categorical_crossentropy"
|
|
)
|
|
if from_logits:
|
|
return tf.nn.softmax_cross_entropy_with_logits(
|
|
labels=target, logits=output, axis=axis
|
|
)
|
|
|
|
# Adjust the predictions so that the probability of
|
|
# each class for every sample adds up to 1
|
|
# This is needed to ensure that the cross entropy is
|
|
# computed correctly.
|
|
output = output / tf.reduce_sum(output, axis, keepdims=True)
|
|
|
|
# Compute cross entropy from probabilities.
|
|
output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon())
|
|
return -tf.reduce_sum(target * tf.math.log(output), axis)
|
|
|
|
|
|
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|
"""Categorical crossentropy with integer targets.
|
|
|
|
Args:
|
|
target: An integer tensor.
|
|
output: A tensor resulting from a softmax
|
|
(unless `from_logits` is True, in which
|
|
case `output` is expected to be the logits).
|
|
from_logits: Boolean, whether `output` is the
|
|
result of a softmax, or is a tensor of logits.
|
|
axis: Int specifying the channels axis. `axis=-1` corresponds to data
|
|
format `channels_last`, and `axis=1` corresponds to data format
|
|
`channels_first`.
|
|
|
|
Returns:
|
|
Output tensor.
|
|
"""
|
|
if axis != -1 and axis != len(output.shape) - 1:
|
|
raise ValueError(
|
|
f"Only axis=-1 is currently supported. Received: axis={axis}"
|
|
)
|
|
|
|
target = tf.convert_to_tensor(target)
|
|
target = tf.cast(target, dtype="int64")
|
|
output = tf.convert_to_tensor(output)
|
|
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
|
|
target = tf.squeeze(target, axis=-1)
|
|
|
|
if len(output.shape) < 1:
|
|
raise ValueError(
|
|
"Argument `output` must be at least rank 1. "
|
|
"Received: "
|
|
f"output.shape={output.shape}"
|
|
)
|
|
if target.shape != output.shape[:-1]:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must have the same shape "
|
|
"up until the last dimension: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
|
|
output, from_logits = _get_logits(
|
|
output, from_logits, "Softmax", "sparse_categorical_crossentropy"
|
|
)
|
|
if not from_logits:
|
|
output = tf.clip_by_value(output, epsilon(), 1 - epsilon())
|
|
output = tf.math.log(output)
|
|
|
|
result = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
labels=target, logits=output
|
|
)
|
|
return result
|
|
|
|
|
|
def binary_crossentropy(target, output, from_logits=False):
|
|
"""Binary crossentropy between an output tensor and a target tensor.
|
|
|
|
Args:
|
|
target: A tensor with the same shape as `output`.
|
|
output: A tensor.
|
|
from_logits: Whether `output` is expected to be a logits tensor.
|
|
By default, we consider that `output`
|
|
encodes a probability distribution.
|
|
|
|
Returns:
|
|
A tensor.
|
|
"""
|
|
target = tf.convert_to_tensor(target)
|
|
output = tf.convert_to_tensor(output)
|
|
|
|
if target.shape != output.shape:
|
|
raise ValueError(
|
|
"Arguments `target` and `output` must have the same shape. "
|
|
"Received: "
|
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
)
|
|
|
|
output, from_logits = _get_logits(
|
|
output, from_logits, "Sigmoid", "binary_crossentropy"
|
|
)
|
|
if from_logits:
|
|
return tf.nn.sigmoid_cross_entropy_with_logits(
|
|
labels=target, logits=output
|
|
)
|
|
|
|
# Compute cross entropy from probabilities.
|
|
output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon())
|
|
bce = target * tf.math.log(output)
|
|
bce += (1 - target) * tf.math.log(1 - output)
|
|
return -bce
|