add tests for RandomNormal and RandomUniform initializers

This commit is contained in:
AakashKumarNain 2023-04-12 23:23:38 +05:30 committed by Francois Chollet
parent 6544d6b850
commit 4b018c0560
34 changed files with 381 additions and 530 deletions

@ -32,9 +32,7 @@ class MiniDropout(Layer):
self.seed_generator = backend.random.RandomSeedGenerator(1337)
def call(self, inputs):
return backend.random.dropout(
inputs, self.rate, seed=self.seed_generator
)
return backend.random.dropout(inputs, self.rate, seed=self.seed_generator)
class MiniBatchNorm(Layer):
@ -45,9 +43,7 @@ class MiniBatchNorm(Layer):
def build(self, input_shape):
shape = (input_shape[-1],)
self.mean = backend.Variable(
initializers.Zeros()(shape), trainable=False
)
self.mean = backend.Variable(initializers.Zeros()(shape), trainable=False)
self.variance = backend.Variable(
initializers.GlorotUniform()(shape), trainable=False
)
@ -62,9 +58,7 @@ class MiniBatchNorm(Layer):
self.variance.assign(
self.variance * self.momentum + variance * (1.0 - self.momentum)
)
self.mean.assign(
self.mean * self.momentum + mean * (1.0 - self.momentum)
)
self.mean.assign(self.mean * self.momentum + mean * (1.0 - self.momentum))
else:
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
outputs *= self.gamma

