继上次人脸识别之后,这次我们来看下。
import torch
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimroot='data/ORL'
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')class Config():root='data/ORL'txt_root='data/train.txt'train_batch_size=32train_number_epochs=56def show_plot(iteration,loss):plt.plot(iteration,loss)plt.show()def convert(train=True):if(train):try:f=open(Config.txt_root,'w')except:print('error')data_path=root+'/'if not os.path.exists(data_path):os.makedirs(data_path)names=[name for name in os.listdir(data_path)]for name in os.listdir(data_path):name_path=os.path.join(root,name)for img in os.listdir(name_path):img_path=os.path.join(name_path,img)f.write(img_path+' '+str(names.index(name))+'\n')f.close()class MyDataset(Dataset):def __init__(self,txt,transform=None,should_invert=False):self.transform=transformself.should_invert=should_invertself.txt=txtdef __getitem__(self, index):line=linecache.getline(self.txt,random.randint(1,self.__len__()))line.strip('\n')img0_list=line.split()#若为0,取得不同人的图片shouled_get_same_class=random.randint(0,1)if shouled_get_same_class:while True:img1_list=linecache.getline(self.txt,random.randint(1,self.__len__())).strip('\n').split()if img0_list[1]==img1_list[1]:breakelse:while True:img1_list=linecache.getline(self.txt,random.randint(1,self.__len__())).strip('\n').split()if img0_list[1]!=img1_list[1]:breakim0=Image.open(img0_list[0]).convert('L')im1=Image.open(img1_list[0]).convert('L')if self.transform is not None:im0=self.transform(im0)im1=self.transform(im1)return im0,im1,torch.from_numpy(np.array([int(img0_list[1]!=img1_list[1])],dtype=np.float32))def __len__(self):fh=open(self.txt,'r')num=len(fh.readlines())fh.close()return numclass SiameseNetwork(nn.Module):def __init__(self):super(SiameseNetwork,self).__init__()self.cnn1=nn.Sequential(nn.Conv2d(1,4,kernel_size=5),nn.BatchNorm2d(4),nn.ReLU(inplace=True),nn.Conv2d(4, 8, kernel_size=5),nn.BatchNorm2d(8),nn.ReLU(inplace=True),nn.Conv2d(8, 8, kernel_size=3),nn.BatchNorm2d(8),nn.ReLU(inplace=True),)self.fc1=nn.Sequential(nn.Linear(8 * 90 * 90,500),nn.ReLU(inplace=True),nn.Linear(500,500),nn.ReLU(inplace=True),nn.Linear(500,40))def forward_once(self, x):optput=self.cnn1(x)optput=optput.view(optput.size()[0],-1)optput=self.fc1(optput)return optputdef forward(self, input1,input2):output1=self.forward_once(input1)output2=self.forward_once(input2)return output1,output2class ContrastiveLoss(torch.nn.Module):def __init__(self,margin=2.0):super(ContrastiveLoss,self).__init__()self.margin=margindef forward(self,output1,output2,label):# 方法一label=label.view(label.size()[0],)euclidean_distance=F.pairwise_distance(output1,output2)loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))##方法二# euclidean_distance=F.pairwise_distance(output1,output2,keepdim=True)# loss_contrastive=torch.mean((1-label) * torch.pow(euclidean_distance,2)+# (label) * torch.pow(torch.clamp(self.margin-euclidean_distance,min=0.0),2))return loss_contrastiveif __name__ == '__main__':convert(True)train_data=MyDataset(txt=Config.txt_root,transform=transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor()]),should_invert=False)train_loader=DataLoader(dataset=train_data,shuffle=True,batch_size=Config.train_batch_size)net=SiameseNetwork().to(device)criterion=ContrastiveLoss()optimizer=optim.Adam(net.parameters(),lr=0.001)counter=[]loss_history=[]iteration_number=0for epoch in range(0,Config.train_number_epochs):for i,data in enumerate(train_loader,0):img0,img1,label=dataimg0,img1,label=img0.to(device),img1.to(device),label.to(device)optimizer.zero_grad()output1,output2=net(img0,img1)loss_contrastive=criterion(output1,output2,label)loss_contrastive.backward()optimizer.step()if i %10 ==0:print('epoch:{},loss:{}\n'.format(epoch,loss_contrastive.item()))iteration_number+=10counter.append(iteration_number)loss_history.append(loss_contrastive.item())torch.save(net,'data/model/8model')show_plot(counter,loss_history)"""
测试
"""import torch
from torch.autograd import Variable
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import PIL.ImageOps
import matplotlib.pyplot as plt
import torch.nn.functional as Fdevice=torch.device('cuda' if torch.cuda.is_available() else 'cpu')transform=transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor()
])class SiameseNetwork(nn.Module):def __init__(self):super(SiameseNetwork,self).__init__()self.cnn1=nn.Sequential(nn.Conv2d(1,4,kernel_size=5),nn.BatchNorm2d(4),nn.ReLU(inplace=True),nn.Conv2d(4, 8, kernel_size=5),nn.BatchNorm2d(8),nn.ReLU(inplace=True),nn.Conv2d(8, 8, kernel_size=3),nn.BatchNorm2d(8),nn.ReLU(inplace=True),)self.fc1=nn.Sequential(nn.Linear(8 * 90 * 90,500),nn.ReLU(inplace=True),nn.Linear(500,500),nn.ReLU(inplace=True),nn.Linear(500,40))def forward_once(self, x):optput=self.cnn1(x)optput=optput.view(optput.size()[0],-1)optput=self.fc1(optput)return optputdef forward(self, input1,input2):output1=self.forward_once(input1)output2=self.forward_once(input2)return output1,output2model=torch.load('data/model/8model')
model.eval()img1=Image.open('data/ORL/s2/s2_0002.png').convert('L')
img2=Image.open('data/ORL/s2/s2_0008.png').convert('L')img1=transform(img1)
img2=transform(img2)imgs1=np.array(img1)[0,...]
imgs2=np.array(img2)[0,...]input1=img1.unsqueeze(0).to(device)
input2=img2.unsqueeze(0).to(device)output1,output2=model(input1,input2)
en_dis=F.pairwise_distance(output1,output2)print('endis_:',en_dis)
diff=en_dis.cpu().detach().numpy()[0]
print(en_dis.cpu().detach().numpy()[0])plt.subplot(1,2,1)
plt.title('diff='+str(diff))
plt.imshow(imgs1,cmap='gray')
plt.subplot(1,2,2)
plt.imshow(imgs2,cmap='gray')
plt.show()
两个相同人比较结果
![在这里插入图片描述](https://img4.php1.cn/3cdc5/6b7c/696/6998b97191980927.png)
loss值
![在这里插入图片描述](https://img4.php1.cn/3cdc5/6b7c/696/06a2af1b630e4307.png)