🔢 Finite difference jvp rule#
define_fdjvp combines custom_jvp and fgrad to define custom finite difference rules,when used with pure_callback it can to make non-tracable code works within jax machinary.
This example is based on the comment from ``jax`` proposed`JEP <google/jax#15425>`__
For example this code will fail to work with jax transformations, becasue it uses numpy functions.
[2]:
import functools as ft
import jax
import jax.numpy as jnp
import numpy as onp
import finitediffx as fdx
import functools as ft
def numpy_func(x: onp.ndarray) -> onp.ndarray:
return onp.power(x, 2)
try:
jax.grad(numpy_func)(2.0)
except jax.errors.TracerArrayConversionError as e:
print(e)
def wrap_pure_callback(func):
@ft.wraps(func)
def wrapper(*args, **kwargs):
args = [jnp.asarray(arg) for arg in args]
func_ = lambda *a, **k: func(*a, **k).astype(a[0].dtype)
dtype_ = jax.ShapeDtypeStruct(
jnp.broadcast_shapes(*[ai.shape for ai in args]),
args[0].dtype,
)
return jax.pure_callback(func_, dtype_, *args, **kwargs, vectorized=True)
return wrapper
@jax.jit # -> can compile
@jax.grad # -> can take gradient
@ft.partial(
fdx.define_fdjvp,
# automatically generate offsets
offsets=fdx.Offset(accuracy=4),
# manually set step size
step_size=1e-3,
)
@wrap_pure_callback
def numpy_func(x: onp.ndarray) -> onp.ndarray:
return onp.power(x, 2)
print(numpy_func(1.0))
# 1.9999794
@jax.jit # -> can compile
@jax.grad # -> can take gradient
@ft.partial(
fdx.define_fdjvp,
# provide the desired evaluation points for the finite difference stencil
# in this case its centered finite difference (f(x-1) - f(x+1))/(2*step_size)
offsets=jnp.array([1, -1]),
# manually set step size
step_size=1e-3,
)
@wrap_pure_callback
def numpy_func(x: onp.ndarray) -> onp.ndarray:
return onp.power(x, 2)
print(numpy_func(1.0))
# 2.0000048
The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
1.9999794
2.0000048