Further code cleanup

This commit is contained in:
Francois Chollet 2016-11-07 17:27:41 -08:00
parent c95c32e473
commit d32b8fa4bd

@ -236,6 +236,8 @@ class Layer(object):
input_mask, output_mask: Same as above, for masks.
trainable_weights: List of variables.
non_trainable_weights: List of variables.
weights: The concatenation of the lists trainable_weights and
non_trainable_weights (in this order).
regularizers: List of regularizers.
constraints: Dict mapping weights to constraints.
@ -872,20 +874,20 @@ class Layer(object):
'''
params = self.weights
if len(params) != len(weights):
raise Exception('You called `set_weights(weights)` on layer "' + self.name +
'" with a weight list of length ' + str(len(weights)) +
', but the layer was expecting ' + str(len(params)) +
' weights. Provided weights: ' + str(weights)[:50] + '...')
raise ValueError('You called `set_weights(weights)` on layer "' + self.name +
'" with a weight list of length ' + str(len(weights)) +
', but the layer was expecting ' + str(len(params)) +
' weights. Provided weights: ' + str(weights)[:50] + '...')
if not params:
return
weight_value_tuples = []
param_values = K.batch_get_value(params)
for pv, p, w in zip(param_values, params, weights):
if pv.shape != w.shape:
raise Exception('Layer weight shape ' +
str(pv.shape) +
' not compatible with '
'provided weight shape ' + str(w.shape))
raise ValueError('Layer weight shape ' +
str(pv.shape) +
' not compatible with '
'provided weight shape ' + str(w.shape))
weight_value_tuples.append((p, w))
K.batch_set_value(weight_value_tuples)