Add mobilenet and efficientnet models.

This commit is contained in:
Francois Chollet 2023-05-22 16:39:24 -07:00
parent c3810a76dd
commit 5b7e5f88b2
16 changed files with 2289 additions and 162 deletions

@ -42,6 +42,8 @@ ALL_OBJECTS = {
}
ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS}
# Additional aliases
ALL_OBJECTS_DICT["swish"] = silu
@keras_core_export("keras_core.activations.serialize")
@ -89,8 +91,8 @@ def get(identifier):
if identifier is None:
return linear
if isinstance(identifier, (str, dict)):
return deserialize(identifier)
elif callable(identifier):
identifier = deserialize(identifier)
if callable(identifier):
return identifier
raise TypeError(
f"Could not interpret activation function identifier: {identifier}"

@ -4,6 +4,8 @@ from absl.testing import parameterized
from keras_core import backend
from keras_core import testing
from keras_core.applications import efficientnet
from keras_core.applications import efficientnet_v2
from keras_core.applications import mobilenet
from keras_core.applications import mobilenet_v2
from keras_core.applications import mobilenet_v3
@ -19,14 +21,32 @@ except ImportError:
PIL = None
MODEL_LIST = [
# cls, last_dim
# vgg
(vgg16.VGG16, 512, vgg16),
(vgg19.VGG19, 512, vgg19),
# xception
(xception.Xception, 2048, xception),
# mobilnet
(mobilenet.MobileNet, 1024, mobilenet),
(mobilenet_v2.MobileNetV2, 1280, mobilenet_v2),
(mobilenet_v3.MobileNetV3Small, 576, mobilenet_v3),
(mobilenet_v3.MobileNetV3Large, 960, mobilenet_v3),
# efficientnet
(efficientnet.EfficientNetB0, 1280, efficientnet),
(efficientnet.EfficientNetB1, 1280, efficientnet),
(efficientnet.EfficientNetB2, 1408, efficientnet),
(efficientnet.EfficientNetB3, 1536, efficientnet),
(efficientnet.EfficientNetB4, 1792, efficientnet),
(efficientnet.EfficientNetB5, 2048, efficientnet),
(efficientnet.EfficientNetB6, 2304, efficientnet),
(efficientnet.EfficientNetB7, 2560, efficientnet),
(efficientnet_v2.EfficientNetV2B0, 1280, efficientnet_v2),
(efficientnet_v2.EfficientNetV2B1, 1280, efficientnet_v2),
(efficientnet_v2.EfficientNetV2B2, 1408, efficientnet_v2),
(efficientnet_v2.EfficientNetV2B3, 1536, efficientnet_v2),
(efficientnet_v2.EfficientNetV2S, 1280, efficientnet_v2),
(efficientnet_v2.EfficientNetV2M, 1280, efficientnet_v2),
(efficientnet_v2.EfficientNetV2L, 1280, efficientnet_v2),
]
# Add names for `named_parameters`.
MODEL_LIST = [(e[0].__name__, *e) for e in MODEL_LIST]

@ -0,0 +1,855 @@
import copy
import math
from tensorflow.io import gfile
from keras_core import backend
from keras_core import layers
from keras_core.api_export import keras_core_export
from keras_core.applications import imagenet_utils
from keras_core.models import Functional
from keras_core.operations import operation_utils
from keras_core.utils import file_utils
BASE_WEIGHTS_PATH = "https://storage.googleapis.com/keras-applications/"
WEIGHTS_HASHES = {
"b0": (
"902e53a9f72be733fc0bcb005b3ebbac",
"50bc09e76180e00e4465e1a485ddc09d",
),
"b1": (
"1d254153d4ab51201f1646940f018540",
"74c4e6b3e1f6a1eea24c589628592432",
),
"b2": (
"b15cce36ff4dcbd00b6dd88e7857a6ad",
"111f8e2ac8aa800a7a99e3239f7bfb39",
),
"b3": (
"ffd1fdc53d0ce67064dc6a9c7960ede0",
"af6d107764bb5b1abb91932881670226",
),
"b4": (
"18c95ad55216b8f92d7e70b3a046e2fc",
"ebc24e6d6c33eaebbd558eafbeedf1ba",
),
"b5": (
"ace28f2a6363774853a83a0b21b9421a",
"38879255a25d3c92d5e44e04ae6cec6f",
),
"b6": (
"165f6e37dce68623721b423839de8be5",
"9ecce42647a20130c1f39a5d4cb75743",
),
"b7": (
"8c03f828fec3ef71311cd463b6759d99",
"cbcfe4450ddf6f3ad90b1b398090fe4a",
),
}
DEFAULT_BLOCKS_ARGS = [
{
"kernel_size": 3,
"repeats": 1,
"filters_in": 32,
"filters_out": 16,
"expand_ratio": 1,
"id_skip": True,
"strides": 1,
"se_ratio": 0.25,
},
{
"kernel_size": 3,
"repeats": 2,
"filters_in": 16,
"filters_out": 24,
"expand_ratio": 6,
"id_skip": True,
"strides": 2,
"se_ratio": 0.25,
},
{
"kernel_size": 5,
"repeats": 2,
"filters_in": 24,
"filters_out": 40,
"expand_ratio": 6,
"id_skip": True,
"strides": 2,
"se_ratio": 0.25,
},
{
"kernel_size": 3,
"repeats": 3,
"filters_in": 40,
"filters_out": 80,
"expand_ratio": 6,
"id_skip": True,
"strides": 2,
"se_ratio": 0.25,
},
{
"kernel_size": 5,
"repeats": 3,
"filters_in": 80,
"filters_out": 112,
"expand_ratio": 6,
"id_skip": True,
"strides": 1,
"se_ratio": 0.25,
},
{
"kernel_size": 5,
"repeats": 4,
"filters_in": 112,
"filters_out": 192,
"expand_ratio": 6,
"id_skip": True,
"strides": 2,
"se_ratio": 0.25,
},
{
"kernel_size": 3,
"repeats": 1,
"filters_in": 192,
"filters_out": 320,
"expand_ratio": 6,
"id_skip": True,
"strides": 1,
"se_ratio": 0.25,
},
]
CONV_KERNEL_INITIALIZER = {
"class_name": "VarianceScaling",
"config": {
"scale": 2.0,
"mode": "fan_out",
"distribution": "truncated_normal",
},
}
DENSE_KERNEL_INITIALIZER = {
"class_name": "VarianceScaling",
"config": {
"scale": 1.0 / 3.0,
"mode": "fan_out",
"distribution": "uniform",
},
}
BASE_DOCSTRING = """Instantiates the {name} architecture.
Reference:
- [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](
https://arxiv.org/abs/1905.11946) (ICML 2019)
This function returns a Keras image classification model,
optionally loaded with weights pre-trained on ImageNet.
For image classification use cases, see
[this page for detailed examples](
https://keras.io/api/applications/#usage-examples-for-image-classification-models).
For transfer learning use cases, make sure to read the
[guide to transfer learning & fine-tuning](
https://keras.io/guides/transfer_learning/).
Note: each Keras Application expects a specific kind of input preprocessing.
For EfficientNet, input preprocessing is included as part of the model
(as a `Rescaling` layer), and thus
`keras_core.applications.efficientnet.preprocess_input` is actually a
pass-through function. EfficientNet models expect their inputs to be float
tensors of pixels with values in the `[0-255]` range.
Args:
include_top: Whether to include the fully-connected
layer at the top of the network. Defaults to `True`.
weights: One of `None` (random initialization),
`"imagenet"` (pre-training on ImageNet),
or the path to the weights file to be loaded.
Defaults to `"imagenet"`.
input_tensor: Optional Keras tensor
(i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: Optional shape tuple, only to be specified
if `include_top` is False.
It should have exactly 3 inputs channels.
pooling: Optional pooling mode for feature extraction
when `include_top` is `False`. Defaults to `None`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional layer.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional layer, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
classes: Optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified. 1000 is how many
ImageNet classes there are. Defaults to `1000`.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Defaults to 'softmax'.
When loading pretrained weights, `classifier_activation` can only
be `None` or `"softmax"`.
Returns:
A model instance.
"""
IMAGENET_STDDEV_RGB = [0.229, 0.224, 0.225]
def EfficientNet(
width_coefficient,
depth_coefficient,
default_size,
dropout_rate=0.2,
drop_connect_rate=0.2,
depth_divisor=8,
activation="swish",
blocks_args="default",
model_name="efficientnet",
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
):
"""Instantiates the EfficientNet architecture.
Args:
width_coefficient: float, scaling coefficient for network width.
depth_coefficient: float, scaling coefficient for network depth.
default_size: integer, default input image size.
dropout_rate: float, dropout rate before final classifier layer.
drop_connect_rate: float, dropout rate at skip connections.
depth_divisor: integer, a unit of network width.
activation: activation function.
blocks_args: list of dicts, parameters to construct block modules.
model_name: string, model name.
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
'imagenet' (pre-training on ImageNet),
or the path to the weights file to be loaded.
input_tensor: optional Keras tensor
(i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False.
It should have exactly 3 inputs channels.
pooling: optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional layer.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional layer, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A model instance.
"""
if blocks_args == "default":
blocks_args = DEFAULT_BLOCKS_ARGS
if not (weights in {"imagenet", None} or gfile.exists(weights)):
raise ValueError(
"The `weights` argument should be either "
"`None` (random initialization), `imagenet` "
"(pre-training on ImageNet), "
"or the path to the weights file to be loaded."
)
if weights == "imagenet" and include_top and classes != 1000:
raise ValueError(
'If using `weights="imagenet"` with `include_top`'
" as true, `classes` should be 1000"
)
# Determine proper input shape
input_shape = imagenet_utils.obtain_input_shape(
input_shape,
default_size=default_size,
min_size=32,
data_format=backend.image_data_format(),
require_flatten=include_top,
weights=weights,
)
if input_tensor is None:
img_input = layers.Input(shape=input_shape)
else:
if not backend.is_keras_tensor(input_tensor):
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
def round_filters(filters, divisor=depth_divisor):
"""Round number of filters based on depth multiplier."""
filters *= width_coefficient
new_filters = max(
divisor, int(filters + divisor / 2) // divisor * divisor
)
# Make sure that round down does not go down by more than 10%.
if new_filters < 0.9 * filters:
new_filters += divisor
return int(new_filters)
def round_repeats(repeats):
"""Round number of repeats based on depth multiplier."""
return int(math.ceil(depth_coefficient * repeats))
# Build stem
x = img_input
x = layers.Rescaling(1.0 / 255.0)(x)
x = layers.Normalization(axis=bn_axis)(x)
if weights == "imagenet":
# Note that the normaliztion layer uses square value of STDDEV as the
# variance for the layer: result = (input - mean) / sqrt(var)
# However, the original implemenetation uses (input - mean) / var to
# normalize the input, we need to divide another sqrt(var) to match the
# original implementation.
# See https://github.com/tensorflow/tensorflow/issues/49930 for more
# details
x = layers.Rescaling(
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB]
)(x)
x = layers.ZeroPadding2D(
padding=imagenet_utils.correct_pad(x, 3), name="stem_conv_pad"
)(x)
x = layers.Conv2D(
round_filters(32),
3,
strides=2,
padding="valid",
use_bias=False,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name="stem_conv",
)(x)
x = layers.BatchNormalization(axis=bn_axis, name="stem_bn")(x)
x = layers.Activation(activation, name="stem_activation")(x)
# Build blocks
blocks_args = copy.deepcopy(blocks_args)
b = 0
blocks = float(sum(round_repeats(args["repeats"]) for args in blocks_args))
for i, args in enumerate(blocks_args):
assert args["repeats"] > 0
# Update block input and output filters based on depth multiplier.
args["filters_in"] = round_filters(args["filters_in"])
args["filters_out"] = round_filters(args["filters_out"])
for j in range(round_repeats(args.pop("repeats"))):
# The first block needs to take care of stride and filter size
# increase.
if j > 0:
args["strides"] = 1
args["filters_in"] = args["filters_out"]
x = block(
x,
activation,
drop_connect_rate * b / blocks,
name=f"block{i + 1}{chr(j + 97)}_",
**args,
)
b += 1
# Build top
x = layers.Conv2D(
round_filters(1280),
1,
padding="same",
use_bias=False,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name="top_conv",
)(x)
x = layers.BatchNormalization(axis=bn_axis, name="top_bn")(x)
x = layers.Activation(activation, name="top_activation")(x)
if include_top:
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
if dropout_rate > 0:
x = layers.Dropout(dropout_rate, name="top_dropout")(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(
classes,
activation=classifier_activation,
kernel_initializer=DENSE_KERNEL_INITIALIZER,
name="predictions",
)(x)
else:
if pooling == "avg":
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
elif pooling == "max":
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = operation_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
model = Functional(inputs, x, name=model_name)
# Load weights.
if weights == "imagenet":
if include_top:
file_suffix = ".h5"
file_hash = WEIGHTS_HASHES[model_name[-2:]][0]
else:
file_suffix = "_notop.h5"
file_hash = WEIGHTS_HASHES[model_name[-2:]][1]
file_name = model_name + file_suffix
weights_path = file_utils.get_file(
file_name,
BASE_WEIGHTS_PATH + file_name,
cache_subdir="models",
file_hash=file_hash,
)
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)
return model
def block(
inputs,
activation="swish",
drop_rate=0.0,
name="",
filters_in=32,
filters_out=16,
kernel_size=3,
strides=1,
expand_ratio=1,
se_ratio=0.0,
id_skip=True,
):
"""An inverted residual block.
Args:
inputs: input tensor.
activation: activation function.
drop_rate: float between 0 and 1, fraction of the input units to drop.
name: string, block label.
filters_in: integer, the number of input filters.
filters_out: integer, the number of output filters.
kernel_size: integer, the dimension of the convolution window.
strides: integer, the stride of the convolution.
expand_ratio: integer, scaling coefficient for the input filters.
se_ratio: float between 0 and 1, fraction to squeeze the input filters.
id_skip: boolean.
Returns:
output tensor for the block.
"""
bn_axis = 3 if backend.image_data_format() == "channels_last" else 1
# Expansion phase
filters = filters_in * expand_ratio
if expand_ratio != 1:
x = layers.Conv2D(
filters,
1,
padding="same",
use_bias=False,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name=name + "expand_conv",
)(inputs)
x = layers.BatchNormalization(axis=bn_axis, name=name + "expand_bn")(x)
x = layers.Activation(activation, name=name + "expand_activation")(x)
else:
x = inputs
# Depthwise Convolution
if strides == 2:
x = layers.ZeroPadding2D(
padding=imagenet_utils.correct_pad(x, kernel_size),
name=name + "dwconv_pad",
)(x)
conv_pad = "valid"
else:
conv_pad = "same"
x = layers.DepthwiseConv2D(
kernel_size,
strides=strides,
padding=conv_pad,
use_bias=False,
depthwise_initializer=CONV_KERNEL_INITIALIZER,
name=name + "dwconv",
)(x)
x = layers.BatchNormalization(axis=bn_axis, name=name + "bn")(x)
x = layers.Activation(activation, name=name + "activation")(x)
# Squeeze and Excitation phase
if 0 < se_ratio <= 1:
filters_se = max(1, int(filters_in * se_ratio))
se = layers.GlobalAveragePooling2D(name=name + "se_squeeze")(x)
if bn_axis == 1:
se_shape = (filters, 1, 1)
else:
se_shape = (1, 1, filters)
se = layers.Reshape(se_shape, name=name + "se_reshape")(se)
se = layers.Conv2D(
filters_se,
1,
padding="same",
activation=activation,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name=name + "se_reduce",
)(se)
se = layers.Conv2D(
filters,
1,
padding="same",
activation="sigmoid",
kernel_initializer=CONV_KERNEL_INITIALIZER,
name=name + "se_expand",
)(se)
x = layers.multiply([x, se], name=name + "se_excite")
# Output phase
x = layers.Conv2D(
filters_out,
1,
padding="same",
use_bias=False,
kernel_initializer=CONV_KERNEL_INITIALIZER,
name=name + "project_conv",
)(x)
x = layers.BatchNormalization(axis=bn_axis, name=name + "project_bn")(x)
if id_skip and strides == 1 and filters_in == filters_out:
if drop_rate > 0:
x = layers.Dropout(
drop_rate, noise_shape=(None, 1, 1, 1), name=name + "drop"
)(x)
x = layers.add([x, inputs], name=name + "add")
return x
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB0",
"keras_core.applications.EfficientNetB0",
]
)
def EfficientNetB0(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
1.0,
1.0,
224,
0.2,
model_name="efficientnetb0",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB1",
"keras_core.applications.EfficientNetB1",
]
)
def EfficientNetB1(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
1.0,
1.1,
240,
0.2,
model_name="efficientnetb1",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB2",
"keras_core.applications.EfficientNetB2",
]
)
def EfficientNetB2(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
1.1,
1.2,
260,
0.3,
model_name="efficientnetb2",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB3",
"keras_core.applications.EfficientNetB3",
]
)
def EfficientNetB3(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
1.2,
1.4,
300,
0.3,
model_name="efficientnetb3",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB4",
"keras_core.applications.EfficientNetB4",
]
)
def EfficientNetB4(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
1.4,
1.8,
380,
0.4,
model_name="efficientnetb4",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB5",
"keras_core.applications.EfficientNetB5",
]
)
def EfficientNetB5(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
1.6,
2.2,
456,
0.4,
model_name="efficientnetb5",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB6",
"keras_core.applications.EfficientNetB6",
]
)
def EfficientNetB6(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
1.8,
2.6,
528,
0.5,
model_name="efficientnetb6",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
@keras_core_export(
[
"keras_core.applications.efficientnet.EfficientNetB7",
"keras_core.applications.EfficientNetB7",
]
)
def EfficientNetB7(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation="softmax",
**kwargs,
):
return EfficientNet(
2.0,
3.1,
600,
0.5,
model_name="efficientnetb7",
include_top=include_top,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation,
**kwargs,
)
EfficientNetB0.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB0")
EfficientNetB1.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB1")
EfficientNetB2.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB2")
EfficientNetB3.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB3")
EfficientNetB4.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB4")
EfficientNetB5.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB5")
EfficientNetB6.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB6")
EfficientNetB7.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB7")
@keras_core_export("keras_core.applications.efficientnet.preprocess_input")
def preprocess_input(x, data_format=None):
"""A placeholder method for backward compatibility.
The preprocessing logic has been included in the efficientnet model
implementation. Users are no longer required to call this method to
normalize the input data. This method does nothing and only kept as a
placeholder to align the API surface between old and new version of model.
Args:
x: A floating point `numpy.array` or a tensor.
data_format: Optional data format of the image tensor/array. `None`
means the global setting `keras_core.backend.image_data_format()`
is used (unless you changed it, it uses `"channels_last"`).
Defaults to `None`.
Returns:
Unchanged `numpy.array` or tensor.
"""
return x
@keras_core_export("keras_core.applications.efficientnet.decode_predictions")
def decode_predictions(preds, top=5):
return imagenet_utils.decode_predictions(preds, top=top)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

File diff suppressed because it is too large Load Diff

@ -120,7 +120,7 @@ def MobileNetV2(
if weights == "imagenet" and include_top and classes != 1000:
raise ValueError(
'If using `weights` as `"imagenet"` with `include_top` '
'If using `weights="imagenet"` with `include_top` '
f"as true, `classes` should be 1000. Received `classes={classes}`"
)

@ -109,7 +109,7 @@ Args:
include_top: Boolean, whether to include the fully-connected
layer at the top of the network. Defaults to `True`.
weights: String, one of `None` (random initialization),
'imagenet' (pre-training on ImageNet),
`"imagenet"` (pre-training on ImageNet),
or the path to the weights file to be loaded.
input_tensor: Optional Keras tensor (i.e. output of
`layers.Input()`)
@ -127,7 +127,7 @@ Args:
- `max` means that global max pooling will
be applied.
classes: Integer, optional number of classes to classify images
into, only to be specified if `include_top` is True, and
into, only to be specified if `include_top` is `True`, and
if no `weights` argument is specified.
dropout_rate: fraction of the input units to drop on the last layer.
classifier_activation: A `str` or callable. The activation function to use
@ -176,7 +176,7 @@ def MobileNetV3(
if weights == "imagenet" and include_top and classes != 1000:
raise ValueError(
'If using `weights` as `"imagenet"` with `include_top` '
'If using `weights="imagenet"` with `include_top` '
"as true, `classes` should be 1000. "
f"Received classes={classes}"
)

@ -111,7 +111,6 @@ from keras_core.layers.reshaping.reshape import Reshape
from keras_core.layers.reshaping.up_sampling1d import UpSampling1D
from keras_core.layers.reshaping.up_sampling2d import UpSampling2D
from keras_core.layers.reshaping.up_sampling3d import UpSampling3D
from keras_core.layers.reshaping.zero_padding1d import ZeroPadding1D
from keras_core.layers.reshaping.zero_padding2d import ZeroPadding2D
from keras_core.layers.reshaping.zero_padding3d import ZeroPadding3D
from keras_core.layers.rnn.bidirectional import Bidirectional

@ -63,14 +63,15 @@ class BaseDepthwiseConv(Layer):
specified, the same dilation rate will be used for all dimensions.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
depthwise_initializer: Initializer for the depthwsie convolution
kernel. If `None`, the default initializer (`"glorot_uniform"`)
will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
depthwise_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
depthwise_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
@ -91,12 +92,12 @@ class BaseDepthwiseConv(Layer):
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
depthwise_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
depthwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
depthwise_constraint=None,
bias_constraint=None,
trainable=True,
name=None,
@ -119,11 +120,11 @@ class BaseDepthwiseConv(Layer):
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.depthwise_initializer = initializers.get(depthwise_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.depthwise_constraint = constraints.get(depthwise_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=self.rank + 2)
self.data_format = self.data_format
@ -164,16 +165,16 @@ class BaseDepthwiseConv(Layer):
self.input_spec = InputSpec(
min_ndim=self.rank + 2, axes={channel_axis: input_channel}
)
kernel_shape = self.kernel_size + (
depthwise_shape = self.kernel_size + (
input_channel,
self.depth_multiplier,
)
self.kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
shape=depthwise_shape,
initializer=self.depthwise_initializer,
regularizer=self.depthwise_regularizer,
constraint=self.depthwise_constraint,
trainable=True,
dtype=self.dtype,
)
@ -249,14 +250,14 @@ class BaseDepthwiseConv(Layer):
"dilation_rate": self.dilation_rate,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
"depthwise_initializer": initializers.serialize(
self.depthwise_initializer
),
"bias_initializer": initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
"depthwise_regularizer": regularizers.serialize(
self.depthwise_regularizer
),
"bias_regularizer": regularizers.serialize(
self.bias_regularizer
@ -264,8 +265,8 @@ class BaseDepthwiseConv(Layer):
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(
self.kernel_constraint
"depthwise_constraint": constraints.serialize(
self.depthwise_constraint
),
"bias_constraint": constraints.serialize(self.bias_constraint),
}

@ -51,14 +51,15 @@ class DepthwiseConv1D(BaseDepthwiseConv):
rate to use for dilated convolution.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
depthwise_initializer: Initializer for the convolution kernel.
If `None`, the default initializer (`"glorot_uniform"`)
will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
depthwise_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
depthwise_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
@ -106,12 +107,12 @@ class DepthwiseConv1D(BaseDepthwiseConv):
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
depthwise_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
depthwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
depthwise_constraint=None,
bias_constraint=None,
**kwargs
):
@ -125,12 +126,12 @@ class DepthwiseConv1D(BaseDepthwiseConv):
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
depthwise_initializer=depthwise_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
depthwise_regularizer=depthwise_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
depthwise_constraint=depthwise_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -51,14 +51,15 @@ class DepthwiseConv2D(BaseDepthwiseConv):
rate to use for dilated convolution.
activation: Activation function. If `None`, no activation is applied.
use_bias: bool, if `True`, bias will be added to the output.
kernel_initializer: Initializer for the convolution kernel. If `None`,
the default initializer (`"glorot_uniform"`) will be used.
depthwise_initializer: Initializer for the convolution kernel.
If `None`, the default initializer (`"glorot_uniform"`)
will be used.
bias_initializer: Initializer for the bias vector. If `None`, the
default initializer (`"zeros"`) will be used.
kernel_regularizer: Optional regularizer for the convolution kernel.
depthwise_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Optional regularizer function for the output.
kernel_constraint: Optional projection function to be applied to the
depthwise_constraint: Optional projection function to be applied to the
kernel after being updated by an `Optimizer` (e.g. used to implement
norm constraints or value constraints for layer weights). The
function must take as input the unprojected variable and must return
@ -106,12 +107,12 @@ class DepthwiseConv2D(BaseDepthwiseConv):
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
depthwise_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
depthwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
depthwise_constraint=None,
bias_constraint=None,
**kwargs
):
@ -125,12 +126,12 @@ class DepthwiseConv2D(BaseDepthwiseConv):
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
depthwise_initializer=depthwise_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
depthwise_regularizer=depthwise_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
depthwise_constraint=depthwise_constraint,
bias_constraint=bias_constraint,
**kwargs
)

@ -170,6 +170,14 @@ class Normalization(Layer):
initializer="ones",
trainable=False,
)
# For backwards compatibility with older saved models.
self.count = self.add_weight(
name="count",
shape=(),
dtype="int",
initializer="zeros",
trainable=False,
)
self.built = True
self.finalize_state()
else:
@ -307,6 +315,6 @@ class Normalization(Layer):
return config
def load_own_variables(self, store):
# Ensure that we call finalize_state after variable loading.
super().load_own_variables(store)
# Ensure that we call finalize_state after variable loading.
self.finalize_state()

@ -26,12 +26,10 @@ class Cropping1D(Layer):
[[8 9]]]
Args:
cropping: Int, or tuple of int (length 2), or dictionary.
- If int: how many units should be trimmed off at the beginning and
end of the cropping dimension (axis 1).
- If tuple of 2 ints: how many units should be trimmed off at the
beginning and end of the cropping dimension
(`(left_crop, right_crop)`).
cropping: Integer or tuple of integers of length 2.
How many units should be trimmed off at the beginning and end of
the cropping dimension (axis 1).
If a single int is provided, the same value will be used for both.
Input shape:
3D tensor with shape `(batch_size, axis_to_crop, features)`

