热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【Pytorch深度学习50篇】·······第二篇:【人脸识别】(5)

【Pytorch深度学习50篇】·······第二篇:【人脸识别】(5)-hello啊朋友们,时隔几日我又回来了,脱更了,因为我去驻厂了,驻厂的意思就是去厂里写代码,在没有网络的环

hello啊朋友们,时隔几日我又回来了,脱更了,因为我去驻厂了,驻厂的意思就是去厂里写代码,在没有网络的环境下,对我来说挑战也不小,对我任何程序员来说没网的话,ctrl+c和ctrl+v这一必杀技就没法用了,所以难顶啊。

3.训练篇

为什么直接就是3了,因为前面已经讲了1和2,不懂就去看,骗流量,哈哈哈。

闲话不多说,开始上训练代码,前面数据准备和网络搭建都已经完成了,现在就要开始训练了

import torch
import torch.nn as nn
import dataset
import my_net as nets
import os


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_image_folder = r'D:\DATAS\manhua\manhua_tou'
    pre_trian_flag = False
    model_folder = r'D:\DATAS\manhua\models'
    lr = 0.001
    epoches = 100
    batch_size = 16

    # 数据准备
    train_data = dataset.dataset(train_image_folder)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)

    # 模型初始化
    if pre_trian_flag == True:
        model_path = os.path.join(model_folder, 'best.pth')
        if os.path.exists(model_path):
            net = torch.load(model_path, map_location=device)
            net.train()
            print('加载预训练模型成功')
        else:
            print('未找到模型或预训练模型,开始重新训练')
    else:
        net = nets.My_Net().to(device)  # 在net.py中自定义一个网络
        net.train()

    # 定义损失函数
    criterion = nn.CrossEntropyLoss()

    # 定义优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    loss_proess = 1
    print('开始训练')
    for epoch in range(epoches):
        train_loss = 0
        for index, (image, label) in enumerate(train_loader):
            image = image.to(device)
            label = label.to(device)
            output = net(image)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            print('epoch[%s/%s]---iteration[%s/%s]---------------loss=%s'%(epoch+1,epoches,index+1,len(train_loader),loss.item()))

        train_loss = train_loss / len(train_loader)
        print('current epoch[%s] total_loss = '%(epoch+1), train_loss)
        if train_loss 

3.1引入的包

还是先说一下导入的包,请看

import torch
import torch.nn as nn
import dataset
import my_net as nets
import os

其中dataset和my_net就是之前提到的两个脚本,其他的就无须多言了

3.2定义的一些路径和超参数

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_image_folder = r'D:\DATAS\manhua\manhua_tou'
    pre_trian_flag = False
    model_folder = r'D:\DATAS\manhua\models'
    lr = 0.001
    epoches = 100
    batch_size = 16

device最后要么是‘cuda’要么是‘cpu’这个就看你电脑的配置了,有没有独立的显卡,一般来说,你搞深度学习没有显卡确实是有点太不方便了,游戏都没法玩,多累啊。哈哈

train_image_folder是训练图片的路径,在dataset篇的时候也讲过了,数据要怎么放,在来截图给你们看看吧

然后每个文件夹里面就是图片文件了,以gangan为例截图演示一下

pre_train_flag 是用来判定有没有预训练模型的标志

model_folder是用来保存模型的文件夹,运行之前,记得先创建一下这个文件夹,免得报错,其实程序里也可以直接加上自己生成这个文件夹的代码,我没加,就是想你报错了回来找我,心机

好了好了,到了最能体现调参侠经验的地方了,这个三个参数,lr,epoches,batch_size,他们是什么,中文名字叫,学习率,迭代次数,批数量。具体都是什么意思,咱们以后开个专题来讲一讲

3.3数据准备和模型定义

    # 数据准备
    train_data = dataset.dataset(train_image_folder)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)

    # 模型初始化
    if pre_trian_flag == True:
        model_path = os.path.join(model_folder, 'best.pth')
        if os.path.exists(model_path):
            net = torch.load(model_path, map_location=device)
            net.train()
            print('加载预训练模型成功')
        else:
            print('未找到模型或预训练模型,开始重新训练')
    else:
        net = nets.My_Net().to(device)  # 在net.py中自定义一个网络
        net.train()

是不是似曾相识啊,因为我们前面以及写过这个代码了,对不,那我们就不详细说明了

3.4优化器和损失函数的定义

    # 定义损失函数
    criterion = nn.CrossEntropyLoss()

    # 定义优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

一般分类的损失函数都会用到crossentropy,只是它是什么,怎么计算的,这里面涉及到了一些数学姿势,要慢慢来说,这里就先记住就行了

优化器呢一般也就是选择Adam,至于为什么,我到现在也还是半蒙状态,所以先不讲。一般盲选Adam就没错了。

3.5训练

    loss_proess = 1
    print('开始训练')
    for epoch in range(epoches):
        train_loss = 0
        for index, (image, label) in enumerate(train_loader):
            image = image.to(device)
            label = label.to(device)
            output = net(image)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            print('epoch[%s/%s]---iteration[%s/%s]---------------loss=%s'%(epoch+1,epoches,index+1,len(train_loader),loss.item()))

        train_loss = train_loss / len(train_loader)
        print('current epoch[%s] total_loss = '%(epoch+1), train_loss)
        if train_loss 

过程就是,先从train_loader里面读数据,然后将数据都放到gpu上,然后数据送入网络,得到的输出和label做损失,得到loss,然后loss方向传播,用于调整网络里的参数,然后再下一次循环,不断的调整参数,使得loss越来越小,也就是说,输入的数据得到的输出结果就会月接近label,这就是我们要的效果。一切计算的过程都交给了计算机,你只需要等待就好了,nice!!我们看看跑起来是什么效果吧

