Cristóbal Alcázar

Directional derivatives and JAX

· 4 min read

Open In Colab

Directional derivatives are the conceptual tool to measure the effect on a function by changing the input in any direction within the input space. It's possible to compute the directional derivatives using the jacobian-vector product, implemented by the automatic differentiation JAX library.

Partial derivatives f/xi give us the rate of change if we slightly modify the ith element of the input vector 𝐱 by h, letting the rest constant.


fx1=limh0f(x1+h,x2,,xn)f(𝐱)hparsing error: new line command not allowed in current environment ╭─► context: │ │… x})}{h} │ \\ │ \vdots │ \\ │ \f │ ^^ ╰───parsing error: new line command not allowed in current environment ╭─► context: │ │…\ │ \vdots │ \\ │ \frac{\parti │ ^^^^^^^^^^^^ ╰─────────────fxn=limh0f(x1,x2,,xn+h)f(𝐱)h


The above definition can be more compactly using vector notation.


fxi(𝐱𝟎)=limh0f(𝐱𝟎+h𝐞𝐢)f(𝐱𝟎)h

The ei vector represents a unit vector in the direction of i with the same number of dimensions that 𝐱𝟎. The only element of ei different from 0 is the ith-element with a value of 1.

As you can see in the initial diagram, in a 2D input space, there are two partial derivatives:

Computing derivatives using unit vectors such as ei give us the change of f on the direction on i, or parallel to the i-axis. How can we compute the derivative of f given a slight nudge of the inputs in any arbitrary direction?

Directional derivatives is a way to compute the rate of change on f in the direction of 𝐯.


$$ \nabla_{{\bf v}}f({\bf x_0}) = lim_{h \to 0} \frac{f({\bf x_0} + h{\bf v}) - f({\bf x_0})}{h} $$

Think as 𝐯 as a weighted vector of the n-directions of the input space. We aren't limited to the changes on f in parallel directions in the input space.

We can compute directional derivatives using the dot product between the jacobian vector (f) and the vector 𝐯. For instance, for a two-dimensional input space, 𝐯=(v1,v2), and any arbitrary point p:

𝐯f(p)=f(p)𝐯=fx1(p)v1+fx2(p)v2

More general:

𝐯f(p)=f(p)𝐯=i=1nfxi(p)vi

Let's focus on computing the above using the function jax.jvp, which jvp stands for the jacobian-vector product.

The function jax.jvp computes the directional derivative and whose arguments are:

  1. A differentiable function f to compute the jacobian f
  2. A primal vector 𝐩 to evaluate the jacobian f(p)
  3. A tangent vector 𝐯 which represent the direction in which we want to calculate the derivative.

jax.jvp returns a tuple with (f(p),fv(p))

Example

We compute the directional derivative of f(x,y)=x2y hand-coding all the necessary elements and then checking the results given by jax.jvp.

def fun(x, y): return x**2 * y
def fun_dx(x, y): return 2*x*y
def fun_dy(x, y): return x**2

We define the primal vector 𝐩 and the tangent vector 𝐯 in which we want to compute the directional derivative.

p = [1., 1.]
v = [1., 2.]

Evaluate f(p):

# *n-list/n-tuple unpack the element e0, e1, ..., en
fun(*p)
> 1.0

Compute the directional derivative using the fun_dx and fun_dy.

fun_dx(*p) * v[0] + fun_dy(*p) * v[1]
> 4.0

Now using jax.jvp we obtain the same results: f(𝐩) and 𝐯f(𝐩).

jax.jvp(fun, p, v)
> (DeviceArray(1., dtype=float32, weak_type=True),
   DeviceArray(4., dtype=float32, weak_type=True))

A surface plot will show the output space, and a contour plot the input space of f(x,y)=x2y. We will compute the directional derivatives for three points and their respective directional vectors.

Look the directional vectors in the plot, or tangent vectors as JAX refers to them, there are of different lengths. It's important to remark that if we want the "slope definition" for directional derivatives we need to transform 𝐯 in a unit length vector (divide the directional derivative definition by ||v||). Remember that partial derivatives are computed using unit vectors (ei).

𝐯f=f𝐱=limh0f(𝐱+h𝐯)f(𝐱)h||𝐯||

primal_a = jnp.array([-5., 3.2])
primal_b = jnp.array([5., -3.2])
primal_c = jnp.array([0., 0.])
va = jnp.array([-7.5, 5.7])
vb = jnp.array([7.5, -5.7])
vc = jnp.array([-1.0, -0.7])
unit_va = va/va.dot(va)**.5
unit_vb = vb/vb.dot(vb)**.5
unit_vc = vc/vc.dot(vc)**.5
# Computing making the directional vectors unit length
_, slope_a = jax.jvp(fun, primal_a.tolist(), unit_va.tolist())
_, slope_b = jax.jvp(fun, primal_b.tolist(), unit_vb.tolist())
_, slope_c = jax.jvp(fun, primal_c.tolist(), unit_vc.tolist())
slope_a, slope_b, slope_c
> (DeviceArray(40.60427, dtype=float32, weak_type=True),
   DeviceArray(-40.60427, dtype=float32, weak_type=True),
   DeviceArray(-0., dtype=float32, weak_type=True))

We can see some observations from the points and their directional derivatives.