liblaf.apple.jax.testing
¤
Functions:
-
assert_fraction_close– -
check_grad– -
check_jvp– -
matrices– -
numeric_jvp– -
seed– -
spd_matrix–
assert_fraction_close
¤
assert_fraction_close(
actual: ArrayLike,
expected: ArrayLike,
*,
atol: float = 0.0,
fraction: float = 0.01,
rtol: float = 1e-07,
) -> None
Source code in src/liblaf/apple/jax/testing/_close.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | |
check_grad
¤
check_grad(
fun: Callable[[Array], Array],
grad: Callable[[Array], Array],
x: Array,
*,
atol: float = 0.0,
fraction: float = 0.0,
rtol: float = 1e-07,
) -> None
Source code in src/liblaf/apple/jax/testing/_jvp.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 | |
check_jvp
¤
check_jvp(
fun: Callable[[Array], Array],
jvp: Callable[[Array, Array], Array],
primal: Array,
*,
atol: float = 0.0,
fraction: float = 0.0,
rtol: float = 1e-07,
) -> None
Source code in src/liblaf/apple/jax/testing/_jvp.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |
matrices
¤
matrices(
draw: DrawFn,
shape: Sequence[int],
dtype: DTypeLike = float64,
*,
min_dims: int = 1,
max_dims: int | None = 1,
) -> Array
Source code in src/liblaf/apple/jax/testing/_matrix.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | |
numeric_jvp
¤
numeric_jvp(
fun: Callable[[Array], Array],
primal: ArrayLike,
tangent: ArrayLike | None = None,
*,
eps: float = 0.0001,
) -> Array
Source code in src/liblaf/apple/jax/testing/_jvp.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | |
seed
¤
seed() -> SearchStrategy[int]
Source code in src/liblaf/apple/jax/testing/_random.py
4 5 | |
spd_matrix
¤
spd_matrix(
draw: DrawFn,
n: int = 3,
dtype: DTypeLike = float64,
*,
min_dims: int = 1,
max_dims: int | None = 1,
) -> Float[Array, "*batch D D"]
Source code in src/liblaf/apple/jax/testing/_matrix.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | |