Merge pull request #942 from matsuyamax/master

Fixes and input checks for dot mode in merge.
This commit is contained in:
François Chollet 2015-11-02 20:45:51 -08:00
commit fe14a845ab
2 changed files with 32 additions and 18 deletions

@ -289,7 +289,8 @@ class Merge(Layer):
if mode in {'sum', 'mul', 'ave', 'cos'}:
input_shapes = set([l.output_shape for l in layers])
if len(input_shapes) > 1:
raise Exception("Only layers of same output shape can be merged using " + mode + " mode")
raise Exception("Only layers of same output shape can be merged using " + mode + " mode. " +
"Layer shapes: %s" % ([l.output_shape for l in layers]))
if mode in {'cos', 'dot'}:
if len(layers) > 2:
raise Exception(mode + " merge takes exactly 2 layers")
@ -303,9 +304,17 @@ class Merge(Layer):
dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)]
else:
dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)]
if type(dot_axes) not in [list, tuple]:
raise Exception("Invalid type for dot_axes - should be a list.")
if len(dot_axes) != 2:
raise Exception("Invalid format for dot_axes - should contain two elements.")
if type(dot_axes[0]) not in [list, tuple, range] or type(dot_axes[1]) not in [list, tuple, range]:
raise Exception("Invalid format for dot_axes - list elements should have type 'list' or 'tuple'.")
for i in range(len(dot_axes[0])):
if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]:
raise Exception(" Dot incompatible layers can not be merged using dot mode")
raise Exception("Dimension incompatibility using dot mode: " +
"%s != %s. " % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) +
"Layer shapes: %s, %s" % (shape1, shape2))
elif mode == 'concat':
input_shapes = set()
for l in layers:
@ -314,7 +323,9 @@ class Merge(Layer):
oshape = tuple(oshape)
input_shapes.add(oshape)
if len(input_shapes) > 1:
raise Exception("'concat' mode can only merge layers with matching output shapes except for the concat axis")
raise Exception("'concat' mode can only merge layers with matching " +
"output shapes except for the concat axis. " +
"Layer shapes: %s" % ([l.output_shape for l in layers]))
self.mode = mode
self.concat_axis = concat_axis
@ -349,14 +360,17 @@ class Merge(Layer):
elif self.mode == 'dot':
shape1 = list(input_shapes[0])
shape2 = list(input_shapes[1])
for i in self.dot_axes[0]:
shape1.pop(i)
for i in self.dot_axes[1]:
shape2.pop(i)
shape = shape1 + shape2[1:]
if len(shape) == 1:
shape.append(1)
return tuple(shape)
dot_axes = []
for axes in self.dot_axes:
dot_axes.append([index-1 for index in axes])
tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])),
np.zeros(tuple(shape2[1:])),
axes=dot_axes)
if len(tensordot_output.shape) == 0:
shape = (1,)
else:
shape = tensordot_output.shape
return (shape1[0],) + shape
elif self.mode == 'cos':
return tuple(input_shapes[0][0], 1)
@ -392,13 +406,13 @@ class Merge(Layer):
l1 = self.layers[0].get_output(train)
l2 = self.layers[1].get_output(train)
output = T.batched_tensordot(l1, l2, self.dot_axes)
output = output.dimshuffle((0, 'x'))
return output
elif self.mode == 'cos':
l1 = self.layers[0].get_output(train)
l2 = self.layers[1].get_output(train)
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2)/T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), sequences=[l1, l2], outputs_info=None)
output = output.dimshuffle((0, 'x'))
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2) / T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)),
sequences=[l1, l2],
outputs_info=None)
return output
else:
raise Exception('Unknown merge mode')
@ -446,7 +460,6 @@ class Merge(Layer):
return dict(list(base_config.items()) + list(config.items()))
class Dropout(MaskedLayer):
'''
Hinton's dropout.

@ -128,7 +128,7 @@ class TestSequential(unittest.TestCase):
nloss = model.evaluate([X_train, X_train], y_train, verbose=0)
print(nloss)
assert(loss == nloss)
def test_merge_dot1(self):
print('Test merge: dot')
left = Sequential()
@ -146,7 +146,7 @@ class TestSequential(unittest.TestCase):
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
def test_merge_dot2(self):
print('Test merge: dot')
left = Sequential()
@ -158,12 +158,13 @@ class TestSequential(unittest.TestCase):
right.add(Activation('relu'))
model = Sequential()
model.add(Merge([left, right], mode='dot', dot_axes=([1],[1])))
model.add(Merge([left, right], mode='dot', dot_axes=([1], [1])))
model.add(Dense(nb_class))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
def test_merge_concat(self):
print('Test merge: concat')
left = Sequential()