Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7bd02d24
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7bd02d24
编写于
8月 29, 2022
作者:
W
Wen Sun
提交者:
GitHub
8月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Completes basic dtypes for all_reduce api in eager mode (#45440)
上级
bb3e4e0c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
75 addition
and
18 deletion
+75
-18
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+14
-18
python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py
.../unittests/collective/collective_allreduce_api_dygraph.py
+36
-0
python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py
...sts/unittests/collective/test_collective_allreduce_api.py
+25
-0
未找到文件。
python/paddle/distributed/collective.py
浏览文件 @
7bd02d24
...
...
@@ -775,8 +775,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
"""
Reduce a tensor over all ranks so that all get the result.
As shown below,
4 GPUs each start 4 processes and the data on each GPU is repres
nted
by
the GPU number
. The reduce operator is sum. Through all_reduce operator,
As shown below,
one process is started with a GPU and the data of this process is represe
nted
by
its group rank
. The reduce operator is sum. Through all_reduce operator,
each GPU will have the sum of the data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
...
...
@@ -786,8 +786,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32
or int64
.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
in
|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
IN
|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
...
...
@@ -799,21 +799,16 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import ReduceOp
from paddle.distributed import init_parallel_env
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data = np.array
([[4, 5, 6], [4, 5, 6]])
data = paddle.to_tensor
([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
paddle.distributed.all_reduce(data)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
...
...
@@ -849,9 +844,10 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
else
:
raise
ValueError
(
"Unknown parameter: {}."
.
format
(
op
))
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'all_reduce'
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
],
'all_reduce'
)
if
op
==
ReduceOp
.
SUM
:
op_type
=
'c_allreduce_sum'
elif
op
==
ReduceOp
.
MAX
:
...
...
@@ -888,7 +884,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32 or int64.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
in
|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
IN
|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
...
...
@@ -984,7 +980,7 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
"""
Gather tensors from all participators and all get the result. As shown
below, 4 GPUs each start
4 processes and the data on each GPU is repres
nted
below, 4 GPUs each start
s 4 processes and the data on each GPU is represe
nted
by the GPU number. Through the all_gather operator, each GPU will have data
from all GPUs.
...
...
@@ -2581,7 +2577,7 @@ def reduce_scatter(tensor,
Args:
tensor (Tensor): Output tensor.
tensor_list (list[Tensor]): List of tensors to reduce and scatter.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
in
|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
IN
|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global
default group. Default: None.
use_calc_stream (bool, optional): Whether this op should be an async op.
...
...
@@ -2654,7 +2650,7 @@ def _reduce_scatter_base(output,
Args:
output (Tensor): Output tensor.
input (Tensor): Input tensor that is of size output tensor size times world size
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
in
|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.M
IN
|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False).
...
...
python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py
0 → 100644
浏览文件 @
7bd02d24
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveAllreduceAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
paddle
.
distributed
.
all_reduce
(
tindata
)
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveAllreduceAPI
,
"allreduce"
)
python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py
浏览文件 @
7bd02d24
...
...
@@ -41,6 +41,31 @@ class TestCollectiveAllreduceAPI(TestDistBase):
self
.
check_with_place
(
"collective_allreduce_api.py"
,
"allreduce"
,
"gloo"
,
"2"
)
def
test_allreduce_nccl_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allreduce_api_dygraph.py"
,
"allreduce"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
def
test_allreduce_gloo_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allreduce_api_dygraph.py"
,
"allreduce"
,
"gloo"
,
"2"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录