Utilities#
General#
|
Scales gradients for the backwards pass. |
|
Alternative to |
Numerical Stability#
|
Increments counter by one while avoiding overflow. |
|
Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients. |
|
Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads. |
Linear Algebra Operators#
|
Computes matrix^(-1/p), where p is a positive integer. |
|
Power iteration algorithm. |
|
Solves the non-negative least squares problem. |
Second Order Optimization#
|
Computes the diagonal of the (observed) Fisher information matrix. |
|
Computes the diagonal hessian of loss at (inputs, targets). |
|
Performs an efficient vector-Hessian (of loss) product. |
Tree#
|
KeyType for a NamedTuple in a tree. |
|
Add two (or more) pytrees. |
|
Add two trees, where the second tree is scaled by a scalar. |
|
Check whether two trees are element-wise approximately equal within a tolerance. |
|
Add leading batch dimensions to each leaf of a pytree. |
|
Cast tree to given dtype, skip if None. |
|
Cast tree to dtypes of other_tree. |
|
Creates an identical tree where all tensors are clipped to [min, max]. |
|
Compute the conjugate of a pytree. |
|
Divide two pytrees. |
|
Fetch dtype of tree. |
|
Creates an identical tree where all tensors are filled with |
|
Extract a value from a pytree matching a given key. |
|
Extract values of a pytree matching a given key. |
|
Compute the vector norm of the given ord of a pytree. |
|
Apply a callable over all params in the given optimizer state. |
|
Compute the max of all the elements in a pytree. |
|
Compute the min of all the elements in a pytree. |
|
Multiply two pytrees. |
|
Creates an all-ones tree with the same structure. |
|
Create tree with random entries of the same shape as target tree. |
|
Compute the real part of a pytree. |
|
Split keys to match structure of target tree. |
|
Multiply a tree by a scalar. |
|
Creates a copy of tree with some values replaced as specified by kwargs. |
|
Total size of a pytree. |
|
Subtract two pytrees. |
|
Compute the sum of all the elements in a pytree. |
|
Compute the inner product between two pytrees. |
|
Select tree_x values if condition is true else tree_y values. |
|
Creates an all-zeros tree with the same structure. |