热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

softmax分类器_一文看懂如何实现softmax分类器(内附完全代码)

本文首发于公众号【拇指笔记】1.实现softmax回归模型首先还是导入需要的包#实现softmax回归importtorchimporttorchvisionimportsy
f2965dba482191f8b5de61bb107ce2f6.png

本文首发于公众号【拇指笔记】

1.实现softmax回归模型

首先还是导入需要的包

#实现softmax回归
import torch
import torchvision
import sys
import numpy as npfrom IPython import display
from numpy import argmax
import torchvision.transforms as transforms
from time import time
import matplotlib.pyplot as plt

1.1获取和读取数据

设置小批量数目为256。这一部分与之前的线性回归的读取数据大同小异,都是转换类型-->生成迭代器。

batch_size = 256mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
#获取训练集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
#获取测试集(这两个数据集在已存在的情况下不会被再次下载)#生成迭代器(调用一次返回一次数据)
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)

1.2初始化模型参数

由输入的数据可知:每个图像都是28*28像素,也就是说每个图像都有28*28=784个特征值。由于图像有10个类别,所以这个网络一共有10个输出。共计存在:784*10个权重参数和10个偏差参数。

num_inputs = 784
num_outputs = 10#初始化参数与线性回归也类似,权重参数设置为均值为0 标准差为0.01的正态分布;偏差设置为0
W = torch.tensor(np.random.normal(0,0.01,(num_inputs,num_outouts)),dtype = torch.float)
b = torch.zeros(num_outputs,dtype=torch.float32)#同样的,开启模型参数梯度
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)

1.3实现softmax运算

softmax运算本质就是将每个元素变成非负数,且每一行和为1。

首先回顾一下tensor的按维度操作。

X = torch.tensor([[1,2,3],[4,5,6]])#dim=0表示对列求和。keepdim表示是否在结果中保留行和列这两个维度
X.sum(dim=0,keepdim=True)
X.sum(dim=1,keepdim=True)

然后定义一下softmax运算:softmax运算会先对每个元素做指数运算,再对exp矩阵同行元素求和,最后令矩阵每行各元素与该行元素之和相除,最终得到的矩阵每行元素和为1且非负数。

def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1,keepdim=True)return X_exp/partition #这部分用了广播机制

1.4定义模型

将第二步做的和第三步做的合起来。

def net(X):return softmax(torch.mm(X.view((-1,num_inputs)), W) +b)#第一步:首先将X换形成28*28的张量,然后用.mm函数将换形后的张量与权重参数相乘,最后加偏差参数#第二步:对第一步进行softmax运算

1.5定义损失函数

首先介绍一下torch.gather函数

#gather函数的定义
torch.gather(input,dim,index,out=None) → Tensor
#gather的作用是这样的,index实际上是索引,具体是行(dim=1)还是列(dim=0)的索引要看前面dim 的指定,输出的大小由index决定

这个函数的原理我归结如下

假设输入与上同;index=B;输出为C
B中每个元素分别为b(0,0)=0,b(0,1)=0b(1,0)=1,b(1,1)=0如果dim=0(列)
则取B中元素的列号,如:b(0,1)的1
b(0,1)=0,所以C中的c(0,1)=输入的(0,1)处元素2如果dim=1(行)
则取B中元素的列号,如:b(0,1)的0
b(0,1)=0,所以C中的c(0,1)=输入的(0,0)处元素1总结如下:
输出 元素 在 输入张量 中的位置为:
输出元素位置取决与同位置的index元素
dim=1时,取同位置的index元素的行号做行号,该位置处index元素做列号
dim=0时,取同位置的index元素的列号做列号,该位置处index元素做行号。最后根据得到的索引在输入中取值index类型必须为LongTensor
gather最终的输出变量与index同形。

例子如下:

import torcha = torch.Tensor([[1,2],[3,4]])b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]]))
#1. 取各个元素行号:[(0,y)(0,y)][(1,y)(1,y)]
#2. 取各个元素值做行号:[(0,0)(0,0)][(1,1)(1,0)]
#3. 根据得到的索引在输入中取值
#[1,1],[4,3]c = torch.gather(a,0,torch.LongTensor([[0,0],[1,0]]))
#1. 取各个元素列号:[(x,0)(x,1)][(x,0)(x,1)]
#2. 取各个元素值做行号:[(0,0)(0,1)][(1,0)(0,1)]
#3. 根据得到的索引在输入中取值
#[1,2],[3,2]

因为softmax回归模型得到的结果可能是多个标签对应的概率,为了得到与真实标签之间的损失值,我们需要使用gather函数提取出在结果中提取出真实标签对应的概率。

假设y_hat是1个样本在3个类别中的预测概率(其余七个为0),y是这个样本的真实标签(数字0-9表示)。