你看total_loss是不是下降了很多了,这就是训练的魔力。

3.6 test的程序

上代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import PIL.Image as pimg

if __name__ == '__main__':
    classes = ['pangpang','shoushou','gangan','xixi','haha']   #设置成你的类别名称
    model_path = os.path.join(r'D:\DATAS\manhua\models','best.pth')
    test_image_folder = r'D:\DATAS\manhua\test_img'
    image_size = 96
    save_folder = r'D:\DATAS\manhua\output_img'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 测试数据准备
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])


    net = torch.load(model_path).to(device)
    net.eval()

    for i in os.listdir(test_image_folder):
        image_path = os.path.join(test_image_folder,i)
        img_ = pimg.open(image_path)

        img = img_.resize((image_size,image_size))
        img = test_transform(img)
        try:
            img = img.view(1, 3, image_size, image_size).to(device)
        except:
            img = img.view(1, 1, image_size, image_size).to(device)

        pre_prob = net(img)
        pre_class = pre_prob.argmax(1).view(-1)
        print(classes[pre_class.item()])
        save_path = os.path.join(save_folder,classes[pre_class.item()])
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        img_.save(save_path + '/' + i)

 试一试能不能看懂test吧,如果看懂了,就学会了

3.7 整个项目的代码链接

链接:https://pan.baidu.com/s/1bZbDQ1f4bfumjkK3RUUDag 
提取码:wrk4 
截图以示清白

整个深度学习分类任务的代码就到自己了,你看看用你的数据集来试一试效果

另外要谢谢大家的支持,已经400个粉丝了,撒花

至此,敬礼,salute!!!


推荐阅读
  • 视觉Transformer综述
    本文综述了视觉Transformer在计算机视觉领域的应用,从原始Transformer出发,详细介绍了其在图像分类、目标检测和图像分割等任务中的最新进展。文章不仅涵盖了基础的Transformer架构,还深入探讨了各类增强版Transformer模型的设计思路和技术细节。 ... [详细]
  • 本文探讨了如何在PHP与MySQL环境中实现高效的分页查询,包括基本的分页实现、性能优化技巧以及高级的分页策略。 ... [详细]
  • 入门指南:使用FastRPC技术连接Qualcomm Hexagon DSP
    本文旨在为初学者提供关于如何使用FastRPC技术连接Qualcomm Hexagon DSP的基础知识。FastRPC技术允许开发者在本地客户端实现远程调用,从而简化Hexagon DSP的开发和调试过程。 ... [详细]
  • 本文将深入探讨 Unreal Engine 4 (UE4) 中的距离场技术,包括其原理、实现细节以及在渲染中的应用。距离场技术在现代游戏引擎中用于提高光照和阴影的效果,尤其是在处理复杂几何形状时。文章将结合具体代码示例,帮助读者更好地理解和应用这一技术。 ... [详细]
  • 雨林木风 GHOST XP SP3 经典珍藏版 YN2014.04
    雨林木风 GHOST XP SP3 经典珍藏版 YN2014.04 ... [详细]
  • Exploring issues and solutions when defining multiple Faust agents programmatically. ... [详细]
  • Kubernetes Services详解
    本文深入探讨了Kubernetes中的服务(Services)概念,解释了如何通过Services实现Pods之间的稳定通信,以及如何管理没有选择器的服务。 ... [详细]
  • Excel技巧:单元格中显示公式而非结果的解决方法
    本文探讨了在Excel中如何通过简单的方法解决单元格显示公式而非计算结果的问题,包括使用快捷键和调整单元格格式两种方法。 ... [详细]
  • 服务器虚拟化存储设计,完美规划储存与资源,部署高性能虚拟化桌面
    规划部署虚拟桌面环境前,必须先估算目前所使用实体桌面环境的工作负载与IOPS性能,并慎选储存设备。唯有谨慎估算贴近实际的IOPS性能,才能 ... [详细]
  • 函子(Functor)是函数式编程中的一个重要概念,它不仅是一个特殊的容器,还提供了一种优雅的方式来处理值和函数。本文将详细介绍函子的基本概念及其在函数式编程中的应用,包括如何通过函子控制副作用、处理异常以及进行异步操作。 ... [详细]
  • 如何将955万数据表的17秒SQL查询优化至300毫秒
    本文详细介绍了通过优化SQL查询策略,成功将一张包含955万条记录的财务流水表的查询时间从17秒缩短至300毫秒的方法。文章不仅提供了具体的SQL优化技巧,还深入探讨了背后的数据库原理。 ... [详细]
  • 流处理中的计数挑战与解决方案
    本文探讨了在流处理中进行计数的各种技术和挑战,并基于作者在2016年圣何塞举行的Hadoop World大会上的演讲进行了深入分析。文章不仅介绍了传统批处理和Lambda架构的局限性,还详细探讨了流处理架构的优势及其在现代大数据应用中的重要作用。 ... [详细]
  • 本文详细探讨了在Web开发中常见的UTF-8编码问题及其解决方案,包括HTML页面、PHP脚本、MySQL数据库以及JavaScript和Flash应用中的乱码问题。 ... [详细]
  • PHP面试题精选及答案解析
    本文精选了新浪PHP笔试题及最新的PHP面试题,并提供了详细的答案解析,帮助求职者更好地准备PHP相关的面试。 ... [详细]
  • 深入解析WebP图片格式及其应用
    随着互联网技术的发展,无论是PC端还是移动端,图片数据流量占据了很大比重。尤其在高分辨率屏幕普及的背景下,如何在保证图片质量的同时减少文件大小,成为了亟待解决的问题。本文将详细介绍Google推出的WebP图片格式,探讨其在实际项目中的应用及优化策略。 ... [详细]
author-avatar
老男孩2702938107
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有