Jax trainer checkpoint.
This commit is contained in:
parent
d0f3799dd8
commit
0f97daf7a7
@ -18,11 +18,11 @@ class MiniDense(Layer):
|
||||
input_dim = input_shape[-1]
|
||||
w_shape = (input_dim, self.units)
|
||||
w_value = initializers.GlorotUniform()(w_shape)
|
||||
self.w = backend.Variable(w_value)
|
||||
self.w = backend.Variable(w_value, name="kernel")
|
||||
|
||||
b_shape = (self.units,)
|
||||
b_value = initializers.Zeros()(b_shape)
|
||||
self.b = backend.Variable(b_value)
|
||||
self.b = backend.Variable(b_value, name="bias")
|
||||
|
||||
def call(self, inputs):
|
||||
return ops.matmul(inputs, self.w) + self.b
|
||||
@ -49,10 +49,10 @@ class MiniBatchNorm(Layer):
|
||||
def build(self, input_shape):
|
||||
shape = (input_shape[-1],)
|
||||
self.mean = backend.Variable(
|
||||
initializers.Zeros()(shape), trainable=False
|
||||
initializers.Zeros()(shape), trainable=False, name="mean"
|
||||
)
|
||||
self.variance = backend.Variable(
|
||||
initializers.GlorotUniform()(shape), trainable=False
|
||||
initializers.GlorotUniform()(shape), trainable=False, name="variance"
|
||||
)
|
||||
self.beta = backend.Variable(initializers.Zeros()(shape))
|
||||
self.gamma = backend.Variable(initializers.Ones()(shape))
|
||||
@ -79,14 +79,14 @@ class MyModel(Layer):
|
||||
def __init__(self, units, num_classes):
|
||||
super().__init__()
|
||||
self.dense1 = MiniDense(units)
|
||||
self.bn = MiniBatchNorm()
|
||||
self.dropout = MiniDropout(0.5)
|
||||
# self.bn = MiniBatchNorm()
|
||||
# self.dropout = MiniDropout(0.5)
|
||||
self.dense2 = MiniDense(num_classes)
|
||||
|
||||
def call(self, x):
|
||||
x = self.dense1(x)
|
||||
x = self.bn(x)
|
||||
x = self.dropout(x)
|
||||
# x = self.bn(x)
|
||||
# x = self.dropout(x)
|
||||
return self.dense2(x)
|
||||
|
||||
|
||||
@ -120,6 +120,7 @@ def compute_loss_and_updates(
|
||||
y_pred, non_trainable_variables = model.stateless_call(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
|
||||
loss = loss_fn(y, y_pred)
|
||||
return loss, non_trainable_variables
|
||||
|
||||
@ -155,10 +156,12 @@ for x, y in dataset:
|
||||
|
||||
# Post-processing model state update
|
||||
for variable, value in zip(model.trainable_variables, trainable_variables):
|
||||
print(variable.name, np.sum(np.abs(variable - value)))
|
||||
variable.assign(value)
|
||||
for variable, value in zip(
|
||||
model.non_trainable_variables, non_trainable_variables
|
||||
):
|
||||
print(variable.name, np.sum(np.abs(variable - value)))
|
||||
variable.assign(value)
|
||||
|
||||
print("Updated values")
|
||||
|
@ -52,11 +52,14 @@ class Variable(KerasVariable):
|
||||
self.name = name or auto_name(self.__class__.__name__)
|
||||
dtype = standardize_dtype(dtype)
|
||||
self._value = jnp.array(value, dtype=dtype)
|
||||
self._dtype = dtype
|
||||
self._shape = tuple(self._value.shape)
|
||||
self._ndim = len(self._shape)
|
||||
self.trainable = trainable
|
||||
|
||||
def assign(self, value):
|
||||
value = convert_to_tensor(value, dtype=self.dtype)
|
||||
if value.shape != self.value.shape:
|
||||
if value.shape != self.shape:
|
||||
raise ValueError(
|
||||
"The shape of the target variable and "
|
||||
"the shape of the target value in "
|
||||
@ -68,8 +71,11 @@ class Variable(KerasVariable):
|
||||
scope = get_stateless_scope()
|
||||
scope.add_update((self, value))
|
||||
else:
|
||||
# TODO: optimize by avoiding memory copies
|
||||
self._value = jnp.array(value, dtype=self.dtype)
|
||||
if isinstance(value, jnp.ndarray) and value.dtype == self.dtype:
|
||||
# Avoid a memory copy
|
||||
self._value = value
|
||||
else:
|
||||
self._value = jnp.array(value, dtype=self.dtype)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
@ -82,15 +88,15 @@ class Variable(KerasVariable):
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.value.dtype.name
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.value.shape
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return self.value.ndim
|
||||
return self._ndim
|
||||
|
||||
def numpy(self):
|
||||
return np.array(self.value)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import callbacks as callbacks_module
|
||||
@ -10,73 +11,6 @@ from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
|
||||
class Trainer(base_trainer.Trainer):
|
||||
|
||||
def compute_loss_and_updates(
|
||||
self, trainable_variables, non_trainable_variables, data
|
||||
):
|
||||
x, y, sample_weight = data_adapters_utils.unpack_x_y_sample_weight(data)
|
||||
y_pred, non_trainable_variables = self.stateless_call(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
loss = self.compute_loss(x, y, y_pred, sample_weight=sample_weight)
|
||||
return loss, (non_trainable_variables, y_pred)
|
||||
|
||||
def _get_gradient_fn(self):
|
||||
return jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
|
||||
|
||||
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
|
||||
"""Stateless loss computation."""
|
||||
del x # The default implementation does not use `x`.
|
||||
losses = []
|
||||
if self._compile_loss is not None:
|
||||
loss = self._compile_loss(y, y_pred, sample_weight)
|
||||
if loss is not None:
|
||||
losses.append(loss)
|
||||
for l in self.losses:
|
||||
losses.append(l.astype(backend.floatx()))
|
||||
if len(losses) == 0:
|
||||
raise ValueError(
|
||||
"No loss to compute. Provide a `loss` argument in `compile()`."
|
||||
)
|
||||
if len(losses) == 1:
|
||||
total_loss = losses[0]
|
||||
else:
|
||||
total_loss = sum(losses)
|
||||
return total_loss
|
||||
|
||||
def train_step(self, trainable_variables, non_trainable_variables, optimizer_variables, data):
|
||||
grad_fn = self._get_gradient_fn()
|
||||
(loss, (non_trainable_variables, y_pred)), grads = grad_fn(
|
||||
trainable_variables, non_trainable_variables, data
|
||||
)
|
||||
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
|
||||
grads, optimizer_variables
|
||||
)
|
||||
return trainable_variables, non_trainable_variables, optimizer_variables, y_pred, loss
|
||||
|
||||
def test_step(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_step(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_train_function(self, force=False):
|
||||
if self.run_eagerly or not self.jit_compile:
|
||||
self.train_function = self.train_step
|
||||
return self.train_step
|
||||
|
||||
@jax.jit
|
||||
def train_fn(trainable_variables, non_trainable_variables, optimizer_variables, data):
|
||||
return self.train_step(trainable_variables, non_trainable_variables, optimizer_variables, data)
|
||||
|
||||
self.train_function = train_fn
|
||||
return train_fn
|
||||
|
||||
def make_test_function(self, force=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_predict_function(self, force=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(
|
||||
self,
|
||||
x=None,
|
||||
@ -152,7 +86,7 @@ class Trainer(base_trainer.Trainer):
|
||||
)
|
||||
|
||||
self.stop_training = False
|
||||
self.make_train_function()
|
||||
# self.make_train_function()
|
||||
callbacks.on_train_begin()
|
||||
training_logs = None
|
||||
logs = None
|
||||
@ -163,6 +97,41 @@ class Trainer(base_trainer.Trainer):
|
||||
non_trainable_variables = self.non_trainable_variables
|
||||
optimizer_variables = self.optimizer.variables
|
||||
|
||||
|
||||
def compute_loss_and_updates(
|
||||
trainable_variables, non_trainable_variables, x, y
|
||||
):
|
||||
y_pred, non_trainable_variables = self.stateless_call(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
|
||||
loss = self._compile_loss(y, y_pred)
|
||||
return loss, (y_pred, non_trainable_variables)
|
||||
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def train_step(
|
||||
trainable_variables, non_trainable_variables, optimizer_variables, x, y
|
||||
):
|
||||
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
|
||||
trainable_variables, non_trainable_variables, x, y
|
||||
)
|
||||
# trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
|
||||
# grads, optimizer_variables
|
||||
# )
|
||||
new_trainable_variables = []
|
||||
for grad, var in zip(grads, trainable_variables):
|
||||
new_trainable_variables.append(var - 0.0001 * grad)
|
||||
trainable_variables = new_trainable_variables
|
||||
|
||||
|
||||
return loss, trainable_variables, non_trainable_variables, optimizer_variables
|
||||
|
||||
# counter = jax.numpy.ones(())
|
||||
|
||||
for epoch in range(initial_epoch, epochs):
|
||||
self.reset_metrics()
|
||||
callbacks.on_epoch_begin(epoch)
|
||||
@ -170,22 +139,24 @@ class Trainer(base_trainer.Trainer):
|
||||
# Callbacks
|
||||
callbacks.on_train_batch_begin(step)
|
||||
|
||||
# Train step (in JAX context)
|
||||
trainable_variables, non_trainable_variables, optimizer_variables, y_pred, loss = self.train_function(trainable_variables, non_trainable_variables, optimizer_variables, data)
|
||||
first_var = trainable_variables[0]
|
||||
|
||||
# Run variable updates (back in eager context)
|
||||
|
||||
# Train step
|
||||
x, y, sample_weight = data_adapters_utils.unpack_x_y_sample_weight(data)
|
||||
logs = self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
|
||||
# for variable, value in zip(self.trainable_variables, trainable_variables):
|
||||
# variable.assign(value)
|
||||
# for variable, value in zip(
|
||||
# self.non_trainable_variables, non_trainable_variables
|
||||
# ):
|
||||
# variable.assign(value)
|
||||
# for variable, value in zip(self.optimizer.variables, optimizer_variables):
|
||||
# variable.assign(value)
|
||||
(
|
||||
loss,
|
||||
trainable_variables,
|
||||
non_trainable_variables,
|
||||
optimizer_variables,
|
||||
) = train_step(
|
||||
trainable_variables, non_trainable_variables, optimizer_variables, x, y
|
||||
)
|
||||
|
||||
self._loss_tracker.update_state(loss)
|
||||
|
||||
# print(np.sum(np.abs(trainable_variables[0] - first_var)))
|
||||
|
||||
# Callbacks
|
||||
callbacks.on_train_batch_end(step, logs)
|
||||
if self.stop_training:
|
||||
|
@ -31,4 +31,3 @@ class TestStatelessScope(testing.TestCase):
|
||||
# Updates can be reapplied.
|
||||
var_out.assign(scope.get_current_value(var_out))
|
||||
self.assertAllClose(var_out_value, 2 * np.ones((2,)))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user