Skip to Content

How does a neural network learn?

Table of Contents

These are my notes from the book Grokking Deep Learning by Andrew Trask. Feel free to check my first post on this book to get my overall thoughts and recommendations on how to approach this series. The rest of my notes for this book can be found here

Paradigm of predict, compare, and learn

  • Predict - Forward propogation
  • Compare - Measurement of how much our prediction ‘missed’.
    • Important & complicated in Deep Learning
    • Error is always positive
    • In these examples, we will use ‘Mean Squared Error’
    • Output logic will simply be ‘hot or cold’ type signal
  • Learn - Takes error, and tells each weight how to change to reduce error
    • Gradient descent
    • Calculates a number for each weight, tells it ‘higher or lower’

Does network make good predictions?

Going back to example of turning a knob (weight) to adjust predictions:

knob_weight = 0.5
my_input = 0.5
goal_pred = 0.8 # Actual things that were observed
pred = my_input * knob_weight
error = (pred - goal_pred) ** 2 # Error squared, always positive
print(error)

Squaring the error amplifies big errors (>1), and reduces small errors (<1). This is ok, because we want the network to pay more attention to the big errors, and less attention to the small ones.

Why measure error?

  • It simplifies the problem. Easier to say ‘how do we get the error to zero, rather than how do we get a more accurate prediction”.
  • Helps us prioritize the bigger errors over the smaller ones. Especially if we square them.
  • We only want positive errors. If we had negatives, they might cancel eachother out. If half the errors were avg 1000, the other half -1000, then our avg error would be 0.

What is the simplest form of neural learning?

Hot and Cold method - Adjust knob_weight either up or down (direction) so that the error reduces. Keep doing this until error gets to 0.

# An empty network
weight = 0.1
lr = 0.1
def neural_network(my_input, weight):
    prediction = my_input * weight
    return prediction
# PREDICT: Make prediction and evaluate error
number_of_toes = [8.5]
win_or_lose_binary = [1] # won

my_input = number_of_toes[0]
true = win_or_lose_binary[0]

pred = neural_network(my_input, weight)
error = (pred - true) ** 2
print(error)
0.022499999999999975
# COMPARE: Predict with higher weight, and evaluate error
lr = 0.1 # Move weight up by this amount
p_up = neural_network(my_input, weight+lr)
e_up = (p_up - true) ** 2
print(e_up)
0.49000000000000027
# COMPARE: Predict with lower weight, and evaluate error
lr = 0.01
p_dn = neural_network(my_input, weight-lr)
e_dn = (p_dn - true) ** 2
print(e_dn)
0.05522499999999994
# COMPARE + LEARN: Compare errors and set new weight
if (e_dn < e_up):
    weight -= lr
if (e_up < e_up):
    weight += lr
# Another example

weight = 0.5
my_input = 0.5
goal_prediction = 0.8

step_amount = 0.001 # how much to move weights each iteration

#PREDICT
for iteration in range(1101):
    prediction = my_input * weight
    error = (prediction - goal_prediction) ** 2
    
    print("Error: " + str(error) + " Prediction: " + str(prediction))
    
#COMPARE
    up_prediction = my_input * (weight + step_amount)
    up_error = (goal_prediction - up_prediction) ** 2
    
    down_prediction = my_input * (weight - step_amount)
    down_error = (goal_prediction - down_prediction) ** 2
    
#LEARN
    if (down_error < up_error): # If down is better
        weight = weight - step_amount # Keep going down
        
    if (down_error > up_error): # If up is better
        weight = weight + step_amount # Keep going up
Error: 0.30250000000000005 Prediction: 0.25
Error: 0.3019502500000001 Prediction: 0.2505
Error: 0.30140100000000003 Prediction: 0.251
Error: 0.30085225 Prediction: 0.2515
Error: 0.30030400000000007 Prediction: 0.252
Error: 0.2997562500000001 Predict
        ...
Error: 1.0799505792475652e-27 Prediction: 0.7999999999999672

In general

  • Hot and cold learning is simple.
  • Problem 1: It’s inefficient
  • Problem 2: Sometimes impossible to predict exact goal

