Pytorch的distributed training

现在数据集变得越来越大,训练也由原来的多卡到多机,pytorch在1.0.0之后在distributed learning上做的已经比较好了,在官网也有现成的文档可以看,网上对这部分内容的介绍也很多。这里只是记录下自己看tutorial的过程,可能会有很多不足的地方,欢迎指正。

Base Code

官方文档是从一个tiny code开始慢慢丰富功能

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""run.py:"""
#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

def run(rank, size):
""" Distributed function to be implemented later. """
pass

def init_process(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)


if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
p = Process(target=init_process, args=(rank, size, run))
p.start()
processes.append(p)

for p in processes:
p.join()

torch.multiprocessing的Process和python的multiprocessing用起来没什么太大区别,主要是init_process_group来初始化进程组,并且让进程组内的进程可以数据交互。

P2P Communication

接着介绍了P2P Communication的两个函数send()和recv(),这两个函数都是blocking的函数,在进程运行过程中会等待通信完成然后接着运行,相应的有两个non-blocking的函数isend()和irecv()。和blocking方法相比,需要注意的就是non-blocking的函数在数据传输的时候一般会采用wait方法来保证通信执行了。

1
2
3
4
5
6
7
8
9
10
11
12
"""Blocking point-to-point communication."""

def run(rank, size):
tensor = torch.zeros(1)
if rank == 0:
tensor += 1
# Send the tensor to process 1
dist.send(tensor=tensor, dst=1)
else:
# Receive tensor from process 0
dist.recv(tensor=tensor, src=0)
print('Rank ', rank, ' has data ', tensor[0])

运行结果

1
2
Rank  0  has data  tensor(1.)
Rank 1 has data tensor(1.)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""Non-blocking point-to-point communication."""

def run(rank, size):
tensor = torch.zeros(1)
req = None
if rank == 0:
tensor += 1
# Send the tensor to process 1
req = dist.isend(tensor=tensor, dst=1)
print('Rank 0 started sending')
else:
# Receive tensor from process 0
req = dist.irecv(tensor=tensor, src=0)
print('Rank 1 started receiving')
req.wait()
print('Rank ', rank, ' has data ', tensor[0])

运行结果:

1
2
3
4
Rank 1 started receiving
Rank 0 started sending
Rank 0 has data tensor(1.)
Rank 1 has data tensor(1.)

Collective Communication

Collective Communication指的是scatter,gather,map,reduce这类操作。看链接里面的图片更清楚。

在多个rank间通信需要先用dist.new_group(group)来建立group

1
2
3
4
5
6
7
""" All-Reduce example."""
def run(rank, size):
""" Simple point-to-point communication. """
group = dist.new_group([0, 1])
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
print('Rank ', rank, ' has data ', tensor[0])

上面的dist.all_reduce(tensor, op, group)在group内对所有的tensor进行op操作,支持的op有

1
2
3
4
dist.reduce_op.SUM,
dist.reduce_op.PRODUCT,
dist.reduce_op.MAX,
dist.reduce_op.MIN.

与all_reduce类似的函数还有

1
2
3
4
5
6
7
dist.broadcast(tensor, src, group): Copies tensor from src to all other processes.
dist.reduce(tensor, dst, op, group): Applies op to all tensor and stores the result in dst.
dist.all_reduce(tensor, op, group): Same as reduce, but the result is stored in all processes.
dist.scatter(tensor, src, scatter_list, group): Copies the ith tensor scatter_list[i] to the ith process.
dist.gather(tensor, dst, gather_list, group): Copies tensor from all processes in dst.
dist.all_gather(tensor_list, tensor, group): Copies tensor from all processes to tensor_list, on all processes.
dist.barrier(group): block all processes in group until each one has entered this function.

Distributed Training


Pytorch的distributed training
http://yoursite.com/2020/04/19/Pytorch的distributed-training/
Author
John Doe
Posted on
April 19, 2020
Licensed under