Fix jax dropout and start tracking all state in layer.variables.

This commit is contained in:
Francois Chollet 2023-04-18 16:21:27 -07:00
parent 1a45e5cd17
commit dd144b6097
4 changed files with 31 additions and 5 deletions

@ -30,23 +30,40 @@ class MyDense(layers.Layer):
return ops.matmul(inputs, self.w) + self.b
class MyDropout(layers.Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
# Use seed_generator for managing RNG state.
# It is a state element and its seed variable is
# tracked as part of `layer.variables`.
self.seed_generator = backend.random.SeedGenerator(1337)
def call(self, inputs):
# Use `backend.random` for random ops.
return backend.random.dropout(
inputs, self.rate, seed=self.seed_generator
)
class MyModel(Model):
def __init__(self, hidden_dim, output_dim):
super().__init__()
self.dense1 = MyDense(hidden_dim)
self.dense2 = MyDense(hidden_dim)
self.dense3 = MyDense(output_dim)
self.dp = MyDropout(0.5)
def call(self, x):
x1 = self.dense1(x)
x2 = self.dense2(x)
# Why not use some ops here as well
x = ops.concatenate([x1, x2], axis=-1)
x = self.dp(x)
return self.dense3(x)
model = MyModel(hidden_dim=256, output_dim=16)
model.summary()
x = np.random.random((50000, 128))
y = np.random.random((50000, 16))
@ -60,5 +77,7 @@ model.compile(
)
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
model.summary()
print("History:")
print(history.history)

@ -82,13 +82,13 @@ class MyModel(Layer):
super().__init__()
self.dense1 = MiniDense(units)
# self.bn = MiniBatchNorm()
# self.dropout = MiniDropout(0.5)
self.dropout = MiniDropout(0.5)
self.dense2 = MiniDense(num_classes)
def call(self, x):
x = self.dense1(x)
# x = self.bn(x)
# x = self.dropout(x)
x = self.dropout(x)
return self.dense2(x)

@ -71,6 +71,8 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
def dropout(inputs, rate, noise_shape=None, seed=None):
seed = draw_seed(seed)
keep_prob = 1.0 - rate
if noise_shape is None:
noise_shape = inputs.shape
mask = jax.random.bernoulli(seed, p=keep_prob, shape=noise_shape)
mask = jax.numpy.broadcast_to(mask, inputs.shape)
return jax.lax.select(

@ -49,6 +49,7 @@ class Layer(Operation):
self._layers = []
self._metrics = []
self._seed_generators = []
self._losses = []
self._variables = []
self._trainable_variables = []
@ -71,7 +72,7 @@ class Layer(Operation):
and not isinstance(x, Metric),
self._layers,
),
# TODO: SeedGenerator tracking
"seed_generators": (lambda x: isinstance(x, backend.random.SeedGenerator), self._seed_generators),
}
)
@ -176,8 +177,12 @@ class Layer(Operation):
@property
def variables(self):
# TODO: include not just weights by any variables (also from metrics, optimizers, SeedGenerators)
# Includes weights, seed generator state, and metric variables.
variables = self.weights[:]
for m in self._metrics:
variables.extend(m.variables)
for sg in self._seed_generators:
variables.append(sg.state)
return variables
@property