Finish PR backporting
This commit is contained in:
parent
eb8b40cccd
commit
2902149f77
@ -63,7 +63,7 @@ batch_size = 128
|
||||
nb_classes = 10
|
||||
nb_epoch = 40
|
||||
|
||||
# the data, shuffled and split between tran and test sets
|
||||
# the data, shuffled and split between train and test sets
|
||||
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||
|
||||
X_train = X_train.reshape(60000, 784)
|
||||
|
@ -28,7 +28,7 @@ nb_pool = 2
|
||||
# convolution kernel size
|
||||
nb_conv = 3
|
||||
|
||||
# the data, shuffled and split between tran and test sets
|
||||
# the data, shuffled and split between train and test sets
|
||||
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||
|
||||
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
|
||||
|
@ -20,7 +20,7 @@ batch_size = 128
|
||||
nb_classes = 10
|
||||
nb_epoch = 20
|
||||
|
||||
# the data, shuffled and split between tran and test sets
|
||||
# the data, shuffled and split between train and test sets
|
||||
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||
|
||||
X_train = X_train.reshape(60000, 784)
|
||||
|
@ -73,7 +73,7 @@ def compute_accuracy(predictions, labels):
|
||||
return labels[predictions.ravel() < 0.5].mean()
|
||||
|
||||
|
||||
# the data, shuffled and split between tran and test sets
|
||||
# the data, shuffled and split between train and test sets
|
||||
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||
X_train = X_train.reshape(60000, 784)
|
||||
X_test = X_test.reshape(10000, 784)
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
|
||||
import warnings
|
||||
import copy
|
||||
import time
|
||||
import numpy as np
|
||||
@ -1170,6 +1171,12 @@ class Model(Container):
|
||||
samples_seen += batch_size
|
||||
|
||||
# epoch finished
|
||||
if samples_seen > samples_per_epoch:
|
||||
warnings.warn('Epoch comprised more than '
|
||||
'`samples_per_epoch` samples, '
|
||||
'which might affect learning results. '
|
||||
'Set `samples_per_epoch` correctly '
|
||||
'to avoid this warning.')
|
||||
if samples_seen >= samples_per_epoch and do_validation:
|
||||
if val_gen:
|
||||
val_outs = self.evaluate_generator(validation_data,
|
||||
|
@ -11,14 +11,13 @@ from scipy import linalg
|
||||
|
||||
from os import listdir
|
||||
from os.path import isfile, join
|
||||
import random
|
||||
import math
|
||||
from six.moves import range
|
||||
import threading
|
||||
|
||||
|
||||
def random_rotation(x, rg, fill_mode='nearest', cval=0.):
|
||||
angle = random.uniform(-rg, rg)
|
||||
angle = np.random.uniform(-rg, rg)
|
||||
x = ndimage.interpolation.rotate(x, angle,
|
||||
axes=(1, 2),
|
||||
reshape=False,
|
||||
@ -31,9 +30,9 @@ def random_shift(x, wrg, hrg, fill_mode='nearest', cval=0.):
|
||||
shift_x = shift_y = 0
|
||||
|
||||
if wrg:
|
||||
shift_x = random.uniform(-wrg, wrg) * x.shape[2]
|
||||
shift_x = np.random.uniform(-wrg, wrg) * x.shape[2]
|
||||
if hrg:
|
||||
shift_y = random.uniform(-hrg, hrg) * x.shape[1]
|
||||
shift_y = np.random.uniform(-hrg, hrg) * x.shape[1]
|
||||
x = ndimage.interpolation.shift(x, (0, shift_y, shift_x),
|
||||
order=0,
|
||||
mode=fill_mode,
|
||||
@ -59,7 +58,7 @@ def random_barrel_transform(x, intensity):
|
||||
|
||||
|
||||
def random_shear(x, intensity, fill_mode='nearest', cval=0.):
|
||||
shear = random.uniform(-intensity, intensity)
|
||||
shear = np.random.uniform(-intensity, intensity)
|
||||
shear_matrix = np.array([[1.0, -math.sin(shear), 0.0],
|
||||
[0.0, math.cos(shear), 0.0],
|
||||
[0.0, 0.0, 1.0]])
|
||||
@ -76,8 +75,8 @@ def random_channel_shift(x, rg):
|
||||
|
||||
|
||||
def random_zoom(x, rg, fill_mode='nearest', cval=0.):
|
||||
zoom_w = random.uniform(1.-rg, 1.)
|
||||
zoom_h = random.uniform(1.-rg, 1.)
|
||||
zoom_w = np.random.uniform(1.-rg, 1.)
|
||||
zoom_h = np.random.uniform(1.-rg, 1.)
|
||||
x = ndimage.interpolation.zoom(x, zoom=(1., zoom_w, zoom_h),
|
||||
mode=fill_mode,
|
||||
cval=cval)
|
||||
@ -253,10 +252,10 @@ class ImageDataGenerator(object):
|
||||
if self.width_shift_range or self.height_shift_range:
|
||||
x = random_shift(x, self.width_shift_range, self.height_shift_range)
|
||||
if self.horizontal_flip:
|
||||
if random.random() < 0.5:
|
||||
if np.random.random() < 0.5:
|
||||
x = horizontal_flip(x)
|
||||
if self.vertical_flip:
|
||||
if random.random() < 0.5:
|
||||
if np.random.random() < 0.5:
|
||||
x = vertical_flip(x)
|
||||
if self.shear_range:
|
||||
x = random_shear(x, self.shear_range)
|
||||
|
Loading…
Reference in New Issue
Block a user