27d6e4aec1
Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
32 lines
804 B
Python
32 lines
804 B
Python
import torch
|
|
|
|
from keras_core import layers
|
|
from keras_core.backend.common import KerasVariable
|
|
|
|
|
|
class Net(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = layers.Dense(1)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
return x
|
|
|
|
|
|
net = Net()
|
|
|
|
# Test using Keras layer in a nn.Module.
|
|
# Test forward pass
|
|
assert list(net(torch.empty(100, 10)).shape) == [100, 1]
|
|
# Test KerasVariables are added as nn.Parameter.
|
|
assert len(list(net.parameters())) == 2
|
|
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
|
|
|
|
|
|
# Test using KerasVariable as a torch tensor for torch ops.
|
|
kernel = net.fc1.kernel
|
|
transposed_kernel = torch.transpose(kernel, 0, 1)
|
|
assert isinstance(kernel, KerasVariable)
|
|
assert isinstance(torch.mul(kernel, transposed_kernel), torch.Tensor)
|