Sunday, June 6, 2021

CycleGAN

 # pip install  git+https://github.com/tensorflow/examples.git


import tensorflow as tf 

from tensorflow import keras 

import tensorflow_datasets as tfds 

from tensorflow_examples.models.pix2pix import pix2pix


################################################################################

dataset , metadata = tfds.load("cycle_gan/horse2zebra" , with_info = True , as_supervised = True)

train_horses , train_zebras = dataset["trainA"] , dataset["trainB"]

test_horses , test_zebras = dataset["testA"] , dataset["testB"]


BUFFER_SIZE = 1000 

BATCH_SIZE = 1 

IMG_HEIGHT = 256 

IMG_WIDTH = 256 

OUTPUT_CHANNELS = 3

LAMBDA = 10 

AUTOTUNE = tf.data.AUTOTUNE

no_of_epochs = 1


def random_jitter(image) : 

  image = tf.image.resize(image , [286 , 286] , method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  image = tf.image.random_crop(image , [256, 256, 3])

  image = tf.image.flip_left_right(image)

  return image


def normalize(image) : 

  image = tf.cast(image , tf.float32) 

  image = (image /127.5) - 1

  return image 


def load_test_image(image , label) :

  image = normalize(image) 

  return image


def load_train_image(image , label) : 

  image = random_jitter(image)

  image = normalize(image) 

  return image 


train_horses = train_horses.map(load_train_image, num_parallel_calls = AUTOTUNE).cache(

    ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

train_zebras = train_zebras.map(load_train_image, num_parallel_calls = AUTOTUNE).cache(

    ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(load_test_image, num_parallel_calls = AUTOTUNE).cache(

    ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

test_zebras = test_zebras.map(load_test_image, num_parallel_calls = AUTOTUNE).cache(

    ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)                              

################################################################################




################################################################################

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type = "instancenorm")

generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS , norm_type = "instancenorm")

discriminator_x = pix2pix.discriminator(norm_type = "instancenorm" , target = False)

discriminator_y = pix2pix.discriminator(norm_type = "instancenorm", target = False)


gen_g_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)

gen_f_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)

disc_x_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)

disc_y_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)


cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real_x , generated_x) : 

  real_loss = cross_entropy(tf.ones_like(real_x) , real_x)

  generated_loss = cross_entropy(tf.zeros_like(generated_x) , generated_x)

  return (real_loss + generated_loss) * 0.5


def generator_loss(disc_generated_image) : 

  return cross_entropy(tf.ones_like(disc_generated_image) , disc_generated_image)


def cycle_loss(real_image , cycled_image)   : 

  loss = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return loss * LAMBDA


def identity_loss(real_image , same_image) : 

  loss = tf.reduce_mean(tf.abs(real_image , same_image))

  return loss * 0.5 * LAMBDA

################################################################################



################################################################################

def train_step(real_x , real_y) : 

  with tf.GradientTape(persistent=True) as tape : 

    fake_y = generator_g(real_x , training = True)

    cycled_x = generator_f(fake_y , training = True)


    fake_x = generator_f(real_y , training = True)

    cycled_y = generator_g(fake_x , training = True)


    same_x = generator_f(real_x , training = True)

    same_y = generator_g(real_y , training = True)


    disc_real_x = discriminator_x(real_x , training = True)

    disc_real_y = discriminator_y(real_y , training = True)


    disc_fake_x = discriminator_x(fake_x , training = True)

    disc_fake_y = discriminator_y(fake_y , training = True)


    gen_g_loss = generator_loss(disc_fake_y)

    gen_f_loss = generator_loss(disc_fake_x)


    tot_cycle_loss = cycle_loss(real_x , cycled_x)  + cycle_loss(real_y , cycled_y)


    tot_gen_g_loss = gen_g_loss + tot_cycle_loss + identity_loss(real_y , same_y)

    tot_gen_f_loss = gen_f_loss + tot_cycle_loss + identity_loss(real_x , same_x)


    disc_x_loss = discriminator_loss(disc_real_x , disc_fake_x)

    disc_y_loss = discriminator_loss(disc_real_y , disc_fake_y)


  gen_g_gradient = tape.gradient(tot_gen_g_loss , generator_g.trainable_variables)

  gen_f_gradient = tape.gradient(tot_gen_f_loss , generator_f.trainable_variables)

  disc_x_gradient = tape.gradient(disc_x_loss , discriminator_x.trainable_variables)

  disc_y_gradient = tape.gradient(disc_y_loss , discriminator_y.trainable_variables)


  gen_g_optimizer.apply_gradients(zip(gen_g_gradient , generator_g.trainable_variables))

  gen_f_optimizer.apply_gradients(zip(gen_g_gradient , generator_f.trainable_variables))

  disc_x_optimizer.apply_gradients(zip(disc_x_gradient , discriminator_x.trainable_variables))

  disc_y_optimizer.apply_gradients(zip(disc_y_gradient , discriminator_y.trainable_variables))

################################################################################



################################################################################

def train() : 

  for epoch in range(no_of_epochs) :

    print("epoch {} started".format(epoch))

    for image_x , image_y in tf.data.Dataset.zip((train_horses, train_zebras)) :

      train_step(image_x , image_y)

      print("." , end="")

    print("epoch {} ended".format(epoch))


def test(test_image) : 

  import matplotlib.pyplot as plt

  predicted_image = generator_g(test_image)

  plt.figure(figsize = (16,16))

  plt.subplot(121)

  plt.title("Original Image")

  plt.imshow(test_image[0]*0.5 +0.5)

  plt.subplot(122)

  plt.title("Predicted Image")

  plt.imshow(predicted_image[0]*0.5 +0.5)  

  plt.show()


train()

for inp in test_horses.take(5) : 

  test(inp)

################################################################################

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...