How to code Differentiation in JAX with Simple Examples

JAX is a Python library for accelerated linear algebra, primarily designed for efficient and high-performance neural networks training and differential programming. This guide will walk you through simple examples of differentiation in JAX using basic Python and high school calculus concepts. This will help you to code differentiation in JAX
What Makes JAX Special?
JAX shares the same interface as NumPy, making it familiar to NumPy users. However, JAX offers significant advantages because it can use accelerators such as GPUs and TPUs to parallelize computations, while NumPy relies solely on CPU processing JAX doesn't just distribute linear algebra calculations across the numerous cores in GPUs—it can also leverage multiple TPUs simultaneously by distributing work across multiple devices, allowing for extremely fast and efficient computations Getting Started with JAX
First, install JAX and import the necessary functions:
pip install jax
import jax.numpy as jnp
from jax import grad
Differentiating a Simple Linear Function
Let's start with a straightforward example: differentiating f(x) = 2x, which should give us 2.
def f1(x):
return 2 * x
df1 = grad(f1) # Create the derivative function
x = 1.0
print(float(df1(x))) # Output: 2.0
No matter what value of x you input, the gradient will always be 2.0:
x = 2.0
print(float(df1(x))) # 2.0
x = 3.0
print(float(df1(x))) # 2.0
Note: JAX requires floating-point inputs rather than integers to avoid TypeError exceptions.
Differentiating a Quadratic Function
Let's try a quadratic equation: f(x) = x² + 5x + 10
def f2(x):
return x ** 2 + 5 * x + 10
df2 = grad(f2)
x = 2.0
print(float(df2(x))) # 9.0 (= 2*2 + 5)
x = 3.0
print(float(df2(x))) # 11.0 (= 2*3 + 5)
Higher Derivatives Made Easy
Calculating second or higher derivatives in JAX is as simple as applying grad() multiple times:
def f2(x):
return x ** 2 + 5 * x + 10
df2 = grad(f2) # First derivative: 2x + 5
d2_f2 = grad(df2) # Second derivative: 2
x = 2.0
print(float(d2_f2(x))) # 2.0
This stacking capability means you can use expressions like grad(grad(grad(f(x)))) to obtain higher-order derivatives as needed.
Differentiating the Sigmoid Function
The sigmoid function is commonly used in neural networks and is more complex to differentiate manually:
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
d_sig = grad(sigmoid)
x = 2.0
print(float(d_sig(x))) # Approximately 0.105
Differentiating ReLU
ReLU (Rectified Linear Unit) is another popular activation function in neural networks. JAX can differentiate through conditional statements like if-else:
def ReLU(x):
if x >= 0:
return x
else:
return 0.0
d_ReLU = grad(ReLU)
print(float(d_ReLU(1.0))) # 1.0
print(float(d_ReLU(2.0))) # 1.0
print(float(d_ReLU(-1.0))) # 0.0
Partial Differentiation in JAX
JAX also handles partial derivatives for multivariate functions. For f(x, y) = x² + y² + 5xy:
def f5(x, y):
return x**2 + y**2 + 5 * x * y
# Partial derivative with respect to x
df5_dx = grad(f5, argnums=0)
# Partial derivative with respect to y
df5_dy = grad(f5, argnums=1)
x = 2.0
y = 1.0
print(float(df5_dx(x, y))) # 2x + 5y = 2*2 + 5*1 = 9.0
print(float(df5_dy(x, y))) # 2y + 5x = 2*1 + 5*2 = 12.0
Conclusion
This introduction provides a foundation for understanding JAX's differentiation capabilities. With these basics, you can now explore more advanced auto-differentiation features in JAX's documentation.
