Based on this
import tensorflow as tf
from tensorflow import keras
facadespath = "/content/facades"
facadesfile = facadespath + "/" + "facades.tar.gz"
facadesextractedpath = facadespath + "/facades"
url = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz"
import os
if not os.path.exists(facadespath) :
os.mkdir(facadespath)
if not os.path.exists(facadesfile) :
tf.keras.utils.get_file(facadesfile , origin= url , extract = True)
if not os.path.exists(facadesextractedpath) :
import tarfile
with tarfile.open(facadesfile , "r") as tref :
tref.extractall(facadespath)
#dir for checkpointing
if not os.path.exists("/content/checkpoint") :
os.mkdir("/content/checkpoint")
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_HEIGHT = 256
IMG_WIDTH = 256
def load(imagefile) :
image = tf.io.read_file(imagefile)
image = tf.image.decode_jpeg(image)
w = tf.shape(image)[1]
w = w//2
input_image = image[:,w:,:]
target_image = image[:,:w,:]
input_image = tf.cast(input_image , tf.float32)
target_image = tf.cast(target_image , tf.float32)
return input_image , target_image
def resize(input_image ,target_image , img_height , img_width) :
input_image = tf.image.resize(input_image , [img_height , img_width] ,
method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
target_image = tf.image.resize(target_image, [img_height , img_width],
method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image , target_image
def random_crop(input_image , target_image) :
stack = tf.stack([input_image , target_image] , axis = 0 )
cropped = tf.image.random_crop(stack, [2, IMG_HEIGHT , IMG_WIDTH , 3])
return cropped[0] , cropped[1]
def random_jitter(input_image , target_image) :
input_image , target_image = resize(input_image , target_image , 286 , 286)
input_image , target_image = random_crop(input_image , target_image)
if ( tf.random.uniform(()) > 0.5) :
input_image = tf.image.flip_left_right(input_image)
target_image = tf.image.flip_left_right(target_image)
return input_image , target_image
def normalize(input_image , target_image) :
input_image = (input_image /127.5) -1
target_image = (target_image/127.5) - 1
return input_image , target_image
def load_train_images(imagefile) :
input_image , target_image = load(imagefile)
input_image , target_image = random_jitter(input_image , target_image)
input_image , target_image = normalize(input_image , target_image)
return input_image , target_image
def load_test_images(imagefile) :
input_image, target_image = load(imagefile)
input_image, target_image = resize(input_image,target_image,IMG_HEIGHT,IMG_WIDTH)
input_image, target_image = normalize(input_image , target_image)
return input_image , target_image
train_ds = tf.data.Dataset.list_files(facadesextractedpath + "/train/*.jpg")
train_ds = train_ds.map(load_train_images)
train_ds = train_ds.shuffle(BUFFER_SIZE)
train_ds = train_ds.batch(BATCH_SIZE)
test_ds = tf.data.Dataset.list_files(facadesextractedpath + "/test/*.jpg")
test_ds = test_ds.map(load_test_images)
test_ds = test_ds.batch(BATCH_SIZE)
No comments:
Post a Comment