热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Fer2013表情识别pytorch(CNN、VGG、Resnet)

#fer2013数据集##数据集介绍*Fer2013人脸表情数据集由35886张人脸表情图片组成,其中,测试图(Training)28708张,公共验证图(PublicT

fer2013数据集

数据集介绍

  • Fer2013人脸表情数据集由35886张人脸表情图片组成,其中,测试图(Training)28708张,公共验证图(PublicTest)和私有验证图(PrivateTest)各3589张,每张图片是由大小固定为48×48的灰度图像组成,共有7种表情,分别对应于数字标签0-6,具体表情对应的标签和中英文如下:0 anger 生气; 1 disgust 厌恶; 2 fear 恐惧; 3 happy 开心; 4 sad 伤心;5 surprised 惊讶; 6 normal 中性。

数据整理

  • 数据给的是一个csv文件,其中的表情数据并没有直接给图片,而是给了像素值,没关系,整理的时候顺便转换成图片就好
  • 将数据分类顺便转换成图片,这里直接分成训练集和验证集两个文件夹。
    import numpy as np
    import pandas as pd
    from PIL import Image
    import os
    
    train_path = \'./data/train/\'
    vaild_path = \'./data/vaild/\'
    data_path = \'./icml_face_data.csv\'
    
    def make_dir():
        for i in range(0,7):
            p1 = os.path.join(train_path,str(i))
            p2 = os.path.join(vaild_path,str(i))
            if not os.path.exists(p1):
                os.makedirs(p1)
            if not os.path.exists(p2):
                os.makedirs(p2)       
    
    def save_images():
        df = pd.read_csv(data_path)
        t_i = [1 for i in range(0,7)]
        v_i = [1 for i in range(0,7)]
        for index in range(len(df)):
            emotion = df.loc[index][0]
            usage = df.loc[index][1] 
            image = df.loc[index][2]
            data_array = list(map(float, image.split()))
            data_array = np.asarray(data_array)
            image = data_array.reshape(48, 48)
            im = Image.fromarray(image).convert(\'L\')#8位黑白图片
            if(usage==\'Training\'):
                t_p = os.path.join(train_path,str(emotion),\'{}.jpg\'.format(t_i[emotion]))
                im.save(t_p)
                t_i[emotion] += 1
                #print(t_p)
            else:
                v_p = os.path.join(vaild_path,str(emotion),\'{}.jpg\'.format(v_i[emotion]))
                im.save(v_p)
                v_i[emotion] += 1
                #print(v_p)
    
    make_dir()
    save_images()
    

简单分析

  • 整理好后看一下数据的分布情况,我们可以看到厌恶表情的数据特别少,其他表情尚可。

数据预处理

  • 我们可以对这些灰度图片做一点数据增强
    path_train = \'./data/train/\'
    path_vaild = \'./data/vaild/\'
    
    transforms_train = transforms.Compose([
        transforms.Grayscale(),#使用ImageFolder默认扩展为三通道,重新变回去就行
        transforms.RandomHorizontalFlip(),#随机翻转
        transforms.ColorJitter(brightness=0.5, cOntrast=0.5),#随机调整亮度和对比度
        transforms.ToTensor()
    ])
    transforms_vaild = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor()
    ])
    
    data_train = torchvision.datasets.ImageFolder(root=path_train,transform=transforms_train)
    data_vaild = torchvision.datasets.ImageFolder(root=path_vaild,transform=transforms_vaild)
    
    train_set = torch.utils.data.DataLoader(dataset=data_train,batch_size=BATCH_SIZE,shuffle=True)
    vaild_set = torch.utils.data.DataLoader(dataset=data_vaild,batch_size=BATCH_SIZE,shuffle=False)
    
  • 看一下效果
    for i in range(1,16+1):
        plt.subplot(4,4,i)
        plt.imshow(data_train[0][0],cmap=\'Greys_r\')
        plt.axis(\'off\')
    plt.show()
    

CNN

模型搭建

  • 使用nn.Sequential快速搭建模型
    CNN = nn.Sequential(
        nn.Conv2d(1,64,3),
        nn.ReLU(True),
        nn.MaxPool2d(2,2),
        nn.Conv2d(64,256,3),
        nn.ReLU(True),
        nn.MaxPool2d(3,3),
        Reshape(),# 两个卷积和池化后,tensor形状为(batchsize,256,7,7)
        nn.Linear(256*7*7,4096),
        nn.ReLU(True),
        nn.Linear(4096,1024),
        nn.ReLU(True),
        nn.Linear(1024,7)
        )
    
  • 其中自己实现Reshape,将tensor打平以送入全连接层
    class Reshape(nn.Module):
        def __init__(self, *args):
            super(Reshape, self).__init__()
    
        def forward(self, x):
            return x.view(x.shape[0],-1)
    

训练效果

  • 显然,在第17个epoch的时候验证集准确率就到了瓶颈

VGG

模型搭建

  • def vgg_block(num_convs, in_channels, out_channels):
        blk = []
        for i in range(num_convs):
            if i == 0:
                blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            else:
                blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
            blk.append(nn.ReLU())
        blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 这里会使宽高减半
        return nn.Sequential(*blk)
    
    def vgg(conv_arch, fc_features, fc_hidden_units):
        net = nn.Sequential()
        # 卷积层部分
        for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):
            # 每经过一个vgg_block都会使宽高减半
            net.add_module("vgg_block_" + str(i+1), vgg_block(num_convs, in_channels, out_channels))
        # 全连接层部分
        net.add_module("fc", nn.Sequential(
                                    Reshape(),
                                    nn.Linear(fc_features, fc_hidden_units),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hidden_units, fc_hidden_units),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hidden_units, 7)
                                    ))
        return net
    
    conv_arch = ((1, 3, 32), (1, 32, 64), (2, 64, 128))
    # 经过5个vgg_block, 宽高会减半5次, 变成 224/32 = 7
    fc_features = 128 * 6* 6 # c * w * h
    fc_hidden_units = 1024 
    
    model = vgg(conv_arch, fc_features, fc_hidden_units)
    

训练效果

  • 先训练了30个epoch
  • vgg的优点在于能使用相同的模块快速加深网络,更深的网络可能会带来更好的学习效果,我们可以增加训练次数来观察曲线

Resnet

模型搭建

  • class Residual(nn.Module): 
        def __init__(self, in_channels, out_channels, use_1x1cOnv=False, stride=1):
            super(Residual, self).__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
            if use_1x1conv:
                self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
            else:
                self.conv3 = None
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.bn2 = nn.BatchNorm2d(out_channels)
    
        def forward(self, X):
            Y = F.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            if self.conv3:
                X = self.conv3(X)
            return F.relu(Y + X)
    
        
    def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
        if first_block:
            assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(in_channels, out_channels, use_1x1cOnv=True, stride=2))
            else:
                blk.append(Residual(out_channels, out_channels))
        return nn.Sequential(*blk)
    
    class GlobalAvgPool2d(nn.Module):
        # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
        def __init__(self):
            super(GlobalAvgPool2d, self).__init__()
        def forward(self, x):
            return F.avg_pool2d(x, kernel_size=x.size()[2:])
    
    net = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=7 , stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
    net.add_module("resnet_block4", resnet_block(256, 512, 2))
    net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    net.add_module("fc", nn.Sequential(Reshape(), nn.Linear(512, 7))) 
    

