Update documentation to new API.

This commit is contained in:
Makoto Matsuyama 2015-10-05 16:28:17 -07:00
parent cb77f7d7e2
commit 0b8a52e463
17 changed files with 271 additions and 222 deletions

116
README.md

@ -34,13 +34,16 @@ from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD from keras.optimizers import SGD
model = Sequential() model = Sequential()
model.add(Dense(20, 64, init='uniform')) # Dense(64) is a fully-connected layer with 64 hidden units.
# in the first layer, you must specify the expected input data shape:
# here, 20-dimensional vectors.
model.add(Dense(64, input_dim=20, init='uniform'))
model.add(Activation('tanh')) model.add(Activation('tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform')) model.add(Dense(64, init='uniform'))
model.add(Activation('tanh')) model.add(Activation('tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform')) model.add(Dense(2, init='uniform'))
model.add(Activation('softmax')) model.add(Activation('softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
@ -54,11 +57,11 @@ score = model.evaluate(X_test, y_test, batch_size=16)
```python ```python
model = Sequential() model = Sequential()
model.add(Dense(20, 64, init='uniform', activation='tanh')) model.add(Dense(64, input_dim=20, init='uniform', activation='tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform', activation='tanh')) model.add(Dense(64, init='uniform', activation='tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform', activation='softmax')) model.add(Dense(2, init='uniform', activation='softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd) model.compile(loss='mean_squared_error', optimizer=sgd)
@ -73,26 +76,29 @@ from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD from keras.optimizers import SGD
model = Sequential() model = Sequential()
model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) # input: 100x100 images with 3 channels -> (3, 100, 100) tensors.
# this applies 32 convolution filters of size 3x3 each.
model.add(Convolution2D(32, 3, 3, border_mode='full', input_shape=(3, 100, 100)))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Convolution2D(32, 32, 3, 3)) model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2))) model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) model.add(Convolution2D(64, 3, 3, border_mode='valid'))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Convolution2D(64, 64, 3, 3)) model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2))) model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Flatten()) model.add(Flatten())
model.add(Dense(64*8*8, 256)) # Note: Keras does automatic shape inference.
model.add(Dense(256))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(256, 10)) model.add(Dense(10))
model.add(Activation('softmax')) model.add(Activation('softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
@ -112,9 +118,9 @@ from keras.layers.recurrent import LSTM
model = Sequential() model = Sequential()
model.add(Embedding(max_features, 256)) model.add(Embedding(max_features, 256))
model.add(LSTM(256, 128, activation='sigmoid', inner_activation='hard_sigmoid')) model.add(LSTM(output_dim=128, activation='sigmoid', inner_activation='hard_sigmoid'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(128, 1)) model.add(Dense(1))
model.add(Activation('sigmoid')) model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='rmsprop') model.compile(loss='binary_crossentropy', optimizer='rmsprop')
@ -126,51 +132,67 @@ score = model.evaluate(X_test, Y_test, batch_size=16)
### Architecture for learning image captions with a convnet and a Gated Recurrent Unit: ### Architecture for learning image captions with a convnet and a Gated Recurrent Unit:
(word-level embedding, caption of maximum length 16 words). (word-level embedding, caption of maximum length 16 words).
Note that getting this to actually "work" will require using a bigger convnet, initialized with pre-trained weights. Note that getting this to work well will require using a bigger convnet, initialized with pre-trained weights.
Displaying readable results will also require an embedding decoder.
```python ```python
max_caption_len = 16 max_caption_len = 16
vocab_size = 10000
model = Sequential() # first, let's define an image model that
model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) # will encode pictures into 128-dimensional vectors.
model.add(Activation('relu')) # it should be initialized with pre-trained weights.
model.add(Convolution2D(32, 32, 3, 3)) image_model = Sequential()
model.add(Activation('relu')) image_model.add(Convolution2D(32, 3, 3, border_mode='full', input_shape=(3, 100, 100)))
model.add(MaxPooling2D(poolsize=(2, 2))) image_model.add(Activation('relu'))
image_model.add(Convolution2D(32, 3, 3))
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) image_model.add(Convolution2D(64, 3, 3, border_mode='full'))
model.add(Activation('relu')) image_model.add(Activation('relu'))
model.add(Convolution2D(64, 64, 3, 3)) image_model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu')) image_model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2))) image_model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Convolution2D(128, 64, 3, 3, border_mode='full')) image_model.add(Flatten())
model.add(Activation('relu')) image_model.add(Dense(128))
model.add(Convolution2D(128, 128, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Flatten()) # let's load the weights from a save file.
model.add(Dense(128*4*4, 256)) image_model.load_weights('weight_file.h5')
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(RepeatVector(max_caption_len)) # next, let's define a RNN model that encodes sequences of words
# the GRU below returns sequences of max_caption_len vectors of size 256 (our word embedding size) # into sequences of 128-dimensional word vectors.
model.add(GRU(256, 256, return_sequences=True)) language_model = Sequential()
language_model.add(Embedding(vocab_size, 256, input_length=max_caption_len))
language_model.add(GRU(output_dim=128, return_sequences=True))
language_model.add(Dense(128))
model.compile(loss='mean_squared_error', optimizer='rmsprop') # let's repeat the image vector to turn it into a sequence.
image_model.add(RepeatVector(max_caption_len))
# "images" is a numpy array of shape (nb_samples, nb_channels=3, width, height) # the output of both models will be tensors of shape (samples, max_caption_len, 128).
# "captions" is a numpy array of shape (nb_samples, max_caption_len=16, embedding_dim=256) # let's concatenate these 2 vector sequences.
# captions are supposed already embedded (dense vectors). model = Merge([image_model, language_model], mode='concat', concat_axis=-1)
model.fit(images, captions, batch_size=16, nb_epoch=100) # let's encode this vector sequence into a single vector
model.add(GRU(256, 256, return_sequences=False))
# which will be used to compute a probability
# distribution over what the next word in the caption should be!
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# "images" is a numpy float array of shape (nb_samples, nb_channels=3, width, height).
# "captions" is a numpy integer array of shape (nb_samples, max_caption_len)
# containing word index sequences representing partial captions.
# "next_words" is a numpy float array of shape (nb_samples, vocab_size)
# containing a categorical encoding (0s and 1s) of the next word in the corresponding
# partial caption.
model.fit([images, partial_captions], next_words, batch_size=16, nb_epoch=100)
``` ```
In the examples folder, you will find example models for real datasets: In the examples folder, you will find example models for real datasets:
- CIFAR10 small images classification: Convnet with realtime data augmentation - CIFAR10 small images classification: Convolutional Neural Network (CNN) with realtime data augmentation
- IMDB movie review sentiment classification: LSTM over sequences of words - IMDB movie review sentiment classification: LSTM over sequences of words
- Reuters newswires topic classification: Multilayer Perceptron (MLP) - Reuters newswires topic classification: Multilayer Perceptron (MLP)
- MNIST handwritten digits classification: MLP & CNN - MNIST handwritten digits classification: MLP & CNN
@ -183,7 +205,7 @@ In the examples folder, you will find example models for real datasets:
For complete coverage of the API, check out [the Keras documentation](http://keras.io). For complete coverage of the API, check out [the Keras documentation](http://keras.io).
A few highlights: convnets, LSTM, GRU, word2vec-style embeddings, PReLU, batch normalization... A few highlights: convnets, LSTM, GRU, word2vec-style embeddings, PReLU, BatchNormalization...
## Installation ## Installation
@ -196,7 +218,7 @@ Keras uses the following dependencies:
- HDF5 and h5py (optional, required if you use model saving/loading functions) - HDF5 and h5py (optional, required if you use model saving/loading functions)
- Optional but recommended if you use CNNs: cuDNN. - Optional but recommended if you use CNNs: cuDNN.
Once you have the dependencies installed, cd to the Keras folder and run the install command: To install, `cd` to the Keras folder and run the install command:
``` ```
sudo python setup.py install sudo python setup.py install
``` ```

@ -6,12 +6,12 @@ Activations can either be used through an `Activation` layer, or through the `ac
```python ```python
from keras.layers.core import Activation, Dense from keras.layers.core import Activation, Dense
model.add(Dense(64, 64, init='uniform')) model.add(Dense(64))
model.add(Activation('tanh')) model.add(Activation('tanh'))
``` ```
is equivalent to: is equivalent to:
```python ```python
model.add(Dense(20, 64, init='uniform', activation='tanh')) model.add(Dense(64, activation='tanh'))
``` ```
You can also pass an element-wise Theano function as an activation: You can also pass an element-wise Theano function as an activation:
@ -20,7 +20,7 @@ You can also pass an element-wise Theano function as an activation:
def tanh(x): def tanh(x):
return theano.tensor.tanh(x) return theano.tensor.tanh(x)
model.add(Dense(20, 64, init='uniform', activation=tanh)) model.add(Dense(64, activation=tanh))
model.add(Activation(tanh)) model.add(Activation(tanh))
``` ```

@ -75,7 +75,7 @@ class LossHistory(keras.callbacks.Callback):
self.losses.append(logs.get('loss')) self.losses.append(logs.get('loss'))
model = Sequential() model = Sequential()
model.add(Dense(784, 10, init='uniform')) model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax')) model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
@ -97,7 +97,7 @@ print history.losses
from keras.callbacks import ModelCheckpoint from keras.callbacks import ModelCheckpoint
model = Sequential() model = Sequential()
model.add(Dense(784, 10, init='uniform')) model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax')) model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

@ -12,7 +12,7 @@ These layers expose 2 keyword arguments:
```python ```python
from keras.constraints import maxnorm from keras.constraints import maxnorm
model.add(Dense(64, 64, W_constraint = maxnorm(2))) model.add(Dense(64, W_constraint = maxnorm(2)))
``` ```
## Available constraints ## Available constraints

@ -1,7 +1,7 @@
Here are a few examples to get you started! Here are a few examples to get you started!
### Multilayer Perceptron (MLP) ### Multilayer Perceptron (MLP):
```python ```python
from keras.models import Sequential from keras.models import Sequential
@ -9,13 +9,16 @@ from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD from keras.optimizers import SGD
model = Sequential() model = Sequential()
model.add(Dense(20, 64, init='uniform')) # Dense(64) is a fully-connected layer with 64 hidden units.
# in the first layer, you must specify the expected input data shape:
# here, 20-dimensional vectors.
model.add(Dense(64, input_dim=20, init='uniform'))
model.add(Activation('tanh')) model.add(Activation('tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform')) model.add(Dense(64, init='uniform'))
model.add(Activation('tanh')) model.add(Activation('tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform')) model.add(Dense(2, init='uniform'))
model.add(Activation('softmax')) model.add(Activation('softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
@ -25,25 +28,21 @@ model.fit(X_train, y_train, nb_epoch=20, batch_size=16)
score = model.evaluate(X_test, y_test, batch_size=16) score = model.evaluate(X_test, y_test, batch_size=16)
``` ```
--- ### Alternative implementation of MLP:
### Alternative implementation of MLP
```python ```python
model = Sequential() model = Sequential()
model.add(Dense(20, 64, init='uniform', activation='tanh')) model.add(Dense(64, input_dim=20, init='uniform', activation='tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform', activation='tanh')) model.add(Dense(64, init='uniform', activation='tanh'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform', activation='softmax')) model.add(Dense(2, init='uniform', activation='softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd) model.compile(loss='mean_squared_error', optimizer=sgd)
``` ```
--- ### VGG-like convnet:
### VGG-like convnet
```python ```python
from keras.models import Sequential from keras.models import Sequential
@ -52,26 +51,29 @@ from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD from keras.optimizers import SGD
model = Sequential() model = Sequential()
model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) # input: 100x100 images with 3 channels -> (3, 100, 100) tensors.
# this applies 32 convolution filters of size 3x3 each.
model.add(Convolution2D(32, 3, 3, border_mode='full', input_shape=(3, 100, 100)))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Convolution2D(32, 32, 3, 3)) model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2))) model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) model.add(Convolution2D(64, 3, 3, border_mode='valid'))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Convolution2D(64, 64, 3, 3)) model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2))) model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Dropout(0.25)) model.add(Dropout(0.25))
model.add(Flatten()) model.add(Flatten())
model.add(Dense(64*8*8, 256)) # Note: Keras does automatic shape inference.
model.add(Dense(256))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(256, 10)) model.add(Dense(10))
model.add(Activation('softmax')) model.add(Activation('softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
@ -81,9 +83,7 @@ model.fit(X_train, Y_train, batch_size=32, nb_epoch=1)
``` ```
--- ### Sequence classification with LSTM:
### Sequence classification with LSTM
```python ```python
from keras.models import Sequential from keras.models import Sequential
@ -92,11 +92,10 @@ from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM from keras.layers.recurrent import LSTM
model = Sequential() model = Sequential()
# Add a mask_zero=True to the Embedding connstructor if 0 is a left-padding value in your data
model.add(Embedding(max_features, 256)) model.add(Embedding(max_features, 256))
model.add(LSTM(256, 128, activation='sigmoid', inner_activation='hard_sigmoid')) model.add(LSTM(output_dim=128, activation='sigmoid', inner_activation='hard_sigmoid'))
model.add(Dropout(0.5)) model.add(Dropout(0.5))
model.add(Dense(128, 1)) model.add(Dense(1))
model.add(Activation('sigmoid')) model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='rmsprop') model.compile(loss='binary_crossentropy', optimizer='rmsprop')
@ -105,59 +104,73 @@ model.fit(X_train, Y_train, batch_size=16, nb_epoch=10)
score = model.evaluate(X_test, Y_test, batch_size=16) score = model.evaluate(X_test, Y_test, batch_size=16)
``` ```
--- ### Architecture for learning image captions with a convnet and a Gated Recurrent Unit:
(word-level embedding, caption of maximum length 16 words).
### Image captioning Note that getting this to work well will require using a bigger convnet, initialized with pre-trained weights.
Architecture for learning image captions with a convnet and a Gated Recurrent Unit (word-level embedding, caption of maximum length 16 words).
Note that getting this to actually "work" will require using a bigger convnet, initialized with pre-trained weights.
Displaying readable results will also require an embedding decoder.
```python ```python
max_caption_len = 16 max_caption_len = 16
vocab_size = 10000
model = Sequential() # first, let's define an image model that
model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) # will encode pictures into 128-dimensional vectors.
model.add(Activation('relu')) # it should be initialized with pre-trained weights.
model.add(Convolution2D(32, 32, 3, 3)) image_model = Sequential()
model.add(Activation('relu')) image_model.add(Convolution2D(32, 3, 3, border_mode='full', input_shape=(3, 100, 100)))
model.add(MaxPooling2D(poolsize=(2, 2))) image_model.add(Activation('relu'))
image_model.add(Convolution2D(32, 3, 3))
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) image_model.add(Convolution2D(64, 3, 3, border_mode='full'))
model.add(Activation('relu')) image_model.add(Activation('relu'))
model.add(Convolution2D(64, 64, 3, 3)) image_model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu')) image_model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2))) image_model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Convolution2D(128, 64, 3, 3, border_mode='full')) image_model.add(Flatten())
model.add(Activation('relu')) image_model.add(Dense(128))
model.add(Convolution2D(128, 128, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Flatten()) # let's load the weights from a save file.
model.add(Dense(128*4*4, 256)) image_model.load_weights('weight_file.h5')
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(RepeatVector(max_caption_len)) # next, let's define a RNN model that encodes sequences of words
# the GRU below returns sequences of max_caption_len vectors of size 256 (our word embedding size) # into sequences of 128-dimensional word vectors.
model.add(GRU(256, 256, return_sequences=True)) language_model = Sequential()
language_model.add(Embedding(vocab_size, 256, input_length=max_caption_len))
language_model.add(GRU(output_dim=128, return_sequences=True))
language_model.add(Dense(128))
model.compile(loss='mean_squared_error', optimizer='rmsprop') # let's repeat the image vector to turn it into a sequence.
image_model.add(RepeatVector(max_caption_len))
# "images" is a numpy array of shape (nb_samples, nb_channels=3, width, height) # the output of both models will be tensors of shape (samples, max_caption_len, 128).
# "captions" is a numpy array of shape (nb_samples, max_caption_len=16, embedding_dim=256) # let's concatenate these 2 vector sequences.
# captions are supposed already embedded (dense vectors). model = Merge([image_model, language_model], mode='concat', concat_axis=-1)
model.fit(images, captions, batch_size=16, nb_epoch=100) # let's encode this vector sequence into a single vector
model.add(GRU(256, 256, return_sequences=False))
# which will be used to compute a probability
# distribution over what the next word in the caption should be!
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# "images" is a numpy float array of shape (nb_samples, nb_channels=3, width, height).
# "captions" is a numpy integer array of shape (nb_samples, max_caption_len)
# containing word index sequences representing partial captions.
# "next_words" is a numpy float array of shape (nb_samples, vocab_size)
# containing a categorical encoding (0s and 1s) of the next word in the corresponding
# partial caption.
model.fit([images, partial_captions], next_words, batch_size=16, nb_epoch=100)
``` ```
--- In the examples folder, you will find example models for real datasets:
- CIFAR10 small images classification: Convolutional Neural Network (CNN) with realtime data augmentation
In the [examples folder](https://github.com/fchollet/keras/tree/master/examples), you will find example models for real datasets:
- CIFAR10 small images classification: Convnet with realtime data augmentation
- IMDB movie review sentiment classification: LSTM over sequences of words - IMDB movie review sentiment classification: LSTM over sequences of words
- Reuters newswires topic classification: Multilayer Perceptron - Reuters newswires topic classification: Multilayer Perceptron (MLP)
- MNIST handwritten digits classification: MLP & CNN
- Character-level text generation with LSTM
...and more.

@ -46,9 +46,9 @@ Stacking layers is as easy as `.add()`:
```python ```python
from keras.layers.core import Dense, Activation from keras.layers.core import Dense, Activation
model.add(Dense(input_dim=100, output_dim=64, init="glorot_uniform")) model.add(Dense(output_dim=64, input_dim=100, init="glorot_uniform"))
model.add(Activation("relu")) model.add(Activation("relu"))
model.add(Dense(input_dim=64, output_dim=10, init="glorot_uniform")) model.add(Dense(output_dim=10, init="glorot_uniform"))
model.add(Activation("softmax")) model.add(Activation("softmax"))
``` ```

@ -6,7 +6,7 @@ Initializations define the probability distribution used to set the initial rand
The keyword arguments used for passing initializations to layers will depend on the layer. Usually it is simply `init`: The keyword arguments used for passing initializations to layers will depend on the layer. Usually it is simply `init`:
```python ```python
model.add(Dense(64, 64, init='uniform')) model.add(Dense(64, init='uniform'))
``` ```
## Available initializations ## Available initializations

@ -7,7 +7,8 @@ keras.layers.advanced_activations.LeakyReLU(alpha=0.3)
Special version of a Rectified Linear Unit that allows a small gradient when the unit is not active (`f(x) = alpha*x for x < 0`). Special version of a Rectified Linear Unit that allows a small gradient when the unit is not active (`f(x) = alpha*x for x < 0`).
- __Input shape__: This layer does not assume a specific input shape. As a result, it cannot be used as the first layer in a model.
- __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
@ -19,18 +20,16 @@ Special version of a Rectified Linear Unit that allows a small gradient when the
## PReLU ## PReLU
```python ```python
keras.layers.advanced_activations.PReLU(input_shape) keras.layers.advanced_activations.PReLU()
``` ```
Parametrized linear unit. Similar to a LeakyReLU, where each input unit has its alpha coefficient, and where these coefficients are learned during training. Parametrized linear unit. Similar to a LeakyReLU, where each input unit has its alpha coefficient, and where these coefficients are learned during training.
- __Input shape__: Same as `input_shape`. This layer cannot be used as first layer in a model.
- __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
- __Arguments__:
- __input_shape__: tuple.
- __References__: - __References__:
- [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](http://arxiv.org/pdf/1502.01852v1.pdf) - [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](http://arxiv.org/pdf/1502.01852v1.pdf)
@ -39,18 +38,15 @@ Parametrized linear unit. Similar to a LeakyReLU, where each input unit has its
## ParametricSoftplus ## ParametricSoftplus
```python ```python
keras.layers.advanced_activations.ParametricSoftplus(input_shape) keras.layers.advanced_activations.ParametricSoftplus()
``` ```
Parametric Softplus of the form: (`f(x) = alpha * (1 + exp(beta * x))`). This is essentially a smooth version of ReLU where the parameters control the sharpness of the rectification. The parameters are initialized to more closely approximate a ReLU than the standard `softplus`: `alpha` initialized to `0.2` and `beta` initialized to `5.0`. The parameters are fit separately for each hidden unit. Parametric Softplus of the form: (`f(x) = alpha * (1 + exp(beta * x))`). This is essentially a smooth version of ReLU where the parameters control the sharpness of the rectification. The parameters are initialized to more closely approximate a ReLU than the standard `softplus`: `alpha` initialized to `0.2` and `beta` initialized to `5.0`. The parameters are fit separately for each hidden unit.
- __Input shape__: Same as `input_shape`. This layer cannot be used as first layer in a model. - __Input shape__: Arbitrary. Use the keyword argument `input_shape=...` when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
- __Arguments__:
- __input_shape__: tuple.
- __References__: - __References__:
- [Inferring Nonlinear Neuronal Computation Based on Physiologically Plausible Inputs](http://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1003143) - [Inferring Nonlinear Neuronal Computation Based on Physiologically Plausible Inputs](http://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1003143)
@ -62,7 +58,8 @@ keras.layers.advanced_activations.ThresholdedLinear(theta)
Parametrized linear unit. provides a threshold near zero where values are zeroed. Parametrized linear unit. provides a threshold near zero where values are zeroed.
- __Input shape__: Same as `input_shape`. This layer cannot be used as first layer in a model.
- __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
@ -80,7 +77,7 @@ keras.layers.advanced_activations.ThresholdedReLu(theta)
Parametrized rectified linear unit. provides a threshold near zero where values are zeroed. Parametrized rectified linear unit. provides a threshold near zero where values are zeroed.
- __Input shape__: Same as `input_shape`. This layer cannot be used as first layer in a model. - __Input shape__: Arbitrary. Use the keyword argument `input_shape=...` when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.

@ -2,22 +2,20 @@
## Convolution1D ## Convolution1D
```python ```python
keras.layers.convolutional.Convolution1D(input_dim, nb_filter, filter_length, keras.layers.convolutional.Convolution1D(nb_filter, filter_length,
init='uniform', activation='linear', weights=None, init='uniform', activation='linear', weights=None,
border_mode='valid', subsample_length=1, border_mode='valid', subsample_length=1,
W_regularizer=None, b_regularizer=None, W_constraint=None, W_regularizer=None, b_regularizer=None, W_constraint=None,
b_constraint=None) b_constraint=None, input_dim=None, input_length=None)
``` ```
Convolution operator for filtering neighborhoods of one-dimensional inputs. Convolution operator for filtering neighborhoods of one-dimensional inputs. When using this layer as the first layer in a model, either provide the keyword argument `input_dim` (int, e.g. 128 for sequences of 128-dimensional vectors), or `input_shape` (tuple of integers, e.g. (10, 128) for sequences of 10 vectors of 128-dimensional vectors).
- __Input shape__: 3D tensor with shape: `(nb_samples, steps, input_dim)`. - __Input shape__: 3D tensor with shape: `(nb_samples, steps, input_dim)`.
- __Output shape__: 3D tensor with shape: `(nb_samples, steps, nb_filter)`. `steps` value might have changed due to padding. - __Output shape__: 3D tensor with shape: `(nb_samples, steps, nb_filter)`. `steps` value might have changed due to padding.
- __Arguments__: - __Arguments__:
- __input_dim__: Number of channels/dimensions in the input.
- __nb_filter__: Number of convolution kernels to use (dimensionality of the output). - __nb_filter__: Number of convolution kernels to use (dimensionality of the output).
- __filter_length__: The extension (spatial or temporal) of each filter. - __filter_length__: The extension (spatial or temporal) of each filter.
- __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument. - __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument.
@ -30,31 +28,32 @@ Convolution operator for filtering neighborhoods of one-dimensional inputs.
- __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output. - __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output.
- __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix. - __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix.
- __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias. - __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias.
- __input_dim__: Number of channels/dimensions in the input. Either this argument or the keyword argument `input_shape` must be provided when using this layer as the first layer in a model.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
--- ---
## Convolution2D ## Convolution2D
```python ```python
keras.layers.convolutional.Convolution2D(nb_filter, stack_size, nb_row, nb_col, keras.layers.convolutional.Convolution2D(nb_filter, nb_row, nb_col,
init='glorot_uniform', activation='linear', weights=None, init='glorot_uniform', activation='linear', weights=None,
border_mode='valid', subsample=(1, 1), border_mode='valid', subsample=(1, 1),
W_regularizer=None, b_regularizer=None, W_constraint=None) W_regularizer=None, b_regularizer=None, W_constraint=None)
``` ```
Convolution operator for filtering windows of two-dimensional inputs. Convolution operator for filtering windows of two-dimensional inputs. When using this layer as the first layer in a model, provide the keyword argument `input_shape` (tuple of integers, does not include the sample axis), e.g. `input_shape=(3, 128, 128)` for 128x128 RGB pictures.
- __Input shape__: 4D tensor with shape: `(nb_samples, stack_size, nb_row, nb_col)`. - __Input shape__: 4D tensor with shape: `(nb_samples, channels, rows, cols)`.
- __Output shape__: 4D tensor with shape: `(nb_samples, nb_filter, nb_row, nb_col)`. `nb_row`, `nb_col` might have changed due to padding. - __Output shape__: 4D tensor with shape: `(nb_samples, nb_filter, rows, cols)`. `rows`, `cols` might have changed due to padding.
- __Arguments__: - __Arguments__:
- __nb_filter__: Number of convolution kernels to use. - __nb_filter__: Number of convolution filters to use.
- __stack_size__: Number of channels in the input. - __nb_row__: Number of rows in the convolution kernel.
- __nb_row__: Number of rows in the convolution kernels - __nb_col__: Number of columns in the convolution kernel.
- __nb_col__: Number of columns in the convolution kernels
- __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument. - __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument.
- __activation__: name of activation function to use (see: [activations](../activations.md)), or alternatively, elementwise Theano function. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x). - __activation__: name of activation function to use (see: [activations](../activations.md)), or alternatively, elementwise Theano function. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).
- __weights__: list of numpy arrays to set as initial weights. - __weights__: list of numpy arrays to set as initial weights.
@ -90,7 +89,7 @@ keras.layers.convolutional.MaxPooling1D(pool_length=2, stride=None, ignore_borde
## MaxPooling2D ## MaxPooling2D
```python ```python
keras.layers.convolutional.MaxPooling2D(poolsize=(2, 2), ignore_border=True) keras.layers.convolutional.MaxPooling2D(pool_size=(2, 2), ignore_border=True)
``` ```
- __Input shape__: 4D tensor with shape: `(nb_samples, stack_size, nb_row, nb_col)`. - __Input shape__: 4D tensor with shape: `(nb_samples, stack_size, nb_row, nb_col)`.

@ -76,8 +76,9 @@ get_config()
## Dense ## Dense
```python ```python
keras.layers.core.Dense(input_dim, output_dim, init='glorot_uniform', activation='linear', weights=None \ keras.layers.core.Dense(output_dim, init='glorot_uniform', activation='linear', weights=None
W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None) W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None, input_dim=None)
``` ```
Standard 1D fully-connect layer. Standard 1D fully-connect layer.
@ -88,7 +89,6 @@ Standard 1D fully-connect layer.
- __Arguments__: - __Arguments__:
- __input_dim__: int >= 0.
- __output_dim__: int >= 0. - __output_dim__: int >= 0.
- __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument. - __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument.
- __activation__: name of activation function to use (see: [activations](../activations.md)), or alternatively, elementwise Theano function. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x). - __activation__: name of activation function to use (see: [activations](../activations.md)), or alternatively, elementwise Theano function. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).
@ -98,21 +98,22 @@ Standard 1D fully-connect layer.
- __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output. - __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output.
- __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix. - __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix.
- __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias. - __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
--- ---
## TimeDistributedDense ## TimeDistributedDense
```python ```python
keras.layers.core.TimeDistributedDense(input_dim, output_dim, init='glorot_uniform', activation='linear', weights=None \ keras.layers.core.TimeDistributedDense(output_dim, init='glorot_uniform', activation='linear', weights=None
W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None) W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None,
input_dim=None, input_length=None)
``` ```
Fully-connected layer distributed over the time dimension. Useful after a recurrent network set to `return_sequences=True`. Fully-connected layer distributed over the time dimension. Useful after a recurrent network set to `return_sequences=True`.
- __Input shape__: 3D tensor with shape: `(nb_samples, nb_timesteps, input_dim)`. - __Input shape__: 3D tensor with shape: `(nb_samples, timesteps, input_dim)`.
- __Arguments__: - __Arguments__:
- __input_dim__: int >= 0.
- __output_dim__: int >= 0. - __output_dim__: int >= 0.
- __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument. - __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument.
- __activation__: name of activation function to use (see: [activations](../activations.md)), or alternatively, elementwise Theano function. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x). - __activation__: name of activation function to use (see: [activations](../activations.md)), or alternatively, elementwise Theano function. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).
@ -122,12 +123,14 @@ Fully-connected layer distributed over the time dimension. Useful after a recurr
- __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output. - __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output.
- __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix. - __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix.
- __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias. - __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
- __Example__: - __Example__:
```python ```python
# input shape: (nb_samples, nb_timesteps, 10) # input shape: (nb_samples, timesteps, 10)
model.add(LSTM(10, 5, return_sequences=True)) # output shape: (nb_samples, nb_timesteps, 5) model.add(LSTM(5, return_sequences=True, input_dim=10)) # output shape: (nb_samples, timesteps, 5)
model.add(TimeDistributedDense(5, 10)) # output shape: (nb_samples, nb_timesteps, 10) model.add(TimeDistributedDense(15)) # output shape: (nb_samples, timesteps, 15)
``` ```
@ -160,8 +163,8 @@ A customizable autoencoder model. If `output_reconstruction = True` then dim(inp
from keras.layers import containers from keras.layers import containers
# input shape: (nb_samples, 32) # input shape: (nb_samples, 32)
encoder = containers.Sequential([Dense(32, 16), Dense(16, 8)]) encoder = containers.Sequential([Dense(16, input_dim=32), Dense(8)])
decoder = containers.Sequential([Dense(8, 16), Dense(16, 32)]) decoder = containers.Sequential([Dense(16, input_dim=8), Dense(32)])
autoencoder = Sequential() autoencoder = Sequential()
autoencoder.add(AutoEncoder(encoder=encoder, decoder=decoder, output_reconstruction=False)) autoencoder.add(AutoEncoder(encoder=encoder, decoder=decoder, output_reconstruction=False))
@ -176,7 +179,8 @@ keras.layers.core.Activation(activation)
``` ```
Apply an activation function to the input. Apply an activation function to the input.
- __Input shape__: This layer does not assume a specific input shape. As a result, it cannot be used as the first layer in a model.
- __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
@ -193,7 +197,8 @@ keras.layers.core.Dropout(p)
``` ```
Apply dropout to the input. Dropout consists in randomly setting a fraction `p` of input units to 0 at each update during training time, which helps prevent overfitting. Reference: [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) Apply dropout to the input. Dropout consists in randomly setting a fraction `p` of input units to 0 at each update during training time, which helps prevent overfitting. Reference: [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
- __Input shape__: This layer does not assume a specific input shape.
- __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
@ -206,24 +211,25 @@ Apply dropout to the input. Dropout consists in randomly setting a fraction `p`
## Reshape ## Reshape
```python ```python
keras.layers.core.Reshape(*dims) keras.layers.core.Reshape(dims)
``` ```
Reshape the input to a new shape containing the same number of units. Reshape the input to a new shape containing the same number of units.
- __Input shape__: This layer does not assume a specific input shape.
- __Output shape__: `(nb_samples, *dims)`. - __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: `(nb_samples, dims)`.
- __Arguments__: - __Arguments__:
- *dims: integers. Dimensions of the new shape. - dims: tuple of integers. Dimensions of the new shape.
- __Example__: - __Example__:
```python ```python
# input shape: (nb_samples, 10) # input shape: (nb_samples, 10)
model.add(Dense(10, 100)) # output shape: (nb_samples, 100) model.add(Dense(100, input_dim=10)) # output shape: (nb_samples, 100)
model.add(Reshape(10, 10)) # output shape: (nb_samples, 10, 10) model.add(Reshape(dims=(10, 10))) # output shape: (nb_samples, 10, 10)
``` ```
--- ---
@ -235,7 +241,7 @@ keras.layers.core.Flatten()
Convert a nD input to 1D. Convert a nD input to 1D.
- __Input shape__: (nb_samples, *). This layer cannot be used as the first layer in a model. - __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: `(nb_samples, nb_input_units)`. - __Output shape__: `(nb_samples, nb_input_units)`.
@ -250,7 +256,7 @@ Repeat the 1D input n times. Dimensions of input are assumed to be `(nb_samples,
Note that the output is still a single tensor; `RepeatVector` does not split the data flow. Note that the output is still a single tensor; `RepeatVector` does not split the data flow.
- __Input shape__: This layer does not assume a specific input shape. This layer cannot be used as the first layer in a model. - __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: `(nb_samples, n, input_dims)`. - __Output shape__: `(nb_samples, n, input_dims)`.
@ -265,7 +271,7 @@ keras.layers.core.Permute(dims)
``` ```
Permute the dimensions of the input data according to the given tuple. Sometimes useful for connecting RNNs and convnets together. Permute the dimensions of the input data according to the given tuple. Sometimes useful for connecting RNNs and convnets together.
- __Input shape__: This layer does not assume a specific input shape. - __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as the input shape, but with the dimensions re-ordered according to the ordering specified by the tuple. - __Output shape__: Same as the input shape, but with the dimensions re-ordered according to the ordering specified by the tuple.
@ -274,9 +280,9 @@ Permute the dimensions of the input data according to the given tuple. Sometimes
- __Example__: - __Example__:
```python ```python
# input shape: (nb_samples, 10) # input shape: (nb_samples, 10)
model.add(Dense(10, 50)) # output shape: (nb_samples, 50) model.add(Dense(50, input_dim=10)) # output shape: (nb_samples, 50)
model.add(Reshape(10, 5)) # output shape: (nb_samples, 10, 5) model.add(Reshape(dims=(10, 5))) # output shape: (nb_samples, 10, 5)
model.add(Permute((2, 1))) #output shape: (nb_samples, 5, 10) model.add(Permute(dims=(2, 1))) #output shape: (nb_samples, 5, 10)
``` ```
--- ---
@ -294,8 +300,9 @@ This layer can be used, for instance, to induce activation sparsity in the previ
## MaxoutDense ## MaxoutDense
```python ```python
keras.layers.core.MaxoutDense(input_dim, output_dim, nb_feature=4, init='glorot_uniform', weights=None, \ keras.layers.core.MaxoutDense(output_dim, nb_feature=4, init='glorot_uniform', weights=None,
W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None) W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None, input_dim=None)
``` ```
A dense maxout layer. A `MaxoutDense` layer takes the element-wise maximum of `nb_feature` `Dense(input_dim, output_dim)` linear layers. This allows the layer to learn a convex, piecewise linear activation function over the inputs. See [this paper](http://arxiv.org/pdf/1302.4389.pdf) for more details. Note that this is a *linear* layer -- if you wish to apply activation function (you shouldn't need to -- they are universal function approximators), an `Activation` layer must be added after. A dense maxout layer. A `MaxoutDense` layer takes the element-wise maximum of `nb_feature` `Dense(input_dim, output_dim)` linear layers. This allows the layer to learn a convex, piecewise linear activation function over the inputs. See [this paper](http://arxiv.org/pdf/1302.4389.pdf) for more details. Note that this is a *linear* layer -- if you wish to apply activation function (you shouldn't need to -- they are universal function approximators), an `Activation` layer must be added after.
@ -306,7 +313,6 @@ A dense maxout layer. A `MaxoutDense` layer takes the element-wise maximum of `n
- __Arguments__: - __Arguments__:
- __input_dim__: int >= 0.
- __output_dim__: int >= 0. - __output_dim__: int >= 0.
- __nb_feature__: int >= 0. the number of features to create for the maxout. This is equivalent to the number of piecewise elements to be allowed for the activation function. - __nb_feature__: int >= 0. the number of features to create for the maxout. This is equivalent to the number of piecewise elements to be allowed for the activation function.
- __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument. - __init__: name of initialization function for the weights of the layer (see: [initializations](../initializations.md)), or alternatively, Theano function to use for weights initialization. This parameter is only relevant if you don't pass a `weights` argument.
@ -316,12 +322,12 @@ A dense maxout layer. A `MaxoutDense` layer takes the element-wise maximum of `n
- __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output. - __activity_regularizer__: instance of [ActivityRegularizer](../regularizers.md), applied to the network output.
- __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix. - __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the main weights matrix.
- __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias. - __b_constraint__: instance of the [constraints](../constraints.md) module, applied to the bias.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
```python ```python
# input shape: (nb_samples, 10) # input shape: (nb_samples, 10)
model.add(Dense(10, 100)) # output shape: (nb_samples, 100) model.add(Dense(100, input_dim=10)) # output shape: (nb_samples, 100)
model.add(MaxoutDense(100, 100, nb_feature=10)) # output shape: (nb_samples, 100) model.add(MaxoutDense(50, nb_feature=10)) # output shape: (nb_samples, 50)
model.add(RepeatVector(2)) # output shape: (nb_samples, 2, 10)
``` ```
## Merge ## Merge
@ -339,17 +345,17 @@ Merge the output of a list of layers (or containers) into a single tensor, follo
```python ```python
left = Sequential() left = Sequential()
left.add(Dense(784, 50)) left.add(Dense(50, input_shape=(784,)))
left.add(Activation('relu')) left.add(Activation('relu'))
right = Sequential() right = Sequential()
right.add(Dense(784, 50)) right.add(Dense(50, input_shape=(784,)))
right.add(Activation('relu')) right.add(Activation('relu'))
model = Sequential() model = Sequential()
model.add(Merge([left, right], mode='sum')) model.add(Merge([left, right], mode='sum'))
model.add(Dense(50, 10)) model.add(Dense(10))
model.add(Activation('softmax')) model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

@ -2,15 +2,15 @@
## Embedding ## Embedding
```python ```python
keras.layers.embeddings.Embedding(input_dim, output_dim, init='uniform', weights=None, W_regularizer=None, W_constraint=None, mask_zero=False) keras.layers.embeddings.Embedding(input_dim, output_dim, init='uniform', weights=None, W_regularizer=None, W_constraint=None, mask_zero=False, max_length=None)
``` ```
Turn positive integers (indexes) into denses vectors of fixed size, Turn positive integers (indexes) into denses vectors of fixed size,
eg. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]` eg. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`
- __Input shape__: 2D tensor with shape: `(nb_samples, maxlen)`. - __Input shape__: 2D tensor with shape: `(nb_samples, sequence_length)`.
- __Output shape__: 3D tensor with shape: `(nb_samples, maxlen, output_dim)`. - __Output shape__: 3D tensor with shape: `(nb_samples, sequence_length, output_dim)`.
- __Arguments__: - __Arguments__:
@ -21,6 +21,7 @@ eg. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`
- __W_regularizer__: instance of the [regularizers](../regularizers.md) module (eg. L1 or L2 regularization), applied to the embedding matrix. - __W_regularizer__: instance of the [regularizers](../regularizers.md) module (eg. L1 or L2 regularization), applied to the embedding matrix.
- __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the embedding matrix. - __W_constraint__: instance of the [constraints](../constraints.md) module (eg. maxnorm, nonneg), applied to the embedding matrix.
- __mask_zero__: Whether or not the input value 0 is a special "padding" value that should be masked out. This is useful for [recurrent layers](recurrent.md) which may take variable length input. If this is `True` then all subsequent layers in the model need to support masking or an exception will be raised. - __mask_zero__: Whether or not the input value 0 is a special "padding" value that should be masked out. This is useful for [recurrent layers](recurrent.md) which may take variable length input. If this is `True` then all subsequent layers in the model need to support masking or an exception will be raised.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
## WordContextProduct ## WordContextProduct

@ -6,9 +6,9 @@ keras.layers.noise.GaussianNoise(sigma)
``` ```
Apply to the input an additive zero-centred gaussian noise with standard deviation `sigma`. This is useful to mitigate overfitting (you could see it as a kind of random data augmentation). Gaussian Noise (GS) is a natural choice as corruption process for real valued inputs. Apply to the input an additive zero-centred gaussian noise with standard deviation `sigma`. This is useful to mitigate overfitting (you could see it as a kind of random data augmentation). Gaussian Noise (GS) is a natural choice as corruption process for real valued inputs.
The Gaussian noise is only added at training time. Only active at training time.
- __Input shape__: This layer does not assume a specific input shape. - __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
@ -24,11 +24,9 @@ keras.layers.noise.GaussianDropout(p)
``` ```
Apply to the input an multiplicative one-centred gaussian noise with standard deviation `sqrt(p/(1-p))`. p refers to drop probability to match Dropout layer syntax. Apply to the input an multiplicative one-centred gaussian noise with standard deviation `sqrt(p/(1-p))`. p refers to drop probability to match Dropout layer syntax.
http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf Only active at training time.
The Gaussian noise is only used at training time. - __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Input shape__: This layer does not assume a specific input shape.
- __Output shape__: Same as input. - __Output shape__: Same as input.
@ -36,3 +34,4 @@ The Gaussian noise is only used at training time.
- __p__: float, drop probability as with Dropout. - __p__: float, drop probability as with Dropout.

@ -2,17 +2,16 @@
## BatchNormalization ## BatchNormalization
```python ```python
keras.layers.normalization.BatchNormalization(input_shape, epsilon=1e-6, weights=None) keras.layers.normalization.BatchNormalization(epsilon=1e-6, weights=None)
``` ```
Normalize the activations of the previous layer at each batch. Normalize the activations of the previous layer at each batch.
- __Input shape__: Same as `input_shape`. This layer cannot be used as first layer in a model. - __Input shape__: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
- __Output shape__: Same as input. - __Output shape__: Same as input.
- __Arguments__: - __Arguments__:
- __input_shape__: tuple.
- __epsilon__: small float > 0. Fuzz parameter. - __epsilon__: small float > 0. Fuzz parameter.
- __weights__: Initialization weights. List of 2 numpy arrays, with shapes: `[(input_shape,), (input_shape,)]` - __weights__: Initialization weights. List of 2 numpy arrays, with shapes: `[(input_shape,), (input_shape,)]`

@ -2,9 +2,9 @@
## SimpleRNN ## SimpleRNN
```python ```python
keras.layers.recurrent.SimpleRNN(input_dim, output_dim, keras.layers.recurrent.SimpleRNN(output_dim,
init='glorot_uniform', inner_init='orthogonal', activation='sigmoid', weights=None, init='glorot_uniform', inner_init='orthogonal', activation='sigmoid', weights=None,
truncate_gradient=-1, return_sequences=False) truncate_gradient=-1, return_sequences=False, input_dim=None, input_length=None)
``` ```
Fully connected RNN where output is to fed back to input. Fully connected RNN where output is to fed back to input.
@ -18,23 +18,25 @@ Fully connected RNN where output is to fed back to input.
- __Arguments__: - __Arguments__:
- __input_dim__: dimension of the input.
- __output_dim__: dimension of the internal projections and the final output. - __output_dim__: dimension of the internal projections and the final output.
- __init__: weight initialization function. Can be the name of an existing function (str), or a Theano function (see: [initializations](../initializations.md)). - __init__: weight initialization function. Can be the name of an existing function (str), or a Theano function (see: [initializations](../initializations.md)).
- __activation__: activation function. Can be the name of an existing function (str), or a Theano function (see: [activations](../activations.md)). - __activation__: activation function. Can be the name of an existing function (str), or a Theano function (see: [activations](../activations.md)).
- __weights__: list of numpy arrays to set as initial weights. The list should have 3 elements, of shapes: `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`. - __weights__: list of numpy arrays to set as initial weights. The list should have 3 elements, of shapes: `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
- __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html). - __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html).
- __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence. - __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
--- ---
## SimpleDeepRNN ## SimpleDeepRNN
```python ```python
keras.layers.recurrent.SimpleDeepRNN(input_dim, output_dim, depth=3, keras.layers.recurrent.SimpleDeepRNN(output_dim, depth=3,
init='glorot_uniform', inner_init='orthogonal', init='glorot_uniform', inner_init='orthogonal',
activation='sigmoid', inner_activation='hard_sigmoid', activation='sigmoid', inner_activation='hard_sigmoid',
weights=None, truncate_gradient=-1, return_sequences=False) weights=None, truncate_gradient=-1, return_sequences=False,
input_dim=None, input_length=None)
``` ```
Fully connected RNN where the output of multiple timesteps (up to "depth" steps in the past) is fed back to the input: Fully connected RNN where the output of multiple timesteps (up to "depth" steps in the past) is fed back to the input:
@ -64,6 +66,8 @@ Not a particularly useful model, included for demonstration purposes.
- __weights__: list of numpy arrays to set as initial weights. The list should have depth+2 elements. - __weights__: list of numpy arrays to set as initial weights. The list should have depth+2 elements.
- __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html). - __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html).
- __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence. - __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
--- ---
@ -74,7 +78,8 @@ Not a particularly useful model, included for demonstration purposes.
keras.layers.recurrent.GRU(input_dim, output_dim=128, keras.layers.recurrent.GRU(input_dim, output_dim=128,
init='glorot_uniform', inner_init='orthogonal', init='glorot_uniform', inner_init='orthogonal',
activation='sigmoid', inner_activation='hard_sigmoid', activation='sigmoid', inner_activation='hard_sigmoid',
weights=None, truncate_gradient=-1, return_sequences=False) weights=None, truncate_gradient=-1, return_sequences=False,
input_dim=None, input_length=None)
``` ```
Gated Recurrent Unit - Cho et al. 2014. Gated Recurrent Unit - Cho et al. 2014.
@ -97,6 +102,8 @@ Gated Recurrent Unit - Cho et al. 2014.
- __weights__: list of numpy arrays to set as initial weights. The list should have 9 elements. - __weights__: list of numpy arrays to set as initial weights. The list should have 9 elements.
- __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html). - __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html).
- __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence. - __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
- __References__: - __References__:
- [On the Properties of Neural Machine Translation: EncoderDecoder Approaches](http://www.aclweb.org/anthology/W14-4012) - [On the Properties of Neural Machine Translation: EncoderDecoder Approaches](http://www.aclweb.org/anthology/W14-4012)
@ -110,7 +117,8 @@ Gated Recurrent Unit - Cho et al. 2014.
keras.layers.recurrent.LSTM(input_dim, output_dim=128, keras.layers.recurrent.LSTM(input_dim, output_dim=128,
init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one', init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one',
activation='tanh', inner_activation='hard_sigmoid', activation='tanh', inner_activation='hard_sigmoid',
weights=None, truncate_gradient=-1, return_sequences=False) weights=None, truncate_gradient=-1, return_sequences=False,
input_dim=None, input_length=None)
``` ```
Long-Short Term Memory unit - Hochreiter 1997. Long-Short Term Memory unit - Hochreiter 1997.
@ -134,6 +142,8 @@ Long-Short Term Memory unit - Hochreiter 1997.
- __weights__: list of numpy arrays to set as initial weights. The list should have 12 elements. - __weights__: list of numpy arrays to set as initial weights. The list should have 12 elements.
- __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html). - __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html).
- __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence. - __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
- __References__: - __References__:
- [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf) (original 1997 paper) - [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf) (original 1997 paper)
@ -148,7 +158,8 @@ Long-Short Term Memory unit - Hochreiter 1997.
keras.layers.recurrent.JZS1(input_dim, output_dim=128, keras.layers.recurrent.JZS1(input_dim, output_dim=128,
init='glorot_uniform', inner_init='orthogonal', init='glorot_uniform', inner_init='orthogonal',
activation='tanh', inner_activation='sigmoid', activation='tanh', inner_activation='sigmoid',
weights=None, truncate_gradient=-1, return_sequences=False) weights=None, truncate_gradient=-1, return_sequences=False,
input_dim=None, input_length=None)
``` ```
Top 3 RNN architectures evolved from the evaluation of thousands of models. Serves as alternatives to LSTMs and GRUs. Corresponds to `MUT1`, `MUT2`, and `MUT3` architectures described in the paper: An Empirical Exploration of Recurrent Network Architectures, Jozefowicz et al. 2015. Top 3 RNN architectures evolved from the evaluation of thousands of models. Serves as alternatives to LSTMs and GRUs. Corresponds to `MUT1`, `MUT2`, and `MUT3` architectures described in the paper: An Empirical Exploration of Recurrent Network Architectures, Jozefowicz et al. 2015.
@ -171,6 +182,8 @@ Top 3 RNN architectures evolved from the evaluation of thousands of models. Serv
- __weights__: list of numpy arrays to set as initial weights. The list should have 9 elements. - __weights__: list of numpy arrays to set as initial weights. The list should have 9 elements.
- __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html). - __truncate_gradient__: Number of steps to use in truncated BPTT. See: [Theano "scan"](http://deeplearning.net/software/theano/library/scan.html).
- __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence. - __return_sequences__: Boolean. Whether to return the last output in the output sequence, or the full sequence.
- __input_dim__: dimensionality of the input (integer). This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model.
- __input_length__: Length of input sequences, when it is constant. This argument is required if you are going to connect `Flatten` then `Dense` layers upstream (without it, the shape of the dense outputs cannot be computed).
- __References__: - __References__:
- [An Empirical Exploration of Recurrent Network Architectures](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) - [An Empirical Exploration of Recurrent Network Architectures](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)

@ -52,7 +52,7 @@ from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD from keras.optimizers import SGD
model = Sequential() model = Sequential()
model.add(Dense(64, 2, init='uniform')) model.add(Dense(2, init='uniform', input_dim=64))
model.add(Activation('softmax')) model.add(Activation('softmax'))
model.compile(loss='mse', optimizer='sgd') model.compile(loss='mse', optimizer='sgd')
@ -125,10 +125,10 @@ Arbitrary connection graph. It can have any number of inputs and outputs, with e
model = keras.models.Graph() model = keras.models.Graph()
``` ```
- __Methods__: - __Methods__:
- __add_input__(name, ndim=2, dtype='float'): Add an input with shape dimensionality `ndim`. - __add_input__(name, input_shape, dtype='float'): Add an input with shape dimensionality `ndim`.
- __Arguments__: - __Arguments__:
- __ndim__: Use `ndim=2` for vector input `(samples, features)`, ndim=3 for temporal input `(samples, time, features)`, ndim=4 for image input `(samples, channels, height, width)`. - __input_shape__: Integer tuple, shape of the expected input (not including the samples axis). E.g. (10,) for 10-dimensional vectors, (None, 128) for sequences (of variable length) of 128-dimensional vectors, (3, 32, 32) for 32x32 images with RGB channels.
- __dtype__: `float` or `int`. Use `int` if the input is connected to an Embedding layer, `float` otherwise. - __dtype__: `float` or `int`. Type of the expected input data.
- __add_output__(name, input=None, inputs=[], merge_mode='concat'): Add an output connect to `input` or `inputs`. - __add_output__(name, input=None, inputs=[], merge_mode='concat'): Add an output connect to `input` or `inputs`.
- __Arguments__: - __Arguments__:
- __name__: str. unique identifier of the output. - __name__: str. unique identifier of the output.
@ -176,10 +176,10 @@ __Examples__:
```python ```python
# graph model with one input and two outputs # graph model with one input and two outputs
graph = Graph() graph = Graph()
graph.add_input(name='input', ndim=2) graph.add_input(name='input', input_shape=(32,))
graph.add_node(Dense(32, 16), name='dense1', input='input') graph.add_node(Dense(16), name='dense1', input='input')
graph.add_node(Dense(32, 4), name='dense2', input='input') graph.add_node(Dense(4), name='dense2', input='input')
graph.add_node(Dense(16, 4), name='dense3', input='dense1') graph.add_node(Dense(4), name='dense3', input='dense1')
graph.add_output(name='output1', input='dense2') graph.add_output(name='output1', input='dense2')
graph.add_output(name='output2', input='dense3') graph.add_output(name='output2', input='dense3')
@ -191,11 +191,11 @@ history = graph.fit({'input':X_train, 'output1':y_train, 'output2':y2_train}, nb
```python ```python
# graph model with two inputs and one output # graph model with two inputs and one output
graph = Graph() graph = Graph()
graph.add_input(name='input1', ndim=2) graph.add_input(name='input1', input_shape=(32,))
graph.add_input(name='input2', ndim=2) graph.add_input(name='input2', input_shape=(32,))
graph.add_node(Dense(32, 16), name='dense1', input='input1') graph.add_node(Dense(16), name='dense1', input='input1')
graph.add_node(Dense(32, 4), name='dense2', input='input2') graph.add_node(Dense(4), name='dense2', input='input2')
graph.add_node(Dense(16, 4), name='dense3', input='dense1') graph.add_node(Dense(4), name='dense3', input='dense1')
graph.add_output(name='output', inputs=['dense2', 'dense3'], merge_mode='sum') graph.add_output(name='output', inputs=['dense2', 'dense3'], merge_mode='sum')
graph.compile('rmsprop', {'output':'mse'}) graph.compile('rmsprop', {'output':'mse'})

@ -5,7 +5,7 @@ An optimizer is one of the two arguments required for compiling a Keras model:
```python ```python
model = Sequential() model = Sequential()
model.add(Dense(20, 64, init='uniform')) model.add(Dense(64, init='uniform', input_dim=10))
model.add(Activation('tanh')) model.add(Activation('tanh'))
model.add(Activation('softmax')) model.add(Activation('softmax'))

@ -15,7 +15,7 @@ These layers expose 3 keyword arguments:
```python ```python
from keras.regularizers import l2, activity_l2 from keras.regularizers import l2, activity_l2
model.add(Dense(64, 64, W_regularizer=l2(0.01), activity_regularizer=activity_l2(0.01))) model.add(Dense(64, input_dim=64, W_regularizer=l2(0.01), activity_regularizer=activity_l2(0.01)))
``` ```
## Available penalties ## Available penalties