Image Classification Example with CNN
In this lesson, we will build a simple image classification model using TensorFlow
and Keras
, and practice classifying digit images (0-9) using the MNIST dataset.
1. Preparing the Data
First, let's load the MNIST
dataset provided by TensorFlow. This dataset consists of 28×28 pixel grayscale images of handwritten digits.
import tensorflow as tf from tensorflow import keras import numpy as np import matplotlib.pyplot as plt # Load MNIST dataset (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() # Normalize data to range 0 to 1 x_train, x_test = x_train / 255.0, x_test / 255.0 # Expand dimensions (CNN requires 3D input) x_train = x_train[..., np.newaxis] x_test = x_test[..., np.newaxis] # Display a data sample plt.imshow(x_train[0].squeeze(), cmap='gray') plt.title(f"Label: {y_train[0]}") plt.show()
2. Creating the CNN Model
A CNN model consists of convolutional layers (Conv2D), pooling layers (MaxPooling2D), and fully connected layers (Dense).
# Define the CNN model model = keras.Sequential([ keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)), keras.layers.MaxPooling2D(2, 2), keras.layers.Conv2D(64, (3,3), activation='relu'), keras.layers.MaxPooling2D(2,2), keras.layers.Flatten(), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(10, activation='softmax') ]) # Display the model architecture model.summary() # Compile the model model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
3. Training the Model
Now, let's train the model with the prepared data.
# Train the model history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
During training, you can observe the accuracy and loss values decreasing.
4. Evaluating and Predicting with the Model
Next, evaluate the trained model and make predictions.
# Evaluate the model loss, acc = model.evaluate(x_test, y_test) print(f"Test Accuracy: {acc:.4f}")
Perform predictions on specific samples to check the results.
# Predict sample data sample = x_test[:5] # First 5 images predictions = model.predict(sample) predicted_labels = np.argmax(predictions, axis=1) # Display prediction results for i in range(5): plt.imshow(sample[i].squeeze(), cmap='gray') plt.title(f"Predicted: {predicted_labels[i]}, Actual: {y_test[i]}") plt.show()
Now, you have built a simple model to classify handwritten digit images using CNN. You can apply this model to various image datasets to build more robust image classification models.
In the next lesson, we will engage in a quiz to review what we've learned so far.
Lecture
AI Tutor
Design
Upload
Notes
Favorites
Help