Skip to content

liblaf.apple.jax.model ¤

Modules:

Classes:

Dirichlet ¤

Parameters:

  • dim ¤

    (int) –
  • dirichlet_index ¤

    (Integer[Array, dirichlet]) –
  • dirichlet_value ¤

    (Float[Array, dirichlet]) –
  • free_index ¤

    (Integer[Array, free]) –
  • n_points ¤

    (int) –

Methods:

Attributes:

dim instance-attribute ¤

dim: int

dirichlet_index instance-attribute ¤

dirichlet_index: Integer[Array, ' dirichlet']

dirichlet_value instance-attribute ¤

dirichlet_value: Float[Array, ' dirichlet']

free_index instance-attribute ¤

free_index: Integer[Array, ' free']

n_dirichlet property ¤

n_dirichlet: int

n_free property ¤

n_free: int

n_full property ¤

n_full: int

n_points instance-attribute ¤

n_points: int

get_dirichlet ¤

get_dirichlet(
    full: Float[Array, "points dim"],
) -> Float[Array, " dirichlet"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
27
28
29
30
31
@eqx.filter_jit
def get_dirichlet(
    self, full: Float[Array, "points dim"]
) -> Float[Array, " dirichlet"]:
    return full.flatten()[self.dirichlet_index]

get_free ¤

get_free(
    full: Float[Array, "points dim"],
) -> Float[Array, " free"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
33
34
35
@eqx.filter_jit
def get_free(self, full: Float[Array, "points dim"]) -> Float[Array, " free"]:
    return full.flatten()[self.free_index]

set_dirichlet ¤

set_dirichlet(
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
37
38
39
40
41
42
43
44
45
@eqx.filter_jit
def set_dirichlet(
    self,
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]:
    if values is None:
        values = self.dirichlet_value
    return full.flatten().at[self.dirichlet_index].set(values).reshape(full.shape)

set_free ¤

set_free(
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " free"],
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
47
48
49
50
51
@eqx.filter_jit
def set_free(
    self, full: Float[Array, "points dim"], values: Float[ArrayLike, " free"]
) -> Float[Array, "points dim"]:
    return full.flatten().at[self.free_index].set(values).reshape(full.shape)

to_full ¤

to_full(
    free: Float[Array, " free"],
    dirichlet: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
53
54
55
56
57
58
59
60
61
62
63
64
@eqx.filter_jit
def to_full(
    self,
    free: Float[Array, " free"],
    dirichlet: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]:
    full: Float[Array, "points dim"] = jnp.empty(
        (self.n_points, self.dim), free.dtype
    )
    full = self.set_free(full, free)
    full = self.set_dirichlet(full, dirichlet)
    return full

DirichletBuilder ¤

DirichletBuilder(dim: int = 3)

Parameters:

  • mask ¤

    (Bool[ndarray, 'points dim']) –
  • value ¤

    (Float[ndarray, 'points dim']) –

Methods:

Attributes:

Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
19
20
21
22
def __init__(self, dim: int = 3) -> None:
    mask: Bool[np.ndarray, "points dim"] = np.empty((0, dim), bool)
    value: Float[np.ndarray, "points dim"] = np.empty((0, dim))
    self.__attrs_init__(mask=mask, value=value)  # pyright: ignore[reportAttributeAccessIssue]

dim property ¤

dim: int

mask instance-attribute ¤

mask: Bool[ndarray, 'points dim']

n_points property ¤

n_points: int

value instance-attribute ¤

value: Float[ndarray, 'points dim']

add_pyvista ¤

add_pyvista(obj: DataSet) -> None
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
32
33
34
35
36
37
38
39
40
41
42
def add_pyvista(self, obj: pv.DataSet) -> None:
    point_id = obj.point_data[POINT_ID]
    self.resize(point_id.max() + 1)
    dirichlet_mask: Bool[Array, "points dim"] = self._left_broadcast_to(
        obj.point_data[DIRICHLET_MASK], obj.n_points
    )
    dirichlet_value: Float[Array, "points dim"] = self._left_broadcast_to(
        obj.point_data[DIRICHLET_VALUE], obj.n_points
    )
    self.mask[point_id] = dirichlet_mask
    self.value[point_id] = dirichlet_value

finalize ¤

finalize() -> Dirichlet
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
44
45
46
47
48
49
50
51
52
53
def finalize(self) -> Dirichlet:
    mask: Bool[Array, "points dim"] = jnp.asarray(self.mask)
    dirichlet_index: Integer[Array, " dirichlet"] = jnp.flatnonzero(mask)
    return Dirichlet(
        dim=self.dim,
        dirichlet_index=dirichlet_index,
        dirichlet_value=jnp.asarray(self.value.flat[dirichlet_index]),
        free_index=jnp.flatnonzero(~mask),
        n_points=self.n_points,
    )

resize ¤

resize(n_points: int) -> None
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
55
56
57
58
59
60
def resize(self, n_points: int) -> None:
    pad_after: int = n_points - self.n_points
    if pad_after <= 0:
        return
    self.mask = np.pad(self.mask, ((0, pad_after), (0, 0)), constant_values=False)
    self.value = np.pad(self.value, ((0, pad_after), (0, 0)), constant_values=0.0)

JaxEnergy ¤

Bases: IdMixin


              flowchart TD
              liblaf.apple.jax.model.JaxEnergy[JaxEnergy]
              liblaf.apple.utils._id_mixin.IdMixin[IdMixin]

                              liblaf.apple.utils._id_mixin.IdMixin --> liblaf.apple.jax.model.JaxEnergy
                


              click liblaf.apple.jax.model.JaxEnergy href "" "liblaf.apple.jax.model.JaxEnergy"
              click liblaf.apple.utils._id_mixin.IdMixin href "" "liblaf.apple.utils._id_mixin.IdMixin"
            

Parameters:

  • id ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –

Methods:

Attributes:

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = field(
    default=frozenset(), kw_only=True
)

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
30
31
def fun(self, u: Vector) -> Scalar:
    raise NotImplementedError

grad ¤

grad(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
33
34
35
36
@eqx.filter_jit
def grad(self, u: Vector) -> Updates:
    values: Vector = eqx.filter_grad(self.fun)(u)
    return values, jnp.arange(u.shape[0])

grad_and_hess_diag ¤

grad_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
61
62
63
@eqx.filter_jit
def grad_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    return self.grad(u), self.hess_diag(u)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
38
39
def hess_diag(self, u: Vector) -> Updates:
    raise NotImplementedError

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
41
42
43
44
45
@eqx.filter_jit
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(self.fun), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
47
48
49
50
51
52
@eqx.filter_jit
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], values)

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> EnergyParams
Source code in src/liblaf/apple/jax/model/_energy.py
65
66
67
68
69
70
71
@eqx.filter_jit
def mixed_derivative_prod(self, u: Vector, p: Vector) -> EnergyParams:
    outputs: EnergyParams = {
        name: getattr(self, f"mixed_derivative_prod_{name}")(u, p)
        for name in self.requires_grad
    }
    return outputs

update ¤

update(u: Vector) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
22
23
24
@eqx.filter_jit
def update(self, u: Vector) -> None:
    pass

update_params ¤

update_params(params: Mapping[str, Array]) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
26
27
28
def update_params(self, params: Mapping[str, Array]) -> None:
    for name, value in params.items():
        setattr(self, name, value)

value_and_grad ¤

value_and_grad(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
54
55
56
57
58
59
@eqx.filter_jit
def value_and_grad(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(self.fun)(u)
    return value, (grad, jnp.arange(u.shape[0]))

JaxModel ¤

Parameters:

  • energies ¤

    (dict[str, JaxEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

Attributes:

energies class-attribute instance-attribute ¤

energies: dict[str, JaxEnergy] = field(
    factory=dict, kw_only=True
)

fun ¤

fun(x: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_model.py
29
30
31
32
33
34
@eqx.filter_jit
def fun(self, x: Vector) -> Scalar:
    output: Scalar = jnp.zeros(())
    for energy in self.energies.values():
        output += energy.fun(x)
    return output

grad ¤

grad(x: Vector) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
36
37
38
39
40
41
42
43
44
@eqx.filter_jit
def grad(self, x: Vector) -> Vector:
    output: Vector = jnp.zeros_like(x)
    for energy in self.energies.values():
        grad: Vector
        index: Index
        grad, index = energy.grad(x)
        output = output.at[index].add(grad)
    return output

grad_and_hess_diag ¤

grad_and_hess_diag(x: Vector) -> tuple[Vector, Vector]
Source code in src/liblaf/apple/jax/model/_model.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@eqx.filter_jit
def grad_and_hess_diag(self, x: Vector) -> tuple[Vector, Vector]:
    grad: Vector = jnp.zeros_like(x)
    hess_diag: Vector = jnp.zeros_like(x)
    for energy in self.energies.values():
        grad_i: Vector
        index_g: Index
        hess_diag_i: Vector
        index_h: Index
        (grad_i, index_g), (hess_diag_i, index_h) = energy.grad_and_hess_diag(x)
        grad = grad.at[index_g].add(grad_i)
        hess_diag = hess_diag.at[index_h].add(hess_diag_i)
    return grad, hess_diag

hess_diag ¤

hess_diag(x: Vector) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
46
47
48
49
50
51
52
53
54
@eqx.filter_jit
def hess_diag(self, x: Vector) -> Vector:
    output: Vector = jnp.zeros_like(x)
    for energy in self.energies.values():
        diag: Vector
        index: Index
        diag, index = energy.hess_diag(x)
        output = output.at[index].add(diag)
    return output

hess_prod ¤

hess_prod(x: Vector, p: Vector) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
56
57
58
59
60
61
62
63
64
@eqx.filter_jit
def hess_prod(self, x: Vector, p: Vector) -> Vector:
    output: Vector = jnp.zeros_like(x)
    for energy in self.energies.values():
        prod: Vector
        index: Index
        prod, index = energy.hess_prod(x, p)
        output = output.at[index].add(prod)
    return output

hess_quad ¤

hess_quad(x: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_model.py
66
67
68
69
70
71
@eqx.filter_jit
def hess_quad(self, x: Vector, p: Vector) -> Scalar:
    output: Scalar = jnp.zeros(())
    for energy in self.energies.values():
        output += energy.hess_quad(x, p)
    return output

mixed_derivative_prod ¤

mixed_derivative_prod(
    x: Vector, p: Vector
) -> ModelParams
Source code in src/liblaf/apple/jax/model/_model.py
73
74
75
76
77
78
@eqx.filter_jit
def mixed_derivative_prod(self, x: Vector, p: Vector) -> ModelParams:
    return {
        name: energy.mixed_derivative_prod(x, p)
        for name, energy in self.energies.items()
    }

update ¤

update(x: Vector) -> None
Source code in src/liblaf/apple/jax/model/_model.py
21
22
23
def update(self, x: Vector) -> None:
    for energy in self.energies.values():
        energy.update(x)

update_params ¤

update_params(params: ModelParams) -> None
Source code in src/liblaf/apple/jax/model/_model.py
25
26
27
def update_params(self, params: ModelParams) -> None:
    for name, energy_params in params.items():
        self.energies[name].update_params(energy_params)

value_and_grad ¤

value_and_grad(x: Vector) -> tuple[Scalar, Vector]
Source code in src/liblaf/apple/jax/model/_model.py
80
81
82
83
84
85
86
87
88
89
90
@eqx.filter_jit
def value_and_grad(self, x: Vector) -> tuple[Scalar, Vector]:
    value: Scalar = jnp.zeros(())
    grad: Vector = jnp.zeros_like(x)
    for energy in self.energies.values():
        value_i: Scalar
        grad_i: Vector
        value_i, (grad_i, index) = energy.value_and_grad(x)
        value += value_i
        grad = grad.at[index].add(grad_i)
    return value, grad

JaxModelBuilder ¤

Parameters:

  • energies ¤

    (dict[str, JaxEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

Attributes:

energies class-attribute instance-attribute ¤

energies: dict[str, JaxEnergy] = field(
    factory=dict, kw_only=True
)

add_energy ¤

add_energy(energy: JaxEnergy) -> None
Source code in src/liblaf/apple/jax/model/_builder.py
11
12
def add_energy(self, energy: JaxEnergy) -> None:
    self.energies[energy.id] = energy

finalize ¤

finalize() -> JaxModel
Source code in src/liblaf/apple/jax/model/_builder.py
14
15
def finalize(self) -> JaxModel:
    return JaxModel(energies=self.energies)