Most of the Google tutorials on Keras do not show how to display a confusion matrix for the solution. A confusion matrix can throw a much clearer light on how the model is actually performing.
Below is a simple CIFAR-10 solution using Keras. Most of the code is similar to any standard CIFAR-10 TensorFlow tutorial, except for a small number of lines at the end which plot the confusion matrix.
Those specific lines are clearly marked with comments.
Complete Code
import tensorflow as tf from tensorflow import keras import matplotlib.pyplot as plt (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data() input_shape = train_images.shape[1:] model = tf.keras.models.Sequential() model.add(tf.keras.layers.Conv2D( 32, 3, activation="relu", input_shape=input_shape )) model.add(tf.keras.layers.Conv2D( 32, 3, activation="relu" )) model.add(tf.keras.layers.MaxPooling2D()) model.add(tf.keras.layers.Conv2D( 64, 3, activation="relu" )) model.add(tf.keras.layers.Conv2D( 64, 3, activation="relu" )) model.add(tf.keras.layers.MaxPooling2D()) model.add(tf.keras.layers.Flatten()) model.add(tf.keras.layers.Dense( 1024, activation="relu" )) model.add(tf.keras.layers.Dense( 10, activation="softmax" )) model.compile( optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=["accuracy"] ) epochs = 20 history = model.fit( train_images, train_labels, validation_data=(test_images, test_labels), epochs=epochs ) plt.figure(figsize=(8,8)) plt.subplot(1,2,1) plt.plot( range(epochs), history.history["accuracy"], "r", label="Training Accuracy" ) plt.plot( range(epochs), history.history["val_accuracy"], "b", label="Validation Accuracy" ) plt.legend(loc="upper left") plt.title("Accuracy") plt.subplot(1,2,2) plt.plot( range(epochs), history.history["loss"], "r", label="Training Loss" ) plt.plot( range(epochs), history.history["val_loss"], "b", label="Validation Loss" ) plt.legend(loc="upper right") plt.title("Loss") plt.show() predictions = model.predict(test_images) # --------------------------------------------------------- # The following lines plot the confusion matrix # --------------------------------------------------------- predictions_for_cm = predictions.argmax(1) from sklearn.metrics import confusion_matrix import seaborn as sns class_names = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] cm = confusion_matrix( test_labels, predictions_for_cm ) plt.figure(figsize=(8,8)) sns.heatmap( cm, annot=True, xticklabels=class_names, yticklabels=class_names ) plt.xlabel("Predicted Label") plt.ylabel("True Label") plt.title("Confusion Matrix") plt.show()
Why Confusion Matrix Matters
A confusion matrix gives much deeper insight into model behavior than simple accuracy metrics. It clearly shows which classes are getting confused with each other.
For example, if the model frequently predicts dog instead of cat, the confusion matrix will immediately expose that pattern.
Source: My answer on Stack Overflow