热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

第七天深入学习DGL框架:官方文档指导下的数据集下载与预处理技巧

在第七天的深度学习课程中,我们将重点探讨DGL框架的高级应用,特别是在官方文档指导下进行数据集的下载与预处理。通过详细的步骤说明和实用技巧,帮助读者高效地构建和优化图神经网络的数据管道。此外,我们还将介绍如何利用DGL提供的模块化工具,实现数据的快速加载和预处理,以提升模型训练的效率和准确性。

参考链接


  1. https://docs.dgl.ai/guide/data.html#guide-data-pipeline
  2. https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset

DGLDataset

DGL在 dgl.data 里实现了很多常用的图数据集。它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道。DGL推荐用户将图数据处理为 dgl.data.DGLDataset 的子类。该类为导入、处理和保存图数据提供了简单而清晰的解决方案。

DGLDataset的执行流程:


  1. 通过调用“has_cache()”判断磁盘上是否有已经处理好的数据集缓存。如果有,则跳转到第5步,直接加载数据集;

  2. 调用“download()”下载数据;

  3. 调用“process()”处理数据;

  4. 调用“save()”保存处理好的数据到磁盘,跳转到第6步;

  5. 调用“load()”从磁盘加载数据集;

  6. 完成。

下面给出了一个继承自DGLDataset类的例子。子类中必须实现process(), getitem(idx) 和 len()。同时官方建议也实现save()和load(),避免对大型数据集的重复处理。

from dgl.data import DGLDatasetclass MyDataset(DGLDataset):""" 用于在DGL中自定义图数据集的模板:Parameters----------url : str下载原始数据集的url。raw_dir : str指定下载数据的存储目录或已下载数据的存储目录。默认: ~/.dgl/save_dir : str处理完成的数据集的保存目录。默认:raw_dir指定的值force_reload : bool是否重新导入数据集。默认:Falseverbose : bool是否打印进度信息。"""def __init__(self,url=None,raw_dir=None,save_dir=None,force_reload=False,verbose=False):super(MyDataset, self).__init__(name='dataset_name',url=url,raw_dir=raw_dir,save_dir=save_dir,force_reload=force_reload,verbose=verbose)def download(self):# 将原始数据下载到本地磁盘passdef process(self):# 将原始数据处理为图、标签和数据集划分的掩码passdef __getitem__(self, idx):# 通过idx得到与之对应的一个样本passdef __len__(self):# 数据样本的数量passdef save(self):# 将处理后的数据保存至 `self.save_path`passdef load(self):# 从 `self.save_path` 导入处理后的数据passdef has_cache(self):# 检查在 `self.save_path` 中是否存有处理后的数据pass

下载原始数据

这一段就是给实现“download()”举了两个例子。
从“self.url”链接下载到“self.raw_dir”目录下,保存为“self.name+格式后缀”:

import os
from dgl.data.utils import downloaddef download(self):# 存储文件的路径file_path = os.path.join(self.raw_dir, self.name + '.mat')# 下载文件download(self.url, path=file_path)

如果数据集是一个zip文件,可以直接继承 dgl.data.DGLBuiltinDataset 类,其支持解压缩zip文件。

如果文件是.gz、.tar、.tar.gz或.tgz文件,下载后需要用 extract_archive() 函数进行解压缩:

from dgl.data.utils import download, check_sha1def download(self):# 存储文件的路径,请确保使用与原始文件名相同的后缀gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')# 下载文件download(self.url, path=gz_file_path)# 检查 SHA-1if not check_sha1(gz_file_path, self._sha1_str):raise UserWarning('File {} is downloaded but the content hash does not match.''The repo may be outdated or download may be incomplete. ''Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))# 将文件解压缩到目录self.raw_dir下的self.name目录中self._extract_gz(gz_file_path, self.raw_path)

处理数据

假设数据已经下载到“self.raw_dir”目录下,接下来就可以处理数据了。根据图上的任务,分别从整图分类、节点分类和链接预测介绍。

整图分类

整图分类任务与传统机器学习任务类似,整图为特征,类别为标签。调用“process()”将数据集处理为 dgl.DGLGraph 对象的列表和标签张量的列表。

class QM7bDataset(DGLDataset):_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \'datasets/qm7b.mat'_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'def __init__(self, raw_dir=None, force_reload=False, verbose=False):super(QM7bDataset, self).__init__(name='qm7b',url=self._url,raw_dir=raw_dir,force_reload=force_reload,verbose=verbose)def process(self):mat_path = self.raw_path + '.mat'self.graphs, self.label = self._load_graph(mat_path)def _load_graph(self, filename):data = io.loadmat(filename)labels = F.tensor(data['T'], dtype=F.data_type_dict['float32'])feats = data['X']num_graphs = labels.shape[0]graphs = []for i in range(num_graphs):edge_list = feats[i].nonzero()g = dgl_graph(edge_list)g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),dtype=F.data_type_dict['float32'])graphs.append(g)return graphs, labelsdef save(self):"""save the graph list and the labels"""graph_path = os.path.join(self.save_path, 'dgl_graph.bin')save_graphs(str(graph_path), self.graphs, {'labels': self.label})def has_cache(self):graph_path = os.path.join(self.save_path, 'dgl_graph.bin')return os.path.exists(graph_path)def load(self):graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph.bin'))self.graphs = graphsself.label = label_dict['labels']def download(self):file_path = os.path.join(self.raw_dir, self.name + '.mat')download(self.url, path=file_path)if not check_sha1(file_path, self._sha1_str):raise UserWarning('File {} is downloaded but the content hash does not match.''The repo may be outdated or download may be incomplete. ''Otherwise you can create an issue for it.'.format(self.name))@propertydef num_labels(self):return 14def __getitem__(self, idx):return self.graphs[idx], self.label[idx]def __len__(self):return len(self.graphs)

