Fix torch bug
This commit is contained in:
parent
0054b644fa
commit
6aac2389d5
@ -14,6 +14,7 @@ class KerasVariable:
|
||||
self.name = name or auto_name(self.__class__.__name__)
|
||||
dtype = standardize_dtype(dtype)
|
||||
self._dtype = dtype
|
||||
self._shape = None
|
||||
self._initializer = None
|
||||
self.trainable = trainable
|
||||
if callable(initializer):
|
||||
|
@ -7,7 +7,13 @@ class TorchLayer(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
# TODO: find a good place to add the params. It should be added right
|
||||
# after the variables are initialized.
|
||||
self.params = torch.nn.ParameterList(
|
||||
[variable.value for variable in self.variables]
|
||||
)
|
||||
if not hasattr(self, "torch_params"):
|
||||
self.torch_params = torch.nn.ParameterList(
|
||||
[
|
||||
torch.nn.Parameter(
|
||||
variable.value, requires_grad=variable.trainable
|
||||
)
|
||||
for variable in self.variables
|
||||
]
|
||||
)
|
||||
return Operation.__call__(self, *args, **kwargs)
|
||||
|
@ -225,7 +225,7 @@ class PyDatasetAdapter(DataAdapter):
|
||||
)
|
||||
if len(batch) == 2:
|
||||
sw = data_adapter_utils.class_weight_to_sample_weights(
|
||||
batch[1], class_weight
|
||||
batch[1], self.class_weight
|
||||
)
|
||||
batch = batch + (sw,)
|
||||
return batch
|
||||
|
Loading…
Reference in New Issue
Block a user