Make BN shareable (not yet working)

This commit is contained in:
Francois Chollet 2016-11-15 05:16:40 -08:00
parent 016d85c9e6
commit c4c4fac1ae
4 changed files with 48 additions and 20 deletions

@ -282,6 +282,9 @@ class Layer(object):
if not hasattr(self, 'uses_learning_phase'):
self.uses_learning_phase = False
# Per-input updates.
self._per_input_updates = {}
# These lists will be filled via successive calls
# to self.add_inbound_node().
self.inbound_nodes = []
@ -799,6 +802,26 @@ class Layer(object):
'ill-defined for the layer. ' +
'Use `get_output_shape_at(node_index)` instead.')
def add_updates(self, updates, inputs):
# Update self.updates
if not hasattr(self, 'updates'):
self.updates = []
self.updates += updates
# Update self._per_input_updates
inputs = to_list(inputs)
updates = to_list(updates)
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
if inputs_hash not in self._per_input_updates:
self._per_input_updates[inputs_hash] = []
self._per_input_updates[inputs_hash] += updates
def get_updates_for(self, inputs):
inputs = to_list(inputs)
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
if inputs_hash in self._per_input_updates:
return self._per_input_updates[inputs_hash]
return []
@property
def weights(self):
return self.trainable_weights + self.non_trainable_weights
@ -1871,9 +1894,23 @@ class Container(Layer):
updates = []
for layer in self.layers:
if hasattr(layer, 'updates'):
updates += layer.updates
if len(layer.inbound_nodes) == 1:
updates += layer.updates
else:
for node_index, node in enumerate(layer.inbound_nodes):
node_key = layer.name + '_ib-' + str(node_index)
if node_key in self.container_nodes:
# The model owns this layer node.
inputs = node.input_tensors
updates += layer.get_updates_for(inputs)
return updates
def get_updates_for(self, inputs):
# In this case, returns model updates,
# since a model cannot have inputs-specific updates
# (only atomic layers can).
return self.updates
@property
def stateful(self):
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])

@ -104,7 +104,6 @@ class BatchNormalization(Layer):
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
self.called_with = None
def call(self, x, mask=None):
if self.mode == 0 or self.mode == 2:
@ -122,23 +121,12 @@ class BatchNormalization(Layer):
epsilon=self.epsilon)
else:
# mode 0
if self.called_with not in {None, x}:
raise Exception('You are attempting to share a '
'same `BatchNormalization` layer across '
'different data flows. '
'This is not possible. '
'You should use `mode=2` in '
'`BatchNormalization`, which has '
'a similar behavior but is shareable '
'(see docs for a description of '
'the behavior).')
self.called_with = x
x_normed, mean, std = K.normalize_batch_in_training(
x, self.gamma, self.beta, reduction_axes,
epsilon=self.epsilon)
self.updates = [K.moving_average_update(self.running_mean, mean, self.momentum),
K.moving_average_update(self.running_std, std, self.momentum)]
self.add_updates([K.moving_average_update(self.running_mean, mean, self.momentum),
K.moving_average_update(self.running_std, std, self.momentum)], x)
if K.backend() == 'tensorflow' and sorted(reduction_axes) == range(K.ndim(x))[:-1]:
x_normed_running = K.batch_normalization(

@ -226,9 +226,10 @@ class Recurrent(Layer):
unroll=self.unroll,
input_length=input_shape[1])
if self.stateful:
self.updates = []
updates = []
for i in range(len(states)):
self.updates.append((self.states[i], states[i]))
updates.append((self.states[i], states[i]))
self.add_updates(updates, x)
if self.return_sequences:
return outputs

@ -473,13 +473,15 @@ class Sequential(Model):
@property
def updates(self):
# support for legacy behavior
return self._gather_list_attr('updates')
return self.model.updates
@property
def state_updates(self):
# support for legacy behavior
return self._gather_list_attr('state_updates')
return self.model.state_updates
def get_updates_for(self, inputs):
return self.model.get_updates_for(inputs)
@property
def regularizers(self):