Utilities#

General#

scale_gradient(inputs, scale)

Scales gradients for the backwards pass.

value_and_grad_from_state(value_fn)

Alternative to jax.value_and_grad() fetches value, grad from state.

Numerical Stability#

safe_increment(count)

Increments counter by one while avoiding overflow.

safe_norm(x, min_norm[, ord, axis, keepdims])

Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.

safe_root_mean_squares(x, min_rms)

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.

Linear Algebra Operators#

matrix_inverse_pth_root(matrix, p[, ...])

Computes matrix^(-1/p), where p is a positive integer.

power_iteration(matrix, *[, v0, num_iters, ...])

Power iteration algorithm.

nnls(A, b, iters[, unroll, L])

Solves the non-negative least squares problem.

Second Order Optimization#

fisher_diag(negative_log_likelihood, params, ...)

Computes the diagonal of the (observed) Fisher information matrix.

hessian_diag(loss, params, inputs, targets)

Computes the diagonal hessian of loss at (inputs, targets).

hvp(loss, v, params, inputs, targets)

Performs an efficient vector-Hessian (of loss) product.

Tree#

NamedTupleKey(tuple_name, name)

KeyType for a NamedTuple in a tree.

tree_add(tree_x, tree_y, *other_trees)

Add two (or more) pytrees.

tree_add_scale(tree_x, scalar, tree_y)

Add two trees, where the second tree is scaled by a scalar.

tree_allclose(a, b[, rtol, atol, equal_nan])

Check whether two trees are element-wise approximately equal within a tolerance.

tree_batch_shape(tree[, shape])

Add leading batch dimensions to each leaf of a pytree.

tree_cast(tree, dtype)

Cast tree to given dtype, skip if None.

tree_cast_like(tree, other_tree)

Cast tree to dtypes of other_tree.

tree_clip(tree[, min_value, max_value])

Creates an identical tree where all tensors are clipped to [min, max].

tree_conj(tree)

Compute the conjugate of a pytree.

tree_div(tree_x, tree_y)

Divide two pytrees.

tree_dtype(tree[, mixed_dtype_handler])

Fetch dtype of tree.

tree_full_like(tree, fill_value[, dtype])

Creates an identical tree where all tensors are filled with fill_value.

tree_get(tree, key[, default, filtering])

Extract a value from a pytree matching a given key.

tree_get_all_with_path(tree, key[, filtering])

Extract values of a pytree matching a given key.

tree_norm(tree[, ord, squared])

Compute the vector norm of the given ord of a pytree.

tree_map_params(initable, f, state, /, *rest)

Apply a callable over all params in the given optimizer state.

tree_max(tree)

Compute the max of all the elements in a pytree.

tree_min(tree)

Compute the min of all the elements in a pytree.

tree_mul(tree_x, tree_y)

Multiply two pytrees.

tree_ones_like(tree[, dtype])

Creates an all-ones tree with the same structure.

tree_random_like(rng_key, target_tree, ...)

Create tree with random entries of the same shape as target tree.

tree_real(tree)

Compute the real part of a pytree.

tree_split_key_like(rng_key, target_tree)

Split keys to match structure of target tree.

tree_scale(scalar, tree)

Multiply a tree by a scalar.

tree_set(tree[, filtering])

Creates a copy of tree with some values replaced as specified by kwargs.

tree_size(tree)

Total size of a pytree.

tree_sub(tree_x, tree_y)

Subtract two pytrees.

tree_sum(tree[, associative_reduction])

Compute the sum of all the elements in a pytree.

tree_vdot(tree_x, tree_y)

Compute the inner product between two pytrees.

tree_where(condition, tree_x, tree_y)

Select tree_x values if condition is true else tree_y values.

tree_zeros_like(tree[, dtype])

Creates an all-zeros tree with the same structure.