Most of the google tutorials on keras do not show how to display a confusion matrix for
the solution. A confusion matrix can throw a clear light on how the model is performing .
Below is a simple cifar10 solution using keras. Most of the code is similar to any other
cifar10 tensorflow tutorial, except a small number of lines at the end, which plot
confusion matrix. Those lines are marked by comment.
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
input_shape = train_images.shape[1:]
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32 , 3, activation="relu" , input_shape = input_shape))
model.add(tf.keras.layers.Conv2D(32 , 3, activation="relu"))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(64 , 3, activation="relu"))
model.add(tf.keras.layers.Conv2D(64 , 3, activation="relu"))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1024, activation = "relu"))
model.add(tf.keras.layers.Dense(10, activation = "softmax"))
model.compile(
optimizer = tf.keras.optimizers.Adam() ,
loss = tf.keras.losses.SparseCategoricalCrossentropy(),
metrics = ["accuracy"]
)
epochs = 20
history = model.fit(
train_images,
train_labels,
validation_data = (test_images, test_labels),
epochs = epochs
)
plt.figure(figsize = (8,8))
plt.subplot(1,2,1)
plt.plot(range(epochs) , history.history["accuracy"] , "r" , label = "Training Accuracy")
plt.plot(range(epochs) , history.history["val_accuracy"] , "b" , label = "Validation Accuracy")
plt.legend(loc="upper left")
plt.title("Accuracy")
plt.subplot(1,2,2)
plt.plot(range(epochs) , history.history["loss"] , "r" , label = "Training Loss")
plt.plot(range(epochs) , history.history["val_loss"] , "b" , label = "Validation Loss")
plt.legend(loc="upper right")
plt.title("Loss")
plt.show()
predictions = model.predict(test_images)
#The following 7 lines are all that is required to plot the confusion matrix.
predictions_for_cm = predictions.argmax(1)
from sklearn.metrics import confusion_matrix
import seaborn as sns
class_names = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
cm = confusion_matrix(test_labels,predictions_for_cm)
plt.figure(figsize=(8,8))
sns.heatmap(cm, annot=True, xticklabels=class_names, yticklabels = class_names)
My answer on SO
No comments:
Post a Comment