keras/keras_core/callbacks/remote_monitor.py
Ramesh Sampath d04747f234 Adds RemoteMonitor Callback (#108)
* Add Remote Monitor Callback

* Add Remote Monitor Callback

* Add Remote Monitor Callback

* Add Remote Monitor
2023-05-08 15:15:17 -05:00

84 lines
2.7 KiB
Python

import json
import warnings
import numpy as np
from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
try:
import requests
except ImportError:
requests = None
@keras_core_export("keras_core.callbacks.RemoteMonitor")
class RemoteMonitor(Callback):
"""Callback used to stream events to a server.
Requires the `requests` library.
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
HTTP POST, with a `data` argument which is a
JSON-encoded dictionary of event data.
If `send_as_json=True`, the content type of the request will be
`"application/json"`.
Otherwise the serialized JSON will be sent within a form.
Args:
root: String; root url of the target server.
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
The field is used only if the payload is sent within a form
(i.e. send_as_json is set to False).
headers: Dictionary; optional custom HTTP headers.
send_as_json: Boolean; whether the request should be
sent as `"application/json"`.
"""
def __init__(
self,
root="http://localhost:9000",
path="/publish/epoch/end/",
field="data",
headers=None,
send_as_json=False,
):
super().__init__()
self.root = root
self.path = path
self.field = field
self.headers = headers
self.send_as_json = send_as_json
def on_epoch_end(self, epoch, logs=None):
if requests is None:
raise ImportError("RemoteMonitor requires the `requests` library.")
logs = logs or {}
send = {}
send["epoch"] = epoch
for k, v in logs.items():
# np.ndarray and np.generic are not scalar types
# therefore we must unwrap their scalar values and
# pass to the json-serializable dict 'send'
if isinstance(v, (np.ndarray, np.generic)):
send[k] = v.item()
else:
send[k] = v
try:
if self.send_as_json:
requests.post(
self.root + self.path, json=send, headers=self.headers
)
else:
requests.post(
self.root + self.path,
{self.field: json.dumps(send)},
headers=self.headers,
)
except requests.exceptions.RequestException:
warnings.warn(
f"Could not reach RemoteMonitor root server at {self.root}",
stacklevel=2,
)