Wednesday, June 2, 2021

Pix2Pix III - The training function

 LAMBDA = 100 

no_of_epochs = 1 #this should be in the range of 100 - 150

generator = get_generator() 
discriminator = get_discriminator()

gen_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1  = 0.5)
disc_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)

def gen_loss(disc_fake_output ,  gen_output , target_image) : 
  gen_loss = cross_entropy(tf.ones_like(disc_fake_output) , disc_fake_output)
  l1_loss = tf.reduce_mean(tf.abs(target_image - gen_output))
  return gen_loss + (LAMBDA * l1_loss)

def disc_loss(disc_real_output , disc_fake_output) : 
  disc_real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
  disc_fake_loss = cross_entropy(tf.zeros_like(disc_fake_output) , disc_fake_output)
  return disc_real_loss  + disc_fake_loss


def train_step(inp_image , tar_image) : 
  with tf.GradientTape() as gen_tape , tf.GradientTape() as disc_tape : 
    gen_output = generator(inp_image , training = True)
    disc_real_output = discriminator([inp_image , tar_image] , training = True)
    disc_fake_output = discriminator([inp_image , gen_output] , training = True)

    disc_loss_val = disc_loss(disc_real_output , disc_fake_output)
    gen_loss_val = gen_loss(disc_fake_output, gen_output , tar_image)

  disc_gradient = disc_tape.gradient(disc_loss_val, discriminator.trainable_variables) #**
  gen_gradient =  gen_tape.gradient(gen_loss_val , generator.trainable_variables)

  gen_optimizer.apply_gradients(zip(gen_gradient , generator.trainable_variables))
  disc_optimizer.apply_gradients(zip(disc_gradient , discriminator.trainable_variables))
  
checkpoint = tf.train.Checkpoint(gen_optimizer = gen_optimizer , 
                           disc_optimizer = disc_optimizer,
                           generator = generator, 
                           discriminator = discriminator
                           )  

import time
def train(train_dstest_ds , no_of_epochs) : 
  for epoch in range(no_of_epochs) : 
    print ("epoch {} started".format(epoch))
    starttime = time.time()
    for n , (inp , tar) in train_ds.enumerate() : 
      train_step(inp , tar)
      print("." , end = "")
    print("epoch {} took {} time".format(epoch , time.time() - starttime))
    checkpoint.save(file_prefix = "/content/checkpoint/chk_")

train(train_ds , test_ds , no_of_epochs)    

No comments:

Post a Comment

 using Microsoft.AspNetCore.Mvc; using System.Xml.Linq; using System.Xml.XPath; //<table class="common-table medium js-table js-stre...