Fix torch bug

This commit is contained in:
Francois Chollet 2023-06-03 14:22:16 -07:00
parent 0054b644fa
commit 6aac2389d5
3 changed files with 11 additions and 4 deletions

@ -14,6 +14,7 @@ class KerasVariable:
self.name = name or auto_name(self.__class__.__name__) self.name = name or auto_name(self.__class__.__name__)
dtype = standardize_dtype(dtype) dtype = standardize_dtype(dtype)
self._dtype = dtype self._dtype = dtype
self._shape = None
self._initializer = None self._initializer = None
self.trainable = trainable self.trainable = trainable
if callable(initializer): if callable(initializer):

@ -7,7 +7,13 @@ class TorchLayer(torch.nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
# TODO: find a good place to add the params. It should be added right # TODO: find a good place to add the params. It should be added right
# after the variables are initialized. # after the variables are initialized.
self.params = torch.nn.ParameterList( if not hasattr(self, "torch_params"):
[variable.value for variable in self.variables] self.torch_params = torch.nn.ParameterList(
[
torch.nn.Parameter(
variable.value, requires_grad=variable.trainable
)
for variable in self.variables
]
) )
return Operation.__call__(self, *args, **kwargs) return Operation.__call__(self, *args, **kwargs)

@ -225,7 +225,7 @@ class PyDatasetAdapter(DataAdapter):
) )
if len(batch) == 2: if len(batch) == 2:
sw = data_adapter_utils.class_weight_to_sample_weights( sw = data_adapter_utils.class_weight_to_sample_weights(
batch[1], class_weight batch[1], self.class_weight
) )
batch = batch + (sw,) batch = batch + (sw,)
return batch return batch