处理完数据后,就可以跟传统分类任务一样使用数据了。

import dgl
import torchfrom torch.utils.data import DataLoader# 数据导入
dataset = QM7bDataset()
num_labels = dataset.num_labels# 创建collate_fn函数
def _collate_fn(batch):graphs, labels = batchg = dgl.batch(graphs)labels = torch.tensor(labels, dtype=torch.long)return g, labels# 创建 dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn)# 训练
for epoch in range(100):for g, labels in dataloader:# 用户自己的训练代码pass

节点分类

与整图分类不同,节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。 DGL建议使用节点掩码来指定数据集的划分,相当于给节点做一个标记,明确是为训练节点(“g.ndata[‘train_mask’]”)、验证节点(“g.ndata[‘val_mask’]”)还是测试节点(“g.ndata[‘test_mask’]”)。 本节以内置数据集 CitationGraphDataset 为例,支持’cora’, ‘citeseer’, 'pubmed’三个常用的数据集,DGL已经分别针对三个数据集构建了子类CoraGraphDataset、CiteseerGraphDataset和PubmedGraphDataset。

from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url, generate_mask_tensorclass CitationGraphDataset(DGLBuiltinDataset):_urls = {'cora_v2' : 'dataset/cora_v2.zip','citeseer' : 'dataset/citeseer.zip','pubmed' : 'dataset/pubmed.zip',}def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):assert name.lower() in ['cora', 'citeseer', 'pubmed']if name.lower() == 'cora':name = 'cora_v2'url = _get_dgl_url(self._urls[name])super(CitationGraphDataset, self).__init__(name,url=url,raw_dir=raw_dir,force_reload=force_reload,verbose=verbose)def process(self):# 跳过一些处理的代码# === 跳过数据处理 ===# 构建图g = dgl.graph(graph)# 划分掩码g.ndata['train_mask'] = generate_mask_tensor(train_mask)g.ndata['val_mask'] = generate_mask_tensor(val_mask)g.ndata['test_mask'] = generate_mask_tensor(test_mask)# 节点的标签g.ndata['label'] = torch.tensor(labels)# 节点的特征g.ndata['feat'] = torch.tensor(_preprocess_features(features),dtype=F.data_type_dict['float32'])self._num_labels = onehot_labels.shape[1]self._labels = labelsself._g = gdef __getitem__(self, idx):assert idx == 0, "这个数据集里只有一个图"return self._gdef __len__(self):return 1

由于数据集只有一个图,所以需要取第0个元素“dataset[0]”:

# 创建链接预测数据集示例
class KnowledgeGraphDataset(DGLBuiltinDataset):def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):self._name = nameself.reverse = reverseurl = _get_dgl_url('dataset/') + '{}.tgz'.format(name)super(KnowledgeGraphDataset, self).__init__(name,url=url,raw_dir=raw_dir,force_reload=force_reload,verbose=verbose)def process(self):# 跳过一些处理的代码# === 跳过数据处理 ===# 划分掩码g.edata['train_mask'] = train_maskg.edata['val_mask'] = val_maskg.edata['test_mask'] = test_mask# 边类型g.edata['etype'] = etype# 节点类型g.ndata['ntype'] = ntypeself._g = gdef __getitem__(self, idx):assert idx == 0, "这个数据集只有一个图"return self._gdef __len__(self):return 1

下面利用’FB15k-237’对应的子类 dgl.data.FB15k237Dataset 来做演示如何使用用于链路预测的数据集:

from dgl.data import FB15k237Dataset# 导入数据
dataset = FB15k237Dataset()
graph = dataset[0]# 获取训练集掩码
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask).squeeze()
src, dst = graph.edges(train_idx)# 获取训练集中的边类型
rel = graph.edata['etype'][train_idx]

保存和加载数据

DGL提供了4个函数:

  1. dgl.save_graphs(): 保存DGLGraph对象和标签到本地磁盘

  2. dgl.load_graphs():从本地磁盘读取它们

  3. dgl.data.utils.save_info(): 将数据集的有用信息(python dict对象)保存到本地磁盘

  4. dgl.data.utils.load_info()和从本地磁盘读取它们

