From f0c0c877babe95253667783653e6540219609cf9 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Wed, 16 Dec 2020 09:34:29 -0800 Subject: [PATCH] Sync OSS keras to head. PiperOrigin-RevId: 347838100 --- keras/engine/BUILD | 3 + keras/layers/gru_v2_test.py | 2 +- keras/layers/local.py | 384 ++++++++---------- keras/layers/lstm_v2_test.py | 2 +- keras/layers/wrappers_test.py | 5 +- keras/losses.py | 223 +++++----- keras/mixed_precision/autocast_variable.py | 32 +- .../mixed_precision/autocast_variable_test.py | 70 +++- keras/saving/saved_model/load.py | 21 +- keras/saving/saved_model/saved_model_test.py | 16 + 10 files changed, 408 insertions(+), 350 deletions(-) diff --git a/keras/engine/BUILD b/keras/engine/BUILD index 9724b120f..18872141a 100644 --- a/keras/engine/BUILD +++ b/keras/engine/BUILD @@ -1,7 +1,10 @@ # Description: # Contains the Keras engine API (internal TensorFlow version). +# buildifier: disable=same-origin-load load("@org_keras//keras:keras.bzl", "tf_py_test") + +# buildifier: disable=same-origin-load load("@org_keras//keras:keras.bzl", "cuda_py_test") package( diff --git a/keras/layers/gru_v2_test.py b/keras/layers/gru_v2_test.py index 786d9ce7d..315e8f7f8 100644 --- a/keras/layers/gru_v2_test.py +++ b/keras/layers/gru_v2_test.py @@ -27,8 +27,8 @@ import shutil from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.framework import test_util as tf_test_util import keras +from tensorflow.python.framework import test_util as tf_test_util from keras import combinations from keras import keras_parameterized from keras import testing_utils diff --git a/keras/layers/local.py b/keras/layers/local.py index 37552e0cc..1b397ca51 100644 --- a/keras/layers/local.py +++ b/keras/layers/local.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Locally-connected layers. -""" +"""Locally-connected layers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -59,79 +58,61 @@ class LocallyConnected1D(Layer): ``` Arguments: - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, - specifying the length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: Currently only supports `"valid"` (case-insensitive). - `"same"` may be supported in the future. - `"valid"` means no padding. - data_format: A string, - one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, length)`. - It defaults to the `image_data_format` value found in your - Keras config file at `~/.keras/keras.json`. - If you never set it, then it will be "channels_last". - activation: Activation function to use. - If you don't specify anything, no activation is applied + filters: Integer, the dimensionality of the output space (i.e. the number + of output filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, specifying the + stride length of the convolution. + padding: Currently only supports `"valid"` (case-insensitive). `"same"` + may be supported in the future. `"valid"` means no padding. + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, length, + channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, length)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. If you + never set it, then it will be "channels_last". + activation: Activation function to use. If you don't specify anything, no + activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. - kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation").. kernel_constraint: Constraint function applied to the kernel matrix. bias_constraint: Constraint function applied to the bias vector. - implementation: implementation mode, either `1`, `2`, or `3`. - `1` loops over input spatial locations to perform the forward pass. - It is memory-efficient but performs a lot of (small) ops. - - `2` stores layer weights in a dense but sparsely-populated 2D matrix - and implements the forward pass as a single matrix-multiply. It uses - a lot of RAM but performs few (large) ops. - - `3` stores layer weights in a sparse tensor and implements the forward - pass as a single sparse matrix-multiply. - + implementation: implementation mode, either `1`, `2`, or `3`. `1` loops + over input spatial locations to perform the forward pass. It is + memory-efficient but performs a lot of (small) ops. `2` stores layer + weights in a dense but sparsely-populated 2D matrix and implements the + forward pass as a single matrix-multiply. It uses a lot of RAM but + performs few (large) ops. `3` stores layer weights in a sparse tensor + and implements the forward pass as a single sparse matrix-multiply. How to choose: - `1`: large, dense models, `2`: small models, - `3`: large, sparse models, - - where "large" stands for large input/output activations - (i.e. many `filters`, `input_filters`, large `input_size`, - `output_size`), and "sparse" stands for few connections between inputs - and outputs, i.e. small ratio - `filters * input_filters * kernel_size / (input_size * strides)`, - where inputs to and outputs of the layer are assumed to have shapes - `(input_size, input_filters)`, `(output_size, filters)` - respectively. - - It is recommended to benchmark each in the setting of interest to pick - the most efficient one (in terms of speed and memory usage). Correct - choice of implementation can lead to dramatic speed improvements (e.g. - 50X), potentially at the expense of RAM. - - Also, only `padding="valid"` is supported by `implementation=1`. - + `3`: large, sparse models, where "large" stands for large + input/output activations (i.e. many `filters`, `input_filters`, + large `input_size`, `output_size`), and "sparse" stands for few + connections between inputs and outputs, i.e. small ratio `filters * + input_filters * kernel_size / (input_size * strides)`, where inputs + to and outputs of the layer are assumed to have shapes `(input_size, + input_filters)`, `(output_size, filters)` respectively. It is + recommended to benchmark each in the setting of interest to pick the + most efficient one (in terms of speed and memory usage). Correct + choice of implementation can lead to dramatic speed improvements + (e.g. 50X), potentially at the expense of RAM. Also, only + `padding="valid"` is supported by `implementation=1`. Input shape: 3D tensor with shape: `(batch_size, steps, input_dim)` - Output shape: - 3D tensor with shape: `(batch_size, new_steps, filters)` - `steps` value might have changed due to padding or strides. + 3D tensor with shape: `(batch_size, new_steps, filters)` `steps` value + might have changed due to padding or strides. """ def __init__(self, @@ -158,8 +139,8 @@ class LocallyConnected1D(Layer): self.padding = conv_utils.normalize_padding(padding) if self.padding != 'valid' and implementation == 1: raise ValueError('Invalid border mode for LocallyConnected1D ' - '(only "valid" is supported if implementation is 1): ' - + padding) + '(only "valid" is supported if implementation is 1): ' + + padding) self.data_format = conv_utils.normalize_data_format(data_format) self.activation = activations.get(activation) self.use_bias = use_bias @@ -181,10 +162,13 @@ class LocallyConnected1D(Layer): input_dim, input_length = input_shape[2], input_shape[1] if input_dim is None: - raise ValueError('Axis 2 of input should be fully-defined. ' - 'Found shape:', input_shape) - self.output_length = conv_utils.conv_output_length( - input_length, self.kernel_size[0], self.padding, self.strides[0]) + raise ValueError( + 'Axis 2 of input should be fully-defined. ' + 'Found shape:', input_shape) + self.output_length = conv_utils.conv_output_length(input_length, + self.kernel_size[0], + self.padding, + self.strides[0]) if self.implementation == 1: self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim, @@ -199,17 +183,18 @@ class LocallyConnected1D(Layer): elif self.implementation == 2: if self.data_format == 'channels_first': - self.kernel_shape = (input_dim, input_length, - self.filters, self.output_length) + self.kernel_shape = (input_dim, input_length, self.filters, + self.output_length) else: - self.kernel_shape = (input_length, input_dim, - self.output_length, self.filters) + self.kernel_shape = (input_length, input_dim, self.output_length, + self.filters) - self.kernel = self.add_weight(shape=self.kernel_shape, - initializer=self.kernel_initializer, - name='kernel', - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) + self.kernel = self.add_weight( + shape=self.kernel_shape, + initializer=self.kernel_initializer, + name='kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) self.kernel_mask = get_locallyconnected_mask( input_shape=(input_length,), @@ -231,8 +216,7 @@ class LocallyConnected1D(Layer): padding=self.padding, filters_in=input_dim, filters_out=self.filters, - data_format=self.data_format) - ) + data_format=self.data_format)) self.kernel = self.add_weight( shape=(len(self.kernel_idxs),), @@ -242,8 +226,8 @@ class LocallyConnected1D(Layer): constraint=self.kernel_constraint) else: - raise ValueError('Unrecognized implementation mode: %d.' - % self.implementation) + raise ValueError('Unrecognized implementation mode: %d.' % + self.implementation) if self.use_bias: self.bias = self.add_weight( @@ -291,8 +275,8 @@ class LocallyConnected1D(Layer): self.compute_output_shape(inputs.shape)) else: - raise ValueError('Unrecognized implementation mode: %d.' - % self.implementation) + raise ValueError('Unrecognized implementation mode: %d.' % + self.implementation) if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) @@ -366,87 +350,71 @@ class LocallyConnected2D(Layer): ``` Arguments: - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: Currently only support `"valid"` (case-insensitive). - `"same"` will be supported in future. - `"valid"` means no padding. - data_format: A string, - one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, height, width)`. - It defaults to the `image_data_format` value found in your - Keras config file at `~/.keras/keras.json`. - If you never set it, then it will be "channels_last". - activation: Activation function to use. - If you don't specify anything, no activation is applied + filters: Integer, the dimensionality of the output space (i.e. the number + of output filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the width + and height of the 2D convolution window. Can be a single integer to + specify the same value for all spatial dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides of + the convolution along the width and height. Can be a single integer to + specify the same value for all spatial dimensions. + padding: Currently only support `"valid"` (case-insensitive). `"same"` + will be supported in future. `"valid"` means no padding. + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, height, width, + channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + "channels_last". + activation: Activation function to use. If you don't specify anything, no + activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. - kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation"). + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation"). kernel_constraint: Constraint function applied to the kernel matrix. bias_constraint: Constraint function applied to the bias vector. - implementation: implementation mode, either `1`, `2`, or `3`. - `1` loops over input spatial locations to perform the forward pass. - It is memory-efficient but performs a lot of (small) ops. - - `2` stores layer weights in a dense but sparsely-populated 2D matrix - and implements the forward pass as a single matrix-multiply. It uses - a lot of RAM but performs few (large) ops. - - `3` stores layer weights in a sparse tensor and implements the forward - pass as a single sparse matrix-multiply. - + implementation: implementation mode, either `1`, `2`, or `3`. `1` loops + over input spatial locations to perform the forward pass. It is + memory-efficient but performs a lot of (small) ops. `2` stores layer + weights in a dense but sparsely-populated 2D matrix and implements the + forward pass as a single matrix-multiply. It uses a lot of RAM but + performs few (large) ops. `3` stores layer weights in a sparse tensor + and implements the forward pass as a single sparse matrix-multiply. How to choose: - `1`: large, dense models, `2`: small models, - `3`: large, sparse models, - - where "large" stands for large input/output activations - (i.e. many `filters`, `input_filters`, large `np.prod(input_size)`, - `np.prod(output_size)`), and "sparse" stands for few connections - between inputs and outputs, i.e. small ratio - `filters * input_filters * np.prod(kernel_size) / (np.prod(input_size) - * np.prod(strides))`, where inputs to and outputs of the layer are - assumed to have shapes `input_size + (input_filters,)`, - `output_size + (filters,)` respectively. - - It is recommended to benchmark each in the setting of interest to pick - the most efficient one (in terms of speed and memory usage). Correct - choice of implementation can lead to dramatic speed improvements (e.g. - 50X), potentially at the expense of RAM. - - Also, only `padding="valid"` is supported by `implementation=1`. - + `3`: large, sparse models, where "large" stands for large + input/output activations (i.e. many `filters`, `input_filters`, + large `np.prod(input_size)`, `np.prod(output_size)`), and "sparse" + stands for few connections between inputs and outputs, i.e. small + ratio `filters * input_filters * np.prod(kernel_size) / + (np.prod(input_size) * np.prod(strides))`, where inputs to and + outputs of the layer are assumed to have shapes `input_size + + (input_filters,)`, `output_size + (filters,)` respectively. It is + recommended to benchmark each in the setting of interest to pick the + most efficient one (in terms of speed and memory usage). Correct + choice of implementation can lead to dramatic speed improvements + (e.g. 50X), potentially at the expense of RAM. Also, only + `padding="valid"` is supported by `implementation=1`. Input shape: - 4D tensor with shape: - `(samples, channels, rows, cols)` if data_format='channels_first' - or 4D tensor with shape: - `(samples, rows, cols, channels)` if data_format='channels_last'. - + 4D tensor with shape: `(samples, channels, rows, cols)` if + data_format='channels_first' + or 4D tensor with shape: `(samples, rows, cols, channels)` if + data_format='channels_last'. Output shape: - 4D tensor with shape: - `(samples, filters, new_rows, new_cols)` if data_format='channels_first' - or 4D tensor with shape: - `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. - `rows` and `cols` values might have changed due to padding. + 4D tensor with shape: `(samples, filters, new_rows, new_cols)` if + data_format='channels_first' + or 4D tensor with shape: `(samples, new_rows, new_cols, filters)` if + data_format='channels_last'. `rows` and `cols` values might have changed + due to padding. """ def __init__(self, @@ -473,8 +441,8 @@ class LocallyConnected2D(Layer): self.padding = conv_utils.normalize_padding(padding) if self.padding != 'valid' and implementation == 1: raise ValueError('Invalid border mode for LocallyConnected2D ' - '(only "valid" is supported if implementation is 1): ' - + padding) + '(only "valid" is supported if implementation is 1): ' + + padding) self.data_format = conv_utils.normalize_data_format(data_format) self.activation = activations.get(activation) self.use_bias = use_bias @@ -509,10 +477,8 @@ class LocallyConnected2D(Layer): self.output_col = output_col if self.implementation == 1: - self.kernel_shape = ( - output_row * output_col, - self.kernel_size[0] * self.kernel_size[1] * input_filter, - self.filters) + self.kernel_shape = (output_row * output_col, self.kernel_size[0] * + self.kernel_size[1] * input_filter, self.filters) self.kernel = self.add_weight( shape=self.kernel_shape, @@ -523,17 +489,18 @@ class LocallyConnected2D(Layer): elif self.implementation == 2: if self.data_format == 'channels_first': - self.kernel_shape = (input_filter, input_row, input_col, - self.filters, self.output_row, self.output_col) + self.kernel_shape = (input_filter, input_row, input_col, self.filters, + self.output_row, self.output_col) else: self.kernel_shape = (input_row, input_col, input_filter, self.output_row, self.output_col, self.filters) - self.kernel = self.add_weight(shape=self.kernel_shape, - initializer=self.kernel_initializer, - name='kernel', - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) + self.kernel = self.add_weight( + shape=self.kernel_shape, + initializer=self.kernel_initializer, + name='kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) self.kernel_mask = get_locallyconnected_mask( input_shape=(input_row, input_col), @@ -555,8 +522,7 @@ class LocallyConnected2D(Layer): padding=self.padding, filters_in=input_filter, filters_out=self.filters, - data_format=self.data_format) - ) + data_format=self.data_format)) self.kernel = self.add_weight( shape=(len(self.kernel_idxs),), @@ -566,8 +532,8 @@ class LocallyConnected2D(Layer): constraint=self.kernel_constraint) else: - raise ValueError('Unrecognized implementation mode: %d.' - % self.implementation) + raise ValueError('Unrecognized implementation mode: %d.' % + self.implementation) if self.use_bias: self.bias = self.add_weight( @@ -619,8 +585,8 @@ class LocallyConnected2D(Layer): self.compute_output_shape(inputs.shape)) else: - raise ValueError('Unrecognized implementation mode: %d.' - % self.implementation) + raise ValueError('Unrecognized implementation mode: %d.' % + self.implementation) if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) @@ -686,10 +652,10 @@ def get_locallyconnected_mask(input_shape, kernel_shape, strides, padding, `strides`, `padding` and `data_format`. Arguments: - input_shape: tuple of size N: `(d_in1, ..., d_inN)` - spatial shape of the input. - kernel_shape: tuple of size N, spatial shape of the convolutional kernel - / receptive field. + input_shape: tuple of size N: `(d_in1, ..., d_inN)` spatial shape of the + input. + kernel_shape: tuple of size N, spatial shape of the convolutional kernel / + receptive field. strides: tuple of size N, strides along each spatial dimension. padding: type of padding, string `"same"` or `"valid"`. data_format: a string, `"channels_first"` or `"channels_last"`. @@ -709,8 +675,7 @@ def get_locallyconnected_mask(input_shape, kernel_shape, strides, padding, input_shape=input_shape, kernel_shape=kernel_shape, strides=strides, - padding=padding - ) + padding=padding) ndims = int(mask.ndim / 2) @@ -739,34 +704,26 @@ def local_conv_matmul(inputs, kernel, kernel_mask, output_shape): reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D. Arguments: - inputs: (N+2)-D tensor with shape - `(batch_size, channels_in, d_in1, ..., d_inN)` - or - `(batch_size, d_in1, ..., d_inN, channels_in)`. + inputs: (N+2)-D tensor with shape `(batch_size, channels_in, d_in1, ..., + d_inN)` or `(batch_size, d_in1, ..., d_inN, channels_in)`. kernel: the unshared weights for N-D convolution, - an (N+2)-D tensor of shape: - `(d_in1, ..., d_inN, channels_in, d_out2, ..., d_outN, channels_out)` - or - `(channels_in, d_in1, ..., d_inN, channels_out, d_out2, ..., d_outN)`, - with the ordering of channels and spatial dimensions matching - that of the input. - Each entry is the weight between a particular input and - output location, similarly to a fully-connected weight matrix. - kernel_mask: a float 0/1 mask tensor of shape: - `(d_in1, ..., d_inN, 1, d_out2, ..., d_outN, 1)` - or - `(1, d_in1, ..., d_inN, 1, d_out2, ..., d_outN)`, - with the ordering of singleton and spatial dimensions - matching that of the input. - Mask represents the connectivity pattern of the layer and is - precomputed elsewhere based on layer parameters: stride, - padding, and the receptive field shape. + an (N+2)-D tensor of shape: `(d_in1, ..., d_inN, channels_in, d_out2, + ..., d_outN, channels_out)` or `(channels_in, d_in1, ..., d_inN, + channels_out, d_out2, ..., d_outN)`, with the ordering of channels + and spatial dimensions matching that of the input. Each entry is the + weight between a particular input and output location, similarly to + a fully-connected weight matrix. + kernel_mask: a float 0/1 mask tensor of shape: `(d_in1, ..., d_inN, 1, + d_out2, ..., d_outN, 1)` or `(1, d_in1, ..., d_inN, 1, d_out2, ..., + d_outN)`, with the ordering of singleton and spatial dimensions matching + that of the input. Mask represents the connectivity pattern of the layer + and is + precomputed elsewhere based on layer parameters: stride, padding, and + the receptive field shape. output_shape: a tuple of (N+2) elements representing the output shape: - `(batch_size, channels_out, d_out1, ..., d_outN)` - or - `(batch_size, d_out1, ..., d_outN, channels_out)`, - with the ordering of channels and spatial dimensions matching that of - the input. + `(batch_size, channels_out, d_out1, ..., d_outN)` or `(batch_size, + d_out1, ..., d_outN, channels_out)`, with the ordering of channels and + spatial dimensions matching that of the input. Returns: Output (N+2)-D tensor with shape `output_shape`. @@ -777,8 +734,9 @@ def local_conv_matmul(inputs, kernel, kernel_mask, output_shape): kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2) output_flat = tf.compat.v1.sparse_matmul(inputs_flat, kernel, b_is_sparse=True) - output = K.reshape(output_flat, - [K.shape(output_flat)[0],] + output_shape.as_list()[1:]) + output = K.reshape(output_flat, [ + K.shape(output_flat)[0], + ] + output_shape.as_list()[1:]) return output @@ -810,14 +768,16 @@ def local_conv_sparse_matmul(inputs, kernel, kernel_idxs, kernel_shape, """ inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1)) output_flat = tf.raw_ops.SparseTensorDenseMatMul( - a_indices=kernel_idxs, a_values=kernel, a_shape=kernel_shape, - b=inputs_flat, adjoint_b=True) + a_indices=kernel_idxs, + a_values=kernel, + a_shape=kernel_shape, + b=inputs_flat, + adjoint_b=True) output_flat_transpose = K.transpose(output_flat) - output_reshaped = K.reshape( - output_flat_transpose, - [K.shape(output_flat_transpose)[0],] + output_shape.as_list()[1:] - ) + output_reshaped = K.reshape(output_flat_transpose, [ + K.shape(output_flat_transpose)[0], + ] + output_shape.as_list()[1:]) return output_reshaped @@ -830,7 +790,7 @@ def make_2d(tensor, split_dim): Arguments: tensor: a tensor of shape `(d0, ..., d(N-1))`. split_dim: an integer from 1 to N-1, index of the dimension to group - dimensions before (excluding) and after (including). + dimensions before (excluding) and after (including). Returns: Tensor of shape diff --git a/keras/layers/lstm_v2_test.py b/keras/layers/lstm_v2_test.py index 1638edcc3..d0fed33aa 100644 --- a/keras/layers/lstm_v2_test.py +++ b/keras/layers/lstm_v2_test.py @@ -28,8 +28,8 @@ import time from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.framework import test_util as tf_test_util import keras +from tensorflow.python.framework import test_util as tf_test_util from keras import keras_parameterized from keras import testing_utils from keras.layers import recurrent as rnn_v1 diff --git a/keras/layers/wrappers_test.py b/keras/layers/wrappers_test.py index 6358c0b16..2f983ee28 100644 --- a/keras/layers/wrappers_test.py +++ b/keras/layers/wrappers_test.py @@ -26,6 +26,7 @@ from absl.testing import parameterized import numpy as np import keras +from tensorflow.python.framework import test_util as tf_test_util from keras import combinations from keras import keras_parameterized from keras import testing_utils @@ -33,8 +34,6 @@ from keras.engine import base_layer_utils from keras.layers import core from keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper from keras.utils import generic_utils -from tensorflow.python.eager import context -from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.training.tracking import util as trackable_util @@ -653,7 +652,7 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase): model.compile(loss='mse', optimizer='sgd') model.fit(x, y, epochs=1, batch_size=1) - if context.executing_eagerly(): + if tf.executing_eagerly(): run_test() else: tf_test_util.enable_output_all_intermediates(run_test)() diff --git a/keras/losses.py b/keras/losses.py index 3a14291cf..768471ba5 100644 --- a/keras/losses.py +++ b/keras/losses.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Built-in loss functions. -""" +"""Built-in loss functions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -85,8 +84,8 @@ class Loss(object): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. """ losses_utils.ReductionV2.validate(reduction) @@ -115,15 +114,15 @@ class Loss(object): sparse loss functions such as sparse categorical crossentropy where shape = `[batch_size, d0, .. dN-1]` y_pred: The predicted values. shape = `[batch_size, d0, .. dN]` - sample_weight: Optional `sample_weight` acts as a - coefficient for the loss. If a scalar is provided, then the loss is - simply scaled by the given value. If `sample_weight` is a tensor of size - `[batch_size]`, then the total loss for each sample of the batch is - rescaled by the corresponding element in the `sample_weight` vector. If - the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be - broadcasted to this shape), then each loss element of `y_pred` is scaled + sample_weight: Optional `sample_weight` acts as a coefficient for the + loss. If a scalar is provided, then the loss is simply scaled by the + given value. If `sample_weight` is a tensor of size `[batch_size]`, then + the total loss for each sample of the batch is rescaled by the + corresponding element in the `sample_weight` vector. If the shape of + `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to + this shape), then each loss element of `y_pred` is scaled by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss - functions reduce by 1 dimension, usually axis=-1.) + functions reduce by 1 dimension, usually axis=-1.) Returns: Weighted loss float `Tensor`. If `reduction` is `NONE`, this has @@ -223,8 +222,8 @@ class LossFunctionWrapper(Loss): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: (Optional) name for the loss. **kwargs: The keyword arguments that are passed on to `fn`. """ @@ -243,8 +242,7 @@ class LossFunctionWrapper(Loss): Loss values per sample. """ if tf.is_tensor(y_pred) and tf.is_tensor(y_true): - y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( - y_pred, y_true) + y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true) ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx()) return ag_fn(y_true, y_pred, **self._fn_kwargs) @@ -307,8 +305,8 @@ class MeanSquaredError(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'mean_squared_error'. """ super(MeanSquaredError, self).__init__( @@ -366,8 +364,8 @@ class MeanAbsoluteError(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'mean_absolute_error'. """ super(MeanAbsoluteError, self).__init__( @@ -426,8 +424,8 @@ class MeanAbsolutePercentageError(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'mean_absolute_percentage_error'. """ @@ -487,8 +485,8 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'mean_squared_logarithmic_error'. """ @@ -500,44 +498,64 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper): class BinaryCrossentropy(LossFunctionWrapper): """Computes the cross-entropy loss between true labels and predicted labels. - Use this cross-entropy loss when there are only two label classes (assumed to - be 0 and 1). For each example, there should be a single floating-point value - per prediction. + Use this cross-entropy loss for binary (0 or 1) classification applications. + The loss function requires the following inputs: - In the snippet below, each of the four examples has only a single - floating-pointing value, and both `y_pred` and `y_true` have the shape - `[batch_size]`. + - `y_true` (true label): This is either 0 or 1. + - `y_pred` (predicted value): This is the model's prediction, i.e, a single + floating-point value which either represents a + [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf] + when `from_logits=True`) or a probability (i.e, value in [0., 1.] when + `from_logits=False`). - Standalone usage: + **Recommended Usage:** (set `from_logits=True`) - >>> y_true = [[0., 1.], [0., 0.]] - >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] - >>> # Using 'auto'/'sum_over_batch_size' reduction type. - >>> bce = tf.keras.losses.BinaryCrossentropy() - >>> bce(y_true, y_pred).numpy() - 0.815 - - >>> # Calling with 'sample_weight'. - >>> bce(y_true, y_pred, sample_weight=[1, 0]).numpy() - 0.458 - - >>> # Using 'sum' reduction type. - >>> bce = tf.keras.losses.BinaryCrossentropy( - ... reduction=tf.keras.losses.Reduction.SUM) - >>> bce(y_true, y_pred).numpy() - 1.630 - - >>> # Using 'none' reduction type. - >>> bce = tf.keras.losses.BinaryCrossentropy( - ... reduction=tf.keras.losses.Reduction.NONE) - >>> bce(y_true, y_pred).numpy() - array([0.916 , 0.714], dtype=float32) - - Usage with the `tf.keras` API: + With `tf.keras` API: ```python - model.compile(optimizer='sgd', loss=tf.keras.losses.BinaryCrossentropy()) + model.compile( + loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), + .... + ) ``` + + As a standalone function: + + >>> # Example 1: (batch_size = 1, number of samples = 4) + >>> y_true = [0, 1, 0, 0] + >>> y_pred = [-18.6, 0.51, 2.94, -12.8] + >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) + >>> bce(y_true, y_pred).numpy() + 0.865 + + >>> # Example 2: (batch_size = 2, number of samples = 4) + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[-18.6, 0.51], [2.94, -12.8]] + >>> # Using default 'auto'/'sum_over_batch_size' reduction type. + >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) + >>> bce(y_true, y_pred).numpy() + 0.865 + >>> # Using 'sample_weight' attribute + >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() + 0.243 + >>> # Using 'sum' reduction` type. + >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True, + ... reduction=tf.keras.losses.Reduction.SUM) + >>> bce(y_true, y_pred).numpy() + 1.730 + >>> # Using 'none' reduction type. + >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True, + ... reduction=tf.keras.losses.Reduction.NONE) + >>> bce(y_true, y_pred).numpy() + array([0.235, 1.496], dtype=float32) + + **Default Usage:** (set `from_logits=False`) + + >>> # Make the following updates to the above "Recommended Usage" section + >>> # 1. Set `from_logits=False` + >>> tf.keras.losses.BinaryCrossentropy() # OR ...('from_logits=False') + >>> # 2. Update `y_pred` to use probabilities instead of logits + >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]] """ def __init__(self, @@ -563,8 +581,8 @@ class BinaryCrossentropy(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: (Optional) Name for the op. Defaults to 'binary_crossentropy'. """ super(BinaryCrossentropy, self).__init__( @@ -633,9 +651,9 @@ class CategoricalCrossentropy(LossFunctionWrapper): default, we assume that `y_pred` encodes a probability distribution. **Note - Using from_logits=True is more numerically stable.** label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, - meaning the confidence on label values are relaxed. e.g. - `label_smoothing=0.2` means that we will use a value of `0.1` for label - `0` and `0.9` for label `1`" + meaning the confidence on label values are relaxed. For example, if + `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss. Default value is `AUTO`. `AUTO` indicates that the reduction option will be determined by the usage context. For almost all cases @@ -643,8 +661,8 @@ class CategoricalCrossentropy(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'categorical_crossentropy'. """ super(CategoricalCrossentropy, self).__init__( @@ -720,8 +738,8 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'sparse_categorical_crossentropy'. """ @@ -784,8 +802,8 @@ class Hinge(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'hinge'. """ super(Hinge, self).__init__(hinge, name=name, reduction=reduction) @@ -845,8 +863,8 @@ class SquaredHinge(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'squared_hinge'. """ super(SquaredHinge, self).__init__( @@ -905,8 +923,8 @@ class CategoricalHinge(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'categorical_hinge'. """ super(CategoricalHinge, self).__init__( @@ -962,8 +980,8 @@ class Poisson(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'poisson'. """ super(Poisson, self).__init__(poisson, name=name, reduction=reduction) @@ -1019,8 +1037,8 @@ class LogCosh(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'log_cosh'. """ super(LogCosh, self).__init__(log_cosh, name=name, reduction=reduction) @@ -1079,8 +1097,8 @@ class KLDivergence(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'kl_divergence'. """ super(KLDivergence, self).__init__( @@ -1147,20 +1165,17 @@ class Huber(LossFunctionWrapper): `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this custom training [tutorial]( - https://www.tensorflow.org/tutorials/distribute/custom_training) - for more details. + https://www.tensorflow.org/tutorials/distribute/custom_training) for + more details. name: Optional name for the op. Defaults to 'huber_loss'. """ super(Huber, self).__init__( huber, name=name, reduction=reduction, delta=delta) -@keras_export('keras.metrics.mean_squared_error', - 'keras.metrics.mse', - 'keras.metrics.MSE', - 'keras.losses.mean_squared_error', - 'keras.losses.mse', - 'keras.losses.MSE') +@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse', + 'keras.metrics.MSE', 'keras.losses.mean_squared_error', + 'keras.losses.mse', 'keras.losses.MSE') @tf.__internal__.dispatch.add_dispatch_support def mean_squared_error(y_true, y_pred): """Computes the mean squared error between labels and predictions. @@ -1191,12 +1206,9 @@ def mean_squared_error(y_true, y_pred): return K.mean(tf.math.squared_difference(y_pred, y_true), axis=-1) -@keras_export('keras.metrics.mean_absolute_error', - 'keras.metrics.mae', - 'keras.metrics.MAE', - 'keras.losses.mean_absolute_error', - 'keras.losses.mae', - 'keras.losses.MAE') +@keras_export('keras.metrics.mean_absolute_error', 'keras.metrics.mae', + 'keras.metrics.MAE', 'keras.losses.mean_absolute_error', + 'keras.losses.mae', 'keras.losses.MAE') @tf.__internal__.dispatch.add_dispatch_support def mean_absolute_error(y_true, y_pred): """Computes the mean absolute error between labels and predictions. @@ -1225,11 +1237,9 @@ def mean_absolute_error(y_true, y_pred): @keras_export('keras.metrics.mean_absolute_percentage_error', - 'keras.metrics.mape', - 'keras.metrics.MAPE', + 'keras.metrics.mape', 'keras.metrics.MAPE', 'keras.losses.mean_absolute_percentage_error', - 'keras.losses.mape', - 'keras.losses.MAPE') + 'keras.losses.mape', 'keras.losses.MAPE') @tf.__internal__.dispatch.add_dispatch_support def mean_absolute_percentage_error(y_true, y_pred): """Computes the mean absolute percentage error between `y_true` and `y_pred`. @@ -1262,11 +1272,9 @@ def mean_absolute_percentage_error(y_true, y_pred): @keras_export('keras.metrics.mean_squared_logarithmic_error', - 'keras.metrics.msle', - 'keras.metrics.MSLE', + 'keras.metrics.msle', 'keras.metrics.MSLE', 'keras.losses.mean_squared_logarithmic_error', - 'keras.losses.msle', - 'keras.losses.MSLE') + 'keras.losses.msle', 'keras.losses.MSLE') @tf.__internal__.dispatch.add_dispatch_support def mean_squared_logarithmic_error(y_true, y_pred): """Computes the mean squared logarithmic error between `y_true` and `y_pred`. @@ -1511,7 +1519,9 @@ def categorical_crossentropy(y_true, y_pred: Tensor of predicted targets. from_logits: Whether `y_pred` is expected to be a logits tensor. By default, we assume that `y_pred` encodes a probability distribution. - label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels + and `0.9 + 0.1 / num_classes` for target labels. Returns: Categorical crossentropy loss value. @@ -1582,7 +1592,9 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. from_logits: Whether `y_pred` is expected to be a logits tensor. By default, we assume that `y_pred` encodes a probability distribution. - label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by + squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing` + for the target class and `0.5 * label_smoothing` for the non-target class. Returns: Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`. @@ -1602,12 +1614,9 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): @keras_export('keras.metrics.kl_divergence', - 'keras.metrics.kullback_leibler_divergence', - 'keras.metrics.kld', - 'keras.metrics.KLD', - 'keras.losses.kl_divergence', - 'keras.losses.kullback_leibler_divergence', - 'keras.losses.kld', + 'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld', + 'keras.metrics.KLD', 'keras.losses.kl_divergence', + 'keras.losses.kullback_leibler_divergence', 'keras.losses.kld', 'keras.losses.KLD') @tf.__internal__.dispatch.add_dispatch_support def kl_divergence(y_true, y_pred): diff --git a/keras/mixed_precision/autocast_variable.py b/keras/mixed_precision/autocast_variable.py index e9edd1c41..aa5ed0223 100644 --- a/keras/mixed_precision/autocast_variable.py +++ b/keras/mixed_precision/autocast_variable.py @@ -69,12 +69,11 @@ class AutoCastVariable(tf.Variable, core.Tensor): called. """ - def __init__(self, variable, op=None): + def __init__(self, variable): """Creates an AutoCastVariable instance. Args: variable: A floating-point resource variable to wrap. - op: Optional operation of this variable. Raises: ValueError: If `variable` is not a floating-point resource variable @@ -86,7 +85,11 @@ class AutoCastVariable(tf.Variable, core.Tensor): raise ValueError('variable must be a floating point variable but has ' 'type: %s' % variable.dtype.name) self._variable = variable - self._op = op + # 'delegate' means AutoCastVariable.op return self._variable.op, which will + # raise an AttributeError in Eager (as intended). If set to any other value, + # AutoCastVariable.op returns that value instead, which is used to set the + # op attribute in AutoCastVariable.assign(). + self._op = 'delegate' def _should_cast(self): """Returns True if this variable should be casted when accessed.""" @@ -211,10 +214,18 @@ class AutoCastVariable(tf.Variable, core.Tensor): use_locking=None, name=None, read_value=True): + # TODO(b/146181571): This logic can be simplified once + # DistributedVariable.assign returns a DistributedVariable. Currently for + # MirroredStrategy, it returns a Mirrored value. if tf.compat.v1.executing_eagerly_outside_functions(): assign_op = update_fn(value, use_locking, name, False) if read_value: - return create_autocast_variable(self._variable, op=assign_op) + # We create a new AutoCastVariable with the same underlying tf.Variable. + # The new AutoCastVariable is identical except the 'op' attribute is + # defined. This matches the behavior of tf.Variable.assign. + var = create_autocast_variable(self._variable) + var._op = assign_op # pylint:disable=protected-access + return var return assign_op # Fallback to wrapping the returned variable in graph mode if possible @@ -310,9 +321,9 @@ class AutoCastVariable(tf.Variable, core.Tensor): @property def op(self): - if self._op is not None: - return self._op - return self._variable.op + if self._op == 'delegate': + return self._variable.op + return self._op def _as_graph_element(self): graph_element = self._variable._as_graph_element() # pylint:disable=protected-access @@ -481,7 +492,7 @@ tf.register_tensor_conversion_function(AutoCastVariable, AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access -def create_autocast_variable(variable, op=None): +def create_autocast_variable(variable): """Creates an AutoCastVariable that wraps another variable. This typically just returns `AutoCastVariable(variable)`. But, if the variable @@ -493,14 +504,13 @@ def create_autocast_variable(variable, op=None): Args: variable: A floating-point resource variable to wrap. - op: Optional operation of this variable. Returns: An AutoCastVariable that wraps the variable. """ if not isinstance(variable, (distribute_values.DistributedVariable, ps_distribute_values.AggregatingVariable)): - return AutoCastVariable(variable, op=op) + return AutoCastVariable(variable) class AutoCastDistributedVariable(AutoCastVariable, variable.__class__): """An AutoCastVariable that also subclasses from variable.__class__. @@ -523,7 +533,7 @@ def create_autocast_variable(variable, op=None): ).format(v=self) # pylint: enable=missing-format-attribute - return AutoCastDistributedVariable(variable, op=op) + return AutoCastDistributedVariable(variable) class enable_auto_cast_variables(object): # pylint:disable=invalid-name diff --git a/keras/mixed_precision/autocast_variable_test.py b/keras/mixed_precision/autocast_variable_test.py index ec89bea49..8c309e975 100644 --- a/keras/mixed_precision/autocast_variable_test.py +++ b/keras/mixed_precision/autocast_variable_test.py @@ -26,7 +26,14 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.distribute import test_util from keras.mixed_precision import autocast_variable +from keras.optimizer_v2 import adadelta +from keras.optimizer_v2 import adagrad +from keras.optimizer_v2 import adam +from keras.optimizer_v2 import adamax +from keras.optimizer_v2 import ftrl from keras.optimizer_v2 import gradient_descent as gradient_descent_v2 +from keras.optimizer_v2 import nadam +from keras.optimizer_v2 import rmsprop maybe_distribute = tf.__internal__.test.combinations.combine(distribution=[ tf.__internal__.distribute.combinations.default_strategy, @@ -335,11 +342,28 @@ class AutoCastVariableTest(tf.test.TestCase, parameterized.TestCase): self.assertAllClose(5., self.evaluate(run_assign())) @tf.__internal__.distribute.combinations.generate(maybe_distribute) - def test_assign_op(self, distribution): + def test_op_attribute(self, distribution): with distribution.scope(): x = get_var(0., tf.float32) x = autocast_variable.create_autocast_variable(x) + # Variable.op raises an AttributeError in Eager mode and is an op in graph + # mode. Variable.assign(...).op is None in Eager mode and an op in Graph + # mode or a tf.function. We test this is also true of AutoCastVariable. + if tf.executing_eagerly(): + with self.assertRaisesRegex( + AttributeError, + 'Tensor.op is meaningless when eager execution is enabled'): + x.op # pylint: disable=pointless-statement + self.assertIsNone(x.assign(1.0).op) + self.assertIsNone(x.assign_add(1.0).op) + self.assertIsNone(x.assign_sub(1.0).op) + else: + self.assertIsNotNone(x.op) + self.assertIsNotNone(x.assign(1.0).op) + self.assertIsNotNone(x.assign_add(1.0).op) + self.assertIsNotNone(x.assign_sub(1.0).op) + @tf.function def func(): self.assertIsNotNone(x.assign(1.0).op) @@ -486,25 +510,51 @@ class AutoCastVariableTest(tf.test.TestCase, parameterized.TestCase): 'dtype_to_cast_to=float32 ' 'inner_variable=MirroredVariable.*>') - @parameterized.named_parameters( - ('v1', tf.compat.v1.train.GradientDescentOptimizer), - ('v2', gradient_descent_v2.SGD)) - def test_optimizer(self, optimizer_class): + @tf.__internal__.distribute.combinations.generate(tf.__internal__.test.combinations.combine( + optimizer_class=[ + adadelta.Adadelta, + adagrad.Adagrad, + adam.Adam, + adamax.Adamax, + ftrl.Ftrl, + gradient_descent_v2.SGD, + nadam.Nadam, + rmsprop.RMSprop, + tf.compat.v1.train.GradientDescentOptimizer + ], + use_tf_function=[False, True])) + def test_optimizer(self, optimizer_class, use_tf_function): + if use_tf_function and not tf.executing_eagerly(): + self.skipTest('Test does not support graph mode with tf.function') x = get_var(1., tf.float32) x = autocast_variable.create_autocast_variable(x) - opt = optimizer_class(1.) + y = get_var(1., tf.float32) + opt = optimizer_class(learning_rate=1.) - @tf.function def f(): - opt.minimize(lambda: x + 1., var_list=[x]) + # Minimize both the AutoCastVariable and the normal tf.Variable. Both + # variables should be updated to the same value. + op = opt.minimize(lambda: x + y, var_list=[x, y]) + return None if tf.compat.v1.executing_eagerly_outside_functions() else op + + if use_tf_function: + f = tf.function(f) if tf.executing_eagerly(): f() else: - op = f() # pylint: disable=assignment-from-no-return + op = f() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(op) - self.assertEqual(self.evaluate(x), 0) + # Assert the AutoCastVariable has changed from its initial value + self.assertNotEqual(self.evaluate(x), 1.) + # Assert AutoCastVariable is updated correctly by comparing it to the normal + # variable + self.assertAlmostEqual(self.evaluate(x), self.evaluate(y)) + if optimizer_class in (gradient_descent_v2.SGD, + tf.compat.v1.train.GradientDescentOptimizer): + # With SGD, the variables decreases by exactly 1 + self.assertEqual(self.evaluate(x), 0) if __name__ == '__main__': diff --git a/keras/saving/saved_model/load.py b/keras/saving/saved_model/load.py index b944fe8ae..73a38e0a9 100644 --- a/keras/saving/saved_model/load.py +++ b/keras/saving/saved_model/load.py @@ -139,7 +139,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin # Recreate layers and metrics using the info stored in the metadata. keras_loader = KerasObjectLoader(metadata, object_graph_def) - keras_loader.load_layers() + keras_loader.load_layers(compile=compile) # Generate a dictionary of all loaded nodes. nodes_to_load = {'root': None} @@ -364,7 +364,7 @@ class KerasObjectLoader(object): obj_child, child_proto, child_id) self.loaded_nodes[child_id] = obj_child, setter - def load_layers(self): + def load_layers(self, compile=True): # pylint: disable=redefined-builtin """Load all layer nodes from the metadata.""" # Load metrics after models and layers, since it's likely that models # and layers will create the metric when initialized (this avoids wasting @@ -380,9 +380,20 @@ class KerasObjectLoader(object): node_metadata.metadata) for node_metadata in metric_list: - self.loaded_nodes[node_metadata.node_id] = self._load_layer( - node_metadata.node_id, node_metadata.identifier, - node_metadata.metadata) + try: + self.loaded_nodes[node_metadata.node_id] = self._load_layer( + node_metadata.node_id, node_metadata.identifier, + node_metadata.metadata) + except ValueError: + # Metrics are only needed when the model is compiled later. We ignore + # errors when trying to load custom metrics when `compile=False` until + # custom metrics are serialized properly (b/135550038). + if compile: + raise + logging.warning('Unable to restore custom metric. Please ensure that ' + 'the layer implements `get_config` and `from_config` ' + 'when saving. In addition, please use the ' + '`custom_objects` arg when calling `load_model()`.') def _load_layer(self, node_id, identifier, metadata): """Load a single layer from a SavedUserObject proto.""" diff --git a/keras/saving/saved_model/saved_model_test.py b/keras/saving/saved_model/saved_model_test.py index fe6864792..cb7f59118 100644 --- a/keras/saving/saved_model/saved_model_test.py +++ b/keras/saving/saved_model/saved_model_test.py @@ -1142,6 +1142,22 @@ class MetricTest(tf.test.TestCase, parameterized.TestCase): self._test_metric_save_and_load( metric, self._save_model_dir(), 1, test_sample_weight=False) + @keras_parameterized.run_with_all_model_types + def test_custom_metric_model(self): + + class CustomMetric(keras.metrics.MeanSquaredError): + pass + + model = testing_utils.get_small_mlp(1, 4, input_dim=3) + model.compile(loss='mse', optimizer='rmsprop', metrics=[CustomMetric()]) + + saved_model_dir = self._save_model_dir() + tf.saved_model.save(model, saved_model_dir) + with self.assertRaisesRegex(ValueError, 'custom_objects'): + keras_load.load(saved_model_dir) + + keras_load.load(saved_model_dir, compile=False) + if __name__ == '__main__': tf.test.main()