parent
0e18cb3efa
commit
f573a86b42
@ -927,7 +927,10 @@ class Layer(object):
|
||||
def get_updates_for(self, inputs):
|
||||
if not hasattr(self, '_per_input_updates'):
|
||||
return []
|
||||
inputs_hash = object_list_uid(inputs)
|
||||
if inputs is not None:
|
||||
inputs_hash = object_list_uid(inputs)
|
||||
else:
|
||||
inputs_hash = None
|
||||
if inputs_hash in self._per_input_updates:
|
||||
return self._per_input_updates[inputs_hash]
|
||||
return []
|
||||
@ -935,7 +938,10 @@ class Layer(object):
|
||||
def get_losses_for(self, inputs):
|
||||
if not hasattr(self, '_per_input_losses'):
|
||||
return []
|
||||
inputs_hash = object_list_uid(inputs)
|
||||
if inputs is not None:
|
||||
inputs_hash = object_list_uid(inputs)
|
||||
else:
|
||||
inputs_hash = None
|
||||
if inputs_hash in self._per_input_losses:
|
||||
return self._per_input_losses[inputs_hash]
|
||||
return []
|
||||
|
@ -9,6 +9,27 @@ from keras import backend as K
|
||||
from keras.models import model_from_json, model_from_yaml
|
||||
from keras.utils.test_utils import keras_test
|
||||
|
||||
@keras_test
|
||||
def test_get_updates_for():
|
||||
a = Input(shape=(2,))
|
||||
dense_layer = Dense(1)
|
||||
dense_layer.add_update(0, inputs=a)
|
||||
dense_layer.add_update(1, inputs=None)
|
||||
|
||||
assert dense_layer.get_updates_for(a) == [0]
|
||||
assert dense_layer.get_updates_for(None) == [1]
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_get_losses_for():
|
||||
a = Input(shape=(2,))
|
||||
dense_layer = Dense(1)
|
||||
dense_layer.add_loss(0, inputs=a)
|
||||
dense_layer.add_loss(1, inputs=None)
|
||||
|
||||
assert dense_layer.get_losses_for(a) == [0]
|
||||
assert dense_layer.get_losses_for(None) == [1]
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_trainable_weights():
|
||||
|
Loading…
Reference in New Issue
Block a user