Wednesday, June 2, 2021

Pix2Pix II - The Generator and Discriminator

 


def upsample(filters , size , apply_dropout = False) : 
  initializer = tf.random_normal_initializer(0., 0.02)
  result = tf.keras.Sequential() 
  result.add(tf.keras.layers.Conv2DTranspose(
      filters , size , strides = 2 , padding = "same"
      use_bias = False , kernel_initializer  = initializer
  ))
  result.add(tf.keras.layers.BatchNormalization())
  if apply_dropout : 
    result.add(tf.keras.layers.Dropout(0.5))
  result.add(tf.keras.layers.ReLU())
  return result 

def downsample(filters , size,  apply_batchnorm = True) : 
  initializer = tf.random_normal_initializer(0., 0.02)
  result = tf.keras.Sequential() 
  result.add(tf.keras.layers.Conv2D(
            filters , size , strides = 2 , padding = "same"
      use_bias = False , kernel_initializer  = initializer
  ))
  if apply_batchnorm : 
    result.add(tf.keras.layers.BatchNormalization())
  result.add(tf.keras.layers.LeakyReLU())    

  return result

def get_generator() : 
  initializer = tf.random_normal_initializer(0., 0.02)
  inputs = tf.keras.layers.Input(shape = (2562563))

  downstack = [
               downsample(64,4, apply_batchnorm = False) , 
               downsample(128,4), 
               downsample(256,4),
               downsample(512,4) , 
               downsample(512,4),
               downsample(512,4),
               downsample(512,4),
               downsample(512,4)
  ]

  upstack = [
              upsample(512,4, apply_dropout = True), 
              upsample(512,4, apply_dropout = True),
              upsample(512,4, apply_dropout = True),
              upsample(512,4), 
              upsample(256,4), 
              upsample(128,4),
              upsample(64,4)
  ]

  OUTPUT_CHANNELS = 3
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS , 4, strides=2 , use_bias = False , 
                                         padding = "same"
                                         kernel_initializer = initializer, 
                                         activation = "tanh")
  x = inputs 
  skips = [] 
  for down in downstack: 
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  for up , skip in zip(upstack , skips):
    x = up(x) 
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x) 

  return tf.keras.Model(inputs = inputs , outputs = x)    
def get_discriminator() : 
  initializer = tf.random_normal_initializer(0., 0.02)
  inp1 = tf.keras.layers.Input(shape = (256256 , 3))
  inp2 = tf.keras.layers.Input(shape = (256256 , 3))
  x = tf.keras.layers.concatenate([inp1 , inp2])

  downsample1 = downsample(644 , apply_batchnorm = False) (x)
  downsample2 = downsample(128 , 4) (downsample1)
  downsample3 = downsample(256 , 4) (downsample2)

  zeropad1 = tf.keras.layers.ZeroPadding2D()(downsample3)
  conv2d = tf.keras.layers.Conv2D(5124, strides = 1 , use_bias = False
            kernel_initializer = initializer)(zeropad1)
  batchnorm = tf.keras.layers.BatchNormalization()(conv2d)
  leakyrelu = tf.keras.layers.LeakyReLU()(batchnorm)
  zeropad2 = tf.keras.layers.ZeroPadding2D()(leakyrelu)
  last = tf.keras.layers.Conv2D(1,4,strides=1,kernel_initializer = initializer) (zeropad2)
  return tf.keras.Model(inputs = [inp1 , inp2] , outputs = last)

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...