热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】

文章目录GAT代码实现【PyGAT】GraphAttentionLayer【一个图注意力层实现】用上面实现的单层网络测试加入Multi-head机制的GAT对数据集Cora的处理c



文章目录


    • GAT
    • 代码实现【PyGAT】
      • GraphAttentionLayer【一个图注意力层实现】
      • 用上面实现的单层网络测试
      • 加入Multi-head机制的GAT
      • 对数据集Cora的处理
      • csr_matrix()处理稀疏矩阵
      • encode_onehot()对label编号
      • build graph
      • 邻接矩阵构造

    • GAT的推广



GAT

题:Graph Attention Networks
摘要:
提出了图形注意网络(GAT) ,这是一种基于图结构数据的新型神经网络结构,利用掩蔽的自我注意层来解决基于图卷积或其近似的先前方法的缺点。通过叠加层,节点能够参与其邻域的特征,我们能够(隐式地)为邻域中的不同节点指定不同的权重,而不需要任何代价高昂的矩阵操作(如反演) ,或者依赖于预先知道图的结构。通过这种方法,我们同时解决了基于谱的图形神经网络的几个关键问题,并使我们的模型容易地适用于归纳和转导问题。我们的 GAT 模型已经实现或匹配了四个已建立的转导和归纳图基准的最新结果: Cora,Citeseer 和 Pubmed 引用网络数据集,以及protein-protein interaction dataset(其中测试图在训练期间保持不可见)。

Paper with code 网址,可找到对应论文和github源码,原论文使用TensorFlow实现,本篇主要对Pytorch版本的 PyGAT附详细注释帮助理解和测试。

GitHUb: keras版本实现
Pytorch版本实现 PyGAT

在这里插入图片描述
在这里插入图片描述

截图及下文代码注释参考自视频:GAT详解及代码实现

视频中的eij的实现与源码不同,视频中是先拼接两个W,再与a乘;
源码在_prepare_attentional_mechanism_input()函数中先分别与a乘,再拼接

代码实现【PyGAT】

在PyGAT :

  • layers.py中定义Simple GAT layer实现(GraphAttentionLayer)和Sparse version GAT layer实现(SpGraphAttentionLayer)。
  • models.py 实现两个版本加入Multi-head机制
  • trains.py 使用model定义的GAT构建模型进行训练,使用cora数据集

GraphAttentionLayer【一个图注意力层实现】

class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""

def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features#结点向量的特征维度
self.out_features = out_features#经过GAT之后的特征维度
self.alpha = alpha#dropout参数
self.concat = concat#LeakyReLU参数
# 定义可训练参数,即论文中的W和a
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)# xavier初始化
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)# xavier初始化
# 定义leakyReLU激活函数
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, h, adj):
'''
adj图邻接矩阵,维度[N,N]非零即一
h.shape: (N, in_features), self.W.shape:(in_features,out_features)
Wh.shape: (N, out_features)
'''

Wh = torch.mm(h, self.W) # 对应eij的计算公式
e = self._prepare_attentional_mechanism_input(Wh)#对应LeakyReLU(eij)计算公式
zero_vec = -9e15*torch.ones_like(e)#将没有链接的边设置为负无穷
attention = torch.where(adj > 0, e, zero_vec)#[N,N]
# 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留
# 否则需要mask设置为非常小的值,因为softmax的时候这个最小值会不考虑
attention = F.softmax(attention, dim=1)# softmax形状保持不变[N,N],得到归一化的注意力全忠!
attention = F.dropout(attention, self.dropout, training=self.training)# dropout,防止过拟合
h_prime = torch.matmul(attention, Wh)#[N,N].[N,out_features]=>[N,out_features]
# 得到由周围节点通过注意力权重进行更新后的表示
if self.concat:
return F.elu(h_prime)
else:
return h_prime
def _prepare_attentional_mechanism_input(self, Wh):
# Wh.shape (N, out_feature)
# self.a.shape (2 * out_feature, 1)
# Wh1&2.shape (N, 1)
# e.shape (N, N)
# 先分别与a相乘再进行拼接
Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
# broadcast add
e = Wh1 + Wh2.T
return self.leakyrelu(e)
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

用上面实现的单层网络测试

x = torch.randn(6,10)
adj=torch.tensor([[0,1,1,0,0,0],
[1,0,1,0,0,0],
[1,1,0,1,0,0],
[0,0,1,0,1,1],
[0,0,0,1,0,0,],
[0,0,0,1,1,0]])
my_gat = GraphAttentionLayer(10,5,0.2,0.2)
print(my_gat(x,adj))

输出:
tensor([[-0.2965, 2.8110, -0.6680, -0.9643, -0.9882],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.4981, -0.7515, 1.1159, 0.3546, 1.3592],
[ 0.4679, 1.7208, 0.3084, -0.5331, -0.1291],
[-0.4375, -0.8778, 1.1767, -0.5869, 1.5154],
[-0.2164, -0.5897, 0.4988, -0.3125, 0.6423]], grad_fn&#61;<EluBackward>)

