add tests for RandomNormal and RandomUniform initializers
This commit is contained in:
parent
6544d6b850
commit
4b018c0560
@ -32,9 +32,7 @@ class MiniDropout(Layer):
|
||||
self.seed_generator = backend.random.RandomSeedGenerator(1337)
|
||||
|
||||
def call(self, inputs):
|
||||
return backend.random.dropout(
|
||||
inputs, self.rate, seed=self.seed_generator
|
||||
)
|
||||
return backend.random.dropout(inputs, self.rate, seed=self.seed_generator)
|
||||
|
||||
|
||||
class MiniBatchNorm(Layer):
|
||||
@ -45,9 +43,7 @@ class MiniBatchNorm(Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
shape = (input_shape[-1],)
|
||||
self.mean = backend.Variable(
|
||||
initializers.Zeros()(shape), trainable=False
|
||||
)
|
||||
self.mean = backend.Variable(initializers.Zeros()(shape), trainable=False)
|
||||
self.variance = backend.Variable(
|
||||
initializers.GlorotUniform()(shape), trainable=False
|
||||
)
|
||||
@ -62,9 +58,7 @@ class MiniBatchNorm(Layer):
|
||||
self.variance.assign(
|
||||
self.variance * self.momentum + variance * (1.0 - self.momentum)
|
||||
)
|
||||
self.mean.assign(
|
||||
self.mean * self.momentum + mean * (1.0 - self.momentum)
|
||||
)
|
||||
self.mean.assign(self.mean * self.momentum + mean * (1.0 - self.momentum))
|
||||
else:
|
||||
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
|
||||
outputs *= self.gamma
|
||||
|
@ -34,9 +34,7 @@ class MiniDropout(Layer):
|
||||
self.seed_generator = backend.random.RandomSeedGenerator(1337)
|
||||
|
||||
def call(self, inputs):
|
||||
return backend.random.dropout(
|
||||
inputs, self.rate, seed=self.seed_generator
|
||||
)
|
||||
return backend.random.dropout(inputs, self.rate, seed=self.seed_generator)
|
||||
|
||||
|
||||
class MiniBatchNorm(Layer):
|
||||
@ -47,9 +45,7 @@ class MiniBatchNorm(Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
shape = (input_shape[-1],)
|
||||
self.mean = backend.Variable(
|
||||
initializers.Zeros()(shape), trainable=False
|
||||
)
|
||||
self.mean = backend.Variable(initializers.Zeros()(shape), trainable=False)
|
||||
self.variance = backend.Variable(
|
||||
initializers.GlorotUniform()(shape), trainable=False
|
||||
)
|
||||
@ -64,9 +60,7 @@ class MiniBatchNorm(Layer):
|
||||
self.variance.assign(
|
||||
self.variance * self.momentum + variance * (1.0 - self.momentum)
|
||||
)
|
||||
self.mean.assign(
|
||||
self.mean * self.momentum + mean * (1.0 - self.momentum)
|
||||
)
|
||||
self.mean.assign(self.mean * self.momentum + mean * (1.0 - self.momentum))
|
||||
else:
|
||||
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
|
||||
outputs *= self.gamma
|
||||
@ -113,9 +107,7 @@ optimizer.build(model.trainable_variables)
|
||||
## Currently operational workflow
|
||||
|
||||
|
||||
def compute_loss_and_updates(
|
||||
trainable_variables, non_trainable_variables, x, y
|
||||
):
|
||||
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
|
||||
y_pred, non_trainable_variables = model.stateless_call(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
@ -127,9 +119,7 @@ grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def train_step(
|
||||
trainable_variables, non_trainable_variables, optimizer_variables, x, y
|
||||
):
|
||||
def train_step(trainable_variables, non_trainable_variables, optimizer_variables, x, y):
|
||||
(loss, non_trainable_variables), grads = grad_fn(
|
||||
trainable_variables, non_trainable_variables, x, y
|
||||
)
|
||||
@ -144,20 +134,14 @@ trainable_variables = model.trainable_variables
|
||||
non_trainable_variables = model.non_trainable_variables
|
||||
optimizer_variables = optimizer.variables
|
||||
for x, y in dataset:
|
||||
(
|
||||
trainable_variables,
|
||||
non_trainable_variables,
|
||||
optimizer_variables,
|
||||
) = train_step(
|
||||
trainable_variables, non_trainable_variables, optimizer_variables = train_step(
|
||||
trainable_variables, non_trainable_variables, optimizer_variables, x, y
|
||||
)
|
||||
|
||||
# Post-processing model state update
|
||||
for variable, value in zip(model.trainable_variables, trainable_variables):
|
||||
variable.assign(value)
|
||||
for variable, value in zip(
|
||||
model.non_trainable_variables, non_trainable_variables
|
||||
):
|
||||
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
|
||||
variable.assign(value)
|
||||
|
||||
print("Updated values")
|
||||
|
@ -30,7 +30,9 @@ class KerasVariable:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return f"<KerasVariable shape={self.shape}, dtype={self.dtype}, name={self.name}>"
|
||||
return (
|
||||
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, name={self.name}>"
|
||||
)
|
||||
|
||||
|
||||
ALLOWED_DTYPES = {
|
||||
|
@ -82,8 +82,7 @@ def set_floatx(value):
|
||||
accepted_dtypes = {"float16", "float32", "float64"}
|
||||
if value not in accepted_dtypes:
|
||||
raise ValueError(
|
||||
f"Unknown `floatx` value: {value}. "
|
||||
f"Expected one of {accepted_dtypes}"
|
||||
f"Unknown `floatx` value: {value}. " f"Expected one of {accepted_dtypes}"
|
||||
)
|
||||
_FLOATX = str(value)
|
||||
|
||||
|
@ -264,8 +264,7 @@ def compute_output_spec(fn, *args, **kwargs):
|
||||
return fn(*args, *static_args, **kwargs, **static_kwargs)
|
||||
|
||||
maybe_symbolic_args, maybe_symbolic_kwargs = nest.map_structure(
|
||||
convert_keras_tensor_to_jax,
|
||||
(maybe_symbolic_args, maybe_symbolic_kwargs),
|
||||
convert_keras_tensor_to_jax, (maybe_symbolic_args, maybe_symbolic_kwargs)
|
||||
)
|
||||
_, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
|
||||
*maybe_symbolic_args, **maybe_symbolic_kwargs
|
||||
|
@ -72,6 +72,4 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
|
||||
keep_prob = 1.0 - rate
|
||||
mask = jax.random.bernoulli(seed, p=keep_prob, shape=noise_shape)
|
||||
mask = jax.numpy.broadcast_to(mask, inputs.shape)
|
||||
return jax.lax.select(
|
||||
mask, inputs / keep_prob, jax.numpy.zeros_like(inputs)
|
||||
)
|
||||
return jax.lax.select(mask, inputs / keep_prob, jax.numpy.zeros_like(inputs))
|
||||
|
@ -22,9 +22,7 @@ def draw_seed(seed):
|
||||
|
||||
if isinstance(seed, RandomSeedGenerator):
|
||||
new_seed_value = seed.state.value
|
||||
seed.state.assign(
|
||||
seed.state + convert_to_tensor([0, 1], dtype="uint32")
|
||||
)
|
||||
seed.state.assign(seed.state + convert_to_tensor([0, 1], dtype="uint32"))
|
||||
return new_seed_value
|
||||
elif isinstance(seed, int):
|
||||
return convert_to_tensor([seed, 0], dtype="uint32")
|
||||
|
@ -128,32 +128,22 @@ class Variable(KerasVariable):
|
||||
return self.value.__rdiv__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.value.__truediv__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__truediv__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return self.value.__rtruediv__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__rtruediv__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self.value.__floordiv__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__floordiv__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return self.value.__rfloordiv__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__rfloordiv__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __divmod__(self, other):
|
||||
return self.value.__divmod__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __rdivmod__(self, other):
|
||||
return self.value.__rdivmod__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__rdivmod__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __mod__(self, other):
|
||||
return self.value.__mod__(convert_to_tensor(other, dtype=self.dtype))
|
||||
@ -171,9 +161,7 @@ class Variable(KerasVariable):
|
||||
return self.value.__matmul__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __rmatmul__(self, other):
|
||||
return self.value.__rmatmul__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__rmatmul__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __and__(self, other):
|
||||
return self.value.__and__(convert_to_tensor(other, dtype=self.dtype))
|
||||
@ -197,17 +185,13 @@ class Variable(KerasVariable):
|
||||
return self.value.__lshift__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __rlshift__(self, other):
|
||||
return self.value.__rlshift__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__rlshift__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __rshift__(self, other):
|
||||
return self.value.__rshift__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __rrshift__(self, other):
|
||||
return self.value.__rrshift__(
|
||||
convert_to_tensor(other, dtype=self.dtype)
|
||||
)
|
||||
return self.value.__rrshift__(convert_to_tensor(other, dtype=self.dtype))
|
||||
|
||||
def __round__(self, ndigits=None):
|
||||
return self.value.__round__(ndigits)
|
||||
@ -248,9 +232,7 @@ def compute_output_spec(fn, *args, **kwargs):
|
||||
return tf.compat.v1.placeholder(shape=x.shape, dtype=x.dtype)
|
||||
return x
|
||||
|
||||
args, kwargs = tf.nest.map_structure(
|
||||
convert_keras_tensor_to_tf, (args, kwargs)
|
||||
)
|
||||
args, kwargs = tf.nest.map_structure(convert_keras_tensor_to_tf, (args, kwargs))
|
||||
tf_out = fn(*args, **kwargs)
|
||||
|
||||
def convert_tf_to_keras_tensor(x):
|
||||
@ -265,6 +247,4 @@ def execute(op_name, *args, **kwargs):
|
||||
if hasattr(tfnp, op_name):
|
||||
op = getattr(tfnp, op_name)
|
||||
return op(*args, **kwargs)
|
||||
raise AttributeError(
|
||||
f"The TensorFlow backend does not support op '{op_name}'"
|
||||
)
|
||||
raise AttributeError(f"The TensorFlow backend does not support op '{op_name}'")
|
||||
|
@ -2,9 +2,7 @@ import warnings
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.Initializer", "keras_core.initializers.Initializer"]
|
||||
)
|
||||
@keras_core_export(["keras_core.Initializer", "keras_core.initializers.Initializer"])
|
||||
class Initializer:
|
||||
"""Initializer base class: all Keras initializers inherit from this class.
|
||||
|
||||
|
@ -52,8 +52,7 @@ class VarianceScaling(Initializer):
|
||||
):
|
||||
if scale <= 0.0:
|
||||
raise ValueError(
|
||||
"Argument `scale` must be positive float. "
|
||||
f"Received: scale={scale}"
|
||||
"Argument `scale` must be positive float. " f"Received: scale={scale}"
|
||||
)
|
||||
allowed_modes = {"fan_in", "fan_out", "fan_avg"}
|
||||
if mode not in allowed_modes:
|
||||
@ -156,9 +155,7 @@ class GlorotUniform(VarianceScaling):
|
||||
"""
|
||||
|
||||
def __init__(self, seed=None):
|
||||
super().__init__(
|
||||
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed
|
||||
)
|
||||
super().__init__(scale=1.0, mode="fan_avg", distribution="uniform", seed=seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"seed": self.seed}
|
||||
@ -284,9 +281,7 @@ class LecunUniform(VarianceScaling):
|
||||
"""
|
||||
|
||||
def __init__(self, seed=None):
|
||||
super().__init__(
|
||||
scale=1.0, mode="fan_in", distribution="uniform", seed=seed
|
||||
)
|
||||
super().__init__(scale=1.0, mode="fan_in", distribution="uniform", seed=seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"seed": self.seed}
|
||||
@ -364,9 +359,7 @@ class HeUniform(VarianceScaling):
|
||||
"""
|
||||
|
||||
def __init__(self, seed=None):
|
||||
super().__init__(
|
||||
scale=2.0, mode="fan_in", distribution="uniform", seed=seed
|
||||
)
|
||||
super().__init__(scale=2.0, mode="fan_in", distribution="uniform", seed=seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"seed": self.seed}
|
||||
@ -398,3 +391,101 @@ def compute_fans(shape):
|
||||
fan_in = shape[-2] * receptive_field_size
|
||||
fan_out = shape[-1] * receptive_field_size
|
||||
return int(fan_in), int(fan_out)
|
||||
|
||||
|
||||
class RandomNormal(Initializer):
|
||||
"""Random normal initializer.
|
||||
|
||||
Draws samples from a normal distribution for given parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # Standalone usage:
|
||||
>>> initializer = RandomNormal(mean=0.0, stddev=1.0)
|
||||
>>> values = initializer(shape=(2, 2))
|
||||
|
||||
>>> # Usage in a Keras layer:
|
||||
>>> initializer = RandomNormal(mean=0.0, stddev=1.0)
|
||||
>>> layer = Dense(3, kernel_initializer=initializer)
|
||||
|
||||
Args:
|
||||
mean: A python scalar or a scalar keras tensor. Mean of the random values to
|
||||
generate.
|
||||
stddev: A python scalar or a scalar keras tensor. Standard deviation of the
|
||||
random values to generate.
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.RandomSeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.RandomSeedGenerator`.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0.0, stddev=1.0, seed=None):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
self.seed = seed or random.make_default_seed()
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
return random.normal(
|
||||
shape=shape,
|
||||
mean=self.mean,
|
||||
stddev=self.stddev,
|
||||
seed=self.seed,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return {"mean": self.mean, "stddev":self.stddev, "seed":self.seed}
|
||||
|
||||
|
||||
class RandomUniform(Initializer):
|
||||
"""Random uniform initializer.
|
||||
|
||||
Draws samples from a uniform distribution for given parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # Standalone usage:
|
||||
>>> initializer = RandomUniform(minval=0.0, maxval=1.0)
|
||||
>>> values = initializer(shape=(2, 2))
|
||||
|
||||
>>> # Usage in a Keras layer:
|
||||
>>> initializer = RandomUniform(minval=0.0, maxval=1.0)
|
||||
>>> layer = Dense(3, kernel_initializer=initializer)
|
||||
|
||||
Args:
|
||||
minval: A python scalar or a scalar keras tensor. Lower bound of the range of
|
||||
random values to generate (inclusive).
|
||||
maxval: A python scalar or a scalar keras tensor. Upper bound of the range of
|
||||
random values to generate (exclusive).
|
||||
seed: A Python integer or instance of
|
||||
`keras_core.backend.RandomSeedGenerator`.
|
||||
Used to make the behavior of the initializer
|
||||
deterministic. Note that an initializer seeded with an integer
|
||||
or None (unseeded) will produce the same random values
|
||||
across multiple calls. To get different random values
|
||||
across multiple calls, use as seed an instance
|
||||
of `keras_core.backend.RandomSeedGenerator`.
|
||||
"""
|
||||
|
||||
def __init__(self, minval=0.0, maxval=1.0, seed=None):
|
||||
self.minval = minval
|
||||
self.maxval = maxval
|
||||
self.seed = seed or random.make_default_seed()
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
return random.uniform(
|
||||
shape=shape,
|
||||
minval=self.minval,
|
||||
maxval=self.maxval,
|
||||
seed=self.seed,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return {"minval": self.minval, "maxval":self.maxval, "seed":self.seed}
|
||||
|
49
keras_core/initializers/random_initializers_test.py
Normal file
49
keras_core/initializers/random_initializers_test.py
Normal file
@ -0,0 +1,49 @@
|
||||
from keras_core import testing
|
||||
from keras_core import initializers
|
||||
import numpy as np
|
||||
|
||||
|
||||
class InitializersTest(testing.TestCase):
|
||||
def test_random_normal(self):
|
||||
shape = (5, 5)
|
||||
mean = 0.0
|
||||
stddev = 1.0
|
||||
seed = 1234
|
||||
external_config = {"mean": 1.0, "stddev": 0.5, "seed": 42}
|
||||
initializer = initializers.RandomNormal(
|
||||
mean=mean,
|
||||
stddev=stddev,
|
||||
seed=seed
|
||||
)
|
||||
values = initializer(shape=shape)
|
||||
self.assertEqual(initializer.mean, mean)
|
||||
self.assertEqual(initializer.stddev, stddev)
|
||||
self.assertEqual(initializer.seed, seed)
|
||||
self.assertEqual(values.shape, shape)
|
||||
self.test_load_external_config(initializer, external_config)
|
||||
|
||||
def test_random_uniform(self):
|
||||
shape = (5, 5)
|
||||
minval = -1.0
|
||||
maxval = 1.0
|
||||
seed = 1234
|
||||
external_config = {"minval": 0.0, "maxval": 1.0, "seed": 42}
|
||||
initializer = initializers.RandomUniform(
|
||||
minval=minval,
|
||||
maxval=maxval,
|
||||
seed=seed
|
||||
)
|
||||
values = initializer(shape=shape)
|
||||
self.assertEqual(initializer.minval, minval)
|
||||
self.assertEqual(initializer.maxval, maxval)
|
||||
self.assertEqual(initializer.seed, seed)
|
||||
self.assertEqual(values.shape, shape)
|
||||
self.test_load_external_config(initializer, external_config)
|
||||
values = values.numpy()
|
||||
self.assertGreaterEqual(np.min(values), minval)
|
||||
self.assertLess(np.max(values), maxval)
|
||||
|
||||
def test_load_external_config(self, initializer, config):
|
||||
initializer = initializer.from_config(config)
|
||||
self.assertEqual(initializer.get_config(), config)
|
||||
|
@ -56,9 +56,7 @@ class InputSpec:
|
||||
allow_last_axis_squeeze=False,
|
||||
name=None,
|
||||
):
|
||||
self.dtype = (
|
||||
backend.standardize_dtype(dtype) if dtype is not None else None
|
||||
)
|
||||
self.dtype = backend.standardize_dtype(dtype) if dtype is not None else None
|
||||
if shape is not None:
|
||||
self.shape = backend.standardize_shape(shape)
|
||||
self.ndim = len(shape)
|
||||
|
@ -62,8 +62,7 @@ class Layer(Operation):
|
||||
),
|
||||
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
|
||||
"layers": (
|
||||
lambda x: isinstance(x, Layer)
|
||||
and not isinstance(x, Metric),
|
||||
lambda x: isinstance(x, Layer) and not isinstance(x, Metric),
|
||||
self._layers,
|
||||
),
|
||||
# TODO: RandomSeedGenerator tracking
|
||||
@ -252,9 +251,7 @@ class Layer(Operation):
|
||||
# Argument validation and conversion. #
|
||||
# 1. Convert first positional argument to tensor of correct dtype.
|
||||
if args and not isinstance(args[0], KerasTensor):
|
||||
args = (
|
||||
nest.map_structure(backend.convert_to_tensor, args[0]),
|
||||
) + args[1:]
|
||||
args = (nest.map_structure(backend.convert_to_tensor, args[0]),) + args[1:]
|
||||
|
||||
# 2. Convert any other array arguments to tensors of correct dtype.
|
||||
def maybe_convert(x):
|
||||
@ -436,7 +433,7 @@ class Layer(Operation):
|
||||
def add_metric(self):
|
||||
# Permanently disabled
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def count_params(self):
|
||||
"""Count the total number of scalars composing the weights.
|
||||
|
||||
@ -538,7 +535,7 @@ class Layer(Operation):
|
||||
values = self.call.__defaults__
|
||||
mapping = dict(zip(kwargs, values))
|
||||
return mapping.get("training", None)
|
||||
|
||||
|
||||
def _flatten_layers(self, include_self=True, recursive=True):
|
||||
layers = []
|
||||
if include_self:
|
||||
@ -581,12 +578,9 @@ def get_shapes_dict(arguments_dict):
|
||||
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
|
||||
elif nest.is_nested(v):
|
||||
flat = nest.flatten(v)
|
||||
if any(
|
||||
isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat
|
||||
):
|
||||
if any(isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat):
|
||||
if not all(
|
||||
isinstance(x, KerasTensor) or backend.is_tensor(x)
|
||||
for x in flat
|
||||
isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat
|
||||
):
|
||||
raise ValueError(
|
||||
"You cannot mix tensors and non-tensors in a nested argument. "
|
||||
|
@ -13,9 +13,7 @@ class FunctionTest(testing.TestCase):
|
||||
return x + 1
|
||||
|
||||
x = keras_tensor.KerasTensor(shape=(2, 3), name="x")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Only input tensors may be passed as"
|
||||
):
|
||||
with self.assertRaisesRegex(ValueError, "Only input tensors may be passed as"):
|
||||
SomeLayer()(x, True)
|
||||
|
||||
# This works
|
||||
|
@ -1,7 +1,6 @@
|
||||
from keras_core import operations as ops
|
||||
from keras_core import backend
|
||||
from keras_core.utils.naming import auto_name
|
||||
from keras_core.utils import dtype_utils
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
||||
|
||||
@ -31,10 +30,7 @@ class Loss:
|
||||
mask = None
|
||||
|
||||
return reduce_weighted_loss(
|
||||
losses,
|
||||
sample_weight=sample_weight,
|
||||
mask=mask,
|
||||
reduction=self.reduction,
|
||||
losses, sample_weight=sample_weight, mask=mask, reduction=self.reduction
|
||||
)
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
@ -75,11 +71,7 @@ def squeeze_to_same_rank(x1, x2):
|
||||
|
||||
|
||||
def reduce_loss(losses, reduction="sum_over_batch_size"):
|
||||
if (
|
||||
reduction is None
|
||||
or tuple(losses.shape) == ()
|
||||
or tuple(losses.shape) == (0,)
|
||||
):
|
||||
if reduction is None or tuple(losses.shape) == () or tuple(losses.shape) == (0,):
|
||||
return losses
|
||||
loss = ops.sum(losses)
|
||||
if reduction == "sum_over_batch_size":
|
||||
@ -109,7 +101,7 @@ def reduce_weighted_loss(
|
||||
# Convert any non float dtypes to floats, to avoid loss of precision
|
||||
# for dtype like int or bool.
|
||||
dtype = backend.standardize_dtype(losses.dtype)
|
||||
if not dtype_utils.is_float(dtype):
|
||||
if not is_float(dtype):
|
||||
input_dtype = losses.dtype
|
||||
losses = ops.cast(losses, "float32")
|
||||
input_casted = True
|
||||
@ -131,6 +123,45 @@ def reduce_weighted_loss(
|
||||
return loss
|
||||
|
||||
|
||||
def float_dtype_size(dtype):
|
||||
if dtype in ("bfloat16", "float16"):
|
||||
return 16
|
||||
if dtype == "float32":
|
||||
return 32
|
||||
if dtype == "float64":
|
||||
return 64
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
|
||||
def is_float(dtype):
|
||||
return "float" in dtype
|
||||
|
||||
|
||||
def cast_to_common_dtype(tensors):
|
||||
"""Cast a list of tensors to a common dtype.
|
||||
|
||||
If any tensor is floating-point, they will all be casted to the most-precise
|
||||
floating-point dtype. Otherwise the tensors are not casted.
|
||||
|
||||
Args:
|
||||
tensors: A list of tensors.
|
||||
|
||||
Returns:
|
||||
Same list, casted to a common dtype.
|
||||
"""
|
||||
highest_float = None
|
||||
for x in tensors:
|
||||
dtype = backend.standardize_dtype(x.dtype)
|
||||
if is_float(dtype):
|
||||
if highest_float is None or float_dtype_size(dtype) > highest_float:
|
||||
highest_float = dtype
|
||||
elif dtype == "float16" and highest_float == "bfloat16":
|
||||
highest_float = "float32"
|
||||
if highest_float:
|
||||
tensors = [ops.cast(x, highest_float) for x in tensors]
|
||||
return tensors
|
||||
|
||||
|
||||
def apply_mask(sample_weight, mask, dtype, reduction):
|
||||
"""Applies any mask on predictions to sample weights."""
|
||||
if mask is not None:
|
||||
|
@ -52,9 +52,7 @@ class LossTest(testing.TestCase):
|
||||
loss_fn = ExampleLoss()
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
self.assertEqual(loss.dtype.name, "float32")
|
||||
self.assertAllClose(
|
||||
np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss
|
||||
)
|
||||
self.assertAllClose(np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss)
|
||||
|
||||
# Test edge case where everything is masked.
|
||||
mask = np.array([False, False, False, False])
|
||||
@ -71,9 +69,7 @@ class LossTest(testing.TestCase):
|
||||
loss_fn = ExampleLoss()
|
||||
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertEqual(loss.dtype.name, "float32")
|
||||
self.assertAllClose(
|
||||
np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss
|
||||
)
|
||||
self.assertAllClose(np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss)
|
||||
|
||||
# Test edge case where every weight is 0.
|
||||
sample_weight = np.array([0.0, 0.0, 0.0, 0.0])
|
||||
@ -100,8 +96,7 @@ class LossTest(testing.TestCase):
|
||||
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertEqual(loss.dtype.name, "float32")
|
||||
self.assertAllClose(
|
||||
np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2)
|
||||
/ 3,
|
||||
np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2) / 3,
|
||||
loss,
|
||||
)
|
||||
|
||||
@ -135,10 +130,7 @@ class LossTest(testing.TestCase):
|
||||
loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertEqual(loss.dtype.name, "float32")
|
||||
self.assertAllClose(
|
||||
np.sum(
|
||||
masked_sample_weight * (masked_y_true - masked_y_pred) ** 2
|
||||
)
|
||||
/ 3,
|
||||
np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2) / 3,
|
||||
loss,
|
||||
)
|
||||
|
||||
|
@ -4,9 +4,7 @@ from keras_core.losses.loss import squeeze_to_same_rank
|
||||
|
||||
|
||||
class LossFunctionWrapper(Loss):
|
||||
def __init__(
|
||||
self, fn, reduction="sum_over_batch_size", name=None, **kwargs
|
||||
):
|
||||
def __init__(self, fn, reduction="sum_over_batch_size", name=None, **kwargs):
|
||||
super().__init__(reduction=reduction, name=name)
|
||||
self.fn = fn
|
||||
self._fn_kwargs = kwargs
|
||||
@ -31,7 +29,5 @@ def mean_squared_error(y_true, y_pred):
|
||||
|
||||
|
||||
class MeanSquaredError(LossFunctionWrapper):
|
||||
def __init__(
|
||||
self, reduction="sum_over_batch_size", name="mean_squared_error"
|
||||
):
|
||||
def __init__(self, reduction="sum_over_batch_size", name="mean_squared_error"):
|
||||
super().__init__(mean_squared_error, reduction=reduction, name=name)
|
||||
|
@ -9,9 +9,7 @@ import numpy as np
|
||||
class ExampleMetric(Metric):
|
||||
def __init__(self, name="mean_square_error", dtype=None):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
self.sum = self.add_variable(
|
||||
name="sum", initializer=initializers.Zeros()
|
||||
)
|
||||
self.sum = self.add_variable(name="sum", initializer=initializers.Zeros())
|
||||
self.total = self.add_variable(
|
||||
name="total", initializer=initializers.Zeros(), dtype="int32"
|
||||
)
|
||||
@ -50,9 +48,7 @@ class MetricTest(testing.TestCase):
|
||||
|
||||
self.assertAllClose(metric.total, 20)
|
||||
result = metric.result()
|
||||
self.assertAllClose(
|
||||
result, np.sum((y_true - y_pred) ** 2) / num_samples
|
||||
)
|
||||
self.assertAllClose(result, np.sum((y_true - y_pred) ** 2) / num_samples)
|
||||
metric.reset_state()
|
||||
self.assertEqual(metric.result(), 0.0)
|
||||
|
||||
@ -64,10 +60,7 @@ class MetricTest(testing.TestCase):
|
||||
|
||||
# In dict
|
||||
metric = ExampleMetric(name="mse")
|
||||
metric.more_vars = {
|
||||
"a": backend.Variable(0.0),
|
||||
"b": backend.Variable(1.0),
|
||||
}
|
||||
metric.more_vars = {"a": backend.Variable(0.0), "b": backend.Variable(1.0)}
|
||||
self.assertEqual(len(metric.variables), 4)
|
||||
|
||||
# In nested structured
|
||||
|
@ -6,12 +6,8 @@ from keras_core import initializers
|
||||
class MeanSquareError(Metric):
|
||||
def __init__(self, name="mean_square_error", dtype=None):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
self.sum = self.add_variable(
|
||||
name="sum", initializer=initializers.Zeros()
|
||||
)
|
||||
self.total = self.add_variable(
|
||||
name="total", initializer=initializers.Zeros()
|
||||
)
|
||||
self.sum = self.add_variable(name="sum", initializer=initializers.Zeros())
|
||||
self.total = self.add_variable(name="total", initializer=initializers.Zeros())
|
||||
|
||||
def update_state(self, y_true, y_pred):
|
||||
# TODO: add support for sample_weight
|
||||
|
@ -27,7 +27,6 @@ class Functional(Function, Model):
|
||||
return
|
||||
super().__init__(inputs, outputs, name=name)
|
||||
self._layers = self.layers
|
||||
self.built = True
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
@ -60,7 +59,7 @@ class Functional(Function, Model):
|
||||
pass
|
||||
|
||||
def add_loss(self, loss):
|
||||
# Symbolic only. TODO
|
||||
# Symbolic only.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ class Model(Layer, Trainer):
|
||||
@property
|
||||
def layers(self):
|
||||
return list(self._flatten_layers(include_self=False, recursive=False))
|
||||
|
||||
|
||||
@layers.setter
|
||||
def layers(self, _):
|
||||
raise AttributeError(
|
||||
|
@ -107,9 +107,7 @@ class Function(Operation):
|
||||
|
||||
def _assert_input_compatibility(self, inputs):
|
||||
try:
|
||||
nest.assert_same_structure(
|
||||
inputs, self._inputs_struct, check_types=False
|
||||
)
|
||||
nest.assert_same_structure(inputs, self._inputs_struct, check_types=False)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Function was called with an invalid input structure. "
|
||||
|
@ -14,9 +14,7 @@ class FunctionTest(testing.TestCase):
|
||||
x = knp.add(x1, x2)
|
||||
y1 = x * 3
|
||||
y2 = x**2
|
||||
fn = function.Function(
|
||||
inputs=[x1, x2], outputs=[y1, y2], name="test_function"
|
||||
)
|
||||
fn = function.Function(inputs=[x1, x2], outputs=[y1, y2], name="test_function")
|
||||
self.assertEqual(fn.name, "test_function")
|
||||
|
||||
# Eager call
|
||||
@ -90,9 +88,7 @@ class FunctionTest(testing.TestCase):
|
||||
x = knp.add(x1, x2)
|
||||
y1 = x * 3
|
||||
y2 = x**2
|
||||
fn = function.Function(
|
||||
inputs=[x1, x2], outputs=[y1, y2], name="test_function"
|
||||
)
|
||||
fn = function.Function(inputs=[x1, x2], outputs=[y1, y2], name="test_function")
|
||||
self.assertEqual(fn.name, "test_function")
|
||||
|
||||
# Bad structure
|
||||
|
@ -36,9 +36,7 @@ class Node:
|
||||
outputs: The output tensors of the `op.__call__()` call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, operation, call_args=None, call_kwargs=None, outputs=None
|
||||
):
|
||||
def __init__(self, operation, call_args=None, call_kwargs=None, outputs=None):
|
||||
self.operation = operation
|
||||
self.arguments = SymbolicArguments(*call_args, **call_kwargs)
|
||||
self.outputs = [] if outputs is None else nest.flatten(outputs)
|
||||
@ -101,9 +99,7 @@ class Node:
|
||||
|
||||
|
||||
class KerasHistory(
|
||||
collections.namedtuple(
|
||||
"KerasHistory", ["operation", "node_index", "tensor_index"]
|
||||
)
|
||||
collections.namedtuple("KerasHistory", ["operation", "node_index", "tensor_index"])
|
||||
):
|
||||
"""Tracks the Operation call that created a Tensor.
|
||||
|
||||
|
@ -285,14 +285,10 @@ class Mean(Operation):
|
||||
self.keepdims = keepdims
|
||||
|
||||
def call(self, x):
|
||||
return backend.execute(
|
||||
"mean", x, axis=self.axis, keepdims=self.keepdims
|
||||
)
|
||||
return backend.execute("mean", x, axis=self.axis, keepdims=self.keepdims)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return compute_np_output_spec(
|
||||
"mean", x, axis=self.axis, keepdims=self.keepdims
|
||||
)
|
||||
return compute_np_output_spec("mean", x, axis=self.axis, keepdims=self.keepdims)
|
||||
|
||||
|
||||
def mean(x, axis=None, keepdims=False):
|
||||
@ -310,9 +306,7 @@ class Var(Operation):
|
||||
return backend.execute("var", x, axis=self.axis, keepdims=self.keepdims)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return compute_np_output_spec(
|
||||
"var", x, axis=self.axis, keepdims=self.keepdims
|
||||
)
|
||||
return compute_np_output_spec("var", x, axis=self.axis, keepdims=self.keepdims)
|
||||
|
||||
|
||||
def var(x, axis=None, keepdims=False):
|
||||
@ -330,9 +324,7 @@ class Sum(Operation):
|
||||
return backend.execute("sum", x, axis=self.axis, keepdims=self.keepdims)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return compute_np_output_spec(
|
||||
"sum", x, axis=self.axis, keepdims=self.keepdims
|
||||
)
|
||||
return compute_np_output_spec("sum", x, axis=self.axis, keepdims=self.keepdims)
|
||||
|
||||
|
||||
def sum(x, axis=None, keepdims=False):
|
||||
|
@ -24,9 +24,7 @@ class Operation:
|
||||
# sets _keras_history on the outputs, and adds itself to the
|
||||
# `_outbound_nodes` of the ops that produced the inputs to this
|
||||
# call.
|
||||
Node(
|
||||
operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs
|
||||
)
|
||||
Node(operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs)
|
||||
return outputs
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
|
@ -61,9 +61,7 @@ class Optimizer:
|
||||
self.iterations = backend.Variable(
|
||||
0, name="iteration", dtype="int64", trainable=False
|
||||
)
|
||||
if isinstance(
|
||||
learning_rate, learning_rate_schedule.LearningRateSchedule
|
||||
):
|
||||
if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule):
|
||||
self._learning_rate = learning_rate
|
||||
elif callable(learning_rate):
|
||||
self._learning_rate = learning_rate
|
||||
@ -264,9 +262,7 @@ class Optimizer:
|
||||
return self._get_current_learning_rate()
|
||||
|
||||
def _get_current_learning_rate(self):
|
||||
if isinstance(
|
||||
self._learning_rate, learning_rate_schedule.LearningRateSchedule
|
||||
):
|
||||
if isinstance(self._learning_rate, learning_rate_schedule.LearningRateSchedule):
|
||||
return self._learning_rate(self.iterations)
|
||||
elif callable(self._learning_rate):
|
||||
return self._learning_rate(self.iterations)
|
||||
@ -340,17 +336,13 @@ class Optimizer:
|
||||
)
|
||||
|
||||
if var_list:
|
||||
self._exclude_from_weight_decay = [
|
||||
id(variable) for variable in var_list
|
||||
]
|
||||
self._exclude_from_weight_decay = [id(variable) for variable in var_list]
|
||||
else:
|
||||
self._exclude_from_weight_decay = []
|
||||
self._exclude_from_weight_decay_names = var_names or []
|
||||
|
||||
def _use_weight_decay(self, variable):
|
||||
exclude_from_weight_decay = getattr(
|
||||
self, "_exclude_from_weight_decay", []
|
||||
)
|
||||
exclude_from_weight_decay = getattr(self, "_exclude_from_weight_decay", [])
|
||||
exclude_from_weight_decay_names = getattr(
|
||||
self, "_exclude_from_weight_decay_names", []
|
||||
)
|
||||
@ -383,9 +375,7 @@ class Optimizer:
|
||||
def _update_model_variables_moving_average(self, var_list):
|
||||
"""Update the stored moving average using the latest value."""
|
||||
if self.use_ema:
|
||||
for var, average in zip(
|
||||
var_list, self._model_variables_moving_average
|
||||
):
|
||||
for var, average in zip(var_list, self._model_variables_moving_average):
|
||||
average.assign(
|
||||
self.ema_momentum * average + (1 - self.ema_momentum) * var
|
||||
)
|
||||
@ -404,9 +394,7 @@ class Optimizer:
|
||||
|
||||
def _overwrite_model_variables_with_average_value_helper(self, var_list):
|
||||
"""Helper function that overwrites model variables."""
|
||||
for var, average_var in zip(
|
||||
var_list, self._model_variables_moving_average
|
||||
):
|
||||
for var, average_var in zip(var_list, self._model_variables_moving_average):
|
||||
var.assign(average_var)
|
||||
|
||||
def finalize_variable_values(self, var_list):
|
||||
@ -439,12 +427,8 @@ class Optimizer:
|
||||
Python dictionary.
|
||||
"""
|
||||
|
||||
if isinstance(
|
||||
self._learning_rate, learning_rate_schedule.LearningRateSchedule
|
||||
):
|
||||
learning_rate = learning_rate_schedule.serialize(
|
||||
self._learning_rate
|
||||
)
|
||||
if isinstance(self._learning_rate, learning_rate_schedule.LearningRateSchedule):
|
||||
learning_rate = learning_rate_schedule.serialize(self._learning_rate)
|
||||
elif isinstance(self._learning_rate, backend.Variable):
|
||||
learning_rate = float(self._learning_rate.numpy())
|
||||
elif ops.is_tensor(self._learning_rate):
|
||||
|
@ -85,9 +85,7 @@ class SGD(optimizer.Optimizer):
|
||||
self.momentums = []
|
||||
for variable in variables:
|
||||
self.momentums.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=variable, name="m"
|
||||
)
|
||||
self.add_variable_from_reference(reference_variable=variable, name="m")
|
||||
)
|
||||
|
||||
def update_step(self, gradient, variable, learning_rate):
|
||||
@ -100,9 +98,7 @@ class SGD(optimizer.Optimizer):
|
||||
if m is not None:
|
||||
m.assign(-gradient * learning_rate + m * momentum)
|
||||
if self.nesterov:
|
||||
variable.assign(
|
||||
variable - gradient * learning_rate + m * momentum
|
||||
)
|
||||
variable.assign(variable - gradient * learning_rate + m * momentum)
|
||||
else:
|
||||
variable.assign(variable + m)
|
||||
else:
|
||||
|
@ -3,9 +3,7 @@ from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.Regularizer", "keras_core.regularizers.Regularizer"]
|
||||
)
|
||||
@keras_core_export(["keras_core.Regularizer", "keras_core.regularizers.Regularizer"])
|
||||
class Regularizer:
|
||||
"""Regularizer base class.
|
||||
|
||||
@ -317,15 +315,10 @@ class OrthogonalRegularizer(Regularizer):
|
||||
inputs = l2_normalize(inputs, axis=0)
|
||||
product = ops.matmul(ops.transpose(inputs), inputs)
|
||||
size = inputs.shape[1]
|
||||
product_no_diagonal = product * (
|
||||
1.0 - ops.eye(size, dtype=inputs.dtype)
|
||||
)
|
||||
product_no_diagonal = product * (1.0 - ops.eye(size, dtype=inputs.dtype))
|
||||
num_pairs = size * (size - 1.0) / 2.0
|
||||
return (
|
||||
self.factor
|
||||
* 0.5
|
||||
* ops.sum(ops.absolute(product_no_diagonal))
|
||||
/ num_pairs
|
||||
self.factor * 0.5 * ops.sum(ops.absolute(product_no_diagonal)) / num_pairs
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
@ -334,9 +327,7 @@ class OrthogonalRegularizer(Regularizer):
|
||||
|
||||
def validate_float_arg(value, name):
|
||||
"""check penalty number availability, raise ValueError if failed."""
|
||||
if not isinstance(value, (float, int)) or (
|
||||
math.isinf(value) or math.isnan(value)
|
||||
):
|
||||
if not isinstance(value, (float, int)) or (math.isinf(value) or math.isnan(value)):
|
||||
raise ValueError(
|
||||
f"Invalid value for argument {name}: expected a float. "
|
||||
f"Received: {name}={value}"
|
||||
|
@ -1,41 +0,0 @@
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
|
||||
|
||||
def float_dtype_size(dtype):
|
||||
if dtype in ("bfloat16", "float16"):
|
||||
return 16
|
||||
if dtype == "float32":
|
||||
return 32
|
||||
if dtype == "float64":
|
||||
return 64
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
|
||||
def is_float(dtype):
|
||||
return "float" in dtype
|
||||
|
||||
|
||||
def cast_to_common_dtype(tensors):
|
||||
"""Cast a list of tensors to a common dtype.
|
||||
|
||||
If any tensor is floating-point, they will all be casted to the most-precise
|
||||
floating-point dtype. Otherwise the tensors are not casted.
|
||||
|
||||
Args:
|
||||
tensors: A list of tensors.
|
||||
|
||||
Returns:
|
||||
Same list, casted to a common dtype.
|
||||
"""
|
||||
highest_float = None
|
||||
for x in tensors:
|
||||
dtype = backend.standardize_dtype(x.dtype)
|
||||
if is_float(dtype):
|
||||
if highest_float is None or float_dtype_size(dtype) > highest_float:
|
||||
highest_float = dtype
|
||||
elif dtype == "float16" and highest_float == "bfloat16":
|
||||
highest_float = "float32"
|
||||
if highest_float:
|
||||
tensors = [ops.cast(x, highest_float) for x in tensors]
|
||||
return tensors
|
@ -42,7 +42,9 @@ def is_interactive_logging_enabled():
|
||||
"""
|
||||
# Use `getattr` in case `INTERACTIVE_LOGGING`
|
||||
# does not have the `enable` attribute.
|
||||
return getattr(INTERACTIVE_LOGGING, "enable", True)
|
||||
return getattr(
|
||||
INTERACTIVE_LOGGING, "enable", True
|
||||
)
|
||||
|
||||
|
||||
def print_msg(message, line_break=True):
|
||||
|
@ -1,48 +1,13 @@
|
||||
from tensorflow import nest
|
||||
from keras_core import backend
|
||||
from keras_core.utils import io_utils
|
||||
from keras_core.utils import dtype_utils
|
||||
from keras_core.utils import text_rendering
|
||||
from keras_core.backend import Variable
|
||||
import math
|
||||
import re
|
||||
|
||||
|
||||
def count_params(weights):
|
||||
shapes = [v.shape for v in weights]
|
||||
shapes = [v.shape for v in weights if isinstance(v, Variable)]
|
||||
return int(sum(math.prod(p) for p in shapes))
|
||||
|
||||
|
||||
def weight_memory_size(weights):
|
||||
"""Compute the memory footprint for weights based on their dtypes.
|
||||
|
||||
Args:
|
||||
weights: An iterable contains the weights to compute weight size.
|
||||
|
||||
Returns:
|
||||
The total memory size (in Bytes) of the weights.
|
||||
"""
|
||||
unique_weights = set(weights)
|
||||
total_memory_size = 0
|
||||
for w in unique_weights:
|
||||
weight_shape = math.prod(w.shape)
|
||||
dtype = backend.standardize_dtype(w.dtype)
|
||||
per_param_size = dtype_utils.float_dtype_size(dtype)
|
||||
total_memory_size += weight_shape * per_param_size
|
||||
return total_memory_size
|
||||
|
||||
|
||||
def readable_memory_size(weight_memory_size):
|
||||
"""Convert the weight memory size (Bytes) to a readable string."""
|
||||
units = ["Byte", "KB", "MB", "GB", "TB", "PB"]
|
||||
scale = 1024
|
||||
for unit in units:
|
||||
if weight_memory_size / scale < 1:
|
||||
return "{:.2f} {}".format(weight_memory_size, unit)
|
||||
else:
|
||||
weight_memory_size /= scale
|
||||
return "{:.2f} {}".format(weight_memory_size, units[-1])
|
||||
|
||||
|
||||
def print_summary(
|
||||
model,
|
||||
line_length=None,
|
||||
@ -121,13 +86,17 @@ def print_summary(
|
||||
if sequential_like:
|
||||
line_length = line_length or 65
|
||||
positions = positions or [0.45, 0.85, 1.0]
|
||||
if positions[-1] <= 1:
|
||||
positions = [int(line_length * p) for p in positions]
|
||||
# header names for the different log elements
|
||||
header = ["Layer (type)", "Output Shape", "Param #"]
|
||||
to_display = ["Layer (type)", "Output Shape", "Param #"]
|
||||
else:
|
||||
line_length = line_length or 98
|
||||
positions = positions or [0.3, 0.6, 0.70, 1.0]
|
||||
if positions[-1] <= 1:
|
||||
positions = [int(line_length * p) for p in positions]
|
||||
# header names for the different log elements
|
||||
header = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
|
||||
to_display = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
|
||||
relevant_nodes = []
|
||||
for v in model._nodes_by_depth.values():
|
||||
relevant_nodes += v
|
||||
@ -135,14 +104,63 @@ def print_summary(
|
||||
if show_trainable:
|
||||
line_length += 11
|
||||
positions.append(line_length)
|
||||
header.append("Trainable")
|
||||
to_display.append("Trainable")
|
||||
|
||||
layer_range = get_layer_index_bound_by_layer_name(model, layer_range)
|
||||
|
||||
print_fn(f'Model: "{model.name}"')
|
||||
rows = []
|
||||
def print_row(fields, positions, nested_level=0):
|
||||
left_to_print = [str(x) for x in fields]
|
||||
while any(left_to_print):
|
||||
line = ""
|
||||
for col in range(len(left_to_print)):
|
||||
if col > 0:
|
||||
start_pos = positions[col - 1]
|
||||
else:
|
||||
start_pos = 0
|
||||
end_pos = positions[col]
|
||||
# Leave room for 2 spaces to delineate columns
|
||||
# we don't need any if we are printing the last column
|
||||
space = 2 if col != len(positions) - 1 else 0
|
||||
cutoff = end_pos - start_pos - space
|
||||
# Except for last col, offset by one to align the start of col
|
||||
if col != len(positions) - 1:
|
||||
cutoff -= 1
|
||||
if col == 0:
|
||||
cutoff -= nested_level
|
||||
fit_into_line = left_to_print[col][:cutoff]
|
||||
# For nicer formatting we line-break on seeing end of
|
||||
# tuple/dict etc.
|
||||
line_break_conditions = ("),", "},", "],", "',")
|
||||
candidate_cutoffs = [
|
||||
fit_into_line.find(x) + len(x)
|
||||
for x in line_break_conditions
|
||||
if fit_into_line.find(x) >= 0
|
||||
]
|
||||
if candidate_cutoffs:
|
||||
cutoff = min(candidate_cutoffs)
|
||||
fit_into_line = fit_into_line[:cutoff]
|
||||
|
||||
def print_layer_summary(layer, prefix=" "):
|
||||
if col == 0:
|
||||
line += "|" * nested_level + " "
|
||||
line += fit_into_line
|
||||
line += " " * space if space else ""
|
||||
left_to_print[col] = left_to_print[col][cutoff:]
|
||||
|
||||
# Pad out to the next position
|
||||
# Make space for nested_level for last column
|
||||
if nested_level and col == len(positions) - 1:
|
||||
line += " " * (positions[col] - len(line) - nested_level)
|
||||
else:
|
||||
line += " " * (positions[col] - len(line))
|
||||
line += "|" * nested_level
|
||||
print_fn(line)
|
||||
|
||||
print_fn(f'Model: "{model.name}"')
|
||||
print_fn("_" * line_length)
|
||||
print_row(to_display, positions)
|
||||
print_fn("=" * line_length)
|
||||
|
||||
def print_layer_summary(layer, nested_level=0):
|
||||
"""Prints a summary for a single layer.
|
||||
|
||||
Args:
|
||||
@ -156,9 +174,9 @@ def print_summary(
|
||||
output_shape = "multiple"
|
||||
except RuntimeError: # output_shape unknown in Eager mode.
|
||||
output_shape = "?"
|
||||
name = prefix + layer.name
|
||||
name = layer.name
|
||||
cls_name = layer.__class__.__name__
|
||||
if not layer.built:
|
||||
if not layer.built and not getattr(layer, "_is_graph_network", False):
|
||||
# If a subclassed model has a layer that is not called in
|
||||
# Model.call, the layer will not be built and we cannot call
|
||||
# layer.count_params().
|
||||
@ -169,9 +187,10 @@ def print_summary(
|
||||
|
||||
if show_trainable:
|
||||
fields.append("Y" if layer.trainable else "N")
|
||||
rows.append(fields)
|
||||
|
||||
def print_layer_summary_with_connections(layer, prefix=""):
|
||||
print_row(fields, positions, nested_level)
|
||||
|
||||
def print_layer_summary_with_connections(layer, nested_level=0):
|
||||
"""Prints a summary for a single layer (including its connections).
|
||||
|
||||
Args:
|
||||
@ -188,15 +207,18 @@ def print_summary(
|
||||
if relevant_nodes and node not in relevant_nodes:
|
||||
# node is not part of the current network
|
||||
continue
|
||||
for kt in node.keras_inputs:
|
||||
keras_history = kt._keras_history
|
||||
inbound_layer = keras_history.layer
|
||||
node_index = keras_history.node_index
|
||||
tensor_index = keras_history.tensor_index
|
||||
|
||||
for (
|
||||
inbound_layer,
|
||||
node_index,
|
||||
tensor_index,
|
||||
_,
|
||||
) in node.iterate_inbound():
|
||||
connections.append(
|
||||
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
|
||||
)
|
||||
name = prefix + layer.name
|
||||
|
||||
name = layer.name
|
||||
cls_name = layer.__class__.__name__
|
||||
fields = [
|
||||
name + " (" + cls_name + ")",
|
||||
@ -204,38 +226,49 @@ def print_summary(
|
||||
layer.count_params(),
|
||||
connections,
|
||||
]
|
||||
|
||||
if show_trainable:
|
||||
fields.append("Y" if layer.trainable else "N")
|
||||
rows.append(fields)
|
||||
|
||||
def print_layer(layer, nested_level=0):
|
||||
print_row(fields, positions, nested_level)
|
||||
|
||||
def print_layer(layer, nested_level=0, is_nested_last=False):
|
||||
if sequential_like:
|
||||
print_layer_summary(layer, prefix=">>>" * nested_level + " ")
|
||||
print_layer_summary(layer, nested_level)
|
||||
else:
|
||||
print_layer_summary_with_connections(
|
||||
layer, prefix=">>>" * nested_level + " "
|
||||
)
|
||||
print_layer_summary_with_connections(layer, nested_level)
|
||||
|
||||
if expand_nested and hasattr(layer, "layers") and layer.layers:
|
||||
nested_layers = layer.layers
|
||||
nested_level += 1
|
||||
for i in range(len(nested_layers)):
|
||||
print_layer(nested_layers[i], nested_level=nested_level)
|
||||
print_fn(
|
||||
"|" * (nested_level + 1)
|
||||
+ "¯" * (line_length - 2 * nested_level - 2)
|
||||
+ "|" * (nested_level + 1)
|
||||
)
|
||||
|
||||
nested_layer = layer.layers
|
||||
is_nested_last = False
|
||||
for i in range(len(nested_layer)):
|
||||
if i == len(nested_layer) - 1:
|
||||
is_nested_last = True
|
||||
print_layer(nested_layer[i], nested_level + 1, is_nested_last)
|
||||
|
||||
print_fn(
|
||||
"|" * nested_level
|
||||
+ "¯" * (line_length - 2 * nested_level)
|
||||
+ "|" * nested_level
|
||||
)
|
||||
|
||||
if not is_nested_last:
|
||||
print_fn(
|
||||
"|" * nested_level
|
||||
+ " " * (line_length - 2 * nested_level)
|
||||
+ "|" * nested_level
|
||||
)
|
||||
|
||||
for layer in model.layers[layer_range[0] : layer_range[1]]:
|
||||
print_layer(layer)
|
||||
print_fn("=" * line_length)
|
||||
|
||||
# Render summary as a table.
|
||||
table = text_rendering.Table(
|
||||
header=header,
|
||||
rows=rows,
|
||||
positions=positions,
|
||||
# Left align layer name, center-align everything else
|
||||
alignments=["left"] + ["center" for _ in range(len(header) - 1)],
|
||||
)
|
||||
print_fn(table.make())
|
||||
|
||||
# After the table, append information about parameter count and size.
|
||||
if hasattr(model, "_collected_trainable_weights"):
|
||||
trainable_count = count_params(model._collected_trainable_weights)
|
||||
trainable_memory_size = weight_memory_size(
|
||||
@ -262,57 +295,6 @@ def print_summary(
|
||||
f"Non-trainable params: {non_trainable_count} "
|
||||
f"({readable_memory_size(non_trainable_memory_size)})"
|
||||
)
|
||||
print_fn("_" * line_length)
|
||||
|
||||
|
||||
def get_layer_index_bound_by_layer_name(model, layer_range=None):
|
||||
"""Get the layer indexes from the model based on layer names.
|
||||
|
||||
The layer indexes can be used to slice the model into sub models for
|
||||
display.
|
||||
|
||||
Args:
|
||||
model: `Model` instance.
|
||||
layer_names: a list or tuple of 2 strings, the starting layer name and
|
||||
ending layer name (both inclusive) for the result. All layers will
|
||||
be included when `None` is provided.
|
||||
|
||||
Returns:
|
||||
The index value of layer based on its unique name (layer_names).
|
||||
Output will be [first_layer_index, last_layer_index + 1].
|
||||
"""
|
||||
if layer_range is not None:
|
||||
if len(layer_range) != 2:
|
||||
raise ValueError(
|
||||
"layer_range must be a list or tuple of length 2. Received: "
|
||||
f"layer_range = {layer_range} of length {len(layer_range)}"
|
||||
)
|
||||
if not isinstance(layer_range[0], str) or not isinstance(
|
||||
layer_range[1], str
|
||||
):
|
||||
raise ValueError(
|
||||
"layer_range should contain string type only. "
|
||||
f"Received: {layer_range}"
|
||||
)
|
||||
else:
|
||||
return [0, len(model.layers)]
|
||||
|
||||
lower_index = [
|
||||
idx
|
||||
for idx, layer in enumerate(model.layers)
|
||||
if re.match(layer_range[0], layer.name)
|
||||
]
|
||||
upper_index = [
|
||||
idx
|
||||
for idx, layer in enumerate(model.layers)
|
||||
if re.match(layer_range[1], layer.name)
|
||||
]
|
||||
|
||||
if not lower_index or not upper_index:
|
||||
raise ValueError(
|
||||
"Passed layer_names do not match the layer names in the model. "
|
||||
f"Received: {layer_range}"
|
||||
)
|
||||
|
||||
if min(lower_index) > max(upper_index):
|
||||
return [min(upper_index), max(lower_index) + 1]
|
||||
return [min(lower_index), max(upper_index) + 1]
|
||||
print_dtensor_variable_summary(model, print_fn, line_length)
|
||||
|
@ -1,126 +0,0 @@
|
||||
import shutil
|
||||
|
||||
|
||||
class Table:
|
||||
def __init__(
|
||||
self, header, rows, positions, alignments=None, max_line_length=80
|
||||
):
|
||||
if len(header) != len(positions):
|
||||
raise ValueError("header and positions should be the same length.")
|
||||
if not all(p <= 1.0 for p in positions):
|
||||
raise ValueError("All positions should be <= 1.")
|
||||
self.alignments = alignments or ["center" for _ in header]
|
||||
if len(self.alignments) != len(header):
|
||||
raise ValueError("header and alignments should be the same length.")
|
||||
last_p = 0.0
|
||||
for p in positions:
|
||||
if p <= last_p:
|
||||
raise ValueError(
|
||||
"All consecutive positions should be greater than the last."
|
||||
)
|
||||
last_p = p
|
||||
self.header = header
|
||||
self.rows = rows
|
||||
|
||||
# Compute columns widths
|
||||
line_length = min(
|
||||
max_line_length, shutil.get_terminal_size().columns - 4
|
||||
)
|
||||
column_widths = []
|
||||
current = 0
|
||||
for pos in positions:
|
||||
width = int(pos * line_length) - current
|
||||
if width < 4:
|
||||
raise ValueError("Insufficient console width to print summary.")
|
||||
column_widths.append(width)
|
||||
current += width
|
||||
self.column_widths = column_widths
|
||||
|
||||
def make_separator(self, left, mid, right, horizontal):
|
||||
line = mid.join(horizontal * width for width in self.column_widths)
|
||||
return f"{left}{line}{right}"
|
||||
|
||||
def print_row(
|
||||
self,
|
||||
fields,
|
||||
vertical_separator="│",
|
||||
alignments=None,
|
||||
):
|
||||
alignments = alignments or ["center" for _ in fields]
|
||||
lines = []
|
||||
line_break_conditions = ("),", "},", "],", "',", ") ")
|
||||
for field, width in zip(fields, self.column_widths):
|
||||
buffered_width = width - 1
|
||||
if len(field) < buffered_width and not "\n" in field:
|
||||
lines.append([field])
|
||||
continue
|
||||
subfields = []
|
||||
while len(field) >= buffered_width or "\n" in field:
|
||||
if "\n" in field[:buffered_width]:
|
||||
# priority: break on line break
|
||||
cutoff = field.find("\n")
|
||||
subfield = field[:cutoff]
|
||||
field = field[cutoff + 1 :]
|
||||
subfields.append(subfield)
|
||||
continue
|
||||
# secondary: break on certain characters
|
||||
candidate_cutoffs = [
|
||||
field.find(x) + len(x)
|
||||
for x in line_break_conditions
|
||||
if 0 < field.find(x) < buffered_width
|
||||
]
|
||||
if candidate_cutoffs:
|
||||
cutoff = min(buffered_width - 1, *candidate_cutoffs)
|
||||
else:
|
||||
cutoff = buffered_width - 1
|
||||
subfield = field[:cutoff]
|
||||
field = field[cutoff:]
|
||||
subfields.append(subfield)
|
||||
if field:
|
||||
subfields.append(field)
|
||||
lines.append(subfields)
|
||||
|
||||
max_subfield_count = max(len(subs) for subs in lines)
|
||||
rendered_lines = []
|
||||
for i in range(max_subfield_count):
|
||||
fields = []
|
||||
for subfields in lines:
|
||||
if len(subfields) < i + 1:
|
||||
field = ""
|
||||
else:
|
||||
field = subfields[i]
|
||||
fields.append(field)
|
||||
line = vertical_separator.join(
|
||||
self.align_field(field, width, alignment)
|
||||
for field, width, alignment in zip(
|
||||
fields, self.column_widths, alignments
|
||||
)
|
||||
)
|
||||
line = f"{vertical_separator}{line}{vertical_separator}"
|
||||
rendered_lines.append(line)
|
||||
return "\n".join(rendered_lines)
|
||||
|
||||
@staticmethod
|
||||
def align_field(field, width, alignment):
|
||||
if alignment == "center":
|
||||
return field.center(width)
|
||||
if alignment == "left":
|
||||
return field.ljust(width)
|
||||
if alignment == "right":
|
||||
return field.rjust(width)
|
||||
|
||||
def make(self):
|
||||
lines = []
|
||||
# Print header
|
||||
lines.append(self.make_separator(*"┏┳┓━"))
|
||||
lines.append(self.print_row(self.header, vertical_separator="┃"))
|
||||
lines.append(self.make_separator(*"┡╇┩━"))
|
||||
|
||||
# Print rows
|
||||
for i, row in enumerate(self.rows):
|
||||
lines.append(self.print_row(row, alignments=self.alignments))
|
||||
if i < len(self.rows) - 1:
|
||||
lines.append(self.make_separator(*"├┼┤─"))
|
||||
|
||||
lines.append(self.make_separator(*"└┴┘─"))
|
||||
return "\n".join(lines)
|
@ -32,9 +32,7 @@ class MiniDropout(Layer):
|
||||
self.seed_generator = backend.random.RandomSeedGenerator(1337)
|
||||
|
||||
def call(self, inputs):
|
||||
return backend.random.dropout(
|
||||
inputs, self.rate, seed=self.seed_generator
|
||||
)
|
||||
return backend.random.dropout(inputs, self.rate, seed=self.seed_generator)
|
||||
|
||||
|
||||
class MiniBatchNorm(Layer):
|
||||
@ -45,9 +43,7 @@ class MiniBatchNorm(Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
shape = (input_shape[-1],)
|
||||
self.mean = backend.Variable(
|
||||
initializers.Zeros()(shape), trainable=False
|
||||
)
|
||||
self.mean = backend.Variable(initializers.Zeros()(shape), trainable=False)
|
||||
self.variance = backend.Variable(
|
||||
initializers.GlorotUniform()(shape), trainable=False
|
||||
)
|
||||
@ -62,9 +58,7 @@ class MiniBatchNorm(Layer):
|
||||
self.variance.assign(
|
||||
self.variance * self.momentum + variance * (1.0 - self.momentum)
|
||||
)
|
||||
self.mean.assign(
|
||||
self.mean * self.momentum + mean * (1.0 - self.momentum)
|
||||
)
|
||||
self.mean.assign(self.mean * self.momentum + mean * (1.0 - self.momentum))
|
||||
else:
|
||||
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
|
||||
outputs *= self.gamma
|
||||
|
Loading…
Reference in New Issue
Block a user