Recreate AdeMAMix Rosenbrock Plot from Paper#

Open in Colab

This notebook attempts to recreate Figure 2 from the AdeMAMix paper

Imports#

from functools import partial

import matplotlib.pyplot as plt
import optax
import jax
import jax.numpy as jnp
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

plt.rc("figure", figsize=(20, 10))
plt.rc("font", size=14)
from optax.schedules import linear_schedule
from optax._src import base


def cond_print(cond, fmt, *args):
    return jax.lax.cond(cond, lambda: (jax.debug.print(fmt, *args), 0)[1], lambda: 0)

Functions#

def rosenbrock(x):
    return jnp.square(1 - x[0]) + 100.0 * jnp.square(x[1] - jnp.square(x[0]))


# Create a grid of x and y values
x = jnp.linspace(-5, 10, 1000)
y = jnp.linspace(-5, 10, 1000)
X, Y = jnp.meshgrid(x, y)

# Compute the Rosenbrock function values for each point on the grid
Z = rosenbrock([X, Y])
num_iterations = 100000

Generate Adam Trajectories (Baseline)#

def _body_fn(carry, _, *, solver, b1):
    params, opt_state, i = carry
    grad = jax.grad(rosenbrock)(params)
    updates, opt_state = solver.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    cond_print(
        i % 25000 == 0,
        "Objective function for b1={} at iteration {} = {}",
        b1,
        i,
        rosenbrock(params),
    )
    return (params, opt_state, i + 1), params

all_b1_params = []
for b1 in [0.9, 0.99, 0.999, 0.9999]:
    solver = optax.adam(learning_rate=0.003, b1=b1, b2=0.9999)
    params = jnp.array([-3.0, 5.0])
    print("Objective function: ", rosenbrock(params))
    opt_state = solver.init(params)

    _, all_params = jax.lax.scan(
        partial(_body_fn, solve=solver, b1=b1),
        (params, opt_state, 0),
        length=num_iterations
    )
    all_b1_params.append(jnp.concatenate([params[None, ...], all_params], 0))
all_b1_params_array = jnp.array(all_b1_params)
Objective function:  1616.0
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 22
     18     params = jnp.array([-3.0, 5.0])
     19     print("Objective function: ", rosenbrock(params))
     20     opt_state = solver.init(params)
     21 