@ -34,9 +34,7 @@ class MiniDropout(Layer):
self.seed_generator = backend.random.RandomSeedGenerator(1337)
def call(self, inputs):
return backend.random.dropout(
inputs, self.rate, seed=self.seed_generator
)
return backend.random.dropout(inputs, self.rate, seed=self.seed_generator)
class MiniBatchNorm(Layer):
@ -47,9 +45,7 @@ class MiniBatchNorm(Layer):
def build(self, input_shape):
shape = (input_shape[-1],)
self.mean = backend.Variable(
initializers.Zeros()(shape), trainable=False
)
self.mean = backend.Variable(initializers.Zeros()(shape), trainable=False)
self.variance = backend.Variable(
initializers.GlorotUniform()(shape), trainable=False
)
@ -64,9 +60,7 @@ class MiniBatchNorm(Layer):
self.variance.assign(
self.variance * self.momentum + variance * (1.0 - self.momentum)
)
self.mean.assign(
self.mean * self.momentum + mean * (1.0 - self.momentum)
)
self.mean.assign(self.mean * self.momentum + mean * (1.0 - self.momentum))
else:
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
outputs *= self.gamma
@ -113,9 +107,7 @@ optimizer.build(model.trainable_variables)
## Currently operational workflow
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, x, y
):
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
@ -127,9 +119,7 @@ grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(
trainable_variables, non_trainable_variables, optimizer_variables, x, y
):
def train_step(trainable_variables, non_trainable_variables, optimizer_variables, x, y):
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
@ -144,20 +134,14 @@ trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
for x, y in dataset:
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
) = train_step(
trainable_variables, non_trainable_variables, optimizer_variables = train_step(
trainable_variables, non_trainable_variables, optimizer_variables, x, y
)
# Post-processing model state update
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(
model.non_trainable_variables, non_trainable_variables
):
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
print("Updated values")

@ -30,7 +30,9 @@ class KerasVariable:
raise NotImplementedError
def __repr__(self):
return f"<KerasVariable shape={self.shape}, dtype={self.dtype}, name={self.name}>"
return (
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, name={self.name}>"
)
ALLOWED_DTYPES = {

@ -82,8 +82,7 @@ def set_floatx(value):
accepted_dtypes = {"float16", "float32", "float64"}
if value not in accepted_dtypes:
raise ValueError(
f"Unknown `floatx` value: {value}. "
f"Expected one of {accepted_dtypes}"
f"Unknown `floatx` value: {value}. " f"Expected one of {accepted_dtypes}"
)
_FLOATX = str(value)

@ -264,8 +264,7 @@ def compute_output_spec(fn, *args, **kwargs):
return fn(*args, *static_args, **kwargs, **static_kwargs)
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
convert_keras_tensor_to_jax,
(maybe_symbolic_args, maybe_symbolic_kwargs),
convert_keras_tensor_to_jax, (maybe_symbolic_args, maybe_symbolic_kwargs)
)
_, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
*maybe_symbolic_args, **maybe_symbolic_kwargs

@ -72,6 +72,4 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
keep_prob = 1.0 - rate
mask = jax.random.bernoulli(seed, p=keep_prob, shape=noise_shape)
mask = jax.numpy.broadcast_to(mask, inputs.shape)
return jax.lax.select(
mask, inputs / keep_prob, jax.numpy.zeros_like(inputs)
)
return jax.lax.select(mask, inputs / keep_prob, jax.numpy.zeros_like(inputs))

@ -22,9 +22,7 @@ def draw_seed(seed):
if isinstance(seed, RandomSeedGenerator):
new_seed_value = seed.state.value
seed.state.assign(
seed.state + convert_to_tensor([0, 1], dtype="uint32")
)
seed.state.assign(seed.state + convert_to_tensor([0, 1], dtype="uint32"))
return new_seed_value
elif isinstance(seed, int):
return convert_to_tensor([seed, 0], dtype="uint32")

@ -128,32 +128,22 @@ class Variable(KerasVariable):
return self.value.__rdiv__(convert_to_tensor(other, dtype=self.dtype))
def __truediv__(self, other):
return self.value.__truediv__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__truediv__(convert_to_tensor(other, dtype=self.dtype))
def __rtruediv__(self, other):
return self.value.__rtruediv__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__rtruediv__(convert_to_tensor(other, dtype=self.dtype))
def __floordiv__(self, other):
return self.value.__floordiv__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__floordiv__(convert_to_tensor(other, dtype=self.dtype))
def __rfloordiv__(self, other):
return self.value.__rfloordiv__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__rfloordiv__(convert_to_tensor(other, dtype=self.dtype))
def __divmod__(self, other):
return self.value.__divmod__(convert_to_tensor(other, dtype=self.dtype))
def __rdivmod__(self, other):
return self.value.__rdivmod__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__rdivmod__(convert_to_tensor(other, dtype=self.dtype))
def __mod__(self, other):
return self.value.__mod__(convert_to_tensor(other, dtype=self.dtype))
@ -171,9 +161,7 @@ class Variable(KerasVariable):
return self.value.__matmul__(convert_to_tensor(other, dtype=self.dtype))
def __rmatmul__(self, other):
return self.value.__rmatmul__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__rmatmul__(convert_to_tensor(other, dtype=self.dtype))
def __and__(self, other):
return self.value.__and__(convert_to_tensor(other, dtype=self.dtype))
@ -197,17 +185,13 @@ class Variable(KerasVariable):
return self.value.__lshift__(convert_to_tensor(other, dtype=self.dtype))
def __rlshift__(self, other):
return self.value.__rlshift__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__rlshift__(convert_to_tensor(other, dtype=self.dtype))
def __rshift__(self, other):
return self.value.__rshift__(convert_to_tensor(other, dtype=self.dtype))
def __rrshift__(self, other):
return self.value.__rrshift__(
convert_to_tensor(other, dtype=self.dtype)
)
return self.value.__rrshift__(convert_to_tensor(other, dtype=self.dtype))
def __round__(self, ndigits=None):
return self.value.__round__(ndigits)
@ -248,9 +232,7 @@ def compute_output_spec(fn, *args, **kwargs):
return tf.compat.v1.placeholder(shape=x.shape, dtype=x.dtype)
return x
args, kwargs = tf.nest.map_structure(
convert_keras_tensor_to_tf, (args, kwargs)
)
args, kwargs = tf.nest.map_structure(convert_keras_tensor_to_tf, (args, kwargs))
tf_out = fn(*args, **kwargs)
def convert_tf_to_keras_tensor(x):
@ -265,6 +247,4 @@ def execute(op_name, *args, **kwargs):
if hasattr(tfnp, op_name):
op = getattr(tfnp, op_name)
return op(*args, **kwargs)
raise AttributeError(
f"The TensorFlow backend does not support op '{op_name}'"
)
raise AttributeError(f"The TensorFlow backend does not support op '{op_name}'")

@ -2,9 +2,7 @@ import warnings
from keras_core.api_export import keras_core_export
@keras_core_export(
["keras_core.Initializer", "keras_core.initializers.Initializer"]
)
@keras_core_export(["keras_core.Initializer", "keras_core.initializers.Initializer"])
class Initializer:
"""Initializer base class: all Keras initializers inherit from this class.

@ -52,8 +52,7 @@ class VarianceScaling(Initializer):
):
if scale <= 0.0:
raise ValueError(
"Argument `scale` must be positive float. "
f"Received: scale={scale}"
"Argument `scale` must be positive float. " f"Received: scale={scale}"
)
allowed_modes = {"fan_in", "fan_out", "fan_avg"}
if mode not in allowed_modes:
@ -156,9 +155,7 @@ class GlorotUniform(VarianceScaling):
"""
def __init__(self, seed=None):
super().__init__(
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed
)
super().__init__(scale=1.0, mode="fan_avg", distribution="uniform", seed=seed)
def get_config(self):
return {"seed": self.seed}
@ -284,9 +281,7 @@ class LecunUniform(VarianceScaling):
"""
def __init__(self, seed=None):
super().__init__(
scale=1.0, mode="fan_in", distribution="uniform", seed=seed
)
super().__init__(scale=1.0, mode="fan_in", distribution="uniform", seed=seed)
def get_config(self):
return {"seed": self.seed}
@ -364,9 +359,7 @@ class HeUniform(VarianceScaling):
"""
def __init__(self, seed=None):
super().__init__(
scale=2.0, mode="fan_in", distribution="uniform", seed=seed
)
super().__init__(scale=2.0, mode="fan_in", distribution="uniform", seed=seed)
def get_config(self):
return {"seed": self.seed}
@ -398,3 +391,101 @@ def compute_fans(shape):
fan_in = shape[-2] * receptive_field_size
fan_out = shape[-1] * receptive_field_size
return int(fan_in), int(fan_out)
class RandomNormal(Initializer):
"""Random normal initializer.
Draws samples from a normal distribution for given parameters.
Examples:
>>> # Standalone usage:
>>> initializer = RandomNormal(mean=0.0, stddev=1.0)
>>> values = initializer(shape=(2, 2))
>>> # Usage in a Keras layer:
>>> initializer = RandomNormal(mean=0.0, stddev=1.0)
>>> layer = Dense(3, kernel_initializer=initializer)
Args:
mean: A python scalar or a scalar keras tensor. Mean of the random values to
generate.
stddev: A python scalar or a scalar keras tensor. Standard deviation of the
random values to generate.
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
"""
def __init__(self, mean=0.0, stddev=1.0, seed=None):
self.mean = mean
self.stddev = stddev
self.seed = seed or random.make_default_seed()
super().__init__()
def __call__(self, shape, dtype=None, **kwargs):
return random.normal(
shape=shape,
mean=self.mean,
stddev=self.stddev,
seed=self.seed,
dtype=dtype
)
def get_config(self):
return {"mean": self.mean, "stddev":self.stddev, "seed":self.seed}
class RandomUniform(Initializer):
"""Random uniform initializer.
Draws samples from a uniform distribution for given parameters.
Examples:
>>> # Standalone usage:
>>> initializer = RandomUniform(minval=0.0, maxval=1.0)
>>> values = initializer(shape=(2, 2))
>>> # Usage in a Keras layer:
>>> initializer = RandomUniform(minval=0.0, maxval=1.0)
>>> layer = Dense(3, kernel_initializer=initializer)
Args:
minval: A python scalar or a scalar keras tensor. Lower bound of the range of
random values to generate (inclusive).
maxval: A python scalar or a scalar keras tensor. Upper bound of the range of
random values to generate (exclusive).
seed: A Python integer or instance of
`keras_core.backend.RandomSeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.backend.RandomSeedGenerator`.
"""
def __init__(self, minval=0.0, maxval=1.0, seed=None):
self.minval = minval
self.maxval = maxval
self.seed = seed or random.make_default_seed()
super().__init__()
def __call__(self, shape, dtype=None, **kwargs):
return random.uniform(
shape=shape,
minval=self.minval,
maxval=self.maxval,
seed=self.seed,
dtype=dtype
)
def get_config(self):
return {"minval": self.minval, "maxval":self.maxval, "seed":self.seed}

@ -0,0 +1,49 @@
from keras_core import testing
from keras_core import initializers
import numpy as np
class InitializersTest(testing.TestCase):
def test_random_normal(self):
shape = (5, 5)
mean = 0.0
stddev = 1.0
seed = 1234
external_config = {"mean": 1.0, "stddev": 0.5, "seed": 42}
initializer = initializers.RandomNormal(
mean=mean,
stddev=stddev,
seed=seed
)
values = initializer(shape=shape)
self.assertEqual(initializer.mean, mean)
self.assertEqual(initializer.stddev, stddev)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
self.test_load_external_config(initializer, external_config)
def test_random_uniform(self):
shape = (5, 5)
minval = -1.0
maxval = 1.0
seed = 1234
external_config = {"minval": 0.0, "maxval": 1.0, "seed": 42}
initializer = initializers.RandomUniform(
minval=minval,
maxval=maxval,
seed=seed
)
values = initializer(shape=shape)
self.assertEqual(initializer.minval, minval)
self.assertEqual(initializer.maxval, maxval)
self.assertEqual(initializer.seed, seed)
self.assertEqual(values.shape, shape)
self.test_load_external_config(initializer, external_config)
values = values.numpy()
self.assertGreaterEqual(np.min(values), minval)
self.assertLess(np.max(values), maxval)
def test_load_external_config(self, initializer, config):
initializer = initializer.from_config(config)
self.assertEqual(initializer.get_config(), config)

@ -56,9 +56,7 @@ class InputSpec:
allow_last_axis_squeeze=False,
name=None,
):
self.dtype = (
backend.standardize_dtype(dtype) if dtype is not None else None
)
self.dtype = backend.standardize_dtype(dtype) if dtype is not None else None
if shape is not None:
self.shape = backend.standardize_shape(shape)
self.ndim = len(shape)

