Adversarial training

Contents

Adversarial training#

Open in Colab

The following code trains a convolutional neural network (CNN) to be robust with respect to the projected gradient descent (PGD) method.

The Projected Gradient Descent Method (PGD) is a simple yet effective method to generate adversarial images. At each iteration, it adds a small perturbation in the direction of the sign of the gradient with respect to the input followed by a projection onto the infinity ball. The gradient sign ensures this perturbation locally maximizes the objective, while the projection ensures this perturbation stays on the boundary of the infinity ball.

References#

Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. “Explaining and harnessing adversarial examples.”, https://arxiv.org/abs/1412.6572

Madry, Aleksander, et al. “Towards deep learning models resistant to adversarial attacks.”, https://arxiv.org/abs/1706.06083

import datetime

import jax
from jax import numpy as jnp
from flax import linen as nn

import optax
from optax.losses import softmax_cross_entropy_with_integer_labels
from optax.tree_utils import tree_l2_norm

from matplotlib import pyplot as plt
plt.rcParams.update({"font.size": 22})

import tensorflow as tf
import tensorflow_datasets as tfds

# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")

# Show on which platform JAX is running.
print("JAX running on", jax.devices()[0].platform.upper())
JAX running on GPU
# @markdown Total number of epochs to train for:
EPOCHS = 10  # @param{type:"integer"}
# @markdown Number of samples for each batch in the training set:
TRAIN_BATCH_SIZE = 128  # @param{type:"integer"}
# @markdown Number of samples for each batch in the test set:
TEST_BATCH_SIZE = 128  # @param{type:"integer"}
# @markdown Learning rate for the optimizer:
LEARNING_RATE = 0.001  # @param{type:"number"}
# @markdown The dataset to use.
DATASET = "mnist"  # @param{type:"string"}
# @markdown The amount of L2 regularization to use:
L2_REG = 0.0001  # @param{type:"number"}
# @markdown Adversarial perturbations lie within the infinity-ball of radius epsilon.
EPSILON = 0.01  # @param{type:"number"}
class CNN(nn.Module):
  """A simple CNN model."""
  num_classes: int

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=self.num_classes)(x)
    return x
(train_loader, test_loader), mnist_info = tfds.load(
    "mnist", split=["train", "test"], as_supervised=True, with_info=True
)

train_loader_batched = train_loader.shuffle(
    10 * TRAIN_BATCH_SIZE, seed=0
).batch(TRAIN_BATCH_SIZE, drop_remainder=True)
test_loader_batched = test_loader.batch(TEST_BATCH_SIZE, drop_remainder=True)

input_shape = (1,) + mnist_info.features["image"].shape
num_classes = mnist_info.features["label"].num_classes
iter_per_epoch_train = (
    mnist_info.splits["train"].num_examples // TRAIN_BATCH_SIZE
)
iter_per_epoch_test = mnist_info.splits["test"].num_examples // TEST_BATCH_SIZE
net = CNN(num_classes)

@jax.jit
def accuracy(params, data):
  inputs, labels = data
  logits = net.apply({"params": params}, inputs)
  return jnp.mean(jnp.argmax(logits, axis=-1) == labels)


@jax.jit
def loss_fun(params, l2reg, data):
  """Compute the loss of the network."""
  inputs, labels = data
  x = inputs.astype(jnp.float32)
  logits = net.apply({"params": params}, x)
  sqnorm = tree_l2_norm(params, squared=True)
  loss_value = jnp.mean(softmax_cross_entropy_with_integer_labels(logits, labels))
  return loss_value + 0.5 * l2reg * sqnorm

@jax.jit
def pgd_attack(image, label, params, epsilon=0.1, maxiter=10):
  """PGD attack on the L-infinity ball with radius epsilon.

  Args:
    image: array-like, input data for the CNN
    label: integer, class label corresponding to image
    params: tree, parameters of the model to attack
    epsilon: float, radius of the L-infinity ball.
    maxiter: int, number of iterations of this algorithm.

  Returns:
    perturbed_image: Adversarial image on the boundary of the L-infinity ball
      of radius epsilon and centered at image.

  Notes:
    PGD attack is described in (Madry et al. 2017),
    https://arxiv.org/pdf/1706.06083.pdf
  """
  image_perturbation = jnp.zeros_like(image)
  def adversarial_loss(perturbation):
    return loss_fun(params, 0, (image + perturbation, label))

  grad_adversarial = jax.grad(adversarial_loss)
  for _ in range(maxiter):
    # compute gradient of the loss wrt to the image
    sign_grad = jnp.sign(grad_adversarial(image_perturbation))

    # heuristic step-size 2 eps / maxiter
    image_perturbation += (2 * epsilon / maxiter) * sign_grad
    # projection step onto the L-infinity ball centered at image
    image_perturbation = jnp.clip(image_perturbation, - epsilon, epsilon)

  # clip the image to ensure pixels are between 0 and 1
  return jnp.clip(image + image_perturbation, 0, 1)
