🔢 Finite difference jvp rule#

define_fdjvp combines custom_jvp and fgrad to define custom finite difference rules,when used with pure_callback it can to make non-tracable code works within jax machinary.

This example is based on the comment from ``jax`` proposed`JEP <google/jax#15425>`__

For example this code will fail to work with jax transformations, becasue it uses numpy functions.

[2]:
import functools as ft

import jax
import jax.numpy as jnp
import numpy as onp
import finitediffx as fdx
import functools as ft


def numpy_func(x: onp.ndarray) -> onp.ndarray:
    return onp.power(x, 2)


try:
    jax.grad(numpy_func)(2.0)
except jax.errors.TracerArrayConversionError as e:
    print(e)


def wrap_pure_callback(func):
    @ft.wraps(func)
    def wrapper(*args, **kwargs):
        args = [jnp.asarray(arg) for arg in args]
        func_ = lambda *a, **k: func(*a, **k).astype(a[0].dtype)
        dtype_ = jax.ShapeDtypeStruct(
            jnp.broadcast_shapes(*[ai.shape for ai in args]),
            args[0].dtype,
        )
        return jax.pure_callback(func_, dtype_, *args, **kwargs, vectorized=True)

    return wrapper


@jax.jit  # -> can compile
@jax.grad  # -> can take gradient
@ft.partial(
    fdx.define_fdjvp,
    # automatically generate offsets
    offsets=fdx.Offset(accuracy=4),
    # manually set step size
    step_size=1e-3,
)
@wrap_pure_callback
def numpy_func(x: onp.ndarray) -> onp.ndarray:
    return onp.power(x, 2)


print(numpy_func(1.0))
# 1.9999794


@jax.jit  # -> can compile
@jax.grad  # -> can take gradient
@ft.partial(
    fdx.define_fdjvp,
    # provide the desired evaluation points for the finite difference stencil
    # in this case its centered finite difference (f(x-1) - f(x+1))/(2*step_size)
    offsets=jnp.array([1, -1]),
    # manually set step size
    step_size=1e-3,
)
@wrap_pure_callback
def numpy_func(x: onp.ndarray) -> onp.ndarray:
    return onp.power(x, 2)


print(numpy_func(1.0))
# 2.0000048
The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
1.9999794
2.0000048