import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import makedirs, save_info, load_infodef save(self):# 保存图和标签graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')save_graphs(graph_path, self.graphs, {'labels': self.labels})# 在Python字典里保存其他信息info_path = os.path.join(self.save_path, self.mode + '_info.pkl')save_info(info_path, {'num_classes': self.num_classes})def load(self):# 从目录 `self.save_path` 里读取处理过的数据graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')self.graphs, label_dict = load_graphs(graph_path)self.labels = label_dict['labels']info_path = os.path.join(self.save_path, self.mode + '_info.pkl')self.num_classes = load_info(info_path)['num_classes']def has_cache(self):# 检查在 `self.save_path` 里是否有处理过的数据文件graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')info_path = os.path.join(self.save_path, self.mode + '_info.pkl')return os.path.exists(graph_path) and os.path.exists(info_path)

当处理过的数据比较大时,在 getitem(idx) 中处理每个数据实例是更高效的方法。

使用ogb包导入OGB数据集

OGB(Open Graph Benchmark)是一个图深度学习的基准数据集。 官方的 ogb 包提供了用于下载和处理OGB数据集到 dgl.data.DGLGraph 对象的API。

首先需要使用“pip install ogb”安装这个包,接着就可以根据任务从里面加载数据集了。

图属性预测任务(Graph Property Prediction)

类的命名十分统一,只需要执行“dataset = DglGraphPropPredDataset(name=‘ogbg-molhiv’)”即可得到相应的数据集,然后与传统机器学习任务类似,将数据处理为(graph, label)的形式。

# 载入OGB的Graph Property Prediction数据集
import dgl
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import DataLoaderdef _collate_fn(batch):# 小批次是一个元组(graph, label)列表graphs = [e[0] for e in batch]g = dgl.batch(graphs)labels = [e[1] for e in batch]labels = torch.stack(labels, 0)return g, labels# 载入数据集
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
# dataloader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)

节点属性预测任务(Node Property Prediction)

类似地,执行“dataset = DglNodePropPredDataset(name=‘ogbn-proteins’)”即可获取数据集,这种数据集只有一个图对象。

# 载入OGB的Node Property Prediction数据集
from ogb.nodeproppred import DglNodePropPredDatasetdataset = DglNodePropPredDataset(name='ogbn-proteins')
split_idx = dataset.get_idx_split()# there is only one graph in Node Property Prediction datasets
# 在Node Property Prediction数据集里只有一个图
g, labels = dataset[0]
# 获取划分的标签
train_label = dataset.labels[split_idx['train']]
valid_label = dataset.labels[split_idx['valid']]
test_label = dataset.labels[split_idx['test']]

链接属性预测任务(Link Property Prediction)

通过执行“dataset = DglLinkPropPredDataset(name=‘ogbl-ppa’)”获取数据集,同样是单图。

# 载入OGB的Link Property Prediction数据集
from ogb.linkproppred import DglLinkPropPredDatasetdataset = DglLinkPropPredDataset(name='ogbl-ppa')
split_edge = dataset.get_edge_split()graph = dataset[0]
print(split_edge['train'].keys())
print(split_edge['valid'].keys())
print(split_edge['test'].keys())


推荐阅读
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 1:有如下一段程序:packagea.b.c;publicclassTest{privatestaticinti0;publicintgetNext(){return ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 深入理解 SQL 视图、存储过程与事务
    本文详细介绍了SQL中的视图、存储过程和事务的概念及应用。视图为用户提供了一种灵活的数据查询方式,存储过程则封装了复杂的SQL逻辑,而事务确保了数据库操作的完整性和一致性。 ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • UNP 第9章:主机名与地址转换
    本章探讨了用于在主机名和数值地址之间进行转换的函数,如gethostbyname和gethostbyaddr。此外,还介绍了getservbyname和getservbyport函数,用于在服务器名和端口号之间进行转换。 ... [详细]
  • 掌握远程执行Linux脚本和命令的技巧
    本文将详细介绍如何利用Python的Paramiko库实现远程执行Linux脚本和命令,帮助读者快速掌握这一实用技能。通过具体的示例和详尽的解释,让初学者也能轻松上手。 ... [详细]
  • 基于KVM的SRIOV直通配置及性能测试
    SRIOV介绍、VF直通配置,以及包转发率性能测试小慢哥的原创文章,欢迎转载目录?1.SRIOV介绍?2.环境说明?3.开启SRIOV?4.生成VF?5.VF ... [详细]
  • 深入探讨CPU虚拟化与KVM内存管理
    本文详细介绍了现代服务器架构中的CPU虚拟化技术,包括SMP、NUMA和MPP三种多处理器结构,并深入探讨了KVM的内存虚拟化机制。通过对比不同架构的特点和应用场景,帮助读者理解如何选择最适合的架构以优化性能。 ... [详细]
  • 本题通过将每个矩形视为一个节点,根据其相对位置构建拓扑图,并利用深度优先搜索(DFS)或状态压缩动态规划(DP)求解最小涂色次数。本文详细解析了该问题的建模思路与算法实现。 ... [详细]
  • 本文详细介绍如何使用Python进行配置文件的读写操作,涵盖常见的配置文件格式(如INI、JSON、TOML和YAML),并提供具体的代码示例。 ... [详细]
author-avatar
gerardlong
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有