Wednesday, May 26, 2021

Revisted MNIST DCGAN : somewhat simplified now

 import tensorflow as tf 

from tensorflow import keras 
from keras import layers
import matplotlib.pyplot as plt 
import time
# rm -r "/content/ckpoint"
no_of_epochs_to_checkpoint = 2
no_of_epochs = 10
no_of_examples =16
no_of_dimensions_for_noise =100
ckpoint_prefix = "/content/ckpoint/ckpt"
BATCH_SIZE = 256 
BUFFER_SIZE = 60000


(train_images , train_labels) , (_,_) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0] , 28 , 28 , 1).astype("float32")
train_images = (train_images -127.5) /127.5

train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

checkpoint = tf.train.Checkpoint()

def get_generator() : 
  model = keras.models.Sequential()
  model.add(layers.Dense(7*7*256 , use_bias = False , input_shape = (100,)))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Reshape((7 ,7256)))

  model.add(layers.Conv2DTranspose(1285, strides = 1 , use_bias = False , padding = "same"))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(64 , 5, strides = 2 , use_bias = False , padding = "same"))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(15, strides = 2 , use_bias= False , padding = "same" , activation = "tanh"))

  return model
def get_discriminator() : 
  model = keras.models.Sequential()

  model.add(layers.Conv2D(64 , 5, strides = 2 , padding = "same" , input_shape = (28,28,1)))
  model.add(layers.LeakyReLU())
  model.add(layers.Dropout(0.3))

  model.add(layers.Conv2D(128 , 5, strides = 2 , padding = "same"))
  model.add(layers.LeakyReLU())
  model.add(layers.Dropout(0.3))

  model.add(layers.Flatten())
  model.add(layers.Dense(1))

  return model  
  
generator = get_generator()
discriminator = get_discriminator()

#LOSSES
cross_entropy = keras.losses.BinaryCrossentropy(from_logits = True)

def generator_loss(fake_output) : 
  return cross_entropy(tf.ones_like(fake_output) , fake_output)

def discriminator_loss(real_outputfake_output) : 
  real_loss = cross_entropy(tf.ones_like(real_output) , real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output) , fake_output)
  return real_loss + fake_loss
#LOSSES

#OPTIMIZER
gen_optimizer = tf.keras.optimizers.Adam(1e-4)
disc_optimizer = tf.keras.optimizers.Adam(1e-4)
#OPTIMIZER



def generate_and_save_files_after_epoch(generator , epoch_number) : 
  seed = tf.random.normal([no_of_examples, no_of_dimensions_for_noise])
  predictions = generator(seed , training = False)

  plt.figure(figsize = (15,15) )

  for i in range(predictions.shape[0]) : 
    plt.subplot(4,4,i+1)
    plt.imshow(predictions[i,:,:,0] *127.5  + 127.5 )
    plt.axis("off")
  plt.savefig( "/content/epochwiseoutput/" + "epoch_{:04d}".format(epoch_number))
  plt.show()

  return 0 

def train_step(imageset) : 
  noise = tf.random.normal([BATCH_SIZE, no_of_dimensions_for_noise])

  with tf.GradientTape() as gen_tape , tf.GradientTape() as disc_tape : 
    fake_images = generator(noise , training = True)
    real_output = discriminator(imageset , training = True)
    fake_output = discriminator(fake_images , training = True

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output , fake_output)

    gen_gradient = gen_tape.gradient(gen_loss , generator.trainable_variables)
    disc_gradient = disc_tape.gradient(disc_loss , discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(gen_gradient , generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(disc_gradient , discriminator.trainable_variables))
  return 0 

def train(batched_training_dataset , no_of_epochs) : 
  for epoch in range(no_of_epochs) : 
    start_time = time.time()
    print("started at {} epoch {:04d}".format(start_time , epoch))
    for batch in batched_training_dataset : 
      train_step(batch)
    if (epoch%no_of_epochs_to_checkpoint == 0 ) : 
      checkpoint.save(file_prefix = ckpoint_prefix)
    generate_and_save_files_after_epoch(generator , epoch + 1)
    print("time taken {} for epoch {:04d}".format(time.time() - start_time , epoch))

  generate_and_save_files_after_epoch(generator , no_of_epochs )

train(train_dataset, no_of_epochs)

import os 
if not (os.path.exists("/content/epochwiseoutput")) :
  os.mkdir("/content/epochwiseoutput")
import imageio 
import glob 

with imageio.get_writer("animatedfile.gif" , mode = "I"as writer : 
  filenames = glob.glob("/content/epochwiseoutput/*.png")
  filenames = sorted(filenames)
  for file in filenames : 
    img = imageio.imread(file)
    writer.append_data(img)
  img = imageio.imread(file)
  writer.append_data(img)

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...