Skip to content

liblaf.apple.jax.energies ¤

Classes:

Gravity ¤

Bases: JaxEnergy


              flowchart TD
              liblaf.apple.jax.energies.Gravity[Gravity]
              liblaf.apple.jax.model._energy.JaxEnergy[JaxEnergy]
              liblaf.apple.utils._id_mixin.IdMixin[IdMixin]

                              liblaf.apple.jax.model._energy.JaxEnergy --> liblaf.apple.jax.energies.Gravity
                                liblaf.apple.utils._id_mixin.IdMixin --> liblaf.apple.jax.model._energy.JaxEnergy
                



              click liblaf.apple.jax.energies.Gravity href "" "liblaf.apple.jax.energies.Gravity"
              click liblaf.apple.jax.model._energy.JaxEnergy href "" "liblaf.apple.jax.model._energy.JaxEnergy"
              click liblaf.apple.utils._id_mixin.IdMixin href "" "liblaf.apple.utils._id_mixin.IdMixin"
            

Parameters:

  • id ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –
  • gravity ¤

    (Float[Array, dim]) –
  • indices ¤

    (Integer[Array, points]) –
  • mass ¤

    (Float[Array, points]) –

Methods:

Attributes:

gravity instance-attribute ¤

gravity: Float[Array, ' dim']

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

indices instance-attribute ¤

indices: Integer[Array, ' points']

mass instance-attribute ¤

mass: Float[Array, ' points']

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = field(
    default=frozenset(), kw_only=True
)

from_pyvista classmethod ¤

from_pyvista(
    obj: DataSet,
    gravity: Float[ArrayLike, " dim"] | None = None,
) -> Self
Source code in src/liblaf/apple/jax/energies/_gravity.py
23
24
25
26
27
28
29
30
31
32
33
@classmethod
def from_pyvista(
    cls, obj: pv.DataSet, gravity: Float[ArrayLike, " dim"] | None = None
) -> Self:
    if gravity is None:
        gravity = jnp.asarray([0.0, -9.81, 0.0])
    return cls(
        gravity=jnp.asarray(gravity),
        indices=jnp.asarray(obj.point_data[POINT_ID]),
        mass=jnp.asarray(obj.point_data[MASS]),
    )

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/energies/_gravity.py
35
36
37
38
@override
def fun(self, u: Vector) -> Scalar:
    u = u[self.indices]
    return -jnp.vdot(self.mass, jnp.vecdot(u, self.gravity, axis=-1))

grad ¤

grad(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
33
34
35
36
@eqx.filter_jit
def grad(self, u: Vector) -> Updates:
    values: Vector = eqx.filter_grad(self.fun)(u)
    return values, jnp.arange(u.shape[0])

grad_and_hess_diag ¤

grad_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
61
62
63
@eqx.filter_jit
def grad_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    return self.grad(u), self.hess_diag(u)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_gravity.py
40
41
42
@override
def hess_diag(self, u: Vector) -> Updates:
    return jnp.zeros_like(u[self.indices]), self.indices

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
41
42
43
44
45
@eqx.filter_jit
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(self.fun), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
47
48
49
50
51
52
@eqx.filter_jit
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], values)

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> EnergyParams
Source code in src/liblaf/apple/jax/model/_energy.py
65
66
67
68
69
70
71
@eqx.filter_jit
def mixed_derivative_prod(self, u: Vector, p: Vector) -> EnergyParams:
    outputs: EnergyParams = {
        name: getattr(self, f"mixed_derivative_prod_{name}")(u, p)
        for name in self.requires_grad
    }
    return outputs

update ¤

update(u: Vector) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
22
23
24
@eqx.filter_jit
def update(self, u: Vector) -> None:
    pass

update_params ¤

update_params(params: Mapping[str, Array]) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
26
27
28
def update_params(self, params: Mapping[str, Array]) -> None:
    for name, value in params.items():
        setattr(self, name, value)

value_and_grad ¤

