Style fixes

This commit is contained in:
Francois Chollet 2016-11-05 13:45:50 -07:00
parent fd326ddf1b
commit 9d4087a1e9
3 changed files with 23 additions and 25 deletions

@ -1,6 +1,6 @@
""" This script demonstrate the use of convolutional LSTM network
This network is used to predict the next frame of an artificialy
generated movie which contain moving squares.
""" This script demonstrates the use of a convolutional LSTM network.
This network is used to predict the next frame of an artificially
generated movie which contains moving squares.
"""
from keras.models import Sequential
from keras.layers.convolutional import Convolution3D
@ -10,7 +10,7 @@ import numpy as np
import pylab as plt
# We create a layer which take as input movies of shape
# (n_frames, width, height, channel) and that returns a movie
# (n_frames, width, height, channels) and returns a movie
# of identical shape.
seq = Sequential()
@ -38,12 +38,12 @@ seq.add(Convolution3D(nb_filter=1, kernel_dim1=1, kernel_dim2=3,
seq.compile(loss='binary_crossentropy', optimizer='adadelta')
# Generating artificial data:
# Artificial data generation:
# Generate movies with 3 to 7 moving squares inside.
# The squares are of shape one by one or two by two pixels and
# they move linearly trought time.
# For convenience we first create movies with bigger width and height, (80x80)
# and at the end we select a 40x40 window
# The squares are of shape 1x1 or 2x2 pixels,
# which move linearly over time.
# For convenience we first create movies with bigger width and height (80x80)
# and at the end we select a 40x40 window.
def generate_movies(n_samples=1200, n_frames=15):
row = 80
@ -53,8 +53,7 @@ def generate_movies(n_samples=1200, n_frames=15):
dtype=np.float)
for i in range(n_samples):
# add from 3 to 7 moving squares
# Add 3 to 7 moving squares
n = np.random.randint(3, 8)
for j in range(n):
@ -75,10 +74,10 @@ def generate_movies(n_samples=1200, n_frames=15):
y_shift - w: y_shift + w, 0] += 1
# Make it more robust by adding noise.
# The idea is that if during predict time,
# The idea is that if during inference,
# the value of the pixel is not exactly one,
# we need to train the network to be robust and stille
# consider it is a pixel belonging to a square.
# we need to train the network to be robust and still
# consider it as a pixel belonging to a square.
if np.random.randint(0, 2):
noise_f = (-1)**np.random.randint(0, 2)
noisy_movies[i, t,
@ -86,13 +85,13 @@ def generate_movies(n_samples=1200, n_frames=15):
y_shift - w - 1: y_shift + w + 1,
0] += noise_f * 0.1
# Shitf the ground truth by 1
# Shift the ground truth by 1
x_shift = xstart + directionx * (t + 1)
y_shift = ystart + directiony * (t + 1)
shifted_movies[i, t, x_shift - w: x_shift + w,
y_shift - w: y_shift + w, 0] += 1
# Cut to a forty's sized window
# Cut to a 40x40 window
noisy_movies = noisy_movies[::, ::, 20:60, 20:60, ::]
shifted_movies = shifted_movies[::, ::, 20:60, 20:60, ::]
noisy_movies[noisy_movies >= 1] = 1

@ -256,7 +256,7 @@ class ConvLSTM2D(ConvRecurrent2D):
forget_bias_init='one', activation='tanh',
inner_activation='hard_sigmoid',
dim_ordering='default',
border_mode='valid', sub_sample=(1, 1),
border_mode='valid', subsample=(1, 1),
W_regularizer=None, U_regularizer=None, b_regularizer=None,
dropout_W=0., dropout_U=0., **kwargs):
@ -273,7 +273,7 @@ class ConvLSTM2D(ConvRecurrent2D):
self.activation = activations.get(activation)
self.inner_activation = activations.get(inner_activation)
self.border_mode = border_mode
self.subsample = sub_sample
self.subsample = subsample
if dim_ordering == 'th':
warnings.warn('Be carefull if used with convolution3D layers:\n'

@ -10,14 +10,13 @@ from keras import regularizers
def test_recurrent_convolutional():
nb_row = 4
nb_col = 4
nb_filter = 20
nb_samples = 5
nb_row = 3
nb_col = 3
nb_filter = 5
nb_samples = 2
input_channel = 2
input_nb_row = 10
input_nb_col = 10
input_nb_row = 5
input_nb_col = 5
sequence_len = 2
for dim_ordering in ['th', 'tf']: