# How to code Differentiation in JAX with Simple Examples

## Intro to JAX

JAX is a python library for accelerated linear algebra and is mainly targeted at efficient and high-performance neural networks training and differential programming. In this article, we will go through very simple examples of differentiation in JAX, all you need to know is the basics of python and high school differential calculus. And knowledge of neural networks is a plus. Let's dive right in!

The code and text for this tutorial is available in this Google Colab.

## Overview

- How is JAX so fast and efficient?
- How to use `grad()` function in JAX
- Differentiate a simple linear function in JAX
- Differentiate quadratic function in JAX
- Easily calculate higher derivatives in JAX
- Differentiate ReLU in JAX
- Partial Differentiation in JAX

### How is JAX so fast and efficient?

JAX has the same interface as NumPy, hence if you already know NumPy, you already know most of

JAX. But JAX is better than NumPy because NumPy only uses the CPU to do its computations, whereas,

JAX uses accelerators such as GPUs and TPUs to parallelize its computations. It doesn't just distribute linear algebra calculations across the huge number of cores in GPUs, it can also use multiple TPUs at the same time by distributing it across multiple devices. This is amazing because you can rent TPUs in the cloud and get extremely fast and efficient computations when you need them.

## How to use `grad()`

function in JAX

### Setup: import JAX

Install JAX using pip:

```
pip install jax
```

Then import:

```
import jax.numpy as jnp
from jax import grad
```

### Differentiate a simple linear function in JAX

Let's start with an extremely simple differentiation of \(2x\), which is just \(2\), regardless of the input \(x\) value. Here's the math:

$$

f_{1}(x) = 2x

$$

$$

\frac{\partial f_{1}}{\partial x} = f_{1}'(x)= 2

$$

$$

x = 2.0; \qquad f_{1}'(2.0)= 2

$$

Now we program it; first, we define the function of \(2x\):

```
def f1(x):
return 2 * x
```

And then we get the first derivation using the `grad()`

function from JAX:

`df1 = grad(f1)`

, which returns another function

`df1`

, which is the differentiated form of `f1`

.

```
def f1(x):
return 2 * x
df1 = grad(f1)
df1
# <function __main__.f1>
```

We can get the value of the derivative by simply inputting \(x\), lets try `1.0`

.

```
x = 1.0
df1(x)
# DeviceArray(2., dtype=float32)
```

We can get just the value of the gradient using `float()`

:

```
x = 1.0
float(df1(x))
# 2.0
```

Note: inputting integers (just`1`

instead of`1.0`

) causes errors such as`TypeError: grad requires real- or complex-valued input`

, since JAX is expecting floating-point numbers.

Side note: the`DeviceArray`

returned is quite interesting because when you will be using a GPU/TPU, the memory inside the GPU/TPU will store your inputs and intermediate variables, not your local RAM. Hence, the accelerators work faster since they access the onboard memory faster.

And no matter what value of \(x\) you input, we will get the same gradient of \(2.0\) in this case.

```
x = 2.0
print(float(df1(x)))
x = 3.0
print(float(df1(x)))
x = 4.0
print(float(df1(x)))
# 2.0
# 2.0
# 2.0
```

### Differentiate quadratic function in JAX

Let's look at a slightly harder differentiation of a quadratic equation, here's the math:

$$

f_{2}(x) = x^2 + 5x + 10

$$

$$

f_{2}'(x)= 2x + 5

$$

$$

x = 2.0; \qquad f_{2}'(2.0)= 9.0

$$

Just like before, let's define `f2`

as a quadratic function and

get `df2`

using `grad()`

:

```
def f2(x):
return x ** 2 + 5 * x + 10
df2 = grad(f2)
```

Try out some input values of \(x\) again as a sanity check, they will return the accurate gradients.

```
x = 2.0
print(float(df2(x)))
x = 3.0
print(float(df2(x)))
x = 4.0
print(float(df2(x)))
# 9.0
# 11.0
# 13.0
```

### Easily calculate higher derivatives in JAX

If you are wondering whether calculating second derivatives and beyond is also easy in JAX, you are right. It's as easy as running `grad()`

again on the first derivative function.Here's the previous example with quadratic equations continued:

$$

f_{2}(x) = x^2 + 5x + 10

$$

$$

f_{2}'(x)= 2x + 5

$$

$$

f_{2}''(x)= 2

$$

$$

x = 2.0; \qquad f_{2}'(2.0)= 9.0

$$

$$

x = 2.0; \qquad f_{2}''(2.0)= 2.0

$$

Here the second derivative is calculated by simply running `d2_f2 = grad(df2)`

,where `d2_f2`

represents the second derivative of the quadratic function.

```
def f2(x):
return x ** 2 + 5 * x + 10
df2 = grad(f2)
d2_f2 = grad(df2)
```

This means that you can stack the `grad()`

function as many times as you want and get higher-order derivatives when needed. So, `grad(grad(grad(f(x))))`

is a perfectly fine line in JAX. Awesome!

Since the second derivative of our quadratic function will always return \(2.0\), here are some runs that prove so:

```
x = 2.0
print(float(d2_f2(x)))
x = 3.0
print(float(d2_f2(x)))
x = 4.0
print(float(d2_f2(x)))
# 2.0
# 2.0
# 2.0
```

### Differentiate sigmoid function in JAX

The sigmoid function is very useful in neural networks and also a little harder to differentiate than our previous examples, and here's the first derivative calculation of the sigmoid function for reference:

$$

f_{3}(x) = s(x) = \frac{1}{1 + e^{-x}}

$$

$$

f_{3}'(x)= \frac{-e^{-x}}{{(1 + e^{-x})}^2} = s(x) \times (1 - s(x))

$$

$$

x = 2.0; \qquad f_{3}'(2.0) = 0.104

$$

Again, let's define `sigmoid(x)`

as `1 / (1 + jnp.exp(-x))`

and get `d_sig`

as the first derivative:

```
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
d_sig = grad(sigmoid)
```

Here's a sanity check of \(x = 2.0\):

```
x = 2.0
float(d_sig(x))
# 0.10499357432126999
```

At this point, you should be able to see that JAX can deal with a pretty complex derivative because differentiating the sigmoid is quite hard. But JAX is even more magical and powerful, as you will see in the following examples.

### Differentiate ReLU in JAX

ReLU is another very popular activation function in neural networks, and this is a good example of how JAX can differentiate through things like `if-else`

statements accurately. Here's the math for the first derivative of ReLU:

$$

f_{4}(x) = ReLU(x)= \begin{cases}

x, & \text{if $x \geqslant 0$}\\

0, & \text{otherwise}

\end{cases}

$$

$$

f'_{4}(x) = \begin{cases}

1, & \text{if $x \geqslant 0$}\\

0, & \text{otherwise}

\end{cases}

$$

$$

x = 2.0; \qquad f_{4}'(2.0) = 1.0

$$

$$

x = -2.0; \qquad f_{4}'(-2.0) = 0.0

$$

Since ReLU depends on the condition of where \(x \geqslant 0\), hence we use `if x >= 0:`

inside the function.

```
def ReLU(x):
if x >= 0:
return x
else:
return 0.0
d_ReLU = grad(ReLU)
```

Checking the gradient values using \(x\)'s that are greater than 0 and ones that are less:

```
x = 1.0
print(float(d_ReLU(x)))
x = 2.0
print(float(d_ReLU(x)))
x = -1.0
print(float(d_ReLU(x)))
x = -10.0
print(float(d_ReLU(x)))
# 1.0
# 1.0
# 0.0
# 0.0
```

## Partial Differentiation in JAX

Now, we will look at a partial derivative, which is really useful since real-world situations are all multivariate. And as you would expect, JAX makes it easy to run `grad()`

on functions with multiple variables. Here's a simple two-variable example:

$$

f_{5}(x, y) = x^2 + y^2 + 5xy

$$

$$

\frac{\partial f_{5}}{\partial x} = f_{5, x}'(x, y)= 2x + 5y

$$

$$

\frac{\partial f_{5}}{\partial y} = f_{5, y}'(x, y)= 2y + 5x

$$

$$

x = 2.0; \qquad y = 1.0; \qquad \frac{\partial f_{5}}{\partial x} = 9.0

$$

$$

x = 2.0; \qquad y = 1.0; \qquad \frac{\partial f_{5}}{\partial y} = 12.0

$$

We define the function `f5`

with two input arguments `x`

and `y`

, and then run `grad()`

to get partial derivatives.

`df5_dx`

is the partial derivative of `f5`

with

respect to `x`

, and we used an `argnums`

input to

specify that we want to differentiate with respect to `x`

(e.g. the first argument

and so the argument in position 0.)

Hence, the `grad(f5, argnums=0)`

gives \(\frac{\partial f_{5}}{\partial x}\). Similarly,

`grad(f5, argnums=1)`

gives \(\frac{\partial f_{5}}{\partial y}\), since

`argnums=1`

represents `y`

, the second argument.

```
def f5(x, y):
return x**2 + y**2 + 5 * x * y
df5_dx = grad(f5, argnums=0)
df5_dy = grad(f5, argnums=1)
```

And after getting the `df5_dx`

and `df5_dy`

derivatives,

we can get the values of the gradients by inputting `x`

and `y`

like `df5_dx(x, y)`

and `df5_dy(x, y)`

. Here are the test runs:

```
x = 2.0
y = 1.0
print(float(df5_dx(x, y)))
print(float(df5_dy(x, y)))
# 9.0
# 12.0
```

## Conclusion

This article was meant to be a very gentle introduction to differentiation in JAX, with very simple examples, so that the readers can grasp the `grad()`

the function of JAX very easily.

You can now learn more advanced auto differentiation in JAX in these articles: