# L-BFGS#

L-BFGS is a classical optimization method that uses past gradients and parameters information to iteratively refine a solution to a minimization problem. In this notebook, we illustrate

how to use L-BFGS as a simple gradient transformation,

how to wrap L-BFGS in a solver, and how linesearches are incorporated,

how to debug the solver if needed,

```
from typing import NamedTuple
import chex
import jax
import jax.numpy as jnp
import jax.random as jrd
import optax
import optax.tree_utils as otu
```

## L-BFGS as a gradient transformation#

### What is L-BFGS?#

To solve a problem of the form

L-BFGS (Limited memory Broyden–Fletcher–Goldfarb–Shanno algorithm) makes steps of the form

where, at iteration \(k\), \(w_k\) are the parameters, \(g_k = \nabla f_k\) are the gradients, \(\eta_k\) is the stepsize, and \(P_k\) is a *preconditioning* matrix, that is, a matrix that transforms the gradients to ease the optimization process.

L-BFGS builds the preconditioning matrix \(P_k\) as an approximation of the Hessian inverse \(P_k \approx \nabla^2 f(w_k)^{-1}\) using past gradient and parameters information. Briefly, at iteration \(k\), the previous preconditioning matrix \(P_{k-1}\) is updated such that \(P_k\) satisfies the secant condition \(P_k(w_k-w_{k-1}) = g_k -g_{k-1}\). The original BFGS algorithm updates \(P_k\) using all past information, the limited-memory variant only uses a fixed number of past parameters and gradients to build \(P_k\). See Nocedal and Wright, Numerical Optimization, 1999 or the documentation for more details on the implementation.

### Using L-BFGS as a gradient transformation#

The function `optax.scale_by_lbfgs()`

implements the update of the preconditioning matrix given a running optimizer state \(s_k\). Given \((g_k, s_k, w_k)\), this function returns \((P_kg_k, s_{k+1})\). We illustrate its performance below on a simple convex quadratic.

```
# Define objective
dim = 8
w_opt = jnp.ones(dim)
mat = jrd.normal(jrd.PRNGKey(0), (dim, dim))
mat = mat.dot(mat.T)
def fun(w):
return 0.5 * (w - w_opt).dot(mat.dot(w - w_opt))
# Define optimizer
lr = 1e-1
opt = optax.scale_by_lbfgs()
# Initialize optimization
w = jrd.normal(jrd.PRNGKey(1), (dim,))
state = opt.init(w_opt)
# Run optimization
for i in range(16):
v, g = jax.value_and_grad(fun)(w)
print(f'Iteration: {i}, Value:{v:.2e}')
u, state = opt.update(g, state, w)
w = w - lr * u
print(f'Final value: {fun(w):.2e}')
```

```
Iteration: 0, Value:4.27e+01
```

```
Iteration: 1, Value:1.08e+01
Iteration: 2, Value:9.57e+00
```

```
Iteration: 3, Value:8.17e+00
Iteration: 4, Value:6.80e+00
```

```
Iteration: 5, Value:5.56e+00
Iteration: 6, Value:4.49e+00
```

```
Iteration: 7, Value:3.63e+00
Iteration: 8, Value:2.94e+00
```

```
Iteration: 9, Value:2.38e+00
Iteration: 10, Value:1.93e+00
```

```
Iteration: 11, Value:1.55e+00
Iteration: 12, Value:1.26e+00
```

```
Iteration: 13, Value:1.03e+00
Iteration: 14, Value:8.49e-01
```

```
Iteration: 15, Value:7.03e-01
Final value: 5.83e-01
```

## L-BFGS as a solver#

L-BFGS is a sample in numerical optimization to solve medium scale problems. It is often the backend of generic minimization functions in software libraries like scipy. A key ingredient to make it a simple optimization blackbox, is to remove the need of tuning the stepsize, a.k.a. learning rate in machine learning. In a deterministic setting (no additional varying inputs like inputs/labels), such automatic tuning of the stepsize is done by means of linesearches reviewed below.

### What are linesearches?#

Given current parameters \(w_k\), an update direction \(u_k\) (such as the negative preconditioned gradient \(u_k = -P_k g_k\) returned by L-BFGS), a linesearch computes a stepsize \(\eta_k\) such that the next iterate \(w_{k+1} = w_k + \eta_k u_k\) satisfies some criterions.

#### Sufficient decrease (Armijo-Goldstein criterion)#

The first criterion that a good stepsize may need to satisfy is to ensure that the next iterate decreases the value of the objective by a a sufficient amount. Mathematically, the criterion is expressed as finding \(\eta_k\) such that

where \(c_1\) is some constant set to \(10^{-4}\) by default. Consider for example the update direction to be \(u_k = -g_k\), i.e., moving along the negative gradient direction. In that case the criterion above reduces to \(f(w_k - \eta_k g_k) \leq f(w_k) - c_1 \eta_k ||g_k||_2^2\). The criterion amounts then to choosing the stepsize such that it decreases the objective by an amount proportional to the squared gradient norm.

As long as the update direction is a *descent direction*, that is, \(\langle u_k, g_k\rangle < 0\) the above criterion is guaranteed to be satisfied by some sufficiently small stepsize.
A simple linesearch technique to ensure a sufficient decrease is then to decrease a candidate stepsize by a constant factor up until the criterion is satisfied. This amounts to the backtracking linesearch implemented in `optax.scale_by_backtracking_linesearch()`

and briefly reviewed below.

#### Small curvature (Strong wolfe criterion)#

The sufficient decrease criterion ensures that the algorithm does not produce a sequence of diverging objective values. However, we may want to not only reduce a current stepsize but also increase it to ensure maximal speed. Ideally, we would like to find the stepsize that minimizes the function along the current update, i.e., \(\eta_k^* = \arg\min_\eta f(w_k + \eta u_k)\). Such an endeavor can be computationally prohibitive, so we may select a stepsize that ensures some properties that an optimal stepsize would satisfy. In particular, we may search for a stepsize such that the derivative of \(h(\eta) = f(w_k + \eta u_k)\) is small enough compared to its derivativeœ at \(\eta=0\). Formally, we may want to select the stepsize \(\eta_k\) such that \(|h'(eta_k)| \leq |h'(0)|\), that is,

See Chapter 3 of Nocedal and Wright, Numerical Optimization, 1999 for some illustrations of this criterion. A linesearch method that can ensure both criterions require some form of bisection method implemented in optax with the `optax.scale_by_zoom_linesearch()`

method. Several other linesearch techniques exist, see e.g. https://github.com/JuliaNLSolvers/LineSearches.jl. It is generally recommended to combine L-BFGS with a line-search ensuring both sufficient decrease and small curvature, which the `optax.scale_by_zoom_linesearch()`

ensures.

### Linesearches in practice#

To find a stepsize satisfying the above criterions, a linesearch needs to access the value and potentially the gradient of the function. So linesearches in optax are implemented as `optax.GradientTransformationExtraArgs()`

, which take the current value, gradient of the objective as well as the function itself. We illustrate this below with `optax.scale_by_backtracking_linesearch()`

.

```
# Objective
def fun(w):
return jnp.sum(jnp.abs(w))
# Linesearch, comment/uncomment the desired one
linesearch = optax.scale_by_backtracking_linesearch(max_backtracking_steps=15)
# linesearch = optax.scale_by_zoom_linesearch(max_linesearch_steps=15)
# Optimizer
opt = optax.chain(
optax.sgd(learning_rate=1.0),
# Compare with or without linesearch by commenting this line
linesearch,
)
# Initialize
w = jrd.normal(jrd.PRNGKey(0), (8,))
state = opt.init(w)
# Run optimization
for i in range(16):
v, g = jax.value_and_grad(fun)(w)
print(f'Iteration: {i}, Value:{v:.2e}')
u, state = opt.update(g, state, w, value=v, grad=g, value_fn=fun)
w = w + u
print(f'Final value: {fun(w):.2e}')
```

```
Iteration: 0, Value:7.90e+00
Iteration: 1, Value:4.49e+00
Iteration: 2, Value:3.74e+00
Iteration: 3, Value:2.91e+00
```

