1 | Assume two functions G:x->y and F:y->x |
2 | Create two generators accordingly: generator_g and generator_f |
3 | Create two discriminators: discriminator_x and discriminator_y |
4 | Create Four Loss functions: discriminator_loss generator_loss cycle_loss identity_loss |
5 | Remember that discriminator_loss and identity_loss are to be halved ( multiply by 0.5) |
6 | Remember that cycle_loss and identity_loss are to be multiplied by a factor of 10 |
Training Tips | |
1 | We already have two functions G:x->y and F:y->x |
2 | Create First Cycle: x->y->x real_x->generated_y->cycled_x generated_y = generator_g(real_x , training = True) cycled_x = generator_f(generated_y , training = True) |
3 | Create Second Cycle: y->x->y real_y->generated_x->cycled_y generated_x = generator_f(real_y , training = True) cycled_y = generator_g(generated_x , training = True) |
4 | "Same Outputs" (same_x , same_y) are generated by giving reversed inputs. e.g. generator_g normally takes x as input and generates y at output. To generate same outputs, generator_g will be given y as input. Remember this as same means input same as output: same_x = genereator_g(real_x , training = True) same_y = genereator_f(real_y , training = True) |
5 | discriminator outpus are straight forward (input and output are same variables) disc_real_x = discriminator_x(real_x , training = True) disc_real_y = discriminator_x(real_y , training = True) disc_generated_x = discriminator_x(generated_x , training = True) disc_generated_y = discriminator_x(generated_y , training = True) |
6 | Generator loss functions are straigh forward : take the correponding x/y generated output of discriminator (NOT generator) gen_g_loss = generator_loss(disc_generated_x) gen_f_loss = generator_loss(disc_generated_y) |
7 | Cycle loss is the sum of cycle losses of x and y : tot_cycle_loss = cycle_loss(real_x , cycled_x) + cycle_loss(real_y , cycled_y) |
8 | Identity losses are also straighforward , calculate them using real_? and same_? |
9 | The total generator losses are sum of gen loss , cycle_loss and identity_loss. (Identity loss of corresponding OUTPUT variable) |
10 | So the total generator losses will be: tot_gen_g_loss = gen_g_loss + tot_cycle_loss + identity_loss(real_y , same_y) tot_gen_f_loss = gen_f_loss + tot_cycle_loss + identity_loss(real_x, real_y) |
11 | Discriminator losses are straight forward. They are calculated using discriminator outputs for real input and generated input. Remember discriminator loss function does not take any real variables as input, it works on generated outputs only. disc_x_loss = discriminator_loss(disc_real_x , disc_generated_x) disc_y_loss = discriminator_loss(disc_real_y , disc_generated_y) |
12 |
Gradient calculations are straight forward gen_g_gradient = tape.gradient(tot_gen_g_loss,generator_g.trainable_variables) gen_f_gradient = tape.gradient(tot_gen_f_loss,generator_f.trainable_variables) disc_x_gradient = tape.gradient(disc_x_loss , discriminator_x.trainable_variables) disc_y_gradient = tape.gradient(disc_y_loss , discriminator_y.trainable_variables) |
13 | apply_gradients is straight forward gen_g_optimizer.apply_gradients(zip(gen_g_gradient , generator_g.trainable_variables)) gen_f_optimizer.apply_gradients(zip(gen_f_gradient , generator_f.trainable_variables)) disc_x_optimizer.apply_gradients(zip(disc_x_gradient , discriminator_x.trainable_variables)) disc_y_optimizer.apply_gradients(zip(disc_y_gradient , discriminator_y.trainable_variables)) |
Saturday, June 5, 2021
Summary Tips for CycleGAN
Subscribe to:
Post Comments (Atom)
How to check local and global angular versions
Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...
-
Most of the google tutorials on keras do not show how to display a confusion matrix for the solution. A confusion matrix can throw a clear l...
-
This error means you have created the DbContext but not configured/added it to the project using either DI in startup.cs or by using DbCon...
-
This happens when you dont define primary key for an entity type. Define a primary key either explicitly or implicitly.
No comments:
Post a Comment