In the example above we had to iterate through 1101 tries before it got to the goal. Also, the step_amount is somewhat arbitrary. It’s difficult to know how big or small it should be. Difficult to predict exact goal, the predictions may just keep going back and forth between each side of goal_prediction as it tries to close in.

Hot and Cold only tells us which direction to go.

Calculating both direction and amount of error

weight = 0.5
goal_pred = 0.8
my_input = 0.5

for iteration in range(20):
    pred = my_input * weight
    error = (pred - goal_pred) ** 2
    direction_and_amount = (pred - goal_pred) * my_input
    weight = weight - direction_and_amount
    
    print('Error: ' + str(error) + ' Prediction: ' + str(pred))
Error: 0.30250000000000005 Prediction: 0.25
Error: 0.17015625000000004 Prediction: 0.3875
Error: 0.095712890625 Prediction: 0.49062500000000003
Error: 0.05383850097656251 Prediction: 0.56796875
Error: 0.03028415679931642 Prediction: 0.6259765625
Error: 0.0170348381996155 Prediction: 0.669482421875
Error: 0.00958209648728372 Prediction: 0.70211181640625
Error: 0.005389929274097089 Prediction: 0.7265838623046875
Error: 0.0030318352166796153 Prediction: 0.7449378967285156
Error: 0.0017054073093822882 Prediction: 0.7587034225463867
Error: 0.0009592916115275371 Prediction: 0.76902756690979
Error: 0.0005396015314842384 Prediction: 0.7767706751823426
Error: 0.000303525861459885 Prediction: 0.7825780063867569
Error: 0.00017073329707118678 Prediction: 0.7869335047900676
Error: 9.603747960254256e-05 Prediction: 0.7902001285925507
Error: 5.402108227642978e-05 Prediction: 0.7926500964444131
Error: 3.038685878049206e-05 Prediction: 0.7944875723333098
Error: 1.7092608064027242e-05 Prediction: 0.7958656792499823
Error: 9.614592036015323e-06 Prediction: 0.7968992594374867
Error: 5.408208020258491e-06 Prediction: 0.7976744445781151

The magic happens in this line of code:

direction_and_amount = (pred - goal_pred) * my_input

  • Where pred - goal_pred is the “pure error” or raw direction and amount we missed.

The following three attributes are applied to the pure error, to translate it into the ‘absolute amount’ we want to change our weight.

  • Stopping - If input is 0, then there is nothing to learn. It basically kills this neuron.
  • Negative - reversal Multiplying pure error by input will flip the sign if input is negative
  • Scaling - If input is big, weight update should also be big. Can go out of control. (Use alpha)

    # Positive input
    # Positive weight
    test_input = 100
    w = .5
    goal_prediction = .8
    prediction = test_input * w
    print('Prediction: ' + str(prediction))
    direction_amount = (prediction - goal_prediction) * test_input
    print('Pure error: ' + str(prediction - goal_prediction))
    print('Direction and Amount: ' + str(direction_amount))

    Prediction: 50.0 Pure error: 49.2 Direction and Amount: 4920.0

    # If input is positive, increasing the weight will move prediction UP
    
    test_input = 100
    w = .75
    goal_prediction = .8
    prediction = test_input * w
    print('Prediction: ' + str(prediction))
    direction_amount = (prediction - goal_prediction) * test_input
    print('Pure error: ' + str(prediction - goal_prediction))
    print('Direction and Amount: ' + str(direction_amount))

    Prediction: 75.0 Pure error: 74.2 Direction and Amount: 7420.0

    # NEGATIVE REVERSAL:
    # Multiplying pure error by input will flip the sign if input is negative
    test_input = -100
    w = .5
    goal_prediction = .8
    prediction = test_input * w
    print('Prediction: ' + str(prediction))
    direction_amount = (prediction - goal_prediction) * test_input
    print('Pure error: ' + str(prediction - goal_prediction))
    print('Direction and Amount: ' + str(direction_amount))

    Prediction: -50.0 Pure error: -50.8 Direction and Amount: 5080.0

    # If input negative, increasing weight will move prediction DOWN
    # Multiplying pure error by input will flip the sign if input is negative
    test_input = -100
    w = .75
    goal_prediction = .8
    prediction = test_input * w
    print('Prediction: ' + str(prediction))
    direction_amount = (prediction - goal_prediction) * test_input
    print('Pure error: ' + str(prediction - goal_prediction))
    print('Direction and Amount: ' + str(direction_amount))

    Prediction: -75.0 Pure error: -75.8 Direction and Amount: 7580.0

    # STOPPING:If input is 0, then it stops learning
    test_input = 0
    w = .5
    goal_prediction = .8
    prediction = test_input * w
    print('Prediction: ' + str(prediction))
    direction_amount = (prediction - goal_prediction) * test_input
    print('Pure error: ' + str(prediction - goal_prediction))
    print('Direction and Amount: ' + str(direction_amount))

    Prediction: 0.0 Pure error: -0.8 Direction and Amount: -0.0

    # SCALING:If input is big, then prediction is big
    test_input = 50000
    w = .5
    goal_prediction = .8
    prediction = test_input * w
    print('Prediction: ' + str(prediction))
    direction_amount = (prediction - goal_prediction) * test_input
    print('Pure error: ' + str(prediction - goal_prediction))
    print('Direction and Amount: ' + str(direction_amount))

    Prediction: 25000.0 Pure error: 24999.2 Direction and Amount: 1249960000.0

