Fix torch bug
This commit is contained in:
parent
0054b644fa
commit
6aac2389d5
@ -14,6 +14,7 @@ class KerasVariable:
|
|||||||
self.name = name or auto_name(self.__class__.__name__)
|
self.name = name or auto_name(self.__class__.__name__)
|
||||||
dtype = standardize_dtype(dtype)
|
dtype = standardize_dtype(dtype)
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
self._shape = None
|
||||||
self._initializer = None
|
self._initializer = None
|
||||||
self.trainable = trainable
|
self.trainable = trainable
|
||||||
if callable(initializer):
|
if callable(initializer):
|
||||||
|
@ -7,7 +7,13 @@ class TorchLayer(torch.nn.Module):
|
|||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
# TODO: find a good place to add the params. It should be added right
|
# TODO: find a good place to add the params. It should be added right
|
||||||
# after the variables are initialized.
|
# after the variables are initialized.
|
||||||
self.params = torch.nn.ParameterList(
|
if not hasattr(self, "torch_params"):
|
||||||
[variable.value for variable in self.variables]
|
self.torch_params = torch.nn.ParameterList(
|
||||||
|
[
|
||||||
|
torch.nn.Parameter(
|
||||||
|
variable.value, requires_grad=variable.trainable
|
||||||
|
)
|
||||||
|
for variable in self.variables
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return Operation.__call__(self, *args, **kwargs)
|
return Operation.__call__(self, *args, **kwargs)
|
||||||
|
@ -225,7 +225,7 @@ class PyDatasetAdapter(DataAdapter):
|
|||||||
)
|
)
|
||||||
if len(batch) == 2:
|
if len(batch) == 2:
|
||||||
sw = data_adapter_utils.class_weight_to_sample_weights(
|
sw = data_adapter_utils.class_weight_to_sample_weights(
|
||||||
batch[1], class_weight
|
batch[1], self.class_weight
|
||||||
)
|
)
|
||||||
batch = batch + (sw,)
|
batch = batch + (sw,)
|
||||||
return batch
|
return batch
|
||||||
|
Loading…
Reference in New Issue
Block a user