---> 22     _, all_params = jax.lax.scan(
     23         partial(_body_fn, solve=solver, b1=b1),
     24         (params, opt_state, 0),
     25         length=num_iterations

    [... skipping hidden 3 frame]

File ~/checkouts/readthedocs.org/user_builds/optax/envs/latest/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2283, in trace_to_jaxpr(***failed resolving arguments***)
   2281 with core.set_current_trace(trace):
   2282   args, kwargs = in_tracers.unflatten()
-> 2283   ans_pytree = fun(*args, **kwargs)
   2284   if fun_returns_flat_tree:
   2285     # TODO(dougalm): make result paths optional
   2286     ans = ans_pytree

TypeError: _body_fn() got an unexpected keyword argument 'solve'

Generate AdeMAMix Trajectories#

Create alpha scheduler#

alpha = 0.8
alpha = linear_schedule(0, alpha, num_iterations)

Create b3 scheduler#

def b3_scheduler(beta_end: float, beta_start: float = 0, warmup: int = 0):
    def f(beta):
        return jnp.log(0.5) / jnp.log(beta) - 1

    def f_inv(t):
        return jnp.power(0.5, 1 / (t + 1))

    def schedule(step):
        is_warmup = jnp.array(step < warmup).astype(jnp.float32)
        alpha = step / float(warmup)
        return is_warmup * f_inv(
            (1.0 - alpha) * f(beta_start) + alpha * f(beta_end)
        ) + beta_end * (1.0 - is_warmup)

    return schedule
def _body_fn(carry, _, *, solver, b3):
    params, opt_state, i = carry
    grad = jax.grad(rosenbrock)(params)
    updates, opt_state = solver.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    cond_print(
        i % 25000 == 0,
        "Objective function for b3={} at iteration {} = {}",
        b3(i),
        i,
        rosenbrock(params),
    )
    return (params, opt_state, i + 1), params

all_ademamix_params = []
for b3 in [0.999, 0.9999]:
    b3 = b3_scheduler(b3, 0, num_iterations)
    solver = optax.contrib.ademamix(
        learning_rate=0.003, b1=0.99, b2=0.999, b3=b3, alpha=alpha
    )
    params = jnp.array([-3.0, 5.0])
    print("Objective function: ", rosenbrock(params))
    opt_state = solver.init(params)

    _, all_params = jax.lax.scan(
        partial(_body_fn, solver=solver, b3=b3),
        (params, opt_state, 0),
        length=num_iterations
    )
    all_ademamix_params.append(jnp.concatenate([params[None, ...], all_params], 0))
all_ademamix_params_array = jnp.array(all_ademamix_params)
Objective function:  1616.0
Objective function for b3=0.0 at iteration 0 = 1599.227294921875
Objective function for b3=0.9960060715675354 at iteration 25000 = 1.1232594943066943e-07
Objective function for b3=0.9980010390281677 at iteration 50000 = 6.792194540139462e-08
Objective function for b3=0.9986668825149536 at iteration 75000 = 1.3234323148481053e-07
Objective function:  1616.0
Objective function for b3=0.0 at iteration 0 = 1599.227294921875
Objective function for b3=0.9995999932289124 at iteration 25000 = 5.958846571729737e-08
Objective function for b3=0.9997999668121338 at iteration 50000 = 1.4210854715202004e-14
Objective function for b3=0.9998666644096375 at iteration 75000 = 2.3673273119584337e-08

Plot the Figure#

fig = plt.figure()
ax = fig.subplots(1, 2)
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[0].set_title("Rosenbrock Function - Adam Trajectories")
# Show the plot
ax[0].plot([1], [1], "x", mew=1, markersize=10, color="cyan")
ax[0].contourf(X, Y, Z, np.logspace(-1, 3, 100), cmap="jet")
for i, b1 in enumerate([0.9, 0.99, 0.999, 0.9999]):
    ax[0].plot(
        all_b1_params_array[i, ::100, 0],
        all_b1_params_array[i, ::100, 1],
        label=f"Adam b1 = {b1}",
    )
ax[0].set_xlim(-4, 4)
ax[0].set_ylim(-3.5, 7.5)
ax[0].legend()

ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
ax[1].set_title("Rosenbrock Function - Ademamix Trajectories")
# Show the plot
ax[1].plot([1], [1], "x", mew=1, markersize=10, color="cyan")
ax[1].contourf(X, Y, Z, np.logspace(-1, 3, 100), cmap="jet")
for i, b3 in enumerate([0.999, 0.9999]):
    ax[1].plot(
        all_ademamix_params_array[i, ::100, 0],
        all_ademamix_params_array[i, ::100, 1],
        label=f"AdEMAMix b3 = {b3}",
    )
ax[1].set_xlim(-4, 4)
ax[1].set_ylim(-3.5, 7.5)
ax[1].legend()

plt.show()
../../../_images/5e4581e8e8b5c6ee05674d59bb485004a2562f5489fd70c228609fb15b1074c0.png

Plot Figure 2a from Paper#

N = num_iterations + 1
fig, ax = plt.subplots()
lns = ax.semilogy(
    jnp.arange(N),
    jnp.linalg.norm(
        all_b1_params_array[0, :, :]
        - jnp.ones(
            2,
        ),
        axis=1,
    ),
    label="Adam b1 = 0.9",
)
for i, b1 in enumerate([0.99, 0.999, 0.9999]):
    lns += ax.semilogy(
        jnp.arange(N),
        jnp.sqrt(
            jnp.linalg.norm(
                all_b1_params_array[i + 1, :, :]
                - jnp.ones(
                    2,
                ),
                axis=1,
            )
        ),
        label=f"Adam b1 = {b1}",
    )
ax1 = ax.twinx()
for i, b3 in enumerate([0.999, 0.9999]):
    lns += ax1.semilogy(
        jnp.arange(N),
        jnp.sqrt(
            jnp.linalg.norm(
                all_ademamix_params_array[i, :, :]
                - jnp.ones(
                    2,
                ),
                axis=1,
            )
        ),
        label=f"AdeMAMix b3 = {b3}",
    )
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc=0)
plt.show()
../../../_images/4feaed4475a4d6d439db1a545e8bac3524f60e256493352cdcc1587eaca5804a.png