图神经网络模型综述
作者:1小柱子8_814 | 来源:互联网 | 2024-11-28 13:27
本文综述了图神经网络(GraphNeuralNetworks,GNN)的发展,从传统的数据存储模型转向图和动态模型,探讨了模型中的显性和隐性结构,并详细介绍了GNN的关键组件及其应用。
### 图神经网络模型综述 随着数据科学的发展,传统的数据存储模型逐渐向图和动态模型转变。图神经网络(Graph Neural Networks, GNN)作为一种新兴的模型,能够在图结构数据中捕捉复杂的依赖关系。尽管模型中可能存在隐性的结构,但显性的结构往往更易于引导和控制。 #### 关键组件 GNN 的核心组件包括传播模块、采样模块和池化模块: 1. **传播模块**:用于在节点之间传播信息,使得聚合的信息能够同时捕获特征信息和拓扑信息。 2. **采样模块**:通常需要在图上进行传播,采样模块通常与传播模块结合使用,以提高效率和准确性。 3. **池化模块**:当需要高级子图或图的整体表示时,池化模块可以从节点中提取关键信息。 #### 传播模块的实现 传播模块通常包含卷积算子和递归算子,这些算子用于聚合来自邻居节点的信息。此外,跳过连接操作可以从节点的历史表示中收集信息,并缓解过度平滑(over-smoothing)问题。 #### GNN 的工作流程 GNN 将图映射到输出的过程通常分为两个步骤: 1. **节点表示生成**:通过传播步骤,生成每个节点的表示。 2. **输出模型**:使用输出模型将每个节点的表示和标签映射为最终的输出。 为了处理图的整体分类任务,一些模型建议引入一个特殊的“超级节点”(supernode),该节点通过特殊边与所有其他节点相连,从而简化整体分类任务。 #### 一般 GNN 模型架构 以下是一般的 GNN 模型架构示意图: ![GNN 架构](https://img.php1.cn/3cd4a/1eebe/cd5/fb32005f2115b419.webp) #### 实现代码示例 以下是使用 DGL 库实现的一个简单的 GNN 模型示例: ```python # -*- coding: utf-8 -*- """ ============================================================= File Name: gcn.py Author: songdongdong Date: 2021/3/8 15:44 Description: GCN (Graph Convolutional Networks) 是一种图卷积网络,提出于 2017 年。 GCN 与 CNN 类似,都是特征提取器,不同的是 GCN 提取的是图数据特征。 ============================================================= """ import torch import torch.nn as nn import torch.nn.functional as F from dgl.nn.pytorch import GraphConv # DGL 库中的图卷积层 from dgl.data import CoraGraphDataset class GCN(nn.Module): def __init__(self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout): super(GCN, self).__init__() self.g = g self.layers = nn.ModuleList() self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) # 输入层 for i in range(n_layers - 1): self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) self.layers.append(GraphConv(n_hidden, n_classes)) # 输出层 self.dropout = nn.Dropout(p=dropout) def forward(self, features): h = features for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) h = layer(self.g, h) return h @torch.no_grad() def evaluate(self, model, features, labels, mask): model.eval() with torch.no_grad(): logits = model(features) logits = logits[mask] labels = labels[mask] _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == labels) return correct.item() * 1.0 / len(labels) def train(self, n_epochs=100, lr=1e-2, weight_decay=5e-4, n_hidden=16, n_layers=1, activation=F.relu, dropout=0.5): data = CoraGraphDataset() g = data[0] features = g.ndata['feat'] labels = g.ndata['label'] train_mask = g.ndata['train_mask'] val_mask = g.ndata['val_mask'] test_mask = g.ndata['test_mask'] in_feats = features.shape[1] n_classes = data.num_classes model = GCN(g, in_feats, n_hidden, n_classes, n_layers, activation, dropout) loss_fcn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for epoch in range(n_epochs): model.train() logits = model(features) loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() acc = self.evaluate(model, features, labels, val_mask) print(f'Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {acc:.4f}') acc = self.evaluate(model, features, labels, test_mask) print(f'Test accuracy: {acc:.2%}') if __name__ == '__main__': gcn = GCN() gcn.train() ``` #### 相关资源 - [RESIDUAL GATED GRAPH CONVNETS](https://arxiv.org/abs/1711.07553) - [GRAPH CONVOLUTIONAL NETWORKS](https://arxiv.org/abs/1609.02907) - [Transformers 作为一种图神经网络](https://arxiv.org/abs/2010.02502) - [DGL 官方教程](https://docs.dgl.ai/en/latest/api/python/dgl.nn.html)
推荐阅读
本文探讨了Flutter和Angular这两个流行框架的主要区别,包括它们的设计理念、适用场景及技术实现。 ...
[详细]
蜡笔小新 2024-11-28 13:19:52
本文将指导你如何通过自定义配置,使 Windows Terminal 中的 PowerShell 7 更加高效且美观。我们将移除默认的广告和提示符,设置快捷键,并添加实用的别名和功能。 ...
[详细]
蜡笔小新 2024-11-28 07:25:46
本文介绍了一种算法,用于在一个给定的二叉树中找到一个节点,该节点的子树包含最大数量的值小于该节点的节点。如果存在多个符合条件的节点,可以选择任意一个。 ...
[详细]
蜡笔小新 2024-11-27 18:08:54
.NetFramework中处理字符和字符串的主要有以下这么几个类:(1)、System.Char类一基础字符串处理类(2)、System.String类一处理不可变的字符串(一经 ...
[详细]
蜡笔小新 2024-11-26 21:04:40
本文介绍了如何使用Workman框架构建一个功能全面的即时通讯系统,该系统不仅支持一对一聊天、群组聊天,还集成了视频会议和实时音视频通话功能,同时提供了红包发送等附加功能。 ...
[详细]
蜡笔小新 2024-11-26 15:42:43
自从AlexNet等模型在计算机视觉领域取得突破以来,深度学习技术迅速发展。近年来,随着BERT等大型模型的广泛应用,AI模型的规模持续扩大,对硬件提出了更高的要求。本文介绍了新加坡国立大学尤洋教授团队开发的夸父AI系统,旨在解决大规模模型训练中的并行计算挑战。 ...
[详细]
蜡笔小新 2024-11-25 19:02:33
本文详细探讨了DropBlock这一正则化方法在卷积神经网络中的应用与效果。通过结构化的dropout方式,即在特征图中连续区域内的单元同时被丢弃,DropBlock有效解决了传统dropout在卷积层应用时效果不佳的问题。更多理论分析及其实现细节可参考原文链接。 ...
[详细]
蜡笔小新 2024-11-28 11:54:39
本文详细介绍了Python中的流程控制与条件判断技术,包括数据导入、数据变换、统计描述、假设检验、可视化以及自定义函数的创建等方面的内容。 ...
[详细]
蜡笔小新 2024-11-27 20:04:59
本文探讨了在Qt框架下实现TCP多线程服务器端的方法,解决了一个常见的问题:服务器端仅能与最后一个连接的客户端通信。通过继承QThread类并利用socketDescriptor标识符,实现了多个客户端与服务器端的同时通信。 ...
[详细]
蜡笔小新 2024-11-27 16:31:40
本文介绍了如何利用snownlp库对微博内容进行情感分析,包括安装、基本使用以及如何自定义训练模型以提高分析准确性。 ...
[详细]
蜡笔小新 2024-11-27 15:01:46
本文主要解决了在编译CM10.2时出现的关于Samsung Exynos 4 HDMI HAL库中SecHdmiV4L2Utils.cpp文件的编译错误。 ...
[详细]
蜡笔小新 2024-11-26 17:26:47
继上次把backTracking的题目做了一下之后:backTracking,我把LeetCode的动态规划的题目又做了一下,还有几道比较难的Medium的题和Hard的题没做出来,后面会继续 ...
[详细]
蜡笔小新 2024-11-26 14:31:10
本文探讨了集群(Cluster)的概念,即通过网络连接的一组计算机系统,它们作为一个整体提供服务,实现分布式计算。文章还详细介绍了负载均衡技术,旨在提高网络服务的效率和可靠性。 ...
[详细]
蜡笔小新 2024-11-26 13:44:24
随着移动互联网的发展,Feed流系统成为了众多社交应用的核心组成部分。本文将深入探讨如何设计一个高效、稳定的Feed流系统,涵盖从基础架构到高级特性的各个方面。 ...
[详细]
蜡笔小新 2024-11-26 12:55:53
本文详细介绍了 SQL Server Express LocalDB,这是一种轻量级的本地 T-SQL 数据库解决方案,特别适合开发环境使用。文章还探讨了 LocalDB 与其他轻量级数据库的对比,并提供了安装和连接 LocalDB 的步骤。 ...
[详细]
蜡笔小新 2024-11-25 20:36:01