Fix Theano tests

This commit is contained in:
Francois Chollet 2016-11-09 20:33:42 -08:00
parent 92e8a20761
commit e916f748db
9 changed files with 28 additions and 1 deletions

@ -142,6 +142,7 @@ class Convolution1D(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
length = conv_output_length(input_shape[1],
@ -434,6 +435,7 @@ class Convolution2D(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
@ -982,6 +984,7 @@ class SeparableConvolution2D(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
@ -1179,6 +1182,7 @@ class Convolution3D(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':

@ -371,6 +371,7 @@ class ConvLSTM2D(ConvRecurrent2D):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def reset_states(self):
assert self.stateful, 'Layer must be stateful.'

@ -723,6 +723,7 @@ class Dense(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def call(self, x, mask=None):
output = K.dot(x, self.W)
@ -891,6 +892,7 @@ class MaxoutDense(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
assert input_shape and len(input_shape) == 2
@ -1028,6 +1030,7 @@ class Highway(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def call(self, x, mask=None):
y = K.dot(x, self.W_carry)
@ -1168,6 +1171,7 @@ class TimeDistributedDense(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
return (input_shape[0], input_shape[1], self.output_dim)

@ -110,6 +110,7 @@ class Embedding(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
self.built = True
def compute_mask(self, x, mask=None):
if not self.mask_zero:

@ -139,6 +139,7 @@ class LocallyConnected1D(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
length = conv_output_length(input_shape[1],
@ -333,6 +334,7 @@ class LocallyConnected2D(Layer):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':

@ -325,6 +325,7 @@ class SimpleRNN(Recurrent):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def reset_states(self):
assert self.stateful, 'Layer must be stateful.'
@ -515,6 +516,7 @@ class GRU(Recurrent):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def reset_states(self):
assert self.stateful, 'Layer must be stateful.'
@ -745,6 +747,7 @@ class LSTM(Recurrent):
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def reset_states(self):
assert self.stateful, 'Layer must be stateful.'

@ -389,7 +389,8 @@ def test_zero_padding_1d():
nb_samples = 2
input_dim = 2
nb_steps = 5
input = np.ones((nb_samples, nb_steps, input_dim))
shape = (nb_samples, nb_steps, input_dim)
input = np.ones(shape)
# basic test
layer_test(convolutional.ZeroPadding1D,
@ -404,6 +405,7 @@ def test_zero_padding_1d():
# correctness test
layer = convolutional.ZeroPadding1D(padding=2)
layer.build(shape)
output = layer(K.variable(input))
np_output = K.eval(output)
for offset in [0, 1, -1, -2]:
@ -411,6 +413,7 @@ def test_zero_padding_1d():
assert_allclose(np_output[:, 2:-2, :], 1.)
layer = convolutional.ZeroPadding1D(padding=(1, 2))
layer.build(shape)
output = layer(K.variable(input))
np_output = K.eval(output)
for left_offset in [0]:
@ -449,6 +452,7 @@ def test_zero_padding_2d():
# correctness test
layer = convolutional.ZeroPadding2D(padding=(2, 2))
output = layer(K.variable(input))
layer.build(input.shape)
np_output = K.eval(output)
if dim_ordering == 'tf':
for offset in [0, 1, -1, -2]:
@ -462,6 +466,7 @@ def test_zero_padding_2d():
assert_allclose(np_output[:, 2:-2, 2:-2, :], 1.)
layer = convolutional.ZeroPadding2D(padding=(1, 2, 3, 4))
layer.build(input.shape)
output = layer(K.variable(input))
np_output = K.eval(output)
if dim_ordering == 'tf':
@ -505,6 +510,7 @@ def test_zero_padding_3d():
# correctness test
layer = convolutional.ZeroPadding3D(padding=(2, 2, 2))
layer.build(input.shape)
output = layer(K.variable(input))
np_output = K.eval(output)
for offset in [0, 1, -1, -2]:
@ -542,6 +548,7 @@ def test_upsampling_2d():
layer = convolutional.UpSampling2D(
size=(length_row, length_col),
dim_ordering=dim_ordering)
layer.build(input.shape)
output = layer(K.variable(input))
np_output = K.eval(output)
if dim_ordering == 'th':
@ -582,6 +589,7 @@ def test_upsampling_3d():
layer = convolutional.UpSampling3D(
size=(length_dim1, length_dim2, length_dim3),
dim_ordering=dim_ordering)
layer.build(input.shape)
output = layer(K.variable(input))
np_output = K.eval(output)
if dim_ordering == 'th':
@ -641,6 +649,7 @@ def test_cropping_2d():
# correctness test
layer = convolutional.Cropping2D(cropping=cropping,
dim_ordering=dim_ordering)
layer.build(input.shape)
output = layer(K.variable(input))
np_output = K.eval(output)
# compare with numpy
@ -681,6 +690,7 @@ def test_cropping_3d():
# correctness test
layer = convolutional.Cropping3D(cropping=cropping,
dim_ordering=dim_ordering)
layer.build(input.shape)
output = layer(K.variable(input))
np_output = K.eval(output)
# compare with numpy

@ -110,6 +110,7 @@ def test_recurrent_convolutional():
'border_mode': "same"}
layer = convolutional_recurrent.ConvLSTM2D(**kwargs)
layer.build(input.shape)
output = layer(K.variable(np.ones(input.shape)))
K.eval(output)

@ -129,6 +129,7 @@ def test_regularizer(layer_class):
U_regularizer=regularizers.WeightRegularizer(l1=0.01),
b_regularizer='l2')
shape = (nb_samples, timesteps, embedding_dim)
layer.build(shape)
output = layer(K.variable(np.ones(shape)))
K.eval(output)