Jax trainer checkpoint.

This commit is contained in:
Francois Chollet 2023-04-18 13:02:29 -07:00
parent d0f3799dd8
commit 0f97daf7a7
4 changed files with 74 additions and 95 deletions

@ -18,11 +18,11 @@ class MiniDense(Layer):
input_dim = input_shape[-1]
w_shape = (input_dim, self.units)
w_value = initializers.GlorotUniform()(w_shape)
self.w = backend.Variable(w_value)
self.w = backend.Variable(w_value, name="kernel")
b_shape = (self.units,)
b_value = initializers.Zeros()(b_shape)
self.b = backend.Variable(b_value)
self.b = backend.Variable(b_value, name="bias")
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
@ -49,10 +49,10 @@ class MiniBatchNorm(Layer):
def build(self, input_shape):
shape = (input_shape[-1],)
self.mean = backend.Variable(
initializers.Zeros()(shape), trainable=False
initializers.Zeros()(shape), trainable=False, name="mean"
)
self.variance = backend.Variable(
initializers.GlorotUniform()(shape), trainable=False
initializers.GlorotUniform()(shape), trainable=False, name="variance"
)
self.beta = backend.Variable(initializers.Zeros()(shape))
self.gamma = backend.Variable(initializers.Ones()(shape))
@ -79,14 +79,14 @@ class MyModel(Layer):
def __init__(self, units, num_classes):
super().__init__()
self.dense1 = MiniDense(units)
self.bn = MiniBatchNorm()
self.dropout = MiniDropout(0.5)
# self.bn = MiniBatchNorm()
# self.dropout = MiniDropout(0.5)
self.dense2 = MiniDense(num_classes)
def call(self, x):
x = self.dense1(x)
x = self.bn(x)
x = self.dropout(x)
# x = self.bn(x)
# x = self.dropout(x)
return self.dense2(x)
@ -120,6 +120,7 @@ def compute_loss_and_updates(
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
@ -155,10 +156,12 @@ for x, y in dataset:
# Post-processing model state update
for variable, value in zip(model.trainable_variables, trainable_variables):
print(variable.name, np.sum(np.abs(variable - value)))
variable.assign(value)
for variable, value in zip(
model.non_trainable_variables, non_trainable_variables
):
print(variable.name, np.sum(np.abs(variable - value)))
variable.assign(value)
print("Updated values")

@ -52,11 +52,14 @@ class Variable(KerasVariable):
self.name = name or auto_name(self.__class__.__name__)
dtype = standardize_dtype(dtype)
self._value = jnp.array(value, dtype=dtype)
self._dtype = dtype
self._shape = tuple(self._value.shape)
self._ndim = len(self._shape)
self.trainable = trainable
def assign(self, value):
value = convert_to_tensor(value, dtype=self.dtype)
if value.shape != self.value.shape:
if value.shape != self.shape:
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
@ -68,8 +71,11 @@ class Variable(KerasVariable):
scope = get_stateless_scope()
scope.add_update((self, value))
else:
# TODO: optimize by avoiding memory copies
self._value = jnp.array(value, dtype=self.dtype)
if isinstance(value, jnp.ndarray) and value.dtype == self.dtype:
# Avoid a memory copy
self._value = value
else:
self._value = jnp.array(value, dtype=self.dtype)
@property
def value(self):
@ -82,15 +88,15 @@ class Variable(KerasVariable):
@property
def dtype(self):
return self.value.dtype.name
return self._dtype
@property
def shape(self):
return self.value.shape
return self._shape
@property
def ndim(self):
return self.value.ndim
return self._ndim
def numpy(self):
return np.array(self.value)

@ -1,4 +1,5 @@
import jax
import numpy as np
from keras_core import backend
from keras_core import callbacks as callbacks_module
@ -10,73 +11,6 @@ from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer):
def compute_loss_and_updates(
self, trainable_variables, non_trainable_variables, data
):
x, y, sample_weight = data_adapters_utils.unpack_x_y_sample_weight(data)
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = self.compute_loss(x, y, y_pred, sample_weight=sample_weight)
return loss, (non_trainable_variables, y_pred)
def _get_gradient_fn(self):
return jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
"""Stateless loss computation."""
del x # The default implementation does not use `x`.
losses = []
if self._compile_loss is not None:
loss = self._compile_loss(y, y_pred, sample_weight)
if loss is not None:
losses.append(loss)
for l in self.losses:
losses.append(l.astype(backend.floatx()))
if len(losses) == 0:
raise ValueError(
"No loss to compute. Provide a `loss` argument in `compile()`."
)
if len(losses) == 1:
total_loss = losses[0]
else:
total_loss = sum(losses)
return total_loss
def train_step(self, trainable_variables, non_trainable_variables, optimizer_variables, data):
grad_fn = self._get_gradient_fn()
(loss, (non_trainable_variables, y_pred)), grads = grad_fn(
trainable_variables, non_trainable_variables, data
)
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
grads, optimizer_variables
)
return trainable_variables, non_trainable_variables, optimizer_variables, y_pred, loss
def test_step(self, data):
raise NotImplementedError
def predict_step(self, data):
raise NotImplementedError
def make_train_function(self, force=False):
if self.run_eagerly or not self.jit_compile:
self.train_function = self.train_step
return self.train_step
@jax.jit
def train_fn(trainable_variables, non_trainable_variables, optimizer_variables, data):
return self.train_step(trainable_variables, non_trainable_variables, optimizer_variables, data)
self.train_function = train_fn
return train_fn
def make_test_function(self, force=False):
raise NotImplementedError
def make_predict_function(self, force=False):
raise NotImplementedError
def fit(
self,
x=None,
@ -152,7 +86,7 @@ class Trainer(base_trainer.Trainer):
)
self.stop_training = False
self.make_train_function()
# self.make_train_function()
callbacks.on_train_begin()
training_logs = None
logs = None
@ -163,6 +97,41 @@ class Trainer(base_trainer.Trainer):
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, x, y
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = self._compile_loss(y, y_pred)
return loss, (y_pred, non_trainable_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(
trainable_variables, non_trainable_variables, optimizer_variables, x, y
):
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
# trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
# grads, optimizer_variables
# )
new_trainable_variables = []
for grad, var in zip(grads, trainable_variables):
new_trainable_variables.append(var - 0.0001 * grad)
trainable_variables = new_trainable_variables
return loss, trainable_variables, non_trainable_variables, optimizer_variables
# counter = jax.numpy.ones(())
for epoch in range(initial_epoch, epochs):
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
@ -170,22 +139,24 @@ class Trainer(base_trainer.Trainer):
# Callbacks
callbacks.on_train_batch_begin(step)
# Train step (in JAX context)
trainable_variables, non_trainable_variables, optimizer_variables, y_pred, loss = self.train_function(trainable_variables, non_trainable_variables, optimizer_variables, data)
first_var = trainable_variables[0]
# Run variable updates (back in eager context)
# Train step
x, y, sample_weight = data_adapters_utils.unpack_x_y_sample_weight(data)
logs = self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
# for variable, value in zip(self.trainable_variables, trainable_variables):
# variable.assign(value)
# for variable, value in zip(
# self.non_trainable_variables, non_trainable_variables
# ):
# variable.assign(value)
# for variable, value in zip(self.optimizer.variables, optimizer_variables):
# variable.assign(value)
(
loss,
trainable_variables,
non_trainable_variables,
optimizer_variables,
) = train_step(
trainable_variables, non_trainable_variables, optimizer_variables, x, y
)
self._loss_tracker.update_state(loss)
# print(np.sum(np.abs(trainable_variables[0] - first_var)))
# Callbacks
callbacks.on_train_batch_end(step, logs)
if self.stop_training:

@ -31,4 +31,3 @@ class TestStatelessScope(testing.TestCase):
# Updates can be reapplied.
var_out.assign(scope.get_current_value(var_out))
self.assertAllClose(var_out_value, 2 * np.ones((2,)))