热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

多种_DGL中的消息传递相关内容的讲解

篇首语:本文由编程笔记#小编为大家整理,主要介绍了DGL中的消息传递相关内容的讲解相关的知识,希望对你有一定的参考价值。文章目录

篇首语:本文由编程笔记#小编为大家整理,主要介绍了DGL中的消息传递相关内容的讲解相关的知识,希望对你有一定的参考价值。



文章目录


  • 前言
  • 消息传递范式
    • DGL中的自定义消息函数
    • DGL中的自定义聚合函数
    • DGL中的自定义更新函数
    • 实例分析
      • 创建图

    • 消息传递
    • graph.apply_edges

  • 参考


前言

学会DGL中的消息传递,基本就能够比较好的来理解编写各种图神经网络的代码了吧。


消息传递范式

消息传递是实现GNN的一种通用框架和编程范式。它从聚合与更新的角度归纳总结了多种GNN模型的实现。

因此在DGL代码编写消息传递部分时,我们需要三个函数,分别是消息函数、聚合函数、更新函数。
简单来说就是:
消息函数用来取边和节点的特征。
聚合函数用来计算边和节点的特征,例如特征求和,根据特征求个注意力权重等等。
更新函数用来更新节点的特征,对聚合函数传来的特征可以过个激活函数等,最后得到最终的节点特征即可更新。


DGL中的自定义消息函数

在DGL中,消息函数 接受一个参数 edges,这是一个 EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。 edges 有 src、 dst 和 data 共3个成员属性, 分别用于访问源节点、目标节点和边的特征。

用法就是定义一个函数,然后需要传入一个edges参数,这个参数有src、 dst 和 data 共3个成员属性,能够索引对应的特征
例:

def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # 获得边的特征
print("edges.src[x]", edges.src["x"]) # 获得边的源节点的特征
print("edges.dst[x]", edges.dst["x"]) # 获得边的目标节点的特征
# 返回得到需要传递的消息特征
return 'e_data': edges.src['h'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]

DGL中的自定义聚合函数

聚合函数 接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 等。

用法就是定义一个函数,然后需要传入一个nodes参数,这个参数能够通过mailbox索引消息函数return来的特征。

例:

def reduce_func(nodes):
print("+"*20)
# 获取每个节点的边特征的和并储存在节点的e_data中
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# 获取每条边的源节点特征并求和储存在节点的src_data中
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# 获取每条边的目标节点特征并求和储存在节点的src_data中
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return "data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum

DGL中的自定义更新函数

更新函数 同样接受参数 nodes。此函数对聚合函数的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。
例:

def apply_node_func(nodes):
# 将x用data_sum更新
return 'x': nodes.data["data_sum"]

最后加上

g.update_all(message_func, reduce_func, apply_node_func)

即可完成消息传递的操作。


实例分析


创建图

首先我们先创建那么一张图:

其中黑色的为节点的特征,红色的为边的特征
对应创建代码如下:

import dgl
import dgl.function as fn
import torch
import torch as th
# 构建图
g = dgl.graph(([0, 1, 1, 1, 2, 3, 2, 4, 3, 4, 4], [1, 0, 3, 2, 1, 1, 4, 2, 4, 3, 4]))
# 每个节点的特征都为[1, 1]
g.ndata['x'] = torch.ones(5, 2)
# 每边节点的特征都为[1, 1]
g.edata['x'] = torch.ones(11, 2)
# 节点4的特征为[0.2, 0.5]
g.ndata['x'][4] = torch.tensor([0.2, 0.5])
# 边5的特征为[0.1, 0.1]
g.edata['x'][5] = torch.tensor([0.1, 0.1])
# 消息汇聚更新
# g.update_all(fn.copy_u(u='x', out='m'), fn.sum(msg='m', out='h'))
print(g.ndata['x'])
print(g.edata["x"])

消息传递

然后我们来试着理解一下消息传递:

