Function transformation#

finitediffx.define_fdjvp(func, offsets=Offset(accuracy=2), step_size=None)[source]#

Define the JVP rule for a function using finite difference.

Parameters:
  • func – function to define the JVP rule for

  • offsets – offsets for the finite difference stencil. Accepted types are: - jax.Array defining location of function evaluation points. - Offset with accuracy field to automatically generate offsets. - tuple of Offset or jax.Array defining offsets for each argument.

  • step_size – step size for the finite difference stencil. Accepted types are: - float defining the step size for all arguments. - tuple of float defining the step size for each argument. - None to use the default step size for each argument.

Returns:

function with JVP rule defined using finite difference.

Return type:

Callable

Note

  • This function is motivated by [JEP](google/jax#15425)

  • This function can be used with jax.pure_callback to define the JVP

    rule for a function that is not differentiable by jax.

Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import finitediffx as fdx
>>> import functools as ft
>>> import numpy as onp
>>> def wrap_pure_callback(func):
...     @ft.wraps(func)
...     def wrapper(*args, **kwargs):
...         args = [jnp.asarray(arg) for arg in args]
...         func_ = lambda *a, **k: func(*a, **k).astype(a[0].dtype)
...         dtype_ = jax.ShapeDtypeStruct(
...             jnp.broadcast_shapes(*[ai.shape for ai in args]),
...             args[0].dtype,
...         )
...         return jax.pure_callback(func_, dtype_, *args, **kwargs, vectorized=True)
...     return wrapper
>>> @jax.jit
... @jax.grad
... @fdx.define_fdjvp
... @wrap_pure_callback
... def numpy_func(x, y):
...     return onp.sin(x) + onp.cos(y)
>>> print(numpy_func(1., 2.))
0.5402466
finitediffx.fgrad(func, *, argnums=0, step_size=None, offsets=Offset(accuracy=3), derivative=1, has_aux=False, average_gradients=False)[source]#

Finite difference derivative of a function with respect to one of its arguments.

Similar to jax.grad but with finite difference approximation.

Parameters:
  • func – function to differentiate

  • argnums – argument number to differentiate. Defaults to 0. If a tuple is passed, the function is differentiated with respect to all the arguments in the tuple.

  • step_size – step size for the finite difference stencil. If None, the step size is set to (2) ** (-23 / (2 * derivative))

  • offsets

    offsets for the finite difference stencil. Accepted types are:

    • jax.Array defining location of function evaluation points.

    • Offset with accuracy field to automatically generate offsets.

    • pytree of jax.Array/Offset to define offsets for each argument of the same pytree structure as argument defined by argnums.

  • derivative – derivative order. Defaults to 1.

  • has_aux – whether the function returns an auxiliary output. Defaults to False. If True, the derivative function will return a tuple of the form: (derivative, aux) otherwise it will return only the derivative.

  • average_gradients – whether to average the array gradients. Yields faster results. Defaults to False.

Returns:

Derivative of the function if has_aux is False, otherwise a tuple of the form: (derivative, aux)

Example

>>> import finitediffx as fdx
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x, y):
...    return x**2 + y**2
>>> df=fdx.fgrad(
...    func=f,
...    argnums=1,  # differentiate with respect to y
...    offsets=fdx.Offset(accuracy=2)  # use 2nd order accurate finite difference
... )
>>> df(2.0, 3.0)
Array(6., dtype=float32)
finitediffx.value_and_fgrad(func, *, argnums=0, step_size=None, offsets=Offset(accuracy=3), derivative=1, has_aux=False, average_gradients=False)[source]#

Finite difference derivative of a function with respect to one of its arguments.

Similar to jax.value_and_grad but with finite difference approximation

Parameters:
  • func – function to differentiate

  • argnums – argument number to differentiate. Defaults to 0. If a tuple is passed, the function is differentiated with respect to all the arguments in the tuple.

  • step_size – step size for the finite difference stencil. If None, the step size is set to (2) ** (-23 / (2 * derivative))

  • offsets

    offsets for the finite difference stencil. Accepted types are:

    • jax.Array defining location of function evaluation points.

    • Offset with accuracy field to automatically generate offsets.

    • pytree of jax.Array/ Offset to define offsets for each argument of the same pytree structure as argument defined by argnums.

  • derivative – derivative order. Defaults to 1.

  • has_aux – whether the function returns an auxiliary output. Defaults to False. If True, the derivative function will return a tuple of the form: ((value,aux), derivative) otherwise (value, derivative)

  • average_gradients – whether to average the array gradients. Yields faster results. Defaults to False.

Returns:

Value and derivative of the function if has_aux is False. If has_aux is True, the derivative function will return a tuple of the form: ((value,aux), derivative)

Example

>>> import finitediffx as fdx
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x, y):
...    return x**2 + y**2
>>> df=fdx.value_and_fgrad(
...    func=f,
...    argnums=1,  # differentiate with respect to y
...    offsets=fdx.Offset(accuracy=2)  # use 2nd order accurate finite difference
... )
>>> df(2.0, 3.0)
(13.0, Array(6., dtype=float32))
class finitediffx.Offset(accuracy)[source]#

Convinience class for finite difference offsets used inside fgrad() value_and_fgrad().

Parameters:

accuracy (int) – The accuracy of the finite difference scheme. Must be >=2.

Example

>>> import finitediffx as fdx
>>> fdx.fgrad(lambda x: x**2, offsets=fdx.Offset(accuracy=2))(1.0)
Array(2., dtype=float32)