From 40685c3b2a8c535f6c68fa3917ce5f1507dce186 Mon Sep 17 00:00:00 2001 From: dolaameng Date: Fri, 9 Sep 2016 01:15:57 +0800 Subject: [PATCH] add examples/neural_doodle.py (#3724) --- examples/neural_doodle.py | 363 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 363 insertions(+) create mode 100644 examples/neural_doodle.py diff --git a/examples/neural_doodle.py b/examples/neural_doodle.py new file mode 100644 index 000000000..aa5259537 --- /dev/null +++ b/examples/neural_doodle.py @@ -0,0 +1,363 @@ +'''Neural doodle with Keras + +Script Usage: + # Arguments: + ``` + --nlabels: # of regions (colors) in mask images + --style-image: image to learn style from + --style-mask: semantic labels for style image + --target-mask: semantic labels for target image (your doodle) + --content-image: optional image to learn content from + --target-image-prefix: path prefix for generated target images + ``` + + # Example 1: doodle using a style image, style mask + and target mask. + ``` + python neural_doodle.py --nlabels 4 --style-image Monet/style.png \ + --style-mask Monet/style_mask.png --target-mask Monet/target_mask.png \ + --target-image-prefix generated/monet + ``` + + # Example 2: doodle using a style image, style mask, + target mask and an optional content image. + ``` + python neural_doodle.py --nlabels 4 --style-image Renoir/style.png \ + --style-mask Renoir/style_mask.png --target-mask Renoir/target_mask.png \ + --content-image Renoir/creek.jpg \ + --target-image-prefix generated/renoir + ``` + +References: +[Dmitry Ulyanov's blog on fast-neural-doodle](http://dmitryulyanov.github.io/feed-forward-neural-doodle/) +[Torch code for fast-neural-doodle](https://github.com/DmitryUlyanov/fast-neural-doodle) +[Torch code for online-neural-doodle](https://github.com/DmitryUlyanov/online-neural-doodle) +[Paper Texture Networks: Feed-forward Synthesis of Textures and Stylized Images](http://arxiv.org/abs/1603.03417) +[Discussion on parameter tuning](https://github.com/fchollet/keras/issues/3705) + +Resources: +Example images can be downloaded from +https://github.com/DmitryUlyanov/fast-neural-doodle/tree/master/data +''' +from __future__ import print_function +import time +import argparse +import numpy as np +from scipy.optimize import fmin_l_bfgs_b +from scipy.misc import imread, imsave + +from keras import backend as K +from keras.layers import Input, Convolution2D, MaxPooling2D, AveragePooling2D +from keras.models import Model +from keras.preprocessing.image import load_img, img_to_array +from keras.applications import vgg19 + +# Command line arguments +parser = argparse.ArgumentParser(description='Keras neural doodle example') +parser.add_argument('--nlabels', type=int, + help='number of semantic labels' + ' (regions in differnet colors)' + ' in style_mask/target_mask') +parser.add_argument('--style-image', type=str, + help='path to image to learn style from') +parser.add_argument('--style-mask', type=str, + help='path to semantic mask of style image') +parser.add_argument('--target-mask', type=str, + help='path to semantic mask of target image') +parser.add_argument('--content-image', type=str, default=None, + help='path to optional content image') +parser.add_argument('--target-image-prefix', type=str, + help='path prefix for generated results') +args = parser.parse_args() + +style_img_path = args.style_image +style_mask_path = args.style_mask +target_mask_path = args.target_mask +content_img_path = args.content_image +target_img_prefix = args.target_image_prefix +use_content_img = content_img_path is not None + +nb_labels = args.nlabels +nb_colors = 3 # RGB +# determine image sizes based on target_mask +ref_img = imread(target_mask_path) +img_nrows, img_ncols = ref_img.shape[:2] + +total_variation_weight = 8.5e-5 +style_weight = 1. +content_weight = 0.1 if use_content_img else 0 + +content_feature_layers = ['block5_conv2'] +# To get better generation qualities, use more conv layers for style features +style_feature_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', + 'block4_conv1', 'block5_conv1'] + + +# helper functions for reading/processing images +def preprocess_image(image_path): + img = load_img(image_path, target_size=(img_nrows, img_ncols)) + img = img_to_array(img) + img = np.expand_dims(img, axis=0) + img = vgg19.preprocess_input(img) + return img + + +def deprocess_image(x): + if K.image_dim_ordering() == 'th': + x = x.reshape((3, img_nrows, img_ncols)) + x = x.transpose((1, 2, 0)) + else: + x = x.reshape((img_nrows, img_ncols, 3)) + x = x[:, :, ::-1] + x[:, :, 0] += 103.939 + x[:, :, 1] += 116.779 + x[:, :, 2] += 123.68 + x = np.clip(x, 0, 255).astype('uint8') + return x + + +def kmeans(xs, k): + assert xs.ndim == 2 + try: + from sklearn.cluster import k_means + _, labels, _ = k_means(xs.astype("float64"), k) + except ImportError: + from scipy.cluster.vq import kmeans2 + _, labels = kmeans2(xs, k, missing='raise') + return labels + + +def load_mask_labels(): + '''Load both target and style masks. + A mask image (nr x nc) with m labels/colors will be loaded + as a 4D boolean tensor: (1, m, nr, nc) for 'th' or (1, nr, nc, m) for 'tf' + ''' + target_mask_img = load_img(target_mask_path, + target_size=(img_nrows, img_ncols)) + target_mask_img = img_to_array(target_mask_img) + style_mask_img = load_img(style_mask_path, + target_size=(img_nrows, img_ncols)) + style_mask_img = img_to_array(style_mask_img) + if K.image_dim_ordering() == 'th': + mask_vecs = np.vstack([style_mask_img.reshape((3, -1)).T, + target_mask_img.reshape((3, -1)).T]) + else: + mask_vecs = np.vstack([style_mask_img.reshape((-1, 3)), + target_mask_img.reshape((-1, 3))]) + + labels = kmeans(mask_vecs, nb_labels) + style_mask_label = labels[:img_nrows * + img_ncols].reshape((img_nrows, img_ncols)) + target_mask_label = labels[img_nrows * + img_ncols:].reshape((img_nrows, img_ncols)) + + stack_axis = 0 if K.image_dim_ordering() == 'th' else -1 + style_mask = np.stack([style_mask_label == r for r in xrange(nb_labels)], + axis=stack_axis) + target_mask = np.stack([target_mask_label == r for r in xrange(nb_labels)], + axis=stack_axis) + + return (np.expand_dims(style_mask, axis=0), + np.expand_dims(target_mask, axis=0)) + +# Create tensor variables for images +if K.image_dim_ordering() == 'th': + shape = (1, nb_colors, img_nrows, img_ncols) +else: + shape = (1, img_nrows, img_ncols, nb_colors) + +style_image = K.variable(preprocess_image(style_img_path)) +target_image = K.placeholder(shape=shape) +if use_content_img: + content_image = K.variable(preprocess_image(content_img_path)) +else: + content_image = K.zeros(shape=shape) + +images = K.concatenate([style_image, target_image, content_image], axis=0) + +# Create tensor variables for masks +raw_style_mask, raw_target_mask = load_mask_labels() +style_mask = K.variable(raw_style_mask.astype("float32")) +target_mask = K.variable(raw_target_mask.astype("float32")) +masks = K.concatenate([style_mask, target_mask], axis=0) + +# index constants for images and tasks variables +STYLE, TARGET, CONTENT = 0, 1, 2 + +# Build image model, mask model and use layer outputs as features +# image model as VGG19 +image_model = vgg19.VGG19(include_top=False, input_tensor=images) + +# mask model as a series of pooling +mask_input = Input(tensor=masks, shape=(None, None, None), name="mask_input") +x = mask_input +for layer in image_model.layers[1:]: + name = 'mask_%s' % layer.name + if 'conv' in layer.name: + x = AveragePooling2D((3, 3), strides=( + 1, 1), name=name, border_mode="same")(x) + elif 'pool' in layer.name: + x = AveragePooling2D((2, 2), name=name)(x) +mask_model = Model(mask_input, x) + +# Collect features from image_model and task_model +image_features = {} +mask_features = {} +for img_layer, mask_layer in zip(image_model.layers, mask_model.layers): + if 'conv' in img_layer.name: + assert 'mask_' + img_layer.name == mask_layer.name + layer_name = img_layer.name + img_feat, mask_feat = img_layer.output, mask_layer.output + image_features[layer_name] = img_feat + mask_features[layer_name] = mask_feat + + +# Define loss functions +def gram_matrix(x): + assert K.ndim(x) == 3 + features = K.batch_flatten(x) + gram = K.dot(features, K.transpose(features)) + return gram + + +def region_style_loss(style_image, target_image, style_mask, target_mask): + '''Calculate style loss between style_image and target_image, + for one common region specified by their (boolean) masks + ''' + assert 3 == K.ndim(style_image) == K.ndim(target_image) + assert 2 == K.ndim(style_mask) == K.ndim(target_mask) + if K.image_dim_ordering() == 'th': + masked_style = style_image * style_mask + masked_target = target_image * target_mask + else: + masked_style = K.permute_dimensions( + style_image, (2, 0, 1)) * style_mask + masked_target = K.permute_dimensions( + target_image, (2, 0, 1)) * target_mask + s = gram_matrix(masked_style) * K.sum(style_mask) + c = gram_matrix(masked_target) * K.sum(target_mask) + return K.sum(K.square(s - c)) + + +def style_loss(style_image, target_image, style_masks, target_masks): + '''Calculate style loss between style_image and target_image, + in all regions. + ''' + assert 3 == K.ndim(style_image) == K.ndim(target_image) + assert 3 == K.ndim(style_masks) == K.ndim(target_masks) + loss = K.variable(0) + for i in xrange(nb_labels): + if K.image_dim_ordering() == 'th': + style_mask = style_masks[i, :, :] + target_mask = target_masks[i, :, :] + else: + style_mask = style_masks[:, :, i] + target_mask = target_masks[:, :, i] + loss += region_style_loss(style_image, + target_image, style_mask, target_mask) + size = img_nrows * img_ncols + return loss / (4. * nb_colors**2 * size**2) + + +def content_loss(content_image, target_image): + return K.sum(K.square(target_image - content_image)) + + +def total_variation_loss(x): + assert 4 == K.ndim(x) + if K.image_dim_ordering() == 'th': + a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - + x[:, :, 1:, :img_ncols - 1]) + b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - + x[:, :, :img_nrows - 1, 1:]) + else: + a = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - + x[:, 1:, :img_ncols - 1, :]) + b = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - + x[:, :img_nrows - 1, 1:, :]) + return K.sum(K.pow(a + b, 1.25)) + +# Overall loss is the weighted sum of content_loss, style_loss and tv_loss +# Each individual loss uses features from image/mask models. +loss = K.variable(0) +for layer in content_feature_layers: + content_feat = image_features[layer][CONTENT, :, :, :] + target_feat = image_features[layer][TARGET, :, :, :] + loss += content_weight * content_loss(content_feat, target_feat) + +for layer in style_feature_layers: + style_feat = image_features[layer][STYLE, :, :, :] + target_feat = image_features[layer][TARGET, :, :, :] + style_masks = mask_features[layer][STYLE, :, :, :] + target_masks = mask_features[layer][TARGET, :, :, :] + sl = style_loss(style_feat, target_feat, style_masks, target_masks) + loss += (style_weight / len(style_feature_layers)) * sl + +loss += total_variation_weight * total_variation_loss(target_image) +loss_grads = K.gradients(loss, target_image) + +# Evaluator class for computing efficiency +outputs = [loss] +if type(loss_grads) in {list, tuple}: + outputs += loss_grads +else: + outputs.append(loss_grads) + +f_outputs = K.function([target_image], outputs) + + +def eval_loss_and_grads(x): + if K.image_dim_ordering() == 'th': + x = x.reshape((1, 3, img_nrows, img_ncols)) + else: + x = x.reshape((1, img_nrows, img_ncols, 3)) + outs = f_outputs([x]) + loss_value = outs[0] + if len(outs[1:]) == 1: + grad_values = outs[1].flatten().astype('float64') + else: + grad_values = np.array(outs[1:]).flatten().astype('float64') + return loss_value, grad_values + + +class Evaluator(object): + + def __init__(self): + self.loss_value = None + self.grads_values = None + + def loss(self, x): + assert self.loss_value is None + loss_value, grad_values = eval_loss_and_grads(x) + self.loss_value = loss_value + self.grad_values = grad_values + return self.loss_value + + def grads(self, x): + assert self.loss_value is not None + grad_values = np.copy(self.grad_values) + self.loss_value = None + self.grad_values = None + return grad_values + +evaluator = Evaluator() + +# Generate images by iterative optimization +if K.image_dim_ordering() == 'th': + x = np.random.uniform(0, 255, (1, 3, img_nrows, img_ncols)) - 128. +else: + x = np.random.uniform(0, 255, (1, img_nrows, img_ncols, 3)) - 128. + +for i in range(100): + print('Start of iteration', i) + start_time = time.time() + x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), + fprime=evaluator.grads, maxfun=20) + print('Current loss value:', min_val) + # save current generated image + img = deprocess_image(x.copy()) + fname = target_img_prefix + '_at_iteration_%d.png' % i + imsave(fname, img) + end_time = time.time() + print('Image saved as', fname) + print('Iteration %d completed in %ds' % (i, end_time - start_time))