def dataset_stats(params, data_loader, iter_per_epoch):
  """Computes accuracy on clean and adversarial images."""
  adversarial_accuracy = 0.
  clean_accuracy = 0.
  for batch in data_loader.as_numpy_iterator():
    images, labels = batch
    images = images.astype(jnp.float32) / 255
    clean_accuracy += jnp.mean(accuracy(params, (images, labels))) / iter_per_epoch
    adversarial_images = pgd_attack(images, labels, params, epsilon=EPSILON)
    adversarial_accuracy += jnp.mean(accuracy(params, (adversarial_images, labels))) / iter_per_epoch
  return {"adversarial accuracy": adversarial_accuracy, "accuracy": clean_accuracy}

@jax.jit
def train_step(params, opt_state, batch):
  images, labels = batch
  # convert images to float as attack requires to take gradients wrt to them
  images = images.astype(jnp.float32) / 255
  adversarial_images_train = pgd_attack(images, labels, params, epsilon=EPSILON)
  # train on adversarial images
  loss_grad_fun = jax.grad(loss_fun)
  grads = loss_grad_fun(params, L2_REG, (adversarial_images_train, labels))
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state
# Initialize parameters.
key = jax.random.PRNGKey(0)
var_params = net.init(key, jnp.zeros(input_shape))["params"]

# Initialize the optimizer.
optimizer = optax.adam(LEARNING_RATE)
opt_state = optimizer.init(var_params)

start = datetime.datetime.now().replace(microsecond=0)

accuracy_train = []
accuracy_test = []
adversarial_accuracy_train = []
adversarial_accuracy_test = []
for epoch in range(EPOCHS):
  for train_batch in train_loader_batched.as_numpy_iterator():
    var_params, opt_state = train_step(var_params, opt_state, train_batch)

  # compute train set accuracy, both on clean and adversarial images
  train_stats = dataset_stats(var_params, train_loader_batched, iter_per_epoch_train)
  accuracy_train.append(train_stats["accuracy"])
  adversarial_accuracy_train.append(train_stats["adversarial accuracy"])

  # compute test set accuracy, both on clean and adversarial images
  test_stats = dataset_stats(var_params, test_loader_batched, iter_per_epoch_test)
  accuracy_test.append(test_stats["accuracy"])
  adversarial_accuracy_test.append(test_stats["adversarial accuracy"])

  time_elapsed = (datetime.datetime.now().replace(microsecond=0) - start)
  print(f"Epoch {epoch} out of {EPOCHS}")
  print(f"Accuracy on train set: {accuracy_train[-1]:.3f}")
  print(f"Accuracy on test set: {accuracy_test[-1]:.3f}")
  print(f"Adversarial accuracy on train set: {adversarial_accuracy_train[-1]:.3f}")
  print(f"Adversarial accuracy on test set: {adversarial_accuracy_test[-1]:.3f}")
  print(f"Time elapsed: {time_elapsed}\n")
Epoch 0 out of 10
Accuracy on train set: 0.982
Accuracy on test set: 0.982
Adversarial accuracy on train set: 0.979
Adversarial accuracy on test set: 0.977
Time elapsed: 0:00:10

Epoch 1 out of 10
Accuracy on train set: 0.989
Accuracy on test set: 0.987
Adversarial accuracy on train set: 0.986
Adversarial accuracy on test set: 0.984
Time elapsed: 0:00:15

Epoch 2 out of 10
Accuracy on train set: 0.991
Accuracy on test set: 0.988
Adversarial accuracy on train set: 0.988
Adversarial accuracy on test set: 0.986
Time elapsed: 0:00:21

Epoch 3 out of 10
Accuracy on train set: 0.992
Accuracy on test set: 0.989
Adversarial accuracy on train set: 0.990
Adversarial accuracy on test set: 0.986
Time elapsed: 0:00:26

Epoch 4 out of 10
Accuracy on train set: 0.992
Accuracy on test set: 0.988
Adversarial accuracy on train set: 0.990
Adversarial accuracy on test set: 0.985
Time elapsed: 0:00:32

