Update symbolic_arguments.py (#513)

* Update symbolic_arguments.py

Added validations to __init__ function

* Update symbolic_arguments.py

Removed the # TODO as requested
This commit is contained in:
Sayed Qaiser Ali 2023-07-18 21:57:40 +05:30 committed by Francois Chollet
parent 7de5c53be4
commit 336c6a042b
14 changed files with 664 additions and 163 deletions

@ -1,9 +1,10 @@
# Benchmark the performance of torch custom training loop
This directory contains benchmarks to compare the performance between Keras and
Torch while using Torch custom training loop. The benchmark purpose is to
understand the performance diff resulting from the modeling API choice (Keras
or Torch).
This directory contains benchmarks to compare the performance of a Keras model
and a equivalent Torch model while using the same Torch custom training loop.
The benchmark purpose is to understand the performance diff resulting from the
modeling API choice (Keras or Torch).
To run the benchmark, use the command below and change to your target:

@ -0,0 +1,400 @@
"""
Title: Compact Convolutional Transformers
Author: [Sayak Paul](https://twitter.com/RisingSayak)
Converted to Keras Core by: [Muhammad Anas Raza](https://anasrz.com)
Date created: 2021/06/30
Last modified: 2023/07/17
Description: Compact Convolutional Transformers for efficient image classification.
Accelerator: GPU
"""
"""
As discussed in the [Vision Transformers (ViT)](https://arxiv.org/abs/2010.11929) paper,
a Transformer-based architecture for vision typically requires a larger dataset than
usual, as well as a longer pre-training schedule. [ImageNet-1k](http://imagenet.org/)
(which has about a million images) is considered to fall under the medium-sized data regime with
respect to ViTs. This is primarily because, unlike CNNs, ViTs (or a typical
Transformer-based architecture) do not have well-informed inductive biases (such as
convolutions for processing images). This begs the question: can't we combine the
benefits of convolution and the benefits of Transformers
in a single network architecture? These benefits include parameter-efficiency, and
self-attention to process long-range and global dependencies (interactions between
different regions in an image).
In [Escaping the Big Data Paradigm with Compact Transformers](https://arxiv.org/abs/2104.05704),
Hassani et al. present an approach for doing exactly this. They proposed the
**Compact Convolutional Transformer** (CCT) architecture. In this example, we will work on an
implementation of CCT and we will see how well it performs on the CIFAR-10 dataset.
If you are unfamiliar with the concept of self-attention or Transformers, you can read
[this chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11/r-3/312)
from François Chollet's book *Deep Learning with Python*. This example uses
code snippets from another example,
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
"""
"""
## Imports
"""
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
from keras_core import layers
import keras_core as keras
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
"""
## Hyperparameters and constants
"""
positional_emb = True
conv_layers = 2
projection_dim = 128
num_heads = 2
transformer_units = [
projection_dim,
projection_dim,
]
transformer_layers = 2
stochastic_depth_rate = 0.1
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 30
image_size = 32
"""
## Load CIFAR-10 dataset
"""
num_classes = 10
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
"""
## The CCT tokenizer
The first recipe introduced by the CCT authors is the tokenizer for processing the
images. In a standard ViT, images are organized into uniform *non-overlapping* patches.
This eliminates the boundary-level information present in between different patches. This
is important for a neural network to effectively exploit the locality information. The
figure below presents an illustration of how images are organized into patches.
![](https://i.imgur.com/IkBK9oY.png)
We already know that convolutions are quite good at exploiting locality information. So,
based on this, the authors introduce an all-convolution mini-network to produce image
patches.
"""
class CCTTokenizer(layers.Layer):
def __init__(
self,
kernel_size=3,
stride=1,
padding=1,
pooling_kernel_size=3,
pooling_stride=2,
num_conv_layers=conv_layers,
num_output_channels=[64, 128],
positional_emb=positional_emb,
**kwargs,
):
super().__init__(**kwargs)
# This is our tokenizer.
self.conv_model = keras.Sequential()
for i in range(num_conv_layers):
self.conv_model.add(
layers.Conv2D(
num_output_channels[i],
kernel_size,
stride,
padding="valid",
use_bias=False,
activation="relu",
kernel_initializer="he_normal",
)
)
self.conv_model.add(layers.ZeroPadding2D(padding))
self.conv_model.add(
layers.MaxPool2D(pooling_kernel_size, pooling_stride, "same")
)
self.positional_emb = positional_emb
def call(self, images):
outputs = self.conv_model(images)
# After passing the images through our mini-network the spatial dimensions
# are flattened to form sequences.
reshaped = tf.reshape(
outputs,
(-1, tf.shape(outputs)[1] * tf.shape(outputs)[2], tf.shape(outputs)[-1]),
)
return reshaped
def positional_embedding(self, image_size):
# Positional embeddings are optional in CCT. Here, we calculate
# the number of sequences and initialize an `Embedding` layer to
# compute the positional embeddings later.
if self.positional_emb:
dummy_inputs = tf.ones((1, image_size, image_size, 3))
dummy_outputs = self.call(dummy_inputs)
sequence_length = tf.shape(dummy_outputs)[1]
projection_dim = tf.shape(dummy_outputs)[-1]
embed_layer = layers.Embedding(
input_dim=sequence_length, output_dim=projection_dim
)
return embed_layer, sequence_length
else:
return None
"""
## Sequence Pooling
Another recipe introduced in CCT is attention pooling or sequence pooling. In ViT, only
the feature map corresponding to the class token is pooled and is then used for the
subsequent classification task (or any other downstream task).
"""
class SequencePooling(layers.Layer):
def __init__(self):
super().__init__()
self.attention = layers.Dense(1)
def call(self, x):
attention_weights = tf.nn.softmax(self.attention(x), axis=1)
weighted_representation = tf.matmul(
attention_weights, x, transpose_a=True
)
return tf.squeeze(weighted_representation, -2)
"""
## Stochastic depth for regularization
[Stochastic depth](https://arxiv.org/abs/1603.09382) is a regularization technique that
randomly drops a set of layers. During inference, the layers are kept as they are. It is
very much similar to [Dropout](https://jmlr.org/papers/v15/srivastava14a.html) but only
that it operates on a block of layers rather than individual nodes present inside a
layer. In CCT, stochastic depth is used just before the residual blocks of a Transformers
encoder.
"""
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prop
def call(self, x, training=None):
if training:
keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (tf.shape(x).shape[0] - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
"""
## MLP for the Transformers encoder
"""
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
"""
## Data augmentation
In the [original paper](https://arxiv.org/abs/2104.05704), the authors use
[AutoAugment](https://arxiv.org/abs/1805.09501) to induce stronger regularization. For
this example, we will be using the standard geometric augmentations like random cropping
and flipping.
"""
# Note the rescaling layer. These layers have pre-defined inference behavior.
data_augmentation = keras.Sequential(
[
layers.Rescaling(scale=1.0 / 255),
layers.RandomCrop(image_size, image_size),
layers.RandomFlip("horizontal"),
],
name="data_augmentation",
)
"""
## The final CCT model
In CCT, outputs from the Transformers encoder are weighted and then passed on to the final task-specific layer (in
this example, we do classification).
"""
def create_cct_model(
image_size=image_size,
input_shape=input_shape,
num_heads=num_heads,
projection_dim=projection_dim,
transformer_units=transformer_units,
):
inputs = layers.Input(input_shape)
# Augment data.
augmented = data_augmentation(inputs)
# Encode patches.
cct_tokenizer = CCTTokenizer()
encoded_patches = cct_tokenizer(augmented)
# Apply positional embedding.
if positional_emb:
pos_embed, seq_length = cct_tokenizer.positional_embedding(image_size)
positions = tf.range(start=0, limit=seq_length, delta=1)
position_embeddings = pos_embed(positions)
encoded_patches += position_embeddings
# Calculate Stochastic Depth probabilities.
dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]
# Create multiple layers of the Transformer block.
for i in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
attention_output = StochasticDepth(dpr[i])(attention_output)
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
x3 = StochasticDepth(dpr[i])(x3)
encoded_patches = layers.Add()([x3, x2])
# Apply sequence pooling.
representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
weighted_representation = SequencePooling()(representation)
# Classify outputs.
logits = layers.Dense(num_classes)(weighted_representation)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=logits)
return model
"""
## Model training and evaluation
"""
def run_experiment(model):
optimizer = keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001)
model.compile(
optimizer=optimizer,
loss=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
metrics=[
keras.metrics.CategoricalAccuracy(name="accuracy"),
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[checkpoint_callback],
)
model.load_weights(checkpoint_filepath)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
return history
cct_model = create_cct_model()
history = run_experiment(cct_model)
"""
Let's now visualize the training progress of the model.
"""
plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()
"""
The CCT model we just trained has just **0.4 million** parameters, and it gets us to
~78% top-1 accuracy within 30 epochs. The plot above shows no signs of overfitting as
well. This means we can train this network for longer (perhaps with a bit more
regularization) and may obtain even better performance. This performance can further be
improved by additional recipes like cosine decay learning rate schedule, other data augmentation
techniques like [AutoAugment](https://arxiv.org/abs/1805.09501),
[MixUp](https://arxiv.org/abs/1710.09412) or
[Cutmix](https://arxiv.org/abs/1905.04899). With these modifications, the authors present
95.1% top-1 accuracy on the CIFAR-10 dataset. The authors also present a number of
experiments to study how the number of convolution blocks, Transformers layers, etc.
affect the final performance of CCTs.
For a comparison, a ViT model takes about **4.7 million** parameters and **100
epochs** of training to reach a top-1 accuracy of 78.22% on the CIFAR-10 dataset. You can
refer to
[this notebook](https://colab.research.google.com/gist/sayakpaul/1a80d9f582b044354a1a26c5cb3d69e5/image_classification_with_vision_transformer.ipynb)
to know about the experimental setup.
The authors also demonstrate the performance of Compact Convolutional Transformers on
NLP tasks and they report competitive results there.
"""

@ -0,0 +1 @@
from keras_core.backend.torch.optimizers.torch_optimizer import TorchOptimizer

@ -0,0 +1,24 @@
import torch
from keras_core.optimizers.base_optimizer import BaseOptimizer
class TorchOptimizer(BaseOptimizer):
def __new__(cls, *args, **kwargs):
# Import locally to avoid circular imports.
from keras_core import optimizers
from keras_core.backend.torch.optimizers import torch_sgd
OPTIMIZERS = {optimizers.SGD: torch_sgd.SGD}
if cls in OPTIMIZERS:
return OPTIMIZERS[cls](*args, **kwargs)
return super().__new__(cls)
def _apply_weight_decay(self, variables):
if self.weight_decay is None:
return
torch._foreach_mul_(
[v.value for v in variables if self._use_weight_decay(v)],
1 - self.weight_decay * self._get_current_learning_rate(),
)

@ -0,0 +1,43 @@
import torch
from keras_core import optimizers
class SGD(optimizers.SGD):
def _internal_apply_gradients(self, grads_and_vars):
grads, trainable_variables = zip(*grads_and_vars)
self._parallel_update_step(
grads,
[v.value for v in trainable_variables],
self._get_current_learning_rate(),
)
self.iterations.assign(self.iterations + 1)
def _parallel_update_step(
self,
grads,
variables,
learning_rate,
):
if self.momentum != 0:
bufs = [
self.momentums[self._get_variable_index(variable.value)]
for variable in variables
]
for i in range(len(bufs)):
if bufs[i] is None:
bufs[i] = torch.clone(grads[i]).detach()
torch._foreach_mul_(bufs, self.momentum)
torch._foreach_add_(bufs, grads, alpha=-learning_rate)
if self.nesterov:
torch._foreach_add_(variables, grads, alpha=-learning_rate)
torch._foreach_add_(variables, bufs, alpha=self.momentum)
else:
torch._foreach_add_(variables, bufs)
else:
torch._foreach_add_(variables, grads, alpha=-learning_rate)

@ -0,0 +1,88 @@
import numpy as np
from keras_core import testing
from keras_core.backend import KerasTensor
from keras_core.layers import InputLayer
class InputLayerTest(testing.TestCase):
# Testing happy path for layer without input tensor
def test_input_basic(self):
input_shape = (2, 3)
batch_size = 4
dtype = "float32"
ndim = len(tuple((batch_size,) + input_shape))
values = InputLayer(
shape=input_shape, batch_size=batch_size, dtype=dtype
)
self.assertEqual(values.dtype, dtype)
self.assertEqual(values.batch_shape[0], batch_size)
self.assertEqual(values.batch_shape[1:], input_shape)
self.assertEqual(values.trainable, True)
self.assertIsInstance(values.output, KerasTensor)
self.assertEqual(values.output.ndim, ndim)
self.assertEqual(values.output.dtype, dtype)
# Testing shape is not None and batch_shape is not None condition
def test_input_error1(self):
input_shape = (2, 3)
with self.assertRaisesRegex(
ValueError, "cannot pass both `shape` and `batch_shape`"
):
InputLayer(shape=input_shape, batch_shape=input_shape)
# Testing batch_size is not None and batch_shape is not None
def test_input_error2(self):
input_shape = (2, 3)
batch_size = 4
with self.assertRaisesRegex(
ValueError, "cannot pass both `batch_size` and `batch_shape`"
):
InputLayer(batch_size=batch_size, batch_shape=input_shape)
# Testing shape is None and batch_shape is None
def test_input_error3(self):
with self.assertRaisesRegex(ValueError, "pass a `shape` argument."):
InputLayer(shape=None, batch_shape=None)
# Testing Input tensor is not Keras tensor
def test_input_tensor_error(self):
input_shape = (2, 3)
batch_size = 4
input_tensor = np.zeros(input_shape)
with self.assertRaisesRegex(
ValueError, "Argument `input_tensor` must be a KerasTensor"
):
InputLayer(
shape=input_shape,
batch_size=batch_size,
input_tensor=input_tensor,
)
# Testing happy path for layer with input tensor
def testing_input_tensor(self):
input_shape = (2, 3)
batch_size = 4
dtype = "float32"
input_tensor = KerasTensor(shape=input_shape, dtype=dtype)
values = InputLayer(
shape=input_shape,
batch_size=batch_size,
input_tensor=input_tensor,
dtype=dtype,
)
self.assertEqual(values.dtype, dtype)
self.assertEqual(values.batch_shape[0], batch_size)
self.assertEqual(values.batch_shape[1:], input_shape)
self.assertEqual(values.trainable, True)
self.assertIsInstance(values.output, KerasTensor)
self.assertEqual(values.output, input_tensor)
self.assertEqual(values.output.ndim, input_tensor.ndim)
self.assertEqual(values.output.dtype, dtype)

@ -84,8 +84,8 @@ class Layer(BackendLayer, Operation):
Attributes:
name: The name of the layer (string).
dtype: The dtype of the layer's weights.
variable_dtype: Dtype of the layer's variables.
dtype: Dtype of the layer's weights. Alias of `layer.variable_dtype`.
variable_dtype: Dtype of the layer's weights.
compute_dtype: The dtype of the layer's computations.
Layers automatically cast inputs to this dtype, which causes
the computations and output to also be in this dtype.
@ -374,21 +374,19 @@ class Layer(BackendLayer, Operation):
constraint=None,
name=None,
):
# TODO: handle layout
self._check_super_called()
initializer = initializers.get(initializer)
variable = backend.Variable(
initializer=initializer,
"""Add a weight variable to the layer.
Alias of `add_weight()`.
"""
return self.add_weight(
shape=shape,
dtype=dtype or self.variable_dtype,
initializer=initializer,
dtype=dtype,
trainable=trainable,
regularizer=regularizer,
constraint=constraint,
name=name,
)
# Will be added to layer.losses
variable.regularizer = regularizer
variable.constraint = constraint
self._track_variable(variable)
return variable
def add_weight(
self,
@ -402,8 +400,6 @@ class Layer(BackendLayer, Operation):
):
"""Add a weight variable to the layer.
Alias of `add_variable()`.
Args:
shape: Shape tuple for the variable.
Must be fully-defined (no `None` entries).
@ -422,15 +418,21 @@ class Layer(BackendLayer, Operation):
name: String name of the variable. Useful
for debugging purposes.
"""
return self.add_variable(
shape=shape,
# TODO: handle layout
self._check_super_called()
initializer = initializers.get(initializer)
variable = backend.Variable(
initializer=initializer,
dtype=dtype,
shape=shape,
dtype=dtype or self.variable_dtype,
trainable=trainable,
regularizer=regularizer,
constraint=constraint,
name=name,
)
# Will be added to layer.losses
variable.regularizer = regularizer
variable.constraint = constraint
self._track_variable(variable)
return variable
@property
def trainable(self):
@ -459,9 +461,13 @@ class Layer(BackendLayer, Operation):
@property
def variables(self):
# Return only weights/rng state/metric variables
# of all Layers, recursively.
# Also deduplicate them.
"""List of all layer state, including metric variables and random seeds.
This extends `layer.weights` to include all state used by the layer
including state for metrics and `SeedGenerator`s.
"""
# Return all `Variables` associate with the layer including metrics
# and random seeds. Also deduplicate them.
variables = []
seen_ids = set()
for v in self._trainable_variables + self._non_trainable_variables:
@ -481,20 +487,32 @@ class Layer(BackendLayer, Operation):
@property
def trainable_variables(self):
"""List of all trainable layer state.
This is equivalent to `layer.trainable_weights`.
"""
if not self.trainable:
return []
return [v for v in self.variables if v.trainable]
@property
def non_trainable_variables(self):
"""List of all non-trainable layer state.
This extends `layer.non_trainable_weights` to include all state used by
the layer including state for metrics and `SeedGenerator`s.
"""
if not self.trainable:
return self.variables
return [v for v in self.variables if not v.trainable]
@property
def weights(self):
"""List of weight variables of the layer."""
# Return only "own weights" of all Layers, recursively.
"""List of all weight variables of the layer.
Unlike, `layer.variables` this excludes metric state and random seeds.
"""
# Return only `Variables` directly owned by layers and sub-layers.
# Also deduplicate them.
weights = []
seen_ids = set()
@ -511,10 +529,9 @@ class Layer(BackendLayer, Operation):
@property
def trainable_weights(self):
"""List of trainable weight variables of the layer.
"""List of all trainable weight variables of the layer.
These are the weights that get updated by the optimizer
during training.
These are the weights that get updated by the optimizer during training.
"""
if not self.trainable:
return []
@ -522,10 +539,11 @@ class Layer(BackendLayer, Operation):
@property
def non_trainable_weights(self):
"""List of non-trainable weight variables of the layer.
"""List of all non-trainable weight variables of the layer.
Non-trainable weights may include batch normalization statistics,
metric variables, or RNG seed variables.
These are the weights that should not be updated by the optimizer during
training. Unlike, `layer.non_trainable_variables` this excludes metric
state and random seeds.
"""
if not self.trainable:
return self.weights
@ -555,7 +573,7 @@ class Layer(BackendLayer, Operation):
@property
def dtype(self):
"""The dtype of the state (weights) of the layer."""
"""Alias of `layer.variable_dtype`."""
return self.variable_dtype
@property

@ -1,15 +1,14 @@
import numpy as np
from keras_core import backend
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.random.seed_generator import SeedGenerator
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
from keras_core.utils import image_utils
from keras_core.utils import dtype_utils
from keras_core.utils.module_utils import tensorflow as tf
@keras_core_export("keras_core.layers.RandomCrop")
class RandomCrop(TFDataLayer):
class RandomCrop(Layer):
"""A preprocessing layer which randomly crops images during training.
During training, this layer will randomly choose a location to crop images
@ -53,128 +52,41 @@ class RandomCrop(TFDataLayer):
`name` and `dtype`.
"""
def __init__(
self, height, width, seed=None, data_format=None, name=None, **kwargs
):
def __init__(self, height, width, seed=None, name=None, **kwargs):
if not tf.available:
raise ImportError(
"Layer RandomCrop requires TensorFlow. "
"Install it via `pip install tensorflow`."
)
super().__init__(name=name, **kwargs)
self.height = height
self.width = width
self.seed = seed or backend.random.make_default_seed()
self.seed_generator = SeedGenerator(seed)
self.data_format = backend.standardize_data_format(data_format)
if self.data_format == "channels_first":
self.heigh_axis = -2
self.width_axis = -1
elif self.data_format == "channels_last":
self.height_axis = -3
self.width_axis = -2
self.layer = tf.keras.layers.RandomCrop(
height=height,
width=width,
seed=self.seed,
name=name,
)
self.supports_masking = False
self.supports_jit = False
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
def call(self, inputs, training=True):
inputs = self.backend.cast(inputs, self.compute_dtype)
input_shape = self.backend.shape(inputs)
is_batched = len(input_shape) > 3
inputs = (
self.backend.numpy.expand_dims(inputs, axis=0)
if not is_batched
else inputs
)
h_diff = input_shape[self.height_axis] - self.height
w_diff = input_shape[self.width_axis] - self.width
def random_crop():
# input_dtype_max = (2 ** dtype_utils.dtype_size(inputs.dtype)) - 1
input_height, input_width = (
input_shape[self.height_axis],
input_shape[self.width_axis],
)
h_start = self.backend.cast(
ops.random.uniform(
(),
0,
maxval=float(input_height - self.height + 1),
dtype=inputs.dtype,
seed=self.seed_generator,
),
h_diff.dtype,
)
w_start = self.backend.cast(
ops.random.uniform(
(),
0,
maxval=float(input_width - self.width + 1),
dtype=inputs.dtype,
seed=self.seed_generator,
),
h_diff.dtype,
)
# rands = ops.random.uniform(
# [2], 0, input_dtype_max, inputs.dtype, seed=self.seed_generator
# )
# original_dtype = h_diff.dtype
# h_start = self.backend.cast(
# rands[0] % self.backend.cast((h_diff + 1), self.compute_dtype),
# original_dtype,
# )
# w_start = self.backend.cast(
# rands[1] % self.backend.cast((w_diff + 1), self.compute_dtype),
# original_dtype,
# )
if self.data_format == "channels_last":
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
else:
return inputs[
:,
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
def resize():
outputs = image_utils.smart_resize(
inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
)
# smart_resize will always output float32, so we need to re-cast.
return self.backend.cast(outputs, self.compute_dtype)
outputs = self.backend.cond(
self.backend.numpy.all((training, h_diff >= 0, w_diff >= 0)),
random_crop,
resize,
)
if self.backend != "tensorflow" and not backend_utils.in_tf_graph():
outputs = self.backend.convert_to_tensor(outputs)
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
outputs = self.layer.call(inputs, training=training)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs
def compute_output_shape(self, input_shape, *args, **kwargs):
input_shape = list(input_shape)
input_shape[self.height_axis] = self.height
input_shape[self.width_axis] = self.width
return tuple(input_shape)
def compute_output_shape(self, input_shape):
return tuple(self.layer.compute_output_shape(input_shape))
def get_config(self):
config = super().get_config()
config.update(
{
"height": self.height,
"width": self.width,
"seed": self.seed,
"data_format": self.data_format,
}
)
config = self.layer.get_config()
config.update({"seed": self.seed})
return config

@ -522,7 +522,7 @@ class GRU(RNN):
# implementation of the inner GRU loop. In the case of
# TF for instance, it will leverage cuDNN when feasible, and
# it will raise NotImplementedError otherwise.
return backend.gru(
out = backend.gru(
sequences,
initial_state,
mask,
@ -536,6 +536,11 @@ class GRU(RNN):
unroll=self.unroll,
reset_after=self.cell.reset_after,
)
# We disable jit_compile for the model in this case,
# since cuDNN ops aren't XLA compatible.
if backend.backend() == "tensorflow":
self.supports_jit = False
return out
except NotImplementedError:
pass
return super().inner_loop(

@ -502,7 +502,7 @@ class LSTM(RNN):
# implementation of the inner LSTM loop. In the case of
# TF for instance, it will leverage cuDNN when feasible, and
# it will raise NotImplementedError otherwise.
return backend.lstm(
out = backend.lstm(
sequences,
initial_state[0],
initial_state[1],
@ -516,6 +516,11 @@ class LSTM(RNN):
go_backwards=self.go_backwards,
unroll=self.unroll,
)
# We disable jit_compile for the model in this case,
# since cuDNN ops aren't XLA compatible.
if backend.backend() == "tensorflow":
self.supports_jit = False
return out
except NotImplementedError:
pass
return super().inner_loop(

@ -5,7 +5,7 @@ from keras_core.backend import KerasTensor
class SymbolicArguments:
def __init__(self, *args, **kwargs):
# TODO: validation
self.args = tree.map_structure(lambda x: x, args)
self.kwargs = tree.map_structure(lambda x: x, kwargs)
self._flat_arguments = tree.flatten((self.args, self.kwargs))

@ -3,9 +3,13 @@ from keras_core.api_export import keras_core_export
from keras_core.optimizers import base_optimizer
if backend.backend() == "tensorflow":
from keras_core.backend.tensorflow import optimizer as tf_optimizer
from keras_core.backend.tensorflow.optimizer import TFOptimizer
BackendOptimizer = tf_optimizer.TFOptimizer
BackendOptimizer = TFOptimizer
elif backend.backend() == "torch":
from keras_core.backend.torch.optimizers import TorchOptimizer
BackendOptimizer = TorchOptimizer
else:
BackendOptimizer = base_optimizer.BaseOptimizer

@ -21,7 +21,7 @@ class SGDTest(testing.TestCase):
def test_single_step(self):
optimizer = SGD(learning_rate=0.5)
self.assertEqual(len(optimizer.variables), 2)
grads = np.array([1.0, 6.0, 7.0, 2.0])
grads = ops.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.build([vars])
optimizer.apply_gradients(zip([grads], [vars]))
@ -32,7 +32,7 @@ class SGDTest(testing.TestCase):
def test_weight_decay(self):
grads, var1, var2, var3 = (
np.zeros(()),
ops.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
@ -56,8 +56,8 @@ class SGDTest(testing.TestCase):
optimizer = SGD(nesterov=True)
x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)
grads = ops.arange(0.1, 1.1, 0.1)
first_grads = ops.full((10,), 0.01)
# fmt: off
golden = np.array(