Skip to content

liblaf.apple.inverse ¤

Classes:

Inverse ¤

Bases: ABC


              flowchart TD
              liblaf.apple.inverse.Inverse[Inverse]

              

              click liblaf.apple.inverse.Inverse href "" "liblaf.apple.inverse.Inverse"
            

Parameters:

  • forward ¤

    (Forward) –
  • adjoint_solver ¤

    (LinearSolver, default: <dynamic> ) –
  • optimizer ¤

    (Optimizer, default: ScipyOptimizer(max_steps=256, jit=False, timer=False, method='L-BFGS-B', tol=1e-05, options=None) ) –

Methods:

Attributes:

adjoint_solver class-attribute instance-attribute ¤

adjoint_solver: LinearSolver = field(
    factory=CompositeSolver, kw_only=True
)

forward instance-attribute ¤

forward: Forward

model property ¤

model: Model

optimizer class-attribute instance-attribute ¤

optimizer: Optimizer = field(
    factory=lambda: ScipyOptimizer(
        method="L-BFGS-B", tol=1e-05
    ),
    kw_only=True,
)

adjoint ¤

adjoint(u: Full, dLdu: Full) -> Full
Source code in src/liblaf/apple/inverse/_inverse.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def adjoint(self, u: Full, dLdu: Full) -> Full:
    u_free: Free = self.model.to_free(u)
    preconditioner: Free = jnp.reciprocal(self.model.hess_diag(u_free))

    def matvec(p_free: Free) -> Free:
        return self.model.hess_prod(u_free, p_free)

    def preconditioner_fn(p_free: Free) -> Free:
        return preconditioner * p_free

    system = LinearSystem(
        matvec=matvec,
        b=-self.model.to_free(dLdu),
        rmatvec=matvec,
        preconditioner=preconditioner_fn,
        rpreconditioner=preconditioner_fn,
    )
    solution: LinearSolver.Solution = self.adjoint_solver.solve(
        system, jnp.zeros_like(u_free)
    )
    if not solution.success:
        logger.warning("Adjoint fail: %r", solution)
    logger.info("Adjoint time: %g sec", solution.stats.time)
    return self.model.to_full(solution.params, 0.0)

fun ¤

fun(params: ParamsT) -> tuple[Scalar, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
67
68
69
70
71
72
def fun(self, params: ParamsT) -> tuple[Scalar, AuxT]:
    model_params: ModelParams = self.make_params(params)
    self.model.update_params(model_params)
    solution: Optimizer.Solution = self.forward.step()
    logger.info("Forward time: %g sec", solution.stats.time)
    return self.loss(self.model.u_full, model_params)

loss abstractmethod ¤

loss(u: Full, params: ModelParams) -> tuple[Scalar, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
74
75
76
@abc.abstractmethod
def loss(self, u: Full, params: ModelParams) -> tuple[Scalar, AuxT]:
    raise NotImplementedError

loss_and_grad ¤

loss_and_grad(
    u: Full, params: ModelParams
) -> tuple[Scalar, Full, ModelParams, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
78
79
80
81
82
83
84
85
86
87
88
89
@eqx.filter_jit
def loss_and_grad(
    self, u: Full, params: ModelParams
) -> tuple[Scalar, Full, ModelParams, AuxT]:
    loss: Scalar
    aux: AuxT
    dLdu: Full
    dLdq: ModelParams
    (loss, aux), (dLdu, dLdq) = jax.value_and_grad(
        self.loss, argnums=(0, 1), has_aux=True
    )(u, params)
    return loss, dLdu, dLdq, aux

make_params abstractmethod ¤

make_params(params: ParamsT) -> ModelParams
Source code in src/liblaf/apple/inverse/_inverse.py
91
92
93
@abc.abstractmethod
def make_params(self, params: ParamsT) -> ModelParams:
    raise NotImplementedError

solve ¤

solve(
    params: ParamsT, callback: Callback | None = None
) -> Solution
Source code in src/liblaf/apple/inverse/_inverse.py
 95
 96
 97
 98
 99
100
101
102
103
104
def solve(
    self, params: ParamsT, callback: Callback | None = None
) -> Optimizer.Solution:
    objective = Objective(value_and_grad=self.value_and_grad)
    optimizer_solution: Optimizer.Solution = self.optimizer.minimize(
        objective, params, callback=callback
    )
    if not optimizer_solution.success:
        logger.warning("Inverse fail: %r", optimizer_solution)
    return optimizer_solution

value_and_grad ¤

value_and_grad(
    params: ParamsT,
) -> tuple[Scalar, ParamsT, AuxT]
Source code in src/liblaf/apple/inverse/_inverse.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def value_and_grad(self, params: ParamsT) -> tuple[Scalar, ParamsT, AuxT]:
    model_params: ModelParams
    model_params_vjp: Callable[[ModelParams], ParamsT]
    model_params, model_params_vjp = jax.vjp(self.make_params, params)
    self.model.update_params(model_params)
    solution: Optimizer.Solution = self.forward.step()
    logger.info("Forward time: %g sec", solution.stats.time)
    u_full: Full = self.model.u_full
    loss: Scalar
    dLdu: Full
    dLdq: ModelParams
    aux: AuxT
    loss, dLdu, dLdq, aux = self.loss_and_grad(u_full, model_params)
    p: Full = self.adjoint(u_full, dLdu)
    prod: ModelParams = self.model.mixed_derivative_prod(u_full, p)
    model_params_grad: ModelParams = jax.tree.map(operator.add, dLdq, prod)
    grad: ParamsT = model_params_vjp(model_params_grad)
    return loss, grad, aux