```
Iteration: 4, Value:2.45e+00
Iteration: 5, Value:1.97e+00
Iteration: 6, Value:1.81e+00
Iteration: 7, Value:1.26e+00
Iteration: 8, Value:1.14e+00
```

```
Iteration: 9, Value:1.00e+00
Iteration: 10, Value:7.17e-01
Iteration: 11, Value:6.61e-01
Iteration: 12, Value:5.50e-01
Iteration: 13, Value:3.80e-01
```

```
Iteration: 14, Value:3.36e-01
Iteration: 15, Value:2.17e-01
Final value: 2.16e-01
```

To validate the stepsize the linesearch calls the function several times. If a stepsize is accepted, we have then a priori access to the value of the function, and, potentially its gradient. The implementation of the linesearches in optax store the value and the gradient computed by the linesearch to avoid recomputing them at the next step. In practice, the code above can be modified as follows.

*Note:*
The backtracking linesearch only evaluates the function and does not compute the gradient natively. To make the backtracking linesearch compute and store the gradient at the stepsize taken, we add the flag `store_grad=True`

, see below.
The zoom linesearch always compute both function and gradient so there is no need to specify an additional flag.

```
# Objective
def fun(w):
return jnp.sum(jnp.abs(w))
# Linesearch
linesearch = optax.scale_by_backtracking_linesearch(
max_backtracking_steps=15, store_grad=True
)
# linesearch = optax.scale_by_zoom_linesearch(max_linesearch_steps=15)
# Optimizer
opt = optax.chain(optax.sgd(learning_rate=1.0), linesearch)
# Initialize
w = jrd.normal(jrd.PRNGKey(0), (8,))
state = opt.init(w)
# Run optimization
for _ in range(16):
# Replace `v, g = jax.value_and_grad(fun)(w)` by
v, g = optax.value_and_grad_from_state(fun)(w, state=state)
u, state = opt.update(g, state, w, value=v, grad=g, value_fn=fun)
w = w + u
print(f'Final value: {fun(w):.2e}')
```

```
Final value: 2.16e-01
```

### L-BFGS solver#

Optax combines then the gradient transformation of L-BFGS and a linesearch in `optax.lbfgs()`

.

We present below a wrapper that combines both into a solver which tries to find the minimizer of a function given

some initial parameters

`init_params`

,the function to optimize

`fun`

,the instance of the L-BFGS solver considered

`opt`

,a maximal number of iteration

`max_iter`

,a tolerance

`tol`

on the optimization error measured here as the gradient norm.

```
def run_lbfgs(init_params, fun, opt, max_iter, tol):
value_and_grad_fun = optax.value_and_grad_from_state(fun)
def step(carry):
params, state = carry
value, grad = value_and_grad_fun(params, state=state)
updates, state = opt.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)
params = optax.apply_updates(params, updates)
return params, state
def continuing_criterion(carry):
_, state = carry
iter_num = otu.tree_get(state, 'count')
grad = otu.tree_get(state, 'grad')
err = otu.tree_l2_norm(grad)
return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))
init_carry = (init_params, opt.init(init_params))
final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
return final_params, final_state
```

We can test the solver on the Rosenbrock function.

```
def fun(w):
return jnp.sum(100.0 * (w[1:] - w[:-1] ** 2) ** 2 + (1.0 - w[:-1]) ** 2)
opt = optax.lbfgs()
init_params = jnp.zeros((8,))
print(
f'Initial value: {fun(init_params):.2e} '
f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'
)
final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)
print(
f'Final value: {fun(final_params):.2e}, '
f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'
)
```

```
Initial value: 7.00e+00 Initial gradient norm: 5.29e+00
```

```
Final value: 2.11e-02, Final gradient norm: 1.83e+00
```

We may add additional information by simply chaining `optax.lbfgs`

with an identity transform that just prints relevant information as follows.

