KERAS and MNIST#

import matplotlib.pyplot as plt
import numpy as np

We’ll apply the ideas we just learned to a neural network that does character recognition using the MNIST database. This is a set of handwritten digits (0–9) represented as a 28×28 pixel grayscale image.

There are 2 datasets, the training set with 60,000 images and the test set with 10,000 images.

import keras
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/keras/src/export/tf2onnx_lib.py:8: FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar.
  if not hasattr(np, "object"):

Important

Keras requires a backend, which can be tensorflow, pytorch, or jax.

By default, it will assume tensorflow.

This notebook has been tested with both pytorch and tensorflow.

Tip

To have keras use pytorch, set the environment variable KERAS_BACKEND as:

export KERAS_BACKEND="torch"

We follow the example for setting up the network: Vict0rSch/deep_learning

Note

For visualization of the network, you need to have pydot installed.

The MNIST data#

The keras library can download the MNIST data directly and provides a function to give us both the training and test images and the corresponding digits. This is already in a format that Keras wants, so we don’t use the classes that we defined earlier.

from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
       0/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step

 4333568/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

11403264/11490434 ━━━━━━━━━━━━━━━━━━━ 0s 0us/step

11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

As before, the training set consists of 60000 digits represented as a 28x28 array (there are no color channels, so this is grayscale data). They are also integer data.

X_train.shape
(60000, 28, 28)
X_train.dtype
dtype('uint8')

Let’s look at the first digit and the “y” value (target) associated with it—that’s the correct answer.

plt.imshow(X_train[0], cmap="gray_r")
print(y_train[0])
5

../_images/4567d8f9bd61f12d86168899465c03b2a4ce67b2904092490c6a2b9dc7107b30.png

Preparing the Data#

The neural network takes a 1-d vector of input and will return a 1-d vector of output. We need to convert our data to this form.

We’ll scale the image data to fall in [0, 1) and the numerical output to be categorized as an array. Finally, we need the input data to be one-dimensional, so we fill flatten the 28x28 images into a single 784 vector.

X_train = X_train.astype('float32')/255
X_test = X_test.astype('float32')/255

X_train = np.reshape(X_train, (60000, 784))
X_test = np.reshape(X_test, (10000, 784))
X_train[0]
array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.01176471, 0.07058824, 0.07058824,
       0.07058824, 0.49411765, 0.53333336, 0.6862745 , 0.10196079,
       0.6509804 , 1.        , 0.96862745, 0.49803922, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.11764706, 0.14117648, 0.36862746, 0.6039216 ,
       0.6666667 , 0.99215686, 0.99215686, 0.99215686, 0.99215686,
       0.99215686, 0.88235295, 0.6745098 , 0.99215686, 0.9490196 ,
       0.7647059 , 0.2509804 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.19215687, 0.93333334,
       0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686,
       0.99215686, 0.99215686, 0.99215686, 0.9843137 , 0.3647059 ,
       0.32156864, 0.32156864, 0.21960784, 0.15294118, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.07058824, 0.85882354, 0.99215686, 0.99215686,
       0.99215686, 0.99215686, 0.99215686, 0.7764706 , 0.7137255 ,
       0.96862745, 0.94509804, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.3137255 , 0.6117647 , 0.41960785, 0.99215686, 0.99215686,
       0.8039216 , 0.04313726, 0.        , 0.16862746, 0.6039216 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.05490196,
       0.00392157, 0.6039216 , 0.99215686, 0.3529412 , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.54509807,
       0.99215686, 0.74509805, 0.00784314, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.04313726, 0.74509805, 0.99215686,
       0.27450982, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.13725491, 0.94509804, 0.88235295, 0.627451  ,
       0.42352942, 0.00392157, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.31764707, 0.9411765 , 0.99215686, 0.99215686, 0.46666667,
       0.09803922, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.1764706 ,
       0.7294118 , 0.99215686, 0.99215686, 0.5882353 , 0.10588235,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.0627451 , 0.3647059 ,
       0.9882353 , 0.99215686, 0.73333335, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.9764706 , 0.99215686,
       0.9764706 , 0.2509804 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.18039216, 0.50980395,
       0.7176471 , 0.99215686, 0.99215686, 0.8117647 , 0.00784314,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.15294118,
       0.5803922 , 0.8980392 , 0.99215686, 0.99215686, 0.99215686,
       0.98039216, 0.7137255 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.09411765, 0.44705883, 0.8666667 , 0.99215686, 0.99215686,
       0.99215686, 0.99215686, 0.7882353 , 0.30588236, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.09019608, 0.25882354, 0.8352941 , 0.99215686,
       0.99215686, 0.99215686, 0.99215686, 0.7764706 , 0.31764707,
       0.00784314, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.07058824, 0.67058825, 0.85882354,
       0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.7647059 ,
       0.3137255 , 0.03529412, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.21568628, 0.6745098 ,
       0.8862745 , 0.99215686, 0.99215686, 0.99215686, 0.99215686,
       0.95686275, 0.52156866, 0.04313726, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.53333336, 0.99215686, 0.99215686, 0.99215686,
       0.83137256, 0.5294118 , 0.5176471 , 0.0627451 , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        ], dtype=float32)

