手写数字卷积神经网络:PyTorch 实现
在计算机视觉领域,卷积神经网络(Convolutional Neural Network,CNN)是一种非常常见的深度学习模型。它在图像识别、目标检测、图像分类等任务中取得了巨大成功。本文将介绍如何使用 PyTorch 实现一个简单的手写数字识别卷积神经网络。
数据集准备
我们将使用 MNIST 数据集,它是一个包含了大量手写数字图片的数据集。PyTorch 提供了内置的 MNIST 数据集,我们可以使用 torchvision 库轻松加载数据。首先,我们需要导入相关的库。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
接下来,我们需要对数据进行预处理。我们可以使用 torchvision.transforms 库对数据进行常用的预处理操作,例如缩放、归一化和数据增强等。在这里,我们将对图像进行缩放到 28x28 大小,并将像素值归一化到 0-1 范围。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
然后,我们可以使用 DataLoader 加载 MNIST 数据集。我们将数据集分为训练集和测试集,其中训练集用于模型的训练和参数优化,测试集用于评估模型的性能。
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
构建模型
我们的卷积神经网络模型将包含多个卷积层、池化层和全连接层。我们可以通过创建一个继承自 nn.Module
的类来定义我们的模型。这个类将包含模型的结构和前向传播的逻辑。
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(7 * 7 * 64, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
model = CNN()
在这个例子中,我们的模型包含两个卷积层,两个池化层和两个全连接层。我们使用了 ReLU 激活函数和最大池化操作来增强模型的非线性能力。
定义损失函数和优化器
接下来,我们需要定义损失函数和优化器。我们将使用交叉熵损失函数和随机梯度下降(SGD)优化器。
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
训练模型
现在,我们可以开始训练我们的模型了。我们将迭代多个 epoch,每个 epoch 中对训练集进行多次训练。对于每个 mini-batch,我们都需要执行以下步骤:
- 将输入数据和标签加载到 GPU 上(如果可用)。
- 清除优化器的梯度。
- 前向传播计算预测值。
- 计算损失。
- 反向传播计算梯度。
- 更新模型参数。
device = torch.device('