# Copyright 2023 FiniteDiffX authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses as dc
import functools as ft
from typing import Any, Callable, Sequence, TypeVar, Union
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from typing_extensions import ParamSpec
from finitediffx._src.utils import _generate_central_offsets, generate_finitediff_coeffs
__all__ = ("fgrad", "Offset", "define_fdjvp", "value_and_fgrad")
P = ParamSpec("P")
T = TypeVar("T")
PyTree = Any
[docs]@dc.dataclass(frozen=True)
class Offset:
"""Convinience class for finite difference offsets used inside :func:`.fgrad`
:func:`.value_and_fgrad`.
Args:
accuracy: The accuracy of the finite difference scheme. Must be >=2.
Example:
>>> import finitediffx as fdx
>>> fdx.fgrad(lambda x: x**2, offsets=fdx.Offset(accuracy=2))(1.0)
Array(2., dtype=float32)
"""
accuracy: int
OffsetType = Union[jax.Array, Offset, PyTree]
StepsizeType = Union[jax.Array, float, PyTree]
def resolve_step_size(
step_size: StepsizeType | Sequence[StepsizeType] | None,
treedef: jtu.PyTreeDef,
derivative: int,
) -> Sequence[StepsizeType] | StepsizeType:
# return non-tuple values if length is None
length = treedef.num_leaves
if isinstance(step_size, (jax.Array, float)):
return (step_size,) * length
if step_size is None:
default = (2) ** (-23 / (2 * derivative))
return (default,) * length
step_size_leaves, step_size_treedef = jtu.tree_flatten(step_size)
if step_size_treedef == treedef:
# step_size is a pytree with the same structure as the input
return step_size_leaves
raise TypeError(
f"`step_size` must be of type:\n"
f"- `jax.Array`\n"
f"- `float`\n"
f"- tuple of `jax.Array` or `float` for tuple argnums.\n"
f"- pytree with the same structure as the desired arg.\n"
f"but got {type(step_size)=}"
)
def resolve_offsets(
offsets: Sequence[OffsetType | None] | OffsetType | None,
treedef: jax.tree_util.PyTreeDef,
derivative: int,
) -> tuple[OffsetType, ...] | OffsetType:
# single value
length = treedef.num_leaves
if isinstance(offsets, Offset):
if offsets.accuracy < 2:
raise ValueError(f"offset accuracy must be >=2, got {offsets.accuracy}")
offsets = jnp.array(_generate_central_offsets(derivative, offsets.accuracy))
return (offsets,) * length
if isinstance(offsets, jax.Array):
return (offsets,) * length
offsets_leaves, offsets_treedef = jtu.tree_flatten(offsets)
if offsets_treedef == treedef:
# offsets is a pytree with the same structure as the input
return offsets_leaves
raise TypeError(
f"`offsets` must be of type:\n"
f"- `Offset`\n"
f"- `jax.Array`\n"
f"- tuple of `Offset` or `jax.Array` for tuple argnums.\n"
f"- pytree with the same structure as the desired arg.\n"
f"but got {type(offsets)=}"
)
def _perturb_flat_args(
*,
flat_func: Callable,
coeffs: jax.Array,
flat_offsets: jax.Array,
flat_argnum: int,
flat_step_size: jax.Array,
derivative: int,
average_gradients: bool = False,
):
def flat_args_wrapper(*flat_args):
def scalar_perturb(*, h: float):
return flat_func(
*(
flat_args[:flat_argnum]
+ (flat_args[flat_argnum] + h,)
+ flat_args[flat_argnum + 1 :]
)
)
def array_perturb(*, h: float) -> jax.Array:
# should be much slower than jax.grad for large arrays
# but can be used for non-tracable code where jax.grad fails
size = flat_args[flat_argnum].size
indices = jnp.arange(size)
shape = flat_args[flat_argnum].shape
flat_array = jnp.array(flat_args[flat_argnum].reshape(-1))
def perturb_element(index):
return flat_func(
*(
flat_args[:flat_argnum]
+ (flat_array.at[index].add(h).reshape(shape),)
+ flat_args[flat_argnum + 1 :]
)
)
try:
# in case of tracable code (jax code)
result = jax.vmap(perturb_element)(indices)
except jax.errors.TracerArrayConversionError:
# non-tracable code e.g. numpy code
result = jnp.array([perturb_element(index) for index in indices])
if result.size > size:
raise TypeError("Non scalar-output.")
return result.reshape(shape)
def array_average_perturb(*, h: float) -> jax.Array:
# perturb the array all at once and average the result
# faster than array_perturb for large arrays but gives
# average gradient
shape = flat_args[flat_argnum].shape
size = flat_args[flat_argnum].size
result = flat_func(
*(
flat_args[:flat_argnum]
+ ((flat_args[flat_argnum] + h),)
+ flat_args[flat_argnum + 1 :]
)
)
if result.size > size:
raise TypeError("Non scalar-output.")
result = jnp.broadcast_to(result, shape)
result = result / result.size
return result
perturb_func = (
(array_average_perturb if average_gradients else array_perturb)
if isinstance(flat_args[flat_argnum], (np.ndarray, jax.Array))
else scalar_perturb
)
return sum(
coeff * perturb_func(h=dx) / flat_step_size**derivative
for coeff, dx in zip(coeffs, flat_offsets * flat_step_size)
)
return flat_args_wrapper
def _fgrad_along_argnum(
func: Callable,
*,
argnum: int = 0,
step_size: StepsizeType | None = None,
offsets: OffsetType = Offset(accuracy=3),
derivative: int = 1,
average_gradients: bool = False,
):
def wrapper(*args, **kwargs):
# full args/kwargs
flat_args, flat_treedef = jtu.tree_flatten(args[argnum])
def flat_func(*flat_args):
return func(
*(
args[:argnum]
+ (jtu.tree_unflatten(flat_treedef, flat_args),)
+ args[argnum + 1 :]
),
**kwargs,
)
step_size_ = resolve_step_size(step_size, flat_treedef, derivative)
offsets_ = resolve_offsets(offsets, flat_treedef, derivative)
flat_result = (
_perturb_flat_args(
flat_func=flat_func,
coeffs=generate_finitediff_coeffs(oi, derivative),
flat_offsets=oi,
flat_step_size=si,
flat_argnum=i,
derivative=derivative,
average_gradients=average_gradients,
)(*flat_args)
for i, (oi, si) in enumerate(zip(offsets_, step_size_))
)
return jtu.tree_unflatten(flat_treedef, flat_result)
return wrapper
[docs]def value_and_fgrad(
func: Callable[P, T],
*,
argnums: int | tuple[int, ...] = 0,
step_size: StepsizeType | None = None,
offsets: OffsetType = Offset(accuracy=3),
derivative: int = 1,
has_aux: bool = False,
average_gradients: bool = False,
):
"""Finite difference derivative of a function with respect to one of its arguments.
Similar to ``jax.value_and_grad`` but with finite difference approximation
Args:
func: function to differentiate
argnums: argument number to differentiate. Defaults to 0.
If a tuple is passed, the function is differentiated with respect to
all the arguments in the tuple.
step_size: step size for the finite difference stencil. If `None`, the step size
is set to ``(2) ** (-23 / (2 * derivative))``
offsets: offsets for the finite difference stencil. Accepted types are:
- ``jax.Array`` defining location of function evaluation points.
- :class:`Offset` with accuracy field to automatically generate offsets.
- pytree of ``jax.Array``/ :class:`.Offset` to define offsets for
each argument of the same pytree structure as argument defined
by ``argnums``.
derivative: derivative order. Defaults to 1.
has_aux: whether the function returns an auxiliary output. Defaults to
``False``. If True, the derivative function will return a tuple of
the form: ((value,aux), derivative) otherwise (value, derivative)
average_gradients: whether to average the array gradients. Yields faster
results. Defaults to ``False``.
Returns:
Value and derivative of the function if ``has_aux`` is ``False``.
If ``has_aux`` is True, the derivative function will return a tuple of the form:
((value,aux), derivative)
Example:
>>> import finitediffx as fdx
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x, y):
... return x**2 + y**2
>>> df=fdx.value_and_fgrad(
... func=f,
... argnums=1, # differentiate with respect to y
... offsets=fdx.Offset(accuracy=2) # use 2nd order accurate finite difference
... )
>>> df(2.0, 3.0)
(13.0, Array(6., dtype=float32))
"""
func.__doc__ = (
f"Finite difference derivative of {getattr(func,'__name__', func)}"
f" w.r.t {argnums=}\n\n{func.__doc__}"
)
if not isinstance(has_aux, bool):
raise TypeError(f"{type(has_aux)} not a bool")
func_ = (lambda *a, **k: func(*a, **k)[0]) if has_aux else func
if isinstance(argnums, int):
# fgrad(func, argnums=0)
dfunc = _fgrad_along_argnum(
func=func_,
argnum=argnums,
step_size=step_size,
offsets=offsets,
derivative=derivative,
average_gradients=average_gradients,
)
if has_aux is True:
@ft.wraps(func)
def wrapper(*a, **k):
value, aux = func(*a, **k)
return (value, aux), dfunc(*a, **k)
return wrapper
@ft.wraps(func)
def wrapper(*a, **k):
return func(*a, **k), dfunc(*a, **k)
return wrapper
if isinstance(argnums, tuple):
# fgrad(func, argnums=(0,1))
# return a tuple of derivatives evaluated at each argnum
# this behavior is similar to jax.grad(func, argnums=(...))
if not all(isinstance(ai, int) for ai in argnums):
raise TypeError(f"{argnums=} must be an integer or a tuple of integers")
if isinstance(offsets, tuple):
if len(offsets) != len(argnums):
raise AssertionError("offsets must have the same length as argnums")
else:
offsets = (offsets,) * len(argnums)
if isinstance(step_size, tuple):
if len(step_size) != len(argnums):
raise AssertionError("step_size must have the same length as argnums")
else:
step_size = (step_size,) * len(argnums)
dfuncs = [
_fgrad_along_argnum(
func=func_,
argnum=ai,
step_size=si,
offsets=oi,
derivative=derivative,
average_gradients=average_gradients,
)
for oi, si, ai in zip(offsets, step_size, argnums)
]
if has_aux:
@ft.wraps(func)
def wrapper(*a, **k):
# destructuring the tuple to ensure
# two item tuple is returned
value, aux = func(*a, **k)
return (value, aux), tuple(df(*a, **k) for df in dfuncs)
return wrapper
@ft.wraps(func)
def wrapper(*a, **k):
return func(*a, **k), tuple(df(*a, **k) for df in dfuncs)
return wrapper
raise TypeError(f"argnums must be an int or a tuple of ints, got {argnums}")
[docs]def fgrad(
func: Callable,
*,
argnums: int | tuple[int, ...] = 0,
step_size: StepsizeType | None = None,
offsets: OffsetType = Offset(accuracy=3),
derivative: int = 1,
has_aux: bool = False,
average_gradients: bool = False,
) -> Callable:
"""Finite difference derivative of a function with respect to one of its arguments.
Similar to ``jax.grad`` but with finite difference approximation.
Args:
func: function to differentiate
argnums: argument number to differentiate. Defaults to 0.
If a tuple is passed, the function is differentiated with respect to
all the arguments in the tuple.
step_size: step size for the finite difference stencil. If `None`, the step size
is set to `(2) ** (-23 / (2 * derivative))`
offsets: offsets for the finite difference stencil. Accepted types are:
- ``jax.Array`` defining location of function evaluation points.
- :class:`.Offset` with accuracy field to automatically generate offsets.
- pytree of ``jax.Array``/:class:`.Offset` to define offsets for
each argument of the same pytree structure as argument defined
by ``argnums``.
derivative: derivative order. Defaults to 1.
has_aux: whether the function returns an auxiliary output. Defaults to
``False``. If ``True``, the derivative function will return a tuple
of the form: (derivative, aux) otherwise it will return only the derivative.
average_gradients: whether to average the array gradients. Yields faster
results. Defaults to ``False``.
Returns:
Derivative of the function if ``has_aux`` is ``False``, otherwise a tuple of
the form: (derivative, aux)
Example:
>>> import finitediffx as fdx
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x, y):
... return x**2 + y**2
>>> df=fdx.fgrad(
... func=f,
... argnums=1, # differentiate with respect to y
... offsets=fdx.Offset(accuracy=2) # use 2nd order accurate finite difference
... )
>>> df(2.0, 3.0)
Array(6., dtype=float32)
"""
value_and_fgrad_func = value_and_fgrad(
func=func,
argnums=argnums,
step_size=step_size,
offsets=offsets,
derivative=derivative,
has_aux=has_aux,
average_gradients=average_gradients,
)
if has_aux:
@ft.wraps(func)
def wrapper(*a, **k):
(_, aux), g = value_and_fgrad_func(*a, **k)
return g, aux
return wrapper
@ft.wraps(func)
def wrapper(*a, **k):
_, g = value_and_fgrad_func(*a, **k)
return g
return wrapper
[docs]def define_fdjvp(
func: Callable[P, T],
offsets: tuple[OffsetType, ...] | OffsetType = Offset(accuracy=2),
step_size: tuple[float, ...] | float | None = None,
) -> Callable[P, T]:
"""Define the JVP rule for a function using finite difference.
Args:
func: function to define the JVP rule for
offsets: offsets for the finite difference stencil. Accepted types are:
- ``jax.Array`` defining location of function evaluation points.
- :class:`.Offset` with accuracy field to automatically generate offsets.
- tuple of `Offset` or ``jax.Array`` defining offsets for each argument.
step_size: step size for the finite difference stencil. Accepted types are:
- ``float`` defining the step size for all arguments.
- ``tuple`` of ``float`` defining the step size for each argument.
- ``None`` to use the default step size for each argument.
Returns:
Callable: function with JVP rule defined using finite difference.
Note:
- This function is motivated by [``JEP``](https://github.com/google/jax/issues/15425)
- This function can be used with ``jax.pure_callback`` to define the JVP
rule for a function that is not differentiable by ``jax``.
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import finitediffx as fdx
>>> import functools as ft
>>> import numpy as onp
>>> def wrap_pure_callback(func):
... @ft.wraps(func)
... def wrapper(*args, **kwargs):
... args = [jnp.asarray(arg) for arg in args]
... func_ = lambda *a, **k: func(*a, **k).astype(a[0].dtype)
... dtype_ = jax.ShapeDtypeStruct(
... jnp.broadcast_shapes(*[ai.shape for ai in args]),
... args[0].dtype,
... )
... return jax.pure_callback(func_, dtype_, *args, **kwargs, vectorized=True)
... return wrapper
>>> @jax.jit
... @jax.grad
... @fdx.define_fdjvp
... @wrap_pure_callback
... def numpy_func(x, y):
... return onp.sin(x) + onp.cos(y)
>>> print(numpy_func(1., 2.))
0.5402466
"""
func = jax.custom_jvp(func)
@func.defjvp
def _(primals, tangents):
kwargs = dict(treedef=jtu.tree_structure(primals), derivative=1)
step_size_ = resolve_step_size(step_size, **kwargs)
offsets_ = resolve_offsets(offsets, **kwargs)
primal_out = func(*primals)
tangent_out = sum(
fgrad(func, argnums=i, step_size=si, offsets=oi)(*primals) * ti
for i, (si, oi, ti) in enumerate(zip(step_size_, offsets_, tangents))
)
return jnp.array(primal_out), jnp.array(tangent_out)
return func