Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-06-11 17:29:33 -07:00
parent 2978ae66f0
commit 0dc1025264
4 changed files with 691 additions and 1 deletions

@ -0,0 +1,236 @@
"""
Title: Collaborative Filtering for Movie Recommendations
Author: [Siddhartha Banerjee](https://twitter.com/sidd2006)
Date created: 2020/05/24
Last modified: 2020/05/24
Description: Recommending movies using a model trained on Movielens dataset.
Accelerator: GPU
"""
"""
## Introduction
This example demonstrates
[Collaborative filtering](https://en.wikipedia.org/wiki/Collaborative_filtering)
using the [Movielens dataset](https://www.kaggle.com/c/movielens-100k)
to recommend movies to users.
The MovieLens ratings dataset lists the ratings given by a set of users to a set of movies.
Our goal is to be able to predict ratings for movies a user has not yet watched.
The movies with the highest predicted ratings can then be recommended to the user.
The steps in the model are as follows:
1. Map user ID to a "user vector" via an embedding matrix
2. Map movie ID to a "movie vector" via an embedding matrix
3. Compute the dot product between the user vector and movie vector, to obtain
the a match score between the user and the movie (predicted rating).
4. Train the embeddings via gradient descent using all known user-movie pairs.
**References:**
- [Collaborative Filtering](https://dl.acm.org/doi/pdf/10.1145/371920.372071)
- [Neural Collaborative Filtering](https://dl.acm.org/doi/pdf/10.1145/3038912.3052569)
"""
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from zipfile import ZipFile
import keras_core as keras
from keras_core import layers
import keras_core.operations as ops
"""
## First, load the data and apply preprocessing
"""
# Download the actual data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip"
# Use the ratings.csv file
movielens_data_file_url = (
"http://files.grouplens.org/datasets/movielens/ml-latest-small.zip"
)
movielens_zipped_file = keras.utils.get_file(
"ml-latest-small.zip", movielens_data_file_url, extract=False
)
keras_datasets_path = Path(movielens_zipped_file).parents[0]
movielens_dir = keras_datasets_path / "ml-latest-small"
# Only extract the data the first time the script is run.
if not movielens_dir.exists():
with ZipFile(movielens_zipped_file, "r") as zip:
# Extract files
print("Extracting all the files now...")
zip.extractall(path=keras_datasets_path)
print("Done!")
ratings_file = movielens_dir / "ratings.csv"
df = pd.read_csv(ratings_file)
"""
First, need to perform some preprocessing to encode users and movies as integer indices.
"""
user_ids = df["userId"].unique().tolist()
user2user_encoded = {x: i for i, x in enumerate(user_ids)}
userencoded2user = {i: x for i, x in enumerate(user_ids)}
movie_ids = df["movieId"].unique().tolist()
movie2movie_encoded = {x: i for i, x in enumerate(movie_ids)}
movie_encoded2movie = {i: x for i, x in enumerate(movie_ids)}
df["user"] = df["userId"].map(user2user_encoded)
df["movie"] = df["movieId"].map(movie2movie_encoded)
num_users = len(user2user_encoded)
num_movies = len(movie_encoded2movie)
df["rating"] = df["rating"].values.astype(np.float32)
# min and max ratings will be used to normalize the ratings later
min_rating = min(df["rating"])
max_rating = max(df["rating"])
print(
"Number of users: {}, Number of Movies: {}, Min rating: {}, Max rating: {}".format(
num_users, num_movies, min_rating, max_rating
)
)
"""
## Prepare training and validation data
"""
df = df.sample(frac=1, random_state=42)
x = df[["user", "movie"]].values
# Normalize the targets between 0 and 1. Makes it easy to train.
y = df["rating"].apply(lambda x: (x - min_rating) / (max_rating - min_rating)).values
# Assuming training on 90% of the data and validating on 10%.
train_indices = int(0.9 * df.shape[0])
x_train, x_val, y_train, y_val = (
x[:train_indices],
x[train_indices:],
y[:train_indices],
y[train_indices:],
)
"""
## Create the model
We embed both users and movies in to 50-dimensional vectors.
The model computes a match score between user and movie embeddings via a dot product,
and adds a per-movie and per-user bias. The match score is scaled to the `[0, 1]`
interval via a sigmoid (since our ratings are normalized to this range).
"""
EMBEDDING_SIZE = 50
class RecommenderNet(keras.Model):
def __init__(self, num_users, num_movies, embedding_size, **kwargs):
super().__init__(**kwargs)
self.num_users = num_users
self.num_movies = num_movies
self.embedding_size = embedding_size
self.user_embedding = layers.Embedding(
num_users,
embedding_size,
embeddings_initializer="he_normal",
embeddings_regularizer=keras.regularizers.l2(1e-6),
)
self.user_bias = layers.Embedding(num_users, 1)
self.movie_embedding = layers.Embedding(
num_movies,
embedding_size,
embeddings_initializer="he_normal",
embeddings_regularizer=keras.regularizers.l2(1e-6),
)
self.movie_bias = layers.Embedding(num_movies, 1)
def call(self, inputs):
user_vector = self.user_embedding(inputs[:, 0])
user_bias = self.user_bias(inputs[:, 0])
movie_vector = self.movie_embedding(inputs[:, 1])
movie_bias = self.movie_bias(inputs[:, 1])
dot_user_movie = ops.tensordot(user_vector, movie_vector, 2)
# Add all the components (including bias)
x = dot_user_movie + user_bias + movie_bias
# The sigmoid activation forces the rating to between 0 and 1
return ops.nn.sigmoid(x)
model = RecommenderNet(num_users, num_movies, EMBEDDING_SIZE)
model.compile(
loss=keras.losses.BinaryCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=0.001),
)
"""
## Train the model based on the data split
"""
history = model.fit(
x=x_train,
y=y_train,
batch_size=64,
epochs=5,
verbose=1,
validation_data=(x_val, y_val),
)
"""
## Plot training and validation loss
"""
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.show()
"""
## Show top 10 movie recommendations to a user
"""
movie_df = pd.read_csv(movielens_dir / "movies.csv")
# Let us get a user and see the top recommendations.
user_id = df.userId.sample(1).iloc[0]
movies_watched_by_user = df[df.userId == user_id]
movies_not_watched = movie_df[
~movie_df["movieId"].isin(movies_watched_by_user.movieId.values)
]["movieId"]
movies_not_watched = list(
set(movies_not_watched).intersection(set(movie2movie_encoded.keys()))
)
movies_not_watched = [[movie2movie_encoded.get(x)] for x in movies_not_watched]
user_encoder = user2user_encoded.get(user_id)
user_movie_array = np.hstack(
([[user_encoder]] * len(movies_not_watched), movies_not_watched)
)
ratings = model.predict(user_movie_array).flatten()
top_ratings_indices = ratings.argsort()[-10:][::-1]
recommended_movie_ids = [
movie_encoded2movie.get(movies_not_watched[x][0]) for x in top_ratings_indices
]
print("Showing recommendations for user: {}".format(user_id))
print("====" * 9)
print("Movies with high ratings from user")
print("----" * 8)
top_movies_user = (
movies_watched_by_user.sort_values(by="rating", ascending=False)
.head(5)
.movieId.values
)
movie_df_rows = movie_df[movie_df["movieId"].isin(top_movies_user)]
for row in movie_df_rows.itertuples():
print(row.title, ":", row.genres)
print("----" * 8)
print("Top 10 movie recommendations")
print("----" * 8)
recommended_movies = movie_df[movie_df["movieId"].isin(recommended_movie_ids)]
for row in recommended_movies.itertuples():
print(row.title, ":", row.genres)
"""
**Example available on HuggingFace**
| Trained Model | Demo |
| :--: | :--: |
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-Collaborative%20Filtering-black.svg)](https://huggingface.co/keras-io/collaborative-filtering-movielens) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-Collaborative%20Filtering-black.svg)](https://huggingface.co/spaces/keras-io/collaborative-filtering-movielens) |
"""

