keras/preprocessing/image.py
2015-03-27 17:59:42 -07:00

239 lines
7.1 KiB
Python

from PIL import Image
import numpy as np
from scipy import ndimage
from scipy import linalg
from os import listdir
from os.path import isfile, join
import random, math
'''
Fairly basic set of tools for realtime data augmentation on image data.
Can easily be extended to include new transforms, new preprocessing methods, etc...
'''
def random_rotation(x, rg, fill_mode="nearest", cval=0.):
angle = random.uniform(-rg, rg)
x = ndimage.interpolation.rotate(x, angle, axes=(1,2), reshape=False, mode=fill_mode, cval=cval)
return x
def random_shift(x, wrg, hrg, fill_mode="nearest", cval=0.):
crop_left_pixels = 0
crop_right_pixels = 0
crop_top_pixels = 0
crop_bottom_pixels = 0
original_w = x.shape[1]
original_h = x.shape[2]
if wrg:
crop = random.uniform(0., wrg)
split = random.uniform(0, 1)
crop_left_pixels = int(split*crop*x.shape[1])
crop_right_pixels = int((1-split)*crop*x.shape[1])
if hrg:
crop = random.uniform(0., hrg)
split = random.uniform(0, 1)
crop_top_pixels = int(split*crop*x.shape[2])
crop_bottom_pixels = int((1-split)*crop*x.shape[2])
x = ndimage.interpolation.shift(x, (0, crop_left_pixels, crop_top_pixels), mode=fill_mode, cval=cval)
return x
def horizontal_flip(x):
for i in range(x.shape[0]):
x[i] = np.fliplr(x[i])
return x
def vertical_flip(x):
for i in range(x.shape[0]):
x[i] = np.flipud(x[i])
return x
def random_barrel_transform(x, intensity):
# TODO
pass
def random_shear(x, intensity):
# TODO
pass
def random_channel_shift(x, rg):
# TODO
pass
def random_zoom(x, rg, fill_mode="nearest", cval=0.):
zoom_w = random.uniform(1.-rg, 1.)
zoom_h = random.uniform(1.-rg, 1.)
x = ndimage.interpolation.zoom(x, zoom=(1., zoom_w, zoom_h), mode=fill_mode, cval=cval)
return x # shape of result will be different from shape of input!
def array_to_img(x, scale=True):
x = x.transpose(1, 2, 0)
if scale:
x += max(-np.min(x), 0)
x /= np.max(x)
x *= 255
if x.shape[2] == 3:
# RGB
return Image.fromarray(x.astype("uint8"), "RGB")
else:
# grayscale
return Image.fromarray(x.astype("uint8"), "L")
def img_to_array(img):
x = np.asarray(img, dtype='float32')
return x.transpose(2, 0, 1)
def load_img(path, grayscale=False):
img = Image.open(open(path))
if grayscale:
img = img.convert('L')
return img
def list_pictures(directory, ext='jpg|jpeg|bmp|png'):
return [join(directory,f) for f in listdir(directory) \
if isfile(join(directory,f)) and re.match('([\w]+\.(?:' + ext + '))', f)]
class ImageDataGenerator(object):
'''
Generate minibatches with
realtime data augmentation.
'''
def __init__(self,
featurewise_center=True, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=True, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0., # degrees (0 to 180)
width_shift_range=0., # fraction of total width
height_shift_range=0., # fraction of total height
horizontal_flip=False,
vertical_flip=False,
):
self.__dict__.update(locals())
self.mean = None
self.std = None
self.principal_components = None
def flow(self, X, y, batch_size=32, shuffle=False, seed=None, save_to_dir=None, save_prefix="", save_format="jpeg"):
if seed:
random.seed(seed)
if shuffle:
seed = random.randint(1, 10e6)
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
nb_batch = int(math.ceil(float(X.shape[0])/batch_size))
for b in range(nb_batch):
batch_end = (b+1)*batch_size
if batch_end > X.shape[0]:
nb_samples = X.shape[0] - b*batch_size
else:
nb_samples = batch_size
bX = np.zeros(tuple([nb_samples]+list(X.shape)[1:]))
for i in range(nb_samples):
x = X[b*batch_size+i]
x = self.random_transform(x.astype("float32"))
x = self.standardize(x)
bX[i] = x
if save_to_dir:
for i in range(nb_samples):
img = array_to_img(bX[i], scale=True)
img.save(save_to_dir + "/" + save_prefix + "_" + str(i) + "." + save_format)
yield bX, y[b*batch_size:b*batch_size+nb_samples]
def standardize(self, x):
if self.featurewise_center:
x -= self.mean
if self.featurewise_std_normalization:
x /= self.std
if self.zca_whitening:
flatx = np.reshape(x, (x.shape[0]*x.shape[1]*x.shape[2]))
whitex = np.dot(flatx, self.principal_components)
x = np.reshape(whitex, (x.shape[0], x.shape[1], x.shape[2]))
if self.samplewise_center:
x -= np.mean(x)
if self.samplewise_std_normalization:
x /= np.std(x)
return x
def random_transform(self, x):
if self.rotation_range:
x = random_rotation(x, self.rotation_range)
if self.width_shift_range or self.height_shift_range:
x = random_shift(x, self.width_shift_range, self.height_shift_range)
if self.horizontal_flip:
if random.random() < 0.5:
x = horizontal_flip(x)
if self.vertical_flip:
if random.random() < 0.5:
x = vertical_flip(x)
# TODO:
# zoom
# barrel/fisheye
# shearing
# channel shifting
return x
def fit(self, X,
augment=False, # fit on randomly augmented samples
rounds=1, # if augment, how many augmentation passes over the data do we use
seed=None
):
'''
Required for featurewise_center, featurewise_std_normalization and zca_whitening.
'''
X = np.copy(X)
if augment:
aX = np.zeros(tuple([rounds*X.shape[0]]+list(X.shape)[1:]))
for r in range(rounds):
for i in range(X.shape[0]):
img = array_to_img(X[i])
img = self.random_transform(img)
aX[i+r*X.shape[0]] = img_to_array(img)
X = aX
if self.featurewise_center:
self.mean = np.mean(X, axis=0)
X -= self.mean
if self.featurewise_std_normalization:
self.std = np.std(X)
X /= self.std
if self.zca_whitening:
flatX = np.reshape(X, (X.shape[0], X.shape[1]*X.shape[2]*X.shape[3]))
fudge = 10e-6
sigma = np.dot(flatX.T, flatX) / flatX.shape[1]
U, S, V = linalg.svd(sigma)
self.principal_components = np.dot(np.dot(U, np.diag(1. / np.sqrt(S + fudge))), U.T)