This commit is contained in:
Stanislaw Jastrzebski 2016-12-30 10:19:01 +01:00 committed by François Chollet
parent 0e18cb3efa
commit f573a86b42
2 changed files with 29 additions and 2 deletions

@ -927,7 +927,10 @@ class Layer(object):
def get_updates_for(self, inputs):
if not hasattr(self, '_per_input_updates'):
return []
inputs_hash = object_list_uid(inputs)
if inputs is not None:
inputs_hash = object_list_uid(inputs)
else:
inputs_hash = None
if inputs_hash in self._per_input_updates:
return self._per_input_updates[inputs_hash]
return []
@ -935,7 +938,10 @@ class Layer(object):
def get_losses_for(self, inputs):
if not hasattr(self, '_per_input_losses'):
return []
inputs_hash = object_list_uid(inputs)
if inputs is not None:
inputs_hash = object_list_uid(inputs)
else:
inputs_hash = None
if inputs_hash in self._per_input_losses:
return self._per_input_losses[inputs_hash]
return []

@ -9,6 +9,27 @@ from keras import backend as K
from keras.models import model_from_json, model_from_yaml
from keras.utils.test_utils import keras_test
@keras_test
def test_get_updates_for():
a = Input(shape=(2,))
dense_layer = Dense(1)
dense_layer.add_update(0, inputs=a)
dense_layer.add_update(1, inputs=None)
assert dense_layer.get_updates_for(a) == [0]
assert dense_layer.get_updates_for(None) == [1]
@keras_test
def test_get_losses_for():
a = Input(shape=(2,))
dense_layer = Dense(1)
dense_layer.add_loss(0, inputs=a)
dense_layer.add_loss(1, inputs=None)
assert dense_layer.get_losses_for(a) == [0]
assert dense_layer.get_losses_for(None) == [1]
@keras_test
def test_trainable_weights():