@ -0,0 +1,176 @@
"""
Title: Timeseries classification with a Transformer model
Author: [Theodoros Ntakouris](https://github.com/ntakouris)
Date created: 2021/06/25
Last modified: 2021/08/05
Description: This notebook demonstrates how to do timeseries classification using a Transformer model.
Accelerator: GPU
"""
"""
## Introduction
This is the Transformer architecture from
[Attention Is All You Need](https://arxiv.org/abs/1706.03762),
applied to timeseries instead of natural language.
This example requires TensorFlow 2.4 or higher.
## Load the dataset
We are going to use the same dataset and preprocessing as the
[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)
example.
"""
import numpy as np
def readucr(filename):
data = np.loadtxt(filename, delimiter="\t")
y = data[:, 0]
x = data[:, 1:]
return x, y.astype(int)
root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"
x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")
x_test, y_test = readucr(root_url + "FordA_TEST.tsv")
x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))
x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))
n_classes = len(np.unique(y_train))
idx = np.random.permutation(len(x_train))
x_train = x_train[idx]
y_train = y_train[idx]
y_train[y_train == -1] = 0
y_test[y_test == -1] = 0
"""
## Build the model
Our model processes a tensor of shape `(batch size, sequence length, features)`,
where `sequence length` is the number of time steps and `features` is each input
timeseries.
You can replace your classification RNN layers with this one: the
inputs are fully compatible!
"""
import keras_core as keras
from keras_core import layers
"""
We include residual connections, layer normalization, and dropout.
The resulting layer can be stacked multiple times.
The projection layers are implemented through `keras.layers.Conv1D`.
"""
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
# Attention and Normalization
x = layers.MultiHeadAttention(
key_dim=head_size, num_heads=num_heads, dropout=dropout
)(inputs, inputs)
x = layers.Dropout(dropout)(x)
x = layers.LayerNormalization(epsilon=1e-6)(x)
res = x + inputs
# Feed Forward Part
x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
x = layers.Dropout(dropout)(x)
x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
x = layers.LayerNormalization(epsilon=1e-6)(x)
return x + res
"""
The main part of our model is now complete. We can stack multiple of those
`transformer_encoder` blocks and we can also proceed to add the final
Multi-Layer Perceptron classification head. Apart from a stack of `Dense`
layers, we need to reduce the output tensor of the `TransformerEncoder` part of
our model down to a vector of features for each data point in the current
batch. A common way to achieve this is to use a pooling layer. For
this example, a `GlobalAveragePooling1D` layer is sufficient.
"""
def build_model(
input_shape,
head_size,
num_heads,
ff_dim,
num_transformer_blocks,
mlp_units,
dropout=0,
mlp_dropout=0,
):
inputs = keras.Input(shape=input_shape)
x = inputs
for _ in range(num_transformer_blocks):
x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)
x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
for dim in mlp_units:
x = layers.Dense(dim, activation="relu")(x)
x = layers.Dropout(mlp_dropout)(x)
outputs = layers.Dense(n_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
"""
## Train and evaluate
"""
input_shape = x_train.shape[1:]
model = build_model(
input_shape,
head_size=256,
num_heads=4,
ff_dim=4,
num_transformer_blocks=4,
mlp_units=[128],
mlp_dropout=0.4,
dropout=0.25,
)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
metrics=["sparse_categorical_accuracy"],
)
model.summary()
callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
model.fit(
x_train,
y_train,
validation_split=0.2,
epochs=2,
batch_size=64,
callbacks=callbacks,
)
model.evaluate(x_test, y_test, verbose=1)
"""
## Conclusions
In about 110-120 epochs (25s each on Colab), the model reaches a training
accuracy of ~0.95, validation accuracy of ~84 and a testing
accuracy of ~85, without hyperparameter tuning. And that is for a model
with less than 100k parameters. Of course, parameter count and accuracy could be
improved by a hyperparameter search and a more sophisticated learning rate
schedule, or a different optimizer.
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/timeseries_transformer_classification)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/timeseries_transformer_classification).
"""

