How to code Differentiation in JAX with Simple Examples
Intro to JAX
JAX is a python library for accelerated linear algebra and is mainly targeted at efficient and high-performance neural networks training and differential programming. In this article, we will go through very simple examples of differentiation in JAX, all you need to know is the basics of python and high school differential calculus. And knowledge of neural networks is a plus. Let's dive right in!
The code and text for this tutorial is available in this Google Colab.
Overview
- How is JAX so fast and efficient?
- How to use `grad()` function in JAX
- Differentiate a simple linear function in JAX
- Differentiate quadratic function in JAX
- Easily calculate higher derivatives in JAX
- Differentiate ReLU in JAX
- Partial Differentiation in JAX
How is JAX so fast and efficient?
JAX has the same interface as NumPy, hence if you already know NumPy, you already know most of
JAX. But JAX is better than NumPy because NumPy only uses the CPU to do its computations, whereas,
JAX uses accelerators such as GPUs and TPUs to parallelize its computations. It doesn't just distribute linear algebra calculations across the huge number of cores in GPUs, it can also use multiple TPUs at the same time by distributing it across multiple devices. This is amazing because you can rent TPUs in the cloud and get extremely fast and efficient computations when you need them.
How to use grad()
function in JAX
Setup: import JAX
Install JAX using pip:
pip install jax
Then import:
import jax.numpy as jnp
from jax import grad
Differentiate a simple linear function in JAX
Let's start with an extremely simple differentiation of \(2x\), which is just \(2\), regardless of the input \(x\) value. Here's the math:
$$
f_{1}(x) = 2x
$$
$$
\frac{\partial f_{1}}{\partial x} = f_{1}'(x)= 2
$$
$$
x = 2.0; \qquad f_{1}'(2.0)= 2
$$
Now we program it; first, we define the function of \(2x\):
def f1(x):
return 2 * x
And then we get the first derivation using the grad()
function from JAX:
df1 = grad(f1)
, which returns another function
df1
, which is the differentiated form of f1
.
def f1(x):
return 2 * x
df1 = grad(f1)
df1
# <function __main__.f1>
We can get the value of the derivative by simply inputting \(x\), lets try 1.0
.
x = 1.0
df1(x)
# DeviceArray(2., dtype=float32)
We can get just the value of the gradient using float()
:
x = 1.0
float(df1(x))
# 2.0
Note: inputting integers (just1
instead of1.0
) causes errors such asTypeError: grad requires real- or complex-valued input
, since JAX is expecting floating-point numbers.Side note: the
DeviceArray
returned is quite interesting because when you will be using a GPU/TPU, the memory inside the GPU/TPU will store your inputs and intermediate variables, not your local RAM. Hence, the accelerators work faster since they access the onboard memory faster.
And no matter what value of \(x\) you input, we will get the same gradient of \(2.0\) in this case.
x = 2.0
print(float(df1(x)))
x = 3.0
print(float(df1(x)))
x = 4.0
print(float(df1(x)))
# 2.0
# 2.0
# 2.0
Differentiate quadratic function in JAX
Let's look at a slightly harder differentiation of a quadratic equation, here's the math:
$$
f_{2}(x) = x^2 + 5x + 10
$$
$$
f_{2}'(x)= 2x + 5
$$
$$
x = 2.0; \qquad f_{2}'(2.0)= 9.0
$$
Just like before, let's define f2
as a quadratic function and
get df2
using grad()
:
def f2(x):
return x ** 2 + 5 * x + 10
df2 = grad(f2)
Try out some input values of \(x\) again as a sanity check, they will return the accurate gradients.
x = 2.0
print(float(df2(x)))
x = 3.0
print(float(df2(x)))
x = 4.0
print(float(df2(x)))
# 9.0
# 11.0
# 13.0
Easily calculate higher derivatives in JAX
If you are wondering whether calculating second derivatives and beyond is also easy in JAX, you are right. It's as easy as running grad()
again on the first derivative function.Here's the previous example with quadratic equations continued:
$$
f_{2}(x) = x^2 + 5x + 10
$$
$$
f_{2}'(x)= 2x + 5
$$
$$
f_{2}''(x)= 2
$$
$$
x = 2.0; \qquad f_{2}'(2.0)= 9.0
$$
$$
x = 2.0; \qquad f_{2}''(2.0)= 2.0
$$
Here the second derivative is calculated by simply running d2_f2 = grad(df2)
,where d2_f2
represents the second derivative of the quadratic function.
def f2(x):
return x ** 2 + 5 * x + 10
df2 = grad(f2)
d2_f2 = grad(df2)
This means that you can stack the grad()
function as many times as you want and get higher-order derivatives when needed. So, grad(grad(grad(f(x))))
is a perfectly fine line in JAX. Awesome!
Since the second derivative of our quadratic function will always return \(2.0\), here are some runs that prove so:
x = 2.0
print(float(d2_f2(x)))
x = 3.0
print(float(d2_f2(x)))
x = 4.0
print(float(d2_f2(x)))
# 2.0
# 2.0
# 2.0
Differentiate sigmoid function in JAX
The sigmoid function is very useful in neural networks and also a little harder to differentiate than our previous examples, and here's the first derivative calculation of the sigmoid function for reference:
$$
f_{3}(x) = s(x) = \frac{1}{1 + e^{-x}}
$$
$$
f_{3}'(x)= \frac{-e^{-x}}{{(1 + e^{-x})}^2} = s(x) \times (1 - s(x))
$$
$$
x = 2.0; \qquad f_{3}'(2.0) = 0.104
$$
Again, let's define sigmoid(x)
as 1 / (1 + jnp.exp(-x))
and get d_sig
as the first derivative:
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
d_sig = grad(sigmoid)
Here's a sanity check of \(x = 2.0\):
x = 2.0
float(d_sig(x))
# 0.10499357432126999
At this point, you should be able to see that JAX can deal with a pretty complex derivative because differentiating the sigmoid is quite hard. But JAX is even more magical and powerful, as you will see in the following examples.
Differentiate ReLU in JAX
ReLU is another very popular activation function in neural networks, and this is a good example of how JAX can differentiate through things like if-else
statements accurately. Here's the math for the first derivative of ReLU:
$$
f_{4}(x) = ReLU(x)= \begin{cases}
x, & \text{if $x \geqslant 0$}\\
0, & \text{otherwise}
\end{cases}
$$
$$
f'_{4}(x) = \begin{cases}
1, & \text{if $x \geqslant 0$}\\
0, & \text{otherwise}
\end{cases}
$$
$$
x = 2.0; \qquad f_{4}'(2.0) = 1.0
$$
$$
x = -2.0; \qquad f_{4}'(-2.0) = 0.0
$$
Since ReLU depends on the condition of where \(x \geqslant 0\), hence we use if x >= 0:
inside the function.
def ReLU(x):
if x >= 0:
return x
else:
return 0.0
d_ReLU = grad(ReLU)
Checking the gradient values using \(x\)'s that are greater than 0 and ones that are less:
x = 1.0
print(float(d_ReLU(x)))
x = 2.0
print(float(d_ReLU(x)))
x = -1.0
print(float(d_ReLU(x)))
x = -10.0
print(float(d_ReLU(x)))
# 1.0
# 1.0
# 0.0
# 0.0
Partial Differentiation in JAX
Now, we will look at a partial derivative, which is really useful since real-world situations are all multivariate. And as you would expect, JAX makes it easy to run grad()
on functions with multiple variables. Here's a simple two-variable example:
$$
f_{5}(x, y) = x^2 + y^2 + 5xy
$$
$$
\frac{\partial f_{5}}{\partial x} = f_{5, x}'(x, y)= 2x + 5y
$$
$$
\frac{\partial f_{5}}{\partial y} = f_{5, y}'(x, y)= 2y + 5x
$$
$$
x = 2.0; \qquad y = 1.0; \qquad \frac{\partial f_{5}}{\partial x} = 9.0
$$
$$
x = 2.0; \qquad y = 1.0; \qquad \frac{\partial f_{5}}{\partial y} = 12.0
$$
We define the function f5
with two input arguments x
and y
, and then run grad()
to get partial derivatives.
df5_dx
is the partial derivative of f5
with
respect to x
, and we used an argnums
input to
specify that we want to differentiate with respect to x
(e.g. the first argument
and so the argument in position 0.)
Hence, the grad(f5, argnums=0)
gives \(\frac{\partial f_{5}}{\partial x}\). Similarly,
grad(f5, argnums=1)
gives \(\frac{\partial f_{5}}{\partial y}\), since
argnums=1
represents y
, the second argument.
def f5(x, y):
return x**2 + y**2 + 5 * x * y
df5_dx = grad(f5, argnums=0)
df5_dy = grad(f5, argnums=1)
And after getting the df5_dx
and df5_dy
derivatives,
we can get the values of the gradients by inputting x
and y
like df5_dx(x, y)
and df5_dy(x, y)
. Here are the test runs:
x = 2.0
y = 1.0
print(float(df5_dx(x, y)))
print(float(df5_dy(x, y)))
# 9.0
# 12.0
Conclusion
This article was meant to be a very gentle introduction to differentiation in JAX, with very simple examples, so that the readers can grasp the grad()
the function of JAX very easily.
You can now learn more advanced auto differentiation in JAX in these articles: