LAMBDA = 100
no_of_epochs = 1 #this should be in the range of 100 - 150
generator = get_generator()
discriminator = get_discriminator()
gen_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)
disc_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)
def gen_loss(disc_fake_output , gen_output , target_image) :
gen_loss = cross_entropy(tf.ones_like(disc_fake_output) , disc_fake_output)
l1_loss = tf.reduce_mean(tf.abs(target_image - gen_output))
return gen_loss + (LAMBDA * l1_loss)
def disc_loss(disc_real_output , disc_fake_output) :
disc_real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
disc_fake_loss = cross_entropy(tf.zeros_like(disc_fake_output) , disc_fake_output)
return disc_real_loss + disc_fake_loss
def train_step(inp_image , tar_image) :
with tf.GradientTape() as gen_tape , tf.GradientTape() as disc_tape :
gen_output = generator(inp_image , training = True)
disc_real_output = discriminator([inp_image , tar_image] , training = True)
disc_fake_output = discriminator([inp_image , gen_output] , training = True)
disc_loss_val = disc_loss(disc_real_output , disc_fake_output)
gen_loss_val = gen_loss(disc_fake_output, gen_output , tar_image)
disc_gradient = disc_tape.gradient(disc_loss_val, discriminator.trainable_variables) #**
gen_gradient = gen_tape.gradient(gen_loss_val , generator.trainable_variables)
gen_optimizer.apply_gradients(zip(gen_gradient , generator.trainable_variables))
disc_optimizer.apply_gradients(zip(disc_gradient , discriminator.trainable_variables))
checkpoint = tf.train.Checkpoint(gen_optimizer = gen_optimizer ,
disc_optimizer = disc_optimizer,
generator = generator,
discriminator = discriminator
)
import time
def train(train_ds, test_ds , no_of_epochs) :
for epoch in range(no_of_epochs) :
print ("epoch {} started".format(epoch))
starttime = time.time()
for n , (inp , tar) in train_ds.enumerate() :
train_step(inp , tar)
print("." , end = "")
print("epoch {} took {} time".format(epoch , time.time() - starttime))
checkpoint.save(file_prefix = "/content/checkpoint/chk_")
train(train_ds , test_ds , no_of_epochs)
No comments:
Post a Comment