训练效果

  • 让我们看看残差块的设计给我们带来……

  • 带来了更好的过拟合效果(逃

总结

  • 事已至此,我们浏览一下混淆矩阵

    • 0-angry
    • 1-disgust
    • 2-fear
    • 3-happy
    • 4-sad
    • 5-surprised
    • 6-neutral

  • 貌似除了开心和惊喜,其他表情准确率都挺一言难尽的,可能这两个比较好认,笑了就是开心,O型嘴就是惊喜,其他表情别说机器,人都不一定认得出


推荐阅读
  • Day2列表、字典、集合操作详解
    本文详细介绍了列表、字典、集合的操作方法,包括定义列表、访问列表元素、字符串操作、字典操作、集合操作、文件操作、字符编码与转码等内容。内容详实,适合初学者参考。 ... [详细]
  • 本文介绍了如何使用python从列表中删除所有的零,并将结果以列表形式输出,同时提供了示例格式。 ... [详细]
  • OpenMap教程4 – 图层概述
    本文介绍了OpenMap教程4中关于地图图层的内容,包括将ShapeLayer添加到MapBean中的方法,OpenMap支持的图层类型以及使用BufferedLayer创建图像的MapBean。此外,还介绍了Layer背景标志的作用和OMGraphicHandlerLayer的基础层类。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
  • 很多时候在注册一些比较重要的帐号,或者使用一些比较重要的接口的时候,需要使用到随机字符串,为了方便,我们设计这个脚本需要注意 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 本文讨论了Kotlin中扩展函数的一些惯用用法以及其合理性。作者认为在某些情况下,定义扩展函数没有意义,但官方的编码约定支持这种方式。文章还介绍了在类之外定义扩展函数的具体用法,并讨论了避免使用扩展函数的边缘情况。作者提出了对于扩展函数的合理性的质疑,并给出了自己的反驳。最后,文章强调了在编写Kotlin代码时可以自由地使用扩展函数的重要性。 ... [详细]
  • MyBatis多表查询与动态SQL使用
    本文介绍了MyBatis多表查询与动态SQL的使用方法,包括一对一查询和一对多查询。同时还介绍了动态SQL的使用,包括if标签、trim标签、where标签、set标签和foreach标签的用法。文章还提供了相关的配置信息和示例代码。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
  • 从批量eml文件中提取附件的Python代码实现方法
    本文介绍了使用Python代码从批量eml文件中提取附件的实现方法,包括获取eml附件信息、递归文件夹下所有文件、创建目的文件夹等步骤。通过该方法可以方便地提取eml文件中的附件,并保存到指定的文件夹中。 ... [详细]
author-avatar
mobiledu2502861465
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有