2016-11-23 21:19:29 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
|
|
Train an Auxiliary Classifier Generative Adversarial Network (ACGAN) on the
|
|
|
|
MNIST dataset. See https://arxiv.org/abs/1610.09585 for more details.
|
|
|
|
|
|
|
|
You should start to see reasonable images after ~5 epochs, and good images
|
|
|
|
by ~15 epochs. You should use a GPU, as the convolution-heavy operations are
|
2017-03-12 03:44:29 +00:00
|
|
|
very slow on the CPU. Prefer the TensorFlow backend if you plan on iterating,
|
|
|
|
as the compilation time can be a blocker using Theano.
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
Timings:
|
|
|
|
|
|
|
|
Hardware | Backend | Time / Epoch
|
|
|
|
-------------------------------------------
|
|
|
|
CPU | TF | 3 hrs
|
|
|
|
Titan X (maxwell) | TF | 4 min
|
|
|
|
Titan X (maxwell) | TH | 7 min
|
|
|
|
|
|
|
|
Consult https://github.com/lukedeo/keras-acgan for more information and
|
|
|
|
example output
|
|
|
|
"""
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
from collections import defaultdict
|
2016-12-15 07:07:21 +00:00
|
|
|
try:
|
|
|
|
import cPickle as pickle
|
|
|
|
except ImportError:
|
|
|
|
import pickle
|
2016-11-23 21:19:29 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from six.moves import range
|
|
|
|
|
|
|
|
import keras.backend as K
|
|
|
|
from keras.datasets import mnist
|
2017-03-12 03:44:29 +00:00
|
|
|
from keras import layers
|
|
|
|
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Dropout
|
2016-11-23 21:19:29 +00:00
|
|
|
from keras.layers.advanced_activations import LeakyReLU
|
2017-03-12 03:44:29 +00:00
|
|
|
from keras.layers.convolutional import UpSampling2D, Conv2D
|
2016-11-23 21:19:29 +00:00
|
|
|
from keras.models import Sequential, Model
|
|
|
|
from keras.optimizers import Adam
|
|
|
|
from keras.utils.generic_utils import Progbar
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
np.random.seed(1337)
|
|
|
|
|
2017-01-13 23:39:04 +00:00
|
|
|
K.set_image_data_format('channels_first')
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
def build_generator(latent_size):
|
|
|
|
# we will map a pair of (z, L), where z is a latent vector and L is a
|
|
|
|
# label drawn from P_c, to image space (..., 1, 28, 28)
|
|
|
|
cnn = Sequential()
|
|
|
|
|
|
|
|
cnn.add(Dense(1024, input_dim=latent_size, activation='relu'))
|
|
|
|
cnn.add(Dense(128 * 7 * 7, activation='relu'))
|
|
|
|
cnn.add(Reshape((128, 7, 7)))
|
|
|
|
|
|
|
|
# upsample to (..., 14, 14)
|
|
|
|
cnn.add(UpSampling2D(size=(2, 2)))
|
2017-03-12 03:44:29 +00:00
|
|
|
cnn.add(Conv2D(256, 5, padding='same',
|
|
|
|
activation='relu',
|
|
|
|
kernel_initializer='glorot_normal'))
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# upsample to (..., 28, 28)
|
|
|
|
cnn.add(UpSampling2D(size=(2, 2)))
|
2017-03-12 03:44:29 +00:00
|
|
|
cnn.add(Conv2D(128, 5, padding='same',
|
|
|
|
activation='relu',
|
|
|
|
kernel_initializer='glorot_normal'))
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# take a channel axis reduction
|
2017-03-12 03:44:29 +00:00
|
|
|
cnn.add(Conv2D(1, 2, padding='same',
|
|
|
|
activation='tanh',
|
|
|
|
kernel_initializer='glorot_normal'))
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# this is the z space commonly refered to in GAN papers
|
|
|
|
latent = Input(shape=(latent_size, ))
|
|
|
|
|
|
|
|
# this will be our label
|
2016-11-25 04:21:56 +00:00
|
|
|
image_class = Input(shape=(1,), dtype='int32')
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# 10 classes in MNIST
|
2016-11-25 04:21:56 +00:00
|
|
|
cls = Flatten()(Embedding(10, latent_size,
|
2017-02-28 02:53:41 +00:00
|
|
|
embeddings_initializer='glorot_normal')(image_class))
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# hadamard product between z-space and a class conditional embedding
|
2017-03-12 03:44:29 +00:00
|
|
|
h = layers.multiply([latent, cls])
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
fake_image = cnn(h)
|
|
|
|
|
2017-02-28 02:53:41 +00:00
|
|
|
return Model([latent, image_class], fake_image)
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
def build_discriminator():
|
|
|
|
# build a relatively standard conv net, with LeakyReLUs as suggested in
|
|
|
|
# the reference paper
|
|
|
|
cnn = Sequential()
|
|
|
|
|
2017-03-12 03:44:29 +00:00
|
|
|
cnn.add(Conv2D(32, 3, padding='same', strides=2,
|
|
|
|
input_shape=(1, 28, 28)))
|
2016-11-23 21:19:29 +00:00
|
|
|
cnn.add(LeakyReLU())
|
|
|
|
cnn.add(Dropout(0.3))
|
|
|
|
|
2017-03-27 22:49:48 +00:00
|
|
|
cnn.add(Conv2D(64, 3, padding='same', strides=1))
|
2016-11-23 21:19:29 +00:00
|
|
|
cnn.add(LeakyReLU())
|
|
|
|
cnn.add(Dropout(0.3))
|
|
|
|
|
2017-03-12 03:44:29 +00:00
|
|
|
cnn.add(Conv2D(128, 3, padding='same', strides=2))
|
2016-11-23 21:19:29 +00:00
|
|
|
cnn.add(LeakyReLU())
|
|
|
|
cnn.add(Dropout(0.3))
|
|
|
|
|
2017-03-12 03:44:29 +00:00
|
|
|
cnn.add(Conv2D(256, 3, padding='same', strides=1))
|
2016-11-23 21:19:29 +00:00
|
|
|
cnn.add(LeakyReLU())
|
|
|
|
cnn.add(Dropout(0.3))
|
|
|
|
|
|
|
|
cnn.add(Flatten())
|
|
|
|
|
|
|
|
image = Input(shape=(1, 28, 28))
|
|
|
|
|
|
|
|
features = cnn(image)
|
|
|
|
|
|
|
|
# first output (name=generation) is whether or not the discriminator
|
|
|
|
# thinks the image that is being shown is fake, and the second output
|
|
|
|
# (name=auxiliary) is the class that the discriminator thinks the image
|
|
|
|
# belongs to.
|
|
|
|
fake = Dense(1, activation='sigmoid', name='generation')(features)
|
|
|
|
aux = Dense(10, activation='softmax', name='auxiliary')(features)
|
|
|
|
|
2017-02-28 02:53:41 +00:00
|
|
|
return Model(image, [fake, aux])
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
# batch and latent size taken from the paper
|
2017-02-28 02:53:41 +00:00
|
|
|
epochs = 50
|
2016-11-23 21:19:29 +00:00
|
|
|
batch_size = 100
|
|
|
|
latent_size = 100
|
|
|
|
|
|
|
|
# Adam parameters suggested in https://arxiv.org/abs/1511.06434
|
|
|
|
adam_lr = 0.0002
|
|
|
|
adam_beta_1 = 0.5
|
|
|
|
|
|
|
|
# build the discriminator
|
|
|
|
discriminator = build_discriminator()
|
|
|
|
discriminator.compile(
|
|
|
|
optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
|
|
|
|
loss=['binary_crossentropy', 'sparse_categorical_crossentropy']
|
|
|
|
)
|
|
|
|
|
|
|
|
# build the generator
|
|
|
|
generator = build_generator(latent_size)
|
|
|
|
generator.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
|
|
|
|
loss='binary_crossentropy')
|
|
|
|
|
|
|
|
latent = Input(shape=(latent_size, ))
|
2016-11-25 04:21:56 +00:00
|
|
|
image_class = Input(shape=(1,), dtype='int32')
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# get a fake image
|
|
|
|
fake = generator([latent, image_class])
|
|
|
|
|
|
|
|
# we only want to be able to train generation for the combined model
|
|
|
|
discriminator.trainable = False
|
|
|
|
fake, aux = discriminator(fake)
|
2017-02-28 02:53:41 +00:00
|
|
|
combined = Model([latent, image_class], [fake, aux])
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
combined.compile(
|
|
|
|
optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
|
|
|
|
loss=['binary_crossentropy', 'sparse_categorical_crossentropy']
|
|
|
|
)
|
|
|
|
|
|
|
|
# get our mnist data, and force it to be of shape (..., 1, 28, 28) with
|
|
|
|
# range [-1, 1]
|
|
|
|
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
|
|
|
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
|
|
|
|
X_train = np.expand_dims(X_train, axis=1)
|
|
|
|
|
|
|
|
X_test = (X_test.astype(np.float32) - 127.5) / 127.5
|
|
|
|
X_test = np.expand_dims(X_test, axis=1)
|
|
|
|
|
2017-02-15 00:08:30 +00:00
|
|
|
num_train, num_test = X_train.shape[0], X_test.shape[0]
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
train_history = defaultdict(list)
|
|
|
|
test_history = defaultdict(list)
|
|
|
|
|
2017-02-28 02:53:41 +00:00
|
|
|
for epoch in range(epochs):
|
|
|
|
print('Epoch {} of {}'.format(epoch + 1, epochs))
|
2016-11-23 21:19:29 +00:00
|
|
|
|
2017-02-15 00:08:30 +00:00
|
|
|
num_batches = int(X_train.shape[0] / batch_size)
|
|
|
|
progress_bar = Progbar(target=num_batches)
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
epoch_gen_loss = []
|
|
|
|
epoch_disc_loss = []
|
|
|
|
|
2017-02-15 00:08:30 +00:00
|
|
|
for index in range(num_batches):
|
2016-11-23 21:19:29 +00:00
|
|
|
progress_bar.update(index)
|
|
|
|
# generate a new batch of noise
|
|
|
|
noise = np.random.uniform(-1, 1, (batch_size, latent_size))
|
|
|
|
|
|
|
|
# get a batch of real images
|
|
|
|
image_batch = X_train[index * batch_size:(index + 1) * batch_size]
|
|
|
|
label_batch = y_train[index * batch_size:(index + 1) * batch_size]
|
|
|
|
|
|
|
|
# sample some labels from p_c
|
|
|
|
sampled_labels = np.random.randint(0, 10, batch_size)
|
|
|
|
|
|
|
|
# generate a batch of fake images, using the generated labels as a
|
|
|
|
# conditioner. We reshape the sampled labels to be
|
2016-11-25 04:21:56 +00:00
|
|
|
# (batch_size, 1) so that we can feed them into the embedding
|
2016-11-23 21:19:29 +00:00
|
|
|
# layer as a length one sequence
|
|
|
|
generated_images = generator.predict(
|
2016-11-25 04:21:56 +00:00
|
|
|
[noise, sampled_labels.reshape((-1, 1))], verbose=0)
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
X = np.concatenate((image_batch, generated_images))
|
|
|
|
y = np.array([1] * batch_size + [0] * batch_size)
|
|
|
|
aux_y = np.concatenate((label_batch, sampled_labels), axis=0)
|
|
|
|
|
|
|
|
# see if the discriminator can figure itself out...
|
|
|
|
epoch_disc_loss.append(discriminator.train_on_batch(X, [y, aux_y]))
|
|
|
|
|
|
|
|
# make new noise. we generate 2 * batch size here such that we have
|
|
|
|
# the generator optimize over an identical number of images as the
|
|
|
|
# discriminator
|
|
|
|
noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size))
|
|
|
|
sampled_labels = np.random.randint(0, 10, 2 * batch_size)
|
|
|
|
|
2017-03-28 11:44:56 +00:00
|
|
|
# we want to train the generator to trick the discriminator
|
2016-11-23 21:19:29 +00:00
|
|
|
# For the generator, we want all the {fake, not-fake} labels to say
|
|
|
|
# not-fake
|
|
|
|
trick = np.ones(2 * batch_size)
|
|
|
|
|
|
|
|
epoch_gen_loss.append(combined.train_on_batch(
|
2017-03-12 03:44:29 +00:00
|
|
|
[noise, sampled_labels.reshape((-1, 1))],
|
|
|
|
[trick, sampled_labels]))
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
print('\nTesting for epoch {}:'.format(epoch + 1))
|
|
|
|
|
|
|
|
# evaluate the testing loss here
|
|
|
|
|
|
|
|
# generate a new batch of noise
|
2017-02-15 00:08:30 +00:00
|
|
|
noise = np.random.uniform(-1, 1, (num_test, latent_size))
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# sample some labels from p_c and generate images from them
|
2017-02-15 00:08:30 +00:00
|
|
|
sampled_labels = np.random.randint(0, 10, num_test)
|
2016-11-23 21:19:29 +00:00
|
|
|
generated_images = generator.predict(
|
2016-11-25 04:21:56 +00:00
|
|
|
[noise, sampled_labels.reshape((-1, 1))], verbose=False)
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
X = np.concatenate((X_test, generated_images))
|
2017-02-15 00:08:30 +00:00
|
|
|
y = np.array([1] * num_test + [0] * num_test)
|
2016-11-23 21:19:29 +00:00
|
|
|
aux_y = np.concatenate((y_test, sampled_labels), axis=0)
|
|
|
|
|
|
|
|
# see if the discriminator can figure itself out...
|
|
|
|
discriminator_test_loss = discriminator.evaluate(
|
|
|
|
X, [y, aux_y], verbose=False)
|
|
|
|
|
|
|
|
discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)
|
|
|
|
|
|
|
|
# make new noise
|
2017-02-15 00:08:30 +00:00
|
|
|
noise = np.random.uniform(-1, 1, (2 * num_test, latent_size))
|
|
|
|
sampled_labels = np.random.randint(0, 10, 2 * num_test)
|
2016-11-23 21:19:29 +00:00
|
|
|
|
2017-02-15 00:08:30 +00:00
|
|
|
trick = np.ones(2 * num_test)
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
generator_test_loss = combined.evaluate(
|
2016-11-25 04:21:56 +00:00
|
|
|
[noise, sampled_labels.reshape((-1, 1))],
|
2016-11-23 21:19:29 +00:00
|
|
|
[trick, sampled_labels], verbose=False)
|
|
|
|
|
|
|
|
generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)
|
|
|
|
|
|
|
|
# generate an epoch report on performance
|
|
|
|
train_history['generator'].append(generator_train_loss)
|
|
|
|
train_history['discriminator'].append(discriminator_train_loss)
|
|
|
|
|
|
|
|
test_history['generator'].append(generator_test_loss)
|
|
|
|
test_history['discriminator'].append(discriminator_test_loss)
|
|
|
|
|
|
|
|
print('{0:<22s} | {1:4s} | {2:15s} | {3:5s}'.format(
|
|
|
|
'component', *discriminator.metrics_names))
|
|
|
|
print('-' * 65)
|
|
|
|
|
|
|
|
ROW_FMT = '{0:<22s} | {1:<4.2f} | {2:<15.2f} | {3:<5.2f}'
|
|
|
|
print(ROW_FMT.format('generator (train)',
|
|
|
|
*train_history['generator'][-1]))
|
|
|
|
print(ROW_FMT.format('generator (test)',
|
|
|
|
*test_history['generator'][-1]))
|
|
|
|
print(ROW_FMT.format('discriminator (train)',
|
|
|
|
*train_history['discriminator'][-1]))
|
|
|
|
print(ROW_FMT.format('discriminator (test)',
|
|
|
|
*test_history['discriminator'][-1]))
|
|
|
|
|
|
|
|
# save weights every epoch
|
|
|
|
generator.save_weights(
|
|
|
|
'params_generator_epoch_{0:03d}.hdf5'.format(epoch), True)
|
|
|
|
discriminator.save_weights(
|
|
|
|
'params_discriminator_epoch_{0:03d}.hdf5'.format(epoch), True)
|
|
|
|
|
|
|
|
# generate some digits to display
|
|
|
|
noise = np.random.uniform(-1, 1, (100, latent_size))
|
|
|
|
|
|
|
|
sampled_labels = np.array([
|
|
|
|
[i] * 10 for i in range(10)
|
2016-11-25 04:21:56 +00:00
|
|
|
]).reshape(-1, 1)
|
2016-11-23 21:19:29 +00:00
|
|
|
|
|
|
|
# get a batch to display
|
|
|
|
generated_images = generator.predict(
|
|
|
|
[noise, sampled_labels], verbose=0)
|
|
|
|
|
|
|
|
# arrange them into a grid
|
|
|
|
img = (np.concatenate([r.reshape(-1, 28)
|
|
|
|
for r in np.split(generated_images, 10)
|
|
|
|
], axis=-1) * 127.5 + 127.5).astype(np.uint8)
|
|
|
|
|
|
|
|
Image.fromarray(img).save(
|
|
|
|
'plot_epoch_{0:03d}_generated.png'.format(epoch))
|
|
|
|
|
|
|
|
pickle.dump({'train': train_history, 'test': test_history},
|
|
|
|
open('acgan-history.pkl', 'wb'))
|