```
class InfoState(NamedTuple):
iter_num: chex.Numeric
def print_info():
def init_fn(params):
del params
return InfoState(iter_num=0)
def update_fn(updates, state, params, *, value, grad, **extra_args):
del params, extra_args
jax.debug.print(
'Iteration: {i}, Value: {v}, Gradient norm: {e}',
i=state.iter_num,
v=value,
e=otu.tree_l2_norm(grad),
)
return updates, InfoState(iter_num=state.iter_num + 1)
return optax.GradientTransformationExtraArgs(init_fn, update_fn)
def fun(w):
return jnp.sum(100.0 * (w[1:] - w[:-1] ** 2) ** 2 + (1.0 - w[:-1]) ** 2)
opt = optax.chain(print_info(), optax.lbfgs())
init_params = jnp.zeros((8,))
print(
f'Initial value: {fun(init_params):.2e} '
f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'
)
final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)
print(
f'Final value: {fun(final_params):.2e}, '
f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'
)
```

```
Initial value: 7.00e+00 Initial gradient norm: 5.29e+00
```

```
Iteration: 0, Value: 7.0, Gradient norm: 5.291502475738525
Iteration: 1, Value: 6.925538539886475, Gradient norm: 3.311527729034424
Iteration: 2, Value: 6.906764507293701, Gradient norm: 2.6968793869018555
Iteration: 3, Value: 6.844264030456543, Gradient norm: 2.102635622024536
Iteration: 4, Value: 6.742486476898193, Gradient norm: 4.286610126495361
Iteration: 5, Value: 6.674825668334961, Gradient norm: 6.461292266845703
Iteration: 6, Value: 6.592029571533203, Gradient norm: 7.5424957275390625
Iteration: 7, Value: 6.418954849243164, Gradient norm: 8.496240615844727
Iteration: 8, Value: 6.174844741821289, Gradient norm: 8.72452163696289
Iteration: 9, Value: 5.94936466217041, Gradient norm: 12.077523231506348
Iteration: 10, Value: 5.505581855773926, Gradient norm: 7.263585567474365
Iteration: 11, Value: 5.194385528564453, Gradient norm: 5.545096397399902
Iteration: 12, Value: 4.789361953735352, Gradient norm: 6.743480205535889
Iteration: 13, Value: 4.4702653884887695, Gradient norm: 8.249473571777344
Iteration: 14, Value: 4.3805060386657715, Gradient norm: 9.349186897277832
Iteration: 15, Value: 4.151071548461914, Gradient norm: 8.896681785583496
Iteration: 16, Value: 3.9318795204162598, Gradient norm: 7.139084339141846
Iteration: 17, Value: 3.7333321571350098, Gradient norm: 6.095333576202393
Iteration: 18, Value: 3.503709554672241, Gradient norm: 6.3326416015625
Iteration: 19, Value: 3.196208953857422, Gradient norm: 4.808716297149658
Iteration: 20, Value: 2.972954273223877, Gradient norm: 5.9373369216918945
Iteration: 21, Value: 2.696378231048584, Gradient norm: 7.466330528259277
Iteration: 22, Value: 2.511641263961792, Gradient norm: 9.061756134033203
Iteration: 23, Value: 2.354001998901367, Gradient norm: 8.319526672363281
Iteration: 24, Value: 2.1314697265625, Gradient norm: 6.8462018966674805
Iteration: 25, Value: 1.8888188600540161, Gradient norm: 5.587310791015625
Iteration: 26, Value: 1.6743313074111938, Gradient norm: 6.713437557220459
Iteration: 27, Value: 1.4451940059661865, Gradient norm: 6.366178035736084
Iteration: 28, Value: 1.27220618724823, Gradient norm: 6.59257698059082
Iteration: 29, Value: 1.146568775177002, Gradient norm: 5.683746337890625
Iteration: 30, Value: 1.0164188146591187, Gradient norm: 8.674849510192871
Iteration: 31, Value: 0.918250322341919, Gradient norm: 7.402083396911621
Iteration: 32, Value: 0.8004761338233948, Gradient norm: 5.253969669342041
Iteration: 33, Value: 0.7032359838485718, Gradient norm: 5.31928014755249
Iteration: 34, Value: 0.61725914478302, Gradient norm: 5.614082336425781
Iteration: 35, Value: 0.5446212887763977, Gradient norm: 4.689955711364746
Iteration: 36, Value: 0.4713936150074005, Gradient norm: 4.368592739105225
Iteration: 37, Value: 0.39408063888549805, Gradient norm: 4.842732906341553
Iteration: 38, Value: 0.34117570519447327, Gradient norm: 4.462212562561035
Iteration: 39, Value: 0.26454681158065796, Gradient norm: 4.792419910430908
Iteration: 40, Value: 0.20966416597366333, Gradient norm: 3.5169851779937744
Iteration: 41, Value: 0.16897892951965332, Gradient norm: 2.185563087463379
Iteration: 42, Value: 0.14081722497940063, Gradient norm: 2.384767770767212
Iteration: 43, Value: 0.10892663151025772, Gradient norm: 3.4658374786376953
Iteration: 44, Value: 0.0865096002817154, Gradient norm: 2.4952120780944824
Iteration: 45, Value: 0.06923285871744156, Gradient norm: 2.4269776344299316
Iteration: 46, Value: 0.060979172587394714, Gradient norm: 3.792335033416748
Iteration: 47, Value: 0.050240471959114075, Gradient norm: 3.4059038162231445
Iteration: 48, Value: 0.036705415695905685, Gradient norm: 2.779883623123169
Iteration: 49, Value: 0.02758537232875824, Gradient norm: 2.363450765609741
Final value: 2.11e-02, Final gradient norm: 1.83e+00
```