Learning is just reducing error

Here’s another iteration of gradient descent

# Empty network
weight = 0.1
alpha = 0.1
def neural_network(my_input, weight):
    prediction = my_input * weight
    return prediction

# PREDICT: Make prediction & evaluate error
number_of_toes = [8.5]
win_or_lose_binary = [1] # Won

my_input = number_of_toes[0]
goal_pred = win_or_lose_binary[0]
pred = neural_network(my_input, weight)

error = (pred - goal_pred) **2 # Raw error

# COMPARE: Calculate 'node delta' and putting it on the output node
# Raw amount of how much this node missed. Either too high or too low.
delta = pred - goal_pred 

# LEARN: Calculate 'weight_delta' and putting it on the weight
# A measure of how much this weight caused the network to miss
# This accounts for scaling, negative reversal, and stopping
weight_delta = my_input * delta

# LEARN: Updating the weight
# Multiply by small number 'alpha' to control how fast the 
# network leans. 
weight = weight_delta * alpha
weight, goal_pred, my_input = (0.0, 0.8, 0.5)

for iteration in range(10):
    pred = my_input * weight
    error = (pred - goal_pred) **2
    delta = pred - goal_pred
    weight_delta = delta * my_input
    weight = weight - weight_delta
    print("Error: " + str(error) + " Prediction: " + str(pred) 
          + " Weight: " + str(weight))
Error: 0.6400000000000001 Prediction: 0.0 Weight: 0.4
Error: 0.3600000000000001 Prediction: 0.2 Weight: 0.7000000000000001
Error: 0.2025 Prediction: 0.35000000000000003 Weight: 0.925
Error: 0.11390625000000001 Prediction: 0.4625 Weight: 1.09375
Error: 0.06407226562500003 Prediction: 0.546875 Weight: 1.2203125
Error: 0.036040649414062535 Prediction: 0.61015625 Weight: 1.315234375
Error: 0.020272865295410177 Prediction: 0.6576171875 Weight: 1.38642578125
Error: 0.011403486728668217 Prediction: 0.693212890625 Weight: 1.4398193359375
Error: 0.006414461284875877 Prediction: 0.71990966796875 Weight: 1.479864501953125
Error: 0.0036081344727426873 Prediction: 0.7399322509765625 Weight: 1.5098983764648437

We keep adjusting the weight in small increments in the correct direction and correct amount so that error goes to 0.

The error is usually a function of some equation being applied to the weight.

A function is just something that takes numbers as an input, does some calculations, then gives another number as an output. This equation is a function. It tells you for each change in the weight, what happens to the error.

It’s all about showing the relationship between the error, and the weight.

# In the case above:
pred = my_input * weight
error = (pred - goal_pred) **2

# Is the same as:
error = ((my_input * weight) - goal_pred) ** 2

# Which is:
error = ((0.5 * weight) - 0.8) ** 2

What this shows, is that if you increased or decreased weight by some amount, then plug into formula, it would tell you what the resulting error is.