def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # 获得边的特征
print("edges.src[x]", edges.src["x"]) # 获得边的源节点的特征
print("edges.dst[x]", edges.dst["x"]) # 获得边的目标节点的特征
# 返回得到需要传递的消息特征
return 'e_data': edges.data['x'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]
def reduce_func(nodes):
print("+"*20)
# 获取每个节点的边特征的和并储存在节点的e_data中
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# 获取每条边的源节点特征并求和储存在节点的src_data中
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# 获取每条边的目标节点特征并求和储存在节点的src_data中
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return "data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum
def apply_node_func(nodes):
# 将x用data_sum更新
return 'x': nodes.data["data_sum"]
g.update_all(message_func, reduce_func, apply_node_func)
print(g.ndata["x"])

我们就看一下用边特征更新后的x特征的输出好了。
g.ndata[“x”]的输出:

tensor([[1.0000, 1.0000],
[2.1000, 2.1000],
[2.0000, 2.0000],
[2.0000, 2.0000],
[3.0000, 3.0000]])

说明每个节点的入度的边的特征都求和之后汇聚到x特征上了,还是非常好理解的。

另外两个大概是求源节点相同的边的目标节点的特征的和来更新节点特征
以及求目标节点相同的边的源节点的特征的和来更新节点特征
可能说起来有点绕,但是看看代码运行的结果再结合图应该就懂了,这里就不放运行结果了。

这里是为了演示步骤,一般不再update_all中自己设置更新函数的。


graph.apply_edges

在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges() 单独调用逐边计算。 apply_edges() 的参数是一个消息函数。并且在默认情况下,这个接口将更新所有的边。

import dgl
import torch
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['h'] = torch.ones(5, 2)
# 三种方式
# def add(edges):
# return"x": edges.src['h'] + edges.dst['h']
# g.apply_edges(add)
# g.apply_edges(lambda edges: 'x' : edges.src['h'] + edges.dst['h']) # 二者等价
g.apply_edges(fn.u_add_v('h', 'h', 'x')) # 使用内置函数,是最好的
print(g.edata['x'])

算图的注意力机制的时候,可以先计算出每个边的注意力权重,此时直接边计算即可。


参考

https://docs.dgl.ai/guide_cn/message.html
https://docs.dgl.ai/guide_cn/message-api.html


推荐阅读
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 本文讨论了如何优化解决hdu 1003 java题目的动态规划方法,通过分析加法规则和最大和的性质,提出了一种优化的思路。具体方法是,当从1加到n为负时,即sum(1,n)sum(n,s),可以继续加法计算。同时,还考虑了两种特殊情况:都是负数的情况和有0的情况。最后,通过使用Scanner类来获取输入数据。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 本文介绍了P1651题目的描述和要求,以及计算能搭建的塔的最大高度的方法。通过动态规划和状压技术,将问题转化为求解差值的问题,并定义了相应的状态。最终得出了计算最大高度的解法。 ... [详细]
  • 本文详细介绍了在ASP.NET中获取插入记录的ID的几种方法,包括使用SCOPE_IDENTITY()和IDENT_CURRENT()函数,以及通过ExecuteReader方法执行SQL语句获取ID的步骤。同时,还提供了使用这些方法的示例代码和注意事项。对于需要获取表中最后一个插入操作所产生的ID或马上使用刚插入的新记录ID的开发者来说,本文提供了一些有用的技巧和建议。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文介绍了数据库的存储结构及其重要性,强调了关系数据库范例中将逻辑存储与物理存储分开的必要性。通过逻辑结构和物理结构的分离,可以实现对物理存储的重新组织和数据库的迁移,而应用程序不会察觉到任何更改。文章还展示了Oracle数据库的逻辑结构和物理结构,并介绍了表空间的概念和作用。 ... [详细]
  • 本文介绍了使用Java实现大数乘法的分治算法,包括输入数据的处理、普通大数乘法的结果和Karatsuba大数乘法的结果。通过改变long类型可以适应不同范围的大数乘法计算。 ... [详细]
  • 本文详细介绍了Java中vector的使用方法和相关知识,包括vector类的功能、构造方法和使用注意事项。通过使用vector类,可以方便地实现动态数组的功能,并且可以随意插入不同类型的对象,进行查找、插入和删除操作。这篇文章对于需要频繁进行查找、插入和删除操作的情况下,使用vector类是一个很好的选择。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
author-avatar
Hancl
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有