Bases: ABC
flowchart TD
liblaf.apple.inverse.Inverse[Inverse]
click liblaf.apple.inverse.Inverse href "" "liblaf.apple.inverse.Inverse"
Parameters:
-
forward
(Forward)
–
-
adjoint_solver
(LinearSolver, default:
<dynamic>
)
–
-
optimizer
(Optimizer, default:
ScipyOptimizer(max_steps=256, jit=False, timer=False, method='L-BFGS-B', tol=1e-05, options=None)
)
–
Methods:
Attributes:
adjoint_solver
class-attribute
instance-attribute
adjoint_solver: LinearSolver = field(
factory=CompositeSolver, kw_only=True
)
forward
instance-attribute
optimizer
class-attribute
instance-attribute
optimizer: Optimizer = field(
factory=lambda: ScipyOptimizer(
method="L-BFGS-B", tol=1e-05
),
kw_only=True,
)
adjoint
adjoint(u: Full, dLdu: Full) -> Full
Source code in src/liblaf/apple/inverse/_inverse.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65 | def adjoint(self, u: Full, dLdu: Full) -> Full:
u_free: Free = self.model.to_free(u)
preconditioner: Free = jnp.reciprocal(self.model.hess_diag(u_free))
def matvec(p_free: Free) -> Free:
return self.model.hess_prod(u_free, p_free)
def preconditioner_fn(p_free: Free) -> Free:
return preconditioner * p_free
system = LinearSystem(
matvec=matvec,
b=-self.model.to_free(dLdu),
rmatvec=matvec,
preconditioner=preconditioner_fn,
rpreconditioner=preconditioner_fn,
)
solution: LinearSolver.Solution = self.adjoint_solver.solve(
system, jnp.zeros_like(u_free)
)
if not solution.success:
logger.warning("Adjoint fail: %r", solution)
logger.info("Adjoint time: %g sec", solution.stats.time)
return self.model.to_full(solution.params, 0.0)
|
fun
fun(params: ParamsT) -> tuple[Scalar, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
| def fun(self, params: ParamsT) -> tuple[Scalar, AuxT]:
model_params: ModelParams = self.make_params(params)
self.model.update_params(model_params)
solution: Optimizer.Solution = self.forward.step()
logger.info("Forward time: %g sec", solution.stats.time)
return self.loss(self.model.u_full, model_params)
|
loss
abstractmethod
loss(u: Full, params: ModelParams) -> tuple[Scalar, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
| @abc.abstractmethod
def loss(self, u: Full, params: ModelParams) -> tuple[Scalar, AuxT]:
raise NotImplementedError
|
loss_and_grad
loss_and_grad(
u: Full, params: ModelParams
) -> tuple[Scalar, Full, ModelParams, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
78
79
80
81
82
83
84
85
86
87
88
89 | @eqx.filter_jit
def loss_and_grad(
self, u: Full, params: ModelParams
) -> tuple[Scalar, Full, ModelParams, AuxT]:
loss: Scalar
aux: AuxT
dLdu: Full
dLdq: ModelParams
(loss, aux), (dLdu, dLdq) = jax.value_and_grad(
self.loss, argnums=(0, 1), has_aux=True
)(u, params)
return loss, dLdu, dLdq, aux
|
make_params
abstractmethod
make_params(params: ParamsT) -> ModelParams
Source code in src/liblaf/apple/inverse/_inverse.py
| @abc.abstractmethod
def make_params(self, params: ParamsT) -> ModelParams:
raise NotImplementedError
|
solve
solve(
params: ParamsT, callback: Callback | None = None
) -> Solution
Source code in src/liblaf/apple/inverse/_inverse.py
95
96
97
98
99
100
101
102
103
104 | def solve(
self, params: ParamsT, callback: Callback | None = None
) -> Optimizer.Solution:
objective = Objective(value_and_grad=self.value_and_grad)
optimizer_solution: Optimizer.Solution = self.optimizer.minimize(
objective, params, callback=callback
)
if not optimizer_solution.success:
logger.warning("Inverse fail: %r", optimizer_solution)
return optimizer_solution
|
value_and_grad
value_and_grad(
params: ParamsT,
) -> tuple[Scalar, ParamsT, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123 | def value_and_grad(self, params: ParamsT) -> tuple[Scalar, ParamsT, AuxT]:
model_params: ModelParams
model_params_vjp: Callable[[ModelParams], ParamsT]
model_params, model_params_vjp = jax.vjp(self.make_params, params)
self.model.update_params(model_params)
solution: Optimizer.Solution = self.forward.step()
logger.info("Forward time: %g sec", solution.stats.time)
u_full: Full = self.model.u_full
loss: Scalar
dLdu: Full
dLdq: ModelParams
aux: AuxT
loss, dLdu, dLdq, aux = self.loss_and_grad(u_full, model_params)
p: Full = self.adjoint(u_full, dLdu)
prod: ModelParams = self.model.mixed_derivative_prod(u_full, p)
model_params_grad: ModelParams = jax.tree.map(operator.add, dLdq, prod)
grad: ParamsT = model_params_vjp(model_params_grad)
return loss, grad, aux
|