optax.safe_increment

Contents

optax.safe_increment#

optax.safe_increment(count: jax.typing.ArrayLike) Array[source]#

Increments counter by one while avoiding overflow.

Denote max_val, min_val as the maximum, minimum, possible values for the dtype of count. Normally max_val + 1 would overflow to min_val. This functions ensures that when max_val is reached the counter stays at max_val.

Parameters:

count โ€“ a counter to be incremented.

Returns:

A counter incremented by 1, or max_val if the maximum value is reached.

Examples

>>> import jax.numpy as jnp
>>> import optax
>>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32))
Array(2, dtype=int32)
>>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32))
Array(2147483647, dtype=int32)

Added in version 0.2.4.