Add RemoteMonitor callback

This commit is contained in:
fchollet 2015-06-20 19:32:28 -07:00
parent 97f23268c2
commit 5a1a00e69e

@ -4,7 +4,7 @@ import theano
import theano.tensor as T
import numpy as np
import warnings
import time
import time, json
from collections import deque
from .utils.generic_utils import Progbar
@ -216,7 +216,7 @@ class ModelCheckpoint(Callback):
class EarlyStopping(Callback):
def __init__(self, patience=1, verbose=0):
def __init__(self, patience=0, verbose=0):
super(Callback, self).__init__()
self.patience = patience
@ -238,3 +238,29 @@ class EarlyStopping(Callback):
print("Epoch %05d: early stopping" % (epoch))
self.model.stop_training = True
self.wait += 1
class RemoteMonitor(Callback):
def __init__(self, root='http://localhost:9000'):
self.root = root
self.seen = 0
self.tot_loss = 0.
self.tot_accuracy = 0.
def on_epoch_begin(self, epoch, logs={}):
self.seen = 0
self.tot_loss = 0.
self.tot_accuracy = 0.
def on_batch_end(self, batch, logs={}):
batch_size = logs.get('size', 0)
self.seen += batch_size
self.tot_loss += logs.get('loss', 0.) * batch_size
if self.params['show_accuracy']:
self.tot_accuracy += logs.get('accuracy', 0.) * batch_size
def on_epoch_end(self, epoch, logs={}):
import requests
logs['epoch'] = epoch
logs['loss'] = self.tot_loss / self.seen
r = requests.post(self.root + '/publish/epoch/end/', {'data':json.dumps(logs)})