defrun(rank, size): tensor = torch.zeros(1) if rank == 0: tensor += 1 # Send the tensor to process 1 dist.send(tensor=tensor, dst=1) else: # Receive tensor from process 0 dist.recv(tensor=tensor, src=0) print('Rank ', rank, ' has data ', tensor[0])
运行结果
1 2
Rank0 has data tensor(1.) Rank1 has data tensor(1.)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
"""Non-blocking point-to-point communication."""
defrun(rank, size): tensor = torch.zeros(1) req = None if rank == 0: tensor += 1 # Send the tensor to process 1 req = dist.isend(tensor=tensor, dst=1) print('Rank 0 started sending') else: # Receive tensor from process 0 req = dist.irecv(tensor=tensor, src=0) print('Rank 1 started receiving') req.wait() print('Rank ', rank, ' has data ', tensor[0])
运行结果:
1 2 3 4
Rank 1 started receiving Rank 0 started sending Rank 0 has data tensor(1.) Rank 1 has data tensor(1.)
dist.broadcast(tensor, src, group): Copies tensor from src toall other processes. dist.reduce(tensor, dst, op, group): Applies op toall tensor and stores the result in dst. dist.all_reduce(tensor, op, group): Same as reduce, but the result is stored inall processes. dist.scatter(tensor, src, scatter_list, group): Copies the ith tensor scatter_list[i] to the ith process. dist.gather(tensor, dst, gather_list, group): Copies tensor fromall processes in dst. dist.all_gather(tensor_list, tensor, group): Copies tensor fromall processes to tensor_list, onall processes. dist.barrier(group): block all processes ingroupuntileach one has entered this function.