diff --git a/keras/layers/core.py b/keras/layers/core.py index 5131f52fa..726b76caa 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -289,7 +289,8 @@ class Merge(Layer): if mode in {'sum', 'mul', 'ave', 'cos'}: input_shapes = set([l.output_shape for l in layers]) if len(input_shapes) > 1: - raise Exception("Only layers of same output shape can be merged using " + mode + " mode") + raise Exception("Only layers of same output shape can be merged using " + mode + " mode. " + + "Layer shapes: %s" % ([l.output_shape for l in layers])) if mode in {'cos', 'dot'}: if len(layers) > 2: raise Exception(mode + " merge takes exactly 2 layers") @@ -303,9 +304,17 @@ class Merge(Layer): dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)] else: dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)] + if type(dot_axes) not in [list, tuple]: + raise Exception("Invalid type for dot_axes - should be a list.") + if len(dot_axes) != 2: + raise Exception("Invalid format for dot_axes - should contain two elements.") + if type(dot_axes[0]) not in [list, tuple, range] or type(dot_axes[1]) not in [list, tuple, range]: + raise Exception("Invalid format for dot_axes - list elements should have type 'list' or 'tuple'.") for i in range(len(dot_axes[0])): if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]: - raise Exception(" Dot incompatible layers can not be merged using dot mode") + raise Exception("Dimension incompatibility using dot mode: " + + "%s != %s. " % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) + + "Layer shapes: %s, %s" % (shape1, shape2)) elif mode == 'concat': input_shapes = set() for l in layers: @@ -314,7 +323,9 @@ class Merge(Layer): oshape = tuple(oshape) input_shapes.add(oshape) if len(input_shapes) > 1: - raise Exception("'concat' mode can only merge layers with matching output shapes except for the concat axis") + raise Exception("'concat' mode can only merge layers with matching " + + "output shapes except for the concat axis. " + + "Layer shapes: %s" % ([l.output_shape for l in layers])) self.mode = mode self.concat_axis = concat_axis @@ -349,14 +360,17 @@ class Merge(Layer): elif self.mode == 'dot': shape1 = list(input_shapes[0]) shape2 = list(input_shapes[1]) - for i in self.dot_axes[0]: - shape1.pop(i) - for i in self.dot_axes[1]: - shape2.pop(i) - shape = shape1 + shape2[1:] - if len(shape) == 1: - shape.append(1) - return tuple(shape) + dot_axes = [] + for axes in self.dot_axes: + dot_axes.append([index-1 for index in axes]) + tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])), + np.zeros(tuple(shape2[1:])), + axes=dot_axes) + if len(tensordot_output.shape) == 0: + shape = (1,) + else: + shape = tensordot_output.shape + return (shape1[0],) + shape elif self.mode == 'cos': return tuple(input_shapes[0][0], 1) @@ -392,13 +406,13 @@ class Merge(Layer): l1 = self.layers[0].get_output(train) l2 = self.layers[1].get_output(train) output = T.batched_tensordot(l1, l2, self.dot_axes) - output = output.dimshuffle((0, 'x')) return output elif self.mode == 'cos': l1 = self.layers[0].get_output(train) l2 = self.layers[1].get_output(train) - output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2)/T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), sequences=[l1, l2], outputs_info=None) - output = output.dimshuffle((0, 'x')) + output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2) / T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), + sequences=[l1, l2], + outputs_info=None) return output else: raise Exception('Unknown merge mode') @@ -446,7 +460,6 @@ class Merge(Layer): return dict(list(base_config.items()) + list(config.items())) - class Dropout(MaskedLayer): ''' Hinton's dropout. diff --git a/tests/auto/test_sequential_model.py b/tests/auto/test_sequential_model.py index 270ce15ac..db76ddf23 100644 --- a/tests/auto/test_sequential_model.py +++ b/tests/auto/test_sequential_model.py @@ -128,7 +128,7 @@ class TestSequential(unittest.TestCase): nloss = model.evaluate([X_train, X_train], y_train, verbose=0) print(nloss) assert(loss == nloss) - + def test_merge_dot1(self): print('Test merge: dot') left = Sequential() @@ -146,7 +146,7 @@ class TestSequential(unittest.TestCase): model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') - + def test_merge_dot2(self): print('Test merge: dot') left = Sequential() @@ -158,12 +158,13 @@ class TestSequential(unittest.TestCase): right.add(Activation('relu')) model = Sequential() - model.add(Merge([left, right], mode='dot', dot_axes=([1],[1]))) + model.add(Merge([left, right], mode='dot', dot_axes=([1], [1]))) model.add(Dense(nb_class)) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + def test_merge_concat(self): print('Test merge: concat') left = Sequential()