Function transformation#
- finitediffx.define_fdjvp(func, offsets=Offset(accuracy=2), step_size=None)[source]#
Define the JVP rule for a function using finite difference.
- Parameters:
func – function to define the JVP rule for
offsets – offsets for the finite difference stencil. Accepted types are: -
jax.Arraydefining location of function evaluation points. -Offsetwith accuracy field to automatically generate offsets. - tuple of Offset orjax.Arraydefining offsets for each argument.step_size – step size for the finite difference stencil. Accepted types are: -
floatdefining the step size for all arguments. -tupleoffloatdefining the step size for each argument. -Noneto use the default step size for each argument.
- Returns:
function with JVP rule defined using finite difference.
- Return type:
Callable
Note
This function is motivated by [
JEP](google/jax#15425)- This function can be used with
jax.pure_callbackto define the JVP rule for a function that is not differentiable by
jax.
- This function can be used with
- Example:
>>> import jax >>> import jax.numpy as jnp >>> import finitediffx as fdx >>> import functools as ft >>> import numpy as onp >>> def wrap_pure_callback(func): ... @ft.wraps(func) ... def wrapper(*args, **kwargs): ... args = [jnp.asarray(arg) for arg in args] ... func_ = lambda *a, **k: func(*a, **k).astype(a[0].dtype) ... dtype_ = jax.ShapeDtypeStruct( ... jnp.broadcast_shapes(*[ai.shape for ai in args]), ... args[0].dtype, ... ) ... return jax.pure_callback(func_, dtype_, *args, **kwargs, vectorized=True) ... return wrapper
>>> @jax.jit ... @jax.grad ... @fdx.define_fdjvp ... @wrap_pure_callback ... def numpy_func(x, y): ... return onp.sin(x) + onp.cos(y) >>> print(numpy_func(1., 2.)) 0.5402466
- finitediffx.fgrad(func, *, argnums=0, step_size=None, offsets=Offset(accuracy=3), derivative=1, has_aux=False, average_gradients=False)[source]#
Finite difference derivative of a function with respect to one of its arguments.
Similar to
jax.gradbut with finite difference approximation.- Parameters:
func – function to differentiate
argnums – argument number to differentiate. Defaults to 0. If a tuple is passed, the function is differentiated with respect to all the arguments in the tuple.
step_size – step size for the finite difference stencil. If None, the step size is set to (2) ** (-23 / (2 * derivative))
offsets –
offsets for the finite difference stencil. Accepted types are:
derivative – derivative order. Defaults to 1.
has_aux – whether the function returns an auxiliary output. Defaults to
False. IfTrue, the derivative function will return a tuple of the form: (derivative, aux) otherwise it will return only the derivative.average_gradients – whether to average the array gradients. Yields faster results. Defaults to
False.
- Returns:
Derivative of the function if
has_auxisFalse, otherwise a tuple of the form: (derivative, aux)
Example
>>> import finitediffx as fdx >>> import jax >>> import jax.numpy as jnp >>> def f(x, y): ... return x**2 + y**2 >>> df=fdx.fgrad( ... func=f, ... argnums=1, # differentiate with respect to y ... offsets=fdx.Offset(accuracy=2) # use 2nd order accurate finite difference ... ) >>> df(2.0, 3.0) Array(6., dtype=float32)
- finitediffx.value_and_fgrad(func, *, argnums=0, step_size=None, offsets=Offset(accuracy=3), derivative=1, has_aux=False, average_gradients=False)[source]#
Finite difference derivative of a function with respect to one of its arguments.
Similar to
jax.value_and_gradbut with finite difference approximation- Parameters:
func – function to differentiate
argnums – argument number to differentiate. Defaults to 0. If a tuple is passed, the function is differentiated with respect to all the arguments in the tuple.
step_size – step size for the finite difference stencil. If None, the step size is set to
(2) ** (-23 / (2 * derivative))offsets –
offsets for the finite difference stencil. Accepted types are:
derivative – derivative order. Defaults to 1.
has_aux – whether the function returns an auxiliary output. Defaults to
False. If True, the derivative function will return a tuple of the form: ((value,aux), derivative) otherwise (value, derivative)average_gradients – whether to average the array gradients. Yields faster results. Defaults to
False.
- Returns:
Value and derivative of the function if
has_auxisFalse. Ifhas_auxis True, the derivative function will return a tuple of the form: ((value,aux), derivative)
Example
>>> import finitediffx as fdx >>> import jax >>> import jax.numpy as jnp >>> def f(x, y): ... return x**2 + y**2 >>> df=fdx.value_and_fgrad( ... func=f, ... argnums=1, # differentiate with respect to y ... offsets=fdx.Offset(accuracy=2) # use 2nd order accurate finite difference ... ) >>> df(2.0, 3.0) (13.0, Array(6., dtype=float32))
- class finitediffx.Offset(accuracy)[source]#
Convinience class for finite difference offsets used inside
fgrad()value_and_fgrad().- Parameters:
accuracy (
int) – The accuracy of the finite difference scheme. Must be >=2.
Example
>>> import finitediffx as fdx >>> fdx.fgrad(lambda x: x**2, offsets=fdx.Offset(accuracy=2))(1.0) Array(2., dtype=float32)