from PIL import Image import numpy as np from scipy import ndimage from scipy import linalg from os import listdir from os.path import isfile, join import random, math ''' Fairly basic set of tools for realtime data augmentation on image data. Can easily be extended to include new transforms, new preprocessing methods, etc... ''' def random_rotation(x, rg, fill_mode="nearest", cval=0.): angle = random.uniform(-rg, rg) x = ndimage.interpolation.rotate(x, angle, axes=(1,2), reshape=False, mode=fill_mode, cval=cval) return x def random_shift(x, wrg, hrg, fill_mode="nearest", cval=0.): crop_left_pixels = 0 crop_right_pixels = 0 crop_top_pixels = 0 crop_bottom_pixels = 0 original_w = x.shape[1] original_h = x.shape[2] if wrg: crop = random.uniform(0., wrg) split = random.uniform(0, 1) crop_left_pixels = int(split*crop*x.shape[1]) crop_right_pixels = int((1-split)*crop*x.shape[1]) if hrg: crop = random.uniform(0., hrg) split = random.uniform(0, 1) crop_top_pixels = int(split*crop*x.shape[2]) crop_bottom_pixels = int((1-split)*crop*x.shape[2]) x = ndimage.interpolation.shift(x, (0, crop_left_pixels, crop_top_pixels), mode=fill_mode, cval=cval) return x def horizontal_flip(x): for i in range(x.shape[0]): x[i] = np.fliplr(x[i]) return x def vertical_flip(x): for i in range(x.shape[0]): x[i] = np.flipud(x[i]) return x def random_barrel_transform(x, intensity): # TODO pass def random_shear(x, intensity): # TODO pass def random_channel_shift(x, rg): # TODO pass def random_zoom(x, rg, fill_mode="nearest", cval=0.): zoom_w = random.uniform(1.-rg, 1.) zoom_h = random.uniform(1.-rg, 1.) x = ndimage.interpolation.zoom(x, zoom=(1., zoom_w, zoom_h), mode=fill_mode, cval=cval) return x # shape of result will be different from shape of input! def array_to_img(x, scale=True): x = x.transpose(1, 2, 0) if scale: x += max(-np.min(x), 0) x /= np.max(x) x *= 255 if x.shape[2] == 3: # RGB return Image.fromarray(x.astype("uint8"), "RGB") else: # grayscale return Image.fromarray(x.astype("uint8"), "L") def img_to_array(img): x = np.asarray(img, dtype='float32') return x.transpose(2, 0, 1) def load_img(path, grayscale=False): img = if grayscale: img = img.convert('L') return img def list_pictures(directory, ext='jpg|jpeg|bmp|png'): return [join(directory,f) for f in listdir(directory) \ if isfile(join(directory,f)) and re.match('([\w]+\.(?:' + ext + '))', f)] class ImageDataGenerator(object): ''' Generate minibatches with realtime data augmentation. ''' def __init__(self, featurewise_center=True, # set input mean to 0 over the dataset samplewise_center=False, # set each sample mean to 0 featurewise_std_normalization=True, # divide inputs by std of the dataset samplewise_std_normalization=False, # divide each input by its std zca_whitening=False, # apply ZCA whitening rotation_range=0., # degrees (0 to 180) width_shift_range=0., # fraction of total width height_shift_range=0., # fraction of total height horizontal_flip=False, vertical_flip=False, ): self.__dict__.update(locals()) self.mean = None self.std = None self.principal_components = None def flow(self, X, y, batch_size=32, shuffle=False, seed=None, save_to_dir=None, save_prefix="", save_format="jpeg"): if seed: random.seed(seed) if shuffle: seed = random.randint(1, 10e6) np.random.seed(seed) np.random.shuffle(X) np.random.seed(seed) np.random.shuffle(y) nb_batch = int(math.ceil(float(X.shape[0])/batch_size)) for b in range(nb_batch): batch_end = (b+1)*batch_size if batch_end > X.shape[0]: nb_samples = X.shape[0] - b*batch_size else: nb_samples = batch_size bX = np.zeros(tuple([nb_samples]+list(X.shape)[1:])) for i in range(nb_samples): x = X[b*batch_size+i] x = self.random_transform(x.astype("float32")) x = self.standardize(x) bX[i] = x if save_to_dir: for i in range(nb_samples): img = array_to_img(bX[i], scale=True) + "/" + save_prefix + "_" + str(i) + "." + save_format) yield bX, y[b*batch_size:b*batch_size+nb_samples] def standardize(self, x): if self.featurewise_center: x -= self.mean if self.featurewise_std_normalization: x /= self.std if self.zca_whitening: flatx = np.reshape(x, (x.shape[0]*x.shape[1]*x.shape[2])) whitex =, self.principal_components) x = np.reshape(whitex, (x.shape[0], x.shape[1], x.shape[2])) if self.samplewise_center: x -= np.mean(x) if self.samplewise_std_normalization: x /= np.std(x) return x def random_transform(self, x): if self.rotation_range: x = random_rotation(x, self.rotation_range) if self.width_shift_range or self.height_shift_range: x = random_shift(x, self.width_shift_range, self.height_shift_range) if self.horizontal_flip: if random.random() < 0.5: x = horizontal_flip(x) if self.vertical_flip: if random.random() < 0.5: x = vertical_flip(x) # TODO: # zoom # barrel/fisheye # shearing # channel shifting return x def fit(self, X, augment=False, # fit on randomly augmented samples rounds=1, # if augment, how many augmentation passes over the data do we use seed=None ): ''' Required for featurewise_center, featurewise_std_normalization and zca_whitening. ''' X = np.copy(X) if augment: aX = np.zeros(tuple([rounds*X.shape[0]]+list(X.shape)[1:])) for r in range(rounds): for i in range(X.shape[0]): img = array_to_img(X[i]) img = self.random_transform(img) aX[i+r*X.shape[0]] = img_to_array(img) X = aX if self.featurewise_center: self.mean = np.mean(X, axis=0) X -= self.mean if self.featurewise_std_normalization: self.std = np.std(X) X /= self.std if self.zca_whitening: flatX = np.reshape(X, (X.shape[0], X.shape[1]*X.shape[2]*X.shape[3])) fudge = 10e-6 sigma =, flatX) / flatX.shape[1] U, S, V = linalg.svd(sigma) self.principal_components =, np.diag(1. / np.sqrt(S + fudge))), U.T)