# pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
################################################################################
dataset , metadata = tfds.load("cycle_gan/horse2zebra" , with_info = True , as_supervised = True)
train_horses , train_zebras = dataset["trainA"] , dataset["trainB"]
test_horses , test_zebras = dataset["testA"] , dataset["testB"]
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_HEIGHT = 256
IMG_WIDTH = 256
OUTPUT_CHANNELS = 3
LAMBDA = 10
AUTOTUNE = tf.data.AUTOTUNE
no_of_epochs = 1
def random_jitter(image) :
image = tf.image.resize(image , [286 , 286] , method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
image = tf.image.random_crop(image , [256, 256, 3])
image = tf.image.flip_left_right(image)
return image
def normalize(image) :
image = tf.cast(image , tf.float32)
image = (image /127.5) - 1
return image
def load_test_image(image , label) :
image = normalize(image)
return image
def load_train_image(image , label) :
image = random_jitter(image)
image = normalize(image)
return image
train_horses = train_horses.map(load_train_image, num_parallel_calls = AUTOTUNE).cache(
).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_zebras = train_zebras.map(load_train_image, num_parallel_calls = AUTOTUNE).cache(
).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_horses = test_horses.map(load_test_image, num_parallel_calls = AUTOTUNE).cache(
).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_zebras = test_zebras.map(load_test_image, num_parallel_calls = AUTOTUNE).cache(
).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
################################################################################
################################################################################
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type = "instancenorm")
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS , norm_type = "instancenorm")
discriminator_x = pix2pix.discriminator(norm_type = "instancenorm" , target = False)
discriminator_y = pix2pix.discriminator(norm_type = "instancenorm", target = False)
gen_g_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)
gen_f_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)
disc_x_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)
disc_y_optimizer = tf.keras.optimizers.Adam(2e-4 , beta_1 = 0.5)
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_x , generated_x) :
real_loss = cross_entropy(tf.ones_like(real_x) , real_x)
generated_loss = cross_entropy(tf.zeros_like(generated_x) , generated_x)
return (real_loss + generated_loss) * 0.5
def generator_loss(disc_generated_image) :
return cross_entropy(tf.ones_like(disc_generated_image) , disc_generated_image)
def cycle_loss(real_image , cycled_image) :
loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
return loss * LAMBDA
def identity_loss(real_image , same_image) :
loss = tf.reduce_mean(tf.abs(real_image , same_image))
return loss * 0.5 * LAMBDA
################################################################################
################################################################################
def train_step(real_x , real_y) :
with tf.GradientTape(persistent=True) as tape :
fake_y = generator_g(real_x , training = True)
cycled_x = generator_f(fake_y , training = True)
fake_x = generator_f(real_y , training = True)
cycled_y = generator_g(fake_x , training = True)
same_x = generator_f(real_x , training = True)
same_y = generator_g(real_y , training = True)
disc_real_x = discriminator_x(real_x , training = True)
disc_real_y = discriminator_y(real_y , training = True)
disc_fake_x = discriminator_x(fake_x , training = True)
disc_fake_y = discriminator_y(fake_y , training = True)
gen_g_loss = generator_loss(disc_fake_y)
gen_f_loss = generator_loss(disc_fake_x)
tot_cycle_loss = cycle_loss(real_x , cycled_x) + cycle_loss(real_y , cycled_y)
tot_gen_g_loss = gen_g_loss + tot_cycle_loss + identity_loss(real_y , same_y)
tot_gen_f_loss = gen_f_loss + tot_cycle_loss + identity_loss(real_x , same_x)
disc_x_loss = discriminator_loss(disc_real_x , disc_fake_x)
disc_y_loss = discriminator_loss(disc_real_y , disc_fake_y)
gen_g_gradient = tape.gradient(tot_gen_g_loss , generator_g.trainable_variables)
gen_f_gradient = tape.gradient(tot_gen_f_loss , generator_f.trainable_variables)
disc_x_gradient = tape.gradient(disc_x_loss , discriminator_x.trainable_variables)
disc_y_gradient = tape.gradient(disc_y_loss , discriminator_y.trainable_variables)
gen_g_optimizer.apply_gradients(zip(gen_g_gradient , generator_g.trainable_variables))
gen_f_optimizer.apply_gradients(zip(gen_g_gradient , generator_f.trainable_variables))
disc_x_optimizer.apply_gradients(zip(disc_x_gradient , discriminator_x.trainable_variables))
disc_y_optimizer.apply_gradients(zip(disc_y_gradient , discriminator_y.trainable_variables))
################################################################################
################################################################################
def train() :
for epoch in range(no_of_epochs) :
print("epoch {} started".format(epoch))
for image_x , image_y in tf.data.Dataset.zip((train_horses, train_zebras)) :
train_step(image_x , image_y)
print("." , end="")
print("epoch {} ended".format(epoch))
def test(test_image) :
import matplotlib.pyplot as plt
predicted_image = generator_g(test_image)
plt.figure(figsize = (16,16))
plt.subplot(121)
plt.title("Original Image")
plt.imshow(test_image[0]*0.5 +0.5)
plt.subplot(122)
plt.title("Predicted Image")
plt.imshow(predicted_image[0]*0.5 +0.5)
plt.show()
train()
for inp in test_horses.take(5) :
test(inp)
################################################################################
No comments:
Post a Comment