Make BN shareable (not yet working)
This commit is contained in:
parent
016d85c9e6
commit
c4c4fac1ae
@ -282,6 +282,9 @@ class Layer(object):
|
||||
if not hasattr(self, 'uses_learning_phase'):
|
||||
self.uses_learning_phase = False
|
||||
|
||||
# Per-input updates.
|
||||
self._per_input_updates = {}
|
||||
|
||||
# These lists will be filled via successive calls
|
||||
# to self.add_inbound_node().
|
||||
self.inbound_nodes = []
|
||||
@ -799,6 +802,26 @@ class Layer(object):
|
||||
'ill-defined for the layer. ' +
|
||||
'Use `get_output_shape_at(node_index)` instead.')
|
||||
|
||||
def add_updates(self, updates, inputs):
|
||||
# Update self.updates
|
||||
if not hasattr(self, 'updates'):
|
||||
self.updates = []
|
||||
self.updates += updates
|
||||
# Update self._per_input_updates
|
||||
inputs = to_list(inputs)
|
||||
updates = to_list(updates)
|
||||
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
||||
if inputs_hash not in self._per_input_updates:
|
||||
self._per_input_updates[inputs_hash] = []
|
||||
self._per_input_updates[inputs_hash] += updates
|
||||
|
||||
def get_updates_for(self, inputs):
|
||||
inputs = to_list(inputs)
|
||||
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
|
||||
if inputs_hash in self._per_input_updates:
|
||||
return self._per_input_updates[inputs_hash]
|
||||
return []
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
return self.trainable_weights + self.non_trainable_weights
|
||||
@ -1871,9 +1894,23 @@ class Container(Layer):
|
||||
updates = []
|
||||
for layer in self.layers:
|
||||
if hasattr(layer, 'updates'):
|
||||
updates += layer.updates
|
||||
if len(layer.inbound_nodes) == 1:
|
||||
updates += layer.updates
|
||||
else:
|
||||
for node_index, node in enumerate(layer.inbound_nodes):
|
||||
node_key = layer.name + '_ib-' + str(node_index)
|
||||
if node_key in self.container_nodes:
|
||||
# The model owns this layer node.
|
||||
inputs = node.input_tensors
|
||||
updates += layer.get_updates_for(inputs)
|
||||
return updates
|
||||
|
||||
def get_updates_for(self, inputs):
|
||||
# In this case, returns model updates,
|
||||
# since a model cannot have inputs-specific updates
|
||||
# (only atomic layers can).
|
||||
return self.updates
|
||||
|
||||
@property
|
||||
def stateful(self):
|
||||
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
|
||||
|
@ -104,7 +104,6 @@ class BatchNormalization(Layer):
|
||||
self.set_weights(self.initial_weights)
|
||||
del self.initial_weights
|
||||
self.built = True
|
||||
self.called_with = None
|
||||
|
||||
def call(self, x, mask=None):
|
||||
if self.mode == 0 or self.mode == 2:
|
||||
@ -122,23 +121,12 @@ class BatchNormalization(Layer):
|
||||
epsilon=self.epsilon)
|
||||
else:
|
||||
# mode 0
|
||||
if self.called_with not in {None, x}:
|
||||
raise Exception('You are attempting to share a '
|
||||
'same `BatchNormalization` layer across '
|
||||
'different data flows. '
|
||||
'This is not possible. '
|
||||
'You should use `mode=2` in '
|
||||
'`BatchNormalization`, which has '
|
||||
'a similar behavior but is shareable '
|
||||
'(see docs for a description of '
|
||||
'the behavior).')
|
||||
self.called_with = x
|
||||
x_normed, mean, std = K.normalize_batch_in_training(
|
||||
x, self.gamma, self.beta, reduction_axes,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
self.updates = [K.moving_average_update(self.running_mean, mean, self.momentum),
|
||||
K.moving_average_update(self.running_std, std, self.momentum)]
|
||||
self.add_updates([K.moving_average_update(self.running_mean, mean, self.momentum),
|
||||
K.moving_average_update(self.running_std, std, self.momentum)], x)
|
||||
|
||||
if K.backend() == 'tensorflow' and sorted(reduction_axes) == range(K.ndim(x))[:-1]:
|
||||
x_normed_running = K.batch_normalization(
|
||||
|
@ -226,9 +226,10 @@ class Recurrent(Layer):
|
||||
unroll=self.unroll,
|
||||
input_length=input_shape[1])
|
||||
if self.stateful:
|
||||
self.updates = []
|
||||
updates = []
|
||||
for i in range(len(states)):
|
||||
self.updates.append((self.states[i], states[i]))
|
||||
updates.append((self.states[i], states[i]))
|
||||
self.add_updates(updates, x)
|
||||
|
||||
if self.return_sequences:
|
||||
return outputs
|
||||
|
@ -473,13 +473,15 @@ class Sequential(Model):
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
# support for legacy behavior
|
||||
return self._gather_list_attr('updates')
|
||||
return self.model.updates
|
||||
|
||||
@property
|
||||
def state_updates(self):
|
||||
# support for legacy behavior
|
||||
return self._gather_list_attr('state_updates')
|
||||
return self.model.state_updates
|
||||
|
||||
def get_updates_for(self, inputs):
|
||||
return self.model.get_updates_for(inputs)
|
||||
|
||||
@property
|
||||
def regularizers(self):
|
||||
|
Loading…
Reference in New Issue
Block a user