import jax import jax.numpy as jnp import numpy as np from jax import lax from jax import nn as jnn from keras_core.backend.common.backend_utils import ( compute_conv_transpose_output_length, ) from keras_core.backend.config import epsilon from keras_core.backend.jax.core import convert_to_tensor def relu(x): return jnn.relu(x) def relu6(x): return jnn.relu6(x) def sigmoid(x): return jnn.sigmoid(x) def tanh(x): return jnn.tanh(x) def softplus(x): return jnn.softplus(x) def softsign(x): return jnn.soft_sign(x) def silu(x): return jnn.silu(x) def swish(x): return jnn.swish(x) def log_sigmoid(x): return jnn.log_sigmoid(x) def leaky_relu(x, negative_slope=0.2): return jnn.leaky_relu(x, negative_slope=negative_slope) def hard_sigmoid(x): return jnn.hard_sigmoid(x) def elu(x): return jnn.elu(x) def selu(x): return jnn.selu(x) def gelu(x, approximate=True): return jnn.gelu(x, approximate) def softmax(x, axis=None): return jnn.softmax(x, axis=axis) def log_softmax(x, axis=-1): return jnn.log_softmax(x, axis=axis) def _convert_to_spatial_operand( x, num_spatial_dims, data_format="channels_last", include_batch_and_channels=True, ): # Helper function that converts an operand to a spatial operand. x = (x,) * num_spatial_dims if isinstance(x, int) else x if not include_batch_and_channels: return x if data_format == "channels_last": x = (1,) + x + (1,) else: x = (1,) + (1,) + x return x def _pool( inputs, initial_value, reduce_fn, pool_size, strides=None, padding="valid", ): """Helper function to define pooling functions. Args: inputs: input data of shape `N+2`. initial_value: the initial value for the reduction. reduce_fn: a reduce function of the form `(T, T) -> T`. pool_size: a sequence of `N` integers, representing the window size to reduce over. strides: a sequence of `N` integers, representing the inter-window strides (default: `(1, ..., 1)`). padding: either the string `same` or `valid`. Returns: The output of the reduction for each window slice. """ if padding not in ("same", "valid"): raise ValueError( f"Invalid padding '{padding}', must be 'same' or 'valid'." ) padding = padding.upper() return lax.reduce_window( inputs, initial_value, reduce_fn, pool_size, strides, padding, ) def max_pool( inputs, pool_size, strides=None, padding="valid", data_format="channels_last", ): num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format ) strides = pool_size if strides is None else strides strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format ) return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding) def average_pool( inputs, pool_size, strides, padding, data_format="channels_last", ): num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format ) strides = pool_size if strides is None else strides strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format ) pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) if padding == "valid": # Avoid the extra reduce_window. return pooled / np.prod(pool_size) else: # Count the number of valid entries at each input point, then use that # for computing average. Assumes that any two arrays of same shape will # be padded the same. Avoid broadcasting on axis where pooling is # skipped. shape = [ (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) ] window_counts = _pool( jnp.ones(shape, inputs.dtype), 0.0, lax.add, pool_size, strides, padding, ) return pooled / window_counts def _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format="channels_last", transpose=False, ): """Create a `lax.ConvDimensionNumbers` for the given inputs.""" num_dims = num_spatial_dims + 2 if data_format == "channels_last": spatial_dims = tuple(range(1, num_dims - 1)) inputs_dn = (0, num_dims - 1) + spatial_dims else: spatial_dims = tuple(range(2, num_dims)) inputs_dn = (0, 1) + spatial_dims if transpose: kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) else: kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) return lax.ConvDimensionNumbers( lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn ) def conv( inputs, kernel, strides=1, padding="valid", data_format="channels_last", dilation_rate=1, ): num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format, transpose=False, ) strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format, include_batch_and_channels=False, ) dilation_rate = _convert_to_spatial_operand( dilation_rate, num_spatial_dims, data_format, include_batch_and_channels=False, ) if data_format == "channels_last": channels = inputs.shape[-1] else: channels = inputs.shape[1] kernel_in_channels = kernel.shape[-2] if channels % kernel_in_channels > 0: raise ValueError( "The number of input channels must be evenly divisible by " f"kernel's in_channels. Received input channels {channels} and " f"kernel in_channels {kernel_in_channels}. " ) feature_group_count = channels // kernel_in_channels return jax.lax.conv_general_dilated( convert_to_tensor(inputs), convert_to_tensor(kernel), strides, padding, rhs_dilation=dilation_rate, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, ) def depthwise_conv( inputs, kernel, strides=1, padding="valid", data_format="channels_last", dilation_rate=1, ): num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format, transpose=False, ) strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format, include_batch_and_channels=False, ) dilation_rate = _convert_to_spatial_operand( dilation_rate, num_spatial_dims, data_format, include_batch_and_channels=False, ) feature_group_count = ( inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] ) kernel = jnp.reshape( kernel, kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), ) return jax.lax.conv_general_dilated( inputs, kernel, strides, padding, rhs_dilation=dilation_rate, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, ) def separable_conv( inputs, depthwise_kernel, pointwise_kernel, strides=1, padding="valid", data_format="channels_last", dilation_rate=1, ): depthwise_conv_output = depthwise_conv( inputs, depthwise_kernel, strides, padding, data_format, dilation_rate, ) return conv( depthwise_conv_output, pointwise_kernel, strides=1, padding="valid", data_format=data_format, dilation_rate=dilation_rate, ) def _compute_padding_value_one_dim( input_length, output_length, kernel_size, stride, padding, dilation_rate, ): """Computes adjusted padding for `conv_transpose` in one dim.""" kernel_size = (kernel_size - 1) * dilation_rate + 1 if padding == "valid": padding_before = 0 else: # padding == "same". padding_needed = max( 0, (input_length - 1) * stride + kernel_size - output_length ) padding_before = padding_needed // 2 expanded_input_length = (input_length - 1) * stride + 1 padded_out_length = output_length + kernel_size - 1 pad_before = kernel_size - 1 - padding_before pad_after = padded_out_length - expanded_input_length - pad_before return (pad_before, pad_after) def _compute_padding_values( input_shape, kernel_shape, strides=1, padding="valid", output_padding=None, data_format="channels_last", dilation_rate=1, ): """Computes adjusted padding for `conv_transpose`.""" num_spatial_dims = len(input_shape) - 2 if isinstance(output_padding, int): output_padding = (output_padding,) * num_spatial_dims if isinstance(strides, int): strides = (strides,) * num_spatial_dims if isinstance(dilation_rate, int): dilation_rate = (dilation_rate,) * num_spatial_dims kernel_spatial_shape = kernel_shape[:-2] if data_format == "channels_last": input_spatial_shape = input_shape[1:-1] else: input_spatial_shape = input_shape[2:] padding_values = [] for i in range(num_spatial_dims): input_length = input_spatial_shape[i] current_output_padding = ( None if output_padding is None else output_padding[i] ) output_length = compute_conv_transpose_output_length( input_spatial_shape[i], kernel_spatial_shape[i], padding=padding, output_padding=current_output_padding, stride=strides[i], dilation=dilation_rate[i], ) padding_value = _compute_padding_value_one_dim( input_length, output_length, kernel_spatial_shape[i], strides[i], padding=padding, dilation_rate=dilation_rate[i], ) padding_values.append(padding_value) return padding_values def conv_transpose( inputs, kernel, strides=1, padding="valid", output_padding=None, data_format="channels_last", dilation_rate=1, ): num_spatial_dims = inputs.ndim - 2 padding_values = _compute_padding_values( inputs.shape, kernel.shape, strides, padding, output_padding, data_format, dilation_rate, ) dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, data_format, transpose=False, ) strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format, include_batch_and_channels=False, ) dilation_rate = _convert_to_spatial_operand( dilation_rate, num_spatial_dims, data_format, include_batch_and_channels=False, ) return jax.lax.conv_transpose( inputs, kernel, strides, padding=padding_values, rhs_dilation=dilation_rate, dimension_numbers=dimension_numbers, transpose_kernel=True, ) def one_hot(x, num_classes, axis=-1): return jnn.one_hot(x, num_classes, axis=axis) def categorical_crossentropy(target, output, from_logits=False, axis=-1): target = jnp.array(target) output = jnp.array(output) if target.shape != output.shape: raise ValueError( "Arguments `target` and `output` must have the same shape. " "Received: " f"target.shape={target.shape}, output.shape={output.shape}" ) if len(target.shape) < 1: raise ValueError( "Arguments `target` and `output` must be at least rank 1. " "Received: " f"target.shape={target.shape}, output.shape={output.shape}" ) if from_logits: log_prob = jax.nn.log_softmax(output, axis=axis) else: output = output / jnp.sum(output, axis, keepdims=True) output = jnp.clip(output, epsilon(), 1.0 - epsilon()) log_prob = jnp.log(output) return -jnp.sum(target * log_prob, axis=axis) def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): target = jnp.array(target, dtype="int64") output = jnp.array(output) if len(target.shape) == len(output.shape) and target.shape[-1] == 1: target = jnp.squeeze(target, axis=-1) if len(output.shape) < 1: raise ValueError( "Argument `output` must be at least rank 1. " "Received: " f"output.shape={output.shape}" ) if target.shape != output.shape[:-1]: raise ValueError( "Arguments `target` and `output` must have the same shape " "up until the last dimension: " f"target.shape={target.shape}, output.shape={output.shape}" ) if from_logits: log_prob = jax.nn.log_softmax(output, axis=axis) else: output = output / jnp.sum(output, axis, keepdims=True) output = jnp.clip(output, epsilon(), 1.0 - epsilon()) log_prob = jnp.log(output) target = jnn.one_hot(target, output.shape[axis], axis=axis) return -jnp.sum(target * log_prob, axis=axis) def binary_crossentropy(target, output, from_logits=False): target = jnp.array(target) output = jnp.array(output) if target.shape != output.shape: raise ValueError( "Arguments `target` and `output` must have the same shape. " "Received: " f"target.shape={target.shape}, output.shape={output.shape}" ) if from_logits: output = jnn.sigmoid(output) output = jnp.clip(output, epsilon(), 1.0 - epsilon()) bce = target * jnp.log(output) bce += (1.0 - target) * jnp.log(1.0 - output) return -bce