We will use categorical data. Keras includes routines to categorize data. In our case, since there are 10 possible digits, we want to put the output into 10 categories (represented by 10 neurons)

from keras.utils import to_categorical

y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

Now let’s look at the target for the first training digit. We know from above that it was ‘5’. Here we see that there is a 1 in the index corresponding to 5 (remember we start counting at 0 in python).

y_train[0]
array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])

Build the Neural Network#

Now we’ll build the neural network. We will have 2 hidden layers, and the number of neurons will look like:

784 → 500 → 300 → 10

Layers#

Let’s start by setting up the layers. For each layer, we tell keras the number of output neurons. It infers the number of inputs from the previous layer (with the exception of the input layer, where we need to tell it what to expect as input).

Properties on the layers:

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Input

model = Sequential()
model.add(Input(shape=(784,)))
model.add(Dense(500, activation="relu"))
model.add(Dropout(0.4))
model.add(Dense(300, activation="relu"))
model.add(Dropout(0.4))
model.add(Dense(10, activation="softmax"))

Loss function#

We need to specify what we want to optimize and how we are going to do it.

Recall: the loss (or cost) function measures how well our predictions match the expected target. Previously we were using the sum of the squares of the error.

For categorical data, like we have, the “cross-entropy” metric is often used. See here for an explanation: https://jamesmccaffrey.wordpress.com/2013/11/05/why-you-should-use-cross-entropy-error-instead-of-classification-error-or-mean-squared-error-for-neural-network-classifier-training/

Optimizer#

We also need to specify an optimizer. This could be gradient descent, as we used before. Here’s a list of the optimizers supoprted by keras: https://keras.io/api/optimizers/ We’ll use RMPprop, which builds off of gradient descent and includes some momentum.

Finally, we need to specify a metric that is evaluated during training and testing. We’ll use "accuracy" here. This means that we’ll see the accuracy of our model reported as we are training and testing.

More details on these options is here: https://keras.io/api/models/model/

from keras.optimizers import RMSprop

rms = RMSprop()
model.compile(loss='categorical_crossentropy',
              optimizer=rms, metrics=['accuracy'])

Network summary#

Let’s take a look at the network:

model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense)                   │ (None, 500)            │       392,500 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 500)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 300)            │       150,300 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 300)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 10)             │         3,010 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 545,810 (2.08 MB)
 Trainable params: 545,810 (2.08 MB)
 Non-trainable params: 0 (0.00 B)

We see that there are > 500k parameters that we will be training

Train#

For training, we pass in the inputs and target and the number of epochs to run and it will optimize the network by adjusting the weights between the nodes in the layers.

The number of epochs is the number of times the entire data set is passed forward and backward through the network. The batch size is the number of training pairs you pass through the network at a given time. You update the parameter in your model (the weights) once for each batch. This makes things more efficient and less noisy.

epochs = 20
batch_size = 256
model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size,
          validation_data=(X_test, y_test), verbose=2)
