Masked and non-masked merge bug fix (#3218)

* Masked and non-masked merge bug fix

* Masked merge concat logic with an expanded loop

* Cast mask of ones for unmasked input in merge to uint8
This commit is contained in:
Pradeep Dasigi 2016-07-27 17:50:49 -07:00 committed by François Chollet
parent e0179bad2f
commit 6a8815de0c
2 changed files with 19 additions and 4 deletions

@ -1356,9 +1356,19 @@ class Merge(Layer):
masks = [K.expand_dims(m, 0) for m in mask if m is not None]
return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
elif self.mode == 'concat':
masks = [K.ones_like(inputs[i][:-1]) if m is None else m for i, m in zip(inputs, mask)]
expanded_dims = [K.expand_dims(m) for m in masks]
concatenated = K.concatenate(expanded_dims, axis=self.concat_axis)
# Make a list of masks while making sure the dimensionality of each mask
# is the same as the corresponding input.
masks = []
for input_i, mask_i in zip(inputs, mask):
if mask_i is None:
# Input is unmasked. Append all 1s to masks, but cast it to uint8 first
masks.append(K.cast(K.ones_like(input_i), 'uint8'))
elif K.ndim(mask_i) < K.ndim(input_i):
# Mask is smaller than the input, expand it
masks.append(K.expand_dims(mask_i))
else:
masks.append(mask_i)
concatenated = K.concatenate(masks, axis=self.concat_axis)
return K.all(concatenated, axis=-1, keepdims=False)
elif self.mode in ['cos', 'dot']:
return None

@ -100,9 +100,10 @@ def test_merge_mask_2d():
masked_a = Masking(mask_value=0)(input_a)
masked_b = Masking(mask_value=0)(input_b)
# two different types of merging
# three different types of merging
merged_sum = merge([masked_a, masked_b], mode='sum')
merged_concat = merge([masked_a, masked_b], mode='concat', concat_axis=1)
merged_concat_mixed = merge([masked_a, input_b], mode='concat', concat_axis=1)
# test sum
model_sum = Model([input_a, input_b], [merged_sum])
@ -114,6 +115,10 @@ def test_merge_mask_2d():
model_concat.compile(loss='mse', optimizer='sgd')
model_concat.fit([rand(2, 3), rand(2, 3)], [rand(2, 6)], nb_epoch=1)
# test concatenation with masked and non-masked inputs
model_concat = Model([input_a, input_b], [merged_concat_mixed])
model_concat.compile(loss='mse', optimizer='sgd')
model_concat.fit([rand(2,3), rand(2,3)], [rand(2,6)], nb_epoch=1)
@keras_test
def test_merge_mask_3d():