How to Train a Neural Network from Scratch in JAX with Example

Updated on: August 31, 2021

If you work with neural nets and have not heard about JAX, you need to. JAX is a python library developed by Google that accelerates machine learning research by using special hardware such as GPU/TPUs. Not only will it help you train neural networks faster, but you are also be saving cost and energy. In this article, we will train a very simple neural network on a small toy dataset.

This article is intended for beginners, who have not used JAX before, but you should know the basics of python and neural networks to understand everything.

This article is also available in Google Colab, so feel free to interactively run the code there.


What is a neural network?

To put it simply, a neural network is a function mapper; it takes an input (e.g. numbers, images, encoded text) and calculates an output using the tunable weights/parameters it has. Suppose we train a neural network on images of dogs; the backpropagation algorithm changes ("tunes") the weights of the network, so that when the pattern of dog images shows up in any image, the network will output positively, else negatively.

How to train a neural network in JAX

Since this is a beginner's tutorial, we will neither use large neural networks with millions of parameters nor will we use datasets with thousands of data points. We will train really small networks (16 hidden units only) on the input-output values of a simple 2-bit XOR gate so that we can focus on the syntax of JAX. In fact, the code quality is also very naive, with no fancy syntax or complex classes, so that beginners can grasp the tools in JAX. Later, you can move onto more advanced neural network training in articles like this.

What is an XOR gate?

An XOR gate is a logical gate/unit that takes bits as input (0 or 1) and also outputs bits based on their characteristics. So, an AND gate will only return 1 or both inputs to it are 1, else it outputs 0. Similarly, there are OR, NAND, NOR gates, but today we will use the XOR gate because it doesn't have linearly separable data (read more about this). And neural networks are great at fitting non-linear data points.


The input-output characteristics of a 2-bit XOR gate are as follow:

  • If both inputs are different (1 and 0, or 0 and 1), then the output is 1.
  • else if both inputs are the same (0 and 0, or 1 and 1), then the output is 0.

Hence we only have 4 different cases for each permutation of the two outputs, and this makes for a great toy dataset.

Import JAX

You can install jax with pip install jax and then import it.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

JAX needs a random seed to generate the initial values of the weights, so we do this:

key = random.PRNGKey(0)

DeviceArray([0, 0], dtype=uint32)

w_key, b_key = random.split(key)
print(w_key, b_key)

[4146024105  967050713] [2718843009 1272950319]

XOR toy dataset

Let's type up the four different cases of an XOR gate into Xand Y as the training dataset (much easier than dealing with MNIST,let alone unclean datasets), followed by converting them to JAX arrays.

# XOR data
X = [
  [0.0, 0.0],
  [0.0, 1.0],
  [1.0, 0.0],
  [1.0, 1.0],

Y = [

X = jnp.array(X)
Y = jnp.array(Y)

Generate the random weight matrices from layer sizes

The XOR gate takes 2 inputs and provides 1 output, hence input_dims = 2and output_dims = 1. Our simple neural network will be a fully connected network (FCN) with just one hidden layer in the middle, with 16 nodes only, hencen_hidden = 16.

w1 and b1 contain the weight values of the first hidden layer with 16 nodes. w1 has a shape of (16, 2) and b1 has a shape of (16,)(just a single dimensional vector). Similarly, w2 and b2 belong to
the output layer and have shapes of (1, 16) and (1,) respectively. The scale variable does exactly what it sounds like, it scales the random weight values to approximately 0.01 because that leads to better training.

input_dims = 2
n_hidden = 16
output_dims = 1

layer1_shape = (n_hidden, input_dims)
layer2_shape = (output_dims, n_hidden)

scale = 1e-2

w1 = scale * random.normal(w_key, layer1_shape)
b1 = scale * random.normal(b_key, (n_hidden,))

w2 = scale * random.normal(w_key, layer2_shape)
b2 = scale * random.normal(b_key, (output_dims,))

Let's have a look at w1, it contains random weights, which will change as the networks train. You can print the other weights too to understand their shapes.

# what does the weight matrix look like?

DeviceArray([[-0.00797431,  0.00609697],
              [-0.01715074, -0.00641281],
              [ 0.00207656,  0.00676562],
              [-0.01798742,  0.00526943],
              [ 0.0181759 ,  0.00267145],
              [ 0.01292056,  0.00214754],
              [-0.00112515,  0.01107192],
              [-0.008862  , -0.01382878],
              [ 0.00024038, -0.00296096],
              [ 0.02082078,  0.01441997],
              [ 0.00993177, -0.00310528],
              [ 0.02046751, -0.00105606],
              [ 0.00622171,  0.00941826],
              [ 0.01753185, -0.00755145],
              [-0.01004758,  0.00183098],
              [ 0.00619588,  0.0072066 ]], dtype=float32)

The parameters of the neural network will be nice to have in a dictionary since JAX can easily differentiate through python data structures, so let's define the params dictionary. You will be able to see the usefulness of this a little later.

params = {
    'w1': w1,
    'b1': b1,
    'w2': w2,
    'b2': b2,

Now, let's get into the calculations at the heart of any neural network: matrix multiplications and activation functions. ReLU is a very popular activation function and provides great performance. ReLU is very easy to understand: if the input is negative, the output is 0. But if the input is positive, the input passes as the output unchanged. Let's define a ReLU function for that:

def relu(x):
  return jnp.maximum(0, x)

And now, let's do a single forward pass through our untrained neural network just to see how JAX can
multiply the weights and apply the activations:

  • let's take X[3] (which is just [1, 1]) as the sample input.
  • z1 =, x) + b1 performs the matrix multiplication for the hidden layer with 16 nodes and a1 = relu(z1) calculates the activations.
  • finally, z2 =, a1) + b2 takes the activations from the hidden layer and does the matrix multiplication of the output layer and then outputs a simple scalar value (i.e. logit), which we will consider as the output of the XOR gate for now.

# testing forward pass
x = X[3]
y = Y[3]

z1 =, x) + b1
a1 = relu(z1)
z2 =, a1) + b2


