Skip to content

liblaf.apple.model ¤

Classes:

Forward ¤

Parameters:

  • model ¤

    (Model) –
  • optimizer ¤

    (Optimizer, default: PNCG(jit=False, timer=False, max_steps=1000, norm=None, atol=Array(1.e-28, dtype=float32, weak_type=True), rtol=Array(1.e-05, dtype=float32, weak_type=True), d_hat=Array(inf, dtype=float32, weak_type=True)) ) –

Methods:

Attributes:

model instance-attribute ¤

model: Model

optimizer class-attribute instance-attribute ¤

optimizer: Optimizer = field(
    factory=lambda: PNCG(max_steps=1000)
)

u_full property ¤

u_full: Float[Array, 'points dim']

step ¤

step(callback: Callback | None = None) -> Solution
Source code in src/liblaf/apple/model/_forward.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def step(self, callback: Callback | None = None) -> Optimizer.Solution:
    objective = Objective(
        fun=self.model.fun,
        grad=self.model.grad,
        hess_diag=self.model.hess_diag,
        hess_prod=self.model.hess_prod,
        hess_quad=self.model.hess_quad,
        value_and_grad=self.model.value_and_grad,
        grad_and_hess_diag=self.model.grad_and_hess_diag,
    )
    solution: Optimizer.Solution = self.optimizer.minimize(
        objective, self.model.u_free, callback=callback
    )
    if not solution.success:
        logger.warning("Forward fail: %r", solution)
    self.model.u_free = solution.params
    return solution

update_params ¤

update_params(params: ModelParams) -> None
Source code in src/liblaf/apple/model/_forward.py
26
27
def update_params(self, params: ModelParams) -> None:
    self.model.update_params(params)

Model ¤

Parameters:

Methods:

Attributes:

dim property ¤

dim: int

dirichlet instance-attribute ¤

dirichlet: Dirichlet

jax instance-attribute ¤

jax: JaxModel

n_free property ¤

n_free: int

n_full property ¤

n_full: int

n_points property ¤

n_points: int

u_free property writable ¤

u_free: Free

u_full instance-attribute ¤

u_full: Full

warp instance-attribute ¤

fun ¤

fun(u: FreeOrFull) -> Scalar
Source code in src/liblaf/apple/model/_model.py
84
85
86
87
88
89
90
def fun(self, u: FreeOrFull) -> Scalar:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    output_jax: Scalar = self.jax.fun(u_full)
    output_wp: Scalar = self.warp.fun(u_full)
    output: Scalar = output_jax + output_wp
    return output

grad ¤

grad(u: FreeOrFull) -> FreeOrFull
Source code in src/liblaf/apple/model/_model.py
92
93
94
95
96
97
98
def grad(self, u: FreeOrFull) -> FreeOrFull:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    output_jax: Full = self.jax.grad(u_full)
    output_wp: Full = self.warp.grad(u_full)
    output: Full = output_jax + output_wp
    return self.to_shape_like(output, u)

grad_and_hess_diag ¤

grad_and_hess_diag(
    u: FreeOrFull,
) -> tuple[FreeOrFull, FreeOrFull]
Source code in src/liblaf/apple/model/_model.py
148
149
150
151
152
153
154
155
156
157
158
159
def grad_and_hess_diag(self, u: FreeOrFull) -> tuple[FreeOrFull, FreeOrFull]:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    grad_jax: Full
    hess_diag_jax: Full
    grad_jax, hess_diag_jax = self.jax.grad_and_hess_diag(u_full)
    grad_wp: Full
    hess_diag_wp: Full
    grad_wp, hess_diag_wp = self.warp.grad_and_hess_diag(u_full)
    grad: Full = grad_jax + grad_wp
    hess_diag: Full = hess_diag_jax + hess_diag_wp
    return self.to_shape_like(grad, u), self.to_shape_like(hess_diag, u)

hess_diag ¤

hess_diag(u: FreeOrFull) -> FreeOrFull
Source code in src/liblaf/apple/model/_model.py
100
101
102
103
104
105
106
def hess_diag(self, u: FreeOrFull) -> FreeOrFull:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    output_jax: Full = self.jax.hess_diag(u_full)
    output_wp: Full = self.warp.hess_diag(u_full)
    output: Full = output_jax + output_wp
    return self.to_shape_like(output, u)

hess_prod ¤

hess_prod(u: FreeOrFull, p: FreeOrFull) -> FreeOrFull
Source code in src/liblaf/apple/model/_model.py
108
109
110
111
112
113
114
115
def hess_prod(self, u: FreeOrFull, p: FreeOrFull) -> FreeOrFull:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    p_full: Full = self.to_full(p, 0.0)
    output_jax: Full = self.jax.hess_prod(u_full, p_full)
    output_wp: Full = self.warp.hess_prod(u_full, p_full)
    output: Full = output_jax + output_wp
    return self.to_shape_like(output, u)

hess_quad ¤

hess_quad(u: FreeOrFull, p: FreeOrFull) -> Scalar
Source code in src/liblaf/apple/model/_model.py
117
118
119
120
121
122
123
124
def hess_quad(self, u: FreeOrFull, p: FreeOrFull) -> Scalar:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    p_full: Full = self.to_full(p, 0.0)
    output_jax: Scalar = self.jax.hess_quad(u_full, p_full)
    output_wp: Scalar = self.warp.hess_quad(u_full, p_full)
    output: Scalar = output_jax + output_wp
    return output

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: FreeOrFull, p: FreeOrFull
) -> ModelParams
Source code in src/liblaf/apple/model/_model.py
126
127
128
129
130
131
132
133
def mixed_derivative_prod(self, u: FreeOrFull, p: FreeOrFull) -> ModelParams:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    p_full: Full = self.to_full(p, 0.0)
    outputs_jax: ModelParams = self.jax.mixed_derivative_prod(u_full, p_full)
    outputs_wp: ModelParams = self.warp.mixed_derivative_prod(u_full, p_full)
    outputs: ModelParams = tlz.merge(outputs_jax, outputs_wp)
    return outputs

to_free ¤

to_free(u: FreeOrFull) -> Free
Source code in src/liblaf/apple/model/_model.py
50
51
52
53
def to_free(self, u: FreeOrFull) -> Free:
    if u.size == self.n_free:
        return u.reshape((self.n_free,))
    return self.dirichlet.get_free(u)

to_full ¤

to_full(
    u: FreeOrFull,
    dirichlet: Float[ArrayLike, " dirichlet"] | None = None,
) -> Full
Source code in src/liblaf/apple/model/_model.py
55
56
57
58
59
60
def to_full(
    self, u: FreeOrFull, dirichlet: Float[ArrayLike, " dirichlet"] | None = None
) -> Full:
    if u.size == self.n_full:
        return u.reshape((self.n_points, self.dim))
    return self.dirichlet.to_full(u, dirichlet)

to_shape_like ¤

to_shape_like(u_full: Full, like: FreeOrFull) -> FreeOrFull
Source code in src/liblaf/apple/model/_model.py
62
63
64
65
def to_shape_like(self, u_full: Full, like: FreeOrFull) -> FreeOrFull:
    if u_full.size == like.size:
        return u_full.reshape(like.shape)
    return self.dirichlet.get_free(u_full)

update ¤

update(u: FreeOrFull) -> None
Source code in src/liblaf/apple/model/_model.py
67
68
69
70
71
72
73
def update(self, u: FreeOrFull) -> None:
    u_full: Full = self.to_full(u)
    if jnp.array_equiv(u_full, self.u_full):
        return
    self.u_full = u_full
    self.jax.update(u_full)
    self.warp.update(u_full)

update_params ¤