value_and_grad(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
54
55
56
57
58
59
@eqx.filter_jit
def value_and_grad(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(self.fun)(u)
    return value, (grad, jnp.arange(u.shape[0]))

MassSpring ¤

Bases: JaxEnergy


              flowchart TD
              liblaf.apple.jax.energies.MassSpring[MassSpring]
              liblaf.apple.jax.model._energy.JaxEnergy[JaxEnergy]
              liblaf.apple.utils._id_mixin.IdMixin[IdMixin]

                              liblaf.apple.jax.model._energy.JaxEnergy --> liblaf.apple.jax.energies.MassSpring
                                liblaf.apple.utils._id_mixin.IdMixin --> liblaf.apple.jax.model._energy.JaxEnergy
                



              click liblaf.apple.jax.energies.MassSpring href "" "liblaf.apple.jax.energies.MassSpring"
              click liblaf.apple.jax.model._energy.JaxEnergy href "" "liblaf.apple.jax.model._energy.JaxEnergy"
              click liblaf.apple.utils._id_mixin.IdMixin href "" "liblaf.apple.utils._id_mixin.IdMixin"
            

Parameters:

  • id ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –
  • edges ¤

    (Integer[Array, 'edges 2']) –
  • length ¤

    (Float[Array, edges]) –
  • points ¤

    (Float[Array, 'edges 2 3']) –
  • stiffness ¤

    (Float[Array, edges]) –

Methods:

Attributes:

edges instance-attribute ¤

edges: Integer[Array, ' edges 2']

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

length instance-attribute ¤

length: Float[Array, ' edges']

n_edges property ¤

n_edges: int

points instance-attribute ¤

points: Float[Array, 'edges 2 3']

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = field(
    default=frozenset(), kw_only=True
)

stiffness instance-attribute ¤

stiffness: Float[Array, ' edges']

from_pyvista classmethod ¤

from_pyvista(obj: PolyData) -> Self
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@classmethod
def from_pyvista(cls, obj: pv.PolyData) -> Self:
    if LENGTH not in obj.cell_data:
        obj = obj.compute_cell_sizes(length=True, area=False, volume=False)  # pyright: ignore[reportAssignmentType]
    point_id: Integer[np.ndarray, " points"] = obj.point_data[POINT_ID]
    edges: Integer[np.ndarray, "edges 2"] = obj.lines.reshape((-1, 3))[:, 1:]
    length: Float[Array, " edges"] = jnp.asarray(obj.cell_data[LENGTH])
    if jnp.any(length < 0.0):
        logger.warning("Length < 0")
    return cls(
        edges=jnp.asarray(point_id[edges]),
        length=length,
        points=jnp.asarray(obj.points[edges]),
        stiffness=jnp.asarray(obj.cell_data[STIFFNESS]),
    )

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
49
50
51
52
53
54
55
56
57
58
@override
def fun(self, u: Vector) -> Scalar:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    energy: Float[Array, " edges"] = (
        0.5
        * self.stiffness
        * jnp.square(jnp.linalg.norm(delta, axis=-1) - self.length)
    )
    return jnp.sum(energy)

grad ¤

grad(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
def grad(self, u: Vector) -> Updates:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    length: Float[Array, " edges"] = jnp.linalg.norm(delta, axis=-1)
    direction: Float[Array, "edges 3"] = (
        delta / jnp.where(length > 0, length, 1.0)[:, jnp.newaxis]
    )
    force: Float[Array, "edges 3"] = (
        self.stiffness[:, jnp.newaxis]
        * (length - self.length)[:, jnp.newaxis]
        * direction
    )
    grad: Float[Array, "edges 2 3"] = jnp.stack([-force, force], axis=1)
    return grad, self.edges.flatten()

grad_and_hess_diag ¤

grad_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
61
62
63
@eqx.filter_jit
def grad_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    return self.grad(u), self.hess_diag(u)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
76
77
78
79
80
81
@override
def hess_diag(self, u: Vector) -> Updates:
    values: Float[Array, "edges*2 3"] = einops.repeat(
        self.stiffness, "edges -> (edges i) j", i=2, j=3
    )
    return values, self.edges.flatten()

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
41
42
43
44
45
@eqx.filter_jit
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(self.fun), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
47
48
49
50
51
52
@eqx.filter_jit
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], values)

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> EnergyParams
Source code in src/liblaf/apple/jax/model/_energy.py
65
66
67
68
69
70
71
@eqx.filter_jit
def mixed_derivative_prod(self, u: Vector, p: Vector) -> EnergyParams:
    outputs: EnergyParams = {
        name: getattr(self, f"mixed_derivative_prod_{name}")(u, p)
        for name in self.requires_grad
    }
    return outputs

update ¤

update(u: Vector) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
22
23
24
@eqx.filter_jit
def update(self, u: Vector) -> None:
    pass

update_params ¤

update_params(params: Mapping[str, Array]) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
26
27
28
def update_params(self, params: Mapping[str, Array]) -> None:
    for name, value in params.items():
        setattr(self, name, value)

value_and_grad ¤

value_and_grad(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
54
55
56
57
58
59
@eqx.filter_jit
def value_and_grad(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(self.fun)(u)
    return value, (grad, jnp.arange(u.shape[0]))

MassSpringPrestrain ¤

Bases: MassSpring


              flowchart TD
              liblaf.apple.jax.energies.MassSpringPrestrain[MassSpringPrestrain]
              liblaf.apple.jax.energies._mass_spring.MassSpring[MassSpring]
              liblaf.apple.jax.model._energy.JaxEnergy[JaxEnergy]
              liblaf.apple.utils._id_mixin.IdMixin[IdMixin]

                              liblaf.apple.jax.energies._mass_spring.MassSpring --> liblaf.apple.jax.energies.MassSpringPrestrain
                                liblaf.apple.jax.model._energy.JaxEnergy --> liblaf.apple.jax.energies._mass_spring.MassSpring
                                liblaf.apple.utils._id_mixin.IdMixin --> liblaf.apple.jax.model._energy.JaxEnergy
                




              click liblaf.apple.jax.energies.MassSpringPrestrain href "" "liblaf.apple.jax.energies.MassSpringPrestrain"
              click liblaf.apple.jax.energies._mass_spring.MassSpring href "" "liblaf.apple.jax.energies._mass_spring.MassSpring"
              click liblaf.apple.jax.model._energy.JaxEnergy href "" "liblaf.apple.jax.model._energy.JaxEnergy"
              click liblaf.apple.utils._id_mixin.IdMixin href "" "liblaf.apple.utils._id_mixin.IdMixin"
            

Parameters:

  • id ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –
  • edges ¤

    (Integer[Array, 'edges 2']) –
  • length ¤

    (Float[Array, edges]) –
  • points ¤

    (Float[Array, 'edges 2 3']) –
  • stiffness ¤

    (Float[Array, edges]) –

Methods:

Attributes:

edges instance-attribute ¤

edges: Integer[Array, ' edges 2']

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

length instance-attribute ¤

length: Float[Array, ' edges']

n_edges property ¤

n_edges: int

points instance-attribute ¤

points: Float[Array, 'edges 2 3']

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = field(
    default=frozenset(), kw_only=True
)

stiffness instance-attribute ¤

stiffness: Float[Array, ' edges']

from_pyvista classmethod ¤

from_pyvista(obj: PolyData) -> Self
Source code in src/liblaf/apple/jax/energies/_mass_spring_prestrain.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@classmethod
def from_pyvista(cls, obj: pv.PolyData) -> Self:
    if LENGTH not in obj.cell_data:
        obj = obj.compute_cell_sizes(length=True, area=False, volume=False)  # pyright: ignore[reportAssignmentType]
    point_id: Integer[np.ndarray, " points"] = obj.point_data[POINT_ID]
    edges: Integer[np.ndarray, "edges 2"] = obj.lines.reshape((-1, 3))[:, 1:]
    length: Float[Array, " edges"] = jnp.asarray(obj.cell_data[LENGTH])
    if jnp.any(length < 0.0):
        logger.warning("Length < 0")
    prestrain: Float[Array, " edges"] = jnp.asarray(obj.cell_data[PRESTRAIN])
    return cls(
        edges=jnp.asarray(point_id[edges]),
        length=length * (1.0 + prestrain),
        points=jnp.asarray(obj.points[edges]),
        stiffness=jnp.asarray(obj.cell_data[STIFFNESS]),
    )

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
49
50
51
52
53
54
55
56
57
58
@override
def fun(self, u: Vector) -> Scalar:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    energy: Float[Array, " edges"] = (
        0.5
        * self.stiffness
        * jnp.square(jnp.linalg.norm(delta, axis=-1) - self.length)
    )
    return jnp.sum(energy)

grad ¤

grad(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
def grad(self, u: Vector) -> Updates:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    length: Float[Array, " edges"] = jnp.linalg.norm(delta, axis=-1)
    direction: Float[Array, "edges 3"] = (
        delta / jnp.where(length > 0, length, 1.0)[:, jnp.newaxis]
    )
    force: Float[Array, "edges 3"] = (
        self.stiffness[:, jnp.newaxis]
        * (length - self.length)[:, jnp.newaxis]
        * direction
    )
    grad: Float[Array, "edges 2 3"] = jnp.stack([-force, force], axis=1)
    return grad, self.edges.flatten()

grad_and_hess_diag ¤

grad_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
61
62
63
@eqx.filter_jit
def grad_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    return self.grad(u), self.hess_diag(u)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
76
77
78
79
80
81
@override
def hess_diag(self, u: Vector) -> Updates:
    values: Float[Array, "edges*2 3"] = einops.repeat(
        self.stiffness, "edges -> (edges i) j", i=2, j=3
    )
    return values, self.edges.flatten()

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
41
42
43
44
45
@eqx.filter_jit
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(self.fun), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
47
48
49
50
51
52
@eqx.filter_jit
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], values)

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> EnergyParams
Source code in src/liblaf/apple/jax/model/_energy.py
65
66
67
68
69
70
71
@eqx.filter_jit
def mixed_derivative_prod(self, u: Vector, p: Vector) -> EnergyParams:
    outputs: EnergyParams = {
        name: getattr(self, f"mixed_derivative_prod_{name}")(u, p)
        for name in self.requires_grad
    }
    return outputs

update ¤

update(u: Vector) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
22
23
24
@eqx.filter_jit
def update(self, u: Vector) -> None:
    pass

update_params ¤

update_params(params: Mapping[str, Array]) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
26
27
28
def update_params(self, params: Mapping[str, Array]) -> None:
    for name, value in params.items():
        setattr(self, name, value)

value_and_grad ¤

value_and_grad(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
54
55
56
57
58
59
@eqx.filter_jit
def value_and_grad(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(self.fun)(u)
    return value, (grad, jnp.arange(u.shape[0]))