'''Deep Dreaming in Keras. Run the script with: ``` python deep_dream.py path_to_your_base_image.jpg prefix_for_results ``` e.g.: ``` python deep_dream.py img/mypic.jpg results/dream ``` It is preferable to run this script on GPU, for speed. If running on CPU, prefer the TensorFlow backend (much faster). Example results: http://i.imgur.com/FX6ROg9.jpg ''' from __future__ import print_function from scipy.misc import imread, imresize, imsave import numpy as np from scipy.optimize import fmin_l_bfgs_b import time import argparse import h5py import os from keras.models import Sequential from keras.layers import Convolution2D, ZeroPadding2D, MaxPooling2D from keras import backend as K parser = argparse.ArgumentParser(description='Deep Dreams with Keras.') parser.add_argument('base_image_path', metavar='base', type=str, help='Path to the image to transform.') parser.add_argument('result_prefix', metavar='res_prefix', type=str, help='Prefix for the saved results.') args = parser.parse_args() base_image_path = args.base_image_path result_prefix = args.result_prefix # dimensions of the generated picture. img_width = 600 img_height = 600 # path to the model weights file. weights_path = 'vgg16_weights.h5' # some settings we found interesting saved_settings = { 'bad_trip': {'features': {'conv4_1': 0.05, 'conv4_2': 0.01, 'conv4_3': 0.01}, 'continuity': 0.1, 'dream_l2': 0.8, 'jitter': 5}, 'dreamy': {'features': {'conv5_1': 0.05, 'conv5_2': 0.02}, 'continuity': 0.1, 'dream_l2': 0.02, 'jitter': 0}, } # the settings we will use in this experiment settings = saved_settings['dreamy'] # util function to open, resize and format pictures into appropriate tensors def preprocess_image(image_path): img = imresize(imread(image_path), (img_width, img_height)) img = img.transpose((2, 0, 1)).astype('float64') img = np.expand_dims(img, axis=0) return img # util function to convert a tensor into a valid image def deprocess_image(x): x = x.transpose((1, 2, 0)) x = np.clip(x, 0, 255).astype('uint8') return x # build the VGG16 network model = Sequential() model.add(ZeroPadding2D((1, 1), batch_input_shape=(1, 3, img_width, img_height))) first_layer = model.layers[-1] # this is a placeholder tensor that will contain our generated images dream = first_layer.input model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_1')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_2')) model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_1')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_2')) model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_1')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_2')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_3')) model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_1')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_2')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_3')) model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_1')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_2')) model.add(ZeroPadding2D((1, 1))) model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_3')) model.add(MaxPooling2D((2, 2), strides=(2, 2))) # load the weights of the VGG16 networks # (trained on ImageNet, won the ILSVRC competition in 2014) # note: when there is a complete match between your model definition # and your weight savefile, you can simply call model.load_weights(filename) assert os.path.exists(weights_path), 'Model weights not found (see "weights_path" variable in script).' f = h5py.File(weights_path) for k in range(f.attrs['nb_layers']): if k >= len(model.layers): # we don't look at the last (fully-connected) layers in the savefile break g = f['layer_{}'.format(k)] weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])] model.layers[k].set_weights(weights) f.close() print('Model loaded.') # get the symbolic outputs of each "key" layer (we gave them unique names). layer_dict = dict([(layer.name, layer) for layer in model.layers]) # continuity loss util function def continuity_loss(x): assert K.ndim(x) == 4 a = K.square(x[:, :, :img_width-1, :img_height-1] - x[:, :, 1:, :img_height-1]) b = K.square(x[:, :, :img_width-1, :img_height-1] - x[:, :, :img_width-1, 1:]) return K.sum(K.pow(a + b, 1.25)) # define the loss loss = K.variable(0.) for layer_name in settings['features']: # add the L2 norm of the features of a layer to the loss assert layer_name in layer_dict.keys(), 'Layer ' + layer_name + ' not found in model.' coeff = settings['features'][layer_name] x = layer_dict[layer_name].output shape = layer_dict[layer_name].output_shape # we avoid border artifacts by only involving non-border pixels in the loss loss -= coeff * K.sum(K.square(x[:, :, 2: shape[2]-2, 2: shape[3]-2])) / np.prod(shape[1:]) # add continuity loss (gives image local coherence, can result in an artful blur) loss += settings['continuity'] * continuity_loss(dream) / (3 * img_width * img_height) # add image L2 norm to loss (prevents pixels from taking very high values, makes image darker) loss += settings['dream_l2'] * K.sum(K.square(dream)) / (3 * img_width * img_height) # feel free to further modify the loss as you see fit, to achieve new effects... # compute the gradients of the dream wrt the loss grads = K.gradients(loss, dream) outputs = [loss] if type(grads) in {list, tuple}: outputs += grads else: outputs.append(grads) f_outputs = K.function([dream], outputs) def eval_loss_and_grads(x): x = x.reshape((1, 3, img_width, img_height)) outs = f_outputs([x]) loss_value = outs[0] if len(outs[1:]) == 1: grad_values = outs[1].flatten().astype('float64') else: grad_values = np.array(outs[1:]).flatten().astype('float64') return loss_value, grad_values # this Evaluator class makes it possible # to compute loss and gradients in one pass # while retrieving them via two separate functions, # "loss" and "grads". This is done because scipy.optimize # requires separate functions for loss and gradients, # but computing them separately would be inefficient. class Evaluator(object): def __init__(self): self.loss_value = None self.grad_values = None def loss(self, x): assert self.loss_value is None loss_value, grad_values = eval_loss_and_grads(x) self.loss_value = loss_value self.grad_values = grad_values return self.loss_value def grads(self, x): assert self.loss_value is not None grad_values = np.copy(self.grad_values) self.loss_value = None self.grad_values = None return grad_values evaluator = Evaluator() # run scipy-based optimization (L-BFGS) over the pixels of the generated image # so as to minimize the loss x = preprocess_image(base_image_path) for i in range(5): print('Start of iteration', i) start_time = time.time() # add a random jitter to the initial image. This will be reverted at decoding time random_jitter = (settings['jitter'] * 2) * (np.random.random((3, img_width, img_height)) - 0.5) x += random_jitter # run L-BFGS for 7 steps x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), fprime=evaluator.grads, maxfun=7) print('Current loss value:', min_val) # decode the dream and save it x = x.reshape((3, img_width, img_height)) x -= random_jitter img = deprocess_image(x) fname = result_prefix + '_at_iteration_%d.png' % i imsave(fname, img) end_time = time.time() print('Image saved as', fname) print('Iteration %d completed in %ds' % (i, end_time - start_time))