Skip to content

liblaf.apple.jax.math ยค

Functions:

normalize ยค

normalize(
    v: Float[Array, "*batch dim"],
) -> Float[Array, "*batch dim"]
Source code in src/liblaf/apple/jax/math/_normalize.py
5
6
7
def normalize(v: Float[Array, "*batch dim"]) -> Float[Array, "*batch dim"]:
    norm: Float[Array, "*batch 1"] = jnp.linalg.norm(v, axis=-1, keepdims=True)
    return v / jnp.where(norm == 0.0, 1.0, norm)