Switch variational examples to new API.

This commit is contained in:
Francois Chollet 2017-04-11 13:43:04 -07:00
parent b2f0dd4cb2
commit 4aa41625bf
2 changed files with 12 additions and 23 deletions

@ -42,11 +42,6 @@ h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
# placeholder loss
def zero_loss(y_true, y_pred):
return K.zeros_like(y_pred)
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
@ -63,12 +58,12 @@ class CustomVariationalLayer(Layer):
x_decoded_mean = inputs[1]
loss = self.vae_loss(x, x_decoded_mean)
self.add_loss(loss, inputs=inputs)
# we don't use this output, but it has to have the correct shape:
return K.ones_like(x)
# We won't actually use the output.
return x
loss_layer = CustomVariationalLayer()([x, x_decoded_mean])
vae = Model(x, [loss_layer])
vae.compile(optimizer='rmsprop', loss=[zero_loss])
y = CustomVariationalLayer()([x, x_decoded_mean])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)
# train the VAE on MNIST digits
@ -79,7 +74,7 @@ x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
vae.fit(x_train, x_train,
vae.fit(x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,

@ -106,11 +106,6 @@ x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)
# placeholder loss
def zero_loss(y_true, y_pred):
return K.zeros_like(y_pred)
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
@ -129,14 +124,13 @@ class CustomVariationalLayer(Layer):
x_decoded_mean_squash = inputs[1]
loss = self.vae_loss(x, x_decoded_mean_squash)
self.add_loss(loss, inputs=inputs)
# we don't use this output, but it has to have the correct shape:
return K.ones_like(x)
# We don't use this output.
return x
loss_layer = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, [loss_layer])
vae.compile(optimizer='rmsprop', loss=zero_loss)
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()
# train the VAE on MNIST digits
@ -149,7 +143,7 @@ x_test = x_test.reshape((x_test.shape[0],) + original_img_size)
print('x_train.shape:', x_train.shape)
vae.fit(x_train, x_train,
vae.fit(x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,