add Model interpretability with integrated gradients example (#238)
This commit is contained in:
parent
df7ce56cf1
commit
02dcb428f2
497
examples/keras_io/tensorflow/vision/integrated_gradients.py
Normal file
497
examples/keras_io/tensorflow/vision/integrated_gradients.py
Normal file
@ -0,0 +1,497 @@
|
||||
"""
|
||||
Title: Model interpretability with Integrated Gradients
|
||||
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
|
||||
Date created: 2020/06/02
|
||||
Last modified: 2020/06/02
|
||||
Description: How to obtain integrated gradients for a classification model.
|
||||
Accelerator: NONE
|
||||
"""
|
||||
|
||||
"""
|
||||
## Integrated Gradients
|
||||
|
||||
[Integrated Gradients](https://arxiv.org/abs/1703.01365) is a technique for
|
||||
attributing a classification model's prediction to its input features. It is
|
||||
a model interpretability technique: you can use it to visualize the relationship
|
||||
between input features and model predictions.
|
||||
|
||||
Integrated Gradients is a variation on computing
|
||||
the gradient of the prediction output with regard to features of the input.
|
||||
To compute integrated gradients, we need to perform the following steps:
|
||||
|
||||
1. Identify the input and the output. In our case, the input is an image and the
|
||||
output is the last layer of our model (dense layer with softmax activation).
|
||||
|
||||
2. Compute which features are important to a neural network
|
||||
when making a prediction on a particular data point. To identify these features, we
|
||||
need to choose a baseline input. A baseline input can be a black image (all pixel
|
||||
values set to zero) or random noise. The shape of the baseline input needs to be
|
||||
the same as our input image, e.g. (299, 299, 3).
|
||||
|
||||
3. Interpolate the baseline for a given number of steps. The number of steps represents
|
||||
the steps we need in the gradient approximation for a given input image. The number of
|
||||
steps is a hyperparameter. The authors recommend using anywhere between
|
||||
20 and 1000 steps.
|
||||
|
||||
4. Preprocess these interpolated images and do a forward pass.
|
||||
5. Get the gradients for these interpolated images.
|
||||
6. Approximate the gradients integral using the trapezoidal rule.
|
||||
|
||||
To read in-depth about integrated gradients and why this method works,
|
||||
consider reading this excellent
|
||||
[article](https://distill.pub/2020/attribution-baselines/).
|
||||
|
||||
**References:**
|
||||
|
||||
- Integrated Gradients original [paper](https://arxiv.org/abs/1703.01365)
|
||||
- [Original implementation](https://github.com/ankurtaly/Integrated-Gradients)
|
||||
"""
|
||||
|
||||
"""
|
||||
## Setup
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy import ndimage
|
||||
from IPython.display import Image, display
|
||||
|
||||
import tensorflow as tf
|
||||
import keras_core as keras
|
||||
from keras_core import layers
|
||||
from keras_core.applications import xception
|
||||
|
||||
keras.config.disable_traceback_filtering()
|
||||
|
||||
|
||||
# Size of the input image
|
||||
img_size = (299, 299, 3)
|
||||
|
||||
# Load Xception model with imagenet weights
|
||||
model = xception.Xception(weights="imagenet")
|
||||
|
||||
# The local path to our target image
|
||||
img_path = keras.utils.get_file("elephant.jpg", "https://i.imgur.com/Bvro0YD.png")
|
||||
display(Image(img_path))
|
||||
|
||||
"""
|
||||
## Integrated Gradients algorithm
|
||||
"""
|
||||
|
||||
|
||||
def get_img_array(img_path, size=(299, 299)):
|
||||
# `img` is a PIL image of size 299x299
|
||||
img = keras.utils.load_img(img_path, target_size=size)
|
||||
# `array` is a float32 Numpy array of shape (299, 299, 3)
|
||||
array = keras.utils.img_to_array(img)
|
||||
# We add a dimension to transform our array into a "batch"
|
||||
# of size (1, 299, 299, 3)
|
||||
array = np.expand_dims(array, axis=0)
|
||||
return array
|
||||
|
||||
|
||||
def get_gradients(img_input, top_pred_idx):
|
||||
"""Computes the gradients of outputs w.r.t input image.
|
||||
|
||||
Args:
|
||||
img_input: 4D image tensor
|
||||
top_pred_idx: Predicted label for the input image
|
||||
|
||||
Returns:
|
||||
Gradients of the predictions w.r.t img_input
|
||||
"""
|
||||
images = tf.cast(img_input, tf.float32)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch(images)
|
||||
preds = model(images)
|
||||
top_class = preds[:, top_pred_idx]
|
||||
|
||||
grads = tape.gradient(top_class, images)
|
||||
return grads
|
||||
|
||||
|
||||
def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):
|
||||
"""Computes Integrated Gradients for a predicted label.
|
||||
|
||||
Args:
|
||||
img_input (ndarray): Original image
|
||||
top_pred_idx: Predicted label for the input image
|
||||
baseline (ndarray): The baseline image to start with for interpolation
|
||||
num_steps: Number of interpolation steps between the baseline
|
||||
and the input used in the computation of integrated gradients. These
|
||||
steps along determine the integral approximation error. By default,
|
||||
num_steps is set to 50.
|
||||
|
||||
Returns:
|
||||
Integrated gradients w.r.t input image
|
||||
"""
|
||||
# If baseline is not provided, start with a black image
|
||||
# having same size as the input image.
|
||||
if baseline is None:
|
||||
baseline = np.zeros(img_size).astype(np.float32)
|
||||
else:
|
||||
baseline = baseline.astype(np.float32)
|
||||
|
||||
# 1. Do interpolation.
|
||||
img_input = img_input.astype(np.float32)
|
||||
interpolated_image = [
|
||||
baseline + (step / num_steps) * (img_input - baseline)
|
||||
for step in range(num_steps + 1)
|
||||
]
|
||||
interpolated_image = np.array(interpolated_image).astype(np.float32)
|
||||
|
||||
# 2. Preprocess the interpolated images
|
||||
interpolated_image = xception.preprocess_input(interpolated_image)
|
||||
|
||||
# 3. Get the gradients
|
||||
grads = []
|
||||
for i, img in enumerate(interpolated_image):
|
||||
img = tf.expand_dims(img, axis=0)
|
||||
grad = get_gradients(img, top_pred_idx=top_pred_idx)
|
||||
grads.append(grad[0])
|
||||
grads = tf.convert_to_tensor(grads, dtype=tf.float32)
|
||||
|
||||
# 4. Approximate the integral using the trapezoidal rule
|
||||
grads = (grads[:-1] + grads[1:]) / 2.0
|
||||
avg_grads = tf.reduce_mean(grads, axis=0)
|
||||
|
||||
# 5. Calculate integrated gradients and return
|
||||
integrated_grads = (img_input - baseline) * avg_grads
|
||||
return integrated_grads
|
||||
|
||||
|
||||
def random_baseline_integrated_gradients(
|
||||
img_input, top_pred_idx, num_steps=50, num_runs=2
|
||||
):
|
||||
"""Generates a number of random baseline images.
|
||||
|
||||
Args:
|
||||
img_input (ndarray): 3D image
|
||||
top_pred_idx: Predicted label for the input image
|
||||
num_steps: Number of interpolation steps between the baseline
|
||||
and the input used in the computation of integrated gradients. These
|
||||
steps along determine the integral approximation error. By default,
|
||||
num_steps is set to 50.
|
||||
num_runs: number of baseline images to generate
|
||||
|
||||
Returns:
|
||||
Averaged integrated gradients for `num_runs` baseline images
|
||||
"""
|
||||
# 1. List to keep track of Integrated Gradients (IG) for all the images
|
||||
integrated_grads = []
|
||||
|
||||
# 2. Get the integrated gradients for all the baselines
|
||||
for run in range(num_runs):
|
||||
baseline = np.random.random(img_size) * 255
|
||||
igrads = get_integrated_gradients(
|
||||
img_input=img_input,
|
||||
top_pred_idx=top_pred_idx,
|
||||
baseline=baseline,
|
||||
num_steps=num_steps,
|
||||
)
|
||||
integrated_grads.append(igrads)
|
||||
|
||||
# 3. Return the average integrated gradients for the image
|
||||
integrated_grads = tf.convert_to_tensor(integrated_grads)
|
||||
return tf.reduce_mean(integrated_grads, axis=0)
|
||||
|
||||
|
||||
"""
|
||||
## Helper class for visualizing gradients and integrated gradients
|
||||
"""
|
||||
|
||||
|
||||
class GradVisualizer:
|
||||
"""Plot gradients of the outputs w.r.t an input image."""
|
||||
|
||||
def __init__(self, positive_channel=None, negative_channel=None):
|
||||
if positive_channel is None:
|
||||
self.positive_channel = [0, 255, 0]
|
||||
else:
|
||||
self.positive_channel = positive_channel
|
||||
|
||||
if negative_channel is None:
|
||||
self.negative_channel = [255, 0, 0]
|
||||
else:
|
||||
self.negative_channel = negative_channel
|
||||
|
||||
def apply_polarity(self, attributions, polarity):
|
||||
if polarity == "positive":
|
||||
return np.clip(attributions, 0, 1)
|
||||
else:
|
||||
return np.clip(attributions, -1, 0)
|
||||
|
||||
def apply_linear_transformation(
|
||||
self,
|
||||
attributions,
|
||||
clip_above_percentile=99.9,
|
||||
clip_below_percentile=70.0,
|
||||
lower_end=0.2,
|
||||
):
|
||||
# 1. Get the thresholds
|
||||
m = self.get_thresholded_attributions(
|
||||
attributions, percentage=100 - clip_above_percentile
|
||||
)
|
||||
e = self.get_thresholded_attributions(
|
||||
attributions, percentage=100 - clip_below_percentile
|
||||
)
|
||||
|
||||
# 2. Transform the attributions by a linear function f(x) = a*x + b such that
|
||||
# f(m) = 1.0 and f(e) = lower_end
|
||||
transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (
|
||||
m - e
|
||||
) + lower_end
|
||||
|
||||
# 3. Make sure that the sign of transformed attributions is the same as original attributions
|
||||
transformed_attributions *= np.sign(attributions)
|
||||
|
||||
# 4. Only keep values that are bigger than the lower_end
|
||||
transformed_attributions *= transformed_attributions >= lower_end
|
||||
|
||||
# 5. Clip values and return
|
||||
transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)
|
||||
return transformed_attributions
|
||||
|
||||
def get_thresholded_attributions(self, attributions, percentage):
|
||||
if percentage == 100.0:
|
||||
return np.min(attributions)
|
||||
|
||||
# 1. Flatten the attributions
|
||||
flatten_attr = attributions.flatten()
|
||||
|
||||
# 2. Get the sum of the attributions
|
||||
total = np.sum(flatten_attr)
|
||||
|
||||
# 3. Sort the attributions from largest to smallest.
|
||||
sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]
|
||||
|
||||
# 4. Calculate the percentage of the total sum that each attribution
|
||||
# and the values about it contribute.
|
||||
cum_sum = 100.0 * np.cumsum(sorted_attributions) / total
|
||||
|
||||
# 5. Threshold the attributions by the percentage
|
||||
indices_to_consider = np.where(cum_sum >= percentage)[0][0]
|
||||
|
||||
# 6. Select the desired attributions and return
|
||||
attributions = sorted_attributions[indices_to_consider]
|
||||
return attributions
|
||||
|
||||
def binarize(self, attributions, threshold=0.001):
|
||||
return attributions > threshold
|
||||
|
||||
def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):
|
||||
closed = ndimage.grey_closing(attributions, structure=structure)
|
||||
opened = ndimage.grey_opening(closed, structure=structure)
|
||||
return opened
|
||||
|
||||
def draw_outlines(
|
||||
self, attributions, percentage=90, connected_component_structure=np.ones((3, 3))
|
||||
):
|
||||
# 1. Binarize the attributions.
|
||||
attributions = self.binarize(attributions)
|
||||
|
||||
# 2. Fill the gaps
|
||||
attributions = ndimage.binary_fill_holes(attributions)
|
||||
|
||||
# 3. Compute connected components
|
||||
connected_components, num_comp = ndimage.label(
|
||||
attributions, structure=connected_component_structure
|
||||
)
|
||||
|
||||
# 4. Sum up the attributions for each component
|
||||
total = np.sum(attributions[connected_components > 0])
|
||||
component_sums = []
|
||||
for comp in range(1, num_comp + 1):
|
||||
mask = connected_components == comp
|
||||
component_sum = np.sum(attributions[mask])
|
||||
component_sums.append((component_sum, mask))
|
||||
|
||||
# 5. Compute the percentage of top components to keep
|
||||
sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)
|
||||
sorted_sums = list(zip(*sorted_sums_and_masks))[0]
|
||||
cumulative_sorted_sums = np.cumsum(sorted_sums)
|
||||
cutoff_threshold = percentage * total / 100
|
||||
cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]
|
||||
if cutoff_idx > 2:
|
||||
cutoff_idx = 2
|
||||
|
||||
# 6. Set the values for the kept components
|
||||
border_mask = np.zeros_like(attributions)
|
||||
for i in range(cutoff_idx + 1):
|
||||
border_mask[sorted_sums_and_masks[i][1]] = 1
|
||||
|
||||
# 7. Make the mask hollow and show only the border
|
||||
eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)
|
||||
border_mask[eroded_mask] = 0
|
||||
|
||||
# 8. Return the outlined mask
|
||||
return border_mask
|
||||
|
||||
def process_grads(
|
||||
self,
|
||||
image,
|
||||
attributions,
|
||||
polarity="positive",
|
||||
clip_above_percentile=99.9,
|
||||
clip_below_percentile=0,
|
||||
morphological_cleanup=False,
|
||||
structure=np.ones((3, 3)),
|
||||
outlines=False,
|
||||
outlines_component_percentage=90,
|
||||
overlay=True,
|
||||
):
|
||||
if polarity not in ["positive", "negative"]:
|
||||
raise ValueError(
|
||||
f""" Allowed polarity values: 'positive' or 'negative'
|
||||
but provided {polarity}"""
|
||||
)
|
||||
if clip_above_percentile < 0 or clip_above_percentile > 100:
|
||||
raise ValueError("clip_above_percentile must be in [0, 100]")
|
||||
|
||||
if clip_below_percentile < 0 or clip_below_percentile > 100:
|
||||
raise ValueError("clip_below_percentile must be in [0, 100]")
|
||||
|
||||
# 1. Apply polarity
|
||||
if polarity == "positive":
|
||||
attributions = self.apply_polarity(attributions, polarity=polarity)
|
||||
channel = self.positive_channel
|
||||
else:
|
||||
attributions = self.apply_polarity(attributions, polarity=polarity)
|
||||
attributions = np.abs(attributions)
|
||||
channel = self.negative_channel
|
||||
|
||||
# 2. Take average over the channels
|
||||
attributions = np.average(attributions, axis=2)
|
||||
|
||||
# 3. Apply linear transformation to the attributions
|
||||
attributions = self.apply_linear_transformation(
|
||||
attributions,
|
||||
clip_above_percentile=clip_above_percentile,
|
||||
clip_below_percentile=clip_below_percentile,
|
||||
lower_end=0.0,
|
||||
)
|
||||
|
||||
# 4. Cleanup
|
||||
if morphological_cleanup:
|
||||
attributions = self.morphological_cleanup_fn(
|
||||
attributions, structure=structure
|
||||
)
|
||||
# 5. Draw the outlines
|
||||
if outlines:
|
||||
attributions = self.draw_outlines(
|
||||
attributions, percentage=outlines_component_percentage
|
||||
)
|
||||
|
||||
# 6. Expand the channel axis and convert to RGB
|
||||
attributions = np.expand_dims(attributions, 2) * channel
|
||||
|
||||
# 7.Superimpose on the original image
|
||||
if overlay:
|
||||
attributions = np.clip((attributions * 0.8 + image), 0, 255)
|
||||
return attributions
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
image,
|
||||
gradients,
|
||||
integrated_gradients,
|
||||
polarity="positive",
|
||||
clip_above_percentile=99.9,
|
||||
clip_below_percentile=0,
|
||||
morphological_cleanup=False,
|
||||
structure=np.ones((3, 3)),
|
||||
outlines=False,
|
||||
outlines_component_percentage=90,
|
||||
overlay=True,
|
||||
figsize=(15, 8),
|
||||
):
|
||||
# 1. Make two copies of the original image
|
||||
img1 = np.copy(image)
|
||||
img2 = np.copy(image)
|
||||
|
||||
# 2. Process the normal gradients
|
||||
grads_attr = self.process_grads(
|
||||
image=img1,
|
||||
attributions=gradients,
|
||||
polarity=polarity,
|
||||
clip_above_percentile=clip_above_percentile,
|
||||
clip_below_percentile=clip_below_percentile,
|
||||
morphological_cleanup=morphological_cleanup,
|
||||
structure=structure,
|
||||
outlines=outlines,
|
||||
outlines_component_percentage=outlines_component_percentage,
|
||||
overlay=overlay,
|
||||
)
|
||||
|
||||
# 3. Process the integrated gradients
|
||||
igrads_attr = self.process_grads(
|
||||
image=img2,
|
||||
attributions=integrated_gradients,
|
||||
polarity=polarity,
|
||||
clip_above_percentile=clip_above_percentile,
|
||||
clip_below_percentile=clip_below_percentile,
|
||||
morphological_cleanup=morphological_cleanup,
|
||||
structure=structure,
|
||||
outlines=outlines,
|
||||
outlines_component_percentage=outlines_component_percentage,
|
||||
overlay=overlay,
|
||||
)
|
||||
|
||||
_, ax = plt.subplots(1, 3, figsize=figsize)
|
||||
ax[0].imshow(image)
|
||||
ax[1].imshow(grads_attr.astype(np.uint8))
|
||||
ax[2].imshow(igrads_attr.astype(np.uint8))
|
||||
|
||||
ax[0].set_title("Input")
|
||||
ax[1].set_title("Normal gradients")
|
||||
ax[2].set_title("Integrated gradients")
|
||||
plt.show()
|
||||
|
||||
|
||||
"""
|
||||
## Let's test-drive it
|
||||
"""
|
||||
|
||||
# 1. Convert the image to numpy array
|
||||
img = get_img_array(img_path)
|
||||
|
||||
# 2. Keep a copy of the original image
|
||||
orig_img = np.copy(img[0]).astype(np.uint8)
|
||||
|
||||
# 3. Preprocess the image
|
||||
img_processed = tf.cast(xception.preprocess_input(img), dtype=tf.float32)
|
||||
|
||||
# 4. Get model predictions
|
||||
preds = model.predict(img_processed)
|
||||
top_pred_idx = tf.argmax(preds[0])
|
||||
print("Predicted:", top_pred_idx, xception.decode_predictions(preds, top=1)[0])
|
||||
|
||||
# 5. Get the gradients of the last layer for the predicted label
|
||||
grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)
|
||||
|
||||
# 6. Get the integrated gradients
|
||||
igrads = random_baseline_integrated_gradients(
|
||||
np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2
|
||||
)
|
||||
|
||||
# 7. Process the gradients and plot
|
||||
vis = GradVisualizer()
|
||||
vis.visualize(
|
||||
image=orig_img,
|
||||
gradients=grads[0].numpy(),
|
||||
integrated_gradients=igrads.numpy(),
|
||||
clip_above_percentile=99,
|
||||
clip_below_percentile=0,
|
||||
)
|
||||
|
||||
vis.visualize(
|
||||
image=orig_img,
|
||||
gradients=grads[0].numpy(),
|
||||
integrated_gradients=igrads.numpy(),
|
||||
clip_above_percentile=95,
|
||||
clip_below_percentile=28,
|
||||
morphological_cleanup=True,
|
||||
outlines=True,
|
||||
)
|
Loading…
Reference in New Issue
Block a user