Fix torch bug

This commit is contained in:
Francois Chollet 2023-06-03 14:22:16 -07:00
parent 0054b644fa
commit 6aac2389d5
3 changed files with 11 additions and 4 deletions

@ -14,6 +14,7 @@ class KerasVariable:
self.name = name or auto_name(self.__class__.__name__)
dtype = standardize_dtype(dtype)
self._dtype = dtype
self._shape = None
self._initializer = None
self.trainable = trainable
if callable(initializer):

@ -7,7 +7,13 @@ class TorchLayer(torch.nn.Module):
def forward(self, *args, **kwargs):
# TODO: find a good place to add the params. It should be added right
# after the variables are initialized.
self.params = torch.nn.ParameterList(
[variable.value for variable in self.variables]
)
if not hasattr(self, "torch_params"):
self.torch_params = torch.nn.ParameterList(
[
torch.nn.Parameter(
variable.value, requires_grad=variable.trainable
)
for variable in self.variables
]
)
return Operation.__call__(self, *args, **kwargs)

@ -225,7 +225,7 @@ class PyDatasetAdapter(DataAdapter):
)
if len(batch) == 2:
sw = data_adapter_utils.class_weight_to_sample_weights(
batch[1], class_weight
batch[1], self.class_weight
)
batch = batch + (sw,)
return batch