optax.safe_increment#
- optax.safe_increment(count: jax.typing.ArrayLike) Array[source]#
Increments counter by one while avoiding overflow.
Denote
max_val,min_valas the maximum, minimum, possible values for thedtypeofcount. Normallymax_val + 1would overflow tomin_val. This functions ensures that whenmax_valis reached the counter stays atmax_val.- Parameters:
count โ a counter to be incremented.
- Returns:
A counter incremented by 1, or
max_valif the maximum value is reached.
Examples
>>> import jax.numpy as jnp >>> import optax >>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32)) Array(2, dtype=int32) >>> optax.safe_increment(jnp.asarray(2147483647, dtype=jnp.int32)) Array(2147483647, dtype=int32)
Added in version 0.2.4.