Remove optional loss masking (now automatic)

This commit is contained in:
fchollet 2015-08-16 05:28:33 +09:00
parent 9d76926eba
commit 7115efc37b
2 changed files with 6 additions and 16 deletions

@ -345,8 +345,7 @@ class Sequential(Model, containers.Sequential):
- set_weights
'''
def compile(self, optimizer, loss, class_mode="categorical", theano_mode=None,
mask_cost=False):
def compile(self, optimizer, loss, class_mode="categorical", theano_mode=None):
self.optimizer = optimizers.get(optimizer)
self.loss = objectives.get(loss)
@ -364,7 +363,7 @@ class Sequential(Model, containers.Sequential):
self.weights = T.ones_like(self.y_train)
if mask_cost:
if hasattr(self.layers[-1], "get_output_mask"):
mask = self.layers[-1].get_output_mask()
else:
mask = None
@ -556,7 +555,7 @@ class Sequential(Model, containers.Sequential):
class Graph(Model, containers.Graph):
def compile(self, optimizer, loss, theano_mode=None, mask_cost=False):
def compile(self, optimizer, loss, theano_mode=None):
# loss is a dictionary mapping output name to loss functions
ys = []
ys_train = []
@ -574,10 +573,10 @@ class Graph(Model, containers.Graph):
ys_train.append(y_train)
ys_test.append(y_test)
if mask_cost is None:
mask = None
if hasattr(self.layers[-1], "get_output_mask"):
mask = self.layers[-1].get_output_mask()
else:
mask = output.get_output_mask()
mask = None
weight = T.ones_like(y_test)
weights.append(weight)

@ -9,20 +9,11 @@ class TestLossMasking(unittest.TestCase):
X = np.array(
[[[1, 1], [2, 1], [3, 1], [5, 5]],
[[1, 5], [5, 0], [0, 0], [0, 0]]], dtype=np.int32)
model = Sequential()
model.add(Masking(mask_value=0))
model.add(TimeDistributedDense(2, 1, init='one'))
model.compile(loss='mse', optimizer='sgd')
y = model.predict(X)
loss = model.fit(X, 4*y, nb_epoch=1, batch_size=2, verbose=1).history['loss'][0]
assert loss == 213.75
model = Sequential()
model.add(Masking(mask_value=0))
model.add(TimeDistributedDense(2, 1, init='one'))
model.compile(loss='mse', optimizer='sgd', mask_cost=True)
loss = model.fit(X, 4*y, nb_epoch=1, batch_size=2, verbose=1).history['loss'][0]
assert loss == 282.375