Epoch 1/20
235/235 - 4s - 18ms/step - accuracy: 0.8874 - loss: 0.3675 - val_accuracy: 0.9579 - val_loss: 0.1355
Epoch 2/20
235/235 - 4s - 18ms/step - accuracy: 0.9510 - loss: 0.1599 - val_accuracy: 0.9714 - val_loss: 0.0883
Epoch 3/20
235/235 - 4s - 19ms/step - accuracy: 0.9653 - loss: 0.1171 - val_accuracy: 0.9729 - val_loss: 0.0913
Epoch 4/20
235/235 - 4s - 19ms/step - accuracy: 0.9703 - loss: 0.0969 - val_accuracy: 0.9763 - val_loss: 0.0758
Epoch 5/20
235/235 - 4s - 18ms/step - accuracy: 0.9753 - loss: 0.0802 - val_accuracy: 0.9787 - val_loss: 0.0712
Epoch 6/20
235/235 - 4s - 19ms/step - accuracy: 0.9782 - loss: 0.0712 - val_accuracy: 0.9793 - val_loss: 0.0685
Epoch 7/20
235/235 - 4s - 18ms/step - accuracy: 0.9796 - loss: 0.0641 - val_accuracy: 0.9822 - val_loss: 0.0582
Epoch 8/20
235/235 - 4s - 18ms/step - accuracy: 0.9823 - loss: 0.0572 - val_accuracy: 0.9826 - val_loss: 0.0561
Epoch 9/20
235/235 - 5s - 19ms/step - accuracy: 0.9830 - loss: 0.0540 - val_accuracy: 0.9836 - val_loss: 0.0570
Epoch 10/20
235/235 - 5s - 19ms/step - accuracy: 0.9854 - loss: 0.0475 - val_accuracy: 0.9844 - val_loss: 0.0572
Epoch 11/20
235/235 - 4s - 18ms/step - accuracy: 0.9863 - loss: 0.0446 - val_accuracy: 0.9830 - val_loss: 0.0575
Epoch 12/20
235/235 - 5s - 19ms/step - accuracy: 0.9864 - loss: 0.0426 - val_accuracy: 0.9840 - val_loss: 0.0505
Epoch 13/20
235/235 - 4s - 19ms/step - accuracy: 0.9874 - loss: 0.0383 - val_accuracy: 0.9837 - val_loss: 0.0603
Epoch 14/20
235/235 - 4s - 18ms/step - accuracy: 0.9877 - loss: 0.0380 - val_accuracy: 0.9845 - val_loss: 0.0542
Epoch 15/20
235/235 - 5s - 19ms/step - accuracy: 0.9890 - loss: 0.0331 - val_accuracy: 0.9829 - val_loss: 0.0664
Epoch 16/20
235/235 - 4s - 18ms/step - accuracy: 0.9894 - loss: 0.0323 - val_accuracy: 0.9862 - val_loss: 0.0545
Epoch 17/20
235/235 - 4s - 19ms/step - accuracy: 0.9900 - loss: 0.0321 - val_accuracy: 0.9856 - val_loss: 0.0545
Epoch 18/20
235/235 - 4s - 19ms/step - accuracy: 0.9900 - loss: 0.0304 - val_accuracy: 0.9861 - val_loss: 0.0571
Epoch 19/20
235/235 - 4s - 18ms/step - accuracy: 0.9908 - loss: 0.0286 - val_accuracy: 0.9854 - val_loss: 0.0588
Epoch 20/20
235/235 - 5s - 19ms/step - accuracy: 0.9909 - loss: 0.0274 - val_accuracy: 0.9847 - val_loss: 0.0642
<keras.src.callbacks.history.History at 0x7f584430c6e0>

Test#

keras has a routine, evaluate() that can take the inputs and targets of a test data set and return the loss value and accuracy (or other defined metrics) on this data.

Here we see we are > 98% accurate on the test data—this is the data that the model has never seen before (and was not trained with).

loss_value, accuracy = model.evaluate(X_test, y_test, batch_size=16)
print(accuracy)
  1/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 1.0000 - loss: 0.0078

 11/625 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 1.0000 - loss: 0.0045

 21/625 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9989 - loss: 0.0173

 31/625 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9966 - loss: 0.0323

 41/625 ━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9950 - loss: 0.0410

 52/625 ━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9934 - loss: 0.0478

 63/625 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9925 - loss: 0.0514

 74/625 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9915 - loss: 0.0545

 79/625 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9909 - loss: 0.0561

 89/625 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9899 - loss: 0.0597

 99/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9890 - loss: 0.0627

109/625 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9883 - loss: 0.0654

115/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9879 - loss: 0.0669

125/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9874 - loss: 0.0689

131/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9871 - loss: 0.0701

137/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9868 - loss: 0.0712

143/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9865 - loss: 0.0722

153/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9861 - loss: 0.0735

164/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9856 - loss: 0.0749

175/625 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9852 - loss: 0.0762

186/625 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9849 - loss: 0.0772

196/625 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9847 - loss: 0.0780

206/625 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9845 - loss: 0.0786

217/625 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9843 - loss: 0.0791

227/625 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9841 - loss: 0.0796

237/625 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9839 - loss: 0.0801

247/625 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9838 - loss: 0.0806

249/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9837 - loss: 0.0807

259/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9835 - loss: 0.0812

269/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9834 - loss: 0.0817

280/625 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9832 - loss: 0.0822

285/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9831 - loss: 0.0824

295/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9830 - loss: 0.0828

301/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9829 - loss: 0.0830

307/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9828 - loss: 0.0831

