Skip to content

liblaf.apple.jax.testing ¤

Functions:

assert_fraction_close ¤

assert_fraction_close(
    actual: ArrayLike,
    expected: ArrayLike,
    *,
    atol: float = 0.0,
    fraction: float = 0.01,
    rtol: float = 1e-07,
) -> None
Source code in src/liblaf/apple/jax/testing/_close.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def assert_fraction_close(
    actual: npt.ArrayLike,
    expected: npt.ArrayLike,
    *,
    atol: float = 0.0,
    fraction: float = 0.01,
    rtol: float = 1e-7,
) -> None:
    __tracebackhide__ = True
    actual: np.ndarray = np.asarray(actual)
    expected: np.ndarray = np.asarray(expected)
    diff: np.ndarray = np.abs(actual - expected)
    n_fail: int = np.count_nonzero(diff > atol + rtol * np.abs(expected))  # pyright: ignore[reportAssignmentType]
    if n_fail < fraction * actual.size:
        return
    np.testing.assert_allclose(actual, expected, rtol=rtol, atol=atol)

check_grad ¤

check_grad(
    fun: Callable[[Array], Array],
    grad: Callable[[Array], Array],
    x: Array,
    *,
    atol: float = 0.0,
    fraction: float = 0.0,
    rtol: float = 1e-07,
) -> None
Source code in src/liblaf/apple/jax/testing/_jvp.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def check_grad(
    fun: Callable[[Array], Array],
    grad: Callable[[Array], Array],
    x: Array,
    *,
    atol: float = 0.0,
    fraction: float = 0.0,
    rtol: float = 1e-7,
) -> None:
    __tracebackhide__ = True
    tangent: Array = _rand_like(x)
    actual: Array = jnp.vdot(grad(x), tangent)
    expected: Array = numeric_jvp(fun, x, tangent)
    assert_fraction_close(actual, expected, fraction=fraction, atol=atol, rtol=rtol)

check_jvp ¤

check_jvp(
    fun: Callable[[Array], Array],
    jvp: Callable[[Array, Array], Array],
    primal: Array,
    *,
    atol: float = 0.0,
    fraction: float = 0.0,
    rtol: float = 1e-07,
) -> None
Source code in src/liblaf/apple/jax/testing/_jvp.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def check_jvp(
    fun: Callable[[Array], Array],
    jvp: Callable[[Array, Array], Array],
    primal: Array,
    *,
    atol: float = 0.0,
    fraction: float = 0.0,
    rtol: float = 1e-7,
) -> None:
    __tracebackhide__ = True
    tangent: Array = _rand_like(primal)
    actual: Array = jvp(primal, tangent)
    expected: Array = numeric_jvp(fun, primal, tangent)
    assert_fraction_close(actual, expected, fraction=fraction, atol=atol, rtol=rtol)

matrices ¤

matrices(
    draw: DrawFn,
    shape: Sequence[int],
    dtype: DTypeLike = float64,
    *,
    min_dims: int = 1,
    max_dims: int | None = 1,
) -> Array
Source code in src/liblaf/apple/jax/testing/_matrix.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@st.composite
def matrices(
    draw: st.DrawFn,
    shape: Sequence[int],
    dtype: DTypeLike = np.float64,
    *,
    min_dims: int = 1,
    max_dims: int | None = 1,
) -> Array:
    batch: Sequence[int] = draw(
        st.shared(hnp.array_shapes(min_dims=min_dims, max_dims=max_dims), key="batch")
    )
    key: Key = jax.random.key(draw(seed()))
    arr: Array = jax.random.uniform(
        key, (*batch, *shape), dtype, minval=-1.0, maxval=1.0
    )
    return arr

numeric_jvp ¤

numeric_jvp(
    fun: Callable[[Array], Array],
    primal: ArrayLike,
    tangent: ArrayLike | None = None,
    *,
    eps: float = 0.0001,
) -> Array
Source code in src/liblaf/apple/jax/testing/_jvp.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def numeric_jvp(
    fun: Callable[[Array], Array],
    primal: ArrayLike,
    tangent: ArrayLike | None = None,
    *,
    eps: float = 1e-4,
) -> Array:
    primal = jnp.asarray(primal)
    if tangent is None:
        key: Key = jax.random.key(0)
        tangent = jax.random.uniform(
            key, primal.shape, primal.dtype, minval=-1.0, maxval=1.0
        )
    else:
        tangent = jnp.asarray(tangent)
    f0: Array = fun(primal - 0.5 * eps * tangent)
    f1: Array = fun(primal + 0.5 * eps * tangent)
    output: Array = (f1 - f0) / eps
    return output

seed ¤

seed() -> SearchStrategy[int]
Source code in src/liblaf/apple/jax/testing/_random.py
4
5
def seed() -> st.SearchStrategy[int]:
    return st.integers(min_value=-(2**31), max_value=2**31 - 1)

spd_matrix ¤

spd_matrix(
    draw: DrawFn,
    n: int = 3,
    dtype: DTypeLike = float64,
    *,
    min_dims: int = 1,
    max_dims: int | None = 1,
) -> Float[Array, "*batch D D"]
Source code in src/liblaf/apple/jax/testing/_matrix.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@st.composite
def spd_matrix(
    draw: st.DrawFn,
    n: int = 3,
    dtype: DTypeLike = np.float64,
    *,
    min_dims: int = 1,
    max_dims: int | None = 1,
) -> Float[Array, "*batch D D"]:
    shape: Sequence[int] = draw(
        st.shared(hnp.array_shapes(min_dims=min_dims, max_dims=max_dims), key="batch")
    )
    key: Key = jax.random.key(draw(seed()))
    arr: Array = jax.random.uniform(key, (*shape, n, n), dtype, minval=-1.0, maxval=1.0)
    arr = 0.5 * (arr.mT + arr) + 2.0 * n * np.identity(n, dtype)
    return arr