From 002a9d5d2b2c26fff63293e3007ede9ab7dee616 Mon Sep 17 00:00:00 2001 From: Makoto Matsuyama Date: Mon, 2 Nov 2015 19:37:00 -0800 Subject: [PATCH] Fixes and input checks for dot mode in merge. --- keras/layers/core.py | 50 ++++++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/keras/layers/core.py b/keras/layers/core.py index 7663ef421..1c52e45e7 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -289,7 +289,8 @@ class Merge(Layer): if mode in {'sum', 'mul', 'ave', 'cos'}: input_shapes = set([l.output_shape for l in layers]) if len(input_shapes) > 1: - raise Exception("Only layers of same output shape can be merged using " + mode + " mode") + raise Exception("Only layers of same output shape can be merged using " + mode + " mode. " + + "Layer shapes: %s" % ([l.output_shape for l in layers])) if mode in {'cos', 'dot'}: if len(layers) > 2: raise Exception(mode + " merge takes exactly 2 layers") @@ -303,13 +304,28 @@ class Merge(Layer): dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)] else: dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)] + if type(dot_axes) not in [list, tuple]: + raise Exception("Invalid type for dot_axes - should be a list.") + if len(dot_axes) != 2: + raise Exception("Invalid format for dot_axes - should contain two elements.") + if type(dot_axes[0]) not in [list, tuple] or type(dot_axes[1]) not in [list, tuple]: + raise Exception("Invalid format for dot_axes - list elements should have type 'list' or 'tuple'.") for i in range(len(dot_axes[0])): if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]: - raise Exception(" Dot incompatible layers can not be merged using dot mode") + raise Exception("Dimension incompatibility using dot mode: " + + "%s != %s. " % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) + + "Layer shapes: %s, %s" % (shape1, shape2)) elif mode == 'concat': - input_shapes = set([list(l.output_shape).pop(concat_axis) for l in layers]) + input_shapes = set() + for l in layers: + oshape = list(l.output_shape) + oshape.pop(concat_axis) + oshape = tuple(oshape) + input_shapes.add(oshape) if len(input_shapes) > 1: - raise Exception("'concat' mode can only merge layers with matching output shapes except for the concat axis") + raise Exception("'concat' mode can only merge layers with matching " + + "output shapes except for the concat axis. " + + "Layer shapes: %s" % ([l.output_shape for l in layers])) self.mode = mode self.concat_axis = concat_axis @@ -344,14 +360,17 @@ class Merge(Layer): elif self.mode == 'dot': shape1 = list(input_shapes[0]) shape2 = list(input_shapes[1]) - for i in self.dot_axes[0]: - shape1.pop(i) - for i in self.dot_axes[1]: - shape2.pop(i) - shape = shape1 + shape2[1:] - if len(shape) == 1: - shape.append(1) - return tuple(shape) + dot_axes = [] + for axes in self.dot_axes: + dot_axes.append([index-1 for index in axes]) + tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])), + np.zeros(tuple(shape2[1:])), + axes=dot_axes) + if len(tensordot_output.shape) == 0: + shape = (1,) + else: + shape = tensordot_output.shape + return (shape1[0],) + shape elif self.mode == 'cos': return tuple(input_shapes[0][0], 1) @@ -387,13 +406,13 @@ class Merge(Layer): l1 = self.layers[0].get_output(train) l2 = self.layers[1].get_output(train) output = T.batched_tensordot(l1, l2, self.dot_axes) - output = output.dimshuffle((0, 'x')) return output elif self.mode == 'cos': l1 = self.layers[0].get_output(train) l2 = self.layers[1].get_output(train) - output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2)/T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), sequences=[l1, l2], outputs_info=None) - output = output.dimshuffle((0, 'x')) + output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2) / T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), + sequences=[l1, l2], + outputs_info=None) return output else: raise Exception('Unknown merge mode') @@ -441,7 +460,6 @@ class Merge(Layer): return dict(list(base_config.items()) + list(config.items())) - class Dropout(MaskedLayer): ''' Hinton's dropout.