Wednesday, June 2, 2021

Pix2Pix I - Preprocessing images and creating training and testing datasets

Based on this


import tensorflow as tf 

from tensorflow import keras 

facadespath = "/content/facades"
facadesfile = facadespath + "/" + "facades.tar.gz"
facadesextractedpath = facadespath + "/facades"
url = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz"


import os 

if not os.path.exists(facadespath) : 
  os.mkdir(facadespath)

if not os.path.exists(facadesfile)  : 
  tf.keras.utils.get_file(facadesfile , origin= url , extract = True)

if not os.path.exists(facadesextractedpath) : 
  import tarfile 
  with tarfile.open(facadesfile , "r"as tref : 
    tref.extractall(facadespath)

#dir for checkpointing
if not os.path.exists("/content/checkpoint") : 
  os.mkdir("/content/checkpoint")


BUFFER_SIZE = 400 
BATCH_SIZE = 1 
IMG_HEIGHT = 256 
IMG_WIDTH = 256

def load(imagefile) : 
  image = tf.io.read_file(imagefile)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]
  w = w//2

  input_image = image[:,w:,:]
  target_image = image[:,:w,:]

  input_image = tf.cast(input_image , tf.float32)
  target_image = tf.cast(target_image , tf.float32)

  return input_image , target_image 


def resize(input_image ,target_image , img_height , img_width) : 
  input_image = tf.image.resize(input_image , [img_height , img_width] ,
                                method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  target_image = tf.image.resize(target_image, [img_height , img_width], 
                                 method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return input_image , target_image

def random_crop(input_image , target_image) : 
  stack = tf.stack([input_image , target_image] , axis = 0 )
  cropped = tf.image.random_crop(stack, [2, IMG_HEIGHT , IMG_WIDTH , 3])
  return cropped[0] , cropped[1]

def random_jitter(input_image , target_image) : 
  input_image , target_image = resize(input_image , target_image , 286 , 286)
  input_image , target_image = random_crop(input_image , target_image)
  if ( tf.random.uniform(()) > 0.5) : 
    input_image = tf.image.flip_left_right(input_image)
    target_image = tf.image.flip_left_right(target_image)
    
  return input_image , target_image 

def normalize(input_image , target_image) : 
  input_image = (input_image /127.5)  -1 
  target_image = (target_image/127.5) - 1 
  return input_image , target_image 


def load_train_images(imagefile) : 
  input_image , target_image = load(imagefile)
  input_image , target_image = random_jitter(input_image , target_image)
  input_image , target_image = normalize(input_image , target_image)
  return input_image , target_image 

def load_test_images(imagefile) : 
  input_image, target_image = load(imagefile)
  input_image, target_image = resize(input_image,target_image,IMG_HEIGHT,IMG_WIDTH)
  input_image, target_image = normalize(input_image , target_image)
  return input_image , target_image


train_ds = tf.data.Dataset.list_files(facadesextractedpath + "/train/*.jpg")
train_ds = train_ds.map(load_train_images)
train_ds = train_ds.shuffle(BUFFER_SIZE)
train_ds = train_ds.batch(BATCH_SIZE)

test_ds = tf.data.Dataset.list_files(facadesextractedpath + "/test/*.jpg")
test_ds = test_ds.map(load_test_images)
test_ds = test_ds.batch(BATCH_SIZE)

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...