未验证 提交 89951472 编写于 作者: L lilong12 提交者: GitHub

add the support for allreduce_prod for new dygraph (#42284)

上级 8395d660
...@@ -350,18 +350,19 @@ def new_group(ranks=None, backend=None): ...@@ -350,18 +350,19 @@ def new_group(ranks=None, backend=None):
global _default_group_name global _default_group_name
gid = _new_ring_id() gid = _new_ring_id()
group_name = _default_group_name + str(gid) group_name = _default_group_name + str(gid)
global_group = _get_default_group() if ranks is None or len(ranks) > 1:
global_rank = global_group.rank global_group = _get_default_group()
global_ranks = global_group.ranks global_rank = global_group.rank
backend = _default_backend if backend is None else backend global_ranks = global_group.ranks
if ranks is None: backend = _default_backend if backend is None else backend
ranks = global_ranks if ranks is None:
assert len(ranks) <= len(global_ranks), ( ranks = global_ranks
"Size of new group must be less than or " assert len(ranks) <= len(global_ranks), (
"equal to that of the default global group.") "Size of new group must be less than or "
"equal to that of the default global group.")
size = len(ranks) size = len(ranks)
ranks = sorted(ranks) ranks = sorted(ranks)
if global_rank in ranks and size > 1: if size > 1 and global_rank in ranks:
rank = ranks.index(global_rank) rank = ranks.index(global_rank)
pg = _new_process_group_impl( pg = _new_process_group_impl(
backend, backend,
...@@ -642,6 +643,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): ...@@ -642,6 +643,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
op_type = core.ReduceOp.MAX op_type = core.ReduceOp.MAX
elif op == ReduceOp.MIN: elif op == ReduceOp.MIN:
op_type = core.ReduceOp.MIN op_type = core.ReduceOp.MIN
elif op == ReduceOp.PROD:
op_type = core.ReduceOp.PRODUCT
else: else:
raise ValueError("Unknown reduce_op type for allreduce.") raise ValueError("Unknown reduce_op type for allreduce.")
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
...@@ -744,6 +747,8 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): ...@@ -744,6 +747,8 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
op_type = core.ReduceOp.MAX op_type = core.ReduceOp.MAX
elif op == ReduceOp.MIN: elif op == ReduceOp.MIN:
op_type = core.ReduceOp.MIN op_type = core.ReduceOp.MIN
elif op == ReduceOp.PROD:
op_type = core.ReduceOp.PRODUCT
else: else:
raise ValueError("Unknown reduce_op type for reduce.") raise ValueError("Unknown reduce_op type for reduce.")
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
......
...@@ -219,8 +219,9 @@ def init_parallel_env(): ...@@ -219,8 +219,9 @@ def init_parallel_env():
"required to create a process group.") "required to create a process group.")
master_addr = os.getenv("MASTER_ADDR", None) master_addr = os.getenv("MASTER_ADDR", None)
master_port = os.getenv("MASTER_PORT", None) master_port = os.getenv("MASTER_PORT", None)
endpoints = None endpoints = ":".join(
if not master_addr or not master_port: [master_addr, master_port]) if master_addr and master_port else None
if endpoints is None:
endpoints = os.getenv("PADDLE_MASTER", None) endpoints = os.getenv("PADDLE_MASTER", None)
if endpoints is None: if endpoints is None:
endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0] endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
......
...@@ -122,6 +122,29 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -122,6 +122,29 @@ class TestProcessGroupFp32(unittest.TestCase):
print("test allreduce min api ok") print("test allreduce min api ok")
# test allreduce prod
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
prod_result = np.multiply(x, y)
if pg.rank() == 0:
task = dist.all_reduce(
tensor_x, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
assert np.array_equal(tensor_x, prod_result)
else:
task = dist.all_reduce(
tensor_y, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
assert np.array_equal(tensor_y, prod_result)
print("test allreduce prod api ok")
# test broadcast # test broadcast
# rank 0 # rank 0
x = np.random.random(self.shape).astype(self.dtype) x = np.random.random(self.shape).astype(self.dtype)
...@@ -332,6 +355,27 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -332,6 +355,27 @@ class TestProcessGroupFp32(unittest.TestCase):
print("test reduce min api ok") print("test reduce min api ok")
# test reduce product
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
prod_result = np.multiply(x, y)
if pg.rank() == 0:
task = dist.reduce(
tensor_x, 0, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
assert np.array_equal(tensor_x, prod_result)
else:
task = dist.reduce(
tensor_y, 0, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
print("test reduce prod api ok")
# test Scatter # test Scatter
# rank 0 # rank 0
in_shape = list(self.shape) in_shape = list(self.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册