DeviceArray([-0.01206843], dtype=float32)


The forward pass through the FCN is something that will happen many many times during training, hence we wrap these steps in a predict() function, which takes the paramsdictionary we defined, along with input (e.g. [0,1]).

def predict(params, input):
  w1 = params['w1']
  b1 = params['b1']
  w2 = params['w2']
  b2 = params['b2']

  z1 =, input) + b1
  a1 = relu(z1)
  z2 =, a1) + b2
  return z2

Define the loss function (mean squared error)

Now, any neural network trains on a loss function, which helps the neural network understand how wrong or right it is compared to the labels (error value). We will use the mean squared error (MSE) loss function since it is the mean of the square of the difference between the prediction/output of the neural network and the label from the training data.

def loss(params, input, target):
  prediction = predict(params, input)
  return jnp.mean((prediction - target) ** 2)

loss(params, x, y)

DeviceArray(1.0242825, dtype=float32)

What does the loss value of MSE represent?
When the neural network is untrained, the output is wrong and so the difference between the
output and label is high, thus loss will be large. However, as the neural network will train
and tune its weights, the output will become closer and closer to the true label and hence the
loss value will go down.

Calculate gradients with respect to the weights

In deep learning, the gradients of the loss function with respect to the weight parameters (e.g. w1, b1, etc.) is the core concept in the backpropagation algorithm. The gradients provide the signals for tuning the weights so that the weights can either increase or decrease incrementally and then the output of the neural network becomes more accurate. And the magicalgrad() function of JAX is all you need to calculate the gradients. Let's print out and look at the gradient matrices, they should be the same shape as the corresponding weight matrices:

grad_fn = grad(loss)

trial_gradients = grad_fn(params, x, y)

# What does the gradient matrix look like?

[[-0.03730299 -0.03730299]
[-0.         -0.        ]
[-0.01264969 -0.01264969]
[-0.         -0.        ]
[ 0.0213216   0.0213216 ]
[ 0.0075064   0.0075064 ]
[-0.         -0.        ]
[-0.         -0.        ]
[-0.         -0.        ]
[-0.02079077 -0.02079077]
[ 0.00804559  0.00804559]
[-0.00102622 -0.00102622]
[-0.02591824 -0.02591824]
[ 0.00158182  0.00158182]
[ 0.          0.        ]
[ 0.          0.        ]]

[-0.03730299 -0.         -0.01264969 -0.          0.0213216   0.0075064
-0.         -0.         -0.         -0.02079077  0.00804559 -0.00102622
-0.02591824  0.00158182  0.          0.        ]

Side note: If you have implemented neural networks with numpy, you had to code up
the big derivative equations and then compute the gradient separately for each and every
layer with dozens of lines of code. Even in frameworks like Tensorflow and Pytorch, the
autodifferentiation is more complicated than this. This is where JAX is so so much simpler,
since all you need is grad(), even if network contains very complicated expressions,
if-else statements or for loops.

Update the weights (backpropagation)

Finally, the most important part of training: update the weights so that the network "learns" from the
data and becomes more accurate. Here's the breakdown of the update() function:

  • grad_fn = grad(loss) returns a differentiated version of loss, with respect to the parameters (w1, b1). Side note: How does it know what to differentiate the loss function against? the grad() function by default differentiates with respect to the first argument input of the loss function, which is params in our case. If you wanted to explicitly mention this, you could do so using the argnums parameter e.g. grad_fn = grad(loss, argnums=0), where 0 means the first
    the argument of the loss function.
  • gradients = grad_fn(params, input, target) provides the gradients values as a dictionary, just like the params dictionary. Hence, you can access the gradient values likedw1 = gradients['w1'].
  • Finally, the weights are updated by subtracting the gradients from the current weight values. The step_size is also commonly known as 'learning rate' and controls how fast the weights converge to their best value. This is a hyperparameter, so you can play around with this, for example, try 0.1 or 0.001.
  • At the end, we can see our new weight matrices outputted, and this will become our newparams.

