import tensorflow as tf
from tensorflow import keras
from keras import layers
import matplotlib.pyplot as plt
import time
# rm -r "/content/ckpoint"
no_of_epochs_to_checkpoint = 2
no_of_epochs = 10
no_of_examples =16
no_of_dimensions_for_noise =100
ckpoint_prefix = "/content/ckpoint/ckpt"
BATCH_SIZE = 256
BUFFER_SIZE = 60000
(train_images , train_labels) , (_,_) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0] , 28 , 28 , 1).astype("float32")
train_images = (train_images -127.5) /127.5
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
checkpoint = tf.train.Checkpoint()
def get_generator() :
model = keras.models.Sequential()
model.add(layers.Dense(7*7*256 , use_bias = False , input_shape = (100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7 ,7, 256)))
model.add(layers.Conv2DTranspose(128, 5, strides = 1 , use_bias = False , padding = "same"))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64 , 5, strides = 2 , use_bias = False , padding = "same"))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, 5, strides = 2 , use_bias= False , padding = "same" , activation = "tanh"))
return model
def get_discriminator() :
model = keras.models.Sequential()
model.add(layers.Conv2D(64 , 5, strides = 2 , padding = "same" , input_shape = (28,28,1)))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128 , 5, strides = 2 , padding = "same"))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
generator = get_generator()
discriminator = get_discriminator()
#LOSSES
cross_entropy = keras.losses.BinaryCrossentropy(from_logits = True)
def generator_loss(fake_output) :
return cross_entropy(tf.ones_like(fake_output) , fake_output)
def discriminator_loss(real_output, fake_output) :
real_loss = cross_entropy(tf.ones_like(real_output) , real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output) , fake_output)
return real_loss + fake_loss
#LOSSES
#OPTIMIZER
gen_optimizer = tf.keras.optimizers.Adam(1e-4)
disc_optimizer = tf.keras.optimizers.Adam(1e-4)
#OPTIMIZER
def generate_and_save_files_after_epoch(generator , epoch_number) :
seed = tf.random.normal([no_of_examples, no_of_dimensions_for_noise])
predictions = generator(seed , training = False)
plt.figure(figsize = (15,15) )
for i in range(predictions.shape[0]) :
plt.subplot(4,4,i+1)
plt.imshow(predictions[i,:,:,0] *127.5 + 127.5 )
plt.axis("off")
plt.savefig( "/content/epochwiseoutput/" + "epoch_{:04d}".format(epoch_number))
plt.show()
return 0
def train_step(imageset) :
noise = tf.random.normal([BATCH_SIZE, no_of_dimensions_for_noise])
with tf.GradientTape() as gen_tape , tf.GradientTape() as disc_tape :
fake_images = generator(noise , training = True)
real_output = discriminator(imageset , training = True)
fake_output = discriminator(fake_images , training = True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output , fake_output)
gen_gradient = gen_tape.gradient(gen_loss , generator.trainable_variables)
disc_gradient = disc_tape.gradient(disc_loss , discriminator.trainable_variables)
gen_optimizer.apply_gradients(zip(gen_gradient , generator.trainable_variables))
disc_optimizer.apply_gradients(zip(disc_gradient , discriminator.trainable_variables))
return 0
def train(batched_training_dataset , no_of_epochs) :
for epoch in range(no_of_epochs) :
start_time = time.time()
print("started at {} epoch {:04d}".format(start_time , epoch))
for batch in batched_training_dataset :
train_step(batch)
if (epoch%no_of_epochs_to_checkpoint == 0 ) :
checkpoint.save(file_prefix = ckpoint_prefix)
generate_and_save_files_after_epoch(generator , epoch + 1)
print("time taken {} for epoch {:04d}".format(time.time() - start_time , epoch))
generate_and_save_files_after_epoch(generator , no_of_epochs )
train(train_dataset, no_of_epochs)
import os
if not (os.path.exists("/content/epochwiseoutput")) :
os.mkdir("/content/epochwiseoutput")
import imageio
import glob
with imageio.get_writer("animatedfile.gif" , mode = "I") as writer :
filenames = glob.glob("/content/epochwiseoutput/*.png")
filenames = sorted(filenames)
for file in filenames :
img = imageio.imread(file)
writer.append_data(img)
img = imageio.imread(file)
writer.append_data(img)
No comments:
Post a Comment