Saturday, June 5, 2021

Summary Tips for CycleGAN

1Assume two functions G:x->y and F:y->x
2Create two generators accordingly: generator_g and generator_f
3Create two discriminators: discriminator_x and discriminator_y
4Create Four Loss functions: 
 discriminator_loss
 generator_loss
 cycle_loss
 identity_loss
5Remember that discriminator_loss and identity_loss are to be halved ( multiply by 0.5)
6Remember that cycle_loss and identity_loss are to be multiplied by a factor of 10

Training Tips

1We already have two functions G:x->y and F:y->x
2Create First Cycle: 
    x->y->x
   real_x->generated_y->cycled_x

generated_y = generator_g(real_x , training = True)
cycled_x = generator_f(generated_y , training = True)
3Create Second Cycle: 
    y->x->y
   real_y->generated_x->cycled_y

generated_x = generator_f(real_y , training = True)
cycled_y = generator_g(generated_x , training = True)
4"Same Outputs" (same_x , same_y) are generated by giving reversed inputs. 
e.g. generator_g normally takes x as input and generates y at output. To generate same outputs,
generator_g will be given y as input. 
Remember this as same means input same as output: 
same_x = genereator_g(real_x , training = True) 
same_y = genereator_f(real_y , training = True) 
5discriminator outpus are straight forward (input and output are same variables) 
disc_real_x = discriminator_x(real_x , training = True) 
disc_real_y = discriminator_x(real_y , training = True) 
disc_generated_x = discriminator_x(generated_x , training = True) 
disc_generated_y = discriminator_x(generated_y , training = True) 
6Generator loss functions are straigh forward : take the correponding x/y generated output of discriminator (NOT generator)

gen_g_loss = generator_loss(disc_generated_x) 
gen_f_loss = generator_loss(disc_generated_y) 
Cycle loss is the sum of cycle losses of x and y : 
tot_cycle_loss = cycle_loss(real_x , cycled_x) + cycle_loss(real_y , cycled_y)
8Identity losses are also straighforward , calculate them using real_? and same_?
9The total generator losses are sum of gen loss , cycle_loss and identity_loss. 
(Identity loss of corresponding OUTPUT variable)
10 So the total generator losses will be: 
tot_gen_g_loss = gen_g_loss  + tot_cycle_loss + identity_loss(real_y , same_y) 
tot_gen_f_loss = gen_f_loss + tot_cycle_loss + identity_loss(real_x, real_y) 
11 Discriminator losses are straight forward. They are calculated using discriminator outputs for real input and generated input. Remember discriminator loss function does not take any real variables as input, it works on generated outputs only.
disc_x_loss = discriminator_loss(disc_real_x , disc_generated_x) 
disc_y_loss = discriminator_loss(disc_real_y , disc_generated_y) 
12 Gradient calculations are straight forward
  gen_g_gradient = tape.gradient(tot_gen_g_loss,generator_g.trainable_variables)

  gen_f_gradient = tape.gradient(tot_gen_f_loss,generator_f.trainable_variables)

  disc_x_gradient = tape.gradient(disc_x_loss , discriminator_x.trainable_variables)

  disc_y_gradient = tape.gradient(disc_y_loss , discriminator_y.trainable_variables)
13apply_gradients is straight forward

gen_g_optimizer.apply_gradients(zip(gen_g_gradient , generator_g.trainable_variables))  gen_f_optimizer.apply_gradients(zip(gen_f_gradient , generator_f.trainable_variables))  disc_x_optimizer.apply_gradients(zip(disc_x_gradient , discriminator_x.trainable_variables))  disc_y_optimizer.apply_gradients(zip(disc_y_gradient , discriminator_y.trainable_variables))

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...