import tensorflow as tf from keras_core.backend.common.backend_utils import ( compute_conv_transpose_output_shape, ) def relu(x): return tf.nn.relu(x) def relu6(x): return tf.nn.relu6(x) def sigmoid(x): return tf.nn.sigmoid(x) def softplus(x): return tf.math.softplus(x) def softsign(x): return tf.nn.softsign(x) def silu(x, beta=1.0): return tf.nn.silu(x, beta=beta) def swish(x): return x * sigmoid(x) def log_sigmoid(x): return tf.math.log_sigmoid(x) def leaky_relu(x, negative_slope=0.2): return tf.nn.leaky_relu(x, alpha=negative_slope) def hard_sigmoid(x): x = x / 6.0 + 0.5 return tf.clip_by_value(x, 0.0, 1.0) def elu(x): return tf.nn.elu(x) def selu(x): return tf.nn.selu(x) def gelu(x, approximate=True): return tf.nn.gelu(x, approximate) def softmax(x, axis=None): return tf.nn.softmax(x, axis=axis) def log_softmax(x, axis=None): return tf.nn.log_softmax(x, axis=axis) def max_pool( inputs, pool_size, strides=None, padding="valid", data_format="channels_last", ): strides = pool_size if strides is None else strides padding = padding.upper() data_format = _convert_data_format(data_format, len(inputs.shape)) return tf.nn.max_pool(inputs, pool_size, strides, padding, data_format) def average_pool( inputs, pool_size, strides=None, padding="valid", data_format="channels_last", ): strides = pool_size if strides is None else strides padding = padding.upper() data_format = _convert_data_format(data_format, len(inputs.shape)) return tf.nn.avg_pool(inputs, pool_size, strides, padding, data_format) def _convert_data_format(data_format, ndim): if data_format == "channels_last": if ndim == 3: return "NWC" elif ndim == 4: return "NHWC" elif ndim == 5: return "NDHWC" else: raise ValueError( f"Input rank not supported: {ndim}. " "Expected values are [3, 4, 5]" ) elif data_format == "channels_first": if ndim == 3: return "NCW" elif ndim == 4: return "NCHW" elif ndim == 5: return "NCDHW" else: raise ValueError( f"Input rank not supported: {ndim}. " "Expected values are [3, 4, 5]" ) else: raise ValueError( f"Invalid data_format: {data_format}. " 'Expected values are ["channels_first", "channels_last"]' ) def conv( inputs, kernel, strides=1, padding="valid", data_format="channel_last", dilation_rate=1, ): """General N-D convolution function. Arg: """ data_format = _convert_data_format(data_format, len(inputs.shape)) padding = padding.upper() return tf.nn.convolution( inputs, kernel, strides, padding, data_format=data_format, dilations=dilation_rate, ) def depthwise_conv( inputs, kernel, strides=1, padding="valid", data_format="channels_last", dilation_rate=1, ): num_spatial_dims = len(inputs.shape) - 2 if num_spatial_dims > 2: raise ValueError( "`inputs` rank must be 3 (1D conv) or 4 (2D conv). Received: " "{inputs.ndim}." ) tf_data_format = _convert_data_format(data_format, len(inputs.shape)) padding = padding.upper() if isinstance(strides, int): strides = (strides,) * num_spatial_dims if isinstance(dilation_rate, int): dilation_rate = (dilation_rate,) * num_spatial_dims if num_spatial_dims == 1: # 1D depthwise conv. if data_format == "channels_last": strides = (1,) + strides * 2 + (1,) spatial_start_dim = 1 else: strides = (1, 1) + strides * 2 spatial_start_dim = 2 inputs = tf.expand_dims(inputs, spatial_start_dim) kernel = tf.expand_dims(kernel, axis=0) dilation_rate = None if dilation_rate is None else (1,) + dilation_rate outputs = tf.nn.depthwise_conv2d( inputs, kernel, strides, padding, data_format=tf_data_format, dilations=dilation_rate, ) return tf.squeeze(outputs, [spatial_start_dim]) if data_format == "channels_last": strides = (1,) + strides + (1,) spatial_start_dim = 1 else: strides = (1, 1) + strides spatial_start_dim = 2 return tf.nn.depthwise_conv2d( inputs, kernel, strides, padding, data_format=tf_data_format, dilations=dilation_rate, ) def separable_conv( inputs, depthwise_kernel, pointwise_kernel, strides=1, padding="valid", data_format="channels_last", dilation_rate=1, ): num_spatial_dims = len(inputs.shape) - 2 if num_spatial_dims > 2: raise ValueError( "`num_spatial_dims` must be 1 or 2. Received: " f"num_spatial_dims={num_spatial_dims}." ) tf_data_format = _convert_data_format(data_format, len(inputs.shape)) padding = padding.upper() if isinstance(strides, int): strides = (strides,) * num_spatial_dims if isinstance(dilation_rate, int): dilation_rate = (dilation_rate,) * num_spatial_dims if num_spatial_dims == 1: # 1D depthwise conv. if data_format == "channels_last": strides = (1,) + strides * 2 + (1,) spatial_start_dim = 1 else: strides = (1, 1) + strides * 2 spatial_start_dim = 2 inputs = tf.expand_dims(inputs, spatial_start_dim) depthwise_kernel = tf.expand_dims(depthwise_kernel, axis=0) pointwise_kernel = tf.expand_dims(pointwise_kernel, axis=0) dilation_rate = None if dilation_rate is None else (1,) + dilation_rate outputs = tf.nn.separable_conv2d( inputs, depthwise_kernel, pointwise_kernel, strides, padding, data_format=tf_data_format, dilations=dilation_rate, ) return tf.squeeze(outputs, [spatial_start_dim]) if data_format == "channels_last": strides = (1,) + strides + (1,) else: strides = (1, 1) + strides return tf.nn.separable_conv2d( inputs, depthwise_kernel, pointwise_kernel, strides, padding, data_format=tf_data_format, dilations=dilation_rate, ) def conv_transpose( inputs, kernel, strides=1, padding="valid", output_padding=None, data_format="channels_last", dilation_rate=1, ): tf_data_format = _convert_data_format(data_format, len(inputs.shape)) output_shape = compute_conv_transpose_output_shape( inputs, kernel, strides, padding, output_padding, data_format, dilation_rate, ) return tf.nn.conv_transpose( inputs, kernel, output_shape, strides, padding=padding.upper(), data_format=tf_data_format, dilations=dilation_rate, ) def one_hot(x, num_classes, axis=-1): return tf.one_hot(x, num_classes, axis=axis)