## Debugging solver#

In some cases, L-BFGS with a linesearch as a solver will fail. Most of the times, the culprit goes down to the linesearch. To debug the solver in such cases, we provide a `verbose`

option to the `optax.scale_by_zoom_linesearch`

. We show below how to proceed.

First we try to minimize the Zakharov function without any changes. You’ll observe that the final value is larger than the initial value which points out that the solver failed, and probably because the linesearch did not find a stepsize that ensured a sufficient decrease.

```
def fun(w):
ii = jnp.arange(1, len(w) + 1, step=1, dtype=w.dtype)
sum1 = (w**2).sum()
sum2 = (0.5 * ii * w).sum()
return sum1 + sum2**2 + sum2**4
opt = optax.chain(print_info(), optax.lbfgs())
init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
print(
f'Initial value: {fun(init_params)} '
f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params))}'
)
final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)
print(
f'Final value: {fun(final_params)}, '
f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params))}'
)
```

```
Initial value: 1.0129932568095621e+18 Initial gradient norm: 609193933406208.0
```

```
Iteration: 0, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 1, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 2, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 3, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 4, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 5, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 6, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 7, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 8, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 9, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 10, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 11, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 12, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 13, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 14, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 15, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 16, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 17, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 18, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 19, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 20, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 21, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 22, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 23, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 24, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 25, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 26, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 27, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 28, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 29, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 30, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 31, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 32, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 33, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 34, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 35, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 36, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 37, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 38, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 39, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 40, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 41, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 42, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 43, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 44, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 45, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 46, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 47, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 48, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 49, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Final value: 1.0129932568095621e+18, Final gradient norm: 609193933406208.0
```

We can change the linesearch used in lbfgs as part of its arguments. Here we keep the default number of linesearch steps (15) and set the verbose option to `True`

.

```
opt = optax.chain(print_info(), optax.lbfgs(
linesearch=optax.scale_by_zoom_linesearch(
max_linesearch_steps=15, verbose=True
)
))
init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
print(
f'Initial value: {fun(init_params):.2e} '
f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'
)
final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)
print(
f'Final value: {fun(final_params):.2e}, '
f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'
)
```

```
Initial value: 1.01e+18 Initial gradient norm: 6.09e+14
```

