提交 5f2a8263 编写于 作者: W wat3rBro 提交者: Francisco Massa

use all_gather to gather results from all gpus (#383)

上级 9b53d15c
...@@ -9,7 +9,7 @@ from tqdm import tqdm ...@@ -9,7 +9,7 @@ from tqdm import tqdm
from maskrcnn_benchmark.data.datasets.evaluation import evaluate from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process from ..utils.comm import is_main_process
from ..utils.comm import scatter_gather from ..utils.comm import all_gather
from ..utils.comm import synchronize from ..utils.comm import synchronize
...@@ -30,7 +30,7 @@ def compute_on_dataset(model, data_loader, device): ...@@ -30,7 +30,7 @@ def compute_on_dataset(model, data_loader, device):
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
all_predictions = scatter_gather(predictions_per_gpu) all_predictions = all_gather(predictions_per_gpu)
if not is_main_process(): if not is_main_process():
return return
# merge the list of dicts # merge the list of dicts
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
""" """
This file contains primitives for multi-gpu communication. This file contains primitives for multi-gpu communication.
This is useful when doing distributed training. This is useful when doing distributed training.
""" """
import os
import pickle import pickle
import tempfile
import time import time
import torch import torch
import torch.distributed as dist
def get_world_size(): def get_world_size():
if not torch.distributed.is_available(): if not dist.is_available():
return 1 return 1
if not torch.distributed.is_initialized(): if not dist.is_initialized():
return 1 return 1
return torch.distributed.get_world_size() return dist.get_world_size()
def get_rank(): def get_rank():
if not torch.distributed.is_available(): if not dist.is_available():
return 0 return 0
if not torch.distributed.is_initialized(): if not dist.is_initialized():
return 0 return 0
return torch.distributed.get_rank() return dist.get_rank()
def is_main_process(): def is_main_process():
if not torch.distributed.is_available(): return get_rank() == 0
return True
if not torch.distributed.is_initialized():
return True
return torch.distributed.get_rank() == 0
def synchronize(): def synchronize():
""" """
Helper function to synchronize between multiple processes when Helper function to synchronize (barrier) among all processes when
using distributed training using distributed training
""" """
if not torch.distributed.is_available(): if not dist.is_available():
return return
if not torch.distributed.is_initialized(): if not dist.is_initialized():
return return
world_size = torch.distributed.get_world_size() world_size = dist.get_world_size()
rank = torch.distributed.get_rank() rank = dist.get_rank()
if world_size == 1: if world_size == 1:
return return
...@@ -55,7 +49,7 @@ def synchronize(): ...@@ -55,7 +49,7 @@ def synchronize():
tensor = torch.tensor(0, device="cuda") tensor = torch.tensor(0, device="cuda")
else: else:
tensor = torch.tensor(1, device="cuda") tensor = torch.tensor(1, device="cuda")
torch.distributed.broadcast(tensor, r) dist.broadcast(tensor, r)
while tensor.item() == 1: while tensor.item() == 1:
time.sleep(1) time.sleep(1)
...@@ -64,94 +58,73 @@ def synchronize(): ...@@ -64,94 +58,73 @@ def synchronize():
_send_and_wait(1) _send_and_wait(1)
def _encode(encoded_data, data): def all_gather(data):
# gets a byte representation for the data
encoded_bytes = pickle.dumps(data)
# convert this byte string into a byte tensor
storage = torch.ByteStorage.from_buffer(encoded_bytes)
tensor = torch.ByteTensor(storage).to("cuda")
# encoding: first byte is the size and then rest is the data
s = tensor.numel()
assert s <= 255, "Can't encode data greater than 255 bytes"
# put the encoded data in encoded_data
encoded_data[0] = s
encoded_data[1 : (s + 1)] = tensor
def _decode(encoded_data):
size = encoded_data[0]
encoded_tensor = encoded_data[1 : (size + 1)].to("cpu")
return pickle.loads(bytearray(encoded_tensor.tolist()))
# TODO try to use tensor in shared-memory instead of serializing to disk
# this involves getting the all_gather to work
def scatter_gather(data):
""" """
This function gathers data from multiple processes, and returns them Run all_gather on arbitrary picklable data (not necessarily tensors)
in a list, as they were obtained from each process. Args:
data: any picklable object
This function is useful for retrieving data from multiple processes,
when launching the code with torch.distributed.launch
Note: this function is slow and should not be used in tight loops, i.e.,
do not use it in the training loop.
Arguments:
data: the object to be gathered from multiple processes.
It must be serializable
Returns: Returns:
result (list): a list with as many elements as there are processes, list[data]: list of data gathered from each rank
where each element i in the list corresponds to the data that was
gathered from the process of rank i.
""" """
# strategy: the main process creates a temporary directory, and communicates world_size = get_world_size()
# the location of the temporary directory to all other processes. if world_size == 1:
# each process will then serialize the data to the folder defined by
# the main process, and then the main process reads all of the serialized
# files and returns them in a list
if not torch.distributed.is_available():
return [data]
if not torch.distributed.is_initialized():
return [data] return [data]
synchronize()
# get rank of the current process # serialized to a Tensor
rank = torch.distributed.get_rank() buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
# the data to communicate should be small tensor = torch.ByteTensor(storage).to("cuda")
data_to_communicate = torch.empty(256, dtype=torch.uint8, device="cuda")
if rank == 0: # obtain Tensor size of each rank
# manually creates a temporary directory, that needs to be cleaned local_size = torch.IntTensor([tensor.numel()]).to("cuda")
# afterwards size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
tmp_dir = tempfile.mkdtemp() dist.all_gather(size_list, local_size)
_encode(data_to_communicate, tmp_dir) size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
synchronize()
# the main process (rank=0) communicates the data to all processes # receiving Tensor from all ranks
torch.distributed.broadcast(data_to_communicate, 0) # we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
# get the data that was communicated tensor_list = []
tmp_dir = _decode(data_to_communicate) for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
# each process serializes to a different file if local_size != max_size:
file_template = "file{}.pth" padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tmp_file = os.path.join(tmp_dir, file_template.format(rank)) tensor = torch.cat((tensor, padding), dim=0)
torch.save(data, tmp_file) dist.all_gather(tensor_list, tensor)
# synchronize before loading the data data_list = []
synchronize() for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
# only the master process returns the data data_list.append(pickle.loads(buffer))
if rank == 0:
data_list = [] return data_list
world_size = torch.distributed.get_world_size()
for r in range(world_size):
file_path = os.path.join(tmp_dir, file_template.format(r)) def reduce_dict(input_dict, average=True):
d = torch.load(file_path) """
data_list.append(d) Args:
# cleanup input_dict (dict): all the values will be reduced
os.remove(file_path) average (bool): whether to do average or sum
# cleanup Reduce the values in the dictionary from all processes so that process with rank
os.rmdir(tmp_dir) 0 has the averaged results. Returns a dict with the same fields as
return data_list input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册