Skip to content

liblaf.apple.warp.model ¤

Classes:

WarpEnergy ¤

Bases: IdMixin


              flowchart TD
              liblaf.apple.warp.model.WarpEnergy[WarpEnergy]
              liblaf.apple.utils._id_mixin.IdMixin[IdMixin]

                              liblaf.apple.utils._id_mixin.IdMixin --> liblaf.apple.warp.model.WarpEnergy
                


              click liblaf.apple.warp.model.WarpEnergy href "" "liblaf.apple.warp.model.WarpEnergy"
              click liblaf.apple.utils._id_mixin.IdMixin href "" "liblaf.apple.utils._id_mixin.IdMixin"
            

Parameters:

Methods:

Attributes:

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

fun ¤

fun(u: Vector, output: Scalar) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
27
28
def fun(self, u: Vector, output: Scalar) -> None:
    raise NotImplementedError

grad ¤

grad(u: Vector, output: Vector) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
30
31
def grad(self, u: Vector, output: Vector) -> None:
    raise NotImplementedError

grad_and_hess_diag ¤

grad_and_hess_diag(
    u: Vector, grad: Vector, hess_diag: Vector
) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
49
50
51
def grad_and_hess_diag(self, u: Vector, grad: Vector, hess_diag: Vector) -> None:
    self.grad(u, grad)
    self.hess_diag(u, hess_diag)

hess_diag ¤

hess_diag(u: Vector, output: Vector) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
33
34
def hess_diag(self, u: Vector, output: Vector) -> None:
    raise NotImplementedError

hess_prod ¤

hess_prod(u: Vector, p: Vector, output: Vector) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
36
37
def hess_prod(self, u: Vector, p: Vector, output: Vector) -> None:
    raise NotImplementedError

hess_quad ¤

hess_quad(u: Vector, p: Vector, output: Scalar) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
39
40
def hess_quad(self, u: Vector, p: Vector, output: Scalar) -> None:
    raise NotImplementedError

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, array]
Source code in src/liblaf/apple/warp/model/_energy.py
42
43
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, wp.array]:
    raise NotImplementedError

update ¤

update(u: Vector) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
19
20
def update(self, u: Vector) -> None:
    pass

update_params ¤

update_params(params: EnergyParams) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
22
23
24
25
def update_params(self, params: EnergyParams) -> None:
    for name, value in params.items():
        param: wp.array = getattr(self, name)
        wp.copy(param, wpu.to_warp(value, param.dtype))

value_and_grad ¤

value_and_grad(
    u: Vector, value: Scalar, grad: Vector
) -> None
Source code in src/liblaf/apple/warp/model/_energy.py
45
46
47
def value_and_grad(self, u: Vector, value: Scalar, grad: Vector) -> None:
    self.fun(u, value)
    self.grad(u, grad)

WarpModel ¤

