Saturday, May 8, 2021

Revisiting the good old fashion_mnist 


#The CNN model does well, does not overfit, and predicts fairly accurately
#with may be just one/two wrong predictions.


import tensorflow as tf 
import tensorflow.keras 
import matplotlib.pyplot as plt 
import numpy as np 


fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images , train_labels), (test_images, test_labels) = fashion_mnist.load_data()


class_names = ['T-shirt/top''Trouser''Pullover''Dress''Coat',
               'Sandal''Shirt''Sneaker''Bag''Ankle boot']

print(train_images.shape)
print(train_labels.shape)
print(test_images.shape)
print(test_labels.shape)

#image printing below and shapes printing above is not strictly required, it 
#is merely done here to be in sync with the tensorflow google tutorials!

#print a single image
plt.figure()
plt.imshow(train_images[0])
plt.xticks([])
plt.yticks([])
plt.xlabel(class_names[train_labels[0]])
plt.show()



i = 0 
plt.figure(figsize=(6,6))
for i in range(25): 
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(train_images[i])
  plt.xlabel(class_names[train_labels[i]])
plt.show()

train_images = train_images / 255.0
test_images = test_images /255.0

train_images = train_images.reshape(train_images.shape[0],2828,1)
test_images = test_images.reshape(test_images.shape[0],2828,1)


model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(16 , 3, activation = "relu" , input_shape =(28,28,1)), 
  tf.keras.layers.MaxPooling2D(), 
  tf.keras.layers.Conv2D(16 , 3, activation = "relu"), 
  tf.keras.layers.MaxPooling2D(), 
  tf.keras.layers.Conv2D(16 , 3, activation = "relu"), 
  tf.keras.layers.MaxPooling2D(),     
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation = "relu"),
  tf.keras.layers.Dense(10, activation= "softmax")
])


model.compile(optimizer = "adam" , loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), metrics = ["accuracy"])
epochs=10
history = model.fit(train_images , train_labels , epochs = epochs , validation_data = (test_images, test_labels))


plt.figure(figsize=(8,8))

plt.subplot(1,2,1)
plt.plot(range(epochs) , history.history["accuracy"] , label = "Training Accuracy")
plt.plot(range(epochs) , history.history["val_accuracy"] , label = "Validation Accuracy")
plt.legend(loc="upper left")
plt.title("Accuracy")

plt.subplot(1,2,2)
plt.plot(range(epochs) , history.history["loss"] , label = "Training Loss")
plt.plot(range(epochs) , history.history["val_loss"] , label = "Validation Loss")
plt.legend(loc="upper right")
plt.title("Loss")

plt.show()


#predict for a single image
img_index = 10
img = tf.expand_dims(test_images[img_index],0
predictions = model.predict(img)
print(predictions[0])



print(class_names[np.argmax(predictions[0])])
print(class_names[test_labels[img_index]])
plt.imshow(test_images[img_index].reshape(28,28))


#predict for multiple images
image_index = 0 
for image_index in range(1 , 25 , 1) :
  img_arr = tf.expand_dims(train_images[image_index],0)
  predictions = model.predict(img_arr)
  print ("Predicted class {} ; actual class {} ".format(class_names[np.argmax(predictions[0])], class_names[train_labels[image_index]]))

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...