Further code cleanup
This commit is contained in:
parent
c95c32e473
commit
d32b8fa4bd
@ -236,6 +236,8 @@ class Layer(object):
|
||||
input_mask, output_mask: Same as above, for masks.
|
||||
trainable_weights: List of variables.
|
||||
non_trainable_weights: List of variables.
|
||||
weights: The concatenation of the lists trainable_weights and
|
||||
non_trainable_weights (in this order).
|
||||
regularizers: List of regularizers.
|
||||
constraints: Dict mapping weights to constraints.
|
||||
|
||||
@ -872,20 +874,20 @@ class Layer(object):
|
||||
'''
|
||||
params = self.weights
|
||||
if len(params) != len(weights):
|
||||
raise Exception('You called `set_weights(weights)` on layer "' + self.name +
|
||||
'" with a weight list of length ' + str(len(weights)) +
|
||||
', but the layer was expecting ' + str(len(params)) +
|
||||
' weights. Provided weights: ' + str(weights)[:50] + '...')
|
||||
raise ValueError('You called `set_weights(weights)` on layer "' + self.name +
|
||||
'" with a weight list of length ' + str(len(weights)) +
|
||||
', but the layer was expecting ' + str(len(params)) +
|
||||
' weights. Provided weights: ' + str(weights)[:50] + '...')
|
||||
if not params:
|
||||
return
|
||||
weight_value_tuples = []
|
||||
param_values = K.batch_get_value(params)
|
||||
for pv, p, w in zip(param_values, params, weights):
|
||||
if pv.shape != w.shape:
|
||||
raise Exception('Layer weight shape ' +
|
||||
str(pv.shape) +
|
||||
' not compatible with '
|
||||
'provided weight shape ' + str(w.shape))
|
||||
raise ValueError('Layer weight shape ' +
|
||||
str(pv.shape) +
|
||||
' not compatible with '
|
||||
'provided weight shape ' + str(w.shape))
|
||||
weight_value_tuples.append((p, w))
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user