Monday, May 24, 2021

Plotting a Confusion Matrix for CIFAR10 using Seaborn

Most of the google tutorials on keras do not show how to display a confusion matrix for

the solution. A confusion matrix can throw a clear light on how the model is performing .

Below is a simple cifar10 solution using keras. Most of the code is similar to any other

cifar10 tensorflow tutorial, except a small number of lines at the end, which plot

confusion matrix. Those lines are marked by comment.




import tensorflow as tf 

from tensorflow import keras 
import matplotlib.pyplot as plt 

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

input_shape = train_images.shape[1:]

model  = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32 , 3, activation="relu" , input_shape = input_shape))
model.add(tf.keras.layers.Conv2D(32 , 3, activation="relu"))
model.add(tf.keras.layers.MaxPooling2D())

model.add(tf.keras.layers.Conv2D(64 , 3, activation="relu"))
model.add(tf.keras.layers.Conv2D(64 , 3, activation="relu"))
model.add(tf.keras.layers.MaxPooling2D())



model.add(tf.keras.layers.Flatten())

model.add(tf.keras.layers.Dense(1024, activation = "relu"))
model.add(tf.keras.layers.Dense(10, activation = "softmax"))

model.compile(
    optimizer = tf.keras.optimizers.Adam() , 
    loss = tf.keras.losses.SparseCategoricalCrossentropy(), 
    metrics = ["accuracy"]
)


epochs = 20
history  =  model.fit(
    train_images, 
    train_labels, 
    validation_data = (test_images, test_labels),
    epochs = epochs
)


plt.figure(figsize = (8,8))

plt.subplot(1,2,1)
plt.plot(range(epochs) , history.history["accuracy"] , "r" , label = "Training Accuracy")
plt.plot(range(epochs) , history.history["val_accuracy"] , "b" , label = "Validation Accuracy")
plt.legend(loc="upper left")
plt.title("Accuracy")

plt.subplot(1,2,2)
plt.plot(range(epochs) , history.history["loss"] , "r" , label = "Training Loss")
plt.plot(range(epochs) , history.history["val_loss"] , "b" , label = "Validation Loss")
plt.legend(loc="upper right")
plt.title("Loss")

plt.show()


predictions  = model.predict(test_images)

#The following 7 lines are all that is required to plot the confusion matrix.
predictions_for_cm = predictions.argmax(1)

from sklearn.metrics import confusion_matrix
import seaborn as sns
class_names = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

cm = confusion_matrix(test_labels,predictions_for_cm)
plt.figure(figsize=(8,8))
sns.heatmap(cm, annot=True,  xticklabels=class_names, yticklabels = class_names)













My answer on SO

No comments:

Post a Comment

How to check local and global angular versions

 Use the command ng version (or ng v ) to find the version of Angular CLI in the current folder. Run it outside of the Angular project, to f...