热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

深度学习优化器

优化器理解:通过梯度下降来更新神经网络中参数的值,使实际值慢慢的向要求值靠近。但是优化器有很多种,因为只是入门,所以我先不

优化器理解:

通过梯度下降来更新神经网络中参数的值,使实际值慢慢的向要求值靠近。但是优化器有很多种,因为只是入门,所以我先不仔细了解各种优化器中要求的算法
可以观察官方文档中的示例:

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

理解与解释:从dataset这个数据集中依次拿出一个数据input,然后我们先将优化器清零,因为如果不清零,上一次循环中原有的值还在,会影响这次循环的优化,然后我们将input放入我们搭建的模型model中,然后利用损失函数来计算差值,利用反向传播来计算我们应该更新的值的大小,然后利用优化器的step()函数来进行一个更新。一般来说,我们会进行很多次这样的优化,来使实际值趋于我们需要的值,所以可以在上方再加一重循环。

对于各种优化器其中的算法,入门的我先不做了解,可以看下官方文档


代码呈现优化器过程

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("dataset2",train=False,transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=1)class Test(nn.Module):def __init__(self):super(Test, self).__init__()self.conv1 = Conv2d(3,32,5,padding=2)self.maxpool1 = MaxPool2d(kernel_size=2)self.conv2 = Conv2d(32,32,5,padding=2)self.maxpool2 = MaxPool2d(2)self.conv3 = Conv2d(32,64,5,padding=2)self.maxpool3 = MaxPool2d(kernel_size=2)self.flatten = Flatten()self.linear1 = Linear(1024,64)self.linear2 = Linear(64,10)def forward(self,x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.conv3(x)x = self.maxpool3(x)x = self.flatten(x)x = self.linear1(x)x = self.linear2(x)return x;test = Test()
loss = CrossEntropyLoss()
optim = torch.optim.SGD(test.parameters(),lr=0.01)
for epoch in range(20):running_loss = 0for data in dataloader:imgs,target = data;optim.zero_grad()output = test(imgs)result_loss = loss(output,target)result_loss.backward()optim.step()running_loss+=result_lossprint(running_loss)

输出:

tensor(18688.1523, grad_fn&#61;<AddBackward0>)
tensor(16161.1523, grad_fn&#61;<AddBackward0>)
tensor(15446.3174, grad_fn&#61;<AddBackward0>)
...

可以看出我们的差值在不断减小。


推荐阅读
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了使用Spark实现低配版高斯朴素贝叶斯模型的原因和原理。随着数据量的增大,单机上运行高斯朴素贝叶斯模型会变得很慢,因此考虑使用Spark来加速运行。然而,Spark的MLlib并没有实现高斯朴素贝叶斯模型,因此需要自己动手实现。文章还介绍了朴素贝叶斯的原理和公式,并对具有多个特征和类别的模型进行了讨论。最后,作者总结了实现低配版高斯朴素贝叶斯模型的步骤。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 个人学习使用:谨慎参考1Client类importcom.thoughtworks.gauge.Step;importcom.thoughtworks.gauge.T ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文介绍了南邮ctf-web的writeup,包括签到题和md5 collision。在CTF比赛和渗透测试中,可以通过查看源代码、代码注释、页面隐藏元素、超链接和HTTP响应头部来寻找flag或提示信息。利用PHP弱类型,可以发现md5('QNKCDZO')='0e830400451993494058024219903391'和md5('240610708')='0e462097431906509019562988736854'。 ... [详细]
  • 3.223.28周学习总结中的贪心作业收获及困惑
    本文是对3.223.28周学习总结中的贪心作业进行总结,作者在解题过程中参考了他人的代码,但前提是要先理解题目并有解题思路。作者分享了自己在贪心作业中的收获,同时提到了一道让他困惑的题目,即input details部分引发的疑惑。 ... [详细]
  • 本文讨论了如何使用IF函数从基于有限输入列表的有限输出列表中获取输出,并提出了是否有更快/更有效的执行代码的方法。作者希望了解是否有办法缩短代码,并从自我开发的角度来看是否有更好的方法。提供的代码可以按原样工作,但作者想知道是否有更好的方法来执行这样的任务。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 本文介绍了在iOS开发中使用UITextField实现字符限制的方法,包括利用代理方法和使用BNTextField-Limit库的实现策略。通过这些方法,开发者可以方便地限制UITextField的字符个数和输入规则。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 本文讨论了编写可保护的代码的重要性,包括提高代码的可读性、可调试性和直观性。同时介绍了优化代码的方法,如代码格式化、解释函数和提炼函数等。还提到了一些常见的坏代码味道,如不规范的命名、重复代码、过长的函数和参数列表等。最后,介绍了如何处理数据泥团和进行函数重构,以提高代码质量和可维护性。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
author-avatar
暗夜风线_371
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有