Differentially private convolutional neural network on MNIST.

Differentially private convolutional neural network on MNIST.#

Open in Colab

A large portion of this code is forked from the differentially private SGD example in the JAX repo.

Differentially Private Stochastic Gradient Descent requires clipping the per-example parameter gradients, which is non-trivial to implement efficiently for convolutional neural networks. The JAX XLA compiler shines in this setting by optimizing the minibatch-vectorized computation for convolutional architectures. Train time takes a few seconds per epoch on a commodity GPU.

import warnings
import dp_accounting
import jax
import jax.numpy as jnp
from optax import contrib
from optax import losses
import optax
from jax.example_libraries import stax
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Shows on which platform JAX is running.
print("JAX running on", jax.devices()[0].platform.upper())
JAX running on GPU

This table contains hyperparameters and the corresponding expected test accuracy.

DPSGD

LEARNING_RATE

NOISE_MULTIPLIER

L2_NORM_CLIP

BATCH_SIZE

NUM_EPOCHS

DELTA

FINAL TEST ACCURACY

False

0.1

NA

NA

256

20

NA

~99%

True

0.25

1.3

1.5

256

15

1e-5

~95%

True

0.15

1.1

1.0

256

60

1e-5

~96.6%

True

0.25

0.7

1.5

256

45

1e-5

~97%

# @markdown Whether to use DP-SGD or vanilla SGD:
DPSGD = True  # @param{type:"boolean"}
# @markdown Learning rate for the optimizer:
LEARNING_RATE = 0.25  # @param{type:"number"}
# @markdown Noise multiplier for DP-SGD optimizer:
NOISE_MULTIPLIER = 1.3  # @param{type:"number"}
# @markdown L2 norm clip:
L2_NORM_CLIP = 1.5  # @param{type:"number"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 256  # @param{type:"integer"}
# @markdown Number of epochs:
NUM_EPOCHS = 15  # @param{type:"integer"}
# @markdown Probability of information leakage:
DELTA = 1e-5  # @param{type:"number"}

CIFAR10 and CIFAR100 are composed of 32x32 images with 3 channels (RGB). We’ll now load the dataset using tensorflow_datasets and display a few of the first samples.

(train_loader, test_loader), info = tfds.load(
    "mnist", split=["train", "test"], as_supervised=True, with_info=True
)

min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)

train_loader_batched = train_loader.shuffle(
    buffer_size=10_000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)

NUM_EXAMPLES = info.splits["test"].num_examples
test_batch = next(test_loader.batch(NUM_EXAMPLES, drop_remainder=True).as_numpy_iterator())
init_random_params, predict = stax.serial(
    stax.Conv(16, (8, 8), padding="SAME", strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(32, (4, 4), padding="VALID", strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Flatten,
    stax.Dense(32),
    stax.Relu,
    stax.Dense(10),
)

This function computes the privacy parameter epsilon for the given number of steps and probability of information leakage DELTA.

def compute_epsilon(steps):
  if NUM_EXAMPLES * DELTA > 1.:
    warnings.warn("Your delta might be too high.")
  q = BATCH_SIZE / float(NUM_EXAMPLES)
  orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
  accountant = dp_accounting.rdp.RdpAccountant(orders)
  accountant.compose(dp_accounting.PoissonSampledDpEvent(
      q, dp_accounting.GaussianDpEvent(NOISE_MULTIPLIER)), steps)
  return accountant.get_epsilon(DELTA)
@jax.jit
def loss_fn(params, batch):
  images, labels = batch
  logits = predict(params, images)
  return losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean(), logits


@jax.jit
def test_step(params, batch):
  images, labels = batch
  logits = predict(params, images)
  loss = losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
  accuracy = (logits.argmax(1) == labels).mean()
  return loss, accuracy * 100
if DPSGD:
  tx = contrib.dpsgd(
      learning_rate=LEARNING_RATE, l2_norm_clip=L2_NORM_CLIP,
      noise_multiplier=NOISE_MULTIPLIER, seed=1337)
else:
  tx = optax.sgd(learning_rate=LEARNING_RATE)

_, params = init_random_params(jax.random.PRNGKey(1337), (-1, 28, 28, 1))
opt_state = tx.init(params)
@jax.jit
def train_step(params, opt_state, batch):
  grad_fn = jax.grad(loss_fn, has_aux=True)
  if DPSGD:
    # Inserts a dimension in axis 1 to use jax.vmap over the batch.
    batch = jax.tree_util.tree_map(lambda x: x[:, None], batch)
    # Uses jax.vmap across the batch to extract per-example gradients.
    grad_fn = jax.vmap(grad_fn, in_axes=(None, 0))

  grads, _ = grad_fn(params, batch)
  updates, new_opt_state = tx.update(grads, opt_state, params)
  new_params = optax.apply_updates(params, updates)
  return new_params, new_opt_state
accuracy, loss, epsilon = [], [], []

for epoch in range(NUM_EPOCHS):
  for batch in train_loader_batched.as_numpy_iterator():
    params, opt_state = train_step(params, opt_state, batch)

  # Evaluates test accuracy.
  test_loss, test_acc = test_step(params, test_batch)
  accuracy.append(test_acc)
  loss.append(test_loss)
  print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, test accuracy: {test_acc}")

  #
  if DPSGD:
    steps = (1 + epoch) * NUM_EXAMPLES // BATCH_SIZE
    eps = compute_epsilon(steps)
    epsilon.append(eps)
if DPSGD:
  _, axs = plt.subplots(ncols=3, figsize=(9, 3))
else:
  _, axs = plt.subplots(ncols=2, figsize=(6, 3))

axs[0].plot(accuracy)
axs[0].set_title("Test accuracy")
axs[1].plot(loss)
axs[1].set_title("Test loss")

if DPSGD:
  axs[2].plot(epsilon)
  axs[2].set_title("Epsilon")

plt.tight_layout()
print(f'Final accuracy: {accuracy[-1]}')
Array(98.99, dtype=float32)