How to Train a Neural Network from Scratch in JAX with Example

If you work with neural nets and have not heard about JAX, you need to. JAX is a python library developed by Google that accelerates machine learning research by using special hardware such as GPU/TPUs. Not only will it help you train neural networks faster, but you are also be saving cost and energy. In this article, we will train a very simple neural network on a small toy dataset.
This article is intended for beginners, who have not used JAX before, but you should know the basics of python and neural networks to understand everything.
This article is also available in Google Colab, so feel free to interactively run the code there.
Outline
What is a neural network?
To put it simply, a neural network is a function mapper; it takes an input (e.g. numbers, images, encoded text) and calculates an output using the tunable weights/parameters it has. Suppose we train a neural network on images of dogs; the backpropagation algorithm changes ("tunes") the weights of the network, so that when the pattern of dog images shows up in any image, the network will output positively, else negatively.
How to train a neural network in JAX
Since this is a beginner's tutorial, we will neither use large neural networks with millions of parameters nor will we use datasets with thousands of data points. We will train really small networks (16 hidden units only) on the input-output values of a simple 2-bit XOR gate so that we can focus on the syntax of JAX. In fact, the code quality is also very naive, with no fancy syntax or complex classes, so that beginners can grasp the tools in JAX. Later, you can move onto more advanced neural network training in articles like this.
What is an XOR gate?
An XOR gate is a logical gate/unit that takes bits as input (0 or 1) and also outputs bits based on their characteristics. So, an AND gate will only return 1 or both inputs to it are 1, else it outputs 0. Similarly, there are OR, NAND, NOR gates, but today we will use the XOR gate because it doesn't have linearly separable data (read more about this). And neural networks are great at fitting non-linear data points.
The input-output characteristics of a 2-bit XOR gate are as follow:
- If both inputs are different (1 and 0, or 0 and 1), then the output is 1.
- else if both inputs are the same (0 and 0, or 1 and 1), then the output is 0.
Hence we only have 4 different cases for each permutation of the two outputs, and this makes for a great toy dataset.
Import JAX
You can install jax with pip install jax
and then import it.
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
JAX needs a random seed to generate the initial values of the weights, so we do this:
key = random.PRNGKey(0)
key
Output:
DeviceArray([0, 0], dtype=uint32)
w_key, b_key = random.split(key)
print(w_key, b_key)
Output:
[4146024105 967050713] [2718843009 1272950319]
XOR toy dataset
Let's type up the four different cases of an XOR gate into X
and Y
as the training dataset (much easier than dealing with MNIST,let alone unclean datasets), followed by converting them to JAX arrays.
# XOR data
X = [
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
]
Y = [
1.0,
0.0,
0.0,
1.0
]
X = jnp.array(X)
Y = jnp.array(Y)
Generate the random weight matrices from layer sizes
The XOR gate takes 2 inputs and provides 1 output, hence input_dims = 2
and output_dims = 1
. Our simple neural network will be a fully connected network (FCN) with just one hidden layer in the middle, with 16 nodes only, hencen_hidden = 16
.
w1
and b1
contain the weight values of the first hidden layer with 16 nodes. w1
has a shape of (16, 2) and b1
has a shape of (16,)(just a single dimensional vector). Similarly, w2
and b2
belong to
the output layer and have shapes of (1, 16) and (1,) respectively. The scale
variable does exactly what it sounds like, it scales the random weight values to approximately 0.01 because that leads to better training.
input_dims = 2
n_hidden = 16
output_dims = 1
layer1_shape = (n_hidden, input_dims)
layer2_shape = (output_dims, n_hidden)
scale = 1e-2
w1 = scale * random.normal(w_key, layer1_shape)
b1 = scale * random.normal(b_key, (n_hidden,))
w2 = scale * random.normal(w_key, layer2_shape)
b2 = scale * random.normal(b_key, (output_dims,))
Let's have a look at w1
, it contains random weights, which will change as the networks train. You can print the other weights too to understand their shapes.
# what does the weight matrix look like?
w1
Output:
DeviceArray([[-0.00797431, 0.00609697],
[-0.01715074, -0.00641281],
[ 0.00207656, 0.00676562],
[-0.01798742, 0.00526943],
[ 0.0181759 , 0.00267145],
[ 0.01292056, 0.00214754],
[-0.00112515, 0.01107192],
[-0.008862 , -0.01382878],
[ 0.00024038, -0.00296096],
[ 0.02082078, 0.01441997],
[ 0.00993177, -0.00310528],
[ 0.02046751, -0.00105606],
[ 0.00622171, 0.00941826],
[ 0.01753185, -0.00755145],
[-0.01004758, 0.00183098],
[ 0.00619588, 0.0072066 ]], dtype=float32)
The parameters of the neural network will be nice to have in a dictionary since JAX can easily differentiate through python data structures, so let's define the params
dictionary. You will be able to see the usefulness of this a little later.
params = {
'w1': w1,
'b1': b1,
'w2': w2,
'b2': b2,
}
Now, let's get into the calculations at the heart of any neural network: matrix multiplications and activation functions. ReLU is a very popular activation function and provides great performance. ReLU is very easy to understand: if the input is negative, the output is 0. But if the input is positive, the input passes as the output unchanged. Let's define a ReLU
function for that:
def relu(x):
return jnp.maximum(0, x)
And now, let's do a single forward pass through our untrained neural network just to see how JAX can
multiply the weights and apply the activations:
- let's take
X[3]
(which is just[1, 1]
) as the sample input. z1 = jnp.dot(w1, x) + b1
performs the matrix multiplication for the hidden layer with 16 nodes anda1 = relu(z1)
calculates the activations.- finally,
z2 = jnp.dot(w2, a1) + b2
takes the activations from the hidden layer and does the matrix multiplication of the output layer and then outputs a simple scalar value (i.e. logit), which we will consider as the output of the XOR gate for now.
# testing forward pass
x = X[3]
y = Y[3]
z1 = jnp.dot(w1, x) + b1
a1 = relu(z1)
z2 = jnp.dot(w2, a1) + b2
z2
Output:
DeviceArray([-0.01206843], dtype=float32)
The forward pass through the FCN is something that will happen many many times during training, hence we wrap these steps in a predict()
function, which takes the params
dictionary we defined, along with input (e.g. [0,1]
).
def predict(params, input):
w1 = params['w1']
b1 = params['b1']
w2 = params['w2']
b2 = params['b2']
z1 = jnp.dot(w1, input) + b1
a1 = relu(z1)
z2 = jnp.dot(w2, a1) + b2
return z2
Define the loss function (mean squared error)
Now, any neural network trains on a loss function, which helps the neural network understand how wrong or right it is compared to the labels (error value). We will use the mean squared error (MSE) loss function since it is the mean of the square of the difference between the prediction/output of the neural network and the label from the training data.
def loss(params, input, target):
prediction = predict(params, input)
return jnp.mean((prediction - target) ** 2)
loss(params, x, y)
Output:
DeviceArray(1.0242825, dtype=float32)
What does the loss value of MSE represent?
When the neural network is untrained, the output is wrong and so the difference between the
output and label is high, thus loss will be large. However, as the neural network will train
and tune its weights, the output will become closer and closer to the true label and hence the
loss value will go down.
Calculate gradients with respect to the weights
In deep learning, the gradients of the loss function with respect to the weight parameters (e.g. w1
, b1
, etc.) is the core concept in the backpropagation algorithm. The gradients provide the signals for tuning the weights so that the weights can either increase or decrease incrementally and then the output of the neural network becomes more accurate. And the magicalgrad()
function of JAX is all you need to calculate the gradients. Let's print out and look at the gradient matrices, they should be the same shape as the corresponding weight matrices:
grad_fn = grad(loss)
trial_gradients = grad_fn(params, x, y)
# What does the gradient matrix look like?
print(trial_gradients['w1'])
print(trial_gradients['b1'])
Output:
[[-0.03730299 -0.03730299]
[-0. -0. ]
[-0.01264969 -0.01264969]
[-0. -0. ]
[ 0.0213216 0.0213216 ]
[ 0.0075064 0.0075064 ]
[-0. -0. ]
[-0. -0. ]
[-0. -0. ]
[-0.02079077 -0.02079077]
[ 0.00804559 0.00804559]
[-0.00102622 -0.00102622]
[-0.02591824 -0.02591824]
[ 0.00158182 0.00158182]
[ 0. 0. ]
[ 0. 0. ]]
[-0.03730299 -0. -0.01264969 -0. 0.0213216 0.0075064
-0. -0. -0. -0.02079077 0.00804559 -0.00102622
-0.02591824 0.00158182 0. 0. ]
Side note: If you have implemented neural networks with numpy, you had to code up
the big derivative equations and then compute the gradient separately for each and every
layer with dozens of lines of code. Even in frameworks like Tensorflow and Pytorch, the
autodifferentiation is more complicated than this. This is where JAX is so so much simpler,
since all you need isgrad()
, even if network contains very complicated expressions,
if-else statements or for loops.
Update the weights (backpropagation)
Finally, the most important part of training: update the weights so that the network "learns" from the
data and becomes more accurate. Here's the breakdown of the update()
function:
grad_fn = grad(loss)
returns a differentiated version of loss, with respect to the parameters (w1
,b1
). Side note: How does it know what to differentiate the loss function against? thegrad()
function by default differentiates with respect to the first argument input of theloss
function, which isparams
in our case. If you wanted to explicitly mention this, you could do so using theargnums
parameter e.g.grad_fn = grad(loss, argnums=0)
, where0
means the first
the argument of the loss function.gradients = grad_fn(params, input, target)
provides the gradients values as a dictionary, just like theparams
dictionary. Hence, you can access the gradient values likedw1 = gradients['w1']
.- Finally, the weights are updated by subtracting the gradients from the current weight values. The
step_size
is also commonly known as 'learning rate' and controls how fast the weights converge to their best value. This is a hyperparameter, so you can play around with this, for example, try 0.1 or 0.001. - At the end, we can see our new weight matrices outputted, and this will become our new
params
.
step_size = 0.01
def update(params, input, target):
grad_fn = grad(loss)
gradients = grad_fn(params, input, target)
dw1 = gradients['w1']
db1 = gradients['b1']
dw2 = gradients['w2']
db2 = gradients['b2']
new_params = {}
new_params['w1'] = params['w1'] - step_size * dw1
new_params['b1'] = params['b1'] - step_size * db1
new_params['w2'] = params['w2'] - step_size * dw2
new_params['b2'] = params['b2'] - step_size * db2
return new_params
update(params, x, y)
Output:
{'b1': DeviceArray([ 0.00245504, -0.0105805 , -0.00281096, -0.00441172,
0.00215377, -0.0004177 , -0.01002556, 0.01156011,
-0.00538138, -0.00468898, 0.00241345, -0.0141186 ,
0.01880229, 0.00225983, 0.00497515, -0.02089684], dtype=float32),
'b2': DeviceArray([0.00772598], dtype=float32),
'w1': DeviceArray([[-0.00760128, 0.00647 ],
[-0.01715074, -0.00641281],
[ 0.00220305, 0.00689212],
[-0.01798742, 0.00526943],
[ 0.01796268, 0.00245824],
[ 0.0128455 , 0.00207248],
[-0.00112515, 0.01107192],
[-0.008862 , -0.01382878],
[ 0.00024038, -0.00296096],
[ 0.02102869, 0.01462787],
[ 0.00985131, -0.00318574],
[ 0.02047778, -0.0010458 ],
[ 0.0064809 , 0.00967745],
[ 0.01751603, -0.00756727],
[-0.01004758, 0.00183098],
[ 0.00619588, 0.0072066 ]], dtype=float32),
'w2': DeviceArray([[ 0.01843323, 0.01236895, 0.00636894, 0.01235389,
-0.01006379, -0.00341038, 0.00302018, 0.00460129,
0.00601732, 0.01088563, -0.00378617, 0.00061392,
0.0134965 , -0.0005334 , -0.01437746, -0.00942327]], dtype=float32)}
Just for the sake of understanding, let's check the outputs of our untrained neural network:
for x in X:
result = predict(params, x)
print(x, float(result))
Output:
[0. 0.] -0.012294544838368893
[0. 1.] -0.012028749100863934
[1. 0.] -0.01234700158238411
[1. 1.] -0.012068428099155426
The outputs obviously make no sense and are essentially random.
Training loop
And now, we shall train the network. Training means passing the data points through the network, computing the loss and the gradient using the output label, and then updating the parameters repeatedly until the accuracy is high enough. So, when we run 100 epochs of training, we are simply running the update function 100 times in a for-loop. So, a bare minimum training loop would look like this:
for epoch in range(500):
for x, y in zip(X, Y):
params = update(params, x, y)
But obviously, this is not printing anything as the training goes on and we are not observing any of the quantities, like a loss. So, we will record the loss value in each epoch and print it every 10 epochs:
losses = []
for epoch in range(500):
avg_loss = 0
for x, y in zip(X, Y):
params = update(params, x, y)
avg_loss += loss(params, x, y)
avg_loss = avg_loss / 4
losses.append(avg_loss)
if epoch % 10 == 0:
print(f"Epoch #{epoch}, loss = {avg_loss}")
Output:
Epoch #0, loss = 0.24474859237670898
Epoch #10, loss = 0.24462637305259705
Epoch #20, loss = 0.24456220865249634
Epoch #30, loss = 0.24450576305389404
...
Epoch #450, loss = 0.18149125576019287
Epoch #460, loss = 0.17451581358909607
Epoch #470, loss = 0.16796663403511047
Epoch #480, loss = 0.16084203124046326
Epoch #490, loss = 0.15317147970199585
The loss values are declining as expected, lets's do a simple plot to observe the trend:
import matplotlib.pyplot as plt
plt.plot(losses)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()
Finally, let's check whether our neural network is behaving like an XOR gate. The raw output of the neural net is not 0 or 1, so we post-process it by doing float(result) < 0.5
to turn it into True or False, which represents 0 or 1. Here's the output:
for x in X:
result = predict(params, x)
print(x, float(result) < 0.5)
Output:
[0. 0.] False
[0. 1.] True
[1. 0.] True
[1. 1.] False
Conclusion
Hopefully, this article introduced you to some tools in JAX and how you can set up a simple neural network. Neither the code nor the dataset are up to any professional standard but are simply meant to make things easy for novice learners. After this, you can move onto bigger and more complex neural networks and better code in articles like this:
(JAX Documentation) Training a Simple Neural Network, with tensorflow/datasets Data Loading
How to code Differentiation in JAX with Simple Examples
Have feedback or facing some problems with this tutorial? Please comment, we will definitely get back to you!
