热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

multitask训练torch_Pytorch多机多卡分布式训练

被这东西刁难两天了,终于想办法解决掉了,来造福下人民群众。关于Pytorch分布训练的话,大家一开始接触的往往是DataParallel&

被这东西刁难两天了,终于想办法解决掉了,来造福下人民群众。

关于Pytorch分布训练的话,大家一开始接触的往往是DataParallel,这个wrapper能够很方便的使用多张卡,而且将进程控制在一个。唯一的问题就在于,DataParallel只能满足一台机器上gpu的通信,而一台机器一般只能装8张卡,对于一些大任务,8张卡就很吃力了,这个时候我们就需要面对多机多卡分布式训练这个问题了,噩梦开始了。

官方pytorch(v1.0.10)在分布式上给出的api有这么几个比较重要的:

torch.nn.parallel.DistributedDataParallel :

这个从名字上就能看出来与DataParallel相类似,也是一个模型wrapper。这个包是实现多机多卡分布训练最核心东西,它可以帮助我们在不同机器的多个模型拷贝之间平均梯度。

2. torch.utils.data.distributed.DistributedSampler:

在多机多卡情况下分布式训练数据的读取也是一个问题,不同的卡读取到的数据应该是不同的。dataparallel的做法是直接将batch切分到不同的卡,这种方法对于多机来说不可取,因为多机之间直接进行数据传输会严重影响效率。于是有了利用sampler确保dataloader只会load到整个数据集的一个特定子集的做法。DistributedSampler就是做这件事的。它为每一个子进程划分出一部分数据集,以避免不同进程之间数据重复。

# 分布式训练示例

from torch.utils.data import Dataset, DataLoader

from torch.utils.data.distributed import DistributedSampler

from torch.nn.parallel import DistributedDataParallel

dataset = your_dataset()

datasampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)

dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=datasampler)

model = your_model()

model = DistributedDataPrallel(model, device_ids=[local_rank], output_device=local_rank)

其他部分就和正常训练代码无异了。

得提的几个点:

和dataparallel不同,dataparallel需要将batchsize设置成n倍的单卡batchsize,而distributedsampler使用的情况下,batchsize设置与单卡设置相同。

这里有几个新的参数:world size, rank, local rank, rank。world size指进程总数,在这里就是我们使用的卡数;rank指进程序号,local_rank指本地序号,两者的区别在于前者用于进程间通讯,后者用于本地设备分配。这个时候真正麻烦的地方来了:

想要使用DistributedDataParallel,得先完成多进程的初始化,就是这个:

torch.distributed.init_process_group()

看官方的说明:

gloo基本只支持cpu,不考虑。mpi需要在本地重新编译pytorch,感兴趣的朋友可以试试。nccl对gpu支持良好还不需要重新编译,在下和官方都强烈推荐这个作为backend。

pytorch作者推荐的初始化方式:

我最后的实现也是利用这种方式。但我面临的问题是:如何在我们的slurm集群上完成这个初始化并进行训练,那么问题就变成了如何在slurm集群上把你分配到的ip写进程序里。两个办法:

1.srun指定-n 进程总数以及 –ntasks-per-node 每个节点进程数,这样就可以通过os.environ获得每个进程的节点ip信息,全局rank以及local rank,有了这些就可以很方便很方便的完成初始化。推荐使用该方法(感谢评论区大佬指点)

2.salloc,这个就相对霸道一些,直接指定几个节点自己拿来用,这样就很容易选出来通信用的节点,再随便给个端口,我们就能完成初始化。相比1还是麻烦不少。

关于获取节点信息的详细代码:

import os

os.environ['SLURM_NTASKS'] #可用作world size

os.environ['SLURM_NODEID'] #node id

os.environ['SLURM_PROCID'] #可用作全局rank

os.environ['SLURM_LOCALID'] #local_rank

os.environ['SLURM_STEP_NODELIST'] #从中取得一个ip作为通讯ip

贴段差不多能跑的代码吧:

import torch

torch.multiprocessing.set_start_method('spawn')

import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

from torch.utils.data.distributed import DistributedSampler

from torch.nn.parallel import DistributedDataParallel

import os

def dist_init(host_addr, rank, local_rank, world_size, port=23456):

host_addr_full = 'tcp://' + host_addr + ':' + str(port)

torch.distributed.init_process_group("nccl", init_method=host_addr_full,

rank=rank, world_size=world_size)

num_gpus = torch.cuda.device_count()

torch.cuda.set_device(local_rank)

assert torch.distributed.is_initialized()

rank = int(os.environ['SLURM_PROCID'])

local_rank = int(os.environ['SLURM_LOCALID'])

world_size = int(os.environ['SLURM_NTASKS'])

# get_ip函数自己写一下 不同服务器这个字符串形式不一样

# 保证所有task拿到的是同一个ip就成

ip = get_ip(os.environ['SLURM_STEP_NODELIST'])

dist_init(ip, rank, local_rank, world_size)

# 接下来是写dataset和dataloader,这个网上有很多教程

