Style fixes

This commit is contained in:
fchollet 2016-07-24 14:20:36 -07:00
parent 59bd247603
commit 09d75a4347

@ -5,7 +5,7 @@ Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from keras.layers import Input, Dense, merge from keras.layers import Input, Dense, Lambda
from keras.models import Model from keras.models import Model
from keras import backend as K from keras import backend as K
from keras import objectives from keras import objectives
@ -14,21 +14,22 @@ from keras.datasets import mnist
batch_size = 100 batch_size = 100
original_dim = 784 original_dim = 784
latent_dim = 2 latent_dim = 2
intermediate_dim = 500 intermediate_dim = 256
nb_epoch = 100 nb_epoch = 50
x = Input(batch_shape=(batch_size, original_dim)) x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x) h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h) z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h) z_log_var = Dense(latent_dim)(h)
def sampling(args): def sampling(args):
z_mean, z_log_var = args z_mean, z_log_var = args
epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.) epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.)
return z_mean + K.exp(z_log_var/2) * epsilon return z_mean + K.exp(z_log_var / 2) * epsilon
# note that "output_shape" isn't necessary with the TensorFlow backend # note that "output_shape" isn't necessary with the TensorFlow backend
z = merge([z_mean, z_log_var], mode=sampling, output_shape=(latent_dim,)) z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# we instantiate these layers separately so as to reuse them later # we instantiate these layers separately so as to reuse them later
decoder_h = Dense(intermediate_dim, activation='relu') decoder_h = Dense(intermediate_dim, activation='relu')
@ -36,6 +37,7 @@ decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z) h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded) x_decoded_mean = decoder_mean(h_decoded)
def vae_loss(x, x_decoded_mean): def vae_loss(x, x_decoded_mean):
xent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean) xent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)