This commit is contained in:
Francois Chollet 2016-12-19 15:12:45 -08:00
commit 070609cbac
2 changed files with 21 additions and 5 deletions

@ -75,7 +75,7 @@ def create_base_network(input_dim):
def compute_accuracy(predictions, labels): def compute_accuracy(predictions, labels):
'''Compute classification accuracy with a fixed threshold on distances. '''Compute classification accuracy with a fixed threshold on distances.
''' '''
return labels[predictions.ravel() < 0.5].mean() return np.mean(labels == (predictions.ravel() > 0.5))
# the data, shuffled and split between train and test sets # the data, shuffled and split between train and test sets

@ -8,6 +8,13 @@ e.g.:
``` ```
python neural_style_transfer.py img/tuebingen.jpg img/starry_night.jpg results/my_result python neural_style_transfer.py img/tuebingen.jpg img/starry_night.jpg results/my_result
``` ```
Optional parameters:
```
--iter, To specify the number of iterations the style transfer takes place (Default is 10)
--content_weight, The weight given to the content loss (Default is 0.025)
--style_weight, The weight given to the style loss (Default is 1.0)
--tv_weight, The weight given to the total variation loss (Default is 1.0)
```
It is preferable to run this script on GPU, for speed. It is preferable to run this script on GPU, for speed.
@ -60,16 +67,25 @@ parser.add_argument('style_reference_image_path', metavar='ref', type=str,
help='Path to the style reference image.') help='Path to the style reference image.')
parser.add_argument('result_prefix', metavar='res_prefix', type=str, parser.add_argument('result_prefix', metavar='res_prefix', type=str,
help='Prefix for the saved results.') help='Prefix for the saved results.')
parser.add_argument('--iter', type=int, default=10, required=False,
help='Number of iterations to run.')
parser.add_argument('--content_weight', type=float, default=0.025, required=False,
help='Content weight.')
parser.add_argument('--style_weight', type=float, default=1.0, required=False,
help='Style weight.')
parser.add_argument('--tv_weight', type=float, default=1.0, required=False,
help='Total Variation weight.')
args = parser.parse_args() args = parser.parse_args()
base_image_path = args.base_image_path base_image_path = args.base_image_path
style_reference_image_path = args.style_reference_image_path style_reference_image_path = args.style_reference_image_path
result_prefix = args.result_prefix result_prefix = args.result_prefix
iterations = args.iter
# these are the weights of the different loss components # these are the weights of the different loss components
total_variation_weight = 1. total_variation_weight = args.tv_weight
style_weight = 1. style_weight = args.style_weight
content_weight = 0.025 content_weight = args.content_weight
# dimensions of the generated picture. # dimensions of the generated picture.
img_nrows = 400 img_nrows = 400
@ -246,7 +262,7 @@ if K.image_dim_ordering() == 'th':
else: else:
x = np.random.uniform(0, 255, (1, img_nrows, img_ncols, 3)) - 128. x = np.random.uniform(0, 255, (1, img_nrows, img_ncols, 3)) - 128.
for i in range(10): for i in range(iterations):
print('Start of iteration', i) print('Start of iteration', i)
start_time = time.time() start_time = time.time()
x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(),