diff --git a/examples/cifar10_cnn.py b/examples/cifar10_cnn.py index b49524805..14e4212fb 100644 --- a/examples/cifar10_cnn.py +++ b/examples/cifar10_cnn.py @@ -28,6 +28,17 @@ nb_classes = 10 nb_epoch = 200 data_augmentation = True +# shape of the image (SHAPE x SHAPE) +shapex, shapey = 32, 32 +# number of convolutional filters to use at each layer +nb_filters = [32, 64] +# level of pooling to perform at each layer (POOL x POOL) +nb_pool = [2, 2] +# level of convolution to perform at each layer (CONV x CONV) +nb_conv = [3, 3] +# the CIFAR10 images are RGB +image_dimensions = 3 + # the data, shuffled and split between tran and test sets (X_train, y_train), (X_test, y_test) = cifar10.load_data() print('X_train shape:', X_train.shape) @@ -40,22 +51,24 @@ Y_test = np_utils.to_categorical(y_test, nb_classes) model = Sequential() -model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) +model.add(Convolution2D(nb_filters[0], image_dimensions, nb_conv[0], nb_conv[0], border_mode='full')) model.add(Activation('relu')) -model.add(Convolution2D(32, 32, 3, 3)) +model.add(Convolution2D(nb_filters[0], nb_filters[0], nb_conv[0], nb_conv[0])) model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) +model.add(MaxPooling2D(poolsize=(nb_pool[0], nb_pool[0]))) model.add(Dropout(0.25)) -model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) +model.add(Convolution2D(nb_filters[1], nb_filters[0], nb_conv[0], nb_conv[0], border_mode='full')) model.add(Activation('relu')) -model.add(Convolution2D(64, 64, 3, 3)) +model.add(Convolution2D(nb_filters[1], nb_filters[1], nb_conv[1], nb_conv[1])) model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) +model.add(MaxPooling2D(poolsize=(nb_pool[1], nb_pool[1]))) model.add(Dropout(0.25)) model.add(Flatten()) -model.add(Dense(64*8*8, 512)) +# the image dimensions are the original dimensions divided by any pooling +# each pixel has a number of filters, determined by the last Convolution2D layer +model.add(Dense(nb_filters[-1] * (shapex / nb_pool[0] / nb_pool[1]) * (shapey / nb_pool[0] / nb_pool[1]), 512)) model.add(Activation('relu')) model.add(Dropout(0.5))