From 12647d83705c1627d3b1eec36e26dce57d9fd69c Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 8 May 2023 13:51:15 -0700 Subject: [PATCH] Add wrapper layer. --- keras_core/layers/__init__.py | 1 + keras_core/layers/activations/__init__.py | 1 - keras_core/layers/activations/leaky_relu.py | 56 -- .../layers/activations/leaky_relu_test.py | 36 - keras_core/layers/core/dense.py | 2 +- keras_core/layers/core/wrapper.py | 47 + keras_core/layers/core/wrapper_test.py | 66 ++ keras_core/layers/layer.py | 19 +- .../schedules/learning_rate_schedule.py | 934 +----------------- .../schedules/learning_rate_schedule_test.py | 452 --------- 10 files changed, 134 insertions(+), 1480 deletions(-) delete mode 100644 keras_core/layers/activations/leaky_relu.py delete mode 100644 keras_core/layers/activations/leaky_relu_test.py create mode 100644 keras_core/layers/core/wrapper.py create mode 100644 keras_core/layers/core/wrapper_test.py delete mode 100644 keras_core/optimizers/schedules/learning_rate_schedule_test.py diff --git a/keras_core/layers/__init__.py b/keras_core/layers/__init__.py index 6b98375e5..dcafbca75 100644 --- a/keras_core/layers/__init__.py +++ b/keras_core/layers/__init__.py @@ -9,6 +9,7 @@ from keras_core.layers.core.identity import Identity from keras_core.layers.core.input_layer import Input from keras_core.layers.core.input_layer import InputLayer from keras_core.layers.core.masking import Masking +from keras_core.layers.core.wrapper import Wrapper from keras_core.layers.layer import Layer from keras_core.layers.merging.add import Add from keras_core.layers.merging.add import add diff --git a/keras_core/layers/activations/__init__.py b/keras_core/layers/activations/__init__.py index 828c64fda..8a753d53a 100644 --- a/keras_core/layers/activations/__init__.py +++ b/keras_core/layers/activations/__init__.py @@ -1,3 +1,2 @@ from keras_core.layers.activations.elu import ELU -from keras_core.layers.activations.leaky_relu import LeakyReLU from keras_core.layers.activations.relu import ReLU diff --git a/keras_core/layers/activations/leaky_relu.py b/keras_core/layers/activations/leaky_relu.py deleted file mode 100644 index 52ff66138..000000000 --- a/keras_core/layers/activations/leaky_relu.py +++ /dev/null @@ -1,56 +0,0 @@ -from keras_core import activations -from keras_core.api_export import keras_core_export -from keras_core.layers.layer import Layer - - -@keras_core_export("keras_core.layers.LeakyReLU") -class LeakyReLU(Layer): - """Leaky version of a Rectified Linear Unit activation layer. - - The layer allows a small gradient when the unit is not active. - - Formula: - ``` python - f(x) = alpha * x if x < 0 - f(x) = x if x >= 0 - ``` - - Example: - ``` python - leaky_relu_layer = LeakyReLU(negative_slope=0.5) - input = np.array([-10, -5, 0.0, 5, 10]) - result = leaky_relu_layer(input) - # result = [-5. , -2.5, 0. , 5. , 10.] - ``` - - Args: - negative_slope: Float >= 0.0. Negative slope coefficient. - Defaults to 0.3. - **kwargs: Base layer keyword arguments, such as - `name` and `dtype`. - - """ - - def __init__(self, negative_slope=0.3, **kwargs): - super().__init__(**kwargs) - if negative_slope is None: - raise ValueError( - "The negative_slope value of a Leaky ReLU layer " - "cannot be None, Expecting a float. Received " - f"negative_slope: {negative_slope}" - ) - self.supports_masking = True - self.negative_slope = negative_slope - - def call(self, inputs): - return activations.leaky_relu( - inputs, negative_slope=self.negative_slope - ) - - def get_config(self): - config = super().get_config() - config.update({"negative_slope": self.negative_slope}) - return config - - def compute_output_shape(self, input_shape): - return input_shape diff --git a/keras_core/layers/activations/leaky_relu_test.py b/keras_core/layers/activations/leaky_relu_test.py deleted file mode 100644 index 926771201..000000000 --- a/keras_core/layers/activations/leaky_relu_test.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np - -from keras_core import testing -from keras_core.layers.activations import leaky_relu - - -class LeakyReLUTest(testing.TestCase): - def test_relu(self): - self.run_layer_test( - leaky_relu.LeakyReLU, - init_kwargs={ - "negative_slope": 1, - }, - input_shape=(2, 3, 4), - supports_masking=True, - ) - - def test_leaky_relu_correctness(self): - leaky_relu_layer = leaky_relu.LeakyReLU(negative_slope=0.5) - input = np.array([-10, -5, 0.0, 5, 10]) - expected_output = np.array([-5.0, -2.5, 0.0, 5.0, 10.0]) - result = leaky_relu_layer(input) - self.assertAllClose(result, expected_output) - - def test_invalid_usage(self): - with self.assertRaisesRegex( - ValueError, - "The negative_slope value of a Leaky ReLU layer cannot be None, " - "Expecting a float. Received negative_slope: None", - ): - self.run_layer_test( - leaky_relu.LeakyReLU, - init_kwargs={"negative_slope": None}, - input_shape=(2, 3, 4), - supports_masking=True, - ) diff --git a/keras_core/layers/core/dense.py b/keras_core/layers/core/dense.py index fbf9b590c..ac596aa89 100644 --- a/keras_core/layers/core/dense.py +++ b/keras_core/layers/core/dense.py @@ -70,7 +70,7 @@ class Dense(Layer): bias_constraint=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(activity_regularizer=activity_regularizer, **kwargs) self.units = units self.activation = activations.get(activation) self.use_bias = use_bias diff --git a/keras_core/layers/core/wrapper.py b/keras_core/layers/core/wrapper.py new file mode 100644 index 000000000..d5310eb7a --- /dev/null +++ b/keras_core/layers/core/wrapper.py @@ -0,0 +1,47 @@ +from keras_core.api_export import keras_core_export +from keras_core.layers.layer import Layer +from keras_core.saving import serialization_lib + + +@keras_core_export("keras_core.layers.Wrapper") +class Wrapper(Layer): + """Abstract wrapper base class. + + Wrappers take another layer and augment it in various ways. + Do not use this class as a layer, it is only an abstract base class. + Two usable wrappers are the `TimeDistributed` and `Bidirectional` layers. + + Args: + layer: The layer to be wrapped. + """ + + def __init__(self, layer, **kwargs): + try: + assert isinstance(layer, Layer) + except Exception: + raise ValueError( + f"Layer {layer} supplied to Wrapper isn't " + "a supported layer type. Please " + "ensure wrapped layer is a valid Keras layer." + ) + super().__init__(**kwargs) + self.layer = layer + + def build(self, input_shape=None): + if not self.layer.built: + self.layer.build(input_shape) + self.layer.built = True + self.built = True + + def get_config(self): + config = {"layer": serialization_lib.serialize_keras_object(self.layer)} + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + layer = serialization_lib.deserialize_keras_object( + config.pop("layer"), + custom_objects=custom_objects, + ) + return cls(layer, **config) diff --git a/keras_core/layers/core/wrapper_test.py b/keras_core/layers/core/wrapper_test.py new file mode 100644 index 000000000..ae1556d6d --- /dev/null +++ b/keras_core/layers/core/wrapper_test.py @@ -0,0 +1,66 @@ +from keras_core import layers +from keras_core import testing + + +class ExampleWrapper(layers.Wrapper): + """Simple Wrapper subclass.""" + + def call(self, inputs, **kwargs): + return self.layer(inputs, **kwargs) + + +class WrapperTest(testing.TestCase): + def test_wrapper_basics(self): + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.Dense(2), + }, + input_shape=(2, 3), + expected_output_shape=(2, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.Dense(2, activity_regularizer="l2"), + }, + input_shape=(2, 3), + expected_output_shape=(2, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=1, + supports_masking=False, + ) + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.Dense(2), + "activity_regularizer": "l2", + }, + input_shape=(2, 3), + expected_output_shape=(2, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=1, + supports_masking=False, + ) + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.BatchNormalization(), + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index 7dd496183..18dbb1c0c 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -202,12 +202,24 @@ class Layer(Operation): @property def variables(self): - # Includes weights, seed generator state, and metric variables. - variables = self.weights[:] + # Return only weights/rng state/metric variables + # of all Layers, recursively. + # Also deduplicate them. + variables = [] + seen_ids = set() + for v in self._trainable_variables + self._non_trainable_variables: + if id(v) not in seen_ids: + variables.append(v) + seen_ids.add(id(v)) for m in self._metrics: variables.extend(m.variables) for sg in self._seed_generators: variables.append(sg.state) + for layer in self._layers: + for v in layer.variables: + if id(v) not in seen_ids: + variables.append(v) + seen_ids.add(id(v)) return variables @property @@ -963,6 +975,9 @@ def get_shapes_dict(call_spec): if k == "mask" or k.startswith("mask_"): # Do not include mask tensors in shapes dict continue + if k == "kwargs" or k == "args": + # Do not include catch-alls in shapes dict + continue if k in call_spec.nested_tensor_argument_names: shapes_dict[f"{k}_shape"] = nest.map_structure( lambda x: backend.standardize_shape(x.shape), v diff --git a/keras_core/optimizers/schedules/learning_rate_schedule.py b/keras_core/optimizers/schedules/learning_rate_schedule.py index 566838d51..f9291e22a 100644 --- a/keras_core/optimizers/schedules/learning_rate_schedule.py +++ b/keras_core/optimizers/schedules/learning_rate_schedule.py @@ -1,933 +1,3 @@ -"""Various learning rate schedule functions.""" - -import math - -from keras_core import operations as ops -from keras_core.api_export import keras_core_export -from keras_core.saving import serialization_lib - - -@keras_core_export("keras_core.optimizers.schedules.LearningRateSchedule") class LearningRateSchedule: - """The learning rate schedule base class. - - You can use a learning rate schedule to modulate how the learning rate - of your optimizer changes over time. - - Several built-in learning rate schedules are available, such as - `keras_core.optimizers.schedules.ExponentialDecay` or - `keras_core.optimizers.schedules.PiecewiseConstantDecay`: - - ```python - lr_schedule = keras_core.optimizers.schedules.ExponentialDecay( - initial_learning_rate=1e-2, - decay_steps=10000, - decay_rate=0.9) - optimizer = keras_core.optimizers.SGD(learning_rate=lr_schedule) - ``` - - A `LearningRateSchedule` instance can be passed in as the `learning_rate` - argument of any optimizer. - - To implement your own schedule object, you should implement the `__call__` - method, which takes a `step` argument (scalar integer tensor, the - current training step count). - Like for any other Keras object, you can also optionally - make your object serializable by implementing the `get_config` - and `from_config` methods. - - Example: - - ```python - class MyLRSchedule(keras_core.optimizers.schedules.LearningRateSchedule): - - def __init__(self, initial_learning_rate): - self.initial_learning_rate = initial_learning_rate - - def __call__(self, step): - return self.initial_learning_rate / (step + 1) - - optimizer = keras_core.optimizers.SGD(learning_rate=MyLRSchedule(0.1)) - ``` - """ - - def __call__(self, step): - raise NotImplementedError( - f"Learning rate schedule '{self.__class__.__name__}' " - "must override `__call__(self, step)`." - ) - - def get_config(self): - raise NotImplementedError( - f"Learning rate schedule '{self.__class__.__name__}' " - "must override `get_config()` in order to be serializable." - ) - - @classmethod - def from_config(cls, config): - """Instantiates a `LearningRateSchedule` from its config. - - Args: - config: Output of `get_config()`. - - Returns: - A `LearningRateSchedule` instance. - """ - return cls(**config) - - -@keras_core_export("keras_core.optimizers.schedules.ExponentialDecay") -class ExponentialDecay(LearningRateSchedule): - """A `LearningRateSchedule` that uses an exponential decay schedule. - - When training a model, it is often useful to lower the learning rate as - the training progresses. This schedule applies an exponential decay function - to an optimizer step, given a provided initial learning rate. - - The schedule is a 1-arg callable that produces a decayed learning - rate when passed the current optimizer step. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - It is computed as: - - ```python - def decayed_learning_rate(step): - return initial_learning_rate * decay_rate ^ (step / decay_steps) - ``` - - If the argument `staircase` is `True`, then `step / decay_steps` is - an integer division and the decayed learning rate follows a - staircase function. - - You can pass this schedule directly into a `keras_core.optimizers.Optimizer` - as the learning rate. - Example: When fitting a Keras model, decay every 100000 steps with a base - of 0.96: - - ```python - initial_learning_rate = 0.1 - lr_schedule = keras_core.optimizers.schedules.ExponentialDecay( - initial_learning_rate, - decay_steps=100000, - decay_rate=0.96, - staircase=True) - - model.compile(optimizer=keras_core.optimizers.SGD(learning_rate=lr_schedule), - loss='sparse_categorical_crossentropy', - metrics=['accuracy']) - - model.fit(data, labels, epochs=5) - ``` - - The learning rate schedule is also serializable and deserializable using - `keras_core.optimizers.schedules.serialize` and - `keras_core.optimizers.schedules.deserialize`. - - Args: - initial_learning_rate: A Python float. The initial learning rate. - decay_steps: A Python integer. Must be positive. See the decay - computation above. - decay_rate: A Python float. The decay rate. - staircase: Boolean. If `True` decay the learning rate at discrete - intervals. - name: String. Optional name of the operation. Defaults to - `"ExponentialDecay`". - - Returns: - A 1-arg callable learning rate schedule that takes the current optimizer - step and outputs the decayed learning rate, a scalar tensor of the - same type as `initial_learning_rate`. - """ - - def __init__( - self, - initial_learning_rate, - decay_steps, - decay_rate, - staircase=False, - name="ExponentialDecay", - ): - super().__init__() - self.initial_learning_rate = initial_learning_rate - self.decay_steps = decay_steps - self.decay_rate = decay_rate - self.staircase = staircase - self.name = name - - def __call__(self, step): - with ops.name_scope(self.name): - initial_learning_rate = ops.convert_to_tensor( - self.initial_learning_rate - ) - dtype = initial_learning_rate.dtype - decay_steps = ops.cast(self.decay_steps, dtype) - decay_rate = ops.cast(self.decay_rate, dtype) - - global_step_recomp = ops.cast(step, dtype) - p = global_step_recomp / decay_steps - if self.staircase: - p = ops.floor(p) - return ops.multiply(initial_learning_rate, ops.power(decay_rate, p)) - - def get_config(self): - return { - "initial_learning_rate": self.initial_learning_rate, - "decay_steps": self.decay_steps, - "decay_rate": self.decay_rate, - "staircase": self.staircase, - "name": self.name, - } - - -@keras_core_export("keras_core.optimizers.schedules.PiecewiseConstantDecay") -class PiecewiseConstantDecay(LearningRateSchedule): - """A `LearningRateSchedule` that uses a piecewise constant decay schedule. - - The function returns a 1-arg callable to compute the piecewise constant - when passed the current optimizer step. This can be useful for changing the - learning rate value across different invocations of optimizer functions. - - Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 - for the next 10000 steps, and 0.1 for any additional steps. - - ```python - step = ops.array(0) - boundaries = [100000, 110000] - values = [1.0, 0.5, 0.1] - learning_rate_fn = keras_core.optimizers.schedules.PiecewiseConstantDecay( - boundaries, values) - - # Later, whenever we perform an optimization step, we pass in the step. - learning_rate = learning_rate_fn(step) - ``` - - You can pass this schedule directly into a `keras_core.optimizers.Optimizer` - as the learning rate. The learning rate schedule is also serializable and - deserializable using `keras_core.optimizers.schedules.serialize` and - `keras_core.optimizers.schedules.deserialize`. - - Args: - boundaries: A list of Python numbers with strictly increasing - entries, and with all elements having the same type as the - optimizer step. - values: A list of Python numbers that specifies the values for the - intervals defined by `boundaries`. It should have one more - element than `boundaries`, and all elements should have the same - type. - name: A string. Optional name of the operation. Defaults to - `"PiecewiseConstant"`. - - Returns: - A 1-arg callable learning rate schedule that takes the current optimizer - step and outputs the decayed learning rate, a scalar tensor of the - same type as the boundary tensors. - - The output of the 1-arg function that takes the `step` - is `values[0]` when `step <= boundaries[0]`, - `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, - ..., and `values[-1]` when `step > boundaries[-1]`. - - - Raises: - ValueError: if the number of elements in the `boundaries` and `values` - lists do not match. - """ - - def __init__(self, boundaries, values, name="PiecewiseConstant"): - super().__init__() - - if len(boundaries) != len(values) - 1: - raise ValueError( - "The length of boundaries should be 1 less than the length of " - f"values. Received: boundaries={boundaries} of length " - f"{len(boundaries)}, and values={values} " - f"of length {len(values)}." - ) - - self.boundaries = boundaries - self.values = values - self.name = name - - def __call__(self, step): - with ops.name_scope(self.name): - boundaries = [ops.convert_to_tensor(x) for x in self.boundaries] - values = [ops.convert_to_tensor(x) for x in self.values] - step = ops.convert_to_tensor(step) - - for i, b in enumerate(boundaries): - if b.dtype != step.dtype: - # We cast the boundaries to have the same type as the step - b = ops.cast(b, step.dtype) - boundaries[i] = b - - result_dtype = values[0].dtype - result_value = ops.array(0, dtype=result_dtype) - - # For each range between boundaries, we check whether the step is - # within that range, cast the resulting boolean to a number, - # and multiply the result by the corresponding value for the range. - # Taking the sum of these yields a piecewise constant function. - step_less_than_first_boundary = ops.cast( - step <= boundaries[0], result_dtype - ) - result_value += step_less_than_first_boundary * values[0] - - step_greater_than_last_boundary = ops.cast( - step > boundaries[-1], result_dtype - ) - result_value += step_greater_than_last_boundary * values[-1] - - for low, high, value in zip( - boundaries[:-1], boundaries[1:], values[1:-1] - ): - # Need to bind v here; can do this with lambda v=v: ... - step_in_range = ops.cast( - (step > low) & (step <= high), result_dtype - ) - result_value += step_in_range * value - - return result_value - - def get_config(self): - return { - "boundaries": self.boundaries, - "values": self.values, - "name": self.name, - } - - -@keras_core_export("keras_core.optimizers.schedules.PolynomialDecay") -class PolynomialDecay(LearningRateSchedule): - """A `LearningRateSchedule` that uses a polynomial decay schedule. - - It is commonly observed that a monotonically decreasing learning rate, whose - degree of change is carefully chosen, results in a better performing model. - This schedule applies a polynomial decay function to an optimizer step, - given a provided `initial_learning_rate`, to reach an `end_learning_rate` - in the given `decay_steps`. - - It requires a `step` value to compute the decayed learning rate. You - can just pass a backend variable that you increment at each training - step. - - The schedule is a 1-arg callable that produces a decayed learning rate - when passed the current optimizer step. This can be useful for changing the - learning rate value across different invocations of optimizer functions. - It is computed as: - - ```python - def decayed_learning_rate(step): - step = min(step, decay_steps) - return ((initial_learning_rate - end_learning_rate) * - (1 - step / decay_steps) ^ (power) - ) + end_learning_rate - ``` - - If `cycle` is True then a multiple of `decay_steps` is used, the first one - that is bigger than `step`. - - ```python - def decayed_learning_rate(step): - decay_steps = decay_steps * ceil(step / decay_steps) - return ((initial_learning_rate - end_learning_rate) * - (1 - step / decay_steps) ^ (power) - ) + end_learning_rate - ``` - - You can pass this schedule directly into a `keras_core.optimizers.Optimizer` - as the learning rate. - Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using - sqrt (i.e. power=0.5): - - ```python - ... - starter_learning_rate = 0.1 - end_learning_rate = 0.01 - decay_steps = 10000 - learning_rate_fn = keras_core.optimizers.schedules.PolynomialDecay( - starter_learning_rate, - decay_steps, - end_learning_rate, - power=0.5) - - model.compile(optimizer=keras_core.optimizers.SGD( - learning_rate=learning_rate_fn), - loss='sparse_categorical_crossentropy', - metrics=['accuracy']) - - model.fit(data, labels, epochs=5) - ``` - - The learning rate schedule is also serializable and deserializable using - `keras_core.optimizers.schedules.serialize` and - `keras_core.optimizers.schedules.deserialize`. - - Args: - initial_learning_rate: A Python float. The initial learning rate. - decay_steps: A Python integer. Must be positive. See the decay - computation above. - end_learning_rate: A Python float. The minimal end learning rate. - power: A Python float. The power of the polynomial. Defaults to - `1.0`. - cycle: A boolean, whether it should cycle beyond decay_steps. - name: String. Optional name of the operation. Defaults to - `"PolynomialDecay"`. - - Returns: - A 1-arg callable learning rate schedule that takes the current optimizer - step and outputs the decayed learning rate, a scalar tensor of the - same type as `initial_learning_rate`. - """ - - def __init__( - self, - initial_learning_rate, - decay_steps, - end_learning_rate=0.0001, - power=1.0, - cycle=False, - name="PolynomialDecay", - ): - super().__init__() - - self.initial_learning_rate = initial_learning_rate - self.decay_steps = decay_steps - self.end_learning_rate = end_learning_rate - self.power = power - self.cycle = cycle - self.name = name - - def __call__(self, step): - with ops.name_scope(self.name): - initial_learning_rate = ops.convert_to_tensor( - self.initial_learning_rate - ) - dtype = initial_learning_rate.dtype - end_learning_rate = ops.cast(self.end_learning_rate, dtype) - power = ops.cast(self.power, dtype) - - global_step_recomp = ops.cast(step, dtype) - decay_steps_recomp = ops.cast(self.decay_steps, dtype) - if self.cycle: - # Find the first multiple of decay_steps that is bigger than - # global_step. If global_step is zero set the multiplier to 1 - multiplier = ops.where( - ops.equal(global_step_recomp, 0), - 1.0, - ops.ceil(global_step_recomp / self.decay_steps), - ) - decay_steps_recomp = ops.multiply( - decay_steps_recomp, multiplier - ) - else: - # Make sure that the global_step used is not bigger than - # decay_steps. - global_step_recomp = ops.minimum( - global_step_recomp, decay_steps_recomp - ) - - p = ops.divide(global_step_recomp, decay_steps_recomp) - return ops.add( - ops.multiply( - initial_learning_rate - end_learning_rate, - ops.power(1 - p, power), - ), - end_learning_rate, - ) - - def get_config(self): - return { - "initial_learning_rate": self.initial_learning_rate, - "decay_steps": self.decay_steps, - "end_learning_rate": self.end_learning_rate, - "power": self.power, - "cycle": self.cycle, - "name": self.name, - } - - -@keras_core_export("keras_core.optimizers.schedules.InverseTimeDecay") -class InverseTimeDecay(LearningRateSchedule): - """A `LearningRateSchedule` that uses an inverse time decay schedule. - - When training a model, it is often useful to lower the learning rate as - the training progresses. This schedule applies the inverse decay function - to an optimizer step, given a provided initial learning rate. - It requires a `step` value to compute the decayed learning rate. You can - just pass a backend variable that you increment at each training step. - - The schedule is a 1-arg callable that produces a decayed learning - rate when passed the current optimizer step. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - It is computed as: - - ```python - def decayed_learning_rate(step): - return initial_learning_rate / (1 + decay_rate * step / decay_step) - ``` - - or, if `staircase` is `True`, as: - - ```python - def decayed_learning_rate(step): - return initial_learning_rate / - (1 + decay_rate * floor(step / decay_step)) - ``` - - You can pass this schedule directly into a `keras_core.optimizers.Optimizer` - as the learning rate. - Example: Fit a Keras model when decaying 1/t with a rate of 0.5: - - ```python - ... - initial_learning_rate = 0.1 - decay_steps = 1.0 - decay_rate = 0.5 - learning_rate_fn = keras_core.optimizers.schedules.InverseTimeDecay( - initial_learning_rate, decay_steps, decay_rate) - - model.compile(optimizer=keras_core.optimizers.SGD( - learning_rate=learning_rate_fn), - loss='sparse_categorical_crossentropy', - metrics=['accuracy']) - - model.fit(data, labels, epochs=5) - ``` - - Args: - initial_learning_rate: A Python float. The initial learning rate. - decay_steps: How often to apply decay. - decay_rate: A Python number. The decay rate. - staircase: Whether to apply decay in a discrete staircase, as o - pposed to continuous, fashion. - name: String. Optional name of the operation. Defaults to - `"InverseTimeDecay"`. - - Returns: - A 1-arg callable learning rate schedule that takes the current optimizer - step and outputs the decayed learning rate, a scalar tensor of the - same type as `initial_learning_rate`. - """ - - def __init__( - self, - initial_learning_rate, - decay_steps, - decay_rate, - staircase=False, - name="InverseTimeDecay", - ): - super().__init__() - - self.initial_learning_rate = initial_learning_rate - self.decay_steps = decay_steps - self.decay_rate = decay_rate - self.staircase = staircase - self.name = name - - def __call__(self, step): - with ops.name_scope(self.name): - initial_learning_rate = ops.convert_to_tensor( - self.initial_learning_rate - ) - dtype = initial_learning_rate.dtype - decay_steps = ops.cast(self.decay_steps, dtype) - decay_rate = ops.cast(self.decay_rate, dtype) - - global_step_recomp = ops.cast(step, dtype) - p = global_step_recomp / decay_steps - if self.staircase: - p = ops.floor(p) - const = ops.cast(ops.array(1), dtype) - denom = ops.add(const, ops.multiply(decay_rate, p)) - return ops.divide(initial_learning_rate, denom) - - def get_config(self): - return { - "initial_learning_rate": self.initial_learning_rate, - "decay_steps": self.decay_steps, - "decay_rate": self.decay_rate, - "staircase": self.staircase, - "name": self.name, - } - - -@keras_core_export("keras_core.optimizers.schedules.CosineDecay") -class CosineDecay(LearningRateSchedule): - """A `LearningRateSchedule` that uses a cosine decay with optional warmup. - - See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983), - SGDR: Stochastic Gradient Descent with Warm Restarts. - - For the idea of a linear warmup of our learning rate, - see [Goyal et al.](https://arxiv.org/pdf/1706.02677.pdf). - - When we begin training a model, we often want an initial increase in our - learning rate followed by a decay. If `warmup_target` is an int, this - schedule applies a linear increase per optimizer step to our learning rate - from `initial_learning_rate` to `warmup_target` for a duration of - `warmup_steps`. Afterwards, it applies a cosine decay function taking our - learning rate from `warmup_target` to `alpha` for a duration of - `decay_steps`. If `warmup_target` is None we skip warmup and our decay - will take our learning rate from `initial_learning_rate` to `alpha`. - It requires a `step` value to compute the learning rate. You can - just pass a backend variable that you increment at each training step. - - The schedule is a 1-arg callable that produces a warmup followed by a - decayed learning rate when passed the current optimizer step. This can be - useful for changing the learning rate value across different invocations of - optimizer functions. - - Our warmup is computed as: - - ```python - def warmup_learning_rate(step): - completed_fraction = step / warmup_steps - total_delta = target_warmup - initial_learning_rate - return completed_fraction * total_delta - ``` - - And our decay is computed as: - - ```python - if warmup_target is None: - initial_decay_lr = initial_learning_rate - else: - initial_decay_lr = warmup_target - - def decayed_learning_rate(step): - step = min(step, decay_steps) - cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps)) - decayed = (1 - alpha) * cosine_decay + alpha - return initial_decay_lr * decayed - ``` - - Example usage without warmup: - - ```python - decay_steps = 1000 - initial_learning_rate = 0.1 - lr_decayed_fn = keras_core.optimizers.schedules.CosineDecay( - initial_learning_rate, decay_steps) - ``` - - Example usage with warmup: - - ```python - decay_steps = 1000 - initial_learning_rate = 0 - warmup_steps = 1000 - target_learning_rate = 0.1 - lr_warmup_decayed_fn = keras_core.optimizers.schedules.CosineDecay( - initial_learning_rate, decay_steps, warmup_target=target_learning_rate, - warmup_steps=warmup_steps - ) - ``` - - You can pass this schedule directly into a `keras_core.optimizers.Optimizer` - as the learning rate. The learning rate schedule is also serializable and - deserializable using `keras_core.optimizers.schedules.serialize` and - `keras_core.optimizers.schedules.deserialize`. - - Args: - initial_learning_rate: A Python float. The initial learning rate. - decay_steps: A Python int. Number of steps to decay over. - alpha: A Python float. Minimum learning rate value for decay as a - fraction of `initial_learning_rate`. - name: String. Optional name of the operation. Defaults to - `"CosineDecay"`. - warmup_target: A Python float. The target learning rate for our - warmup phase. Will cast to the `initial_learning_rate` datatype. - Setting to `None` will skip warmup and begins decay phase from - `initial_learning_rate`. Otherwise scheduler will warmup from - `initial_learning_rate` to `warmup_target`. - warmup_steps: A Python int. Number of steps to warmup over. - - Returns: - A 1-arg callable learning rate schedule that takes the current optimizer - step and outputs the decayed learning rate, a scalar tensor of the - same type as `initial_learning_rate`. - """ - - def __init__( - self, - initial_learning_rate, - decay_steps, - alpha=0.0, - name="CosineDecay", - warmup_target=None, - warmup_steps=0, - ): - super().__init__() - - self.initial_learning_rate = initial_learning_rate - self.decay_steps = decay_steps - self.alpha = alpha - self.name = name - self.warmup_steps = warmup_steps - self.warmup_target = warmup_target - - def _decay_function(self, step, decay_steps, decay_from_lr, dtype): - with ops.name_scope(self.name): - completed_fraction = step / decay_steps - pi = ops.array(math.pi, dtype=dtype) - cosine_decayed = 0.5 * (1.0 + ops.cos(pi * completed_fraction)) - decayed = (1 - self.alpha) * cosine_decayed + self.alpha - return ops.multiply(decay_from_lr, decayed) - - def _warmup_function( - self, step, warmup_steps, warmup_target, initial_learning_rate - ): - with ops.name_scope(self.name): - completed_fraction = step / warmup_steps - total_step_delta = warmup_target - initial_learning_rate - return total_step_delta * completed_fraction + initial_learning_rate - - def __call__(self, step): - with ops.name_scope(self.name): - initial_learning_rate = ops.convert_to_tensor( - self.initial_learning_rate - ) - dtype = initial_learning_rate.dtype - decay_steps = ops.cast(self.decay_steps, dtype) - global_step_recomp = ops.cast(step, dtype) - - if self.warmup_target is None: - global_step_recomp = ops.minimum( - global_step_recomp, decay_steps - ) - return self._decay_function( - global_step_recomp, - decay_steps, - initial_learning_rate, - dtype, - ) - - warmup_target = ops.cast(self.warmup_target, dtype) - warmup_steps = ops.cast(self.warmup_steps, dtype) - - global_step_recomp = ops.minimum( - global_step_recomp, decay_steps + warmup_steps - ) - - return ops.cond( - global_step_recomp < warmup_steps, - lambda: self._warmup_function( - global_step_recomp, - warmup_steps, - warmup_target, - initial_learning_rate, - ), - lambda: self._decay_function( - global_step_recomp - warmup_steps, - decay_steps, - warmup_target, - dtype, - ), - ) - - def get_config(self): - return { - "initial_learning_rate": self.initial_learning_rate, - "decay_steps": self.decay_steps, - "alpha": self.alpha, - "name": self.name, - "warmup_target": self.warmup_target, - "warmup_steps": self.warmup_steps, - } - - -@keras_core_export("keras_core.optimizers.schedules.CosineDecayRestarts") -class CosineDecayRestarts(LearningRateSchedule): - """A `LearningRateSchedule` that uses a cosine decay schedule with restarts. - - See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983), - SGDR: Stochastic Gradient Descent with Warm Restarts. - - When training a model, it is often useful to lower the learning rate as - the training progresses. This schedule applies a cosine decay function with - restarts to an optimizer step, given a provided initial learning rate. - It requires a `step` value to compute the decayed learning rate. You can - just pass a backend variable that you increment at each training step. - - The schedule is a 1-arg callable that produces a decayed learning - rate when passed the current optimizer step. This can be useful for changing - the learning rate value across different invocations of optimizer functions. - - The learning rate multiplier first decays - from 1 to `alpha` for `first_decay_steps` steps. Then, a warm - restart is performed. Each new warm restart runs for `t_mul` times more - steps and with `m_mul` times initial learning rate as the new learning rate. - - Example usage: - ```python - first_decay_steps = 1000 - lr_decayed_fn = ( - keras_core.optimizers.schedules.CosineDecayRestarts( - initial_learning_rate, - first_decay_steps)) - ``` - - You can pass this schedule directly into a `keras_core.optimizers.Optimizer` - as the learning rate. The learning rate schedule is also serializable and - deserializable using `keras_core.optimizers.schedules.serialize` and - `keras_core.optimizers.schedules.deserialize`. - - Args: - initial_learning_rate: A Python float. The initial learning rate. - first_decay_steps: A Python integer. Number of steps to decay over. - t_mul: A Python float. Used to derive the number of iterations in - the i-th period. - m_mul: A Python float. Used to derive the initial learning rate of - the i-th period. - alpha: A Python float. Minimum learning rate value as a fraction of - the `initial_learning_rate`. - name: String. Optional name of the operation. Defaults to - `"SGDRDecay"`. - - Returns: - A 1-arg callable learning rate schedule that takes the current optimizer - step and outputs the decayed learning rate, a scalar tensor of the - same type as `initial_learning_rate`. - """ - - def __init__( - self, - initial_learning_rate, - first_decay_steps, - t_mul=2.0, - m_mul=1.0, - alpha=0.0, - name="SGDRDecay", - ): - super().__init__() - - self.initial_learning_rate = initial_learning_rate - self.first_decay_steps = first_decay_steps - self._t_mul = t_mul - self._m_mul = m_mul - self.alpha = alpha - self.name = name - - def __call__(self, step): - with ops.name_scope(self.name): - initial_learning_rate = ops.convert_to_tensor( - self.initial_learning_rate - ) - dtype = initial_learning_rate.dtype - first_decay_steps = ops.cast(self.first_decay_steps, dtype) - alpha = ops.cast(self.alpha, dtype) - t_mul = ops.cast(self._t_mul, dtype) - m_mul = ops.cast(self._m_mul, dtype) - - global_step_recomp = ops.cast(step, dtype) - completed_fraction = global_step_recomp / first_decay_steps - - def compute_step(completed_fraction, geometric=False): - """Helper for `cond` operation.""" - if geometric: - i_restart = ops.floor( - ops.log(1.0 - completed_fraction * (1.0 - t_mul)) - / ops.log(t_mul) - ) - - sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) - completed_fraction = ( - completed_fraction - sum_r - ) / t_mul**i_restart - - else: - i_restart = ops.floor(completed_fraction) - completed_fraction -= i_restart - - return i_restart, completed_fraction - - i_restart, completed_fraction = ops.cond( - ops.equal(t_mul, 1.0), - lambda: compute_step(completed_fraction, geometric=False), - lambda: compute_step(completed_fraction, geometric=True), - ) - - m_fac = m_mul**i_restart - cosine_decayed = ( - 0.5 - * m_fac - * ( - 1.0 - + ops.cos( - ops.array(math.pi, dtype=dtype) * completed_fraction - ) - ) - ) - decayed = (1 - alpha) * cosine_decayed + alpha - - return ops.multiply(initial_learning_rate, decayed) - - def get_config(self): - return { - "initial_learning_rate": self.initial_learning_rate, - "first_decay_steps": self.first_decay_steps, - "t_mul": self._t_mul, - "m_mul": self._m_mul, - "alpha": self.alpha, - "name": self.name, - } - - -@keras_core_export("keras_core.optimizers.schedules.serialize") -def serialize(learning_rate_schedule): - """Serializes a `LearningRateSchedule` into a JSON-compatible dict. - - Args: - learning_rate_schedule: The `LearningRateSchedule` object to serialize. - - Returns: - A JSON-serializable dict representing the object's config. - - Example: - - >>> lr_schedule = keras_core.optimizers.schedules.ExponentialDecay( - ... 0.1, decay_steps=100000, decay_rate=0.96, staircase=True) - >>> keras_core.optimizers.schedules.serialize(lr_schedule) - {'module': 'keras_core.optimizers.schedules', - 'class_name': 'ExponentialDecay', 'config': {...}, - 'registered_name': None} - """ - return serialization_lib.serialize_keras_object(learning_rate_schedule) - - -@keras_core_export("keras_core.optimizers.schedules.deserialize") -def deserialize(config, custom_objects=None): - """Instantiates a `LearningRateSchedule` object from a serialized form. - - Args: - config: The serialized form of the `LearningRateSchedule`. Dictionary of - the form {'class_name': str, 'config': dict}. - custom_objects: A dictionary mapping class names (or function names) of - custom (non-Keras) objects to class/functions. - - Returns: - A `LearningRateSchedule` object. - - Example: - - ```python - # Configuration for PolynomialDecay - config = { - 'class_name': 'PolynomialDecay', - 'config': {'cycle': False, - 'decay_steps': 10000, - 'end_learning_rate': 0.01, - 'initial_learning_rate': 0.1, - 'name': None, - 'power': 0.5 - } - } - lr_schedule = keras_core.optimizers.schedules.deserialize(config) - ``` - """ - return serialization_lib.deserialize_keras_object( - config, - module_objects=globals(), - custom_objects=custom_objects, - printable_module_name="decay", - ) + # TODO + pass diff --git a/keras_core/optimizers/schedules/learning_rate_schedule_test.py b/keras_core/optimizers/schedules/learning_rate_schedule_test.py deleted file mode 100644 index 0bdb09e8a..000000000 --- a/keras_core/optimizers/schedules/learning_rate_schedule_test.py +++ /dev/null @@ -1,452 +0,0 @@ -"""Tests for learning rate schedule API.""" - -import math - -import numpy as np - -from keras_core import backend -from keras_core import testing -from keras_core.optimizers.schedules import learning_rate_schedule - - -class ExponentialDecayTest(testing.TestCase): - def test_config(self): - self.run_class_serialization_test( - learning_rate_schedule.ExponentialDecay( - initial_learning_rate=0.05, - decay_steps=10, - decay_rate=0.96, - staircase=True, - name="my_ed", - ) - ) - - def test_continuous(self): - step = 5 - decayed_lr = learning_rate_schedule.ExponentialDecay(0.05, 10, 0.96) - expected = 0.05 * 0.96 ** (5.0 / 10.0) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_staircase(self): - step = backend.Variable(1) - decayed_lr = learning_rate_schedule.ExponentialDecay( - 0.1, 3, 0.96, staircase=True - ) - - # No change to learning rate due to staircase - expected = 0.1 - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - expected = 0.1 - step.assign(2) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - # Decayed learning rate - expected = 0.1 * 0.96 ** (100 // 3) - step.assign(100) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_variables(self): - step = backend.Variable(1) - decayed_lr = learning_rate_schedule.ExponentialDecay( - 0.1, 3, 0.96, staircase=True - ) - - # No change to learning rate - step.assign(1) - self.assertAllClose(decayed_lr(step), 0.1, 1e-6) - step.assign(2) - self.assertAllClose(decayed_lr(step), 0.1, 1e-6) - # Decayed learning rate - step.assign(100) - expected = 0.1 * 0.96 ** (100 // 3) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - -class PiecewiseConstantDecayTest(testing.TestCase): - def test_config(self): - self.run_class_serialization_test( - learning_rate_schedule.PiecewiseConstantDecay( - boundaries=[10, 20], values=[1, 2, 3], name="my_pcd" - ) - ) - - def test_piecewise_values(self): - x = backend.Variable(-999) - decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( - [100, 110, 120], [1.0, 0.1, 0.01, 0.001] - ) - - self.assertAllClose(decayed_lr(x), 1.0, 1e-6) - x.assign(100) - self.assertAllClose(decayed_lr(x), 1.0, 1e-6) - x.assign(105) - self.assertAllClose(decayed_lr(x), 0.1, 1e-6) - x.assign(110) - self.assertAllClose(decayed_lr(x), 0.1, 1e-6) - x.assign(120) - self.assertAllClose(decayed_lr(x), 0.01, 1e-6) - x.assign(999) - self.assertAllClose(decayed_lr(x), 0.001, 1e-6) - - def test_boundary_values(self): - # Test casting boundaries from int32 to int64. - x_int64 = backend.Variable(0, dtype="int64") - boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] - decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( - boundaries, values - ) - - self.assertAllClose(decayed_lr(x_int64), 0.4, 1e-6) - x_int64.assign(1) - self.assertAllClose(decayed_lr(x_int64), 0.4, 1e-6) - x_int64.assign(2) - self.assertAllClose(decayed_lr(x_int64), 0.5, 1e-6) - x_int64.assign(3) - self.assertAllClose(decayed_lr(x_int64), 0.6, 1e-6) - x_int64.assign(4) - self.assertAllClose(decayed_lr(x_int64), 0.7, 1e-6) - - -class LinearDecayTest(testing.TestCase): - def test_config(self): - self.run_class_serialization_test( - learning_rate_schedule.PolynomialDecay( - initial_learning_rate=0.1, - decay_steps=100, - end_learning_rate=0.005, - power=1.0, - cycle=False, - name="my_ld", - ) - ) - - def test_halfway(self): - step = 5 - lr = 0.05 - end_lr = 0.0 - decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) - expected = lr * 0.5 - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_end(self): - step = 10 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_halfway_with_end(self): - step = 5 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) - expected = (lr + end_lr) * 0.5 - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_beyond_end(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_beyond_end_with_cycle(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_schedule.PolynomialDecay( - lr, 10, end_lr, cycle=True - ) - expected = (lr - end_lr) * 0.25 + end_lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - -class SqrtDecayTest(testing.TestCase): - def test_halfway(self): - step = 5 - lr = 0.05 - end_lr = 0.0 - power = 0.5 - decayed_lr = learning_rate_schedule.PolynomialDecay( - lr, 10, end_lr, power=power - ) - expected = lr * 0.5**power - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_end(self): - step = 10 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_schedule.PolynomialDecay( - lr, 10, end_lr, power=power - ) - expected = end_lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_halfway_with_end(self): - step = 5 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_schedule.PolynomialDecay( - lr, 10, end_lr, power=power - ) - expected = (lr - end_lr) * 0.5**power + end_lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_beyond_end(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_schedule.PolynomialDecay( - lr, 10, end_lr, power=power - ) - expected = end_lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_beyond_end_with_cycle(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_schedule.PolynomialDecay( - lr, 10, end_lr, power=power, cycle=True - ) - expected = (lr - end_lr) * 0.25**power + end_lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_begin_with_cycle(self): - lr = 0.001 - decay_steps = 10 - step = 0 - decayed_lr = learning_rate_schedule.PolynomialDecay( - lr, decay_steps, cycle=True - ) - expected = lr - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - -class InverseTimeDecayTest(testing.TestCase): - def test_config(self): - self.run_class_serialization_test( - learning_rate_schedule.InverseTimeDecay( - initial_learning_rate=0.05, - decay_steps=10, - decay_rate=0.96, - staircase=True, - name="my_itd", - ) - ) - - def test_decay(self): - initial_lr = 0.1 - k = 10 - decay_rate = 0.96 - step = backend.Variable(0) - decayed_lr = learning_rate_schedule.InverseTimeDecay( - initial_lr, k, decay_rate - ) - - for i in range(k + 1): - expected = initial_lr / (1 + i / k * decay_rate) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - step.assign(step + 1) - - def test_staircase(self): - initial_lr = 0.1 - k = 10 - decay_rate = 0.96 - step = backend.Variable(0) - decayed_lr = learning_rate_schedule.InverseTimeDecay( - initial_lr, k, decay_rate, staircase=True - ) - - for i in range(k + 1): - expected = initial_lr / (1 + decay_rate * (i // k)) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - step.assign(step + 1) - - -class CosineDecayTest(testing.TestCase): - def test_config(self): - self.run_class_serialization_test( - learning_rate_schedule.CosineDecay( - initial_learning_rate=0.05, - decay_steps=10, - alpha=0.1, - warmup_target=0.2, - warmup_steps=2, - name="my_cd", - ) - ) - - def np_cosine_decay(self, step, decay_steps, alpha=0.0): - step = min(step, decay_steps) - completed_fraction = step / decay_steps - decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) - return (1.0 - alpha) * decay + alpha - - def test_decay(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecay( - initial_lr, num_training_steps - ) - expected = self.np_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def linear_warmup(self, step, warmup_steps, initial_lr, target_lr): - completed_fraction = step / warmup_steps - total_delta = target_lr - initial_lr - return completed_fraction * total_delta - - def test_warmup(self): - warmup_steps = 1500 - initial_lr = 0.0 - target_lr = 10.0 - for step in range(0, 1500, 250): - lr = learning_rate_schedule.CosineDecay( - initial_lr, - 0, - warmup_target=target_lr, - warmup_steps=warmup_steps, - ) - expected = self.linear_warmup( - step, warmup_steps, initial_lr, target_lr - ) - self.assertAllClose(lr(step), expected) - - def test_alpha(self): - num_training_steps = 1000 - initial_lr = 1.0 - alpha = 0.1 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecay( - initial_lr, num_training_steps, alpha - ) - expected = self.np_cosine_decay(step, num_training_steps, alpha) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_float64(self): - num_training_steps = 1000 - initial_lr = np.float64(1.0) - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecay( - initial_lr, num_training_steps - ) - expected = self.np_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_warmup_decay(self): - warmup_steps = 2000 - decay_steps = 1000 - initial_lr = 0.0 - target_lr = 10.0 - for step in range(0, 3000, 250): - lr = learning_rate_schedule.CosineDecay( - initial_lr, - decay_steps, - warmup_target=target_lr, - warmup_steps=warmup_steps, - ) - if step < warmup_steps + 1: - expected = self.linear_warmup( - step, warmup_steps, initial_lr, target_lr - ) - else: - expected = target_lr * self.np_cosine_decay( - step - warmup_steps, decay_steps - ) - self.assertAllClose(lr(step), expected) - - -class CosineDecayRestartsTest(testing.TestCase): - def test_config(self): - self.run_class_serialization_test( - learning_rate_schedule.CosineDecayRestarts( - initial_learning_rate=0.05, - first_decay_steps=10, - alpha=0.1, - t_mul=3.0, - m_mul=4.0, - name="my_cdr", - ) - ) - - def np_cosine_decay_restarts( - self, step, decay_steps, t_mul=2.0, m_mul=1.0, alpha=0.0 - ): - fac = 1.0 - while step >= decay_steps: - step -= decay_steps - decay_steps *= t_mul - fac *= m_mul - - completed_fraction = step / decay_steps - decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) - return (1.0 - alpha) * decay + alpha - - def test_decay(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecayRestarts( - initial_lr, num_training_steps - ) - expected = self.np_cosine_decay_restarts(step, num_training_steps) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_float64(self): - num_training_steps = 1000 - initial_lr = np.float64(1.0) - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecayRestarts( - initial_lr, num_training_steps - ) - expected = self.np_cosine_decay_restarts(step, num_training_steps) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_alpha(self): - num_training_steps = 1000 - initial_lr = 1.0 - alpha = 0.1 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecayRestarts( - initial_lr, num_training_steps, alpha=alpha - ) - expected = self.np_cosine_decay_restarts( - step, num_training_steps, alpha=alpha - ) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_mmul(self): - num_training_steps = 1000 - initial_lr = 1.0 - m_mul = 0.9 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecayRestarts( - initial_lr, num_training_steps, m_mul=m_mul - ) - expected = self.np_cosine_decay_restarts( - step, num_training_steps, m_mul=m_mul - ) - self.assertAllClose(decayed_lr(step), expected, 1e-6) - - def test_tmul(self): - num_training_steps = 1000 - initial_lr = 1.0 - t_mul = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_schedule.CosineDecayRestarts( - initial_lr, num_training_steps, t_mul=t_mul - ) - expected = self.np_cosine_decay_restarts( - step, num_training_steps, t_mul=t_mul - ) - self.assertAllClose(decayed_lr(step), expected, 1e-6)