热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

CV算法复现(分类算法2/6):AlexNet(2012年Hinton组)

致谢:霹雳吧啦Wz:https:space.bilibili.com18161609目录致谢:霹雳吧啦Wz:https:

致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609

目录

致谢:霹雳吧啦Wz:https://space.bilibili.com/18161609

1 本次要点

1.1 深度学习理论

1.2 pytorch框架语法

2 网络简介

2.1 历史意义

2.2 网络亮点

2.3 网络架构

3 代码结构

3.1 model.py

3.2 train.py

3.3 predict.py

3.4 split_data.py




1 本次要点


1.1 深度学习理论


  1. 经过一次卷积操作后,图像新尺寸计算公式:(如果padding [p1, p2]中p1,p2不相等,那么公式中2P就变为P1+P2)(如果结果值不是整数,pytorch中会自动忽略最后一行以及最后一列,以保证N为整数。)
  2.  

1.2 pytorch框架语法


  1. pytorch可以自定义网络权重的初始化方法(见model.py)。
  2. pata = list(net.parameters()) #查看模型参数


2 网络简介


2.1 历史意义


  • 2012年ImageNet图像分类冠军网络,分类准确率由传统的 70%+直接提升到 80%+。在那年之后,深
    度学习开始迅速发展。

2.2 网络亮点


  1. 首次利用 GPU 进行网络加速训练。
  2. 使用了 ReLU 激活函数,而不是传统的 Sigmoid 激活函数以及 Tanh 激活函数。
  3. 在前两层的全连接层中使用了 Dropout 随机失活神经元操作,以减少过拟合。

2.3 网络架构

备注:padding: [1, 2]即图像最左边缘加1列0,最右边缘加2列0。图像最上边缘加1行0,图像最下边缘加2行0。


3 代码结构


  • model.py
  • train.py
  • predict.py
  • split_data.py(数据集划分)

3.1 model.py

import torch.nn as nn
import torch"""
本AlexNet复现相比原论文,每层的卷积核个数减半。
"""
class AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()# nn.Sequential():将一系列层结构进行打包。省去每一层都用一个变量去表示。self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]nn.ReLU(inplace=True), #inplace:通过增加计算量来降低内存使用,从而可以载入更大模型(默认False)。nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048), # 输入:128通道*6*6(特征图大小)(到此之前会拉成1维)nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1) # torch中顺序[B,C,H,W],start_dim=1就是将C维度拉平。x = self.classifier(x)return x# 初始化权重方式(框架有默认,如果要自定义可如下方式写)def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

3.2 train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time"""
数据集:花分类(5类)
"""def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),#水平随机翻转transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) #os.getcwd():获取当前绝对路径。"../.."返回到上上层路径。image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())#将键和值顺序反过来。目的是让模型预测的结果索引,可直接找到对应的类型。# write dict into json filejson_str = json.dumps(cla_dict, indent=4)#编码成json格式with open('class_indices.json', 'w') as json_file:#新建json文件并写入内容json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=False,num_workers=nw)print("using {} images for training, {} images fot validation.".format(train_num,# 查看数据集代码 val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()## def imshow(img):# img = img / 2 + 0.5 # unnormalize# npimg = img.numpy()# plt.imshow(np.transpose(npimg, (1, 2, 0)))# plt.show()## print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))# imshow(utils.make_grid(test_image))net = AlexNet(num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()# pata = list(net.parameters()) #查看模型参数(调试用)optimizer = optim.Adam(net.parameters(), lr=0.0002)save_path = './AlexNet.pth'best_acc = 0.0for epoch in range(10):# 训练阶段net.train() #自动判定dropout或BN层是否应该启用。running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()#反向传播optimizer.step()#更新每个节点参数# print statisticsrunning_loss += loss.item()# print train process 打印训练信息rate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)# 验证阶段net.eval() #自动判定dropout或BN层是否应该启用。acc = 0.0 # accumulate accurate number / epochwith torch.no_grad():#不去计算损失梯度for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc: best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')if __name__ == '__main__':main()

训练结果:


3.3 predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():#不去计算损失梯度# predict classoutput = torch.squeeze(model(img))#torch.squeeze():对数据的维度进行压缩,去掉维数为1的的维度predict = torch.softmax(output, dim=0)#将预测结果值转换为概率分布形式。predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

输出:


3.4 split_data.py

