How to code Differentiation in JAX with Simple Examples

Updated on: August 16, 2021
,



Intro to JAX

JAX is a python library for accelerated linear algebra and is mainly targeted at efficient and high-performance neural networks training and differential programming. In this article, we will go through very simple examples of differentiation in JAX, all you need to know is the basics of python and high school differential calculus. And knowledge of neural networks is a plus. Let's dive right in!

The code and text for this tutorial is available in this Google Colab.

Overview

How is JAX so fast and efficient?

JAX has the same interface as NumPy, hence if you already know NumPy, you already know most of
JAX. But JAX is better than NumPy because NumPy only uses the CPU to do its computations, whereas,
JAX uses accelerators such as GPUs and TPUs to parallelize its computations. It doesn't just distribute linear algebra calculations across the huge number of cores in GPUs, it can also use multiple TPUs at the same time by distributing it across multiple devices. This is amazing because you can rent TPUs in the cloud and get extremely fast and efficient computations when you need them.

How to use grad() function in JAX

Setup: import JAX

Install JAX using pip:


pip install jax
                

Then import:


import jax.numpy as jnp
from jax import grad
                

Differentiate a simple linear function in JAX

Let's start with an extremely simple differentiation of \(2x\), which is just \(2\), regardless of the input \(x\) value. Here's the math:

$$
f_{1}(x) = 2x
$$

$$
\frac{\partial f_{1}}{\partial x} = f_{1}'(x)= 2
$$

$$
x = 2.0; \qquad f_{1}'(2.0)= 2
$$

Now we program it; first, we define the function of \(2x\):


def f1(x):
    return 2 * x
                

And then we get the first derivation using the grad() function from JAX:
df1 = grad(f1), which returns another function
df1, which is the differentiated form of f1.


def f1(x):
    return 2 * x

df1 = grad(f1)
df1

# <function __main__.f1>
                

We can get the value of the derivative by simply inputting \(x\), lets try 1.0.


x = 1.0
df1(x)

# DeviceArray(2., dtype=float32)
                

We can get just the value of the gradient using float():


x = 1.0
float(df1(x))

# 2.0
                


Note: inputting integers (just 1 instead of 1.0) causes errors such as TypeError: grad requires real- or complex-valued input, since JAX is expecting floating-point numbers.

Side note: the DeviceArray returned is quite interesting because when you will be using a GPU/TPU, the memory inside the GPU/TPU will store your inputs and intermediate variables, not your local RAM. Hence, the accelerators work faster since they access the onboard memory faster.

And no matter what value of \(x\) you input, we will get the same gradient of \(2.0\) in this case.


x = 2.0
print(float(df1(x)))
x = 3.0
print(float(df1(x)))
x = 4.0
print(float(df1(x)))

# 2.0
# 2.0
# 2.0
                

Differentiate quadratic function in JAX

Let's look at a slightly harder differentiation of a quadratic equation, here's the math:

$$
f_{2}(x) = x^2 + 5x + 10
$$

$$
f_{2}'(x)= 2x + 5
$$

$$
x = 2.0; \qquad f_{2}'(2.0)= 9.0
$$

Just like before, let's define f2 as a quadratic function and
get df2 using grad():


def f2(x):
    return x ** 2 + 5 * x + 10

df2 = grad(f2)
                

Try out some input values of \(x\) again as a sanity check, they will return the accurate gradients.


x = 2.0
print(float(df2(x)))
x = 3.0
print(float(df2(x)))
x = 4.0
print(float(df2(x)))

# 9.0
# 11.0
# 13.0
                

Easily calculate higher derivatives in JAX

If you are wondering whether calculating second derivatives and beyond is also easy in JAX, you are right. It's as easy as running grad() again on the first derivative function.Here's the previous example with quadratic equations continued:

$$
f_{2}(x) = x^2 + 5x + 10
$$

$$
f_{2}'(x)= 2x + 5
$$

$$
f_{2}''(x)= 2
$$

$$
x = 2.0; \qquad f_{2}'(2.0)= 9.0
$$

$$
x = 2.0; \qquad f_{2}''(2.0)= 2.0
$$

Here the second derivative is calculated by simply running d2_f2 = grad(df2),where d2_f2 represents the second derivative of the quadratic function.


def f2(x):
    return x ** 2 + 5 * x + 10

df2 = grad(f2)
d2_f2 = grad(df2)
                

This means that you can stack the grad() function as many times as you want and get higher-order derivatives when needed. So, grad(grad(grad(f(x)))) is a perfectly fine line in JAX. Awesome!

Since the second derivative of our quadratic function will always return \(2.0\), here are some runs that prove so:


x = 2.0
print(float(d2_f2(x)))
x = 3.0
print(float(d2_f2(x)))
x = 4.0
print(float(d2_f2(x)))