@ -1,69 +0,0 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.ZeroPadding1D")
class ZeroPadding1D(Layer):
"""Zero-padding layer for 1D input (e.g. temporal sequence).
Examples:
>>> input_shape = (2, 2, 3)
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
>>> x
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
>>> y = keras_core.layers.ZeroPadding1D(padding=2)(x)
>>> y
[[[ 0 0 0]
[ 0 0 0]
[ 0 1 2]
[ 3 4 5]
[ 0 0 0]
[ 0 0 0]]
[[ 0 0 0]
[ 0 0 0]
[ 6 7 8]
[ 9 10 11]
[ 0 0 0]
[ 0 0 0]]]
Args:
padding: Int, or tuple of int (length 2), or dictionary.
- If int: how many zeros to add at the beginning and end of
the padding dimension (axis 1).
- If tuple of 2 ints: how many zeros to add at the beginning and the
end of the padding dimension (`(left_pad, right_pad)`).
Input shape:
3D tensor with shape `(batch_size, axis_to_pad, features)`
Output shape:
3D tensor with shape `(batch_size, padded_axis, features)`
"""
def __init__(self, padding=1, name=None, dtype=None):
super().__init__(name=name, dtype=dtype)
if isinstance(padding, int):
padding = (padding, padding)
self.padding = padding
self.input_spec = InputSpec(ndim=3)
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
if input_shape[1] is not None:
input_shape[1] += self.padding[0] + self.padding[1]
return tuple(output_shape)
def call(self, inputs):
all_dims_padding = ((0, 0), self.padding, (0, 0))
return ops.pad(inputs, all_dims_padding)
def get_config(self):
config = {"padding": self.padding}
base_config = super().get_config()
return {**base_config, **config}

