热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

图神经网络模型综述

本文综述了图神经网络(GraphNeuralNetworks,GNN)的发展,从传统的数据存储模型转向图和动态模型,探讨了模型中的显性和隐性结构,并详细介绍了GNN的关键组件及其应用。
### 图神经网络模型综述

随着数据科学的发展,传统的数据存储模型逐渐向图和动态模型转变。图神经网络(Graph Neural Networks, GNN)作为一种新兴的模型,能够在图结构数据中捕捉复杂的依赖关系。尽管模型中可能存在隐性的结构,但显性的结构往往更易于引导和控制。

#### 关键组件

GNN 的核心组件包括传播模块、采样模块和池化模块:

1. **传播模块**:用于在节点之间传播信息,使得聚合的信息能够同时捕获特征信息和拓扑信息。
2. **采样模块**:通常需要在图上进行传播,采样模块通常与传播模块结合使用,以提高效率和准确性。
3. **池化模块**:当需要高级子图或图的整体表示时,池化模块可以从节点中提取关键信息。

#### 传播模块的实现

传播模块通常包含卷积算子和递归算子,这些算子用于聚合来自邻居节点的信息。此外,跳过连接操作可以从节点的历史表示中收集信息,并缓解过度平滑(over-smoothing)问题。

#### GNN 的工作流程

GNN 将图映射到输出的过程通常分为两个步骤:

1. **节点表示生成**:通过传播步骤,生成每个节点的表示。
2. **输出模型**:使用输出模型将每个节点的表示和标签映射为最终的输出。

为了处理图的整体分类任务,一些模型建议引入一个特殊的“超级节点”(supernode),该节点通过特殊边与所有其他节点相连,从而简化整体分类任务。

#### 一般 GNN 模型架构

以下是一般的 GNN 模型架构示意图:

![GNN 架构](https://img.php1.cn/3cd4a/1eebe/cd5/fb32005f2115b419.webp)

#### 实现代码示例

以下是使用 DGL 库实现的一个简单的 GNN 模型示例:

```python
# -*- coding: utf-8 -*-
"""
=============================================================
File Name: gcn.py
Author: songdongdong
Date: 2021/3/8 15:44
Description: GCN (Graph Convolutional Networks) 是一种图卷积网络,提出于 2017 年。
GCN 与 CNN 类似,都是特征提取器,不同的是 GCN 提取的是图数据特征。
=============================================================
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv # DGL 库中的图卷积层
from dgl.data import CoraGraphDataset

class GCN(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout):
super(GCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) # 输入层
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
self.layers.append(GraphConv(n_hidden, n_classes)) # 输出层
self.dropout = nn.Dropout(p=dropout)

def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(self.g, h)
return h

@torch.no_grad()
def evaluate(self, model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)

def train(self, n_epochs=100, lr=1e-2, weight_decay=5e-4, n_hidden=16, n_layers=1, activation=F.relu, dropout=0.5):
data = CoraGraphDataset()
g = data[0]
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_classes
model = GCN(g, in_feats, n_hidden, n_classes, n_layers, activation, dropout)
loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
for epoch in range(n_epochs):
model.train()
logits = model(features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = self.evaluate(model, features, labels, val_mask)
print(f'Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {acc:.4f}')
acc = self.evaluate(model, features, labels, test_mask)
print(f'Test accuracy: {acc:.2%}')

if __name__ == '__main__':
gcn = GCN()
gcn.train()
```

#### 相关资源

- [RESIDUAL GATED GRAPH CONVNETS](https://arxiv.org/abs/1711.07553)
- [GRAPH CONVOLUTIONAL NETWORKS](https://arxiv.org/abs/1609.02907)
- [Transformers 作为一种图神经网络](https://arxiv.org/abs/2010.02502)
- [DGL 官方教程](https://docs.dgl.ai/en/latest/api/python/dgl.nn.html)

推荐阅读
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
  • 基因组浏览器中的Wig格式解析
    本文详细介绍了Wiggle(Wig)格式及其在基因组浏览器中的应用,涵盖variableStep和fixedStep两种主要格式的特点、适用场景及具体使用方法。同时,还提供了关于数据值和自定义参数的补充信息。 ... [详细]
  • 基于KVM的SRIOV直通配置及性能测试
    SRIOV介绍、VF直通配置,以及包转发率性能测试小慢哥的原创文章,欢迎转载目录?1.SRIOV介绍?2.环境说明?3.开启SRIOV?4.生成VF?5.VF ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 本文提供了一系列Python编程基础练习题,涵盖了列表操作、循环结构、字符串处理和元组特性等内容。通过这些练习题,读者可以巩固对Python语言的理解并提升编程技能。 ... [详细]
  • 开发笔记:2020 BJDCTF Re encode
    开发笔记:2020 BJDCTF Re encode ... [详细]
  • CentOS 7.6环境下Prometheus与Grafana的集成部署指南
    本文旨在提供一套详细的步骤,指导读者如何在CentOS 7.6操作系统上成功安装和配置Prometheus 2.17.1及Grafana 6.7.2-1,实现高效的数据监控与可视化。 ... [详细]
  • 一个登陆界面
    预览截图html部分123456789101112用户登入1314邮箱名称邮箱为空15密码密码为空16登 ... [详细]
  • 本文档汇总了Python编程的基础与高级面试题目,涵盖语言特性、数据结构、算法以及Web开发等多个方面,旨在帮助开发者全面掌握Python核心知识。 ... [详细]
  • YB02 防水车载GPS追踪器
    YB02防水车载GPS追踪器由Yuebiz科技有限公司设计生产,适用于车辆防盗、车队管理和实时追踪等多种场合。 ... [详细]
  • 本文总结了在使用Ionic 5进行Android平台APK打包时遇到的问题,特别是针对QRScanner插件的改造。通过详细分析和提供具体的解决方法,帮助开发者顺利打包并优化应用性能。 ... [详细]
  • 微软Exchange服务器遭遇2022年版“千年虫”漏洞
    微软Exchange服务器在新年伊始遭遇了一个类似于‘千年虫’的日期处理漏洞,导致邮件传输受阻。该问题主要影响配置了FIP-FS恶意软件引擎的Exchange 2016和2019版本。 ... [详细]
  • 本文介绍如何在PostgreSQL数据库中正确插入和处理JSON数据类型,确保数据完整性和避免常见错误。 ... [详细]
  • 实用正则表达式有哪些
    小编给大家分享一下实用正则表达式有哪些,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下 ... [详细]
author-avatar
1小柱子8_814
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有