y_hat = torch.tensor([0.1,0.3,0.6])
y = torch.LongTensor([0]) #gather函数中的index参数类型必须为LongTensor
y_hat.gather(1,y.vies(-1,1))
#如果y不是列向量,则需要将变量y换形为列向量。选取第一维度(行)。
#套用上述公式可知,输出为0.1,0.1就是真是类别0的概率。

有了上述理论基础,并根据交叉熵函数的公式

我们可以得到最终的损失函数。

def cross_entropy(y_hat,y):return -torch.log(y_hat.gather(1,y.view(-1,1)))

1.6计算分类准确率

计算准确率的原理:

我们把预测概率最大的类别作为输出类别,如果它与真实类别y一致,说明预测正确。分类准确率就是正确预测数量与总预测数量之比

首先我们需要得到预测的结果。

从一组预测概率(变量y_hat)中找出最大的概率对应的索引(索引即代表了类别)

#argmax(f(x))函数,对f(x)求最大值所对应的点x。我们令f(x)= dim=1,即可实现求所有行上的最大值对应的索引。
A = y_hat.argmax(dim=1)
#最终输出结果为一个行数与y_hat相同的列向量

然后我们需要将得到的最大概率对应的类别与真实类别(y)比较,判断预测是否是正确的

B = (y_hat.argmax(dim=1)==y).float()
#由于y_hat.argmax(dim=1)==y得到的是ByteTensor型数据,所以我们通过.float()将其转换为浮点型Tensor()

最后我们需要计算分类准确率

我们知道y_hat的行数就对应着样本总数,所以,对B求平均值得到的就是分类准确率

(y_hat.argmax(dim=1)==y).float().mean()

上一步最终得到的数据为tensor(x)的形式,为了得到最终的pytorch number,需要对其进行下一步操作

(y_hat.argmax(dim=1)==y).float().mean().item()
#pytorch number的获取统一通过.item()实现

整理一下,得到计算分类准确率函数

def accuracy(y_hat,y):return (y_hat.argmax(dim=1).float().mean().item())

作为推广,该函数还可以评价模型net在数据集data_iter上的准确率。

def net_accurary(data_iter,net):right_sum,n = 0.0,0for X,y in data_iter:#从迭代器data_iter中获取X和yright_sum += (net(X).argmax(dim=1)==y).float().sum().item()#计算准确判断的数量n +=y.shape[0]#通过shape[0]获取y的零维度(列)的元素数量return right_sum/n

1.7优化算法

softmax回归应用的优化算法同样使用小批量随机梯度下降算法。

def sgd(params,lr,batch_size):#lr:学习率,params:权重参数和偏差参数for param in params:param.data -= lr*param.grad/batch_size#.data是对数据备份进行操作,不改变数据本身。

1.8训练模型

在训练模型时,迭代周期数num_epochs和学习率lr都是可以调节的超参数,通过调节超参数的值可以获得分类更准确的模型。

num_epochs,lr = 5,0.1def train_softmax(net,train_iter,test_iter,loss,num_epochs,batch_size,params,lr ,optimizer):for epoch in range(num_epochs):#损失值、正确数量、总数 初始化。train_l_sum,train_right_sum,n= 0.0,0.0,0for X,y in train_iter:y_hat = net(X)l = loss(y_hat,y).sum()#数据集损失函数的值=每个样本的损失函数值的和。 if optimizer is not None:optimizer.zero_grad() #对优化函数梯度清零elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backgrad() #对损失函数求梯度optimzer(params,lr,batch_size)train_l_sum += l.item()train_right_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = net_accuracy(test_iter, net) #测试集的准确率print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))train_softmax(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,[W,b],lr,sgd)

1.9预测

做一个模型的最终目的当然不是训练了,所以来预测一下试试。

def get_Fashion_MNIST_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]#labels是一个列表,所以有了for循环获取这个列表对应的文本列表def show_fashion_mnist(images,labels):display.set_matplotlib_formats('svg')#绘制矢量图_,figs = plt.subplots(1,len(images),figsize=(12,12))#设置添加子图的数量、大小for f,img,lbl in zip(figs,images,labels):f.imshow(img.view(28,28).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()X, y = iter(test_iter).next()true_labels = get_Fashion_MNIST_labels(y.numpy())
pred_labels = get_Fashion_MNIST_labels(net(X).argmax(dim=1).numpy())
titles = [true + 'n' + pred for true, pred in zip(true_labels, pred_labels)]show_fashion_mnist(X[0:9], titles[0:9])

最终效果

769e3984051aa01372e27557d7ce0846.png

由于训练比较耗时,我只训练了五次,可以看出,随着训练次数的增加,损失值从0.7854减少到0.4846;准确率从0.785提升到0.826。

完整程序