step_size = 0.01

def update(params, input, target):
  grad_fn = grad(loss)
  gradients = grad_fn(params, input, target)
  dw1 = gradients['w1']
  db1 = gradients['b1']
  dw2 = gradients['w2']
  db2 = gradients['b2']

  new_params = {}
  new_params['w1'] = params['w1'] - step_size * dw1
  new_params['b1'] = params['b1'] - step_size * db1
  new_params['w2'] = params['w2'] - step_size * dw2
  new_params['b2'] = params['b2'] - step_size * db2

  return new_params

update(params, x, y)

{'b1': DeviceArray([ 0.00245504, -0.0105805 , -0.00281096, -0.00441172,
              0.00215377, -0.0004177 , -0.01002556,  0.01156011,
              -0.00538138, -0.00468898,  0.00241345, -0.0141186 ,
              0.01880229,  0.00225983,  0.00497515, -0.02089684],            dtype=float32),
'b2': DeviceArray([0.00772598], dtype=float32),
'w1': DeviceArray([[-0.00760128,  0.00647   ],
              [-0.01715074, -0.00641281],
              [ 0.00220305,  0.00689212],
              [-0.01798742,  0.00526943],
              [ 0.01796268,  0.00245824],
              [ 0.0128455 ,  0.00207248],
              [-0.00112515,  0.01107192],
              [-0.008862  , -0.01382878],
              [ 0.00024038, -0.00296096],
              [ 0.02102869,  0.01462787],
              [ 0.00985131, -0.00318574],
              [ 0.02047778, -0.0010458 ],
              [ 0.0064809 ,  0.00967745],
              [ 0.01751603, -0.00756727],
              [-0.01004758,  0.00183098],
              [ 0.00619588,  0.0072066 ]], dtype=float32),
'w2': DeviceArray([[ 0.01843323,  0.01236895,  0.00636894,  0.01235389,
              -0.01006379, -0.00341038,  0.00302018,  0.00460129,
                0.00601732,  0.01088563, -0.00378617,  0.00061392,
                0.0134965 , -0.0005334 , -0.01437746, -0.00942327]],            dtype=float32)}

Just for the sake of understanding, let's check the outputs of our untrained neural network:

for x in X:
  result = predict(params, x)
  print(x, float(result))

[0. 0.] -0.012294544838368893
[0. 1.] -0.012028749100863934
[1. 0.] -0.01234700158238411
[1. 1.] -0.012068428099155426

The outputs obviously make no sense and are essentially random.

Training loop

And now, we shall train the network. Training means passing the data points through the network, computing the loss and the gradient using the output label, and then updating the parameters repeatedly until the accuracy is high enough. So, when we run 100 epochs of training, we are simply running the update function 100 times in a for-loop. So, a bare minimum training loop would look like this:

for epoch in range(500):
  for x, y in zip(X, Y):
    params = update(params, x, y)

But obviously, this is not printing anything as the training goes on and we are not observing any of the quantities, like a loss. So, we will record the loss value in each epoch and print it every 10 epochs:

losses = []
for epoch in range(500):
  avg_loss = 0
  for x, y in zip(X, Y):
    params = update(params, x, y)
    avg_loss += loss(params, x, y)

  avg_loss = avg_loss / 4
  if epoch % 10 == 0:
    print(f"Epoch #{epoch}, loss = {avg_loss}")
Epoch #0, loss = 0.24474859237670898
Epoch #10, loss = 0.24462637305259705
Epoch #20, loss = 0.24456220865249634
Epoch #30, loss = 0.24450576305389404
Epoch #450, loss = 0.18149125576019287
Epoch #460, loss = 0.17451581358909607
Epoch #470, loss = 0.16796663403511047
Epoch #480, loss = 0.16084203124046326
Epoch #490, loss = 0.15317147970199585

The loss values are declining as expected, lets's do a simple plot to observe the trend:

import matplotlib.pyplot as plt



Finally, let's check whether our neural network is behaving like an XOR gate. The raw output of the neural net is not 0 or 1, so we post-process it by doing float(result) < 0.5 to turn it into True or False, which represents 0 or 1. Here's the output:

for x in X:
  result = predict(params, x)
  print(x, float(result) < 0.5)

[0. 0.] False
[0. 1.] True
[1. 0.] True
[1. 1.] False


Hopefully, this article introduced you to some tools in JAX and how you can set up a simple neural network. Neither the code nor the dataset are up to any professional standard but are simply meant to make things easy for novice learners. After this, you can move onto bigger and more complex neural networks and better code in articles like this:

Have feedback or facing some problems with this tutorial? Please comment, we will definitely get back to you!

Artificial Neural Networks
Shah Yasser Aziz
He is a tech & MOOC enthusiast who likes to code and build web 2.0 applications using his skills. He also likes to learn new tech stack in the industry to sharp his skillsets.

Leave a Reply

Your email address will not be published. Required fields are marked *

Share via
Copy link