@ -62,8 +62,7 @@ class Layer(Operation):
),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
"layers": (
lambda x: isinstance(x, Layer)
and not isinstance(x, Metric),
lambda x: isinstance(x, Layer) and not isinstance(x, Metric),
self._layers,
),
# TODO: RandomSeedGenerator tracking
@ -252,9 +251,7 @@ class Layer(Operation):
# Argument validation and conversion. #
# 1. Convert first positional argument to tensor of correct dtype.
if args and not isinstance(args[0], KerasTensor):
args = (
nest.map_structure(backend.convert_to_tensor, args[0]),
) + args[1:]
args = (nest.map_structure(backend.convert_to_tensor, args[0]),) + args[1:]
# 2. Convert any other array arguments to tensors of correct dtype.
def maybe_convert(x):
@ -436,7 +433,7 @@ class Layer(Operation):
def add_metric(self):
# Permanently disabled
raise NotImplementedError
def count_params(self):
"""Count the total number of scalars composing the weights.
@ -538,7 +535,7 @@ class Layer(Operation):
values = self.call.__defaults__
mapping = dict(zip(kwargs, values))
return mapping.get("training", None)
def _flatten_layers(self, include_self=True, recursive=True):
layers = []
if include_self:
@ -581,12 +578,9 @@ def get_shapes_dict(arguments_dict):
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
elif nest.is_nested(v):
flat = nest.flatten(v)
if any(
isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat
):
if any(isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat):
if not all(
isinstance(x, KerasTensor) or backend.is_tensor(x)
for x in flat
isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat
):
raise ValueError(
"You cannot mix tensors and non-tensors in a nested argument. "

@ -13,9 +13,7 @@ class FunctionTest(testing.TestCase):
return x + 1
x = keras_tensor.KerasTensor(shape=(2, 3), name="x")
with self.assertRaisesRegex(
ValueError, "Only input tensors may be passed as"
):
with self.assertRaisesRegex(ValueError, "Only input tensors may be passed as"):
SomeLayer()(x, True)
# This works

@ -1,7 +1,6 @@
from keras_core import operations as ops
from keras_core import backend
from keras_core.utils.naming import auto_name
from keras_core.utils import dtype_utils
from keras_core.api_export import keras_core_export
@ -31,10 +30,7 @@ class Loss:
mask = None
return reduce_weighted_loss(
losses,
sample_weight=sample_weight,
mask=mask,
reduction=self.reduction,
losses, sample_weight=sample_weight, mask=mask, reduction=self.reduction
)
def call(self, y_true, y_pred):
@ -75,11 +71,7 @@ def squeeze_to_same_rank(x1, x2):
def reduce_loss(losses, reduction="sum_over_batch_size"):
if (
reduction is None
or tuple(losses.shape) == ()
or tuple(losses.shape) == (0,)
):
if reduction is None or tuple(losses.shape) == () or tuple(losses.shape) == (0,):
return losses
loss = ops.sum(losses)
if reduction == "sum_over_batch_size":
@ -109,7 +101,7 @@ def reduce_weighted_loss(
# Convert any non float dtypes to floats, to avoid loss of precision
# for dtype like int or bool.
dtype = backend.standardize_dtype(losses.dtype)
if not dtype_utils.is_float(dtype):
if not is_float(dtype):
input_dtype = losses.dtype
losses = ops.cast(losses, "float32")
input_casted = True
@ -131,6 +123,45 @@ def reduce_weighted_loss(
return loss
def float_dtype_size(dtype):
if dtype in ("bfloat16", "float16"):
return 16
if dtype == "float32":
return 32
if dtype == "float64":
return 64
raise ValueError(f"Invalid dtype: {dtype}")
def is_float(dtype):
return "float" in dtype
def cast_to_common_dtype(tensors):
"""Cast a list of tensors to a common dtype.
If any tensor is floating-point, they will all be casted to the most-precise
floating-point dtype. Otherwise the tensors are not casted.
Args:
tensors: A list of tensors.
Returns:
Same list, casted to a common dtype.
"""
highest_float = None
for x in tensors:
dtype = backend.standardize_dtype(x.dtype)
if is_float(dtype):
if highest_float is None or float_dtype_size(dtype) > highest_float:
highest_float = dtype
elif dtype == "float16" and highest_float == "bfloat16":
highest_float = "float32"
if highest_float:
tensors = [ops.cast(x, highest_float) for x in tensors]
return tensors
def apply_mask(sample_weight, mask, dtype, reduction):
"""Applies any mask on predictions to sample weights."""
if mask is not None:

@ -52,9 +52,7 @@ class LossTest(testing.TestCase):
loss_fn = ExampleLoss()
loss = loss_fn(y_true, y_pred)
self.assertEqual(loss.dtype.name, "float32")
self.assertAllClose(
np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss
)
self.assertAllClose(np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss)
# Test edge case where everything is masked.
mask = np.array([False, False, False, False])
@ -71,9 +69,7 @@ class LossTest(testing.TestCase):
loss_fn = ExampleLoss()
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(loss.dtype.name, "float32")
self.assertAllClose(
np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss
)
self.assertAllClose(np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss)
# Test edge case where every weight is 0.
sample_weight = np.array([0.0, 0.0, 0.0, 0.0])
@ -100,8 +96,7 @@ class LossTest(testing.TestCase):
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(loss.dtype.name, "float32")
self.assertAllClose(
np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2)
/ 3,
np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2) / 3,
loss,
)
@ -135,10 +130,7 @@ class LossTest(testing.TestCase):
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertEqual(loss.dtype.name, "float32")
self.assertAllClose(
np.sum(
masked_sample_weight * (masked_y_true - masked_y_pred) ** 2
)
/ 3,
np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2) / 3,
loss,
)

@ -4,9 +4,7 @@ from keras_core.losses.loss import squeeze_to_same_rank
class LossFunctionWrapper(Loss):
def __init__(
self, fn, reduction="sum_over_batch_size", name=None, **kwargs
):
def __init__(self, fn, reduction="sum_over_batch_size", name=None, **kwargs):
super().__init__(reduction=reduction, name=name)
self.fn = fn
self._fn_kwargs = kwargs
@ -31,7 +29,5 @@ def mean_squared_error(y_true, y_pred):
class MeanSquaredError(LossFunctionWrapper):
def __init__(
self, reduction="sum_over_batch_size", name="mean_squared_error"
):
def __init__(self, reduction="sum_over_batch_size", name="mean_squared_error"):
super().__init__(mean_squared_error, reduction=reduction, name=name)

@ -9,9 +9,7 @@ import numpy as np
class ExampleMetric(Metric):
def __init__(self, name="mean_square_error", dtype=None):
super().__init__(name=name, dtype=dtype)
self.sum = self.add_variable(
name="sum", initializer=initializers.Zeros()
)
self.sum = self.add_variable(name="sum", initializer=initializers.Zeros())
self.total = self.add_variable(
name="total", initializer=initializers.Zeros(), dtype="int32"
)
@ -50,9 +48,7 @@ class MetricTest(testing.TestCase):
self.assertAllClose(metric.total, 20)
result = metric.result()
self.assertAllClose(
result, np.sum((y_true - y_pred) ** 2) / num_samples
)
self.assertAllClose(result, np.sum((y_true - y_pred) ** 2) / num_samples)
metric.reset_state()
self.assertEqual(metric.result(), 0.0)
@ -64,10 +60,7 @@ class MetricTest(testing.TestCase):
# In dict
metric = ExampleMetric(name="mse")
metric.more_vars = {
"a": backend.Variable(0.0),
"b": backend.Variable(1.0),
}
metric.more_vars = {"a": backend.Variable(0.0), "b": backend.Variable(1.0)}
self.assertEqual(len(metric.variables), 4)
# In nested structured

@ -6,12 +6,8 @@ from keras_core import initializers
class MeanSquareError(Metric):
def __init__(self, name="mean_square_error", dtype=None):
super().__init__(name=name, dtype=dtype)
self.sum = self.add_variable(
name="sum", initializer=initializers.Zeros()
)
self.total = self.add_variable(
name="total", initializer=initializers.Zeros()
)
self.sum = self.add_variable(name="sum", initializer=initializers.Zeros())
self.total = self.add_variable(name="total", initializer=initializers.Zeros())
def update_state(self, y_true, y_pred):
# TODO: add support for sample_weight

@ -27,7 +27,6 @@ class Functional(Function, Model):
return
super().__init__(inputs, outputs, name=name)
self._layers = self.layers
self.built = True
@property
def layers(self):
@ -60,7 +59,7 @@ class Functional(Function, Model):
pass
def add_loss(self, loss):
# Symbolic only. TODO
# Symbolic only.
raise NotImplementedError

@ -36,7 +36,7 @@ class Model(Layer, Trainer):
@property
def layers(self):
return list(self._flatten_layers(include_self=False, recursive=False))
@layers.setter
def layers(self, _):
raise AttributeError(

@ -107,9 +107,7 @@ class Function(Operation):
def _assert_input_compatibility(self, inputs):
try:
nest.assert_same_structure(
inputs, self._inputs_struct, check_types=False
)
nest.assert_same_structure(inputs, self._inputs_struct, check_types=False)
except ValueError:
raise ValueError(
"Function was called with an invalid input structure. "

@ -14,9 +14,7 @@ class FunctionTest(testing.TestCase):
x = knp.add(x1, x2)
y1 = x * 3
y2 = x**2
fn = function.Function(
inputs=[x1, x2], outputs=[y1, y2], name="test_function"
)
fn = function.Function(inputs=[x1, x2], outputs=[y1, y2], name="test_function")
self.assertEqual(fn.name, "test_function")
# Eager call
@ -90,9 +88,7 @@ class FunctionTest(testing.TestCase):
x = knp.add(x1, x2)
y1 = x * 3
y2 = x**2
fn = function.Function(
inputs=[x1, x2], outputs=[y1, y2], name="test_function"
)
fn = function.Function(inputs=[x1, x2], outputs=[y1, y2], name="test_function")
self.assertEqual(fn.name, "test_function")
# Bad structure

@ -36,9 +36,7 @@ class Node:
outputs: The output tensors of the `op.__call__()` call.
"""
def __init__(
self, operation, call_args=None, call_kwargs=None, outputs=None
):
def __init__(self, operation, call_args=None, call_kwargs=None, outputs=None):
self.operation = operation
self.arguments = SymbolicArguments(*call_args, **call_kwargs)
self.outputs = [] if outputs is None else nest.flatten(outputs)
@ -101,9 +99,7 @@ class Node:
class KerasHistory(
collections.namedtuple(
"KerasHistory", ["operation", "node_index", "tensor_index"]
)
collections.namedtuple("KerasHistory", ["operation", "node_index", "tensor_index"])
):
"""Tracks the Operation call that created a Tensor.

@ -285,14 +285,10 @@ class Mean(Operation):
self.keepdims = keepdims
def call(self, x):
return backend.execute(
"mean", x, axis=self.axis, keepdims=self.keepdims
)
return backend.execute("mean", x, axis=self.axis, keepdims=self.keepdims)
def compute_output_spec(self, x):
return compute_np_output_spec(
"mean", x, axis=self.axis, keepdims=self.keepdims
)
return compute_np_output_spec("mean", x, axis=self.axis, keepdims=self.keepdims)
def mean(x, axis=None, keepdims=False):
@ -310,9 +306,7 @@ class Var(Operation):
return backend.execute("var", x, axis=self.axis, keepdims=self.keepdims)
def compute_output_spec(self, x):
return compute_np_output_spec(
"var", x, axis=self.axis, keepdims=self.keepdims
)
return compute_np_output_spec("var", x, axis=self.axis, keepdims=self.keepdims)
def var(x, axis=None, keepdims=False):
@ -330,9 +324,7 @@ class Sum(Operation):
return backend.execute("sum", x, axis=self.axis, keepdims=self.keepdims)
def compute_output_spec(self, x):
return compute_np_output_spec(
"sum", x, axis=self.axis, keepdims=self.keepdims
)
return compute_np_output_spec("sum", x, axis=self.axis, keepdims=self.keepdims)
def sum(x, axis=None, keepdims=False):

@ -24,9 +24,7 @@ class Operation:
# sets _keras_history on the outputs, and adds itself to the
# `_outbound_nodes` of the ops that produced the inputs to this
# call.
Node(
operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs
)
Node(operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs)
return outputs
def call(self, *args, **kwargs):

@ -61,9 +61,7 @@ class Optimizer:
self.iterations = backend.Variable(
0, name="iteration", dtype="int64", trainable=False
)
if isinstance(
learning_rate, learning_rate_schedule.LearningRateSchedule
):
if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule):
self._learning_rate = learning_rate
elif callable(learning_rate):
self._learning_rate = learning_rate
@ -264,9 +262,7 @@ class Optimizer:
return self._get_current_learning_rate()
def _get_current_learning_rate(self):
if isinstance(
self._learning_rate, learning_rate_schedule.LearningRateSchedule
):
if isinstance(self._learning_rate, learning_rate_schedule.LearningRateSchedule):
return self._learning_rate(self.iterations)
elif callable(self._learning_rate):
return self._learning_rate(self.iterations)
@ -340,17 +336,13 @@ class Optimizer:
)
if var_list:
self._exclude_from_weight_decay = [
id(variable) for variable in var_list
]
self._exclude_from_weight_decay = [id(variable) for variable in var_list]
else:
self._exclude_from_weight_decay = []
self._exclude_from_weight_decay_names = var_names or []
def _use_weight_decay(self, variable):
exclude_from_weight_decay = getattr(
self, "_exclude_from_weight_decay", []
)
exclude_from_weight_decay = getattr(self, "_exclude_from_weight_decay", [])
exclude_from_weight_decay_names = getattr(
self, "_exclude_from_weight_decay_names", []
)
@ -383,9 +375,7 @@ class Optimizer:
def _update_model_variables_moving_average(self, var_list):
"""Update the stored moving average using the latest value."""
if self.use_ema:
for var, average in zip(
var_list, self._model_variables_moving_average
):
for var, average in zip(var_list, self._model_variables_moving_average):
average.assign(
self.ema_momentum * average + (1 - self.ema_momentum) * var
)
@ -404,9 +394,7 @@ class Optimizer:
def _overwrite_model_variables_with_average_value_helper(self, var_list):
"""Helper function that overwrites model variables."""
for var, average_var in zip(
var_list, self._model_variables_moving_average
):
for var, average_var in zip(var_list, self._model_variables_moving_average):
var.assign(average_var)
def finalize_variable_values(self, var_list):
@ -439,12 +427,8 @@ class Optimizer:
Python dictionary.
"""
if isinstance(
self._learning_rate, learning_rate_schedule.LearningRateSchedule
):
learning_rate = learning_rate_schedule.serialize(
self._learning_rate
)
if isinstance(self._learning_rate, learning_rate_schedule.LearningRateSchedule):
learning_rate = learning_rate_schedule.serialize(self._learning_rate)
elif isinstance(self._learning_rate, backend.Variable):
learning_rate = float(self._learning_rate.numpy())
elif ops.is_tensor(self._learning_rate):