Parameters:

  • energies ¤

    (dict[str, WarpEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

Attributes:

energies class-attribute instance-attribute ¤

energies: dict[str, WarpEnergy] = field(factory=dict)

fun ¤

fun(u: Vector, output: Scalar) -> None
Source code in src/liblaf/apple/warp/model/_model.py
27
28
29
def fun(self, u: Vector, output: Scalar) -> None:
    for energy in self.energies.values():
        energy.fun(u, output)

grad ¤

grad(u: Vector, output: Vector) -> None
Source code in src/liblaf/apple/warp/model/_model.py
31
32
33
def grad(self, u: Vector, output: Vector) -> None:
    for energy in self.energies.values():
        energy.grad(u, output)

grad_and_hess_diag ¤

grad_and_hess_diag(
    u: Vector, grad: Vector, hess_diag: Vector
) -> None
Source code in src/liblaf/apple/warp/model/_model.py
60
61
62
def grad_and_hess_diag(self, u: Vector, grad: Vector, hess_diag: Vector) -> None:
    for energy in self.energies.values():
        energy.grad_and_hess_diag(u, grad, hess_diag)

hess_diag ¤

hess_diag(u: Vector, output: Vector) -> None
Source code in src/liblaf/apple/warp/model/_model.py
35
36
37
def hess_diag(self, u: Vector, output: Vector) -> None:
    for energy in self.energies.values():
        energy.hess_diag(u, output)

hess_prod ¤

hess_prod(u: Vector, p: Vector, output: Vector) -> None
Source code in src/liblaf/apple/warp/model/_model.py
39
40
41
def hess_prod(self, u: Vector, p: Vector, output: Vector) -> None:
    for energy in self.energies.values():
        energy.hess_prod(u, p, output)

hess_quad ¤

hess_quad(u: Vector, p: Vector, output: Scalar) -> None
Source code in src/liblaf/apple/warp/model/_model.py
43
44
45
def hess_quad(self, u: Vector, p: Vector, output: Scalar) -> None:
    for energy in self.energies.values():
        energy.hess_quad(u, p, output)

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, dict[str, array]]
Source code in src/liblaf/apple/warp/model/_model.py
47
48
49
50
51
52
53
54
def mixed_derivative_prod(
    self, u: Vector, p: Vector
) -> dict[str, dict[str, wp.array]]:
    output: dict[str, dict[str, wp.array]] = {
        name: energy.mixed_derivative_prod(u, p)
        for name, energy in self.energies.items()
    }
    return output

update ¤

update(u: Vector) -> None
Source code in src/liblaf/apple/warp/model/_model.py
19
20
21
def update(self, u: Vector) -> None:
    for energy in self.energies.values():
        energy.update(u)

update_params ¤

update_params(params: ModelParams) -> None
Source code in src/liblaf/apple/warp/model/_model.py
23
24
25
def update_params(self, params: ModelParams) -> None:
    for name, energy_params in params.items():
        self.energies[name].update_params(energy_params)

value_and_grad ¤

value_and_grad(
    u: Vector, value: Scalar, grad: Vector
) -> None
Source code in src/liblaf/apple/warp/model/_model.py
56
57
58
def value_and_grad(self, u: Vector, value: Scalar, grad: Vector) -> None:
    for energy in self.energies.values():
        energy.value_and_grad(u, value, grad)

WarpModelAdapter ¤

Parameters:

Methods:

Attributes:

energies property ¤

energies: Mapping[str, WarpEnergy]

wrapped instance-attribute ¤

wrapped: WarpModel

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/warp/model/_adapter.py
32
33
34
35
36
def fun(self, u: Vector) -> Scalar:
    u_wp: wp.array = _to_warp(u)
    output_wp: wp.array = wp.zeros((1,), dtype=wp.dtype_from_jax(u.dtype))
    self.wrapped.fun(u_wp, output_wp)
    return wp.to_jax(output_wp)[0]

grad ¤

grad(u: Vector) -> Vector
Source code in src/liblaf/apple/warp/model/_adapter.py
38
39
40
41
42
def grad(self, u: Vector) -> Vector:
    u_wp: wp.array = _to_warp(u)
    output_wp: wp.array = wp.zeros_like(u_wp)
    self.wrapped.grad(u_wp, output_wp)
    return wp.to_jax(output_wp)

grad_and_hess_diag ¤

grad_and_hess_diag(u: Vector) -> tuple[Vector, Vector]
Source code in src/liblaf/apple/warp/model/_adapter.py
86
87
88
89
90
91
92
93
def grad_and_hess_diag(self, u: Vector) -> tuple[Vector, Vector]:
    u_wp: wp.array = _to_warp(u)
    grad_wp: wp.array = wp.zeros_like(u_wp)
    hess_diag_wp: wp.array = wp.zeros_like(u_wp)
    self.wrapped.grad_and_hess_diag(u_wp, grad_wp, hess_diag_wp)
    grad: Vector = wp.to_jax(grad_wp)
    hess_diag: Vector = wp.to_jax(hess_diag_wp)
    return grad, hess_diag

hess_diag ¤

hess_diag(u: Vector) -> Vector
Source code in src/liblaf/apple/warp/model/_adapter.py
44
45
46
47
48
def hess_diag(self, u: Vector) -> Vector:
    u_wp: wp.array = _to_warp(u)
    output_wp: wp.array = wp.zeros_like(u_wp)
    self.wrapped.hess_diag(u_wp, output_wp)
    return wp.to_jax(output_wp)

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Vector
Source code in src/liblaf/apple/warp/model/_adapter.py
50
51
52
53
54
55
def hess_prod(self, u: Vector, p: Vector) -> Vector:
    u_wp: wp.array = _to_warp(u)
    p_wp: wp.array = _to_warp(p)
    output_wp: wp.array = wp.zeros_like(u_wp)
    self.wrapped.hess_prod(u_wp, p_wp, output_wp)
    return wp.to_jax(output_wp)

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/warp/model/_adapter.py
57
58
59
60
61
62
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    u_wp: wp.array = _to_warp(u)
    p_wp: wp.array = _to_warp(p)
    output_wp: wp.array = wp.zeros((1,), dtype=wp.dtype_from_jax(u.dtype))
    self.wrapped.hess_quad(u_wp, p_wp, output_wp)
    return wp.to_jax(output_wp)[0]

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, dict[str, Array]]
Source code in src/liblaf/apple/warp/model/_adapter.py
64
65
66
67
68
69
70
71
72
73
74
75
def mixed_derivative_prod(
    self, u: Vector, p: Vector
) -> dict[str, dict[str, Array]]:
    u_wp: wp.array = _to_warp(u)
    p_wp: wp.array = _to_warp(p)
    outputs_wp: dict[str, dict[str, wp.array]] = self.wrapped.mixed_derivative_prod(
        u_wp, p_wp
    )
    outputs: dict[str, dict[str, Array]] = tlz.valmap(
        lambda energy_dict: tlz.valmap(wp.to_jax, energy_dict), outputs_wp
    )
    return outputs