update_params(params: ModelParams) -> None
Source code in src/liblaf/apple/model/_model.py
75
76
77
78
79
80
81
82
def update_params(self, params: ModelParams) -> None:
    def pick(allowlist: Container[str], d: ModelParams) -> ModelParams:
        return tlz.keyfilter(lambda name: name in allowlist, d)

    params_jax: ModelParams = pick(self.jax.energies, params)
    params_warp: ModelParams = pick(self.warp.energies, params)
    self.jax.update_params(params_jax)
    self.warp.update_params(params_warp)

value_and_grad ¤

value_and_grad(u: FreeOrFull) -> tuple[Scalar, FreeOrFull]
Source code in src/liblaf/apple/model/_model.py
135
136
137
138
139
140
141
142
143
144
145
146
def value_and_grad(self, u: FreeOrFull) -> tuple[Scalar, FreeOrFull]:
    u_full: Full = self.to_full(u)
    self.update(u_full)
    value_jax: Scalar
    grad_jax: Full
    value_jax, grad_jax = self.jax.value_and_grad(u_full)
    value_wp: Scalar
    grad_wp: Full
    value_wp, grad_wp = self.warp.value_and_grad(u_full)
    value: Scalar = value_jax + value_wp
    grad: Full = grad_jax + grad_wp
    return value, self.to_shape_like(grad, u)

ModelBuilder ¤

ModelBuilder(dim: int = 3)

Parameters:

Methods:

Attributes:

Source code in src/liblaf/apple/model/_builder.py
27
28
29
def __init__(self, dim: int = 3) -> None:
    dirichlet: DirichletBuilder = DirichletBuilder(dim=dim)
    self.__attrs_init__(dirichlet=dirichlet)  # pyright: ignore[reportAttributeAccessIssue]

dirichlet class-attribute instance-attribute ¤

dirichlet: DirichletBuilder = field(
    factory=DirichletBuilder
)

jax class-attribute instance-attribute ¤

jax: JaxModelBuilder = field(factory=JaxModelBuilder)

n_points property ¤

n_points: int

warp class-attribute instance-attribute ¤

warp: WarpModelBuilder = field(factory=WarpModelBuilder)

add_dirichlet ¤

add_dirichlet(obj: DataSet) -> None
Source code in src/liblaf/apple/model/_builder.py
35
36
def add_dirichlet(self, obj: pv.DataSet) -> None:
    self.dirichlet.add_pyvista(obj)

add_energy ¤

add_energy(energy: JaxEnergy | WarpEnergy) -> None
Source code in src/liblaf/apple/model/_builder.py
38
39
40
41
42
43
44
def add_energy(self, energy: JaxEnergy | WarpEnergy) -> None:
    if isinstance(energy, JaxEnergy):
        self.jax.add_energy(energy)
    elif isinstance(energy, WarpEnergy):
        self.warp.add_energy(energy)
    else:
        raise TypeError(energy)

assign_global_ids ¤

assign_global_ids[T: DataSet](obj: T) -> T
Source code in src/liblaf/apple/model/_builder.py
46
47
48
49
50
51
def assign_global_ids[T: pv.DataSet](self, obj: T) -> T:
    start: int = self.n_points
    stop: int = start + obj.n_points
    self.dirichlet.resize(stop)
    obj.point_data[POINT_ID] = np.arange(start, stop)
    return obj

finalize ¤

finalize() -> Model
Source code in src/liblaf/apple/model/_builder.py
53
54
55
56
57
58
59
60
61
62
def finalize(self) -> Model:
    dirichlet: Dirichlet = self.dirichlet.finalize()
    u_full: Full = jnp.zeros((self.dirichlet.n_points, self.dirichlet.dim))
    u_full = dirichlet.set_dirichlet(u_full)
    return Model(
        dirichlet=dirichlet,
        u_full=u_full,
        jax=self.jax.finalize(),
        warp=WarpModelAdapter(self.warp.finalize()),
    )