Skip to content

liblaf.apple.warp.utils ยค

Functions:

to_warp ยค

to_warp(
    arr: ndarray | Array,
    dtype: int | tuple[int, int] | WarpDType | None = None,
    *,
    requires_grad: bool = ...,
) -> array
Source code in src/liblaf/apple/warp/utils/_to_warp.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def to_warp(arr: np.ndarray | Array, dtype: Any = None, **kwargs) -> wp.array:
    adapter: Adapter = _registry(arr)
    if dtype is None:
        return adapter.array_from(arr, **kwargs)
    if isinstance(dtype, int):
        length: int = dtype
        return adapter.array_from(
            arr, dtype=wp.types.vector(length, adapter.dtype_from(arr.dtype)), **kwargs
        )
    if isinstance(dtype, tuple):
        shape: tuple[int, int] = dtype
        return adapter.array_from(
            arr, dtype=wp.types.matrix(shape, adapter.dtype_from(arr.dtype)), **kwargs
        )
    return adapter.array_from(
        arr.astype(adapter.dtype_to(_type_scalar_type(dtype))), dtype, **kwargs
    )