2023-04-09 19:21:45 +00:00
|
|
|
import jax
|
|
|
|
import numpy as np
|
2023-04-12 18:31:58 +00:00
|
|
|
|
|
|
|
from keras_core import backend
|
|
|
|
from keras_core import initializers
|
|
|
|
from keras_core import operations as ops
|
|
|
|
from keras_core.layers.layer import Layer
|
|
|
|
from keras_core.optimizers import SGD
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MiniDense(Layer):
|
|
|
|
def __init__(self, units, name=None):
|
|
|
|
super().__init__(name=name)
|
|
|
|
self.units = units
|
|
|
|
|
|
|
|
def build(self, input_shape):
|
|
|
|
input_dim = input_shape[-1]
|
|
|
|
w_shape = (input_dim, self.units)
|
|
|
|
w_value = initializers.GlorotUniform()(w_shape)
|
2023-04-18 20:02:29 +00:00
|
|
|
self.w = backend.Variable(w_value, name="kernel")
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
b_shape = (self.units,)
|
|
|
|
b_value = initializers.Zeros()(b_shape)
|
2023-04-18 20:02:29 +00:00
|
|
|
self.b = backend.Variable(b_value, name="bias")
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def call(self, inputs):
|
|
|
|
return ops.matmul(inputs, self.w) + self.b
|
|
|
|
|
|
|
|
|
|
|
|
class MiniDropout(Layer):
|
|
|
|
def __init__(self, rate, name=None):
|
|
|
|
super().__init__(name=name)
|
|
|
|
self.rate = rate
|
2023-04-13 04:07:17 +00:00
|
|
|
self.seed_generator = backend.random.SeedGenerator(1337)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def call(self, inputs):
|
2023-04-12 18:00:14 +00:00
|
|
|
return backend.random.dropout(
|
|
|
|
inputs, self.rate, seed=self.seed_generator
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MiniBatchNorm(Layer):
|
|
|
|
def __init__(self, name=None):
|
|
|
|
super().__init__(name=name)
|
|
|
|
self.epsilon = 1e-5
|
|
|
|
self.momentum = 0.99
|
|
|
|
|
|
|
|
def build(self, input_shape):
|
|
|
|
shape = (input_shape[-1],)
|
2023-04-12 18:00:14 +00:00
|
|
|
self.mean = backend.Variable(
|
2023-04-18 20:02:29 +00:00
|
|
|
initializers.Zeros()(shape), trainable=False, name="mean"
|
2023-04-12 18:00:14 +00:00
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
self.variance = backend.Variable(
|
2023-04-21 22:01:17 +00:00
|
|
|
initializers.Ones()(shape),
|
2023-04-18 21:49:38 +00:00
|
|
|
trainable=False,
|
|
|
|
name="variance",
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
self.beta = backend.Variable(initializers.Zeros()(shape))
|
|
|
|
self.gamma = backend.Variable(initializers.Ones()(shape))
|
|
|
|
|
|
|
|
def call(self, inputs, training=False):
|
|
|
|
if training:
|
2023-04-21 22:01:17 +00:00
|
|
|
mean = ops.mean(inputs, axis=(0,)) # TODO: extend to rank 3+
|
|
|
|
variance = ops.var(inputs, axis=(0,))
|
2023-04-09 19:21:45 +00:00
|
|
|
outputs = (inputs - mean) / (variance + self.epsilon)
|
|
|
|
self.variance.assign(
|
|
|
|
self.variance * self.momentum + variance * (1.0 - self.momentum)
|
|
|
|
)
|
2023-04-12 18:00:14 +00:00
|
|
|
self.mean.assign(
|
|
|
|
self.mean * self.momentum + mean * (1.0 - self.momentum)
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
else:
|
|
|
|
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
|
|
|
|
outputs *= self.gamma
|
|
|
|
outputs += self.beta
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
class MyModel(Layer):
|
|
|
|
def __init__(self, units, num_classes):
|
|
|
|
super().__init__()
|
|
|
|
self.dense1 = MiniDense(units)
|
2023-04-18 20:02:29 +00:00
|
|
|
# self.bn = MiniBatchNorm()
|
2023-04-18 23:21:27 +00:00
|
|
|
self.dropout = MiniDropout(0.5)
|
2023-04-09 19:21:45 +00:00
|
|
|
self.dense2 = MiniDense(num_classes)
|
|
|
|
|
|
|
|
def call(self, x):
|
|
|
|
x = self.dense1(x)
|
2023-04-18 20:02:29 +00:00
|
|
|
# x = self.bn(x)
|
2023-04-18 23:21:27 +00:00
|
|
|
x = self.dropout(x)
|
2023-04-09 19:21:45 +00:00
|
|
|
return self.dense2(x)
|
|
|
|
|
|
|
|
|
|
|
|
def Dataset():
|
|
|
|
for _ in range(10):
|
|
|
|
yield (np.random.random((8, 4)), np.random.random((8, 2)))
|
|
|
|
|
|
|
|
|
|
|
|
def loss_fn(y_true, y_pred):
|
|
|
|
return ops.sum((y_true - y_pred) ** 2)
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = SGD()
|
|
|
|
model = MyModel(8, 2)
|
|
|
|
dataset = Dataset()
|
|
|
|
|
|
|
|
# Build model
|
|
|
|
x = ops.convert_to_tensor(np.random.random((8, 4)))
|
|
|
|
model(x)
|
|
|
|
# Build optimizer
|
|
|
|
optimizer.build(model.trainable_variables)
|
|
|
|
|
|
|
|
|
|
|
|
################################
|
2023-04-09 20:09:38 +00:00
|
|
|
## Currently operational workflow
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
2023-04-12 18:00:14 +00:00
|
|
|
def compute_loss_and_updates(
|
|
|
|
trainable_variables, non_trainable_variables, x, y
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
y_pred, non_trainable_variables = model.stateless_call(
|
|
|
|
trainable_variables, non_trainable_variables, x
|
|
|
|
)
|
2023-04-18 21:49:38 +00:00
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
loss = loss_fn(y, y_pred)
|
|
|
|
return loss, non_trainable_variables
|
|
|
|
|
|
|
|
|
|
|
|
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
|
|
|
|
|
|
|
|
|
|
|
|
@jax.jit
|
2023-04-12 18:00:14 +00:00
|
|
|
def train_step(
|
|
|
|
trainable_variables, non_trainable_variables, optimizer_variables, x, y
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
(loss, non_trainable_variables), grads = grad_fn(
|
|
|
|
trainable_variables, non_trainable_variables, x, y
|
|
|
|
)
|
|
|
|
trainable_variables, optimizer_variables = optimizer.stateless_apply(
|
2023-04-18 21:49:38 +00:00
|
|
|
grads, trainable_variables, optimizer_variables
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
return trainable_variables, non_trainable_variables, optimizer_variables
|
|
|
|
|
|
|
|
|
|
|
|
# Training loop
|
|
|
|
trainable_variables = model.trainable_variables
|
|
|
|
non_trainable_variables = model.non_trainable_variables
|
|
|
|
optimizer_variables = optimizer.variables
|
|
|
|
for x, y in dataset:
|
2023-04-12 18:00:14 +00:00
|
|
|
(
|
|
|
|
trainable_variables,
|
|
|
|
non_trainable_variables,
|
|
|
|
optimizer_variables,
|
|
|
|
) = train_step(
|
2023-04-09 19:21:45 +00:00
|
|
|
trainable_variables, non_trainable_variables, optimizer_variables, x, y
|
|
|
|
)
|
|
|
|
|
|
|
|
# Post-processing model state update
|
|
|
|
for variable, value in zip(model.trainable_variables, trainable_variables):
|
2023-04-18 20:02:29 +00:00
|
|
|
print(variable.name, np.sum(np.abs(variable - value)))
|
2023-04-09 19:21:45 +00:00
|
|
|
variable.assign(value)
|
2023-04-12 18:00:14 +00:00
|
|
|
for variable, value in zip(
|
|
|
|
model.non_trainable_variables, non_trainable_variables
|
|
|
|
):
|
2023-04-18 20:02:29 +00:00
|
|
|
print(variable.name, np.sum(np.abs(variable - value)))
|
2023-04-09 19:21:45 +00:00
|
|
|
variable.assign(value)
|
|
|
|
|
|
|
|
print("Updated values")
|