Possibly faster RNNs

Forgot dropout.

Various fixes

Fix SimpleRNN dropout typo
This commit is contained in:
Francois Chollet 2016-02-24 20:57:27 -08:00
parent 06a1545645
commit abca83373d
2 changed files with 137 additions and 74 deletions

@ -47,7 +47,7 @@ print('X_test shape:', X_test.shape)
print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128, input_length=maxlen, dropout=0.5))
model.add(LSTM(128, dropout_W=0.5, dropout_U=0.5)) # try using a GRU instead, for fun
model.add(LSTM(128, dropout_W=0.5, dropout_U=0.1)) # try using a GRU instead, for fun
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

@ -7,6 +7,26 @@ from .. import activations, initializations, regularizers
from ..layers.core import MaskedLayer
def time_distributed_dense(x, w, b=None, dropout=None,
input_dim=None, output_dim=None, timesteps=None):
if not input_dim:
# won't work with TensorFlow
input_dim = K.shape(x)[2]
if not timesteps:
# won't work with TensorFlow
timesteps = K.shape(x)[1]
if not output_dim:
output_dim = K.shape(w)[1]
x = K.reshape(x, (-1, input_dim))
if dropout:
x *= K.concatenate([dropout] * timesteps, 0)
x = K.dot(x, w)
if b:
x = x + b
x = K.reshape(x, (-1, timesteps, output_dim))
return x
class Recurrent(MaskedLayer):
'''Abstract base class for recurrent layers.
Do not use in a model -- it's not a functional layer!
@ -116,23 +136,25 @@ class Recurrent(MaskedLayer):
def step(self, x, states):
raise NotImplementedError
def get_constants(self, X, train=False):
return None
def get_constants(self, x, train=False):
return []
def get_initial_states(self, X):
def get_initial_states(self, x):
# build an all-zero tensor of shape (samples, output_dim)
initial_state = K.zeros_like(X) # (samples, timesteps, input_dim)
initial_state = K.zeros_like(x) # (samples, timesteps, input_dim)
initial_state = K.sum(initial_state, axis=1) # (samples, input_dim)
reducer = K.zeros((self.input_dim, self.output_dim))
initial_state = K.dot(initial_state, reducer) # (samples, output_dim)
initial_states = [initial_state for _ in range(len(self.states))]
return initial_states
def preprocess_input(self, x, train=False):
return x
def get_output(self, train=False):
# input shape: (nb_samples, time (padded with zeros), input_dim)
X = self.get_input(train)
mask = self.get_input_mask(train)
constants = self.get_constants(X, train)
assert K.ndim(X) == 3
if K._BACKEND == 'tensorflow':
@ -150,8 +172,10 @@ class Recurrent(MaskedLayer):
initial_states = self.states
else:
initial_states = self.get_initial_states(X)
constants = self.get_constants(X, train)
preprocessed_input = self.preprocess_input(X, train)
last_output, outputs, states = K.rnn(self.step, X,
last_output, outputs, states = K.rnn(self.step, preprocessed_input,
initial_states,
go_backwards=self.go_backwards,
mask=mask,
@ -263,33 +287,39 @@ class SimpleRNN(Recurrent):
else:
self.states = [K.zeros((input_shape[0], self.output_dim))]
def step(self, x, states):
# states contains the previous output,
# and the two dropout matrices from self.get_constants()
assert len(states) == 3 # 1 state and 2 constants
def preprocess_input(self, x, train=False):
if train and (0 < self.dropout_W < 1):
ones = K.ones_like(K.reshape(x[:, 0, :], (-1, self.input_dim)))
B_W = K.dropout(ones, self.dropout_W)
else:
B_W = None
input_shape = self.input_shape
input_dim = input_shape[2]
if input_shape[1]:
timesteps = input_shape[1]
else:
# this won't work with TensorFlow
timesteps = K.shape(x)[1]
return time_distributed_dense(x, self.W, self.b, B_W,
input_dim, self.output_dim, timesteps)
def step(self, h, states):
prev_output = states[0]
B_W = states[1]
B_U = states[2]
h = K.dot(x * B_W, self.W) + self.b
if len(states) == 2:
B_U = states[1]
else:
B_U = 1.
output = self.activation(h + K.dot(prev_output * B_U, self.U))
return output, [output]
def get_constants(self, X, train=False):
retain_p_W = 1. - self.dropout_W
retain_p_U = 1. - self.dropout_U
if train and (self.dropout_W > 0 or self.dropout_U > 0):
nb_samples = K.shape(X)[0]
if K._BACKEND == 'tensorflow':
if not self.input_shape[0]:
raise Exception('For RNN dropout in tensorflow, a complete ' +
'input_shape must be provided (including batch size).')
nb_samples = self.input_shape[0]
B_W = K.random_binomial((nb_samples, self.input_dim), p=retain_p_W)
B_U = K.random_binomial((nb_samples, self.output_dim), p=retain_p_U)
def get_constants(self, x, train=False):
if train and (0 < self.dropout_U < 1):
ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1)))
ones = K.concatenate([ones] * self.output_dim, 1)
B_U = K.dropout(ones, self.dropout_U)
return [B_U]
else:
B_W = np.ones(1, dtype=K.floatx()) * retain_p_W
B_U = np.ones(1, dtype=K.floatx()) * retain_p_U
return [B_W, B_U]
B_U = []
def get_config(self):
config = {"output_dim": self.output_dim,
@ -405,15 +435,38 @@ class GRU(Recurrent):
else:
self.states = [K.zeros((input_shape[0], self.output_dim))]
def step(self, x, states):
assert len(states) == 3 # 1 state and 2 constants
h_tm1 = states[0] # previous memory
B_W = states[1] # dropout matrix for input units
B_U = states[2] # dropout matrix for recurrent units
def preprocess_input(self, x, train=False):
if train and (0 < self.dropout_W < 1):
ones = K.ones_like(K.reshape(x[:, 0, :], (-1, self.input_dim)))
B_W = [K.dropout(ones, self.dropout_W) for _ in range(3)]
else:
B_W = [None for _ in range(3)]
input_shape = self.input_shape
input_dim = input_shape[2]
if input_shape[1]:
timesteps = input_shape[1]
else:
# this won't work with TensorFlow
timesteps = K.shape(x)[1]
x_z = K.dot(x * B_W[0], self.W_z) + self.b_z
x_r = K.dot(x * B_W[1], self.W_r) + self.b_r
x_h = K.dot(x * B_W[2], self.W_h) + self.b_h
x_z = time_distributed_dense(x, self.W_z, self.b_z,
B_W[0], input_dim, self.output_dim, timesteps)
x_r = time_distributed_dense(x, self.W_r, self.b_r,
B_W[1], input_dim, self.output_dim, timesteps)
x_h = time_distributed_dense(x, self.W_h, self.b_h,
B_W[2], input_dim, self.output_dim, timesteps)
return K.concatenate([x_z, x_r, x_h], axis=2)
def step(self, x, states):
h_tm1 = states[0] # previous memory
if len(states) == 2:
B_U = states[1] # dropout matrices for recurrent units
else:
B_U = [1., 1., 1.]
x_z = x[:, :self.output_dim]
x_r = x[:, self.output_dim: 2 * self.output_dim]
x_h = x[:, 2 * self.output_dim:]
z = self.inner_activation(x_z + K.dot(h_tm1 * B_U[0], self.U_z))
r = self.inner_activation(x_r + K.dot(h_tm1 * B_U[1], self.U_r))
@ -422,22 +475,14 @@ class GRU(Recurrent):
h = z * h_tm1 + (1 - z) * hh
return h, [h]
def get_constants(self, X, train=False):
retain_p_W = 1. - self.dropout_W
retain_p_U = 1. - self.dropout_U
if train and (self.dropout_W > 0 or self.dropout_U > 0):
nb_samples = K.shape(X)[0]
if K._BACKEND == 'tensorflow':
if not self.input_shape[0]:
raise Exception('For RNN dropout in tensorflow, a complete ' +
'input_shape must be provided (including batch size).')
nb_samples = self.input_shape[0]
B_W = [K.random_binomial((nb_samples, self.input_dim), p=retain_p_W) for _ in range(3)]
B_U = [K.random_binomial((nb_samples, self.output_dim), p=retain_p_U) for _ in range(3)]
def get_constants(self, x, train=False):
if train and (0 < self.dropout_U < 1):
ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1)))
ones = K.concatenate([ones] * self.output_dim, 1)
B_U = [K.dropout(ones, self.dropout_U) for _ in range(3)]
return [B_U]
else:
B_W = np.ones(3, dtype=K.floatx()) * retain_p_W
B_U = np.ones(3, dtype=K.floatx()) * retain_p_U
return [B_W, B_U]
B_U = []
def get_config(self):
config = {"output_dim": self.output_dim,
@ -572,41 +617,59 @@ class LSTM(Recurrent):
self.states = [K.zeros((input_shape[0], self.output_dim)),
K.zeros((input_shape[0], self.output_dim))]
def preprocess_input(self, x, train=False):
if train and (0 < self.dropout_W < 1):
ones = K.ones_like(K.reshape(x[:, 0, :], (-1, self.input_dim)))
B_W = [K.dropout(ones, self.dropout_W) for _ in range(4)]
else:
B_W = [None for _ in range(4)]
input_shape = self.input_shape
input_dim = input_shape[2]
if input_shape[1]:
timesteps = input_shape[1]
else:
# this won't work with TensorFlow
timesteps = K.shape(x)[1]
x_i = time_distributed_dense(x, self.W_i, self.b_i,
B_W[0], input_dim, self.output_dim, timesteps)
x_f = time_distributed_dense(x, self.W_f, self.b_f,
B_W[1], input_dim, self.output_dim, timesteps)
x_c = time_distributed_dense(x, self.W_c, self.b_c,
B_W[2], input_dim, self.output_dim, timesteps)
x_o = time_distributed_dense(x, self.W_o, self.b_o,
B_W[3], input_dim, self.output_dim, timesteps)
return K.concatenate([x_i, x_f, x_c, x_o], axis=2)
def step(self, x, states):
assert len(states) == 4 # 2 states and 2 constants
h_tm1 = states[0]
c_tm1 = states[1]
B_W = states[2]
B_U = states[3]
if len(states) == 3:
B_U = states[2]
else:
B_U = [1. for _ in range(4)]
x_i = K.dot(x * B_W[0], self.W_i) + self.b_i
x_f = K.dot(x * B_W[1], self.W_f) + self.b_f
x_c = K.dot(x * B_W[2], self.W_c) + self.b_c
x_o = K.dot(x * B_W[3], self.W_o) + self.b_o
x_i = x[:, :self.output_dim]
x_f = x[:, self.output_dim: 2 * self.output_dim]
x_c = x[:, 2 * self.output_dim: 3 * self.output_dim]
x_o = x[:, 3 * self.output_dim:]
i = self.inner_activation(x_i + K.dot(h_tm1 * B_U[0], self.U_i))
f = self.inner_activation(x_f + K.dot(h_tm1 * B_U[1], self.U_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1 * B_U[2], self.U_c))
o = self.inner_activation(x_o + K.dot(h_tm1 * B_U[3], self.U_o))
h = o * self.activation(c)
return h, [h, c]
def get_constants(self, X, train=False):
retain_p_W = 1. - self.dropout_W
retain_p_U = 1. - self.dropout_U
if train and (self.dropout_W > 0 or self.dropout_U > 0):
nb_samples = K.shape(X)[0]
if K._BACKEND == 'tensorflow':
if not self.input_shape[0]:
raise Exception('For RNN dropout in tensorflow, a complete ' +
'input_shape must be provided (including batch size).')
nb_samples = self.input_shape[0]
B_W = [K.random_binomial((nb_samples, self.input_dim), p=retain_p_W) for _ in range(4)]
B_U = [K.random_binomial((nb_samples, self.output_dim), p=retain_p_U) for _ in range(4)]
def get_constants(self, x, train=False):
if train and (0 < self.dropout_U < 1):
ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1)))
ones = K.concatenate([ones] * self.output_dim, 1)
B_U = [K.dropout(ones, self.dropout_U) for _ in range(4)]
return [B_U]
else:
B_W = np.ones(4, dtype=K.floatx()) * retain_p_W
B_U = np.ones(4, dtype=K.floatx()) * retain_p_U
return [B_W, B_U]
return []
def get_config(self):
config = {"output_dim": self.output_dim,