Fixes and input checks for dot mode in merge.
This commit is contained in:
parent
a5d93bfdc1
commit
002a9d5d2b
@ -289,7 +289,8 @@ class Merge(Layer):
|
||||
if mode in {'sum', 'mul', 'ave', 'cos'}:
|
||||
input_shapes = set([l.output_shape for l in layers])
|
||||
if len(input_shapes) > 1:
|
||||
raise Exception("Only layers of same output shape can be merged using " + mode + " mode")
|
||||
raise Exception("Only layers of same output shape can be merged using " + mode + " mode. " +
|
||||
"Layer shapes: %s" % ([l.output_shape for l in layers]))
|
||||
if mode in {'cos', 'dot'}:
|
||||
if len(layers) > 2:
|
||||
raise Exception(mode + " merge takes exactly 2 layers")
|
||||
@ -303,13 +304,28 @@ class Merge(Layer):
|
||||
dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)]
|
||||
else:
|
||||
dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)]
|
||||
if type(dot_axes) not in [list, tuple]:
|
||||
raise Exception("Invalid type for dot_axes - should be a list.")
|
||||
if len(dot_axes) != 2:
|
||||
raise Exception("Invalid format for dot_axes - should contain two elements.")
|
||||
if type(dot_axes[0]) not in [list, tuple] or type(dot_axes[1]) not in [list, tuple]:
|
||||
raise Exception("Invalid format for dot_axes - list elements should have type 'list' or 'tuple'.")
|
||||
for i in range(len(dot_axes[0])):
|
||||
if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]:
|
||||
raise Exception(" Dot incompatible layers can not be merged using dot mode")
|
||||
raise Exception("Dimension incompatibility using dot mode: " +
|
||||
"%s != %s. " % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) +
|
||||
"Layer shapes: %s, %s" % (shape1, shape2))
|
||||
elif mode == 'concat':
|
||||
input_shapes = set([list(l.output_shape).pop(concat_axis) for l in layers])
|
||||
input_shapes = set()
|
||||
for l in layers:
|
||||
oshape = list(l.output_shape)
|
||||
oshape.pop(concat_axis)
|
||||
oshape = tuple(oshape)
|
||||
input_shapes.add(oshape)
|
||||
if len(input_shapes) > 1:
|
||||
raise Exception("'concat' mode can only merge layers with matching output shapes except for the concat axis")
|
||||
raise Exception("'concat' mode can only merge layers with matching " +
|
||||
"output shapes except for the concat axis. " +
|
||||
"Layer shapes: %s" % ([l.output_shape for l in layers]))
|
||||
|
||||
self.mode = mode
|
||||
self.concat_axis = concat_axis
|
||||
@ -344,14 +360,17 @@ class Merge(Layer):
|
||||
elif self.mode == 'dot':
|
||||
shape1 = list(input_shapes[0])
|
||||
shape2 = list(input_shapes[1])
|
||||
for i in self.dot_axes[0]:
|
||||
shape1.pop(i)
|
||||
for i in self.dot_axes[1]:
|
||||
shape2.pop(i)
|
||||
shape = shape1 + shape2[1:]
|
||||
if len(shape) == 1:
|
||||
shape.append(1)
|
||||
return tuple(shape)
|
||||
dot_axes = []
|
||||
for axes in self.dot_axes:
|
||||
dot_axes.append([index-1 for index in axes])
|
||||
tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])),
|
||||
np.zeros(tuple(shape2[1:])),
|
||||
axes=dot_axes)
|
||||
if len(tensordot_output.shape) == 0:
|
||||
shape = (1,)
|
||||
else:
|
||||
shape = tensordot_output.shape
|
||||
return (shape1[0],) + shape
|
||||
elif self.mode == 'cos':
|
||||
return tuple(input_shapes[0][0], 1)
|
||||
|
||||
@ -387,13 +406,13 @@ class Merge(Layer):
|
||||
l1 = self.layers[0].get_output(train)
|
||||
l2 = self.layers[1].get_output(train)
|
||||
output = T.batched_tensordot(l1, l2, self.dot_axes)
|
||||
output = output.dimshuffle((0, 'x'))
|
||||
return output
|
||||
elif self.mode == 'cos':
|
||||
l1 = self.layers[0].get_output(train)
|
||||
l2 = self.layers[1].get_output(train)
|
||||
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2)/T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), sequences=[l1, l2], outputs_info=None)
|
||||
output = output.dimshuffle((0, 'x'))
|
||||
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2) / T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)),
|
||||
sequences=[l1, l2],
|
||||
outputs_info=None)
|
||||
return output
|
||||
else:
|
||||
raise Exception('Unknown merge mode')
|
||||
@ -441,7 +460,6 @@ class Merge(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
|
||||
class Dropout(MaskedLayer):
|
||||
'''
|
||||
Hinton's dropout.
|
||||
|
Loading…
Reference in New Issue
Block a user