# 2.0
# 2.0
# 2.0
                

Differentiate sigmoid function in JAX

The sigmoid function is very useful in neural networks and also a little harder to differentiate than our previous examples, and here's the first derivative calculation of the sigmoid function for reference:

$$
f_{3}(x) = s(x) = \frac{1}{1 + e^{-x}}
$$

$$
f_{3}'(x)= \frac{-e^{-x}}{{(1 + e^{-x})}^2} = s(x) \times (1 - s(x))
$$

$$
x = 2.0; \qquad f_{3}'(2.0) = 0.104
$$

Again, let's define sigmoid(x) as 1 / (1 + jnp.exp(-x))
and get d_sig as the first derivative:


def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

d_sig = grad(sigmoid)
                

Here's a sanity check of \(x = 2.0\):


x = 2.0
float(d_sig(x))

# 0.10499357432126999
                

At this point, you should be able to see that JAX can deal with a pretty complex derivative because differentiating the sigmoid is quite hard. But JAX is even more magical and powerful, as you will see in the following examples.

Differentiate ReLU in JAX

ReLU is another very popular activation function in neural networks, and this is a good example of how JAX can differentiate through things like if-else statements accurately. Here's the math for the first derivative of ReLU:

$$
f_{4}(x) = ReLU(x)= \begin{cases}
x, & \text{if $x \geqslant 0$}\\
0, & \text{otherwise}
\end{cases}
$$

$$
f'_{4}(x) = \begin{cases}
1, & \text{if $x \geqslant 0$}\\
0, & \text{otherwise}
\end{cases}
$$

$$
x = 2.0; \qquad f_{4}'(2.0) = 1.0
$$
$$
x = -2.0; \qquad f_{4}'(-2.0) = 0.0
$$

Since ReLU depends on the condition of where \(x \geqslant 0\), hence we use if x >= 0:
inside the function.


def ReLU(x):
    if x >= 0:
        return x
    else:
        return 0.0

d_ReLU = grad(ReLU)
                

Checking the gradient values using \(x\)'s that are greater than 0 and ones that are less:


x = 1.0
print(float(d_ReLU(x)))
x = 2.0
print(float(d_ReLU(x)))
x = -1.0
print(float(d_ReLU(x)))
x = -10.0
print(float(d_ReLU(x)))

# 1.0
# 1.0
# 0.0
# 0.0
                

Partial Differentiation in JAX

Now, we will look at a partial derivative, which is really useful since real-world situations are all multivariate. And as you would expect, JAX makes it easy to run grad() on functions with multiple variables. Here's a simple two-variable example:

$$
f_{5}(x, y) = x^2 + y^2 + 5xy
$$

$$
\frac{\partial f_{5}}{\partial x} = f_{5, x}'(x, y)= 2x + 5y
$$

$$
\frac{\partial f_{5}}{\partial y} = f_{5, y}'(x, y)= 2y + 5x
$$

$$
x = 2.0; \qquad y = 1.0; \qquad \frac{\partial f_{5}}{\partial x} = 9.0
$$

$$
x = 2.0; \qquad y = 1.0; \qquad \frac{\partial f_{5}}{\partial y} = 12.0
$$

We define the function f5 with two input arguments x
and y, and then run grad() to get partial derivatives.

df5_dx is the partial derivative of f5 with
respect to x, and we used an argnums input to
specify that we want to differentiate with respect to x (e.g. the first argument
and so the argument in position 0.)

Hence, the grad(f5, argnums=0) gives \(\frac{\partial f_{5}}{\partial x}\). Similarly,
grad(f5, argnums=1) gives \(\frac{\partial f_{5}}{\partial y}\), since
argnums=1 represents y, the second argument.


def f5(x, y):
    return x**2 + y**2 + 5 * x * y

df5_dx = grad(f5, argnums=0)
df5_dy = grad(f5, argnums=1)
                

And after getting the df5_dx and df5_dy derivatives,
we can get the values of the gradients by inputting x and y
like df5_dx(x, y) and df5_dy(x, y). Here are the test runs:


x = 2.0
y = 1.0
print(float(df5_dx(x, y)))
print(float(df5_dy(x, y)))

# 9.0
# 12.0
                

Conclusion

This article was meant to be a very gentle introduction to differentiation in JAX, with very simple examples, so that the readers can grasp the grad() the function of JAX very easily.

You can now learn more advanced auto differentiation in JAX in these articles:



Shah Yasser Aziz
He is a tech & MOOC enthusiast who likes to code and build web 2.0 applications using his skills. He also likes to learn new tech stack in the industry to sharp his skillsets.

Leave a Reply

Your email address will not be published. Required fields are marked *