diff --git a/keras/layers/core.py b/keras/layers/core.py index af55af4d0..bb04e9237 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -161,11 +161,42 @@ class Masking(MaskedLayer): return {"name": self.__class__.__name__, "mask_value": self.mask_value} +class TimeDistributedMerge(Layer): + def __init__(self, mode='sum'): + ''' + Sum/multiply/avearge over a time distributed layer's outputs. + mode: {'sum', 'mul', 'ave'} + Tensor input dimensions: (nb_sample, shared_dimension, input_dim) + Tensor output dimensions: (nb_sample, output_dim) + ''' + self.mode = mode + self.params = [] + self.regularizers = [] + self.constraints = [] + self.updates = [] + + def get_output(self, train=False): + X = self.get_input(train) + if self.mode == 'sum' or self.mode == 'ave': + s = theano.tensor.sum(X, axis=1) + if self.mode == 'ave': + s /= X.shape[1] + return s + elif self.mode == 'mul': + s = theano.tensor.mul(X, axis=1) + return s + else: + raise Exception('Unknown merge mode') + + def get_config(self): + return {"name": self.__class__.__name__, + "mode": self.mode} + class Merge(Layer): def __init__(self, layers, mode='sum', concat_axis=-1): ''' Merge the output of a list of layers or containers into a single tensor. - mode: {'sum', 'mul', 'concat'} + mode: {'sum', 'mul', 'concat', 'ave'} ''' if len(layers) < 2: raise Exception("Please specify two or more input layers (or containers) to merge") @@ -190,10 +221,12 @@ class Merge(Layer): return self.params, self.regularizers, self.constraints, self.updates def get_output(self, train=False): - if self.mode == 'sum': + if self.mode == 'sum' or self.mode == 'ave': s = self.layers[0].get_output(train) for i in range(1, len(self.layers)): s += self.layers[i].get_output(train) + if self.mode == 'ave': + s /= len(self.layers) return s elif self.mode == 'concat': inputs = [self.layers[i].get_output(train) for i in range(len(self.layers))] diff --git a/tests/auto/keras/layers/test_core.py b/tests/auto/keras/layers/test_core.py index e8034c55c..705a8f7cc 100644 --- a/tests/auto/keras/layers/test_core.py +++ b/tests/auto/keras/layers/test_core.py @@ -104,6 +104,10 @@ class TestConfigParams(unittest.TestCase): layer = core.TimeDistributedDense(10, 10) self._runner(layer) + def test_time_dist_merge(self): + layer = core.TimeDistributedMerge() + self._runner(layer) + def test_autoencoder(self): layer_1 = core.Layer() layer_2 = core.Layer()