加入Multi-head机制的GAT

用不同head捕捉不同特征&#xff0c;使模型有更好的拟合能力。

class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout &#61; dropout
# 加入Multi-head机制
self.attentions &#61; [GraphAttentionLayer(nfeat, nhid, dropout&#61;dropout, alpha&#61;alpha, concat&#61;True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module(&#39;attention_{}&#39;.format(i), attention)
self.out_att &#61; GraphAttentionLayer(nhid * nheads, nclass, dropout&#61;dropout, alpha&#61;alpha, concat&#61;False)
def forward(self, x, adj):
x &#61; F.dropout(x, self.dropout, training&#61;self.training)
x &#61; torch.cat([att(x, adj) for att in self.attentions], dim&#61;1)
x &#61; F.dropout(x, self.dropout, training&#61;self.training)
x &#61; F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim&#61;1)

对数据集Cora的处理

在这里插入图片描述
数据集中两个文件&#xff0c;cites&#xff1a;比如上图11行&#xff1a;编号25和编号1114331的文章
content文件&#xff1a;如下图&#xff0c;每篇文章的id、features及类别
在这里插入图片描述

csr_matrix()处理稀疏矩阵

utils.py中对数据进行的处理

#数据是稀疏的&#xff0c;csr_matrix操作从行开始将1的位置取出来,对数据进行压缩

features &#61; sp.csr_matrix(idx_features_labels[:, 1:-1], dtype&#61;np.float32)
labels &#61; encode_onehot(idx_features_labels[:, -1])

在这里插入图片描述

encode_onehot()对label编号

有7个类别&#xff0c;通过classes_dict是7*7的对角阵把每个类别映射成不同向量&#xff0c;对所有label进行编号&#xff0c;再将编号转换为one_hot向量
在这里插入图片描述

def encode_onehot(labels):
# The classes must be sorted before encoding to enable static class encoding.
# In other words, make sure the first class always maps to index 0.
classes &#61; sorted(list(set(labels)))
classes_dict &#61; {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
labels_onehot &#61; np.array(list(map(classes_dict.get, labels)), dtype&#61;np.int32)
return labels_onehot

build graph

见注释&#xff1a;

# build graph
idx &#61; np.array(idx_features_labels[:, 0], dtype&#61;np.int32)#获取所有文章id
idx_map &#61; {j: i for i, j in enumerate(idx)}#按文章数目&#xff0c;对id重新映射
# 读取数据集中文章和文章直接的引用关系
edges_unordered &#61; np.genfromtxt("{}{}.cites".format(path, dataset), dtype&#61;np.int32)
# 根据idx_map,将文章引用关系也重新映射
edges &#61; np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype&#61;np.int32).reshape(edges_unordered.shape)
adj &#61; sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape&#61;(labels.shape[0], labels.shape[0]), dtype&#61;np.float32)
# build symmetric adjacency matrix 生成邻接矩阵
adj &#61; adj &#43; adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
features &#61; normalize_features(features)
adj &#61; normalize_adj(adj &#43; sp.eye(adj.shape[0]))

邻接矩阵构造

csr_matrix()只记录了&#xff08;0&#xff0c;1&#xff09;1,忽略了&#xff08;1&#xff0c;0&#xff09;1。所以需要coo_matrix()操作&#xff01;才能还原出无向图的邻接矩阵&#xff01;
在这里插入图片描述

本文的一些代码注释及截图还可见视频

一个拓展&#xff1a;

GAT的推广

GAT的推广
GAT仅仅是应用在了单层图结构网络上&#xff0c;我们是否可以将它推广到多层网络结构呢&#xff1f;

这里我们假设一个有N层网络的结构&#xff0c;每层网络都定义了相同的节点&#xff0c;但是节点之间的关系有所差异。举一个简单的例子&#xff0c;假设有一个用户关系型网络&#xff0c;每层网络的节点都被定义成了网络中的用户&#xff0c;网络的第一层视图的关系可以定义为&#xff0c;两个用户之间是否具有好友关系&#xff1b;网络的第二层视图可以定义为&#xff0c;你评论过我的动态&#xff1b;网络的第三层视图可以定义为你转发过我的动态&#xff1b;第四层关系可以定义为&#xff0c;你at过我等等。

通过这样的定义我们就完成了一个多层网络的构建&#xff0c;他们共享相同的节点&#xff0c;但又分别具有不同的邻边&#xff0c;如果我们分别处理每一层视图视图&#xff0c;然后将他们得出的节点表示单纯相加的话&#xff0c;就可能会失去不同视图之间的协作关系&#xff0c;降低分类&#xff08;预测&#xff09;的精度。

