Adds PyTorch NN (except Pooling and Conv) (#111)
* Add NN for pytorch backend * Add PyTorch nn functions * Add PyTorch nn functions * Adds PyTorch Backend * Adds PyTorch Backend * Adds PyTorch nn backend * Add PyTorch backend funcitons * Add Torch NN Backend * Add Torch NN Backend * Add Torch NN Backend
This commit is contained in:
parent
2cae421d47
commit
3b115c1073
@ -0,0 +1,240 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as tnn
|
||||||
|
|
||||||
|
from keras_core.backend.config import epsilon
|
||||||
|
|
||||||
|
|
||||||
|
def relu(x):
|
||||||
|
return tnn.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
def relu6(x):
|
||||||
|
return tnn.relu6(x)
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid(x):
|
||||||
|
return tnn.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def tanh(x):
|
||||||
|
return tnn.tanh(x)
|
||||||
|
|
||||||
|
|
||||||
|
def softplus(x):
|
||||||
|
return tnn.softplus(x)
|
||||||
|
|
||||||
|
|
||||||
|
def softsign(x):
|
||||||
|
return tnn.soft_sign(x)
|
||||||
|
|
||||||
|
|
||||||
|
def silu(x, beta=1.0):
|
||||||
|
return x * sigmoid(beta * x)
|
||||||
|
|
||||||
|
|
||||||
|
def swish(x):
|
||||||
|
return silu(x, beta=1)
|
||||||
|
|
||||||
|
|
||||||
|
def log_sigmoid(x):
|
||||||
|
return tnn.logsigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x, negative_slope=0.2):
|
||||||
|
return tnn.leaky_relu(x, negative_slope=negative_slope)
|
||||||
|
|
||||||
|
|
||||||
|
def hard_sigmoid(x):
|
||||||
|
return tnn.hardsigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def elu(x):
|
||||||
|
return tnn.elu(x)
|
||||||
|
|
||||||
|
|
||||||
|
def selu(x):
|
||||||
|
return tnn.selu(x)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x, approximate=True):
|
||||||
|
return tnn.gelu(x, approximate)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x, axis=None):
|
||||||
|
return tnn.softmax(x, dim=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def log_softmax(x, axis=-1):
|
||||||
|
return tnn.log_softmax(x, dim=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def max_pool(
|
||||||
|
inputs,
|
||||||
|
pool_size,
|
||||||
|
strides=None,
|
||||||
|
padding="valid",
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`max_pool` not yet implemented for PyTorch Backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def average_pool(
|
||||||
|
inputs,
|
||||||
|
pool_size,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`average_pool` not yet implemented for PyTorch Backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv(
|
||||||
|
inputs,
|
||||||
|
kernel,
|
||||||
|
strides=1,
|
||||||
|
padding="valid",
|
||||||
|
data_format="channels_last",
|
||||||
|
dilation_rate=1,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("`conv` not yet implemented for PyTorch Backend")
|
||||||
|
|
||||||
|
|
||||||
|
def depthwise_conv(
|
||||||
|
inputs,
|
||||||
|
kernel,
|
||||||
|
strides=1,
|
||||||
|
padding="valid",
|
||||||
|
data_format="channels_last",
|
||||||
|
dilation_rate=1,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`depthwise_conv` not yet implemented for PyTorch Backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def separable_conv(
|
||||||
|
inputs,
|
||||||
|
depthwise_kernel,
|
||||||
|
pointwise_kernel,
|
||||||
|
strides=1,
|
||||||
|
padding="valid",
|
||||||
|
data_format="channels_last",
|
||||||
|
dilation_rate=1,
|
||||||
|
):
|
||||||
|
depthwise_conv_output = depthwise_conv(
|
||||||
|
inputs,
|
||||||
|
depthwise_kernel,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
data_format,
|
||||||
|
dilation_rate,
|
||||||
|
)
|
||||||
|
return conv(
|
||||||
|
depthwise_conv_output,
|
||||||
|
pointwise_kernel,
|
||||||
|
strides=1,
|
||||||
|
padding="valid",
|
||||||
|
data_format=data_format,
|
||||||
|
dilation_rate=dilation_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_transpose(
|
||||||
|
inputs,
|
||||||
|
kernel,
|
||||||
|
strides=1,
|
||||||
|
padding="valid",
|
||||||
|
output_padding=None,
|
||||||
|
data_format="channels_last",
|
||||||
|
dilation_rate=1,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`conv_transpose` not yet implemented for PyTorch backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def one_hot(x, num_classes, axis=-1):
|
||||||
|
if axis != -1 or axis != x.shape[-1]:
|
||||||
|
raise ValueError(
|
||||||
|
"`one_hot` is only implemented for last axis for PyTorch backend. "
|
||||||
|
f"`axis` arg value {axis} should be -1 or last axis of the input "
|
||||||
|
f"tensor with shape {x.shape}."
|
||||||
|
)
|
||||||
|
return tnn.one_hot(x, num_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||||
|
target = torch.as_tensor(target)
|
||||||
|
output = torch.as_tensor(output)
|
||||||
|
|
||||||
|
if target.shape != output.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"Arguments `target` and `output` must have the same shape. "
|
||||||
|
"Received: "
|
||||||
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
||||||
|
)
|
||||||
|
if len(target.shape) < 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Arguments `target` and `output` must be at least rank 1. "
|
||||||
|
"Received: "
|
||||||
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if from_logits:
|
||||||
|
log_prob = tnn.log_softmax(output, dim=axis)
|
||||||
|
else:
|
||||||
|
output = output / torch.sum(output, dim=axis, keepdim=True)
|
||||||
|
output = torch.clip(output, epsilon(), 1.0 - epsilon())
|
||||||
|
log_prob = torch.log(output)
|
||||||
|
return -torch.sum(target * log_prob, dim=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||||
|
target = torch.as_tensor(target, dtype=torch.long)
|
||||||
|
output = torch.as_tensor(output)
|
||||||
|
|
||||||
|
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
|
||||||
|
target = torch.squeeze(target, dim=-1)
|
||||||
|
|
||||||
|
if len(output.shape) < 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Argument `output` must be at least rank 1. "
|
||||||
|
"Received: "
|
||||||
|
f"output.shape={output.shape}"
|
||||||
|
)
|
||||||
|
if target.shape != output.shape[:-1]:
|
||||||
|
raise ValueError(
|
||||||
|
"Arguments `target` and `output` must have the same shape "
|
||||||
|
"up until the last dimension: "
|
||||||
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
||||||
|
)
|
||||||
|
if from_logits:
|
||||||
|
log_prob = tnn.log_softmax(output, dim=axis)
|
||||||
|
else:
|
||||||
|
output = output / torch.sum(output, dim=axis, keepdim=True)
|
||||||
|
output = torch.clip(output, epsilon(), 1.0 - epsilon())
|
||||||
|
log_prob = torch.log(output)
|
||||||
|
target = one_hot(target, output.shape[axis], axis=axis)
|
||||||
|
return -torch.sum(target * log_prob, dim=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def binary_crossentropy(target, output, from_logits=False):
|
||||||
|
# TODO: `torch.as_tensor` has device arg. Need to think how to pass it.
|
||||||
|
target = torch.as_tensor(target)
|
||||||
|
output = torch.as_tensor(output)
|
||||||
|
|
||||||
|
if target.shape != output.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"Arguments `target` and `output` must have the same shape. "
|
||||||
|
"Received: "
|
||||||
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if from_logits:
|
||||||
|
return tnn.binary_cross_entropy_with_logits(output, target)
|
||||||
|
else:
|
||||||
|
return tnn.binary_cross_entropy(output, target)
|
@ -12,4 +12,3 @@ pandas
|
|||||||
absl-py
|
absl-py
|
||||||
requests
|
requests
|
||||||
h5py
|
h5py
|
||||||
torch
|
|
Loading…
Reference in New Issue
Block a user