Fix jax dropout and start tracking all state in layer.variables.
This commit is contained in:
parent
1a45e5cd17
commit
dd144b6097
@ -30,23 +30,40 @@ class MyDense(layers.Layer):
|
||||
return ops.matmul(inputs, self.w) + self.b
|
||||
|
||||
|
||||
class MyDropout(layers.Layer):
|
||||
def __init__(self, rate, name=None):
|
||||
super().__init__(name=name)
|
||||
self.rate = rate
|
||||
# Use seed_generator for managing RNG state.
|
||||
# It is a state element and its seed variable is
|
||||
# tracked as part of `layer.variables`.
|
||||
self.seed_generator = backend.random.SeedGenerator(1337)
|
||||
|
||||
def call(self, inputs):
|
||||
# Use `backend.random` for random ops.
|
||||
return backend.random.dropout(
|
||||
inputs, self.rate, seed=self.seed_generator
|
||||
)
|
||||
|
||||
|
||||
class MyModel(Model):
|
||||
def __init__(self, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.dense1 = MyDense(hidden_dim)
|
||||
self.dense2 = MyDense(hidden_dim)
|
||||
self.dense3 = MyDense(output_dim)
|
||||
self.dp = MyDropout(0.5)
|
||||
|
||||
def call(self, x):
|
||||
x1 = self.dense1(x)
|
||||
x2 = self.dense2(x)
|
||||
# Why not use some ops here as well
|
||||
x = ops.concatenate([x1, x2], axis=-1)
|
||||
x = self.dp(x)
|
||||
return self.dense3(x)
|
||||
|
||||
|
||||
model = MyModel(hidden_dim=256, output_dim=16)
|
||||
model.summary()
|
||||
|
||||
x = np.random.random((50000, 128))
|
||||
y = np.random.random((50000, 16))
|
||||
@ -60,5 +77,7 @@ model.compile(
|
||||
)
|
||||
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
|
||||
|
||||
model.summary()
|
||||
|
||||
print("History:")
|
||||
print(history.history)
|
||||
|
@ -82,13 +82,13 @@ class MyModel(Layer):
|
||||
super().__init__()
|
||||
self.dense1 = MiniDense(units)
|
||||
# self.bn = MiniBatchNorm()
|
||||
# self.dropout = MiniDropout(0.5)
|
||||
self.dropout = MiniDropout(0.5)
|
||||
self.dense2 = MiniDense(num_classes)
|
||||
|
||||
def call(self, x):
|
||||
x = self.dense1(x)
|
||||
# x = self.bn(x)
|
||||
# x = self.dropout(x)
|
||||
x = self.dropout(x)
|
||||
return self.dense2(x)
|
||||
|
||||
|
||||
|
@ -71,6 +71,8 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
def dropout(inputs, rate, noise_shape=None, seed=None):
|
||||
seed = draw_seed(seed)
|
||||
keep_prob = 1.0 - rate
|
||||
if noise_shape is None:
|
||||
noise_shape = inputs.shape
|
||||
mask = jax.random.bernoulli(seed, p=keep_prob, shape=noise_shape)
|
||||
mask = jax.numpy.broadcast_to(mask, inputs.shape)
|
||||
return jax.lax.select(
|
||||
|
@ -49,6 +49,7 @@ class Layer(Operation):
|
||||
|
||||
self._layers = []
|
||||
self._metrics = []
|
||||
self._seed_generators = []
|
||||
self._losses = []
|
||||
self._variables = []
|
||||
self._trainable_variables = []
|
||||
@ -71,7 +72,7 @@ class Layer(Operation):
|
||||
and not isinstance(x, Metric),
|
||||
self._layers,
|
||||
),
|
||||
# TODO: SeedGenerator tracking
|
||||
"seed_generators": (lambda x: isinstance(x, backend.random.SeedGenerator), self._seed_generators),
|
||||
}
|
||||
)
|
||||
|
||||
@ -176,8 +177,12 @@ class Layer(Operation):
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
# TODO: include not just weights by any variables (also from metrics, optimizers, SeedGenerators)
|
||||
# Includes weights, seed generator state, and metric variables.
|
||||
variables = self.weights[:]
|
||||
for m in self._metrics:
|
||||
variables.extend(m.variables)
|
||||
for sg in self._seed_generators:
|
||||
variables.append(sg.state)
|
||||
return variables
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user