diff --git a/docs/templates/preprocessing/image.md b/docs/templates/preprocessing/image.md index 163a24e72..bb0f64cfe 100644 --- a/docs/templates/preprocessing/image.md +++ b/docs/templates/preprocessing/image.md @@ -77,8 +77,9 @@ Generate batches of tensor image data with real-time data augmentation. The data The generator loops indefinitely. - __flow_from_directory(directory)__: Takes the path to a directory, and generates batches of augmented/normalized data. Yields batches indefinitely, in an infinite loop. - __Arguments__: - - __directory__: path to the target directory. It should contain one subdirectory per class, - and the subdirectories should contain PNG or JPG images. See [this script](https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d) for more details. + - __directory__: path to the target directory. It should contain one subdirectory per class. + Any PNG, JPG or BNP images inside each of the subdirectories directory tree will be included in the generator. + See [this script](https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d) for more details. - __target_size__: tuple of integers, default: `(256, 256)`. The dimensions to which all images found will be resized. - __color_mode__: one of "grayscale", "rbg". Default: "rgb". Whether the images will be converted to have 1 or 3 color channels. - __classes__: optional list of class subdirectories (e.g. `['dogs', 'cats']`). Default: None. If not provided, the list of classes will be automatically inferred (and the order of the classes, which will map to the label indices, will be alphanumeric). @@ -89,6 +90,7 @@ Generate batches of tensor image data with real-time data augmentation. The data - __save_to_dir__: None or str (default: None). This allows you to optimally specify a directory to which to save the augmented pictures being generated (useful for visualizing what you are doing). - __save_prefix__: str. Prefix to use for filenames of saved pictures (only relevant if `save_to_dir` is set). - __save_format__: one of "png", "jpeg" (only relevant if `save_to_dir` is set). Default: "jpeg". + - __follow_links__: whether to follow symlinks inside class subdirectories (default: False). - __Examples__: diff --git a/keras/preprocessing/image.py b/keras/preprocessing/image.py index 6af220238..18e433858 100644 --- a/keras/preprocessing/image.py +++ b/keras/preprocessing/image.py @@ -195,8 +195,9 @@ def load_img(path, grayscale=False, target_size=None): def list_pictures(directory, ext='jpg|jpeg|bmp|png'): - return [os.path.join(directory, f) for f in sorted(os.listdir(directory)) - if os.path.isfile(os.path.join(directory, f)) and re.match('([\w]+\.(?:' + ext + '))', f)] + return [os.path.join(root, f) + for root, dirs, files in os.walk(directory) for f in files + if re.match('([\w]+\.(?:' + ext + '))', f)] class ImageDataGenerator(object): @@ -300,14 +301,16 @@ class ImageDataGenerator(object): target_size=(256, 256), color_mode='rgb', classes=None, class_mode='categorical', batch_size=32, shuffle=True, seed=None, - save_to_dir=None, save_prefix='', save_format='jpeg'): + save_to_dir=None, save_prefix='', save_format='jpeg', + follow_links=False): return DirectoryIterator( directory, self, target_size=target_size, color_mode=color_mode, classes=classes, class_mode=class_mode, dim_ordering=self.dim_ordering, batch_size=batch_size, shuffle=shuffle, seed=seed, - save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format) + save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format, + follow_links=follow_links) def standardize(self, x): if self.preprocessing_function: @@ -589,7 +592,8 @@ class DirectoryIterator(Iterator): dim_ordering='default', classes=None, class_mode='categorical', batch_size=32, shuffle=True, seed=None, - save_to_dir=None, save_prefix='', save_format='jpeg'): + save_to_dir=None, save_prefix='', save_format='jpeg', + follow_links=False): if dim_ordering == 'default': dim_ordering = K.image_dim_ordering() self.directory = directory @@ -633,16 +637,20 @@ class DirectoryIterator(Iterator): self.nb_class = len(classes) self.class_indices = dict(zip(classes, range(len(classes)))) + def _recursive_list(subpath): + return sorted(os.walk(subpath, followlinks=follow_links), key=lambda tpl: tpl[0]) + for subdir in classes: subpath = os.path.join(directory, subdir) - for fname in sorted(os.listdir(subpath)): - is_valid = False - for extension in white_list_formats: - if fname.lower().endswith('.' + extension): - is_valid = True - break - if is_valid: - self.nb_sample += 1 + for root, dirs, files in _recursive_list(subpath): + for fname in files: + is_valid = False + for extension in white_list_formats: + if fname.lower().endswith('.' + extension): + is_valid = True + break + if is_valid: + self.nb_sample += 1 print('Found %d images belonging to %d classes.' % (self.nb_sample, self.nb_class)) # second, build an index of the images in the different class subfolders @@ -651,16 +659,19 @@ class DirectoryIterator(Iterator): i = 0 for subdir in classes: subpath = os.path.join(directory, subdir) - for fname in sorted(os.listdir(subpath)): - is_valid = False - for extension in white_list_formats: - if fname.lower().endswith('.' + extension): - is_valid = True - break - if is_valid: - self.classes[i] = self.class_indices[subdir] - self.filenames.append(os.path.join(subdir, fname)) - i += 1 + for root, dirs, files in _recursive_list(subpath): + for fname in files: + is_valid = False + for extension in white_list_formats: + if fname.lower().endswith('.' + extension): + is_valid = True + break + if is_valid: + self.classes[i] = self.class_indices[subdir] + i += 1 + # add filename relative to directory + absolute_path = os.path.join(root, fname) + self.filenames.append(os.path.relpath(absolute_path, directory)) super(DirectoryIterator, self).__init__(self.nb_sample, batch_size, shuffle, seed) def next(self): diff --git a/tests/keras/preprocessing/test_image.py b/tests/keras/preprocessing/test_image.py index 660a8388e..2792ac4cd 100644 --- a/tests/keras/preprocessing/test_image.py +++ b/tests/keras/preprocessing/test_image.py @@ -118,6 +118,48 @@ class TestImage: x = np.random.random((32, 3, 10, 10)) generator.fit(x) + def test_directory_iterator(self): + num_classes = 2 + tmp_folder = tempfile.mkdtemp(prefix='test_images') + + # create folders and subfolders + paths = [] + for cl in range(num_classes): + class_directory = 'class-{}'.format(cl) + classpaths = [ + class_directory, + os.path.join(class_directory, 'subfolder-1'), + os.path.join(class_directory, 'subfolder-2'), + os.path.join(class_directory, 'subfolder-1', 'sub-subfolder') + ] + for path in classpaths: + os.mkdir(os.path.join(tmp_folder, path)) + paths.append(classpaths) + + # save the images in the paths + count = 0 + filenames = [] + for test_images in self.all_test_images: + for im in test_images: + # rotate image class + im_class = count % num_classes + # rotate subfolders + classpaths = paths[im_class] + filename = os.path.join(classpaths[count % len(classpaths)], 'image-{}.jpg'.format(count)) + filenames.append(filename) + im.save(os.path.join(tmp_folder, filename)) + count += 1 + + # create iterator + generator = image.ImageDataGenerator() + dir_iterator = generator.flow_from_directory(tmp_folder) + + # check number of classes and images + assert(len(dir_iterator.class_indices) == num_classes) + assert(len(dir_iterator.classes) == count) + assert(sorted(dir_iterator.filenames) == sorted(filenames)) + shutil.rmtree(tmp_folder) + def test_img_utils(self): height, width = 10, 8