#实现softmax回归
import torch
import torchvision
import sys
import numpy as npfrom IPython import display
from numpy import argmax
import torchvision.transforms as transforms
from time import time
import matplotlib.pyplot as pltbatch_size =256
num_inputs = 784
num_outputs = 10
num_epochs,lr = 5,0.1mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
#获取训练集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
#获取测试集(这两个数据集在已存在的情况下不会被再次下载)#生成迭代器(调用一次返回一次数据)
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)#初始化参数与线性回归也类似,权重参数设置为均值为0 标准差为0.01的正态分布;偏差设置为0
W = torch.tensor(np.random.normal(0,0.01,(num_inputs,num_outputs)),dtype = torch.float)
b = torch.zeros(num_outputs,dtype=torch.float32)#同样的,开启模型参数梯度
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1,keepdim=True)return X_exp/partition #这部分用了广播机制def net(X):return softmax(torch.mm(X.view((-1,num_inputs)), W) +b)def cross_entropy(y_hat,y):return -torch.log(y_hat.gather(1,y.view(-1,1)))def accuracy(y_hat,y):return (y_hat.argmax(dim=1).float().mean().item())def net_accurary(data_iter,net):right_sum,n = 0.0,0for X,y in data_iter:#从迭代器data_iter中获取X和yright_sum += (net(X).argmax(dim=1)==y).float().sum().item()#计算准确判断的数量n +=y.shape[0]#通过shape[0]获取y的零维度(列)的元素数量return right_sum/ndef sgd(params,lr,batch_size):#lr:学习率,params:权重参数和偏差参数for param in params:param.data -= lr*param.grad/batch_sizedef train_softmax(net,train_iter,test_iter,loss,num_epochs,batch_size,params,lr ,optimizer,net_accuracy):for epoch in range(num_epochs):#损失值、正确数量、总数 初始化。train_l_sum,train_right_sum,n= 0.0,0.0,0for X,y in train_iter:y_hat = net(X)l = loss(y_hat,y).sum()#数据集损失函数的值=每个样本的损失函数值的和。 if params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward() #对损失函数求梯度optimizer(params,lr,batch_size)train_l_sum += l.item()train_right_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = net_accurary(test_iter, net) #测试集的准确率print('epoch %d, loss %.4f, train right %.3f, test right %.3f' % (epoch + 1, train_l_sum / n, train_right_sum / n, test_acc))def get_Fashion_MNIST_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]#labels是一个列表,所以有了for循环获取这个列表对应的文本列表def show_fashion_mnist(images,labels):display.set_matplotlib_formats('svg')#绘制矢量图_,figs = plt.subplots(1,len(images),figsize=(12,12))#设置添加子图的数量、大小for f,img,lbl in zip(figs,images,labels):f.imshow(img.view(28,28).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()time1 = time()
train_softmax(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,[W,b],lr,sgd,net_accurary)
print('n',time()-time1,'s')X, y = iter(test_iter).next()true_labels = get_Fashion_MNIST_labels(y.numpy())
pred_labels = get_Fashion_MNIST_labels(net(X).argmax(dim=1).numpy())
titles = [true + 'n' + pred for true, pred in zip(true_labels, pred_labels)]show_fashion_mnist(X[0:9], titles[0:9])




推荐阅读
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 本文介绍了C#中生成随机数的三种方法,并分析了其中存在的问题。首先介绍了使用Random类生成随机数的默认方法,但在高并发情况下可能会出现重复的情况。接着通过循环生成了一系列随机数,进一步突显了这个问题。文章指出,随机数生成在任何编程语言中都是必备的功能,但Random类生成的随机数并不可靠。最后,提出了需要寻找其他可靠的随机数生成方法的建议。 ... [详细]
  • 拥抱Android Design Support Library新变化(导航视图、悬浮ActionBar)
    转载请注明明桑AndroidAndroid5.0Loollipop作为Android最重要的版本之一,为我们带来了全新的界面风格和设计语言。看起来很受欢迎࿰ ... [详细]
  • ASP.NET2.0数据教程之十四:使用FormView的模板
    本文介绍了在ASP.NET 2.0中使用FormView控件来实现自定义的显示外观,与GridView和DetailsView不同,FormView使用模板来呈现,可以实现不规则的外观呈现。同时还介绍了TemplateField的用法和FormView与DetailsView的区别。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • EPPlus绘制刻度线的方法及示例代码
    本文介绍了使用EPPlus绘制刻度线的方法,并提供了示例代码。通过ExcelPackage类和List对象,可以实现在Excel中绘制刻度线的功能。具体的方法和示例代码在文章中进行了详细的介绍和演示。 ... [详细]
  • Python使用Pillow包生成验证码图片的方法
    本文介绍了使用Python中的Pillow包生成验证码图片的方法。通过随机生成数字和符号,并添加干扰象素,生成一幅验证码图片。需要配置好Python环境,并安装Pillow库。代码实现包括导入Pillow包和随机模块,定义随机生成字母、数字和字体颜色的函数。 ... [详细]
author-avatar
再度重相逢jc_866
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有