Add RemoteMonitor callback
This commit is contained in:
parent
97f23268c2
commit
5a1a00e69e
@ -4,7 +4,7 @@ import theano
|
||||
import theano.tensor as T
|
||||
import numpy as np
|
||||
import warnings
|
||||
import time
|
||||
import time, json
|
||||
from collections import deque
|
||||
|
||||
from .utils.generic_utils import Progbar
|
||||
@ -216,7 +216,7 @@ class ModelCheckpoint(Callback):
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
def __init__(self, patience=1, verbose=0):
|
||||
def __init__(self, patience=0, verbose=0):
|
||||
super(Callback, self).__init__()
|
||||
|
||||
self.patience = patience
|
||||
@ -238,3 +238,29 @@ class EarlyStopping(Callback):
|
||||
print("Epoch %05d: early stopping" % (epoch))
|
||||
self.model.stop_training = True
|
||||
self.wait += 1
|
||||
|
||||
|
||||
class RemoteMonitor(Callback):
|
||||
def __init__(self, root='http://localhost:9000'):
|
||||
self.root = root
|
||||
self.seen = 0
|
||||
self.tot_loss = 0.
|
||||
self.tot_accuracy = 0.
|
||||
|
||||
def on_epoch_begin(self, epoch, logs={}):
|
||||
self.seen = 0
|
||||
self.tot_loss = 0.
|
||||
self.tot_accuracy = 0.
|
||||
|
||||
def on_batch_end(self, batch, logs={}):
|
||||
batch_size = logs.get('size', 0)
|
||||
self.seen += batch_size
|
||||
self.tot_loss += logs.get('loss', 0.) * batch_size
|
||||
if self.params['show_accuracy']:
|
||||
self.tot_accuracy += logs.get('accuracy', 0.) * batch_size
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
import requests
|
||||
logs['epoch'] = epoch
|
||||
logs['loss'] = self.tot_loss / self.seen
|
||||
r = requests.post(self.root + '/publish/epoch/end/', {'data':json.dumps(logs)})
|
||||
|
Loading…
Reference in New Issue
Block a user