基于以上观点&#xff0c;我们提出了一种新的方法&#xff1a;首先在每一层单视图中应用GAT进行学习&#xff0c;并计算出每层视图的节点表示。之后再不同视图之间引入attention机制来让网络自行学习不同视图的权重。之后根据学习的权重&#xff0c;将各个视图加权相加得到全局节点表示并进行后续的诸如节点表示&#xff0c;链接预测等任务。

同时&#xff0c;因为不同视图共享同样的节点&#xff0c;即使每一层视图都表示了不同的节点关系&#xff0c;最终得到的每一层的节点嵌入表示应具有一定的相关性。基于以上理论&#xff0c;我们在每层GAT的网络参数间引入正则化项来约束参数&#xff0c;使其向互相相近的方向学习。大致的网络流程图如下&#xff1a;

这部分来源于 链接&#xff1a;https://www.jianshu.com/p/d5d366ba1a57 来源&#xff1a;简书







推荐阅读
  • 本文详细介绍了在Windows操作系统上使用Python 3.8.5编译支持CUDA 11和cuDNN 8.0.2的TensorFlow 2.3的步骤。文章不仅提供了详细的编译指南,还分享了编译后的文件下载链接,方便用户快速获取所需资源。此外,文中还涵盖了常见的编译问题及其解决方案,确保用户能够顺利进行编译和安装。 ... [详细]
  • 本文提供了PyTorch框架中常用的预训练模型的下载链接及详细使用指南,涵盖ResNet、Inception、DenseNet、AlexNet、VGGNet等六大分类模型。每种模型的预训练参数均经过精心调优,适用于多种计算机视觉任务。文章不仅介绍了模型的下载方式,还详细说明了如何在实际项目中高效地加载和使用这些模型,为开发者提供全面的技术支持。 ... [详细]
  • 处理Android EditText中数字输入与parseInt方法
    本文探讨了如何在Android应用中从EditText组件安全地获取并解析用户输入的数字,特别是用于设置端口号的情况。通过示例代码和异常处理策略,展示了有效的方法来避免因非法输入导致的应用崩溃。 ... [详细]
  • 在1995年,Simon Plouffe 发现了一种特殊的求和方法来表示某些常数。两年后,Bailey 和 Borwein 在他们的论文中发表了这一发现,这种方法被命名为 Bailey-Borwein-Plouffe (BBP) 公式。该问题要求计算圆周率 π 的第 n 个十六进制数字。 ... [详细]
  • spring boot使用jetty无法启动 ... [详细]
  • Android与JUnit集成测试实践
    本文探讨了如何在Android项目中集成JUnit进行单元测试,并详细介绍了修改AndroidManifest.xml文件以支持测试的方法。 ... [详细]
  • 使用QT构建基础串口辅助工具
    本文详细介绍了如何利用QT框架创建一个简易的串口助手应用程序,包括项目的建立、界面设计与编程实现、运行测试以及最终的应用程序打包。 ... [详细]
  • 在Windows系统中安装TensorFlow GPU版的详细指南与常见问题解决
    在Windows系统中安装TensorFlow GPU版是许多深度学习初学者面临的挑战。本文详细介绍了安装过程中的每一个步骤,并针对常见的问题提供了有效的解决方案。通过本文的指导,读者可以顺利地完成安装并避免常见的陷阱。 ... [详细]
  • 在Conda环境中高效配置并安装PyTorch和TensorFlow GPU版的方法如下:首先,创建一个新的Conda环境以避免与基础环境发生冲突,例如使用 `conda create -n pytorch_gpu python=3.7` 命令。接着,激活该环境,确保所有依赖项都正确安装。此外,建议在安装过程中指定CUDA版本,以确保与GPU兼容性。通过这些步骤,可以确保PyTorch和TensorFlow GPU版的顺利安装和运行。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 在Windows命令行中,通过Conda工具可以高效地管理和操作虚拟环境。具体步骤包括:1. 列出现有虚拟环境:`conda env list`;2. 创建新虚拟环境:`conda create --name 环境名`;3. 删除虚拟环境:`conda env remove --name 环境名`。这些命令不仅简化了环境管理流程,还提高了开发效率。此外,Conda还支持环境文件导出和导入,方便在不同机器间迁移配置。 ... [详细]
  • 本文通过一个具体的实例,介绍如何利用TensorFlow框架来计算神经网络模型在多分类任务中的Top-K准确率。代码中包含了随机种子设置、模拟预测结果生成、真实标签生成以及准确率计算等步骤。 ... [详细]
  • 深入理解Dockerfile及其作用
    Dockerfile是一种文本格式的配置文件,用于定义构建Docker镜像所需的步骤。通过使用`docker build`命令,用户可以将Dockerfile中的一系列指令转换成一个可执行的Docker镜像。 ... [详细]
  • 在将 Android Studio 从 3.0 升级到 3.1 版本后,遇到项目无法正常编译的问题,具体错误信息为:org.gradle.api.tasks.TaskExecutionException: Execution failed for task ':app:processDemoProductDebugResources'。 ... [详细]
author-avatar
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有