@ -0,0 +1,272 @@
"""
Title: Image segmentation with a U-Net-like architecture
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/03/20
Last modified: 2020/04/20
Description: Image segmentation model trained from scratch on the Oxford Pets dataset.
Accelerator: GPU
"""
"""
## Download the data
"""
"""shell
!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz
curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz
tar -xf images.tar.gz
tar -xf annotations.tar.gz
"""
"""
## Prepare paths of input images and target segmentation masks
"""
import os
input_dir = "images/"
target_dir = "annotations/trimaps/"
img_size = (160, 160)
num_classes = 3
batch_size = 32
input_img_paths = sorted(
[
os.path.join(input_dir, fname)
for fname in os.listdir(input_dir)
if fname.endswith(".jpg")
]
)
target_img_paths = sorted(
[
os.path.join(target_dir, fname)
for fname in os.listdir(target_dir)
if fname.endswith(".png") and not fname.startswith(".")
]
)
print("Number of samples:", len(input_img_paths))
for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
print(input_path, "|", target_path)
"""
## What does one input image and corresponding segmentation mask look like?
"""
from IPython.display import Image, display
from keras_core.utils import load_img
from PIL import ImageOps
# Display input image #7
display(Image(filename=input_img_paths[9]))
# Display auto-contrast version of corresponding target (per-pixel categories)
img = ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)
"""
## Prepare dataset to load & vectorize batches of data
"""
import keras_core as keras
import numpy as np
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io
def get_dataset(
batch_size,
img_size,
input_img_paths,
target_img_paths,
max_dataset_len=None,
):
"""Returns a TF Dataset."""
def load_img_masks(input_img_path, target_img_path):
input_img = tf_io.read_file(input_img_path)
input_img = tf_io.decode_png(input_img, channels=3)
input_img = tf_image.resize(input_img, img_size)
input_img = tf_image.convert_image_dtype(input_img, "float32")
target_img = tf_io.read_file(target_img_path)
target_img = tf_io.decode_png(target_img, channels=1)
target_img = tf_image.resize(target_img, img_size, method="nearest")
target_img = tf_image.convert_image_dtype(target_img, "uint8")
# Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
target_img -= 1
return input_img, target_img
# For faster debugging, limit the size of data
if max_dataset_len:
input_img_paths = input_img_paths[:max_dataset_len]
target_img_paths = target_img_paths[:max_dataset_len]
dataset = tf_data.Dataset.from_tensor_slices(
(input_img_paths, target_img_paths)
)
dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE)
return dataset.batch(batch_size)
"""
## Prepare U-Net Xception-style model
"""
from keras_core import layers
def get_model(img_size, num_classes):
inputs = keras.Input(shape=img_size + (3,))
### [First half of the network: downsampling inputs] ###
# Entry block
x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
# Blocks 1, 2, 3 are identical apart from the feature depth.
for filters in [64, 128, 256]:
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
### [Second half of the network: upsampling inputs] ###
for filters in [256, 128, 64, 32]:
x = layers.Activation("relu")(x)
x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D(2)(x)
# Project residual
residual = layers.UpSampling2D(2)(previous_block_activation)
residual = layers.Conv2D(filters, 1, padding="same")(residual)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
# Add a per-pixel classification layer
outputs = layers.Conv2D(
num_classes, 3, activation="softmax", padding="same"
)(x)
# Define the model
model = keras.Model(inputs, outputs)
return model
# Build model
model = get_model(img_size, num_classes)
model.summary()
"""
## Set aside a validation split
"""
import random
# Split our img paths into a training and a validation set
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]
# Instantiate dataset for each split
# Limit input files in `max_dataset_len` for faster epoch training time.
# Remove the `max_dataset_len` arg when running with full dataset.
train_dataset = get_dataset(
batch_size,
img_size,
train_input_img_paths,
train_target_img_paths,
max_dataset_len=1000,
)
valid_dataset = get_dataset(
batch_size, img_size, val_input_img_paths, val_target_img_paths
)
"""
## Train the model
"""
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy")
callbacks = [
keras.callbacks.ModelCheckpoint(
"oxford_segmentation.keras", save_best_only=True
)
]
# Train the model, doing validation at the end of each epoch.
epochs = 15
model.fit(
train_dataset,
epochs=epochs,
validation_data=valid_dataset,
callbacks=callbacks,
)
"""
## Visualize predictions
"""
# Generate predictions for all images in the validation set
val_dataset = get_dataset(
batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds = model.predict(val_dataset)
def display_mask(i):
"""Quick utility to display a model's prediction."""
mask = np.argmax(val_preds[i], axis=-1)
mask = np.expand_dims(mask, axis=-1)
img = ImageOps.autocontrast(keras.utils.array_to_img(mask))
display(img)
# Display results for validation image #10
i = 10
# Display input image
display(Image(filename=val_input_img_paths[i]))
# Display ground-truth target mask
img = ImageOps.autocontrast(load_img(val_target_img_paths[i]))
display(img)
# Display mask predicted by our model
display_mask(i) # Note that the model only sees inputs at 150x150.

@ -154,7 +154,13 @@ class UpSampling2D(Layer):
if data_format == "channels_first":
x = ops.transpose(x, [0, 2, 3, 1])
x = ops.image.resize(x, new_shape, method=interpolation)
# https://github.com/keras-team/keras-core/issues/294
# Use `ops.repeat` for `nearest` interpolation
if interpolation == "nearest":
x = ops.repeat(x, height_factor, axis=1)
x = ops.repeat(x, width_factor, axis=2)
else:
x = ops.image.resize(x, new_shape, method=interpolation)
if data_format == "channels_first":
x = ops.transpose(x, [0, 3, 1, 2])