Matlab 使用CNN拟合回归模型预测手写数字的旋转角度
  NTGlKyq7MwNU 2023年11月02日 221 0

✅作者简介:热爱科研的算法开发者,Python、Matlab项目可交流、沟通、学习。

🍎个人主页:算法工程师的学习日志

一个深度学习文档分享一下,很简单,但思路不错,在个人项目上也可以按照需求变化数据集来实现CNN回归计算

加载数据

clc
close all
clear
%% 加载数据
%% 数据集包含手写数字的合成图像,以及每幅图像旋转的对应角度(以角度为单位)。
%% 使用digitTrain4DArrayData和digitTest4DArrayData将训练和验证图像加载为4D数组。
%% 输出YTrain和YValidation是以角度为单位的旋转角度。每个训练和验证数据集包含5000张图像。
[XTrain, ~, Ytrain] = digitTrain4DArrayData;
[XValidation, ~, YValidation] = digitTest4DArrayData;
%% 随机显示20张训练图像
numTrainImages = numel(Ytrain);
figure;
idx = randperm(numTrainImages, 20);
for i = 1 : numel(idx)
  subplot(4, 5, i);
  imshow(XTrain(:, :, :, idx(i)))
  drawnow
end

Matlab 使用CNN拟合回归模型预测手写数字的旋转角度_旋转角度

数据归一化

当训练神经网络时,确保你的数据在网络的所有阶段都是标准化的。归一化有助于使用梯度下降来稳定和加速网络训练。如果数据规模太小,那么损失可能会变成NaN,并且在培训期间网络参数可能会出现分歧。

标准化数据的常用方法包括重新标定数据,使其范围变为[0,1]或使其均值为0,标准差为1。

标准化以下数据:

1、输入数据。在将预测器输入到网络之前对数据进行规范化。

2、层输出。使用批处理规范化层对每个卷积和完全连接层的输出进行规范化。

3、响应。如果使用批处理规范化层对网络末端的层输出进行规范化,则在开始训练时对网络的预测进行规范化。

%% 绘制响应分布:在分类问题中,输出是类概率,类概率总是归一化的。
figure;
histogram(Ytrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

Matlab 使用CNN拟合回归模型预测手写数字的旋转角度_旋转角度_02

创建网络层

%% 创建网络层
%% 第一层定义输入数据的大小和类型。输入的图像大小为28×28×1。创建与训练图像大小相同的图像输入层。
%% 网络的中间层定义了网络的核心架构,大部分计算和学习都在这个架构中进行。
%% 最后一层定义输出数据的大小和类型。对于回归问题,全连接层必须先于网络末端的回归层。
layers = [
  imageInputLayer([28 28 1])
  batchNormalizationLayer
  reluLayer
  
  averagePooling2dLayer(2, 'Stride', 2)
  
  convolution2dLayer(3, 16, 'Padding', 'same')
  batchNormalizationLayer
  reluLayer
  
  averagePooling2dLayer(2, 'Stride', 2)
  
  convolution2dLayer(3, 32, 'Padding', 'same')
  batchNormalizationLayer
  reluLayer
  
  convolution2dLayer(3, 32, 'Padding', 'same')
  batchNormalizationLayer
  reluLayer
  
  dropoutLayer(0.2)
  fullyConnectedLayer(1)
  regressionLayer];

Matlab 使用CNN拟合回归模型预测手写数字的旋转角度_2d_03

训练网络设置

使用 trainNetwork 创建网络。如果存在兼容的 GPU,此命令会使用 GPU。否则,trainNetwork 将使用 CPU。在 GPU 上进行训练需要具有 3.0 或更高计算能力的支持 CUDA® 的 NVIDIA® GPU。

%% 训练网络——Options
%% Train for 30 epochs 学习率0.001 在20个epoch后降低学习率。
%% 通过指定验证数据和验证频率,监控培训过程中的网络准确性。
%% 根据训练数据对网络进行训练,并在训练过程中定期对验证数据进行精度计算。
%% 验证数据不用于更新网络权重。打开训练进度图,并关闭命令窗口输出。
miniBatchSize = 128;
validationFrequency = floor(numel(Ytrain) / miniBatchSize);
options = trainingOptions('sgdm', ...
  'MiniBatchSize', miniBatchSize, ...
  'MaxEpochs', 30, ...
  'InitialLearnRate', 1e-3, ...
  'LearnRateSchedule', 'piecewise', ...
  'LearnRateDropFactor', 0.1, ...
  'LearnRateDropPeriod', 20, ...
  'Shuffle', 'every-epoch', ...
  'ValidationData', {XValidation, YValidation}, ...
  'ValidationFrequency', validationFrequency, ...
  'Plots', 'training-progress', ...
  'Verbose', false);

训练网络

net = trainNetwork(XTrain, Ytrain, layers, options);

Matlab 使用CNN拟合回归模型预测手写数字的旋转角度_数据_04

预测结果


基于验证数据评估准确度来测试网络性能。使用 predict 预测验证图像的旋转角度。

YPredicted = predict(net,XValidation);

评估性能

通过计算以下值来评估模型性能:

predictionError = YValidation - YPredicted;

计算在实际角度的可接受误差界限内的预测值的数量。将阈值设置为 10 度。计算此阈值范围内的预测值的百分比。

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);
accuracy = numCorrect/numValidationImages

使用均方根误差 (RMSE) 来衡量预测旋转角度和实际旋转角度之间的差异。

squares = predictionError.^2;
rmse = sqrt(mean(squares))
accuracy =
    0.9584




rmse =
  single
    4.8987


显示原始数字以及校正旋转后的数字,使用 montage (Image Processing Toolbox) 将数字显示在同一个图像上。

Matlab 使用CNN拟合回归模型预测手写数字的旋转角度_2d_05

【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月08日 0

暂无评论

推荐阅读
  3XDZIv8qh70z   2023年12月23日   19   0   0 2d2d
NTGlKyq7MwNU