Layer API refactor

This commit is contained in:
fchollet 2015-05-09 15:36:47 -07:00
parent 03fa41801a
commit be1ec8aecc
7 changed files with 29 additions and 26 deletions

@ -6,7 +6,7 @@ class LeakyReLU(Layer):
super(LeakyReLU,self).__init__()
self.alpha = alpha
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
return ((X + abs(X)) / 2.0) + self.alpha * ((X - abs(X)) / 2.0)
@ -27,7 +27,7 @@ class PReLU(Layer):
self.params = [self.alphas]
self.input_shape = input_shape
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
pos = ((X + abs(X)) / 2.0)
neg = self.alphas * ((X - abs(X)) / 2.0)

@ -41,7 +41,7 @@ class Convolution2D(Layer):
if weights is not None:
self.set_weights(weights)
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
conv_out = theano.tensor.nnet.conv.conv2d(X, self.W,
@ -69,7 +69,7 @@ class MaxPooling2D(Layer):
self.poolsize = poolsize
self.ignore_border = ignore_border
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
output = downsample.max_pool_2d(X, self.poolsize, ignore_border=self.ignore_border)
return output

@ -16,15 +16,15 @@ class Layer(object):
def __init__(self):
self.params = []
def connect(self, previous_layer):
self.previous_layer = previous_layer
def connect(self, node):
self.previous = node
def output(self, train):
def get_output(self, train):
raise NotImplementedError
def get_input(self, train):
if hasattr(self, 'previous_layer'):
return self.previous_layer.output(train=train)
if hasattr(self, 'previous'):
return self.previous.get_output(train=train)
else:
return self.input
@ -50,7 +50,7 @@ class Dropout(Layer):
super(Dropout,self).__init__()
self.p = p
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
if self.p > 0.:
retain_prob = 1. - self.p
@ -75,7 +75,7 @@ class Activation(Layer):
self.target = target
self.beta = beta
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
return self.activation(X)
@ -96,7 +96,7 @@ class Reshape(Layer):
super(Reshape,self).__init__()
self.dims = dims
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
nshape = make_tuple(X.shape[0], *self.dims)
return theano.tensor.reshape(X, nshape)
@ -114,7 +114,7 @@ class Flatten(Layer):
def __init__(self):
super(Flatten,self).__init__()
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
size = theano.tensor.prod(X.shape) // X.shape[0]
nshape = (X.shape[0], size)
@ -132,7 +132,7 @@ class RepeatVector(Layer):
super(RepeatVector,self).__init__()
self.n = n
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
tensors = [X]*self.n
stacked = theano.tensor.stack(*tensors)
@ -168,7 +168,7 @@ class Dense(Layer):
if weights is not None:
self.set_weights(weights)
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
output = self.activation(T.dot(X, self.W) + self.b)
return output
@ -210,7 +210,7 @@ class TimeDistributedDense(Layer):
if weights is not None:
self.set_weights(weights)
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
def act_func(X):

@ -27,7 +27,7 @@ class Embedding(Layer):
if weights is not None:
self.set_weights(weights)
def output(self, train=False):
def get_output(self, train=False):
X = self.get_input(train)
out = self.W[X]
return out
@ -83,7 +83,7 @@ class WordContextProduct(Layer):
self.set_weights(weights)
def output(self, train=False):
def get_output(self, train=False):
X = self.get_input(train)
w = self.W_w[X[:, 0]] # nb_samples, proj_dim
c = self.W_c[X[:, 1]] # nb_samples, proj_dim

@ -27,7 +27,7 @@ class BatchNormalization(Layer):
if weights is not None:
self.set_weights(weights)
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
if self.mode == 0:

@ -46,7 +46,7 @@ class SimpleRNN(Layer):
'''
return self.activation(x_t + T.dot(h_tm1, u))
def output(self, train):
def get_output(self, train):
X = self.get_input(train) # shape: (nb_samples, time (padded with zeros at the end), input_dim)
# new shape: (time, nb_samples, input_dim) -> because theano.scan iterates over main dimension
X = X.dimshuffle((1,0,2))
@ -119,7 +119,7 @@ class SimpleDeepRNN(Layer):
o += self.inner_activation(T.dot(args[i], args[i+self.depth]))
return self.activation(o)
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
X = X.dimshuffle((1,0,2))
@ -222,7 +222,7 @@ class GRU(Layer):
h_t = z * h_tm1 + (1 - z) * hh_t
return h_t
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
X = X.dimshuffle((1,0,2))
@ -331,7 +331,7 @@ class LSTM(Layer):
h_t = o_t * self.activation(c_t)
return h_t, c_t
def output(self, train):
def get_output(self, train):
X = self.get_input(train)
X = X.dimshuffle((1,0,2))

@ -66,7 +66,9 @@ class Sequential(object):
self.constraints += [layer.constraint for _ in range(len(layer.params))]
else:
self.constraints += [constraints.identity for _ in range(len(layer.params))]
def get_output(self, train):
return self.layers[-1].get_output(train)
def compile(self, optimizer, loss, class_mode="categorical", y_dim_components=1):
self.optimizer = optimizers.get(optimizer)
@ -81,8 +83,8 @@ class Sequential(object):
self.layers[0].input = ndim_tensor(ndim)
self.X = self.layers[0].input
self.y_train = self.layers[-1].output(train=True)
self.y_test = self.layers[-1].output(train=False)
self.y_train = self.get_output(train=True)
self.y_test = self.get_output(train=False)
# output of model
self.y = ndim_tensor(y_dim_components+1)
@ -114,6 +116,7 @@ class Sequential(object):
self._test_with_acc = theano.function([self.X, self.y], [test_score, test_accuracy],
allow_input_downcast=True)
def train(self, X, y, accuracy=False):
y = standardize_y(y)
if accuracy: