未验证 提交 89951472 编写于 作者: L lilong12 提交者: GitHub

add the support for allreduce_prod for new dygraph (#42284)

上级 8395d660
......@@ -350,6 +350,7 @@ def new_group(ranks=None, backend=None):
global _default_group_name
gid = _new_ring_id()
group_name = _default_group_name + str(gid)
if ranks is None or len(ranks) > 1:
global_group = _get_default_group()
global_rank = global_group.rank
global_ranks = global_group.ranks
......@@ -361,7 +362,7 @@ def new_group(ranks=None, backend=None):
"equal to that of the default global group.")
size = len(ranks)
ranks = sorted(ranks)
if global_rank in ranks and size > 1:
if size > 1 and global_rank in ranks:
rank = ranks.index(global_rank)
pg = _new_process_group_impl(
backend,
......@@ -642,6 +643,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
op_type = core.ReduceOp.MAX
elif op == ReduceOp.MIN:
op_type = core.ReduceOp.MIN
elif op == ReduceOp.PROD:
op_type = core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for allreduce.")
group = _get_default_group() if group is None else group
......@@ -744,6 +747,8 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
op_type = core.ReduceOp.MAX
elif op == ReduceOp.MIN:
op_type = core.ReduceOp.MIN
elif op == ReduceOp.PROD:
op_type = core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for reduce.")
group = _get_default_group() if group is None else group
......
......@@ -219,8 +219,9 @@ def init_parallel_env():
"required to create a process group.")
master_addr = os.getenv("MASTER_ADDR", None)
master_port = os.getenv("MASTER_PORT", None)
endpoints = None
if not master_addr or not master_port:
endpoints = ":".join(
[master_addr, master_port]) if master_addr and master_port else None
if endpoints is None:
endpoints = os.getenv("PADDLE_MASTER", None)
if endpoints is None:
endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
......
......@@ -122,6 +122,29 @@ class TestProcessGroupFp32(unittest.TestCase):
print("test allreduce min api ok")
# test allreduce prod
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
prod_result = np.multiply(x, y)
if pg.rank() == 0:
task = dist.all_reduce(
tensor_x, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
assert np.array_equal(tensor_x, prod_result)
else:
task = dist.all_reduce(
tensor_y, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
assert np.array_equal(tensor_y, prod_result)
print("test allreduce prod api ok")
# test broadcast
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
......@@ -332,6 +355,27 @@ class TestProcessGroupFp32(unittest.TestCase):
print("test reduce min api ok")
# test reduce product
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
prod_result = np.multiply(x, y)
if pg.rank() == 0:
task = dist.reduce(
tensor_x, 0, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
assert np.array_equal(tensor_x, prod_result)
else:
task = dist.reduce(
tensor_y, 0, dist.ReduceOp.PROD, use_calc_stream=False)
task.wait()
print("test reduce prod api ok")
# test Scatter
# rank 0
in_shape = list(self.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册