Sunday, May 16, 2021

Gradient Tape Basic Tutorial

Gradient Tape is tensorflow's automatic differentiation API.

GradientTape allows you to calculate and track gradient of every differentiable tensorflow operation.  

GradientTape allows you to create custom training loops.

As an example, consider a linear equation y = 4x -5.
Here 4 is the weight and -5 is the bias. 
Let us create a custom training loop for this eqaution and see if it is able to guess the weight and bias.

import numpy as np 

import tensorflow as tf

import random 


x = np.array([-2, -1, 0 , 1,2,4,5,6], dtype=float)

y = 4* x - 5 


print(x)

print(y) 



#define weight and bias

w = tf.Variable(random.random(), trainable = True)

b = tf.Variable(random.random(), trainable = True)


#simple loss function

def simple_loss(y_groundtruth, y_predicted) :

  return tf.abs(y_groundtruth -y_predicted )



#lr

lr = 0.001


def fit_function(x_groundtruth , y_groundtruth) : 

  with tf.GradientTape(persistent = True) as tape : 

    y_predicted = w * x_groundtruth + b 

    loss  =  simple_loss(y_groundtruth , y_predicted)    


  w_gradient = tape.gradient(loss , w)

  b_gradient = tape.gradient(loss , b) 


  w.assign_sub(w_gradient * lr)

  b.assign_sub(b_gradient * lr)



for _ in range(2000) : 

    fit_function(x, y)


#w and b are tf.Variable objects, printing them directly causes the 

# objects to be printed in <object> syntax. hence call the numpy method   

print("Expected weight: 4; Predicted weight: {}".format(w.numpy()))

print("Expected bias : -5; Predicted bias : {}".format(b.numpy()))


Output : 

[-2. -1.  0.  1.  2.  4.  5.  6.]

[-13.  -9.  -5.  -1.   3.  11.  15.  19.]

Expected weight: 4; Predicted weight: 3.9907336235046387

Expected bias : -5; Predicted bias : -5.000271320343018

The predictions are pretty close to the ground truths after 2000 epochs.

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...