Remove magic numbers from cifar10_cnn.py
(fixes #469)
This commit is contained in:
parent
93c1a8c675
commit
49335d4345
@ -28,6 +28,17 @@ nb_classes = 10
|
|||||||
nb_epoch = 200
|
nb_epoch = 200
|
||||||
data_augmentation = True
|
data_augmentation = True
|
||||||
|
|
||||||
|
# shape of the image (SHAPE x SHAPE)
|
||||||
|
shapex, shapey = 32, 32
|
||||||
|
# number of convolutional filters to use at each layer
|
||||||
|
nb_filters = [32, 64]
|
||||||
|
# level of pooling to perform at each layer (POOL x POOL)
|
||||||
|
nb_pool = [2, 2]
|
||||||
|
# level of convolution to perform at each layer (CONV x CONV)
|
||||||
|
nb_conv = [3, 3]
|
||||||
|
# the CIFAR10 images are RGB
|
||||||
|
image_dimensions = 3
|
||||||
|
|
||||||
# the data, shuffled and split between tran and test sets
|
# the data, shuffled and split between tran and test sets
|
||||||
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
|
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
|
||||||
print('X_train shape:', X_train.shape)
|
print('X_train shape:', X_train.shape)
|
||||||
@ -40,22 +51,24 @@ Y_test = np_utils.to_categorical(y_test, nb_classes)
|
|||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
|
|
||||||
model.add(Convolution2D(32, 3, 3, 3, border_mode='full'))
|
model.add(Convolution2D(nb_filters[0], image_dimensions, nb_conv[0], nb_conv[0], border_mode='full'))
|
||||||
model.add(Activation('relu'))
|
model.add(Activation('relu'))
|
||||||
model.add(Convolution2D(32, 32, 3, 3))
|
model.add(Convolution2D(nb_filters[0], nb_filters[0], nb_conv[0], nb_conv[0]))
|
||||||
model.add(Activation('relu'))
|
model.add(Activation('relu'))
|
||||||
model.add(MaxPooling2D(poolsize=(2, 2)))
|
model.add(MaxPooling2D(poolsize=(nb_pool[0], nb_pool[0])))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
|
|
||||||
model.add(Convolution2D(64, 32, 3, 3, border_mode='full'))
|
model.add(Convolution2D(nb_filters[1], nb_filters[0], nb_conv[0], nb_conv[0], border_mode='full'))
|
||||||
model.add(Activation('relu'))
|
model.add(Activation('relu'))
|
||||||
model.add(Convolution2D(64, 64, 3, 3))
|
model.add(Convolution2D(nb_filters[1], nb_filters[1], nb_conv[1], nb_conv[1]))
|
||||||
model.add(Activation('relu'))
|
model.add(Activation('relu'))
|
||||||
model.add(MaxPooling2D(poolsize=(2, 2)))
|
model.add(MaxPooling2D(poolsize=(nb_pool[1], nb_pool[1])))
|
||||||
model.add(Dropout(0.25))
|
model.add(Dropout(0.25))
|
||||||
|
|
||||||
model.add(Flatten())
|
model.add(Flatten())
|
||||||
model.add(Dense(64*8*8, 512))
|
# the image dimensions are the original dimensions divided by any pooling
|
||||||
|
# each pixel has a number of filters, determined by the last Convolution2D layer
|
||||||
|
model.add(Dense(nb_filters[-1] * (shapex / nb_pool[0] / nb_pool[1]) * (shapey / nb_pool[0] / nb_pool[1]), 512))
|
||||||
model.add(Activation('relu'))
|
model.add(Activation('relu'))
|
||||||
model.add(Dropout(0.5))
|
model.add(Dropout(0.5))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user