Array finite difference#
- finitediffx.curl(array, *, accuracy=1, step_size=1, method='central', keepdims=True)[source]#
Compute the ∇×F of input array where F is a vector field whose components are the first axis of x and returns a vector field
Index notation: εijk dFk/dxj
- Parameters:
x – input array where the leading axis is the dimension of the vector field
accuracy – accuracy order of the gradient. Default is 1, can be a tuple for each axis
step_size – step size. Default is 1, can be a tuple for each axis
method – the method to use (forward, central, backward). Default is central
keepdims – whether to keep the leading dimension of the vector field (only for 2D)
Example
>>> # Curl for a 3D vector field is defined as >>> # F = (F1, F2, F3) >>> # ∇×F = (dF3/dy - dF2/dz, dF1/dz - dF3/dx, dF2/dx - dF1/dy) >>> import finitediffx as fdx >>> import jax.numpy as jnp >>> import numpy.testing as npt >>> import jax >>> with jax.experimental.enable_x64(): ... x,y,z = [jnp.linspace(0, 1, 100)] * 3 ... dx,dy,dz = x[1]-x[0], y[1]-y[0], z[1]-z[0] ... X,Y,Z = jnp.meshgrid(x,y,z, indexing="ij") ... F1 = X**2 + Y**3 ... F2 = X**4 + Y**3 ... F3 = jnp.zeros_like(F1) ... F = jnp.stack([F1,F2,F3], axis=0) ... curlF = fdx.curl(F, step_size=(dx,dy,dz), accuracy=6) ... curlF_exact = jnp.stack([F1*0,F1*0, 4*X**3 - 3*Y**2], axis=0) ... npt.assert_allclose(curlF, curlF_exact, atol=1e-7)
>>> # Curl of 2D vector field is defined as >>> x,y = [jnp.linspace(-1,1,50)]*2 >>> dx,dy = x[1]-x[0],y[1]-y[0] >>> X,Y = jnp.meshgrid(x,y, indexing="ij") >>> F1 = jnp.sin(Y) >>> F2 = jnp.cos(X) >>> F = jnp.stack([F1,F2], axis=0) >>> curl = fdx.curl(F, accuracy=4, step_size=dx)
- finitediffx.difference(array, *, axis=0, accuracy=1, step_size=1.0, derivative=1, method='central')[source]#
Compute the finite difference derivative along a given axis with a given accuracy using central difference for interior points and forward/backward difference for boundary points Similar to np.gradient, but with the option to specify accuracy, derivative and step size See: google/jax
- Parameters:
x – input array
axis (int) – axis along which to compute the gradient. Default is 0
accuracy (int) – accuracy order of the gradient. Default is 1
step_size (float | jax.Array) – step size. Default is 1
derivative (int) – derivative order of the gradient. Default is 1
method (MethodKind) – the method to use (forward, central, backward). Default is central
- Return type:
jax.Array
- Returns:
Finite difference derivative along the given axis
Example
>>> # dydx of a 2D array >>> import finitediffx as fdx >>> import jax.numpy as jnp >>> import jax >>> with jax.experimental.enable_x64(): ... x, y = [jnp.linspace(0, 1, 100)] * 2 ... dx, dy = x[1] - x[0], y[1] - y[0] ... X, Y = jnp.meshgrid(x, y, indexing="ij") ... F = jnp.sin(X) * jnp.cos(Y) ... dFdX = fdx.difference(F, step_size=dx, axis=0, accuracy=3, method="central") ... dFdXdY = fdx.difference(dFdX, step_size=dy, axis=1, accuracy=3, method="central") ... # 1d finite difference derivative ... x = jnp.array([1.2, 1.3, 2.2, 3., 4.5, 5.5, 6., 7., 8., 20.]) ... print(fdx.difference(x, accuracy=1)) [ 0.1 0.5 0.85 1.15 1.25 0.75 0.75 1. 6.5 12. ]
Note
Handling of boundary points is done by applying forward/backward difference to the first/last element and central difference to the interior elements. For the previous example, the following steps are performed:
Apply forward difference to the first element with accuracy
x_1 = 1.3-1.2 = 0.1
Apply central difference to interior elements with accuracy 2
x_2 = (2.2-1.2)/2 = 0.5x_3 = (3.-1.3)/2 = 0.85x_4 = (4.5-2.2)/2 = 1.15x_5 = (5.5-3.)/2 = 1.25x_6 = (6.-4.5)/2 = 0.75x_7 = (7.-5.5)/2 = 0.75x_8 = (8.-6.)/2 = 1.x_9 = (20.-7.)/2 = 6.5
Apply backward difference to the last element with accuracy 1
x_10 = 20.-8. = 12.
- Parameters:
array (jax.Array) –
- finitediffx.divergence(array, *, accuracy=1, step_size=1, keepdims=True, method='central')[source]#
Compute the ∇.F of input array where F is a vector field whose components are the first axis of x and returns a scalar field
- Parameters:
x – input array where the leading axis is the dimension of the vector field
accuracy – accuracy order of the gradient. Default is 1, can be a tuple for each axis
step_size – step size. Default is 1, can be a tuple for each axis
method – the method to use (forward, central, backward). Default is central
keepdims – whether to keep the leading dimension. Default is True.
Index notation: dFi/dxi
Example
>>> # ∇.F of a 2D array >>> import finitediffx as fdx >>> import jax.numpy as jnp >>> import numpy.testing as npt >>> x, y = [jnp.linspace(0, 1, 100)] * 2 >>> dx, dy = x[1] - x[0], y[1] - y[0] >>> X, Y = jnp.meshgrid(x, y, indexing="ij") >>> F1 = X**2 + Y**3 >>> F2 = X**4 + Y**3 >>> F = jnp.stack([F1, F2], axis=0) # 2D vector field F = (F1, F2) >>> divZ = fdx.divergence(F,step_size=(dx,dy), accuracy=7, keepdims=False) >>> divZ_true = 2*X + 3*Y**2 # (dF1/dx) + (dF2/dy) >>> npt.assert_allclose(divZ, divZ_true, atol=5e-4)
- finitediffx.gradient(array, *, accuracy=1, step_size=1, method='central')[source]#
Compute the ∇F of input array where F is a scalar function of x and returns vectors of the same shape as x stacked along the first axis.
- Parameters:
x – input array
accuracy – accuracy order of the gradient. Default is 1, can be a tuple for each axis
step_size – step size. Default is 1, can be a tuple for each axis
method – the method to use (forward, central, backward). Default is central
Index notation : dF/dxi
Example
>>> # ∇F of a 2D array >>> import finitediffx as fdx >>> import jax.numpy as jnp >>> import numpy.testing as npt >>> import jax >>> with jax.experimental.enable_x64(): ... x, y = [jnp.linspace(0, 1, 100)] * 2 ... dx, dy = x[1] - x[0], y[1] - y[0] ... X, Y = jnp.meshgrid(x, y, indexing="ij") ... Z = X**2 + Y**3 ... dZdX , dZdY = fdx.gradient(Z, step_size=(dx,dy), accuracy=5) ... dZdX_true, dZdY_true= 2*X , 3*Y**2 ... npt.assert_allclose(dZdX, dZdX_true, atol=1e-4) ... npt.assert_allclose(dZdY, dZdY_true, atol=1e-4)
- finitediffx.hessian(array, *, accuracy=2, step_size=1, method='central')[source]#
Compute hessian of F: R^m -> R
- Parameters:
x – input array
accuracy – accuracy order of the gradient. Default is 2, can be a tuple for each axis
method – the method to use (forward, central, backward). Default is central
step_size – step size. Default is 1, can be a tuple for each axis
Index notation: d2F/dxij
Example
>>> import finitediffx as fdx >>> import jax.numpy as jnp >>> import numpy.testing as npt >>> import jax >>> with jax.experimental.enable_x64(): ... x, y = [jnp.linspace(-1, 1, 100)] * 2 ... dx, dy = x[1] - x[0], y[1] - y[0] ... X, Y = jnp.meshgrid(x, y, indexing="ij") ... F = X**2 * Y ... H = fdx.hessian(F, accuracy=4, step_size=(dx, dy)) ... H_true = jnp.array([[2 * Y, 2 * X], [2 * X, jnp.zeros_like(X)]]) ... npt.assert_allclose(H, H_true, atol=1e-7)
- finitediffx.jacobian(array, *, accuracy=1, step_size=1, method='central')[source]#
Compute the ∂Fi/∂xj of input array where F is a vector function of x and returns vectors of the same shape as x stacked along the first axis.
- Parameters:
x – input array
accuracy – accuracy order of the gradient. Default is 1, can be a tuple for each axis
step_size – step size. Default is 1, can be a tuple for each axis
method – the method to use (forward, central, backward). Default is central
Index notation: ∂Fi/∂xj
Example
>>> # F: R^2 -> R^2 >>> # F = [ x^2*y, 5x+siny ] >>> # JF = [ [2xy, x^2], [5, cos(y)] ] >>> import finitediffx as fdx >>> import jax.numpy as jnp >>> import numpy.testing as npt >>> import jax >>> with jax.experimental.enable_x64(): ... x, y = [jnp.linspace(-1, 1, 100)] * 2 ... dx, dy = x[1] - x[0], y[1] - y[0] ... X, Y = jnp.meshgrid(x, y, indexing="ij") ... F1 = X**2 * Y ... F2 = 5 * X + jnp.sin(Y) ... F = jnp.stack([F1, F2], axis=0) ... JF = fdx.jacobian(F, accuracy=4, step_size=(dx, dy)) ... JF_true = jnp.array([[2 * X * Y, X**2], [5*jnp.ones_like(X), jnp.cos(Y)]]) ... npt.assert_allclose(JF, JF_true, atol=1e-7)
- finitediffx.laplacian(array, *, accuracy=1, step_size=1, method='central')[source]#
Compute the ΔF of input array. :param x: input array :type accuracy: :param accuracy: accuracy order of the gradient. Default is 1, can be a tuple for each axis :type step_size: :param step_size: step size. Default is 1, can be a tuple for each axis :type method: :param method: the method to use (forward, central, backward). Default is central
Index notation: d2F/dxi2 .. rubric:: Example
>>> import finitediffx as fdx >>> import jax.numpy as jnp >>> import numpy.testing as npt >>> import jax >>> with jax.experimental.enable_x64(): ... x, y = [jnp.linspace(0, 1, 100)] * 2 ... dx, dy = x[1] - x[0], y[1] - y[0] ... X, Y = jnp.meshgrid(x, y, indexing="ij") ... Z = X**4 + Y**3 ... laplacianZ = fdx.laplacian(Z, step_size=(dx,dy), accuracy=10) ... laplacianZ_true = 12*X**2 + 6*Y ... npt.assert_allclose(laplacianZ, laplacianZ_true, atol=1e-4)