未验证 提交 18d55e54 编写于 作者: Q Quentin Anthony 提交者: GitHub

Update barrier and reduce_scatter_base to conform to PyTorch signatures (#2570)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 731965db
......@@ -255,6 +255,7 @@ def has_reduce_scatter_base():
def reduce_scatter_fn(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
......@@ -265,6 +266,7 @@ def reduce_scatter_fn(output_tensor,
if cdb.has_reduce_scatter_base:
return reduce_scatter_base(output_tensor,
tensor,
op=op,
group=group,
async_op=async_op,
prof=prof,
......@@ -279,6 +281,7 @@ def reduce_scatter_fn(output_tensor,
input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
return reduce_scatter(output_tensor,
input_tensor_lst,
op=op,
group=group,
async_op=async_op,
prof=prof,
......@@ -288,6 +291,7 @@ def reduce_scatter_fn(output_tensor,
@timed_op
def reduce_scatter_base(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
......@@ -296,6 +300,7 @@ def reduce_scatter_base(output_tensor,
global cdb
return cdb.reduce_scatter_base(output_tensor=output_tensor,
input_tensor=tensor,
op=op,
group=group,
async_op=async_op)
......@@ -453,9 +458,25 @@ def scatter(tensor,
@timed_op
def barrier(group=None, prof=False, log_name='barrier', debug=get_caller_func()):
def barrier(group=None,
async_op=False,
device_ids=None,
prof=False,
log_name='barrier',
debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
@timed_op
def monitored_barrier(group=None,
timeout=None,
wait_all_ranks=False,
prof=False,
log_name='monitored_barrier',
debug=get_caller_func()):
global cdb
return cdb.barrier()
return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
def log_summary():
......
......@@ -95,11 +95,13 @@ class TorchBackend(Backend):
def reduce_scatter_base(self,
output_tensor,
input_tensor,
op=ReduceOp.SUM,
group=None,
async_op=False):
if self.has_reduce_scatter_base:
return torch.distributed._reduce_scatter_base(output_tensor,
input_tensor,
op=self._reduce_op(op),
group=group,
async_op=async_op)
else:
......@@ -149,8 +151,25 @@ class TorchBackend(Backend):
group=group,
async_op=async_op)
def barrier(self):
return torch.distributed.barrier()
def barrier(self,
group=torch.distributed.GroupMember.WORLD,
async_op=False,
device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group,
async_op=async_op,
device_ids=device_ids)
def monitored_barrier(self,
group=torch.distributed.GroupMember.WORLD,
timeout=None,
wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.monitored_barrier(group=group,
timeout=timeout,
wait_all_ranks=wait_all_ranks)
def get_rank(self, group=None):
return torch.distributed.get_rank(group=group)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册