Exploding and vanishing gradients
The following code explores how gradients explode or vanish when backpropagating through a neural network. It depends on the type of activation function you are using and how many layers you have in the network.
Observations for some activation functions:
- sigmoid tends to lead to vanishing gradient
- sometimes with relu you get a lot of zeros in the gradient.
- I didn’t see any vanishing gradient with relu, but there was exploding gradient. Same with tanh, which was unexpected. I expected to see vanishing gradient for tanh.
- I didn’t see correspondence with relu between a weight value < 0 for a given weight layer and a gradient of zero. This was also unexpected.
Here’s the code:
import numpy as np
import torch
def print_gradients(f=torch.Tensor.relu, n_layers = 15,
layer_size=3, normalise_x = True, print_weights=False):
"""
A function to print the gradients of intermediate layers at the init stage of
a simple dense neural network. No training of the neural network is done.
All intermediate layers are square with shape (`layer_size`,
`layer_size`), except the last layer which has shape (`layer_size`, 1).
The function creates some sample data to calculate a sample loss function. The
loss is for a regression problem and is MSE.
Params:
f: activation function, like torch.Tensor.relu, or torch.Tensor.tanh
n_layers: how many layers to have in the network
layer_size: what is the size of these layers.
normalise_x: set to True to set x to (x - mean(x)) / std(x)
print_weights: set to True to print the weights of each layer along with its gradient
"""
l = dict()
n = n_layers
w = layer_size
torch.Tensor.f = f
# create some sample data
x = torch.randint(low = -10, high = 10, size=(100,w), dtype = torch.float)
y = torch.randint(low=0, high=5, size=(100,),dtype = torch.float)
if normalise_x: x = (x- x.mean(dim=0)) / x.std(dim=0)
# create random intermediate weight layers
for i in range(n):
name = 'w' + str(i)
size = (w,w) if i < (n-1) else (w,1)
l[name] = torch.randn(size, dtype = torch.float, requires_grad=True)
# forward pass, loss function, calculate gradients of intermediate layers
tmp = f(x)
for i in range(n):
name = 'w' + str(i)
if i < (n-1): tmp = tmp.matmul(l[name]).f()
else: tmp = tmp.matmul(l[name])
yp = tmp
L = (yp - y).pow(2).sum()
L.backward()
# print out the gradients
for i in range(n):
name = 'w' + str(i)
print("####", name, "####")
if print_weights: print("weights\n", l[name].detach().numpy().round(4))
print("grad\n", str(l[name].grad.numpy().round(4)),"\n")
Here is how it is used:
n_layers = 10
f = torch.Tensor.relu
print_gradients(f,n_layers, layer_size=4,print_weights=True)
Some example output:
#### w0 ####
weights
[[-1.86 0.6851 -0.5556 -0.5878]
[-0.5929 2.1052 -0.4813 -0.7117]
[-0.3172 -0.0884 1.0429 0.6922]
[-1.8281 -0.8995 1.0208 -0.9727]]
grad
[[ 0. 5713.2656 -1374.9153 0. ]
[ 0. 9508.161 -5190.595 1170.9376]
[ 0. 4955.623 -5194.9365 1775.924 ]
[ 0. 2636.087 -1891.1196 181.3048]]
#### w1 ####
weights
[[ 0.6907 0.1203 -0.4189 -2.7348]
[ 0.7943 1.1307 1.3316 -0.8833]
[-0.9451 -0.6082 0.3866 0.8421]
[ 1.3765 1.2247 0.4677 -0.4686]]
grad
[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[-1.5590840e+02 2.6379445e+04 -6.4444326e+03 0.0000000e+00]
[-6.6746460e+02 6.1963838e+03 -2.4536016e+03 0.0000000e+00]
[-2.0615299e+01 2.3536490e+02 -8.6139603e+01 0.0000000e+00]]
#### w2 ####
weights
[[ 0.1008 0.1396 -1.4134 -0.1433]
[ 1.8467 -0.8925 0.0248 -1.4145]
[-0.2846 0.4181 -1.2936 -0.1676]
[ 0.3009 2.1765 -1.7931 0.2762]]
grad
[[ 7417.276 -1930.0452 0. 0. ]
[ 11098.887 -6554.655 0. 0. ]
[ 13884.539 -13437.387 0. 0. ]
[ 0. 0. 0. 0. ]]
#### w3 ####
weights
[[-0.7117 -0.4262 -1.4683 0.0227]
[-1.0138 1.305 -0.5682 -0.8239]
[ 0.5405 -0.4138 1.5537 -0.6959]
[-1.1945 1.0653 -2.0544 1.5107]]
grad
[[0.000000e+00 0.000000e+00 0.000000e+00 7.618817e+05]
[0.000000e+00 0.000000e+00 0.000000e+00 4.573620e+01]
[0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00]
[0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00]]
#### w4 ####
weights
[[ 2.4130e-01 9.0000e-04 -6.2140e-01 -1.4420e-01]
[ 1.5490e+00 -8.5590e-01 9.4570e-01 -9.3770e-01]
[-3.0020e-01 2.0068e+00 1.1331e+00 -6.4380e-01]
[ 2.8960e-01 8.1560e-01 1.8700e-02 -1.3460e-01]]
grad
[[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[-11734.712 24611.941 30977.482 0. ]]
#### w5 ####
weights
[[ 0.8595 0.2263 0.2642 -0.9272]
[ 0.5051 0.478 2.3984 -0.8763]
[-2.3839 -1.2242 -1.0006 0.6598]
[ 0.1359 -0.8366 1.4181 0.0567]]
grad
[[ -4956.994 -1142.9338 4243.301 0. ]
[-13960.942 -3218.9739 11950.89 0. ]
[ -320.1633 -73.82 274.0672 0. ]
[ 0. 0. 0. 0. ]]
#### w6 ####
weights
[[-1.6302 -0.3628 -0.5336 0.4115]
[-0.3475 -1.2729 0.0135 -0.0092]
[ 1.3637 -0.3827 0.3039 -0.3609]
[-0.2053 -0.7371 0.0209 1.2956]]
grad
[[ 6943.7407 0. -1444.9719 0. ]
[ 4873.084 0. -1014.0744 0. ]
[22692.582 0. -4722.259 0. ]
[ 0. 0. 0. 0. ]]
#### w7 ####
weights
[[ 0.7484 0.1641 0.3732 1.5976]
[-0.3619 1.863 -1.3552 0.4618]
[ 0.4905 0.2353 -0.4943 0.2616]
[-0.004 -0.4753 -0.1268 0.2372]]
grad
[[-10202.121 1074.2303 5650.8564 14572.965 ]
[ 0. 0. 0. 0. ]
[ -1852.5332 195.0621 1026.1002 2646.205 ]
[ 0. 0. 0. 0. ]]
#### w8 ####
weights
[[-0.515 0.6908 0.0826 1.5223]
[-0.3799 0.8091 0.0506 -1.2749]
[ 0.6146 1.4947 0.8942 -0.1798]
[ 1.3187 -0.5813 1.2728 -0.3163]]
grad
[[ 6438.6147 -893.7505 1732.8018 -3122.6318]
[ 1589.804 -220.6823 427.8584 -771.0312]
[ 2178.9844 -302.467 586.4224 -1056.7747]
[12648.798 -1755.7922 3404.1267 -6134.4775]]
#### w9 ####
weights
[[-2.6363]
[ 0.366 ]
[-0.7095]
[ 1.2786]]
grad
[[-5348.286 ]
[ -621.3683]
[-7078.1904]
[-1282.6868]]