# 我这给的也只是个形式,按自己需求写好就ok

dataset = your_dataset() #主要是把这写好

datasampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)

dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=source_sampler)

model = your_model() #也是按自己的模型写

model = DistributedDataPrallel(model, device_ids=[local_rank], output_device=local_rank)

# 此后训练流程与普通模型无异

照上面写好train.py之后(叫啥都行,这儿就叫train.py吧),slrum指令写这样:

# 这里是3台机器,每台机器8张卡的样子

srun -n24 --gres=gpu:8 --ntasks-per-node=8 python train.py

—————————————分割线—————————————–

以下是关于方法2的补充,1已经足够用了,保留这部分只是为了文章完整。

然而,事情并没有这么简单。

如果一次salloc多个节点,那么接下来的srun指令默认是每个节点执行一个拷贝,也就是说,我们的rank是无法保证两个节点上是不一样的。怎么办?

参考这个

需要多少个节点,就开多少个窗口,每个窗口salloc一个节点,就能解决上面的问题。

最后,如何方便的得到多个rank不一样的进程,其实pytorch已经有模块可以很好的解决这个问题:

参考:

采用torch.distributed.launch模块可以很容易为每个进程得到不一样的localrank,再在不同的节点指定不同的rank初始值,slurm群组上就可以开心的pytorch多机分布式训练啦。



推荐阅读
  • 本文介绍了利用ARMA模型对平稳非白噪声序列进行建模的步骤及代码实现。首先对观察值序列进行样本自相关系数和样本偏自相关系数的计算,然后根据这些系数的性质选择适当的ARMA模型进行拟合,并估计模型中的位置参数。接着进行模型的有效性检验,如果不通过则重新选择模型再拟合,如果通过则进行模型优化。最后利用拟合模型预测序列的未来走势。文章还介绍了绘制时序图、平稳性检验、白噪声检验、确定ARMA阶数和预测未来走势的代码实现。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 本文介绍了计算机网络的定义和通信流程,包括客户端编译文件、二进制转换、三层路由设备等。同时,还介绍了计算机网络中常用的关键词,如MAC地址和IP地址。 ... [详细]
  • 本文介绍了如何在Mac上使用Pillow库加载不同于默认字体和大小的字体,并提供了一个简单的示例代码。通过该示例,读者可以了解如何在Python中使用Pillow库来写入不同字体的文本。同时,本文也解决了在Mac上使用Pillow库加载字体时可能遇到的问题。读者可以根据本文提供的示例代码,轻松实现在Mac上使用Pillow库加载不同字体的功能。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 计算机存储系统的层次结构及其优势
    本文介绍了计算机存储系统的层次结构,包括高速缓存、主存储器和辅助存储器三个层次。通过分层存储数据可以提高程序的执行效率。计算机存储系统的层次结构将各种不同存储容量、存取速度和价格的存储器有机组合成整体,形成可寻址存储空间比主存储器空间大得多的存储整体。由于辅助存储器容量大、价格低,使得整体存储系统的平均价格降低。同时,高速缓存的存取速度可以和CPU的工作速度相匹配,进一步提高程序执行效率。 ... [详细]
  • Python正则表达式学习记录及常用方法
    本文记录了学习Python正则表达式的过程,介绍了re模块的常用方法re.search,并解释了rawstring的作用。正则表达式是一种方便检查字符串匹配模式的工具,通过本文的学习可以掌握Python中使用正则表达式的基本方法。 ... [详细]
  • 关键词:Golang, Cookie, 跟踪位置, net/http/cookiejar, package main, golang.org/x/net/publicsuffix, io/ioutil, log, net/http, net/http/cookiejar ... [详细]
  • 本文由编程笔记小编整理,主要介绍了使用Junit和黄瓜进行自动化测试中步骤缺失的问题。文章首先介绍了使用cucumber和Junit创建Runner类的代码,然后详细说明了黄瓜功能中的步骤和Steps类的实现。本文对于需要使用Junit和黄瓜进行自动化测试的开发者具有一定的参考价值。摘要长度:187字。 ... [详细]
  • 重入锁(ReentrantLock)学习及实现原理
    本文介绍了重入锁(ReentrantLock)的学习及实现原理。在学习synchronized的基础上,重入锁提供了更多的灵活性和功能。文章详细介绍了重入锁的特性、使用方法和实现原理,并提供了类图和测试代码供读者参考。重入锁支持重入和公平与非公平两种实现方式,通过对比和分析,读者可以更好地理解和应用重入锁。 ... [详细]
  • NotSupportedException无法将类型“System.DateTime”强制转换为类型“System.Object”
    本文介绍了在使用LINQ to Entities时出现的NotSupportedException异常,该异常是由于无法将类型“System.DateTime”强制转换为类型“System.Object”所导致的。同时还介绍了相关的错误信息和解决方法。 ... [详细]
  • 1Lock与ReadWriteLock1.1LockpublicinterfaceLock{voidlock();voidlockInterruptibl ... [详细]
author-avatar
飛仔2502897013
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有