keras/examples/antirectifier.py

106 lines
3.3 KiB
Python
Raw Normal View History

2016-01-08 04:22:33 +00:00
'''The example demonstrates how to write custom layers for Keras.
We build a custom activation layer called 'Antirectifier',
which modifies the shape of the tensor that passes through it.
We need to specify two methods: `compute_output_shape` and `call`.
2016-01-08 04:22:33 +00:00
2016-01-08 05:18:34 +00:00
Note that the same result can also be achieved via a Lambda layer.
2016-01-08 04:22:33 +00:00
Because our custom layer is written with primitives from the Keras
backend (`K`), our code can run both on TensorFlow and Theano.
'''
from __future__ import print_function
import keras
2016-01-08 04:22:33 +00:00
from keras.models import Sequential
2017-03-11 23:03:27 +00:00
from keras import layers
2016-01-08 04:22:33 +00:00
from keras.datasets import mnist
from keras import backend as K
2017-03-11 23:03:27 +00:00
class Antirectifier(layers.Layer):
2016-01-08 04:22:33 +00:00
'''This is the combination of a sample-wise
L2 normalization with the concatenation of the
2016-01-08 05:18:34 +00:00
positive part of the input with the negative part
of the input. The result is a tensor of samples that are
2016-01-08 04:22:33 +00:00
twice as large as the input samples.
It can be used in place of a ReLU.
# Input shape
2D tensor of shape (samples, n)
# Output shape
2D tensor of shape (samples, 2*n)
# Theoretical justification
When applying ReLU, assuming that the distribution
of the previous output is approximately centered around 0.,
you are discarding half of your input. This is inefficient.
Antirectifier allows to return all-positive outputs like ReLU,
without discarding any data.
Tests on MNIST show that Antirectifier allows to train networks
2016-01-08 05:18:34 +00:00
with twice less parameters yet with comparable
classification accuracy as an equivalent ReLU-based network.
2016-01-08 04:22:33 +00:00
'''
2017-01-11 19:39:58 +00:00
2017-02-15 00:08:30 +00:00
def compute_output_shape(self, input_shape):
2016-03-19 16:07:15 +00:00
shape = list(input_shape)
2016-01-08 04:22:33 +00:00
assert len(shape) == 2 # only valid for 2D tensors
shape[-1] *= 2
return tuple(shape)
2017-02-20 03:24:32 +00:00
def call(self, inputs):
inputs -= K.mean(inputs, axis=1, keepdims=True)
inputs = K.l2_normalize(inputs, axis=1)
pos = K.relu(inputs)
neg = K.relu(-inputs)
2016-01-08 04:22:33 +00:00
return K.concatenate([pos, neg], axis=1)
# global parameters
batch_size = 128
2017-02-15 00:08:30 +00:00
num_classes = 10
epochs = 40
2016-01-08 04:22:33 +00:00
2016-04-04 18:30:24 +00:00
# the data, shuffled and split between train and test sets
2017-02-20 03:24:32 +00:00
(x_train, y_train), (x_test, y_test) = mnist.load_data()
2016-01-08 04:22:33 +00:00
2017-02-20 03:24:32 +00:00
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
2016-01-08 04:22:33 +00:00
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
2016-01-08 04:22:33 +00:00
# build the model
model = Sequential()
2017-03-11 23:03:27 +00:00
model.add(layers.Dense(256, input_shape=(784,)))
2016-01-08 04:22:33 +00:00
model.add(Antirectifier())
2017-03-11 23:03:27 +00:00
model.add(layers.Dropout(0.1))
model.add(layers.Dense(256))
2016-01-08 04:22:33 +00:00
model.add(Antirectifier())
2017-03-11 23:03:27 +00:00
model.add(layers.Dropout(0.1))
model.add(layers.Dense(10))
model.add(layers.Activation('softmax'))
2016-01-08 04:22:33 +00:00
# compile the model
2016-03-19 16:07:15 +00:00
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
2016-01-08 04:22:33 +00:00
# train the model
model.fit(x_train, y_train,
2017-02-15 00:08:30 +00:00
batch_size=batch_size, epochs=epochs,
verbose=1, validation_data=(x_test, y_test))
2016-01-08 04:22:33 +00:00
# next, compare with an equivalent network
# with2x bigger Dense layers and ReLU