Showing posts with label keras. Show all posts
Showing posts with label keras. Show all posts

Monday, May 24, 2021

Plotting a Confusion Matrix for CIFAR10 using Seaborn

Most of the Google tutorials on Keras do not show how to display a confusion matrix for the solution. A confusion matrix can throw a much clearer light on how the model is actually performing.

Below is a simple CIFAR-10 solution using Keras. Most of the code is similar to any standard CIFAR-10 TensorFlow tutorial, except for a small number of lines at the end which plot the confusion matrix.

Those specific lines are clearly marked with comments.

Complete Code

import tensorflow as tf

from tensorflow import keras
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

input_shape = train_images.shape[1:]

model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Conv2D(
    32,
    3,
    activation="relu",
    input_shape=input_shape
))

model.add(tf.keras.layers.Conv2D(
    32,
    3,
    activation="relu"
))

model.add(tf.keras.layers.MaxPooling2D())

model.add(tf.keras.layers.Conv2D(
    64,
    3,
    activation="relu"
))

model.add(tf.keras.layers.Conv2D(
    64,
    3,
    activation="relu"
))

model.add(tf.keras.layers.MaxPooling2D())

model.add(tf.keras.layers.Flatten())

model.add(tf.keras.layers.Dense(
    1024,
    activation="relu"
))

model.add(tf.keras.layers.Dense(
    10,
    activation="softmax"
))

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)

epochs = 20

history = model.fit(
    train_images,
    train_labels,
    validation_data=(test_images, test_labels),
    epochs=epochs
)

plt.figure(figsize=(8,8))

plt.subplot(1,2,1)

plt.plot(
    range(epochs),
    history.history["accuracy"],
    "r",
    label="Training Accuracy"
)

plt.plot(
    range(epochs),
    history.history["val_accuracy"],
    "b",
    label="Validation Accuracy"
)

plt.legend(loc="upper left")
plt.title("Accuracy")

plt.subplot(1,2,2)

plt.plot(
    range(epochs),
    history.history["loss"],
    "r",
    label="Training Loss"
)

plt.plot(
    range(epochs),
    history.history["val_loss"],
    "b",
    label="Validation Loss"
)

plt.legend(loc="upper right")
plt.title("Loss")

plt.show()

predictions = model.predict(test_images)


# ---------------------------------------------------------
# The following lines plot the confusion matrix
# ---------------------------------------------------------


predictions_for_cm = predictions.argmax(1)

from sklearn.metrics import confusion_matrix
import seaborn as sns

class_names = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]

cm = confusion_matrix(
    test_labels,
    predictions_for_cm
)

plt.figure(figsize=(8,8))

sns.heatmap(
    cm,
    annot=True,
    xticklabels=class_names,
    yticklabels=class_names
)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")

plt.show()


Why Confusion Matrix Matters

A confusion matrix gives much deeper insight into model behavior than simple accuracy metrics. It clearly shows which classes are getting confused with each other.

For example, if the model frequently predicts dog instead of cat, the confusion matrix will immediately expose that pattern.

What is Pydantic

Pydantic Pydantic is a data validation and settings management library for Python. ...