```
Iteration: 0, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.711172570311185e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Cannot even make a step without getting Inf or Nan. The linesearch won't make a step and the optimizer is stuck.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 6.103515625e-05 Decrease Error: inf Curvature Error: inf
Iteration: 1, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.711172570311185e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.7111734675608165e+28
Iteration: 2, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 3, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 4, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 5, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 6, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 7, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 8, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 9, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 10, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 11, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 12, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 13, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 14, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 15, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 16, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 17, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 18, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 19, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 20, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 21, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 22, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 23, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 24, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 25, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 26, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 27, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 28, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 29, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 30, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 31, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 32, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 33, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 34, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 35, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 36, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 37, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 38, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 39, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 40, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 41, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 42, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 43, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 44, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 45, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 46, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 47, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 48, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Iteration: 49, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: optax.zoom_linesearch: Computed stepsize (=0.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: optax.zoom_linesearch: Very large absolute slope at stepsize=0. (|slope|=3.7111729481005034e+29). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 15 Stepsize: 0.0 Decrease Error: 0.0 Curvature Error: 3.711173939797465e+28
Final value: 1.01e+18, Final gradient norm: 6.09e+14
```

As expected, the linesearch failed at the very first step taking a stepsize that did not ensure a sufficient decrease. Multiple information is displayed. For example, the slope (derivative along the update direction) at the first step si extremely large which explains the difficulties to find an appropriate stepsize. As pointed out in the log above, the first thing to try is to use a larger number of linesearch steps.

```
opt = optax.chain(print_info(), optax.lbfgs(
linesearch=optax.scale_by_zoom_linesearch(
max_linesearch_steps=50, verbose=True
)
))
init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
print(
f'Initial value: {fun(init_params):.2e} '
f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'
)
final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)
print(
f'Final value: {fun(final_params):.2e}, '
f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'
)
```

```
Initial value: 1.01e+18 Initial gradient norm: 6.09e+14
```

```
Iteration: 0, Value: 1.0129932568095621e+18, Gradient norm: 609193933406208.0
Iteration: 1, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 2, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 3, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 4, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 5, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 6, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 7, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 8, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 9, Value: 1.2426158578597888e+16, Gradient norm: 22454497968128.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 10, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 11, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
WARNING: optax.zoom_linesearch: The linesearch failed because the provided direction is not a descent direction. The slope (=5448015154249728.0) at stepsize=0 should be negative
WARNING: optax.zoom_linesearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
INFO: optax.zoom_linesearch: Iter: 50 Stepsize: 1.9852334701272664e-23 Decrease Error: 0.0 Curvature Error: 726286016184320.0
Iteration: 12, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 13, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 14, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 15, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 16, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 17, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 18, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 19, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 20, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 21, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 22, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 23, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 24, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 25, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 26, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 27, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 28, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 29, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 30, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 31, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 32, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 33, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 34, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 35, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 36, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 37, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 38, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 39, Value: 1.2426149988663296e+16, Gradient norm: 22454487482368.0
Iteration: 40, Value: 9894877726769152.0, Gradient norm: 18928164143104.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 41, Value: 9894877726769152.0, Gradient norm: 18928164143104.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 42, Value: 9894877726769152.0, Gradient norm: 18928164143104.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 43, Value: 9894877726769152.0, Gradient norm: 18928164143104.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 44, Value: 9894877726769152.0, Gradient norm: 18928164143104.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 45, Value: 9894877726769152.0, Gradient norm: 18928164143104.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 46, Value: 9894874505543680.0, Gradient norm: 18928162045952.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 47, Value: 9894874505543680.0, Gradient norm: 18928162045952.0
Iteration: 48, Value: 5347937584414720.0, Gradient norm: 11931374059520.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Iteration: 49, Value: 5347937584414720.0, Gradient norm: 11931374059520.0
WARNING: optax.zoom_linesearch: Length of searched interval has been reduced below threshold.
WARNING: optax.zoom_linesearch: Returning stepsize with sufficient decrease but curvature condition not satisfied.
Final value: 5.35e+15, Final gradient norm: 1.19e+13
```

By simply taking a maximum of 50 steps of the linesearch instead of 15, we ensured that the first stepsize taken provided a sufficient decrease and the solver worked well.
Additional debugging information can be found in the source code accessible from the docs of `optax.scale_by_zoom_linesearch()`

.