@ -85,9 +85,7 @@ class SGD(optimizer.Optimizer):
self.momentums = []
for variable in variables:
self.momentums.append(
self.add_variable_from_reference(
reference_variable=variable, name="m"
)
self.add_variable_from_reference(reference_variable=variable, name="m")
)
def update_step(self, gradient, variable, learning_rate):
@ -100,9 +98,7 @@ class SGD(optimizer.Optimizer):
if m is not None:
m.assign(-gradient * learning_rate + m * momentum)
if self.nesterov:
variable.assign(
variable - gradient * learning_rate + m * momentum
)
variable.assign(variable - gradient * learning_rate + m * momentum)
else:
variable.assign(variable + m)
else:

@ -3,9 +3,7 @@ from keras_core import operations as ops
from keras_core.api_export import keras_core_export
@keras_core_export(
["keras_core.Regularizer", "keras_core.regularizers.Regularizer"]
)
@keras_core_export(["keras_core.Regularizer", "keras_core.regularizers.Regularizer"])
class Regularizer:
"""Regularizer base class.
@ -317,15 +315,10 @@ class OrthogonalRegularizer(Regularizer):
inputs = l2_normalize(inputs, axis=0)
product = ops.matmul(ops.transpose(inputs), inputs)
size = inputs.shape[1]
product_no_diagonal = product * (
1.0 - ops.eye(size, dtype=inputs.dtype)
)
product_no_diagonal = product * (1.0 - ops.eye(size, dtype=inputs.dtype))
num_pairs = size * (size - 1.0) / 2.0
return (
self.factor
* 0.5
* ops.sum(ops.absolute(product_no_diagonal))
/ num_pairs
self.factor * 0.5 * ops.sum(ops.absolute(product_no_diagonal)) / num_pairs
)
def get_config(self):
@ -334,9 +327,7 @@ class OrthogonalRegularizer(Regularizer):
def validate_float_arg(value, name):
"""check penalty number availability, raise ValueError if failed."""
if not isinstance(value, (float, int)) or (
math.isinf(value) or math.isnan(value)
):
if not isinstance(value, (float, int)) or (math.isinf(value) or math.isnan(value)):
raise ValueError(
f"Invalid value for argument {name}: expected a float. "
f"Received: {name}={value}"

@ -1,41 +0,0 @@
from keras_core import backend
from keras_core import operations as ops
def float_dtype_size(dtype):
if dtype in ("bfloat16", "float16"):
return 16
if dtype == "float32":
return 32
if dtype == "float64":
return 64
raise ValueError(f"Invalid dtype: {dtype}")
def is_float(dtype):
return "float" in dtype
def cast_to_common_dtype(tensors):
"""Cast a list of tensors to a common dtype.
If any tensor is floating-point, they will all be casted to the most-precise
floating-point dtype. Otherwise the tensors are not casted.
Args:
tensors: A list of tensors.
Returns:
Same list, casted to a common dtype.
"""
highest_float = None
for x in tensors:
dtype = backend.standardize_dtype(x.dtype)
if is_float(dtype):
if highest_float is None or float_dtype_size(dtype) > highest_float:
highest_float = dtype
elif dtype == "float16" and highest_float == "bfloat16":
highest_float = "float32"
if highest_float:
tensors = [ops.cast(x, highest_float) for x in tensors]
return tensors

@ -42,7 +42,9 @@ def is_interactive_logging_enabled():
"""
# Use `getattr` in case `INTERACTIVE_LOGGING`
# does not have the `enable` attribute.
return getattr(INTERACTIVE_LOGGING, "enable", True)
return getattr(
INTERACTIVE_LOGGING, "enable", True
)
def print_msg(message, line_break=True):

