热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【图像分类实战】利用DenseNet在PyTorch中实现秃头识别

本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。

目录

摘要

导入项目使用的库

设置全局参数

图像预处理

 读取数据

设置模型

设置训练和验证

测试

完整代码




摘要

​     我在前面的文章已经写了很多模型的实战,这是实战的最后一篇了。我没有加入可视化,也没有对代码做过多的装饰,只希望用最简单的方式让大家知道分类模型是怎样实现的。

    今天我们用DenseNet实现对秃头的分类,数据集我放在百度网盘了,地址:链接:https://pan.baidu.com/s/177ethB_1ZLyl8_Ef1lJxSA 提取码:47fo 。这个数据集可能让广大的程序员扎心了。

下面展示一下数据集的样例。

这个都是秃顶的,他们的共同特点:都是男士,为啥女士不秃顶呢? 


导入项目使用的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torchvision.models import densenet121

 


设置全局参数

设置BatchSize、学习率和epochs,判断是否有cuda环境,如果没有设置为cpu

# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 32
EPOCHS = 5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

图像预处理

      在做图像与处理时,train数据集的transform和验证集的transform分开做,train的图像处理出了resize和归一化之外,还可以设置图像的增强,比如旋转、随机擦除等一系列的操作,验证集则不需要做图像增强,另外不要盲目的做增强,不合理的增强手段很可能会带来负作用,甚至出现Loss不收敛的情况。
 

# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

 读取数据

这个数据集,作者已经做过调整,可以直接使用Python的默认方式读取数据。数据的目录如下图:

 训练集中有1.6万张图片,其中有3千多个是秃头,验证集有2万多张,其中秃头是470张,数量差别太大我随机删了一部分。然后,写读取数据的代码。

dataset_train = datasets.ImageFolder('Dataset/Train', transform)
dataset_test = datasets.ImageFolder('Dataset/Validation',transform_test)
# 读取数据
print(dataset_train.imgs)# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型


使用交叉熵作为loss,模型采用densenet121,建议使用预训练模型,我在调试的过程中,使用预训练模型可以快速得到收敛好的模型,使用预训练模型将pretrained设置为True即可。更改最后一层的全连接,将类别设置为2,然后将模型放到DEVICE。优化器选用Adam。

# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = densenet121(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, 2)
model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)def adjust_learning_rate(optimizer, epoch):"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""modellrnew = modellr * (0.1 ** (epoch // 50))print("lr:", modellrnew)for param_group in optimizer.param_groups:param_group['lr'] = modellrnew

设置训练和验证

 最外层是循环的是每个epochs,先训练,后验证。下面分别讲一下训练和验证的过程。

训练过程必须经历的步骤:

第一步:将输入input向前传播,进行运算后得到输出output,代码:output = model(data)

第二步:将output再输入loss函数,计算loss值(是个标量),代码:  loss = criterion(output, target)

第三步:将梯度反向传播到每个参数,代码:  loss.backward()

第四步:将参数的grad值初始化为0,代码: optimizer.zero_grad()

第五步:更新权重,代码: optimizer.step()

验证过程和训练过程基本相似。

 

# 定义训练过程def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 50 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))# 验证过程
def val(model, device, test_loader):model.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))# 训练for epoch in range(1, EPOCHS + 1):adjust_learning_rate(optimizer, epoch)train(model_ft, DEVICE, train_loader, optimizer, epoch)val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')

完成后就可以run了,运行结果如下: 


测试

测试集存放的目录如下图:

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!我们在训练时,Bald类别是0,NoBald类别是1,所以我定义classes为('Bald','NoBald')。

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

第三步 加载model,并将模型放在DEVICE里,

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import osclasses=('Bald','NoBald')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)path='Dataset/Test/Bald/'
testList=os.listdir(path)
for file in testList:img=Image.open(path+file)img=transform_test(img)img.unsqueeze_(0)img = Variable(img).to(DEVICE)out=model(img)# Predict_, pred = torch.max(out.data, 1)print('Image Name:{},predict:{}'.format(file,classes[pred.data.item()]))

运行结果如下:

第二种方法可以使用pytorch默认加载数据集的方法。

import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variableclasses=('Bald','NoBald')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)dataset_test = datasets.ImageFolder('Dataset/Test',transform_test)
print(len(dataset_test))
# 对应文件夹的labelfor index in range(len(dataset_test)):item = dataset_test[index]img, label = itemimg.unsqueeze_(0)data = Variable(img).to(DEVICE)output = model(data)_, pred = torch.max(output.data, 1)print('Image Name:{},predict:{}'.format(dataset_test.imgs[index], classes[pred.data.item()]))index += 1

运行结果:


完整代码

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transformsimport torchvision.datasets as datasets
from torch.autograd import Variable
from torchvision.models import densenet121# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 32
EPOCHS = 5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 数据预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = datasets.ImageFolder('Dataset/Train', transform)
dataset_test = datasets.ImageFolder('Dataset/Validation',transform_test)
# 读取数据
print(dataset_train.imgs)# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = densenet121(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, 2)
model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)def adjust_learning_rate(optimizer, epoch):"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""modellrnew = modellr * (0.1 ** (epoch // 50))print("lr:", modellrnew)for param_group in optimizer.param_groups:param_group['lr'] = modellrnew# 定义训练过程def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 50 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))# 验证过程
def val(model, device, test_loader):model.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))# 训练for epoch in range(1, EPOCHS + 1):adjust_learning_rate(optimizer, epoch)train(model_ft, DEVICE, train_loader, optimizer, epoch)val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')

