diff --git a/examples/mnist_acgan.py b/examples/mnist_acgan.py index b28e8ac1c..20d949b02 100644 --- a/examples/mnist_acgan.py +++ b/examples/mnist_acgan.py @@ -160,8 +160,6 @@ if __name__ == '__main__': loss=['binary_crossentropy', 'sparse_categorical_crossentropy'] ) - discriminator.trainable = True - # get our mnist data, and force it to be of shape (..., 1, 28, 28) with # range [-1, 1] (X_train, y_train), (X_test, y_test) = mnist.load_data() @@ -217,10 +215,7 @@ if __name__ == '__main__': noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size)) sampled_labels = np.random.randint(0, 10, 2 * batch_size) - # we want to fix the discriminator and let the generator train to - # trick it - discriminator.trainable = False - + # we want to train the genrator to trick the discriminator # For the generator, we want all the {fake, not-fake} labels to say # not-fake trick = np.ones(2 * batch_size) @@ -228,8 +223,6 @@ if __name__ == '__main__': epoch_gen_loss.append(combined.train_on_batch( [noise, sampled_labels.reshape((-1, 1))], [trick, sampled_labels])) - discriminator.trainable = True - print('\nTesting for epoch {}:'.format(epoch + 1)) # evaluate the testing loss here