If you plotted error = ((0.5 * weight) - 0.8) ** 2, it would look like a parabola with error on the y-axis, and weight on the x-axis. The line of the parabola would represent every value of error, for every weight according to the formula. The slope points to the bottom of the bowl. That bottom represents the lowest error, where error = 0 on the y-axis.

Watch 4 steps of learning below:

weight, goal_pred, my_input = (0.0, 0.8, 1.1)

for iteration in range(4):
  print("-----\nWeight:" + str(weight))
  pred = my_input * weight
  error = (pred - goal_pred) ** 2
  delta = pred - goal_pred
  weight_delta = delta * my_input
  weight = weight - weight_delta
  print("Error:" + str(error) + " Prediction:" + str(pred))
  print("Delta:" + str(delta) + " Weight Delta:" + str(weight_delta))
-----
Weight:0.0
Error:0.6400000000000001 Prediction:0.0
Delta:-0.8 Weight Delta:-0.8800000000000001
-----
Weight:0.8800000000000001
Error:0.02822400000000005 Prediction:0.9680000000000002
Delta:0.16800000000000015 Weight Delta:0.1848000000000002
-----
Weight:0.6951999999999999
Error:0.0012446784000000064 Prediction:0.76472
Delta:-0.03528000000000009 Weight Delta:-0.0388080000000001
-----
Weight:0.734008
Error:5.4890317439999896e-05 Prediction:0.8074088
Delta:0.007408799999999993 Weight Delta:0.008149679999999992

Let’s look again at the function error = ((input * weight) - goal_pred) ** 2. What can we change to make error go down to zero? Really, the only thing that makes sense is the weight variable. Which is part of the pred function input * weight. The goal is to get our network to learn and adjust the weights so that error goes to zero. Many deep learning researches will spend their careers trying everything to get pred calculation to be as accurate as possible. ie… getting error == 0.

Tunnel vision on a single concept

It’s all about getting error to zero

  • Understand the relationship between our weight and error
  • How do changes in one variable affect the other? What is the sensitivity between the two.

Back to this formula: error = ((input * weight) - goal_pred) ** 2

This formula describes the relationship betwen error and weight. How can we use this formula to change weight so that it moves error in a certain direction?

A Box with rods poking out

Imagine a box with blue rod sticking out by 2 inches, and red rod sticking out by 4 inches. Everytime you move the blue rod in or out by 1 inch, the red rod moves in our out by 2 inches respectively. The stuff happening inside the box is the formula. In this case it could be: red_length = blue_length * 2

The formal definition of “how much does rod x move, when I tug rod y” is called a derivative.

The derivative for how much does red move when I tug on blue is equal to 2. For each move of blue, red moves blue * 2.

The derivative is between two variables. If the deivative is positive, then when we change one variable, the other will move in the same direction. If derivative is negative, then the variable will move in the opposite direction.

In the example above, deriviative of 2 means both variables will move in the same direction, red by 2 * blue. If the derivative had been -1, then red would move in the opposite direction by the same amount. Thus derivative represents both direction and amount that one variable will change given the change of the other.

Derivatives… take two

Another way to think about derivatives is by thinking of it as the slope at a point on a line or a curve. If you plot out error = ((input * weight) - goal_pred) ** 2, with input and goal_pred as fixed, you would get a U shaped curve with error as Y-axis, and weight as X-axis. The point in the middle would be where error == 0. As you move away from the zero error point, each point would have a slope, and that slope would be the derivative. To the right of the zero error point, the slope would be positive, and to the left the slope would be negative. The further away from goal_weight you move, the steeper the slope gets.

The slope’s sign gives direction, and the steepness gives amount.

  • The slope of a line/curve always points in the opposite direction to the lowest point of the line/curve. If slope is negative, increase weight to get to minimum of the error.

What you really need to know

  • Neural network is just a bunch of weights, used to compute an error function.
  • For any error function, we can compute relationship between weight and final error of the network.
  • Since we know the relationship, we can change each weight in network to reduce error down to zero.

Gradient descent

  • We move our weight in the opposite direction of the derivative to find the lowest weight.
  • This method for learning is called Gradient Descent.
  • We move the weight value opposite the gradient value, which descends error to zero.
  • Increase weight when we have a negative gradient value, and decrease weight when positive gradient value.

