热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Fer2013表情识别pytorch(CNN、VGG、Resnet)

#fer2013数据集##数据集介绍*Fer2013人脸表情数据集由35886张人脸表情图片组成,其中,测试图(Training)28708张,公共验证图(PublicT

fer2013数据集

数据集介绍

  • Fer2013人脸表情数据集由35886张人脸表情图片组成,其中,测试图(Training)28708张,公共验证图(PublicTest)和私有验证图(PrivateTest)各3589张,每张图片是由大小固定为48×48的灰度图像组成,共有7种表情,分别对应于数字标签0-6,具体表情对应的标签和中英文如下:0 anger 生气; 1 disgust 厌恶; 2 fear 恐惧; 3 happy 开心; 4 sad 伤心;5 surprised 惊讶; 6 normal 中性。

数据整理

  • 数据给的是一个csv文件,其中的表情数据并没有直接给图片,而是给了像素值,没关系,整理的时候顺便转换成图片就好
  • 将数据分类顺便转换成图片,这里直接分成训练集和验证集两个文件夹。
    import numpy as np
    import pandas as pd
    from PIL import Image
    import os
    
    train_path = \'./data/train/\'
    vaild_path = \'./data/vaild/\'
    data_path = \'./icml_face_data.csv\'
    
    def make_dir():
        for i in range(0,7):
            p1 = os.path.join(train_path,str(i))
            p2 = os.path.join(vaild_path,str(i))
            if not os.path.exists(p1):
                os.makedirs(p1)
            if not os.path.exists(p2):
                os.makedirs(p2)       
    
    def save_images():
        df = pd.read_csv(data_path)
        t_i = [1 for i in range(0,7)]
        v_i = [1 for i in range(0,7)]
        for index in range(len(df)):
            emotion = df.loc[index][0]
            usage = df.loc[index][1] 
            image = df.loc[index][2]
            data_array = list(map(float, image.split()))
            data_array = np.asarray(data_array)
            image = data_array.reshape(48, 48)
            im = Image.fromarray(image).convert(\'L\')#8位黑白图片
            if(usage==\'Training\'):
                t_p = os.path.join(train_path,str(emotion),\'{}.jpg\'.format(t_i[emotion]))
                im.save(t_p)
                t_i[emotion] += 1
                #print(t_p)
            else:
                v_p = os.path.join(vaild_path,str(emotion),\'{}.jpg\'.format(v_i[emotion]))
                im.save(v_p)
                v_i[emotion] += 1
                #print(v_p)
    
    make_dir()
    save_images()
    

简单分析

  • 整理好后看一下数据的分布情况,我们可以看到厌恶表情的数据特别少,其他表情尚可。

数据预处理

  • 我们可以对这些灰度图片做一点数据增强
    path_train = \'./data/train/\'
    path_vaild = \'./data/vaild/\'
    
    transforms_train = transforms.Compose([
        transforms.Grayscale(),#使用ImageFolder默认扩展为三通道,重新变回去就行
        transforms.RandomHorizontalFlip(),#随机翻转
        transforms.ColorJitter(brightness=0.5, cOntrast=0.5),#随机调整亮度和对比度
        transforms.ToTensor()
    ])
    transforms_vaild = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor()
    ])
    
    data_train = torchvision.datasets.ImageFolder(root=path_train,transform=transforms_train)
    data_vaild = torchvision.datasets.ImageFolder(root=path_vaild,transform=transforms_vaild)
    
    train_set = torch.utils.data.DataLoader(dataset=data_train,batch_size=BATCH_SIZE,shuffle=True)
    vaild_set = torch.utils.data.DataLoader(dataset=data_vaild,batch_size=BATCH_SIZE,shuffle=False)
    
  • 看一下效果
    for i in range(1,16+1):
        plt.subplot(4,4,i)
        plt.imshow(data_train[0][0],cmap=\'Greys_r\')
        plt.axis(\'off\')
    plt.show()
    

CNN

模型搭建

  • 使用nn.Sequential快速搭建模型
    CNN = nn.Sequential(
        nn.Conv2d(1,64,3),
        nn.ReLU(True),
        nn.MaxPool2d(2,2),
        nn.Conv2d(64,256,3),
        nn.ReLU(True),
        nn.MaxPool2d(3,3),
        Reshape(),# 两个卷积和池化后,tensor形状为(batchsize,256,7,7)
        nn.Linear(256*7*7,4096),
        nn.ReLU(True),
        nn.Linear(4096,1024),
        nn.ReLU(True),
        nn.Linear(1024,7)
        )
    
  • 其中自己实现Reshape,将tensor打平以送入全连接层
    class Reshape(nn.Module):
        def __init__(self, *args):
            super(Reshape, self).__init__()
    
        def forward(self, x):
            return x.view(x.shape[0],-1)
    

训练效果

  • 显然,在第17个epoch的时候验证集准确率就到了瓶颈

VGG

模型搭建

  • def vgg_block(num_convs, in_channels, out_channels):
        blk = []
        for i in range(num_convs):
            if i == 0:
                blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            else:
                blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
            blk.append(nn.ReLU())
        blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 这里会使宽高减半
        return nn.Sequential(*blk)
    
    def vgg(conv_arch, fc_features, fc_hidden_units):
        net = nn.Sequential()
        # 卷积层部分
        for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):
            # 每经过一个vgg_block都会使宽高减半
            net.add_module("vgg_block_" + str(i+1), vgg_block(num_convs, in_channels, out_channels))
        # 全连接层部分
        net.add_module("fc", nn.Sequential(
                                    Reshape(),
                                    nn.Linear(fc_features, fc_hidden_units),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hidden_units, fc_hidden_units),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Linear(fc_hidden_units, 7)
                                    ))
        return net
    
    conv_arch = ((1, 3, 32), (1, 32, 64), (2, 64, 128))
    # 经过5个vgg_block, 宽高会减半5次, 变成 224/32 = 7
    fc_features = 128 * 6* 6 # c * w * h
    fc_hidden_units = 1024 
    
    model = vgg(conv_arch, fc_features, fc_hidden_units)
    

训练效果

  • 先训练了30个epoch
  • vgg的优点在于能使用相同的模块快速加深网络,更深的网络可能会带来更好的学习效果,我们可以增加训练次数来观察曲线

Resnet

模型搭建

  • class Residual(nn.Module): 
        def __init__(self, in_channels, out_channels, use_1x1cOnv=False, stride=1):
            super(Residual, self).__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
            if use_1x1conv:
                self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
            else:
                self.conv3 = None
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.bn2 = nn.BatchNorm2d(out_channels)
    
        def forward(self, X):
            Y = F.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            if self.conv3:
                X = self.conv3(X)
            return F.relu(Y + X)
    
        
    def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
        if first_block:
            assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(in_channels, out_channels, use_1x1cOnv=True, stride=2))
            else:
                blk.append(Residual(out_channels, out_channels))
        return nn.Sequential(*blk)
    
    class GlobalAvgPool2d(nn.Module):
        # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
        def __init__(self):
            super(GlobalAvgPool2d, self).__init__()
        def forward(self, x):
            return F.avg_pool2d(x, kernel_size=x.size()[2:])
    
    net = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=7 , stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
    net.add_module("resnet_block4", resnet_block(256, 512, 2))
    net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    net.add_module("fc", nn.Sequential(Reshape(), nn.Linear(512, 7))) 
    

训练效果

  • 让我们看看残差块的设计给我们带来……

  • 带来了更好的过拟合效果(逃

总结

  • 事已至此,我们浏览一下混淆矩阵

    • 0-angry
    • 1-disgust
    • 2-fear
    • 3-happy
    • 4-sad
    • 5-surprised
    • 6-neutral

  • 貌似除了开心和惊喜,其他表情准确率都挺一言难尽的,可能这两个比较好认,笑了就是开心,O型嘴就是惊喜,其他表情别说机器,人都不一定认得出


推荐阅读
author-avatar
mobiledu2502861465
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有