Merge pull request #942 from matsuyamax/master
Fixes and input checks for dot mode in merge.
This commit is contained in:
commit
fe14a845ab
@ -289,7 +289,8 @@ class Merge(Layer):
|
|||||||
if mode in {'sum', 'mul', 'ave', 'cos'}:
|
if mode in {'sum', 'mul', 'ave', 'cos'}:
|
||||||
input_shapes = set([l.output_shape for l in layers])
|
input_shapes = set([l.output_shape for l in layers])
|
||||||
if len(input_shapes) > 1:
|
if len(input_shapes) > 1:
|
||||||
raise Exception("Only layers of same output shape can be merged using " + mode + " mode")
|
raise Exception("Only layers of same output shape can be merged using " + mode + " mode. " +
|
||||||
|
"Layer shapes: %s" % ([l.output_shape for l in layers]))
|
||||||
if mode in {'cos', 'dot'}:
|
if mode in {'cos', 'dot'}:
|
||||||
if len(layers) > 2:
|
if len(layers) > 2:
|
||||||
raise Exception(mode + " merge takes exactly 2 layers")
|
raise Exception(mode + " merge takes exactly 2 layers")
|
||||||
@ -303,9 +304,17 @@ class Merge(Layer):
|
|||||||
dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)]
|
dot_axes = [range(dot_axes % n1, n1), range(dot_axes % n2, n2)]
|
||||||
else:
|
else:
|
||||||
dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)]
|
dot_axes = [range(n1 - dot_axes, n2), range(1, dot_axes + 1)]
|
||||||
|
if type(dot_axes) not in [list, tuple]:
|
||||||
|
raise Exception("Invalid type for dot_axes - should be a list.")
|
||||||
|
if len(dot_axes) != 2:
|
||||||
|
raise Exception("Invalid format for dot_axes - should contain two elements.")
|
||||||
|
if type(dot_axes[0]) not in [list, tuple, range] or type(dot_axes[1]) not in [list, tuple, range]:
|
||||||
|
raise Exception("Invalid format for dot_axes - list elements should have type 'list' or 'tuple'.")
|
||||||
for i in range(len(dot_axes[0])):
|
for i in range(len(dot_axes[0])):
|
||||||
if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]:
|
if shape1[dot_axes[0][i]] != shape2[dot_axes[1][i]]:
|
||||||
raise Exception(" Dot incompatible layers can not be merged using dot mode")
|
raise Exception("Dimension incompatibility using dot mode: " +
|
||||||
|
"%s != %s. " % (shape1[dot_axes[0][i]], shape2[dot_axes[1][i]]) +
|
||||||
|
"Layer shapes: %s, %s" % (shape1, shape2))
|
||||||
elif mode == 'concat':
|
elif mode == 'concat':
|
||||||
input_shapes = set()
|
input_shapes = set()
|
||||||
for l in layers:
|
for l in layers:
|
||||||
@ -314,7 +323,9 @@ class Merge(Layer):
|
|||||||
oshape = tuple(oshape)
|
oshape = tuple(oshape)
|
||||||
input_shapes.add(oshape)
|
input_shapes.add(oshape)
|
||||||
if len(input_shapes) > 1:
|
if len(input_shapes) > 1:
|
||||||
raise Exception("'concat' mode can only merge layers with matching output shapes except for the concat axis")
|
raise Exception("'concat' mode can only merge layers with matching " +
|
||||||
|
"output shapes except for the concat axis. " +
|
||||||
|
"Layer shapes: %s" % ([l.output_shape for l in layers]))
|
||||||
|
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.concat_axis = concat_axis
|
self.concat_axis = concat_axis
|
||||||
@ -349,14 +360,17 @@ class Merge(Layer):
|
|||||||
elif self.mode == 'dot':
|
elif self.mode == 'dot':
|
||||||
shape1 = list(input_shapes[0])
|
shape1 = list(input_shapes[0])
|
||||||
shape2 = list(input_shapes[1])
|
shape2 = list(input_shapes[1])
|
||||||
for i in self.dot_axes[0]:
|
dot_axes = []
|
||||||
shape1.pop(i)
|
for axes in self.dot_axes:
|
||||||
for i in self.dot_axes[1]:
|
dot_axes.append([index-1 for index in axes])
|
||||||
shape2.pop(i)
|
tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])),
|
||||||
shape = shape1 + shape2[1:]
|
np.zeros(tuple(shape2[1:])),
|
||||||
if len(shape) == 1:
|
axes=dot_axes)
|
||||||
shape.append(1)
|
if len(tensordot_output.shape) == 0:
|
||||||
return tuple(shape)
|
shape = (1,)
|
||||||
|
else:
|
||||||
|
shape = tensordot_output.shape
|
||||||
|
return (shape1[0],) + shape
|
||||||
elif self.mode == 'cos':
|
elif self.mode == 'cos':
|
||||||
return tuple(input_shapes[0][0], 1)
|
return tuple(input_shapes[0][0], 1)
|
||||||
|
|
||||||
@ -392,13 +406,13 @@ class Merge(Layer):
|
|||||||
l1 = self.layers[0].get_output(train)
|
l1 = self.layers[0].get_output(train)
|
||||||
l2 = self.layers[1].get_output(train)
|
l2 = self.layers[1].get_output(train)
|
||||||
output = T.batched_tensordot(l1, l2, self.dot_axes)
|
output = T.batched_tensordot(l1, l2, self.dot_axes)
|
||||||
output = output.dimshuffle((0, 'x'))
|
|
||||||
return output
|
return output
|
||||||
elif self.mode == 'cos':
|
elif self.mode == 'cos':
|
||||||
l1 = self.layers[0].get_output(train)
|
l1 = self.layers[0].get_output(train)
|
||||||
l2 = self.layers[1].get_output(train)
|
l2 = self.layers[1].get_output(train)
|
||||||
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2)/T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)), sequences=[l1, l2], outputs_info=None)
|
output, _ = theano.scan(lambda v1, v2: T.dot(v1, v2) / T.sqrt(T.dot(v1, v1) * T.dot(v2, v2)),
|
||||||
output = output.dimshuffle((0, 'x'))
|
sequences=[l1, l2],
|
||||||
|
outputs_info=None)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown merge mode')
|
raise Exception('Unknown merge mode')
|
||||||
@ -446,7 +460,6 @@ class Merge(Layer):
|
|||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Dropout(MaskedLayer):
|
class Dropout(MaskedLayer):
|
||||||
'''
|
'''
|
||||||
Hinton's dropout.
|
Hinton's dropout.
|
||||||
|
@ -158,12 +158,13 @@ class TestSequential(unittest.TestCase):
|
|||||||
right.add(Activation('relu'))
|
right.add(Activation('relu'))
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
model.add(Merge([left, right], mode='dot', dot_axes=([1],[1])))
|
model.add(Merge([left, right], mode='dot', dot_axes=([1], [1])))
|
||||||
|
|
||||||
model.add(Dense(nb_class))
|
model.add(Dense(nb_class))
|
||||||
model.add(Activation('softmax'))
|
model.add(Activation('softmax'))
|
||||||
|
|
||||||
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
||||||
|
|
||||||
def test_merge_concat(self):
|
def test_merge_concat(self):
|
||||||
print('Test merge: concat')
|
print('Test merge: concat')
|
||||||
left = Sequential()
|
left = Sequential()
|
||||||
|
Loading…
Reference in New Issue
Block a user