热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch实现SiameseNetwork进行人脸识别

继上次人脸识别之后,这次我们来看下。pytorch实现SiameseNetwork进行人脸识别importtorchimportosimportrandom

继上次人脸识别之后,这次我们来看下。

// pytorch实现SiameseNetwork进行人脸识别
import torch
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimroot='data/ORL'
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')class Config():root='data/ORL'txt_root='data/train.txt'train_batch_size=32train_number_epochs=56def show_plot(iteration,loss):plt.plot(iteration,loss)plt.show()def convert(train=True):if(train):try:f=open(Config.txt_root,'w')except:print('error')data_path=root+'/'if not os.path.exists(data_path):os.makedirs(data_path)names=[name for name in os.listdir(data_path)]for name in os.listdir(data_path):name_path=os.path.join(root,name)for img in os.listdir(name_path):img_path=os.path.join(name_path,img)f.write(img_path+' '+str(names.index(name))+'\n')f.close()class MyDataset(Dataset):def __init__(self,txt,transform=None,should_invert=False):self.transform=transformself.should_invert=should_invertself.txt=txtdef __getitem__(self, index):line=linecache.getline(self.txt,random.randint(1,self.__len__()))line.strip('\n')img0_list=line.split()#若为0,取得不同人的图片shouled_get_same_class=random.randint(0,1)if shouled_get_same_class:while True:img1_list=linecache.getline(self.txt,random.randint(1,self.__len__())).strip('\n').split()if img0_list[1]==img1_list[1]:breakelse:while True:img1_list=linecache.getline(self.txt,random.randint(1,self.__len__())).strip('\n').split()if img0_list[1]!=img1_list[1]:breakim0=Image.open(img0_list[0]).convert('L')im1=Image.open(img1_list[0]).convert('L')if self.transform is not None:im0=self.transform(im0)im1=self.transform(im1)return im0,im1,torch.from_numpy(np.array([int(img0_list[1]!=img1_list[1])],dtype=np.float32))def __len__(self):fh=open(self.txt,'r')num=len(fh.readlines())fh.close()return numclass SiameseNetwork(nn.Module):def __init__(self):super(SiameseNetwork,self).__init__()self.cnn1=nn.Sequential(nn.Conv2d(1,4,kernel_size=5),nn.BatchNorm2d(4),nn.ReLU(inplace=True),nn.Conv2d(4, 8, kernel_size=5),nn.BatchNorm2d(8),nn.ReLU(inplace=True),nn.Conv2d(8, 8, kernel_size=3),nn.BatchNorm2d(8),nn.ReLU(inplace=True),)self.fc1=nn.Sequential(nn.Linear(8 * 90 * 90,500),nn.ReLU(inplace=True),nn.Linear(500,500),nn.ReLU(inplace=True),nn.Linear(500,40))def forward_once(self, x):optput=self.cnn1(x)optput=optput.view(optput.size()[0],-1)optput=self.fc1(optput)return optputdef forward(self, input1,input2):output1=self.forward_once(input1)output2=self.forward_once(input2)return output1,output2class ContrastiveLoss(torch.nn.Module):def __init__(self,margin=2.0):super(ContrastiveLoss,self).__init__()self.margin=margindef forward(self,output1,output2,label):# 方法一label=label.view(label.size()[0],)euclidean_distance=F.pairwise_distance(output1,output2)loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))##方法二# euclidean_distance=F.pairwise_distance(output1,output2,keepdim=True)# loss_contrastive=torch.mean((1-label) * torch.pow(euclidean_distance,2)+# (label) * torch.pow(torch.clamp(self.margin-euclidean_distance,min=0.0),2))return loss_contrastiveif __name__ == '__main__':convert(True)train_data=MyDataset(txt=Config.txt_root,transform=transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor()]),should_invert=False)train_loader=DataLoader(dataset=train_data,shuffle=True,batch_size=Config.train_batch_size)net=SiameseNetwork().to(device)criterion=ContrastiveLoss()optimizer=optim.Adam(net.parameters(),lr=0.001)counter=[]loss_history=[]iteration_number=0for epoch in range(0,Config.train_number_epochs):for i,data in enumerate(train_loader,0):img0,img1,label=dataimg0,img1,label=img0.to(device),img1.to(device),label.to(device)optimizer.zero_grad()output1,output2=net(img0,img1)loss_contrastive=criterion(output1,output2,label)loss_contrastive.backward()optimizer.step()if i %10 ==0:print('epoch:{},loss:{}\n'.format(epoch,loss_contrastive.item()))iteration_number+=10counter.append(iteration_number)loss_history.append(loss_contrastive.item())torch.save(net,'data/model/8model')show_plot(counter,loss_history)"""
测试
"""import torch
from torch.autograd import Variable
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import PIL.ImageOps
import matplotlib.pyplot as plt
import torch.nn.functional as Fdevice=torch.device('cuda' if torch.cuda.is_available() else 'cpu')transform=transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor()
])class SiameseNetwork(nn.Module):def __init__(self):super(SiameseNetwork,self).__init__()self.cnn1=nn.Sequential(nn.Conv2d(1,4,kernel_size=5),nn.BatchNorm2d(4),nn.ReLU(inplace=True),nn.Conv2d(4, 8, kernel_size=5),nn.BatchNorm2d(8),nn.ReLU(inplace=True),nn.Conv2d(8, 8, kernel_size=3),nn.BatchNorm2d(8),nn.ReLU(inplace=True),)self.fc1=nn.Sequential(nn.Linear(8 * 90 * 90,500),nn.ReLU(inplace=True),nn.Linear(500,500),nn.ReLU(inplace=True),nn.Linear(500,40))def forward_once(self, x):optput=self.cnn1(x)optput=optput.view(optput.size()[0],-1)optput=self.fc1(optput)return optputdef forward(self, input1,input2):output1=self.forward_once(input1)output2=self.forward_once(input2)return output1,output2model=torch.load('data/model/8model')
model.eval()img1=Image.open('data/ORL/s2/s2_0002.png').convert('L')
img2=Image.open('data/ORL/s2/s2_0008.png').convert('L')img1=transform(img1)
img2=transform(img2)imgs1=np.array(img1)[0,...]
imgs2=np.array(img2)[0,...]input1=img1.unsqueeze(0).to(device)
input2=img2.unsqueeze(0).to(device)output1,output2=model(input1,input2)
en_dis=F.pairwise_distance(output1,output2)print('endis_:',en_dis)
diff=en_dis.cpu().detach().numpy()[0]
print(en_dis.cpu().detach().numpy()[0])plt.subplot(1,2,1)
plt.title('diff='+str(diff))
plt.imshow(imgs1,cmap='gray')
plt.subplot(1,2,2)
plt.imshow(imgs2,cmap='gray')
plt.show()

两个相同人比较结果
在这里插入图片描述
loss值
在这里插入图片描述


推荐阅读
  • PHP图片截取方法及应用实例
    本文介绍了使用PHP动态切割JPEG图片的方法,并提供了应用实例,包括截取视频图、提取文章内容中的图片地址、裁切图片等问题。详细介绍了相关的PHP函数和参数的使用,以及图片切割的具体步骤。同时,还提供了一些注意事项和优化建议。通过本文的学习,读者可以掌握PHP图片截取的技巧,实现自己的需求。 ... [详细]
  • Html5-Canvas实现简易的抽奖转盘效果
    本文介绍了如何使用Html5和Canvas标签来实现简易的抽奖转盘效果,同时使用了jQueryRotate.js旋转插件。文章中给出了主要的html和css代码,并展示了实现的基本效果。 ... [详细]
  • 本文讨论了在openwrt-17.01版本中,mt7628设备上初始化启动时eth0的mac地址总是随机生成的问题。每次随机生成的eth0的mac地址都会写到/sys/class/net/eth0/address目录下,而openwrt-17.01原版的SDK会根据随机生成的eth0的mac地址再生成eth0.1、eth0.2等,生成后的mac地址会保存在/etc/config/network下。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • 本文详细介绍了Linux中进程控制块PCBtask_struct结构体的结构和作用,包括进程状态、进程号、待处理信号、进程地址空间、调度标志、锁深度、基本时间片、调度策略以及内存管理信息等方面的内容。阅读本文可以更加深入地了解Linux进程管理的原理和机制。 ... [详细]
  • 摘要: 在测试数据中,生成中文姓名是一个常见的需求。本文介绍了使用C#编写的随机生成中文姓名的方法,并分享了相关代码。作者欢迎读者提出意见和建议。 ... [详细]
  • 本文讨论了在iOS平台中的Metal框架中,对于if语句中的判断条件的限制和处理方式。作者提到了在Metal shader中,判断条件不能写得太长太复杂,否则可能导致程序停留或没有响应。作者还分享了自己的经验,建议在CPU端进行处理,以避免出现问题。 ... [详细]
  • 这篇文章主要介绍了Python拼接字符串的七种方式,包括使用%、format()、join()、f-string等方法。每种方法都有其特点和限制,通过本文的介绍可以帮助读者更好地理解和运用字符串拼接的技巧。 ... [详细]
  • 模板引擎StringTemplate的使用方法和特点
    本文介绍了模板引擎StringTemplate的使用方法和特点,包括强制Model和View的分离、Lazy-Evaluation、Recursive enable等。同时,还介绍了StringTemplate语法中的属性和普通字符的使用方法,并提供了向模板填充属性的示例代码。 ... [详细]
  • 本文介绍了一个编程问题,要求求解一个给定n阶方阵的鞍点个数。通过输入格式的描述,可以了解到输入的是一个n阶方阵,每个元素都是整数。通过输出格式的描述,可以了解到输出的是鞍点的个数。通过题目集全集传送门,可以了解到提供了两个函数is_line_max和is_rank_min,用于判断一个元素是否为鞍点。本文还提供了三个样例,分别展示了不同情况下的输入和输出。 ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
  • 深入理解Java虚拟机的并发编程与性能优化
    本文主要介绍了Java内存模型与线程的相关概念,探讨了并发编程在服务端应用中的重要性。同时,介绍了Java语言和虚拟机提供的工具,帮助开发人员处理并发方面的问题,提高程序的并发能力和性能优化。文章指出,充分利用计算机处理器的能力和协调线程之间的并发操作是提高服务端程序性能的关键。 ... [详细]
  • 使用Spring AOP实现切面编程的步骤和注意事项
    本文介绍了使用Spring AOP实现切面编程的步骤和注意事项。首先解释了@EnableAspectJAutoProxy、@Aspect、@Pointcut等注解的作用,并介绍了实现AOP功能的方法。然后详细介绍了创建切面、编写测试代码的过程,并展示了测试结果。接着讲解了关于环绕通知的使用方法,并修改了FirstTangent类以添加环绕通知方法。最后介绍了利用AOP拦截注解的方法,只需修改全局切入点即可实现。使用Spring AOP进行切面编程可以方便地实现对代码的增强和拦截。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
author-avatar
曾经沧海难为水95531837155423
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有