305 lines
7.2 KiB
Python
305 lines
7.2 KiB
Python
import tensorflow as tf
|
|
|
|
from keras_core.backend.common.backend_utils import (
|
|
compute_conv_transpose_output_shape,
|
|
)
|
|
|
|
|
|
def relu(x):
|
|
return tf.nn.relu(x)
|
|
|
|
|
|
def relu6(x):
|
|
return tf.nn.relu6(x)
|
|
|
|
|
|
def sigmoid(x):
|
|
return tf.nn.sigmoid(x)
|
|
|
|
|
|
def softplus(x):
|
|
return tf.math.softplus(x)
|
|
|
|
|
|
def softsign(x):
|
|
return tf.nn.softsign(x)
|
|
|
|
|
|
def silu(x, beta=1.0):
|
|
return tf.nn.silu(x, beta=beta)
|
|
|
|
|
|
def swish(x):
|
|
return x * sigmoid(x)
|
|
|
|
|
|
def log_sigmoid(x):
|
|
return tf.math.log_sigmoid(x)
|
|
|
|
|
|
def leaky_relu(x, negative_slope=0.2):
|
|
return tf.nn.leaky_relu(x, alpha=negative_slope)
|
|
|
|
|
|
def hard_sigmoid(x):
|
|
x = x / 6.0 + 0.5
|
|
return tf.clip_by_value(x, 0.0, 1.0)
|
|
|
|
|
|
def elu(x):
|
|
return tf.nn.elu(x)
|
|
|
|
|
|
def selu(x):
|
|
return tf.nn.selu(x)
|
|
|
|
|
|
def gelu(x, approximate=True):
|
|
return tf.nn.gelu(x, approximate)
|
|
|
|
|
|
def softmax(x, axis=None):
|
|
return tf.nn.softmax(x, axis=axis)
|
|
|
|
|
|
def log_softmax(x, axis=None):
|
|
return tf.nn.log_softmax(x, axis=axis)
|
|
|
|
|
|
def max_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides=None,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
):
|
|
strides = pool_size if strides is None else strides
|
|
padding = padding.upper()
|
|
data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
return tf.nn.max_pool(inputs, pool_size, strides, padding, data_format)
|
|
|
|
|
|
def average_pool(
|
|
inputs,
|
|
pool_size,
|
|
strides=None,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
):
|
|
strides = pool_size if strides is None else strides
|
|
padding = padding.upper()
|
|
data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
return tf.nn.avg_pool(inputs, pool_size, strides, padding, data_format)
|
|
|
|
|
|
def _convert_data_format(data_format, ndim):
|
|
if data_format == "channels_last":
|
|
if ndim == 3:
|
|
return "NWC"
|
|
elif ndim == 4:
|
|
return "NHWC"
|
|
elif ndim == 5:
|
|
return "NDHWC"
|
|
else:
|
|
raise ValueError(
|
|
f"Input rank not supported: {ndim}. "
|
|
"Expected values are [3, 4, 5]"
|
|
)
|
|
elif data_format == "channels_first":
|
|
if ndim == 3:
|
|
return "NCW"
|
|
elif ndim == 4:
|
|
return "NCHW"
|
|
elif ndim == 5:
|
|
return "NCDHW"
|
|
else:
|
|
raise ValueError(
|
|
f"Input rank not supported: {ndim}. "
|
|
"Expected values are [3, 4, 5]"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid data_format: {data_format}. "
|
|
'Expected values are ["channels_first", "channels_last"]'
|
|
)
|
|
|
|
|
|
def conv(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channel_last",
|
|
dilation_rate=1,
|
|
):
|
|
"""General N-D convolution function.
|
|
|
|
Arg:
|
|
"""
|
|
|
|
data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
padding = padding.upper()
|
|
|
|
return tf.nn.convolution(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
data_format=data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
|
|
def depthwise_conv(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
num_spatial_dims = len(inputs.shape) - 2
|
|
if num_spatial_dims > 2:
|
|
raise ValueError(
|
|
"`inputs` rank must be 3 (1D conv) or 4 (2D conv). Received: "
|
|
"{inputs.ndim}."
|
|
)
|
|
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
padding = padding.upper()
|
|
if isinstance(strides, int):
|
|
strides = (strides,) * num_spatial_dims
|
|
if isinstance(dilation_rate, int):
|
|
dilation_rate = (dilation_rate,) * num_spatial_dims
|
|
if num_spatial_dims == 1:
|
|
# 1D depthwise conv.
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides * 2 + (1,)
|
|
spatial_start_dim = 1
|
|
else:
|
|
strides = (1, 1) + strides * 2
|
|
spatial_start_dim = 2
|
|
inputs = tf.expand_dims(inputs, spatial_start_dim)
|
|
kernel = tf.expand_dims(kernel, axis=0)
|
|
|
|
dilation_rate = None if dilation_rate is None else (1,) + dilation_rate
|
|
|
|
outputs = tf.nn.depthwise_conv2d(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
return tf.squeeze(outputs, [spatial_start_dim])
|
|
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides + (1,)
|
|
spatial_start_dim = 1
|
|
else:
|
|
strides = (1, 1) + strides
|
|
spatial_start_dim = 2
|
|
return tf.nn.depthwise_conv2d(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
|
|
def separable_conv(
|
|
inputs,
|
|
depthwise_kernel,
|
|
pointwise_kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
num_spatial_dims = len(inputs.shape) - 2
|
|
if num_spatial_dims > 2:
|
|
raise ValueError(
|
|
"`num_spatial_dims` must be 1 or 2. Received: "
|
|
f"num_spatial_dims={num_spatial_dims}."
|
|
)
|
|
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
padding = padding.upper()
|
|
if isinstance(strides, int):
|
|
strides = (strides,) * num_spatial_dims
|
|
if isinstance(dilation_rate, int):
|
|
dilation_rate = (dilation_rate,) * num_spatial_dims
|
|
if num_spatial_dims == 1:
|
|
# 1D depthwise conv.
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides * 2 + (1,)
|
|
spatial_start_dim = 1
|
|
else:
|
|
strides = (1, 1) + strides * 2
|
|
spatial_start_dim = 2
|
|
inputs = tf.expand_dims(inputs, spatial_start_dim)
|
|
depthwise_kernel = tf.expand_dims(depthwise_kernel, axis=0)
|
|
pointwise_kernel = tf.expand_dims(pointwise_kernel, axis=0)
|
|
dilation_rate = None if dilation_rate is None else (1,) + dilation_rate
|
|
|
|
outputs = tf.nn.separable_conv2d(
|
|
inputs,
|
|
depthwise_kernel,
|
|
pointwise_kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
return tf.squeeze(outputs, [spatial_start_dim])
|
|
|
|
if data_format == "channels_last":
|
|
strides = (1,) + strides + (1,)
|
|
else:
|
|
strides = (1, 1) + strides
|
|
return tf.nn.separable_conv2d(
|
|
inputs,
|
|
depthwise_kernel,
|
|
pointwise_kernel,
|
|
strides,
|
|
padding,
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
|
|
def conv_transpose(
|
|
inputs,
|
|
kernel,
|
|
strides=1,
|
|
padding="valid",
|
|
output_padding=None,
|
|
data_format="channels_last",
|
|
dilation_rate=1,
|
|
):
|
|
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
|
|
output_shape = compute_conv_transpose_output_shape(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
output_padding,
|
|
data_format,
|
|
dilation_rate,
|
|
)
|
|
|
|
return tf.nn.conv_transpose(
|
|
inputs,
|
|
kernel,
|
|
output_shape,
|
|
strides,
|
|
padding=padding.upper(),
|
|
data_format=tf_data_format,
|
|
dilations=dilation_rate,
|
|
)
|
|
|
|
|
|
def one_hot(x, num_classes, axis=-1):
|
|
return tf.one_hot(x, num_classes, axis=axis)
|