313/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9827 - loss: 0.0833

314/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9827 - loss: 0.0833

324/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9826 - loss: 0.0834

335/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9825 - loss: 0.0835

346/625 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - accuracy: 0.9825 - loss: 0.0835

357/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9824 - loss: 0.0835

367/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9824 - loss: 0.0835

377/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9823 - loss: 0.0835

387/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9823 - loss: 0.0835

397/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9822 - loss: 0.0834

407/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9822 - loss: 0.0833

418/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9822 - loss: 0.0832

420/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0832

429/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0831

440/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0830

451/625 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.9822 - loss: 0.0828

456/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0827

466/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0825

472/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0824

479/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0823

485/625 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - accuracy: 0.9822 - loss: 0.0821

495/625 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9822 - loss: 0.0819

505/625 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9823 - loss: 0.0816

515/625 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9823 - loss: 0.0813

526/625 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9823 - loss: 0.0810

537/625 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9824 - loss: 0.0807

548/625 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9824 - loss: 0.0804

559/625 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9825 - loss: 0.0801

569/625 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9825 - loss: 0.0798

579/625 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9825 - loss: 0.0795

589/625 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9826 - loss: 0.0792

590/625 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9826 - loss: 0.0792

596/625 ━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9826 - loss: 0.0790

605/625 ━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9827 - loss: 0.0787

615/625 ━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9827 - loss: 0.0785

625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.9847 - loss: 0.0642
0.9847000241279602

Predicting#

Suppose we simply want to ask our neural network to predict the target for an input. We can use the predict() method to return the category array with the predictions. We can then use np.argmax() to select the most probable.

np.argmax(model.predict(np.array([X_test[0]])))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
np.int64(7)
y_test[0]
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])

Now let’s loop over the test set and print out what we predict vs. the true answer for those we get wrong. We can also plot the image of the digit.

wrong = 0
max_wrong = 10

for n, (x, y) in enumerate(zip(X_test, y_test)):
    try:
        res = model.predict(np.array([x]), verbose=0)
        if np.argmax(res) != np.argmax(y):
            print(f"test {n}: prediction = {np.argmax(res)}, truth is {np.argmax(y)}")
            plt.imshow(x.reshape(28, 28), cmap="gray_r")
            plt.show()
            wrong += 1
            if (wrong > max_wrong-1):
                break
    except KeyboardInterrupt:
        print("stopping")
        break
test 247: prediction = 2, truth is 4
../_images/95b9f0fd23894c2cbbb25bb94ff4162bea2142c17024708eb2e068cc777e852f.png
test 321: prediction = 7, truth is 2
../_images/ffee7b61de1ff038024f9ad240685159d4c292312da298aac782027770fecb9c.png
test 340: prediction = 3, truth is 5
../_images/c8c2834b4172a70240f93d1cb14ae0d552f4a26654861da536d16eea043dd641.png
test 445: prediction = 0, truth is 6
../_images/99aa1a1124655bc04ed0c253cede4ee4f50d860b4a8e1e8796107a753cbcfabf.png
test 495: prediction = 0, truth is 8
../_images/ae7d94ffa26d5baa2e15a13dae0847ac6a63895412a6d03f86c08f8e3f328f37.png
test 582: prediction = 2, truth is 8
../_images/33080619ca831dc4a962e00d235d22b7840db996fcbf78452a7c4a9c7b934226.png
test 619: prediction = 8, truth is 1
../_images/2ee3a7bfb70145b8521630f44baa7166fdc312a788f1d42b7a2f4f4c568ffe42.png
test 684: prediction = 3, truth is 7
../_images/98019fec9f9a7010e0af77902647a3580eeef88e4c97dbff5ed78ad0225f50f3.png
test 691: prediction = 4, truth is 8
../_images/5ec6442430ba2e69ddc680b2ebf114ec010ce3c02fc55b979af9a0496de1ae80.png
test 720: prediction = 8, truth is 5
../_images/f063e7e494a50222ddcca935866b8e2396b261ba2326a451382c9d2b05c48c78.png

Experimenting#

There are a number of things we can play with to see how the network performance changes:

  • batch size

  • adding or removing hidden layers

  • changing the dropout

  • changing the activation function

Callbacks#

Keras allows for callbacks each epoch to store some information. These can allow you to, for example, plot of the accuracy vs. epoch by adding a callback. Take a look here for some inspiration:

https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/History

Going Further#

Convolutional neural networks are often used for image recognition, especially with larger images. They use filter to try to recognize patterns in portions of images (A tile). See this for a keras example:

https://www.tensorflow.org/tutorials/images/cnn