This commit is contained in:
Yarin 2016-02-22 17:54:23 +00:00
commit 1588998ee8
5 changed files with 29 additions and 31 deletions

@ -59,7 +59,7 @@ model.add(LSTM(50,
return_sequences=False,
stateful=True))
model.add(Dense(1))
model.compile(loss='rmse', optimizer='rmsprop')
model.compile(loss='mse', optimizer='rmsprop')
print('Training')
for i in range(epochs):

@ -1,11 +1,8 @@
from __future__ import absolute_import
import numpy as np
from .. import backend as K
from .. import activations, initializations, regularizers, constraints
from ..layers.core import Layer, MaskedLayer
from ..constraints import unitnorm
from .. import initializations, regularizers, constraints
from ..layers.core import Layer
class Embedding(Layer):
@ -108,7 +105,8 @@ class Embedding(Layer):
B = K.random_binomial((self.input_dim,), p=retain_p)
else:
B = K.ones((self.input_dim)) * retain_p
out = K.gather(self.W * K.expand_dims(B), X) # we zero-out rows of W at random
# we zero-out rows of W at random
out = K.gather(self.W * K.expand_dims(B), X)
return out
def get_config(self):

@ -5,7 +5,6 @@ import numpy as np
from .. import backend as K
from .. import activations, initializations, regularizers
from ..layers.core import MaskedLayer
from ..backend.common import _FLOATX
class Recurrent(MaskedLayer):
@ -208,7 +207,7 @@ class SimpleRNN(Recurrent):
'''
def __init__(self, output_dim,
init='glorot_uniform', inner_init='orthogonal',
activation='sigmoid',
activation='tanh',
W_regularizer=None, U_regularizer=None, b_regularizer=None,
dropout_W=0., dropout_U=0., **kwargs):
self.output_dim = output_dim
@ -265,8 +264,9 @@ class SimpleRNN(Recurrent):
self.states = [K.zeros((input_shape[0], self.output_dim))]
def step(self, x, states):
# states only contains the previous output.
assert len(states) == 3 # 1 state and 2 constants
# states contains the previous output,
# and the two dropout matrices from self.get_constants()
assert len(states) == 3 # 1 state and 2 constants
prev_output = states[0]
B_W = states[1]
B_U = states[2]
@ -287,8 +287,8 @@ class SimpleRNN(Recurrent):
B_W = K.random_binomial((nb_samples, self.input_dim), p=retain_p_W)
B_U = K.random_binomial((nb_samples, self.output_dim), p=retain_p_U)
else:
B_W = np.ones(1, dtype=_FLOATX) * retain_p_W
B_U = np.ones(1, dtype=_FLOATX) * retain_p_U
B_W = np.ones(1, dtype=K.floatx()) * retain_p_W
B_U = np.ones(1, dtype=K.floatx()) * retain_p_U
return [B_W, B_U]
def get_config(self):
@ -334,7 +334,7 @@ class GRU(Recurrent):
'''
def __init__(self, output_dim,
init='glorot_uniform', inner_init='orthogonal',
activation='sigmoid', inner_activation='hard_sigmoid',
activation='tanh', inner_activation='hard_sigmoid',
W_regularizer=None, U_regularizer=None, b_regularizer=None,
dropout_W=0., dropout_U=0., **kwargs):
self.output_dim = output_dim
@ -406,10 +406,10 @@ class GRU(Recurrent):
self.states = [K.zeros((input_shape[0], self.output_dim))]
def step(self, x, states):
assert len(states) == 3 # 1 state and 2 constants
h_tm1 = states[0]
B_W = states[1]
B_U = states[2]
assert len(states) == 3 # 1 state and 2 constants
h_tm1 = states[0] # previous memory
B_W = states[1] # dropout matrix for input units
B_U = states[2] # dropout matrix for recurrent units
x_z = K.dot(x * B_W[0], self.W_z) + self.b_z
x_r = K.dot(x * B_W[1], self.W_r) + self.b_r
@ -435,8 +435,8 @@ class GRU(Recurrent):
B_W = [K.random_binomial((nb_samples, self.input_dim), p=retain_p_W) for _ in range(3)]
B_U = [K.random_binomial((nb_samples, self.output_dim), p=retain_p_U) for _ in range(3)]
else:
B_W = np.ones(3, dtype=_FLOATX) * retain_p_W
B_U = np.ones(3, dtype=_FLOATX) * retain_p_U
B_W = np.ones(3, dtype=K.floatx()) * retain_p_W
B_U = np.ones(3, dtype=K.floatx()) * retain_p_U
return [B_W, B_U]
def get_config(self):
@ -573,7 +573,7 @@ class LSTM(Recurrent):
K.zeros((input_shape[0], self.output_dim))]
def step(self, x, states):
assert len(states) == 4 # 2 states and 2 constants
assert len(states) == 4 # 2 states and 2 constants
h_tm1 = states[0]
c_tm1 = states[1]
B_W = states[2]
@ -604,8 +604,8 @@ class LSTM(Recurrent):
B_W = [K.random_binomial((nb_samples, self.input_dim), p=retain_p_W) for _ in range(4)]
B_U = [K.random_binomial((nb_samples, self.output_dim), p=retain_p_U) for _ in range(4)]
else:
B_W = np.ones(4, dtype=_FLOATX) * retain_p_W
B_U = np.ones(4, dtype=_FLOATX) * retain_p_U
B_W = np.ones(4, dtype=K.floatx()) * retain_p_W
B_U = np.ones(4, dtype=K.floatx()) * retain_p_U
return [B_W, B_U]
def get_config(self):

@ -63,15 +63,15 @@ class Progbar(object):
numdigits = int(np.floor(np.log10(self.target))) + 1
barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
bar = barstr % (current, self.target)
prog = float(current)/self.target
prog_width = int(self.width*prog)
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('='*(prog_width-1))
bar += ('=' * (prog_width-1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.'*(self.width-prog_width))
bar += ('.' * (self.width - prog_width))
bar += ']'
sys.stdout.write(bar)
self.total_width = len(bar)
@ -80,7 +80,7 @@ class Progbar(object):
time_per_unit = (now - self.start) / current
else:
time_per_unit = 0
eta = time_per_unit*(self.target - current)
eta = time_per_unit * (self.target - current)
info = ''
if current < self.target:
info += ' - ETA: %ds' % eta
@ -99,7 +99,7 @@ class Progbar(object):
self.total_width += len(info)
if prev_total_width > self.total_width:
info += ((prev_total_width-self.total_width) * " ")
info += ((prev_total_width - self.total_width) * " ")
sys.stdout.write(info)
sys.stdout.flush()
@ -120,4 +120,4 @@ class Progbar(object):
sys.stdout.write(info + "\n")
def add(self, n, values=[]):
self.update(self.seen_so_far+n, values)
self.update(self.seen_so_far + n, values)

@ -273,7 +273,7 @@ class TestBackend(object):
check_single_tensor_operation('tanh', (4, 2))
# dropout
val = np.random.random((20, 20))
val = np.random.random((100, 100))
xth = KTH.variable(val)
xtf = KTF.variable(val)
zth = KTH.eval(KTH.dropout(xth, level=0.2))