Add shareable BN (per-datastream updates).
This commit is contained in:
parent
c4c4fac1ae
commit
771010f43b
@ -282,9 +282,6 @@ class Layer(object):
|
|||||||
if not hasattr(self, 'uses_learning_phase'):
|
if not hasattr(self, 'uses_learning_phase'):
|
||||||
self.uses_learning_phase = False
|
self.uses_learning_phase = False
|
||||||
|
|
||||||
# Per-input updates.
|
|
||||||
self._per_input_updates = {}
|
|
||||||
|
|
||||||
# These lists will be filled via successive calls
|
# These lists will be filled via successive calls
|
||||||
# to self.add_inbound_node().
|
# to self.add_inbound_node().
|
||||||
self.inbound_nodes = []
|
self.inbound_nodes = []
|
||||||
@ -806,8 +803,13 @@ class Layer(object):
|
|||||||
# Update self.updates
|
# Update self.updates
|
||||||
if not hasattr(self, 'updates'):
|
if not hasattr(self, 'updates'):
|
||||||
self.updates = []
|
self.updates = []
|
||||||
self.updates += updates
|
try:
|
||||||
|
self.updates += updates
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
# Update self._per_input_updates
|
# Update self._per_input_updates
|
||||||
|
if not hasattr(self, '_per_input_updates'):
|
||||||
|
self._per_input_updates = {}
|
||||||
inputs = to_list(inputs)
|
inputs = to_list(inputs)
|
||||||
updates = to_list(updates)
|
updates = to_list(updates)
|
||||||
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
||||||
@ -816,6 +818,8 @@ class Layer(object):
|
|||||||
self._per_input_updates[inputs_hash] += updates
|
self._per_input_updates[inputs_hash] += updates
|
||||||
|
|
||||||
def get_updates_for(self, inputs):
|
def get_updates_for(self, inputs):
|
||||||
|
if not hasattr(self, '_per_input_updates'):
|
||||||
|
return []
|
||||||
inputs = to_list(inputs)
|
inputs = to_list(inputs)
|
||||||
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
||||||
if inputs_hash in self._per_input_updates:
|
if inputs_hash in self._per_input_updates:
|
||||||
@ -1905,12 +1909,6 @@ class Container(Layer):
|
|||||||
updates += layer.get_updates_for(inputs)
|
updates += layer.get_updates_for(inputs)
|
||||||
return updates
|
return updates
|
||||||
|
|
||||||
def get_updates_for(self, inputs):
|
|
||||||
# In this case, returns model updates,
|
|
||||||
# since a model cannot have inputs-specific updates
|
|
||||||
# (only atomic layers can).
|
|
||||||
return self.updates
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stateful(self):
|
def stateful(self):
|
||||||
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
|
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
|
||||||
@ -2198,6 +2196,10 @@ class Container(Layer):
|
|||||||
output_tensors = to_list(layer.call(computed_tensors, computed_masks))
|
output_tensors = to_list(layer.call(computed_tensors, computed_masks))
|
||||||
output_masks = to_list(layer.compute_mask(computed_tensors, computed_masks))
|
output_masks = to_list(layer.compute_mask(computed_tensors, computed_masks))
|
||||||
|
|
||||||
|
# update model updates
|
||||||
|
layer_inputs = [x[0] for x in computed_data]
|
||||||
|
self.add_updates(layer.get_updates_for(layer_inputs), inputs)
|
||||||
|
|
||||||
# Update _keras_shape.
|
# Update _keras_shape.
|
||||||
if all([hasattr(x, '_keras_shape') for x in computed_tensors]):
|
if all([hasattr(x, '_keras_shape') for x in computed_tensors]):
|
||||||
if len(computed_tensors) == 1:
|
if len(computed_tensors) == 1:
|
||||||
|
@ -2,10 +2,10 @@ import pytest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.testing import assert_allclose
|
from numpy.testing import assert_allclose
|
||||||
|
|
||||||
from keras.layers.core import Dense, Activation
|
from keras.layers import Dense, Activation, Input
|
||||||
from keras.utils.test_utils import layer_test, keras_test
|
from keras.utils.test_utils import layer_test, keras_test
|
||||||
from keras.layers import normalization
|
from keras.layers import normalization
|
||||||
from keras.models import Sequential
|
from keras.models import Sequential, Model
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
|
|
||||||
input_1 = np.arange(10)
|
input_1 = np.arange(10)
|
||||||
@ -78,5 +78,33 @@ def test_batchnorm_mode_1():
|
|||||||
assert_allclose(K.eval(K.std(out)), 0.0, atol=1e-1)
|
assert_allclose(K.eval(K.std(out)), 0.0, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_test
|
||||||
|
def test_shared_batchnorm():
|
||||||
|
'''Test that a BN layer can be shared
|
||||||
|
across different data streams.
|
||||||
|
'''
|
||||||
|
# Test single layer reuse
|
||||||
|
bn = normalization.BatchNormalization(input_shape=(10,), mode=0)
|
||||||
|
x1 = Input(shape=(10,))
|
||||||
|
bn(x1)
|
||||||
|
|
||||||
|
x2 = Input(shape=(10,))
|
||||||
|
y2 = bn(x2)
|
||||||
|
|
||||||
|
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
|
||||||
|
model = Model(x2, y2)
|
||||||
|
assert len(model.updates) == 2
|
||||||
|
model.compile('sgd', 'mse')
|
||||||
|
model.train_on_batch(x, x)
|
||||||
|
|
||||||
|
# Test model-level reuse
|
||||||
|
x3 = Input(shape=(10,))
|
||||||
|
y3 = model(x3)
|
||||||
|
new_model = Model(x3, y3)
|
||||||
|
assert len(model.updates) == 2
|
||||||
|
new_model.compile('sgd', 'mse')
|
||||||
|
new_model.train_on_batch(x, x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
Loading…
Reference in New Issue
Block a user