pytorch学习3(pytorch手写数字识别练习)

2023-09-21 16:01:10

网络模型

设置三层网络,一般最后一层激活函数不选择relu
在这里插入图片描述

任务步骤

手写数字识别任务共有四个步骤:
1、数据加载--Load Data
2、构建网络--Build Model
3、训练--Train
4、测试--Test

实战

1、导入各种需要的包

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision

from matplotlib import pyplot as plt

from minist_utils import plot_image, plot_curve, one_hot ##自写文件

minist_utils:
在这里插入图片描述
在这里插入图片描述在这里插入图片描述

2、加载数据

batch_size = 512

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=False

取一些样本看数据的shape以及图片内容

x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')

在这里插入图片描述在这里插入图片描述

注:经过load加载处理后的数据集包含x(图像信息)和y(标签信息)
next(iter())的用法是取一组样本,重复运行可以依次顺序取样,直到样本被取完
可在csdn自行搜索学习了解

3、网络构建

按之前设想的三层线性模型嵌套的思想搭建模型,为了模型简单,第三层不加激活函数。

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        # xw+b
        self.fc1 = nn.Linear(28*28, 256) #输入特征数,输出特征数
        self.fc2 = nn.Linear(256, 64)  #256,64是根据经验判断
        self.fc3 = nn.Linear(64, 10)  #最开始的28*28和输出的10是一定的

    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x)) #输入x后第一次线性模型得到H1作第二层输入
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x)) #输入H1得到H2作第三层输入
        # h3 = h2w3 + b3
        x = self.fc3(x)	#输入H3得到最终结果,维度为10

        return x

4、模型训练

net = Net()

# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []

for epoch in range(3):

    for batch_idx, (x, y) in enumerate(train_loader):

        # x: [b, 1, 28, 28], y: [512]
        # [b, 1, 28, 28] => [b, feature] 全连接层只能接受这样的数据
        x = x.view(x.size(0), 28*28)
        # => [b, 10]
        out = net(x)
        # [b, 10]
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward() # 梯度计算过程
        # w` = w - lr * grad
        optimizer.step() # 优化更新w,b

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)

在这里插入图片描述

5、测试

1、计算准确率acc

total_correct = 0
for x, y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10] => pred: [b]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num = len(test_loader.dataset)
acc = total_correct / total_num
print(("acc:", acc))

在这里插入图片描述
2、展示部分测试样本原图以及预测标签结果

x, y =next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')

在这里插入图片描述

更多推荐

WPF样式

样式是组织和重用格式化选项的重要工具。不是使用重复的标记填充XAML,以便设置外边距、内边距、颜色以及字体等细节,而是创建一系列封装所有这些细节的样式,然后在需要之处通过属性来应用样式。样式基础样式是可应用与元素的属性值集合。WPF样式系统与HTML标记中的层叠样式表(CSS)标准担当类似的角色。与CSS类似,通过WP

强化学习从基础到进阶-案例与实践[4]:深度Q网络-DQN、double DQN、经验回放、rainbow、分布式DQN

【强化学习原理+项目专栏】必看系列:单智能体、多智能体算法原理+项目实战、相关技巧(调参、画图等、趣味项目实现、学术应用项目实现专栏详细介绍:【强化学习原理+项目专栏】必看系列:单智能体、多智能体算法原理+项目实战、相关技巧(调参、画图等、趣味项目实现、学术应用项目实现对于深度强化学习这块规划为:基础单智能算法教学(g

分布式运用之rsync远程同步

一、rsync的相关知识1.1rsync简介rsync(RemoteSync,远程同步)是一个开源的快速备份工具,可以在不同主机之间镜像同步整个目录树,支持增量备份,并保持链接和权限,且采用优化的同步算法,传输前执行压缩,因此非常适用于异地备份、镜像服务器等应用。rsync的官方站点的网址是rsync.samba.or

ubuntu搭建sftp服务

安装OpenSSH服务器Ubuntu通常已经预装了OpenSSH客户端,但如果您还没有OpenSSH服务器,请在终端中执行以下命令来安装:sudoaptupdatesudoaptinstallopenssh-server创建SFTP用户和组创建一个新的用户组(例如sftp_users),用于管理SFTP用户:sudog

Linux之shell条件测试

目录作用基本用法格式:案例-f用法[]用法[[]]用法(())语法文件测试参数案例编写脚本,测试文件是否存在,不存在则创建整数测试作用操作符案例系统用户个数小于50的则输出信息逻辑操作符符号案例命令分隔符案例分析案例1---判断当前已登录的账户数,超过5个则输出信息案例2---取出/etc/passwd文件的第6行内容

Layui快速入门之第十四节 分页

目录一:基本用法API渲染属性二:自定义主题三:自定义文本四:自定义排版五:完整显示一:基本用法分页组件laypage提供了前端的分页逻辑,使得我们可以很灵活处理不同量级的数据,从而提升渲染效率<!DOCTYPEhtml><html><head><metacharset="utf-8"><title>分页</title

STM32低功耗分析

1.ARM发布最新内核2023年5月29日,Arm公司今天发布了处理器核心:Cortex-X4、Cortex-A720和Cortex-A520。这些核心都是基于Armv9.2架构,只支持64位指令集,不再兼容32位应用。Arm公司表示,这些核心在性能和效率方面都有显著的提升,同时也加强了安全性和可扩展性。Cortex-

字符串相似度算法

相似度算法JaccardSimilarityCoefficient、JaroWinkler、CosineSimilarity、Levenshtein距离编辑算法案例。Jaccard相似性系数衡量两个集合的相似程度,通过计算两个集合的交集大小除以并集大小得出。适用于处理文本、推荐系统、生物信息学等领域CosineSimi

青龙面板从0到1的实现

文章目录需要有一台云服务器Docker、SSH、青龙如何打开云服务器上的青龙面板青龙注册登录看这个青龙配置最后、从此需要有一台云服务器我这里选择的是阿里云新用户免费送的三个月服务器,服务器操作系统:CenOS(其他操作系统也可以:Ubantu、Debian)。Docker、SSH、青龙为云服务器系统安装Docker容器

支付功能、支付平台、支持渠道如何测试?

有学员提问:作为一个支付平台,接入了快钱、易宝或直连银行等多家的渠道,内在的产品流程是自己的。业内有什么比较好的测试办法,来测试各渠道及其支持的银行通道呢?作为产品,我自己办了十几张银行卡方便测试,但QA和开发不愿意这样做,怎么办呢?回答:对支付平台而言,与支付渠道相关的测试大致可以分为:测试支付渠道功能、测试支付产品

scons体验以及rtthread中的简单使用

SCons是一个用于构建软件项目的软件构建工具。它使用Python脚本作为配置文件,提供了一种简单而灵活的方式来描述软件项目的构建过程。下面是一个简单的SCons使用示例:安装SCons:首先,确保你已经安装了Python。然后,可以使用Python的包管理器pip安装SCons。在命令行中运行以下命令安装SCons:

热文推荐