import os
from shutil import copy, rmtree
import random"""
使用步骤如下:
(1)在data_set文件夹下创建新文件夹"flower_data"
(2)点击链接下载花分类数据集 http://download.tensorflow.org/example_images/flower_photos.tgz
(3)解压数据集到flower_data文件夹下
(4)执行"split_data.py"脚本自动将数据集划分成训练集train和验证集val├── flower_data ├── flower_photos(解压的数据集文件夹,3670个样本) ├── train(生成的训练集,3306个样本) └── val(生成的验证集,364个样本)
"""def mk_file(file_path: str):if os.path.exists(file_path):# 如果文件夹存在,则先删除原文件夹在重新创建rmtree(file_path)os.makedirs(file_path)def main():# 保证随机可复现random.seed(0)# 将数据集中10%的数据划分到验证集中split_rate = 0.1# 指向你解压后的flower_photos文件夹cwd = os.getcwd()data_root = os.path.join(cwd, "flower_data")origin_flower_path = os.path.join(data_root, "flower_photos")assert os.path.exists(origin_flower_path)flower_class = [cla for cla in os.listdir(origin_flower_path)if os.path.isdir(os.path.join(origin_flower_path, cla))]# 建立保存训练集的文件夹train_root = os.path.join(data_root, "train")mk_file(train_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(train_root, cla))# 建立保存验证集的文件夹val_root = os.path.join(data_root, "val")mk_file(val_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(val_root, cla))for cla in flower_class:cla_path = os.path.join(origin_flower_path, cla)images = os.listdir(cla_path)num = len(images)# 随机采样验证集的索引eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:# 将分配至验证集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(val_root, cla)copy(image_path, new_path)else:# 将分配至训练集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(train_root, cla)copy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing barprint()print("processing done!")if __name__ == '__main__':main()

输出:


推荐阅读
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • GPT-3发布,动动手指就能自动生成代码的神器来了!
    近日,OpenAI发布了最新的NLP模型GPT-3,该模型在GitHub趋势榜上名列前茅。GPT-3使用的数据集容量达到45TB,参数个数高达1750亿,训练好的模型需要700G的硬盘空间来存储。一位开发者根据GPT-3模型上线了一个名为debuid的网站,用户只需用英语描述需求,前端代码就能自动生成。这个神奇的功能让许多程序员感到惊讶。去年,OpenAI在与世界冠军OG战队的表演赛中展示了他们的强化学习模型,在限定条件下以2:0完胜人类冠军。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • t-io 2.0.0发布-法网天眼第一版的回顾和更新说明
    本文回顾了t-io 1.x版本的工程结构和性能数据,并介绍了t-io在码云上的成绩和用户反馈。同时,还提到了@openSeLi同学发布的t-io 30W长连接并发压力测试报告。最后,详细介绍了t-io 2.0.0版本的更新内容,包括更简洁的使用方式和内置的httpsession功能。 ... [详细]
  • 本文介绍了PhysioNet网站提供的生理信号处理工具箱WFDB Toolbox for Matlab的安装和使用方法。通过下载并添加到Matlab路径中或直接在Matlab中输入相关内容,即可完成安装。该工具箱提供了一系列函数,可以方便地处理生理信号数据。详细的安装和使用方法可以参考本文内容。 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 本文介绍了JavaScript进化到TypeScript的历史和背景,解释了TypeScript相对于JavaScript的优势和特点。作者分享了自己对TypeScript的观察和认识,并提到了在项目开发中使用TypeScript的好处。最后,作者表示对TypeScript进行尝试和探索的态度。 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 【疑难杂症】allennlp安装报错:Installing build dependencies ... error
    背景:配置PURE的算法环境,安装allennlp0.9.0(pipinstallallennlp0.9.0)报错ÿ ... [详细]
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • S3D算法详解
    S3D论文详解论文地址:RethinkingSpatiotemporalFeatureLearning:Speed-AccuracyTrade-offsinVide ... [详细]
  • pytorch Dropout过拟合的操作
    这篇文章主要介绍了pytorchDropout过拟合的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完 ... [详细]
  • 都会|可能会_###haohaohao###图神经网络之神器——PyTorch Geometric 上手 & 实战
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了###haohaohao###图神经网络之神器——PyTorchGeometric上手&实战相关的知识,希望对你有一定的参考价值。 ... [详细]
author-avatar
mobiledu2502859163
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有