Fixes and input checks for dot mode in merge.

This commit is contained in:
Makoto Matsuyama 2015-11-02 19:37:00 -08:00
parent a5d93bfdc1
commit 002a9d5d2b

@ -289,7 +289,8 @@ class Merge(Layer):
if mode in {'sum', 'mul', 'ave', 'cos'}:
input_shapes = set([l.output_shape for l in layers])
if len(input_shapes) > 1:
raise Exception("Only layers of same output shape can be merged using " + mode + " mode")
raise Exception("Only layers of same output shape can be merged using " + mode + " mode. " +
"Layer shapes: %s" % ([l.output_shape for l in layers]))
if mode in {'cos', 'dot'}:
if len(layers) > 2:
raise Exception(mode + " merge takes exactly 2 layers")
@ -303,13 +304,28 @@ class Merge(Layer):
dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)]
else:
dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)]
if type(dot_axes) not in [list, tuple]:
raise Exception("Invalid type for dot_axes - should be a list.")
if len(dot_axes) != 2:
raise Exception("Invalid format for dot_axes - should contain two elements.")
if type(dot_axes[0]) not in [list, tuple] or type(dot_axes[1]) not in [list, tuple]:
raise Exception("Invalid format for dot_axes - list elements should have type 'list' or 'tuple'.")
for i in range(len(dot_axes[0])):
if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]:
raise Exception(" Dot incompatible layers can not be merged using dot mode")
raise Exception("Dimension incompatibility using dot mode: " +
"%s != %s. " % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) +
"Layer shapes: %s, %s" % (shape1, shape2))
elif mode == 'concat':
input_shapes = set([list(l.output_shape).pop(concat_axis) for l in layers])
input_shapes = set()
for l in layers:
oshape = list(l.output_shape)
oshape.pop(concat_axis)
oshape = tuple(oshape)
input_shapes.add(oshape)
if len(input_shapes) > 1:
raise Exception("'concat' mode can only merge layers with matching output shapes except for the concat axis")
raise Exception("'concat' mode can only merge layers with matching " +
"output shapes except for the concat axis. " +
"Layer shapes: %s" % ([l.output_shape for l in layers]))
self.mode = mode
self.concat_axis = concat_axis
@ -344,14 +360,17 @@ class Merge(Layer):
elif self.mode == 'dot':
shape1 = list(input_shapes[0])
shape2 = list(input_shapes[1])
for i in self.dot_axes[0]:
shape1.pop(i)
for i in self.dot_axes[1]:
shape2.pop(i)
shape = shape1 + shape2[1:]
if len(shape) == 1:
shape.append(1)
return tuple(shape)
dot_axes = []
for axes in self.dot_axes:
dot_axes.append([index-1 for index in axes])
tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])),
np.zeros(tuple(shape2[1:])),
axes=dot_axes)
if len(tensordot_output.shape) == 0:
shape = (1,)
else:
shape = tensordot_output.shape
return (shape1[0],) + shape
elif self.mode == 'cos':
return tuple(input_shapes[0][0], 1)
@ -387,13 +406,13 @@ class Merge(Layer):
l1 = self.layers[0].get_output(train)
l2 = self.layers[1].get_output(train)
output = T.batched_tensordot(l1, l2, self.dot_axes)
output = output.dimshuffle((0, 'x'))
return output
elif self.mode == 'cos':
l1 = self.layers[0].get_output(train)
l2 = self.layers[1].get_output(train)
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2)/T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), sequences=[l1, l2], outputs_info=None)
output = output.dimshuffle((0, 'x'))
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2) / T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)),
sequences=[l1, l2],
outputs_info=None)
return output
else:
raise Exception('Unknown merge mode')
@ -441,7 +460,6 @@ class Merge(Layer):
return dict(list(base_config.items()) + list(config.items()))
class Dropout(MaskedLayer):
'''
Hinton's dropout.