Update cifar10 example

This commit is contained in:
fchollet 2015-07-04 14:28:24 -07:00
parent be75548ca3
commit dab55518ba

@ -73,7 +73,7 @@ if not data_augmentation:
X_test = X_test.astype("float32")
X_train /= 255
X_test /= 255
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=10)
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch)
score = model.evaluate(X_test, Y_test, batch_size=batch_size)
print('Test score:', score)
@ -105,14 +105,14 @@ else:
# batch train with realtime data augmentation
progbar = generic_utils.Progbar(X_train.shape[0])
for X_batch, Y_batch in datagen.flow(X_train, Y_train):
loss = model.train(X_batch, Y_batch)
loss = model.train_on_batch(X_batch, Y_batch)
progbar.add(X_batch.shape[0], values=[("train loss", loss)])
print("Testing...")
# test time!
progbar = generic_utils.Progbar(X_test.shape[0])
for X_batch, Y_batch in datagen.flow(X_test, Y_test):
score = model.test(X_batch, Y_batch)
score = model.test_on_batch(X_batch, Y_batch)
progbar.add(X_batch.shape[0], values=[("test loss", score)])