Improve README, functional API guide

This commit is contained in:
Francois Chollet 2016-04-04 21:46:27 -07:00
parent d3615e682e
commit 263de77a5a
2 changed files with 129 additions and 66 deletions

@ -37,9 +37,9 @@ Keras is compatible with: __Python 2.7-3.5__.
## Getting started: 30 seconds to Keras ## Getting started: 30 seconds to Keras
The core data structure of Keras is a __model__, a way to organize layers. There are two types of models: [`Sequential`](http://keras.io/models/#sequential) and [`Graph`](http://keras.io/models/#graph). The core data structure of Keras is a __model__, a way to organize layers. The main type of model is the [`Sequential`](http://keras.io/models/#sequential) model, a linear stack of layers. For more complex architectures, you should use the [Keras function API]().
Here's the `Sequential` model (a linear pile of layers): Here's the `Sequential` model:
```python ```python
from keras.models import Sequential from keras.models import Sequential
@ -52,15 +52,15 @@ Stacking layers is as easy as `.add()`:
```python ```python
from keras.layers.core import Dense, Activation from keras.layers.core import Dense, Activation
model.add(Dense(output_dim=64, input_dim=100, init="glorot_uniform")) model.add(Dense(output_dim=64, input_dim=100))
model.add(Activation("relu")) model.add(Activation("relu"))
model.add(Dense(output_dim=10, init="glorot_uniform")) model.add(Dense(output_dim=10))
model.add(Activation("softmax")) model.add(Activation("softmax"))
``` ```
Once your model looks good, configure its learning process with `.compile()`: Once your model looks good, configure its learning process with `.compile()`:
```python ```python
model.compile(loss='categorical_crossentropy', optimizer='sgd') model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
``` ```
If you need to, you can further configure your optimizer. A core principle of Keras is to make things reasonably simple, while allowing the user to be fully in control when they need to (the ultimate control being the easy extensibility of the source code). If you need to, you can further configure your optimizer. A core principle of Keras is to make things reasonably simple, while allowing the user to be fully in control when they need to (the ultimate control being the easy extensibility of the source code).
@ -92,9 +92,13 @@ proba = model.predict_proba(X_test, batch_size=32)
Building a network of LSTMs, a deep CNN, a Neural Turing Machine, a word2vec embedder or any other model is just as fast. The ideas behind deep learning are simple, so why should their implementation be painful? Building a network of LSTMs, a deep CNN, a Neural Turing Machine, a word2vec embedder or any other model is just as fast. The ideas behind deep learning are simple, so why should their implementation be painful?
Have a look at these [starter examples](http://keras.io/examples/). For a more in-depth tutorial about Keras, you can check out:
In the [examples folder](https://github.com/fchollet/keras/tree/master/examples) of the repo, you will find more advanced models: question-answering with memory networks, text generation with stacked LSTMs, neural turing machines, etc. - [Getting started with the Sequential model]()
- [Getting started with the functional API]()
- [Starter examples]()
In the [examples folder](https://github.com/fchollet/keras/tree/master/examples) of the repository, you will find more advanced models: question-answering with memory networks, text generation with stacked LSTMs, etc.
------------------ ------------------

@ -57,74 +57,74 @@ processed_sequences = TimeDistributed(model)(input_sequences)
Here's a good use case for the functional API: models with multiple inputs and outputs. The functional API makes it really easy to manipulate a large number of intertwinned datastreams. Here's a good use case for the functional API: models with multiple inputs and outputs. The functional API makes it really easy to manipulate a large number of intertwinned datastreams.
Let's consider a model for question-answering. The model learns a "relevance" embedding space where questions and their answers will be embedded at close positions. This embedding will allow us to quickly query a database of answers to find those that are relevant to a new question, based on the distances between the new question and stored answers. Let's consider the following model. We seek to predict how many retweets and likes a news headline will receive on Twitter. The main input to the model will be the headline itself, as a sequence of words, but to spice things up, our model will also have an auxiliary input, receiving extra data such as the time of day when the headline was posted, etc.
The model will also be supervised via two loss functions. Using the main loss function earlier in a model is a good regularization mechanism for deep models.
The model has three input branches: an embedding for the question, and two embeddings for two different answers, a relevant answer and an unrelated answer. We'll train the model with a triplet loss, teaching the model to maximize the dot product (i.e. cosine distance) between the question embedding and the embedding for the relevant answer, while minimizing the dot product between the question and the irrelevant answer. Here's how our model looks like:
[model graph] [model graph]
Implementing this with the functional API is quick and simple: Let's implement it with the functional API.
```python ```python
from keras.layers import Input, Embedding, LSTM, merge, Lambda from keras.layers import Input, Embedding, LSTM, Dense, merge
from keras.models import Model
# an input question will be a vector of 100 integers, # the main input will receive the headline,
# each being the index of a word in a vocabulary # as a sequence of integers (each integer encodes a word).
question_input = Input(shape=(100,), dtype='int32') # The integers will be between 1 and 10000 (a vocabuary of 10000 words)
# and the sequences will be 100 words long.
# Note that we can name any layer by passing it a "name" argument.
main_input = Input(shape=(10,), dtype='int32', name='main_input')
# this embedding layer will encode the input sequence
# into a sequence of dense 512-dimensional vectors
x = Embedding(output_dim=512, input_dim=10000, input_length=10)(main_input)
# a LSTM will transform the vector sequence into a single vector,
# containing information about the entire sequence
lstm_out = LSTM(32)(x)
# here we insert the auxiliary loss, allowing the LSTM and Embedding layer
# to be trained smoothly even the main loss will be much higher in the model
auxiliary_loss = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)
good_answer_input = Input(shape=(100,), dtype='int32') # at this point we feed into the model our auxiliary input data
bad_answer_input = Input(shape=(100,), dtype='int32') # by concatenating it with the LSTM output
auxiliary_input = Input(shape=(5,), name='aux_input')
x = merge([lstm_out, auxiliary_input], mode='concat')
# we stack a deep fully-connected network on top
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
# and finally we add the main logistic regression layer
main_loss = Dense(1, activation='sigmoid', name='main_output')(x)
embedded_question = Embedding(output_dim=1024)(question_input) # this defines a model with two inputs and two outputs
encoded_question = LSTM(512)(embedded_question) model = Model(input=[main_input, auxiliary_input], output=[main_loss, auxiliary_loss])
# the two layers below will be shared across the two answer inputs. # we compile the model and assign a weight of 0.2 to the auxiliary loss.
# we'll go over shared layers in detail in the next section. # to specify `loss_weight` or `loss`, you can use a list or a dictionary.
answer_embedding = Embedding(output_dim=1024) # here we pass a single loss so the same loss will be used on all outputs.
answer_lstm = LSTM(512) model.compile(optimizer='rmsprop', loss='binary_crossentropy',
loss_weight=[1., 0.2])
embedded_good_answer = answer_embedding(good_answer_input) headline_data = np.random.randint(5000, size=(20, 10))
encoded_good_answer = answer_lstm(embedded_good_answer) additional_data = np.random.random((20, 5))
labels = np.random.randint(2, size=(20, 1))
embedded_bad_answer = answer_embedding(bad_answer_input) # we can train the model by passing it lists of input arrays and target arrays:
encoded_bad_answer = answer_lstm(embedded_bad_answer) model.fit([headline_data, additional_data], [labels, labels],
nb_epoch=50, batch_size=32)
# let's take the dot product between the question embedding # since our inputs and outputs are named (we passed them a "name" argument),
# and the embedding of the answers # we could also have compiled the model via:
good_answer_score = merge([encoded_question, encoded_good_answer], mode='dot') model.compile(optimizer='rmsprop',
bad_answer_score = merge([encoded_question, encoded_bad_answer], mode='dot') loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},
loss_weight={'main_output': 1., 'aux_output': 0.2})
# this is a lambda layer. It allows you to create # and trained it via:
# simple stateless layers on the fly to take care of basic operations. model.fit({'main_input': headline_data, 'aux_input': additional_data},
# Note how the layer below has multiple inputs. Also, here we are using {'main_output': labels, 'aux_output': labels},
# a function from `keras.backend` that squares the error. nb_epoch=50, batch_size=32)
output = Lambda(lambda x, y: K.square(x - y))([good_answer_score, bad_answer_score])
``` ```
Now let's say that we want our model to return not only the final output, but also each of the two previous scores (so you can apply auxilliary loss functions to them).
You can define the following model:
```python
model = Model(input=[question_input, good_answer_input, bad_answer_input],
output=[output, good_answer_score, bad_answer_score])
model.compile(optimizer='rmsprop', loss=[custom_loss_1, custom_loss_2, custom_loss_3])
model.fit([q_data, good_ans_data, bad_ans_data], [custom_target_1, custom_target_2, custom_target_3])
```
You can also define a separate model that just embeds a single question, using the same layers as trained with the model above:
```python
question_embedder = Model(input=question_input, output=encoded_question)
embedded_qs = question_embedder.predict(q_data)
```
And one that can embed any answer:
```python
answer_embedder = Model(input=good_answer_input, output=encoded_good_answer)
embedded_ans = answer_embedder.predict(ans_data)
```
Great! Now if you have some training data, you got yourself a question/answer matching model.
## Shared layers ## Shared layers
Another good use for the functional API are models that use shared layers. Let's take a look at shared layers. Another good use for the functional API are models that use shared layers. Let's take a look at shared layers.
@ -225,8 +225,6 @@ Simple enough, right?
The same is true for the properties `input_shape` and `output_shape`: as long as the layer has only one node, or as long as all nodes have the same input/output shape, then the notion of "layer output/input shape" is well defined, and that one shape will be returned by `layer.output_shape`/`layer.input_shape`. But if, for instance, you apply a same `Convolution2D` layer to an input of shape (3, 32, 32) then to an input of shape `(3, 64, 64)`, the layer will have multiple input/output shapes, and you will have to fetch them via the index of the node they belong to: The same is true for the properties `input_shape` and `output_shape`: as long as the layer has only one node, or as long as all nodes have the same input/output shape, then the notion of "layer output/input shape" is well defined, and that one shape will be returned by `layer.output_shape`/`layer.input_shape`. But if, for instance, you apply a same `Convolution2D` layer to an input of shape (3, 32, 32) then to an input of shape `(3, 64, 64)`, the layer will have multiple input/output shapes, and you will have to fetch them via the index of the node they belong to:
```python ```python
from keras.layers import merge, Convolution2D
a = Input(shape=(3, 32, 32)) a = Input(shape=(3, 32, 32))
b = Input(shape=(3, 64, 64)) b = Input(shape=(3, 64, 64))
@ -251,6 +249,8 @@ Code examples are still the best way to get started, so here are a few more.
For more information about the Inception architecture, see [Going Deeper with Convolutions](http://arxiv.org/abs/1409.4842). For more information about the Inception architecture, see [Going Deeper with Convolutions](http://arxiv.org/abs/1409.4842).
```python ```python
from keras.layers import merge, Convolution2D, MaxPooling2D, Input
input_img = Input(shape=(3, 256, 256)) input_img = Input(shape=(3, 256, 256))
tower_1 = Convolution2D(64, 1, 1, border_mode='same', activation='relu')(input_img) tower_1 = Convolution2D(64, 1, 1, border_mode='same', activation='relu')(input_img)
@ -285,12 +285,15 @@ z = merge([x, y], mode='sum')
This model re-uses the same image-processing module on two inputs, to classify whether two MNIST digits are the same digit or different digits. This model re-uses the same image-processing module on two inputs, to classify whether two MNIST digits are the same digit or different digits.
```python ```python
from keras.layers import merge, Convolution2D, MaxPooling2D, Input, Dense, Flatten
from keras.models import Model
# first, define the vision modules # first, define the vision modules
digit_input = Input(shape=(1, 27, 27)) digit_input = Input(shape=(1, 27, 27))
x = Convolution2D(64, 3, 3)(digit_input)
x = Convolution2D(64, 3, 3)(x) x = Convolution2D(64, 3, 3)(x)
x = Convolution2D(64, 3, 3)(x) x = MaxPooling2D((2, 2))(x)
out = MaxPooling2D((2, 2))(x) out = Flatten()(x)
vision_model = Model(digit_input, out) vision_model = Model(digit_input, out)
@ -315,7 +318,46 @@ This model can select the correct one-word answer when asked a natural-language
It works by encoding the question into a vector, encoding the image into a vector, concatenating the two, and training on top a logistic regression over some vocabulary of potential answers. It works by encoding the question into a vector, encoding the image into a vector, concatenating the two, and training on top a logistic regression over some vocabulary of potential answers.
```python ```python
[TODO] from keras.layers import Convolution2D, MaxPooling2D, Flatten
from keras.layers import Input, LSTM, Embedding, Dense, merge
from keras.models import Model, Sequential
# first, let's define a vision model using a Sequential model.
# this model will encode an image into a vector.
vision_model = Sequential()
vision_model.add(Convolution2D(64, 3, 3, activation='relu', border_mode='same', input_shape=(3, 224, 224)))
vision_model.add(Convolution2D(64, 3, 3, activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Convolution2D(128, 3, 3, activation='relu', border_mode='same'))
vision_model.add(Convolution2D(128, 3, 3, activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Convolution2D(256, 3, 3, activation='relu', border_mode='same'))
vision_model.add(Convolution2D(256, 3, 3, activation='relu'))
vision_model.add(Convolution2D(256, 3, 3, activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Flatten())
# now let's get a tensor with the output of our vision model:
image_input = Input(shape=(3, 224, 224))
encoded_image = vision_model(image_input)
# next, let's define a language model to encode the question into a vector.
# each question will be at most 100 word long,
# and we will index words as integers from 1 to 9999.
question_input = Input(shape=(100,), dtype='int32')
embedded_question = Embedding(input_dim=10000, output_dim=256, input_length=100)(question_input)
encoded_question = LSTM(256)(embedded_question)
# let's concatenate the question vector and the image vector:
merged = merge([encoded_question, encoded_image], mode='concat')
# and let's train a logistic regression over 1000 words on top:
output = Dense(1000, activation='softmax')(merged)
# this is our final model:
vqa_model = Model(input=[image_input, question_input], output=output)
# the next stage would be training this model on actual data.
``` ```
### Video question answering model. ### Video question answering model.
@ -323,5 +365,22 @@ It works by encoding the question into a vector, encoding the image into a vecto
Now that we have trained our image QA model, we can quickly turn it into a video QA model. With appropriate training, you will be able to show it a short video (e.g. 100-frame human action) and ask a natural language question about the video (e.g. "what sport is the boy playing?" -> "footbal"). Now that we have trained our image QA model, we can quickly turn it into a video QA model. With appropriate training, you will be able to show it a short video (e.g. 100-frame human action) and ask a natural language question about the video (e.g. "what sport is the boy playing?" -> "footbal").
```python ```python
[TODO] from keras.layers import TimeDistributed
video_input = Input(shape=(100, 3, 224, 224))
# this is our video encoded via the previously trained vision_model (weights are reused)
encoded_frame_sequence = TimeDistributed(vision_model)(video_input) # the output will be a sequence of vectors
encoded_video = LSTM(256)(encoded_frame_sequence) # the output will be a vector
# this is a model-level representation of the question encoder, reusing the same weights as before:
question_encoder = Model(input=question_input, output=encoded_question)
# let's use it to encode the question:
video_question_input = Input(shape=(100,), dtype='int32')
encoded_video_question = question_encoder(video_question_input)
# and this is our video question answering model:
merged = merge([encoded_video, encoded_video_question], mode='concat')
output = Dense(1000, activation='softmax')(merged)
video_qa_model = Model(input=[video_input, video_question_input], output=output)
``` ```