diff --git a/examples/keras_io/structured_data/collaborative_filtering_movielens.py b/examples/keras_io/structured_data/collaborative_filtering_movielens.py new file mode 100644 index 000000000..2ff859161 --- /dev/null +++ b/examples/keras_io/structured_data/collaborative_filtering_movielens.py @@ -0,0 +1,236 @@ +""" +Title: Collaborative Filtering for Movie Recommendations +Author: [Siddhartha Banerjee](https://twitter.com/sidd2006) +Date created: 2020/05/24 +Last modified: 2020/05/24 +Description: Recommending movies using a model trained on Movielens dataset. +Accelerator: GPU +""" +""" +## Introduction + +This example demonstrates +[Collaborative filtering](https://en.wikipedia.org/wiki/Collaborative_filtering) +using the [Movielens dataset](https://www.kaggle.com/c/movielens-100k) +to recommend movies to users. +The MovieLens ratings dataset lists the ratings given by a set of users to a set of movies. +Our goal is to be able to predict ratings for movies a user has not yet watched. +The movies with the highest predicted ratings can then be recommended to the user. + +The steps in the model are as follows: + +1. Map user ID to a "user vector" via an embedding matrix +2. Map movie ID to a "movie vector" via an embedding matrix +3. Compute the dot product between the user vector and movie vector, to obtain +the a match score between the user and the movie (predicted rating). +4. Train the embeddings via gradient descent using all known user-movie pairs. + +**References:** + +- [Collaborative Filtering](https://dl.acm.org/doi/pdf/10.1145/371920.372071) +- [Neural Collaborative Filtering](https://dl.acm.org/doi/pdf/10.1145/3038912.3052569) +""" + +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from zipfile import ZipFile + +import keras_core as keras +from keras_core import layers +import keras_core.operations as ops +""" +## First, load the data and apply preprocessing +""" + +# Download the actual data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip" +# Use the ratings.csv file +movielens_data_file_url = ( + "http://files.grouplens.org/datasets/movielens/ml-latest-small.zip" +) +movielens_zipped_file = keras.utils.get_file( + "ml-latest-small.zip", movielens_data_file_url, extract=False +) +keras_datasets_path = Path(movielens_zipped_file).parents[0] +movielens_dir = keras_datasets_path / "ml-latest-small" + +# Only extract the data the first time the script is run. +if not movielens_dir.exists(): + with ZipFile(movielens_zipped_file, "r") as zip: + # Extract files + print("Extracting all the files now...") + zip.extractall(path=keras_datasets_path) + print("Done!") + +ratings_file = movielens_dir / "ratings.csv" +df = pd.read_csv(ratings_file) + +""" +First, need to perform some preprocessing to encode users and movies as integer indices. +""" +user_ids = df["userId"].unique().tolist() +user2user_encoded = {x: i for i, x in enumerate(user_ids)} +userencoded2user = {i: x for i, x in enumerate(user_ids)} +movie_ids = df["movieId"].unique().tolist() +movie2movie_encoded = {x: i for i, x in enumerate(movie_ids)} +movie_encoded2movie = {i: x for i, x in enumerate(movie_ids)} +df["user"] = df["userId"].map(user2user_encoded) +df["movie"] = df["movieId"].map(movie2movie_encoded) + +num_users = len(user2user_encoded) +num_movies = len(movie_encoded2movie) +df["rating"] = df["rating"].values.astype(np.float32) +# min and max ratings will be used to normalize the ratings later +min_rating = min(df["rating"]) +max_rating = max(df["rating"]) + +print( + "Number of users: {}, Number of Movies: {}, Min rating: {}, Max rating: {}".format( + num_users, num_movies, min_rating, max_rating + ) +) + +""" +## Prepare training and validation data +""" +df = df.sample(frac=1, random_state=42) +x = df[["user", "movie"]].values +# Normalize the targets between 0 and 1. Makes it easy to train. +y = df["rating"].apply(lambda x: (x - min_rating) / (max_rating - min_rating)).values +# Assuming training on 90% of the data and validating on 10%. +train_indices = int(0.9 * df.shape[0]) +x_train, x_val, y_train, y_val = ( + x[:train_indices], + x[train_indices:], + y[:train_indices], + y[train_indices:], +) + +""" +## Create the model + +We embed both users and movies in to 50-dimensional vectors. + +The model computes a match score between user and movie embeddings via a dot product, +and adds a per-movie and per-user bias. The match score is scaled to the `[0, 1]` +interval via a sigmoid (since our ratings are normalized to this range). +""" +EMBEDDING_SIZE = 50 + + +class RecommenderNet(keras.Model): + def __init__(self, num_users, num_movies, embedding_size, **kwargs): + super().__init__(**kwargs) + self.num_users = num_users + self.num_movies = num_movies + self.embedding_size = embedding_size + self.user_embedding = layers.Embedding( + num_users, + embedding_size, + embeddings_initializer="he_normal", + embeddings_regularizer=keras.regularizers.l2(1e-6), + ) + self.user_bias = layers.Embedding(num_users, 1) + self.movie_embedding = layers.Embedding( + num_movies, + embedding_size, + embeddings_initializer="he_normal", + embeddings_regularizer=keras.regularizers.l2(1e-6), + ) + self.movie_bias = layers.Embedding(num_movies, 1) + + def call(self, inputs): + user_vector = self.user_embedding(inputs[:, 0]) + user_bias = self.user_bias(inputs[:, 0]) + movie_vector = self.movie_embedding(inputs[:, 1]) + movie_bias = self.movie_bias(inputs[:, 1]) + dot_user_movie = ops.tensordot(user_vector, movie_vector, 2) + # Add all the components (including bias) + x = dot_user_movie + user_bias + movie_bias + # The sigmoid activation forces the rating to between 0 and 1 + return ops.nn.sigmoid(x) + + +model = RecommenderNet(num_users, num_movies, EMBEDDING_SIZE) +model.compile( + loss=keras.losses.BinaryCrossentropy(), + optimizer=keras.optimizers.Adam(learning_rate=0.001), +) + +""" +## Train the model based on the data split +""" +history = model.fit( + x=x_train, + y=y_train, + batch_size=64, + epochs=5, + verbose=1, + validation_data=(x_val, y_val), +) + +""" +## Plot training and validation loss +""" +plt.plot(history.history["loss"]) +plt.plot(history.history["val_loss"]) +plt.title("model loss") +plt.ylabel("loss") +plt.xlabel("epoch") +plt.legend(["train", "test"], loc="upper left") +plt.show() + +""" +## Show top 10 movie recommendations to a user +""" + +movie_df = pd.read_csv(movielens_dir / "movies.csv") + +# Let us get a user and see the top recommendations. +user_id = df.userId.sample(1).iloc[0] +movies_watched_by_user = df[df.userId == user_id] +movies_not_watched = movie_df[ + ~movie_df["movieId"].isin(movies_watched_by_user.movieId.values) +]["movieId"] +movies_not_watched = list( + set(movies_not_watched).intersection(set(movie2movie_encoded.keys())) +) +movies_not_watched = [[movie2movie_encoded.get(x)] for x in movies_not_watched] +user_encoder = user2user_encoded.get(user_id) +user_movie_array = np.hstack( + ([[user_encoder]] * len(movies_not_watched), movies_not_watched) +) +ratings = model.predict(user_movie_array).flatten() +top_ratings_indices = ratings.argsort()[-10:][::-1] +recommended_movie_ids = [ + movie_encoded2movie.get(movies_not_watched[x][0]) for x in top_ratings_indices +] + +print("Showing recommendations for user: {}".format(user_id)) +print("====" * 9) +print("Movies with high ratings from user") +print("----" * 8) +top_movies_user = ( + movies_watched_by_user.sort_values(by="rating", ascending=False) + .head(5) + .movieId.values +) +movie_df_rows = movie_df[movie_df["movieId"].isin(top_movies_user)] +for row in movie_df_rows.itertuples(): + print(row.title, ":", row.genres) + +print("----" * 8) +print("Top 10 movie recommendations") +print("----" * 8) +recommended_movies = movie_df[movie_df["movieId"].isin(recommended_movie_ids)] +for row in recommended_movies.itertuples(): + print(row.title, ":", row.genres) + +""" +**Example available on HuggingFace** + +| Trained Model | Demo | +| :--: | :--: | +| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-Collaborative%20Filtering-black.svg)](https://huggingface.co/keras-io/collaborative-filtering-movielens) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-Collaborative%20Filtering-black.svg)](https://huggingface.co/spaces/keras-io/collaborative-filtering-movielens) | +""" diff --git a/examples/keras_io/timeseries/timeseries_classification_transformer.py b/examples/keras_io/timeseries/timeseries_classification_transformer.py new file mode 100644 index 000000000..66b7bc738 --- /dev/null +++ b/examples/keras_io/timeseries/timeseries_classification_transformer.py @@ -0,0 +1,176 @@ +""" +Title: Timeseries classification with a Transformer model +Author: [Theodoros Ntakouris](https://github.com/ntakouris) +Date created: 2021/06/25 +Last modified: 2021/08/05 +Description: This notebook demonstrates how to do timeseries classification using a Transformer model. +Accelerator: GPU +""" + + +""" +## Introduction + +This is the Transformer architecture from +[Attention Is All You Need](https://arxiv.org/abs/1706.03762), +applied to timeseries instead of natural language. + +This example requires TensorFlow 2.4 or higher. + +## Load the dataset + +We are going to use the same dataset and preprocessing as the +[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch) +example. +""" + +import numpy as np + + +def readucr(filename): + data = np.loadtxt(filename, delimiter="\t") + y = data[:, 0] + x = data[:, 1:] + return x, y.astype(int) + + +root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/" + +x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv") +x_test, y_test = readucr(root_url + "FordA_TEST.tsv") + +x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1)) +x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1)) + +n_classes = len(np.unique(y_train)) + +idx = np.random.permutation(len(x_train)) +x_train = x_train[idx] +y_train = y_train[idx] + +y_train[y_train == -1] = 0 +y_test[y_test == -1] = 0 + +""" +## Build the model + +Our model processes a tensor of shape `(batch size, sequence length, features)`, +where `sequence length` is the number of time steps and `features` is each input +timeseries. + +You can replace your classification RNN layers with this one: the +inputs are fully compatible! +""" + +import keras_core as keras +from keras_core import layers + +""" +We include residual connections, layer normalization, and dropout. +The resulting layer can be stacked multiple times. + +The projection layers are implemented through `keras.layers.Conv1D`. +""" + + +def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0): + # Attention and Normalization + x = layers.MultiHeadAttention( + key_dim=head_size, num_heads=num_heads, dropout=dropout + )(inputs, inputs) + x = layers.Dropout(dropout)(x) + x = layers.LayerNormalization(epsilon=1e-6)(x) + res = x + inputs + + # Feed Forward Part + x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res) + x = layers.Dropout(dropout)(x) + x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x) + x = layers.LayerNormalization(epsilon=1e-6)(x) + return x + res + + +""" +The main part of our model is now complete. We can stack multiple of those +`transformer_encoder` blocks and we can also proceed to add the final +Multi-Layer Perceptron classification head. Apart from a stack of `Dense` +layers, we need to reduce the output tensor of the `TransformerEncoder` part of +our model down to a vector of features for each data point in the current +batch. A common way to achieve this is to use a pooling layer. For +this example, a `GlobalAveragePooling1D` layer is sufficient. +""" + + +def build_model( + input_shape, + head_size, + num_heads, + ff_dim, + num_transformer_blocks, + mlp_units, + dropout=0, + mlp_dropout=0, +): + inputs = keras.Input(shape=input_shape) + x = inputs + for _ in range(num_transformer_blocks): + x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout) + + x = layers.GlobalAveragePooling1D(data_format="channels_first")(x) + for dim in mlp_units: + x = layers.Dense(dim, activation="relu")(x) + x = layers.Dropout(mlp_dropout)(x) + outputs = layers.Dense(n_classes, activation="softmax")(x) + return keras.Model(inputs, outputs) + + +""" +## Train and evaluate +""" + +input_shape = x_train.shape[1:] + +model = build_model( + input_shape, + head_size=256, + num_heads=4, + ff_dim=4, + num_transformer_blocks=4, + mlp_units=[128], + mlp_dropout=0.4, + dropout=0.25, +) + +model.compile( + loss="sparse_categorical_crossentropy", + optimizer=keras.optimizers.Adam(learning_rate=1e-4), + metrics=["sparse_categorical_accuracy"], +) +model.summary() + +callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)] + +model.fit( + x_train, + y_train, + validation_split=0.2, + epochs=2, + batch_size=64, + callbacks=callbacks, +) + +model.evaluate(x_test, y_test, verbose=1) + +""" +## Conclusions + +In about 110-120 epochs (25s each on Colab), the model reaches a training +accuracy of ~0.95, validation accuracy of ~84 and a testing +accuracy of ~85, without hyperparameter tuning. And that is for a model +with less than 100k parameters. Of course, parameter count and accuracy could be +improved by a hyperparameter search and a more sophisticated learning rate +schedule, or a different optimizer. + +You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/timeseries_transformer_classification) +and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/timeseries_transformer_classification). +""" diff --git a/examples/keras_io/vision/oxford_pets_image_segmentation.py b/examples/keras_io/vision/oxford_pets_image_segmentation.py new file mode 100644 index 000000000..58387da82 --- /dev/null +++ b/examples/keras_io/vision/oxford_pets_image_segmentation.py @@ -0,0 +1,272 @@ +""" +Title: Image segmentation with a U-Net-like architecture +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2019/03/20 +Last modified: 2020/04/20 +Description: Image segmentation model trained from scratch on the Oxford Pets dataset. +Accelerator: GPU +""" +""" +## Download the data +""" + +"""shell +!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz +!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz + +curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz +curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz + +tar -xf images.tar.gz +tar -xf annotations.tar.gz +""" + +""" +## Prepare paths of input images and target segmentation masks +""" + +import os + +input_dir = "images/" +target_dir = "annotations/trimaps/" +img_size = (160, 160) +num_classes = 3 +batch_size = 32 + +input_img_paths = sorted( + [ + os.path.join(input_dir, fname) + for fname in os.listdir(input_dir) + if fname.endswith(".jpg") + ] +) +target_img_paths = sorted( + [ + os.path.join(target_dir, fname) + for fname in os.listdir(target_dir) + if fname.endswith(".png") and not fname.startswith(".") + ] +) + +print("Number of samples:", len(input_img_paths)) + +for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]): + print(input_path, "|", target_path) + +""" +## What does one input image and corresponding segmentation mask look like? +""" + +from IPython.display import Image, display +from keras_core.utils import load_img +from PIL import ImageOps + +# Display input image #7 +display(Image(filename=input_img_paths[9])) + +# Display auto-contrast version of corresponding target (per-pixel categories) +img = ImageOps.autocontrast(load_img(target_img_paths[9])) +display(img) + +""" +## Prepare dataset to load & vectorize batches of data +""" + +import keras_core as keras +import numpy as np +from tensorflow import data as tf_data +from tensorflow import image as tf_image +from tensorflow import io as tf_io + + +def get_dataset( + batch_size, + img_size, + input_img_paths, + target_img_paths, + max_dataset_len=None, +): + """Returns a TF Dataset.""" + + def load_img_masks(input_img_path, target_img_path): + input_img = tf_io.read_file(input_img_path) + input_img = tf_io.decode_png(input_img, channels=3) + input_img = tf_image.resize(input_img, img_size) + input_img = tf_image.convert_image_dtype(input_img, "float32") + + target_img = tf_io.read_file(target_img_path) + target_img = tf_io.decode_png(target_img, channels=1) + target_img = tf_image.resize(target_img, img_size, method="nearest") + target_img = tf_image.convert_image_dtype(target_img, "uint8") + + # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2: + target_img -= 1 + return input_img, target_img + + # For faster debugging, limit the size of data + if max_dataset_len: + input_img_paths = input_img_paths[:max_dataset_len] + target_img_paths = target_img_paths[:max_dataset_len] + dataset = tf_data.Dataset.from_tensor_slices( + (input_img_paths, target_img_paths) + ) + dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE) + return dataset.batch(batch_size) + + +""" +## Prepare U-Net Xception-style model +""" + +from keras_core import layers + + +def get_model(img_size, num_classes): + inputs = keras.Input(shape=img_size + (3,)) + + ### [First half of the network: downsampling inputs] ### + + # Entry block + x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs) + x = layers.BatchNormalization()(x) + x = layers.Activation("relu")(x) + + previous_block_activation = x # Set aside residual + + # Blocks 1, 2, 3 are identical apart from the feature depth. + for filters in [64, 128, 256]: + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(filters, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(filters, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.MaxPooling2D(3, strides=2, padding="same")(x) + + # Project residual + residual = layers.Conv2D(filters, 1, strides=2, padding="same")( + previous_block_activation + ) + x = layers.add([x, residual]) # Add back residual + previous_block_activation = x # Set aside next residual + + ### [Second half of the network: upsampling inputs] ### + + for filters in [256, 128, 64, 32]: + x = layers.Activation("relu")(x) + x = layers.Conv2DTranspose(filters, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.Activation("relu")(x) + x = layers.Conv2DTranspose(filters, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.UpSampling2D(2)(x) + + # Project residual + residual = layers.UpSampling2D(2)(previous_block_activation) + residual = layers.Conv2D(filters, 1, padding="same")(residual) + x = layers.add([x, residual]) # Add back residual + previous_block_activation = x # Set aside next residual + + # Add a per-pixel classification layer + outputs = layers.Conv2D( + num_classes, 3, activation="softmax", padding="same" + )(x) + + # Define the model + model = keras.Model(inputs, outputs) + return model + + +# Build model +model = get_model(img_size, num_classes) +model.summary() + +""" +## Set aside a validation split +""" + +import random + +# Split our img paths into a training and a validation set +val_samples = 1000 +random.Random(1337).shuffle(input_img_paths) +random.Random(1337).shuffle(target_img_paths) +train_input_img_paths = input_img_paths[:-val_samples] +train_target_img_paths = target_img_paths[:-val_samples] +val_input_img_paths = input_img_paths[-val_samples:] +val_target_img_paths = target_img_paths[-val_samples:] + +# Instantiate dataset for each split +# Limit input files in `max_dataset_len` for faster epoch training time. +# Remove the `max_dataset_len` arg when running with full dataset. +train_dataset = get_dataset( + batch_size, + img_size, + train_input_img_paths, + train_target_img_paths, + max_dataset_len=1000, +) +valid_dataset = get_dataset( + batch_size, img_size, val_input_img_paths, val_target_img_paths +) + +""" +## Train the model +""" + +# Configure the model for training. +# We use the "sparse" version of categorical_crossentropy +# because our target data is integers. +model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy") + +callbacks = [ + keras.callbacks.ModelCheckpoint( + "oxford_segmentation.keras", save_best_only=True + ) +] + +# Train the model, doing validation at the end of each epoch. +epochs = 15 +model.fit( + train_dataset, + epochs=epochs, + validation_data=valid_dataset, + callbacks=callbacks, +) + +""" +## Visualize predictions +""" + +# Generate predictions for all images in the validation set + +val_dataset = get_dataset( + batch_size, img_size, val_input_img_paths, val_target_img_paths +) +val_preds = model.predict(val_dataset) + + +def display_mask(i): + """Quick utility to display a model's prediction.""" + mask = np.argmax(val_preds[i], axis=-1) + mask = np.expand_dims(mask, axis=-1) + img = ImageOps.autocontrast(keras.utils.array_to_img(mask)) + display(img) + + +# Display results for validation image #10 +i = 10 + +# Display input image +display(Image(filename=val_input_img_paths[i])) + +# Display ground-truth target mask +img = ImageOps.autocontrast(load_img(val_target_img_paths[i])) +display(img) + +# Display mask predicted by our model +display_mask(i) # Note that the model only sees inputs at 150x150. diff --git a/keras_core/layers/reshaping/up_sampling2d.py b/keras_core/layers/reshaping/up_sampling2d.py index dd2e5f869..840c21c4a 100644 --- a/keras_core/layers/reshaping/up_sampling2d.py +++ b/keras_core/layers/reshaping/up_sampling2d.py @@ -154,7 +154,13 @@ class UpSampling2D(Layer): if data_format == "channels_first": x = ops.transpose(x, [0, 2, 3, 1]) - x = ops.image.resize(x, new_shape, method=interpolation) + # https://github.com/keras-team/keras-core/issues/294 + # Use `ops.repeat` for `nearest` interpolation + if interpolation == "nearest": + x = ops.repeat(x, height_factor, axis=1) + x = ops.repeat(x, width_factor, axis=2) + else: + x = ops.image.resize(x, new_shape, method=interpolation) if data_format == "channels_first": x = ops.transpose(x, [0, 3, 1, 2])