Divergence

  • Neural networks can sometimes explode in value.
  • If input is large, this can make weight update large, even when error is small.
  • If new error is even bigger, then it overcorrects more.
  • Big input is sensitive to changes in the weight pred = input * weight
  • Derivative is big… ie error is very sensitive to weight

Use Alpha to prevent overorrecting weight updates

  • Multiply the weight update by a fraction to make it smaller.
  • Usually a single valued number betweeen 0-1 (alpha).
  • Finding the right alpha to use is trial and error.
  • Watch error over time, if it starts diverging, then alpha is too high.
  • If learning is too slow, then increase alpha.

Example of using a big input and causing divergence

weight = 0.5
goal_pred = 0.8
my_input = 2
for iteration in range(20):
    pred = my_input * weight
    error = (pred - goal_pred) ** 2
    derivative = my_input * (pred - goal_pred)
    weight = weight - derivative
    print('Error: ' + str(error) + ' Prediction: ' + str(pred))
Error: 0.03999999999999998 Prediction: 1.0
Error: 0.3599999999999998 Prediction: 0.20000000000000018
Error: 3.2399999999999984 Prediction: 2.5999999999999996
Error: 29.159999999999986 Prediction: -4.599999999999999
Error: 262.4399999999999 Prediction: 16.999999999999996
Error: 2361.959999999998 Prediction: -47.79999999999998
Error: 21257.639999999978 Prediction: 146.59999999999994
Error: 191318.75999999983 Prediction: -436.5999999999998
Error: 1721868.839999999 Prediction: 1312.9999999999995
Error: 15496819.559999991 Prediction: -3935.799999999999
Error: 139471376.03999993 Prediction: 11810.599999999997
Error: 1255242384.3599997 Prediction: -35428.59999999999
Error: 11297181459.239996 Prediction: 106288.99999999999
Error: 101674633133.15994 Prediction: -318863.79999999993
Error: 915071698198.4395 Prediction: 956594.5999999997
Error: 8235645283785.954 Prediction: -2869780.599999999
Error: 74120807554073.56 Prediction: 8609344.999999996
Error: 667087267986662.1 Prediction: -25828031.799999986
Error: 6003785411879960.0 Prediction: 77484098.59999996
Error: 5.403406870691965e+16 Prediction: -232452292.5999999

Example of using alpha to eliminate divergence

weight = 0.5
goal_pred = 0.8
my_input = 2
alpha = 0.1
for iteration in range(20):
    pred = my_input * weight
    error = (pred - goal_pred) ** 2
    derivative = my_input * (pred - goal_pred)
    weight = weight - (alpha * derivative)
    print('Error: ' + str(error) + ' Prediction: ' + str(pred))
Error: 0.03999999999999998 Prediction: 1.0
Error: 0.0144 Prediction: 0.92
Error: 0.005183999999999993 Prediction: 0.872
Error: 0.0018662400000000014 Prediction: 0.8432000000000001
Error: 0.0006718464000000028 Prediction: 0.8259200000000001
Error: 0.00024186470400000033 Prediction: 0.815552
Error: 8.70712934399997e-05 Prediction: 0.8093312
Error: 3.134566563839939e-05 Prediction: 0.80559872
Error: 1.1284439629823931e-05 Prediction: 0.803359232
Error: 4.062398266736526e-06 Prediction: 0.8020155392
Error: 1.4624633760252567e-06 Prediction: 0.8012093235200001
Error: 5.264868153690924e-07 Prediction: 0.8007255941120001
Error: 1.8953525353291194e-07 Prediction: 0.8004353564672001
Error: 6.82326912718715e-08 Prediction: 0.8002612138803201
Error: 2.456376885786678e-08 Prediction: 0.8001567283281921
Error: 8.842956788836216e-09 Prediction: 0.8000940369969153
Error: 3.1834644439835434e-09 Prediction: 0.8000564221981492
Error: 1.1460471998340758e-09 Prediction: 0.8000338533188895
Error: 4.125769919393652e-10 Prediction: 0.8000203119913337
Error: 1.485277170987127e-10 Prediction: 0.8000121871948003