Epoch 5 out of 10
Accuracy on train set: 0.995
Accuracy on test set: 0.991
Adversarial accuracy on train set: 0.994
Adversarial accuracy on test set: 0.989
Time elapsed: 0:00:37

Epoch 6 out of 10
Accuracy on train set: 0.995
Accuracy on test set: 0.990
Adversarial accuracy on train set: 0.993
Adversarial accuracy on test set: 0.988
Time elapsed: 0:00:43

Epoch 7 out of 10
Accuracy on train set: 0.996
Accuracy on test set: 0.992
Adversarial accuracy on train set: 0.995
Adversarial accuracy on test set: 0.990
Time elapsed: 0:00:48

Epoch 8 out of 10
Accuracy on train set: 0.994
Accuracy on test set: 0.990
Adversarial accuracy on train set: 0.992
Adversarial accuracy on test set: 0.987
Time elapsed: 0:00:54

Epoch 9 out of 10
Accuracy on train set: 0.997
Accuracy on test set: 0.992
Adversarial accuracy on train set: 0.995
Adversarial accuracy on test set: 0.991
Time elapsed: 0:00:59
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))

plt.suptitle("Adversarial training on " + f"{DATASET}".upper())
axes[0].plot(
    accuracy_train, lw=3, label="train set.", marker="<", markersize=10
)
axes[0].plot(accuracy_test, lw=3, label="test set.", marker="d", markersize=10)
axes[0].grid()
axes[0].set_ylabel("accuracy on clean images")

axes[1].plot(
    adversarial_accuracy_train,
    lw=3,
    label="adversarial accuracy on train set.",
    marker="^",
    markersize=10,
)
axes[1].plot(
    adversarial_accuracy_test,
    lw=3,
    label="adversarial accuracy on test set.",
    marker=">",
    markersize=10,
)
axes[1].grid()
axes[0].legend(
    frameon=False, ncol=2, loc="upper center", bbox_to_anchor=(0.8, -0.1)
)
axes[0].set_xlabel("epochs")
axes[1].set_ylabel("accuracy on adversarial images")
plt.subplots_adjust(wspace=0.5)


plt.show()

Find a test set image that is correctly classified but not its adversarial perturbation

def find_adversarial_imgs(params, loader_batched):
  """Finds a test set image that is correctly classified but not its adversarial perturbation."""
  for batch in loader_batched.as_numpy_iterator():
    images, labels = batch
    images = images.astype(jnp.float32) / 255
    logits = net.apply({"params": params}, images)
    labels_clean = jnp.argmax(logits, axis=-1)

    adversarial_images = pgd_attack(images, labels, params, epsilon=EPSILON)
    labels_adversarial = jnp.argmax(
        net.apply({"params": params}, adversarial_images), axis=-1
    )
    idx_misclassified = jnp.where(labels_clean != labels_adversarial)[0]
    for j in idx_misclassified:
      clean_image = images[j]
      prediction_clean = labels_clean[j]
      if prediction_clean != labels[j]:
        # the clean image predicts the wrong label, skip
        continue
      adversarial_image = adversarial_images[j]
      adversarial_prediction = labels_adversarial[j]
      # we found our image
      return (
          clean_image,
          prediction_clean,
          adversarial_image,
          adversarial_prediction,
      )

  raise ValueError("No mismatch between clean and adversarial prediction found")


img_clean, pred_clean, img_adversarial, prediction_adversarial = (
    find_adversarial_imgs(var_params, test_loader_batched)
)
_, axes = plt.subplots(nrows=1, ncols=3, figsize=(6 * 3, 6))

axes[0].set_title("Clean image \n Prediction %s" % int(pred_clean))
axes[0].imshow(img_clean, cmap=plt.cm.get_cmap("Greys"), vmax=1, vmin=0)
axes[1].set_title("Adversarial image \n Prediction %s" % prediction_adversarial)
axes[1].imshow(img_adversarial, cmap=plt.cm.get_cmap("Greys"), vmax=1, vmin=0)
axes[2].set_title(r"|Adversarial - clean| $\times$ %.0f" % (1 / EPSILON))
axes[2].imshow(
    jnp.abs(img_clean - img_adversarial) / EPSILON,
    cmap=plt.cm.get_cmap("Greys"),
    vmax=1,
    vmin=0,
)
for i in range(3):
  axes[i].set_xticks(())
  axes[i].set_yticks(())
plt.show()