DenseNet图像分类.zip-深度学习文档类资源-CSDN下载


推荐阅读
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 探索偶数次幂二项式系数的求和方法及其数学意义 ... [详细]
  • 利用Python进行学生学业表现评估与成绩预测分析
    利用Python进行学生学业表现评估与成绩预测分析 ... [详细]
  • 在《Python编程基础》课程中,我们将深入探讨Python中的循环结构。通过详细解析for循环和while循环的语法与应用场景,帮助初学者掌握循环控制语句的核心概念和实际应用技巧。此外,还将介绍如何利用循环结构解决复杂问题,提高编程效率和代码可读性。 ... [详细]
  • 从2019年AI顶级会议最佳论文,探索深度学习的理论根基与前沿进展 ... [详细]
  • 在Ubuntu系统中配置Python环境变量是确保项目顺利运行的关键步骤。本文介绍了如何将Windows上的Django项目迁移到Ubuntu,并解决因虚拟环境导致的模块缺失问题。通过详细的操作指南,帮助读者正确配置虚拟环境,确保所有第三方库都能被正确识别和使用。此外,还提供了一些实用的技巧,如如何检查环境变量配置是否正确,以及如何在多个虚拟环境之间切换。 ... [详细]
  • TensorFlow Lite在移动设备上的部署实践与优化笔记
    近期在探索如何将服务器端的模型迁移到移动设备上,并记录了一些关键问题和解决方案。本文假设读者具备以下基础知识:了解TensorFlow的计算图(Graph)、图定义(GraphDef)和元图定义(MetaGraphDef)。此外,文中还详细介绍了模型转换、性能优化和资源管理等方面的实践经验,为开发者提供有价值的参考。 ... [详细]
  • PyCharm 作为 JetBrains 出品的知名集成开发环境(IDE),提供了丰富的功能和强大的工具支持,包括项目视图、代码结构视图、代码导航、语法高亮、自动补全和错误检测等。本文详细介绍了 PyCharm 的高级使用技巧和程序调试方法,旨在帮助开发者提高编码效率和调试能力。此外,还探讨了如何利用 PyCharm 的插件系统扩展其功能,以满足不同开发场景的需求。 ... [详细]
  • 本文介绍了如何利用Apache POI库高效读取Excel文件中的数据。通过实际测试,除了分数被转换为小数存储外,其他数据均能正确读取。若在使用过程中发现任何问题,请及时留言反馈,以便我们进行更新和改进。 ... [详细]
  • 如何高效启动大数据应用之旅?
    在前一篇文章中,我探讨了大数据的定义及其与数据挖掘的区别。本文将重点介绍如何高效启动大数据应用项目,涵盖关键步骤和最佳实践,帮助读者快速踏上大数据之旅。 ... [详细]
  • Java SE 文件操作类详解与应用
    ### Java SE 文件操作类详解与应用#### 1. File 类##### 1.1 File 类概述File 类是 Java SE 中用于表示文件和目录路径名的对象。它提供了丰富的方法来操作文件和目录,包括创建、删除、重命名文件,以及获取文件属性和信息。通过 File 类,开发者可以轻松地进行文件系统操作,如检查文件是否存在、读取文件内容、列出目录下的文件等。此外,File 类还支持跨平台操作,确保在不同操作系统中的一致性。 ... [详细]
  • 探索聚类分析中的K-Means与DBSCAN算法及其应用
    聚类分析是一种用于解决样本或特征分类问题的统计分析方法,也是数据挖掘领域的重要算法之一。本文主要探讨了K-Means和DBSCAN两种聚类算法的原理及其应用场景。K-Means算法通过迭代优化簇中心来实现数据点的划分,适用于球形分布的数据集;而DBSCAN算法则基于密度进行聚类,能够有效识别任意形状的簇,并且对噪声数据具有较好的鲁棒性。通过对这两种算法的对比分析,本文旨在为实际应用中选择合适的聚类方法提供参考。 ... [详细]
  • 在Java编程中,若需实现两个整数(例如2和3)相除并保留两位小数的结果,可以通过精确计算方法来达到预期效果。具体而言,可以利用BigDecimal类进行高精度运算,确保2除以3的结果准确显示为0.66。此外,还可以通过格式化输出来控制小数位数,确保最终结果符合要求。 ... [详细]
  • 本文详细解析了 MySQL 5.7.20 版本中二进制日志(binlog)崩溃恢复机制的工作流程。假设使用 InnoDB 存储引擎,并且启用了 `sync_binlog=1` 配置,文章深入探讨了在系统崩溃后如何通过 binlog 进行数据恢复,确保数据的一致性和完整性。 ... [详细]
  • 探讨 `org.openide.windows.TopComponent.componentOpened()` 方法的应用及其代码实例分析 ... [详细]
author-avatar
闲看静观_925
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有