热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch实现Classification分类

跟着莫凡大神学习importtorchfromtorch.autogradimportVariableimporttorch.nn.functionalasFimport

跟着莫凡大神学习

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

# make fake data
n_data =torch.ones(100,2)
# https://ptorch.com/docs/1/torchlists
x0 = torch.normal(2*n_data,1) # class0 x data (tensor), shape=(100, 2)
# torch.normal(means, std, out=None) means (Tensor) – 均值 , std (Tensor) – 标准差, out (Tensor) – 可选的输出张量
y0 = torch.zeros(100)
x1 = torch.normal(-2*n_data,1) # class1 x data (tensor), shape=(100, 2)
y1 = torch.ones(100)

x=torch.cat((x0,x1),0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floating
y=torch.cat((y0,y1),0).type(torch.LongTensor) # shape (200,) LOngTensor= 64-bit integer



x,y=Variable(x),Variable(y)

# plt.scatter(x.data.numpy(),y.data.numpy())
# plt.show()

class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_output):
super(Net,self).__init__()
self.hidden = torch.nn.Linear(n_feature,n_hidden)
self.predict = torch.nn.Linear(n_hidden,n_output)

def forward(self, x):
x=F.relu(self.hidden(x))
x=self.predict(x)
return x


net =Net(2,10,2) # define the network
plt.ion() # something about plotting
plt.show()

optimizer =torch.optim.SGD(net.parameters(),lr=0.002) #优化参数
loss_func = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted

for t in range(100):
out =net(x) #开始训练

loss = loss_func(out,y) # 一定要预测的值在前,真实值在后

# below are
optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step()
if t % 2==0: # 每训练2次 ,打印一次
# plot and show learning process
plt.cla()
prediction = torch.max(out,1)[1] # why is 1
predy = prediction.data.numpy().squeeze()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:,0], x.data.numpy()[:,1],c=predy,s=100,lw=0,cmap='RdYlGn')
accuracy = sum(predy == target_y)/200
plt.text(1.5,-4,'Accuracy=%.2f' % accuracy,fOntdict={'size':20,'color':'red'})
plt.pause(0.1)

plt.ioff()
plt.show()


刚开始不太清楚上面使用的数据,所以自己做了一些其他测试

import torchn_data =torch.ones(4,2)print(n_data)x0 = torch.normal(2*n_data,1)print('x0\n', x0)y0 = torch.zeros(4)print('yo\n',y0)x1 = torch.normal(-2*n_data,1)print('x1\n', x1)y1 = torch.ones(4)print('y1\n',y1)x=torch.cat((x0,x1),0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floatingprint('x\n',x)y=torch.cat((y0,y1),0).type(torch.LongTensor) # shape (200,) LOngTensor= 64-bit integerprint('y\n',y)

输出结果如下:

 1  1 1  1 1  1 1  1[torch.FloatTensor of size 4x2]x0  0.2261  3.0315 2.0241  1.5661 4.7188  2.0684 1.8433  2.0262[torch.FloatTensor of size 4x2]yo  0 0 0 0[torch.FloatTensor of size 4]x1 -0.4156 -1.0854-1.5244 -1.1929-2.2120 -0.3639-1.4513 -2.1948[torch.FloatTensor of size 4x2]y1  1 1 1 1[torch.FloatTensor of size 4]x  (PS: 二维平面的坐标)  0.2261  3.0315 2.0241  1.5661 4.7188  2.0684 1.8433  2.0262-0.4156 -1.0854-1.5244 -1.1929-2.2120 -0.3639-1.4513 -2.1948[torch.FloatTensor of size 8x2]y (针对二维平面每个坐标的 标签)  0 0 0 0 1 1 1 1[torch.LongTensor of size 8]Process finished with exit code 0



推荐阅读
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • 如何自行分析定位SAP BSP错误
    The“BSPtag”Imentionedintheblogtitlemeansforexamplethetagchtmlb:configCelleratorbelowwhichi ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文讨论了如何使用IF函数从基于有限输入列表的有限输出列表中获取输出,并提出了是否有更快/更有效的执行代码的方法。作者希望了解是否有办法缩短代码,并从自我开发的角度来看是否有更好的方法。提供的代码可以按原样工作,但作者想知道是否有更好的方法来执行这样的任务。 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 本文详细介绍了如何使用MySQL来显示SQL语句的执行时间,并通过MySQL Query Profiler获取CPU和内存使用量以及系统锁和表锁的时间。同时介绍了效能分析的三种方法:瓶颈分析、工作负载分析和基于比率的分析。 ... [详细]
  • 本文介绍了如何使用Express App提供静态文件,同时提到了一些不需要使用的文件,如package.json和/.ssh/known_hosts,并解释了为什么app.get('*')无法捕获所有请求以及为什么app.use(express.static(__dirname))可能会提供不需要的文件。 ... [详细]
  • IjustinheritedsomewebpageswhichusesMooTools.IneverusedMooTools.NowIneedtoaddsomef ... [详细]
author-avatar
爱着你心却痛_534
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有