Add shareable BN (per-datastream updates).

This commit is contained in:
Francois Chollet 2016-11-16 19:06:46 -08:00
parent c4c4fac1ae
commit 771010f43b
2 changed files with 42 additions and 12 deletions

@ -282,9 +282,6 @@ class Layer(object):
if not hasattr(self, 'uses_learning_phase'):
self.uses_learning_phase = False
# Per-input updates.
self._per_input_updates = {}
# These lists will be filled via successive calls
# to self.add_inbound_node().
self.inbound_nodes = []
@ -806,8 +803,13 @@ class Layer(object):
# Update self.updates
if not hasattr(self, 'updates'):
self.updates = []
try:
self.updates += updates
except AttributeError:
pass
# Update self._per_input_updates
if not hasattr(self, '_per_input_updates'):
self._per_input_updates = {}
inputs = to_list(inputs)
updates = to_list(updates)
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
@ -816,6 +818,8 @@ class Layer(object):
self._per_input_updates[inputs_hash] += updates
def get_updates_for(self, inputs):
if not hasattr(self, '_per_input_updates'):
return []
inputs = to_list(inputs)
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
if inputs_hash in self._per_input_updates:
@ -1905,12 +1909,6 @@ class Container(Layer):
updates += layer.get_updates_for(inputs)
return updates
def get_updates_for(self, inputs):
# In this case, returns model updates,
# since a model cannot have inputs-specific updates
# (only atomic layers can).
return self.updates
@property
def stateful(self):
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
@ -2198,6 +2196,10 @@ class Container(Layer):
output_tensors = to_list(layer.call(computed_tensors, computed_masks))
output_masks = to_list(layer.compute_mask(computed_tensors, computed_masks))
# update model updates
layer_inputs = [x[0] for x in computed_data]
self.add_updates(layer.get_updates_for(layer_inputs), inputs)
# Update _keras_shape.
if all([hasattr(x, '_keras_shape') for x in computed_tensors]):
if len(computed_tensors) == 1:

@ -2,10 +2,10 @@ import pytest
import numpy as np
from numpy.testing import assert_allclose
from keras.layers.core import Dense, Activation
from keras.layers import Dense, Activation, Input
from keras.utils.test_utils import layer_test, keras_test
from keras.layers import normalization
from keras.models import Sequential
from keras.models import Sequential, Model
from keras import backend as K
input_1 = np.arange(10)
@ -78,5 +78,33 @@ def test_batchnorm_mode_1():
assert_allclose(K.eval(K.std(out)), 0.0, atol=1e-1)
@keras_test
def test_shared_batchnorm():
'''Test that a BN layer can be shared
across different data streams.
'''
# Test single layer reuse
bn = normalization.BatchNormalization(input_shape=(10,), mode=0)
x1 = Input(shape=(10,))
bn(x1)
x2 = Input(shape=(10,))
y2 = bn(x2)
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
model = Model(x2, y2)
assert len(model.updates) == 2
model.compile('sgd', 'mse')
model.train_on_batch(x, x)
# Test model-level reuse
x3 = Input(shape=(10,))
y3 = model(x3)
new_model = Model(x3, y3)
assert len(model.updates) == 2
new_model.compile('sgd', 'mse')
new_model.train_on_batch(x, x)
if __name__ == '__main__':
pytest.main([__file__])