update ¤

update(u: Vector) -> None
Source code in src/liblaf/apple/warp/model/_adapter.py
25
26
27
def update(self, u: Vector) -> None:
    u_wp: wp.array = _to_warp(u)
    self.wrapped.update(u_wp)

update_params ¤

update_params(params: ModelParams) -> None
Source code in src/liblaf/apple/warp/model/_adapter.py
29
30
def update_params(self, params: ModelParams) -> None:
    self.wrapped.update_params(params)

value_and_grad ¤

value_and_grad(u: Vector) -> tuple[Scalar, Vector]
Source code in src/liblaf/apple/warp/model/_adapter.py
77
78
79
80
81
82
83
84
def value_and_grad(self, u: Vector) -> tuple[Scalar, Vector]:
    u_wp: wp.array = _to_warp(u)
    value_wp: wp.array = wp.zeros((1,), dtype=wp.dtype_from_jax(u.dtype))
    grad_wp: wp.array = wp.zeros_like(u_wp)
    self.wrapped.value_and_grad(u_wp, value_wp, grad_wp)
    value: Scalar = wp.to_jax(value_wp)[0]
    grad: Vector = wp.to_jax(grad_wp)
    return value, grad

WarpModelBuilder ¤

Parameters:

  • energies ¤

    (dict[str, WarpEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

Attributes:

energies class-attribute instance-attribute ¤

energies: dict[str, WarpEnergy] = field(factory=dict)

add_energy ¤

add_energy(energy: WarpEnergy) -> None
Source code in src/liblaf/apple/warp/model/_builder.py
11
12
def add_energy(self, energy: WarpEnergy) -> None:
    self.energies[energy.id] = energy

finalize ¤

finalize() -> WarpModel
Source code in src/liblaf/apple/warp/model/_builder.py
14
15
def finalize(self) -> WarpModel:
    return WarpModel(energies=self.energies)