From c61d075abcc17fd87532688898577c6a1c0785e9 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 26 Jan 2016 09:36:35 -0800 Subject: [PATCH] Fix ImageDataGenerator docs --- docs/templates/preprocessing/image.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/templates/preprocessing/image.md b/docs/templates/preprocessing/image.md index a40104e60..dfc119977 100644 --- a/docs/templates/preprocessing/image.md +++ b/docs/templates/preprocessing/image.md @@ -14,7 +14,7 @@ keras.preprocessing.image.ImageDataGenerator(featurewise_center=True, vertical_flip=False) ``` -Generate batches of tensor image data with real-time data augmentation. +Generate batches of tensor image data with real-time data augmentation. The data will be looped over (in batches) indefinitely. - __Arguments__: - __featurewise_center__: Boolean. Set input mean to 0 over the dataset. @@ -62,9 +62,19 @@ datagen = ImageDataGenerator( # (std, mean, and principal components if ZCA whitening is applied) datagen.fit(X_train) +# fits the model on batches with real-time data augmentation: +model.fit_generator(datagen.flow(X_train, Y_train, batch_size=32), + samples_per_epoch=len(X_train), nb_epoch=nb_epoch) + +# here's a more "manual" example for e in range(nb_epoch): print 'Epoch', e - # batch train with realtime data augmentation - for X_batch, Y_batch in datagen.flow(X_train, Y_train): + batches = 0 + for X_batch, Y_batch in datagen.flow(X_train, Y_train, batch_size=32): loss = model.train(X_batch, Y_batch) + batches += 1 + if batches >= len(X_train) / 32: + # we need to break the loop by hand because + # the generator loops indefinitely + break ``` \ No newline at end of file