diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 8d91853fc..88075a8ea 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -1143,8 +1143,6 @@ class Merge(Layer): self.mode = mode self.concat_axis = concat_axis self.dot_axes = dot_axes - if type(self.dot_axes) == int: - self.dot_axes = [self.dot_axes, ] * 2 self._output_shape = output_shape self.node_indices = node_indices self._output_mask = output_mask @@ -1220,16 +1218,16 @@ class Merge(Layer): n2 = len(shape2) if type(dot_axes) == int: if dot_axes < 0: - dot_axes = [dot_axes % n1, dot_axes % n2] + self.dot_axes = [dot_axes % n1, dot_axes % n2] else: - dot_axes = [n1 - dot_axes, n2 - dot_axes] - if type(dot_axes) not in [list, tuple]: + self.dot_axes = [dot_axes, ] * 2 + if type(self.dot_axes) not in [list, tuple]: raise Exception('Invalid type for dot_axes - should be a list.') - if len(dot_axes) != 2: + if len(self.dot_axes) != 2: raise Exception('Invalid format for dot_axes - should contain two elements.') - if type(dot_axes[0]) is not int or type(dot_axes[1]) is not int: + if type(self.dot_axes[0]) is not int or type(self.dot_axes[1]) is not int: raise Exception('Invalid format for dot_axes - list elements should be "int".') - if shape1[dot_axes[0]] != shape2[dot_axes[1]]: + if shape1[self.dot_axes[0]] != shape2[self.dot_axes[1]]: raise Exception('Dimension incompatibility using dot mode: ' + '%s != %s. ' % (shape1[dot_axes[0]], shape2[dot_axes[1]]) + 'Layer shapes: %s, %s' % (shape1, shape2))