@ -1,35 +0,0 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import testing
class ZeroPadding1DTest(testing.TestCase, parameterized.TestCase):
def test_zero_padding_1d(self):
inputs = np.random.rand(1, 2, 3)
outputs = layers.ZeroPadding1D(padding=(1, 2))(inputs)
for index in [0, -1, -2]:
self.assertAllClose(outputs[:, index, :], 0.0)
self.assertAllClose(outputs[:, 1:-2, :], inputs)
@parameterized.named_parameters(("one_tuple", (2, 2)), ("one_int", 2))
def test_zero_padding_1d_with_same_padding(self, padding):
inputs = np.random.rand(1, 2, 3)
outputs = layers.ZeroPadding1D(padding=padding)(inputs)
for index in [0, 1, -1, -2]:
self.assertAllClose(outputs[:, index, :], 0.0)
self.assertAllClose(outputs[:, 2:-2, :], inputs)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_zero_padding_1d_with_dynamic_spatial_dim(self):
input_layer = layers.Input(batch_shape=(1, None, 3))
padded = layers.ZeroPadding1D((1, 2))(input_layer)
self.assertEqual(padded.shape, (1, None, 3))

@ -67,8 +67,8 @@ class ZeroPadding2D(Layer):
`(batch_size, channels, padded_height, padded_width)`
"""
def __init__(self, padding=(1, 1), data_format=None, name=None, dtype=None):
super().__init__(name=name, dtype=dtype)
def __init__(self, padding=(1, 1), data_format=None, **kwargs):
super().__init__(**kwargs)
self.data_format = backend.standardize_data_format(data_format)
if isinstance(padding, int):
self.padding = ((padding, padding), (padding, padding))