Add shareable BN (per-datastream updates).
This commit is contained in:
parent
c4c4fac1ae
commit
771010f43b
@ -282,9 +282,6 @@ class Layer(object):
|
||||
if not hasattr(self, 'uses_learning_phase'):
|
||||
self.uses_learning_phase = False
|
||||
|
||||
# Per-input updates.
|
||||
self._per_input_updates = {}
|
||||
|
||||
# These lists will be filled via successive calls
|
||||
# to self.add_inbound_node().
|
||||
self.inbound_nodes = []
|
||||
@ -806,8 +803,13 @@ class Layer(object):
|
||||
# Update self.updates
|
||||
if not hasattr(self, 'updates'):
|
||||
self.updates = []
|
||||
self.updates += updates
|
||||
try:
|
||||
self.updates += updates
|
||||
except AttributeError:
|
||||
pass
|
||||
# Update self._per_input_updates
|
||||
if not hasattr(self, '_per_input_updates'):
|
||||
self._per_input_updates = {}
|
||||
inputs = to_list(inputs)
|
||||
updates = to_list(updates)
|
||||
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
||||
@ -816,6 +818,8 @@ class Layer(object):
|
||||
self._per_input_updates[inputs_hash] += updates
|
||||
|
||||
def get_updates_for(self, inputs):
|
||||
if not hasattr(self, '_per_input_updates'):
|
||||
return []
|
||||
inputs = to_list(inputs)
|
||||
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
||||
if inputs_hash in self._per_input_updates:
|
||||
@ -1905,12 +1909,6 @@ class Container(Layer):
|
||||
updates += layer.get_updates_for(inputs)
|
||||
return updates
|
||||
|
||||
def get_updates_for(self, inputs):
|
||||
# In this case, returns model updates,
|
||||
# since a model cannot have inputs-specific updates
|
||||
# (only atomic layers can).
|
||||
return self.updates
|
||||
|
||||
@property
|
||||
def stateful(self):
|
||||
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
|
||||
@ -2198,6 +2196,10 @@ class Container(Layer):
|
||||
output_tensors = to_list(layer.call(computed_tensors, computed_masks))
|
||||
output_masks = to_list(layer.compute_mask(computed_tensors, computed_masks))
|
||||
|
||||
# update model updates
|
||||
layer_inputs = [x[0] for x in computed_data]
|
||||
self.add_updates(layer.get_updates_for(layer_inputs), inputs)
|
||||
|
||||
# Update _keras_shape.
|
||||
if all([hasattr(x, '_keras_shape') for x in computed_tensors]):
|
||||
if len(computed_tensors) == 1:
|
||||
|
@ -2,10 +2,10 @@ import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from keras.layers.core import Dense, Activation
|
||||
from keras.layers import Dense, Activation, Input
|
||||
from keras.utils.test_utils import layer_test, keras_test
|
||||
from keras.layers import normalization
|
||||
from keras.models import Sequential
|
||||
from keras.models import Sequential, Model
|
||||
from keras import backend as K
|
||||
|
||||
input_1 = np.arange(10)
|
||||
@ -78,5 +78,33 @@ def test_batchnorm_mode_1():
|
||||
assert_allclose(K.eval(K.std(out)), 0.0, atol=1e-1)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_shared_batchnorm():
|
||||
'''Test that a BN layer can be shared
|
||||
across different data streams.
|
||||
'''
|
||||
# Test single layer reuse
|
||||
bn = normalization.BatchNormalization(input_shape=(10,), mode=0)
|
||||
x1 = Input(shape=(10,))
|
||||
bn(x1)
|
||||
|
||||
x2 = Input(shape=(10,))
|
||||
y2 = bn(x2)
|
||||
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
|
||||
model = Model(x2, y2)
|
||||
assert len(model.updates) == 2
|
||||
model.compile('sgd', 'mse')
|
||||
model.train_on_batch(x, x)
|
||||
|
||||
# Test model-level reuse
|
||||
x3 = Input(shape=(10,))
|
||||
y3 = model(x3)
|
||||
new_model = Model(x3, y3)
|
||||
assert len(model.updates) == 2
|
||||
new_model.compile('sgd', 'mse')
|
||||
new_model.train_on_batch(x, x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
Loading…
Reference in New Issue
Block a user