@ -1,48 +1,13 @@
from tensorflow import nest
from keras_core import backend
from keras_core.utils import io_utils
from keras_core.utils import dtype_utils
from keras_core.utils import text_rendering
from keras_core.backend import Variable
import math
import re
def count_params(weights):
shapes = [v.shape for v in weights]
shapes = [v.shape for v in weights if isinstance(v, Variable)]
return int(sum(math.prod(p) for p in shapes))
def weight_memory_size(weights):
"""Compute the memory footprint for weights based on their dtypes.
Args:
weights: An iterable contains the weights to compute weight size.
Returns:
The total memory size (in Bytes) of the weights.
"""
unique_weights = set(weights)
total_memory_size = 0
for w in unique_weights:
weight_shape = math.prod(w.shape)
dtype = backend.standardize_dtype(w.dtype)
per_param_size = dtype_utils.float_dtype_size(dtype)
total_memory_size += weight_shape * per_param_size
return total_memory_size
def readable_memory_size(weight_memory_size):
"""Convert the weight memory size (Bytes) to a readable string."""
units = ["Byte", "KB", "MB", "GB", "TB", "PB"]
scale = 1024
for unit in units:
if weight_memory_size / scale < 1:
return "{:.2f} {}".format(weight_memory_size, unit)
else:
weight_memory_size /= scale
return "{:.2f} {}".format(weight_memory_size, units[-1])
def print_summary(
model,
line_length=None,
@ -121,13 +86,17 @@ def print_summary(
if sequential_like:
line_length = line_length or 65
positions = positions or [0.45, 0.85, 1.0]
if positions[-1] <= 1:
positions = [int(line_length * p) for p in positions]
# header names for the different log elements
header = ["Layer (type)", "Output Shape", "Param #"]
to_display = ["Layer (type)", "Output Shape", "Param #"]
else:
line_length = line_length or 98
positions = positions or [0.3, 0.6, 0.70, 1.0]
if positions[-1] <= 1:
positions = [int(line_length * p) for p in positions]
# header names for the different log elements
header = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
to_display = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
relevant_nodes = []
for v in model._nodes_by_depth.values():
relevant_nodes += v
@ -135,14 +104,63 @@ def print_summary(
if show_trainable:
line_length += 11
positions.append(line_length)
header.append("Trainable")
to_display.append("Trainable")
layer_range = get_layer_index_bound_by_layer_name(model, layer_range)
print_fn(f'Model: "{model.name}"')
rows = []
def print_row(fields, positions, nested_level=0):
left_to_print = [str(x) for x in fields]
while any(left_to_print):
line = ""
for col in range(len(left_to_print)):
if col > 0:
start_pos = positions[col - 1]
else:
start_pos = 0
end_pos = positions[col]
# Leave room for 2 spaces to delineate columns
# we don't need any if we are printing the last column
space = 2 if col != len(positions) - 1 else 0
cutoff = end_pos - start_pos - space
# Except for last col, offset by one to align the start of col
if col != len(positions) - 1:
cutoff -= 1
if col == 0:
cutoff -= nested_level
fit_into_line = left_to_print[col][:cutoff]
# For nicer formatting we line-break on seeing end of
# tuple/dict etc.
line_break_conditions = ("),", "},", "],", "',")
candidate_cutoffs = [
fit_into_line.find(x) + len(x)
for x in line_break_conditions
if fit_into_line.find(x) >= 0
]
if candidate_cutoffs:
cutoff = min(candidate_cutoffs)
fit_into_line = fit_into_line[:cutoff]
def print_layer_summary(layer, prefix=" "):
if col == 0:
line += "|" * nested_level + " "
line += fit_into_line
line += " " * space if space else ""
left_to_print[col] = left_to_print[col][cutoff:]
# Pad out to the next position
# Make space for nested_level for last column
if nested_level and col == len(positions) - 1:
line += " " * (positions[col] - len(line) - nested_level)
else:
line += " " * (positions[col] - len(line))
line += "|" * nested_level
print_fn(line)
print_fn(f'Model: "{model.name}"')
print_fn("_" * line_length)
print_row(to_display, positions)
print_fn("=" * line_length)
def print_layer_summary(layer, nested_level=0):
"""Prints a summary for a single layer.
Args:
@ -156,9 +174,9 @@ def print_summary(
output_shape = "multiple"
except RuntimeError: # output_shape unknown in Eager mode.
output_shape = "?"
name = prefix + layer.name
name = layer.name
cls_name = layer.__class__.__name__
if not layer.built:
if not layer.built and not getattr(layer, "_is_graph_network", False):
# If a subclassed model has a layer that is not called in
# Model.call, the layer will not be built and we cannot call
# layer.count_params().
@ -169,9 +187,10 @@ def print_summary(
if show_trainable:
fields.append("Y" if layer.trainable else "N")
rows.append(fields)
def print_layer_summary_with_connections(layer, prefix=""):
print_row(fields, positions, nested_level)
def print_layer_summary_with_connections(layer, nested_level=0):
"""Prints a summary for a single layer (including its connections).
Args:
@ -188,15 +207,18 @@ def print_summary(
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for kt in node.keras_inputs:
keras_history = kt._keras_history
inbound_layer = keras_history.layer
node_index = keras_history.node_index
tensor_index = keras_history.tensor_index
for (
inbound_layer,
node_index,
tensor_index,
_,
) in node.iterate_inbound():
connections.append(
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
)
name = prefix + layer.name
name = layer.name
cls_name = layer.__class__.__name__
fields = [
name + " (" + cls_name + ")",
@ -204,38 +226,49 @@ def print_summary(
layer.count_params(),
connections,
]
if show_trainable:
fields.append("Y" if layer.trainable else "N")
rows.append(fields)
def print_layer(layer, nested_level=0):
print_row(fields, positions, nested_level)
def print_layer(layer, nested_level=0, is_nested_last=False):
if sequential_like:
print_layer_summary(layer, prefix=">>>" * nested_level + " ")
print_layer_summary(layer, nested_level)
else:
print_layer_summary_with_connections(
layer, prefix=">>>" * nested_level + " "
)
print_layer_summary_with_connections(layer, nested_level)
if expand_nested and hasattr(layer, "layers") and layer.layers:
nested_layers = layer.layers
nested_level += 1
for i in range(len(nested_layers)):
print_layer(nested_layers[i], nested_level=nested_level)
print_fn(
"|" * (nested_level + 1)
+ "¯" * (line_length - 2 * nested_level - 2)
+ "|" * (nested_level + 1)
)
nested_layer = layer.layers
is_nested_last = False
for i in range(len(nested_layer)):
if i == len(nested_layer) - 1:
is_nested_last = True
print_layer(nested_layer[i], nested_level + 1, is_nested_last)
print_fn(
"|" * nested_level
+ "¯" * (line_length - 2 * nested_level)
+ "|" * nested_level
)
if not is_nested_last:
print_fn(
"|" * nested_level
+ " " * (line_length - 2 * nested_level)
+ "|" * nested_level
)
for layer in model.layers[layer_range[0] : layer_range[1]]:
print_layer(layer)
print_fn("=" * line_length)
# Render summary as a table.
table = text_rendering.Table(
header=header,
rows=rows,
positions=positions,
# Left align layer name, center-align everything else
alignments=["left"] + ["center" for _ in range(len(header) - 1)],
)
print_fn(table.make())
# After the table, append information about parameter count and size.
if hasattr(model, "_collected_trainable_weights"):
trainable_count = count_params(model._collected_trainable_weights)
trainable_memory_size = weight_memory_size(
@ -262,57 +295,6 @@ def print_summary(
f"Non-trainable params: {non_trainable_count} "
f"({readable_memory_size(non_trainable_memory_size)})"
)
print_fn("_" * line_length)
def get_layer_index_bound_by_layer_name(model, layer_range=None):
"""Get the layer indexes from the model based on layer names.
The layer indexes can be used to slice the model into sub models for
display.
Args:
model: `Model` instance.
layer_names: a list or tuple of 2 strings, the starting layer name and
ending layer name (both inclusive) for the result. All layers will
be included when `None` is provided.
Returns:
The index value of layer based on its unique name (layer_names).
Output will be [first_layer_index, last_layer_index + 1].
"""
if layer_range is not None:
if len(layer_range) != 2:
raise ValueError(
"layer_range must be a list or tuple of length 2. Received: "
f"layer_range = {layer_range} of length {len(layer_range)}"
)
if not isinstance(layer_range[0], str) or not isinstance(
layer_range[1], str
):
raise ValueError(
"layer_range should contain string type only. "
f"Received: {layer_range}"
)
else:
return [0, len(model.layers)]
lower_index = [
idx
for idx, layer in enumerate(model.layers)
if re.match(layer_range[0], layer.name)
]
upper_index = [
idx
for idx, layer in enumerate(model.layers)
if re.match(layer_range[1], layer.name)
]
if not lower_index or not upper_index:
raise ValueError(
"Passed layer_names do not match the layer names in the model. "
f"Received: {layer_range}"
)
if min(lower_index) > max(upper_index):
return [min(upper_index), max(lower_index) + 1]
return [min(lower_index), max(upper_index) + 1]
print_dtensor_variable_summary(model, print_fn, line_length)

@ -1,126 +0,0 @@
import shutil
class Table:
def __init__(
self, header, rows, positions, alignments=None, max_line_length=80
):
if len(header) != len(positions):
raise ValueError("header and positions should be the same length.")
if not all(p <= 1.0 for p in positions):
raise ValueError("All positions should be <= 1.")
self.alignments = alignments or ["center" for _ in header]
if len(self.alignments) != len(header):
raise ValueError("header and alignments should be the same length.")
last_p = 0.0
for p in positions:
if p <= last_p:
raise ValueError(
"All consecutive positions should be greater than the last."
)
last_p = p
self.header = header
self.rows = rows
# Compute columns widths
line_length = min(
max_line_length, shutil.get_terminal_size().columns - 4
)
column_widths = []
current = 0
for pos in positions:
width = int(pos * line_length) - current
if width < 4:
raise ValueError("Insufficient console width to print summary.")
column_widths.append(width)
current += width
self.column_widths = column_widths
def make_separator(self, left, mid, right, horizontal):
line = mid.join(horizontal * width for width in self.column_widths)
return f"{left}{line}{right}"
def print_row(
self,
fields,
vertical_separator="",
alignments=None,
):
alignments = alignments or ["center" for _ in fields]
lines = []
line_break_conditions = ("),", "},", "],", "',", ") ")
for field, width in zip(fields, self.column_widths):
buffered_width = width - 1
if len(field) < buffered_width and not "\n" in field:
lines.append([field])
continue
subfields = []
while len(field) >= buffered_width or "\n" in field:
if "\n" in field[:buffered_width]:
# priority: break on line break
cutoff = field.find("\n")
subfield = field[:cutoff]
field = field[cutoff + 1 :]
subfields.append(subfield)
continue
# secondary: break on certain characters
candidate_cutoffs = [
field.find(x) + len(x)
for x in line_break_conditions
if 0 < field.find(x) < buffered_width
]
if candidate_cutoffs:
cutoff = min(buffered_width - 1, *candidate_cutoffs)
else:
cutoff = buffered_width - 1
subfield = field[:cutoff]
field = field[cutoff:]
subfields.append(subfield)
if field:
subfields.append(field)
lines.append(subfields)
max_subfield_count = max(len(subs) for subs in lines)
rendered_lines = []
for i in range(max_subfield_count):
fields = []
for subfields in lines:
if len(subfields) < i + 1:
field = ""
else:
field = subfields[i]
fields.append(field)
line = vertical_separator.join(
self.align_field(field, width, alignment)
for field, width, alignment in zip(
fields, self.column_widths, alignments
)
)
line = f"{vertical_separator}{line}{vertical_separator}"
rendered_lines.append(line)
return "\n".join(rendered_lines)
@staticmethod
def align_field(field, width, alignment):
if alignment == "center":
return field.center(width)
if alignment == "left":
return field.ljust(width)
if alignment == "right":
return field.rjust(width)
def make(self):
lines = []
# Print header
lines.append(self.make_separator(*"┏┳┓━"))
lines.append(self.print_row(self.header, vertical_separator=""))
lines.append(self.make_separator(*"┡╇┩━"))
# Print rows
for i, row in enumerate(self.rows):
lines.append(self.print_row(row, alignments=self.alignments))
if i < len(self.rows) - 1:
lines.append(self.make_separator(*"├┼┤─"))
lines.append(self.make_separator(*"└┴┘─"))
return "\n".join(lines)

@ -32,9 +32,7 @@ class MiniDropout(Layer):
self.seed_generator = backend.random.RandomSeedGenerator(1337)
def call(self, inputs):
return backend.random.dropout(
inputs, self.rate, seed=self.seed_generator
)
return backend.random.dropout(inputs, self.rate, seed=self.seed_generator)
class MiniBatchNorm(Layer):
@ -45,9 +43,7 @@ class MiniBatchNorm(Layer):
def build(self, input_shape):
shape = (input_shape[-1],)
self.mean = backend.Variable(
initializers.Zeros()(shape), trainable=False
)
self.mean = backend.Variable(initializers.Zeros()(shape), trainable=False)
self.variance = backend.Variable(
initializers.GlorotUniform()(shape), trainable=False
)
@ -62,9 +58,7 @@ class MiniBatchNorm(Layer):
self.variance.assign(
self.variance * self.momentum + variance * (1.0 - self.momentum)
)
self.mean.assign(
self.mean * self.momentum + mean * (1.0 - self.momentum)
)
self.mean.assign(self.mean * self.momentum + mean * (1.0 - self.momentum))
else:
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
outputs *= self.gamma