Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
18d33346
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
18d33346
编写于
11月 09, 2022
作者:
L
LiYuRio
提交者:
GitHub
11月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
new mp_allreduce_sum_op (#47715)
上级
38ba5f2e
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
246 addition
and
17 deletion
+246
-17
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
+4
-6
paddle/fluid/operators/collective/mp_allreduce_sum_op.cc
paddle/fluid/operators/collective/mp_allreduce_sum_op.cc
+97
-0
paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc
paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc
+30
-0
paddle/fluid/operators/collective/mp_allreduce_sum_op.kps
paddle/fluid/operators/collective/mp_allreduce_sum_op.kps
+30
-0
paddle/fluid/operators/collective/mp_allreduce_sum_op_mlu.cc
paddle/fluid/operators/collective/mp_allreduce_sum_op_mlu.cc
+26
-0
paddle/fluid/operators/collective/mp_allreduce_sum_op_npu.cc
paddle/fluid/operators/collective/mp_allreduce_sum_op_npu.cc
+31
-0
paddle/fluid/operators/collective/mp_allreduce_sum_op_xpu.cc
paddle/fluid/operators/collective/mp_allreduce_sum_op_xpu.cc
+23
-0
python/paddle/distributed/fleet/layers/mpu/mp_ops.py
python/paddle/distributed/fleet/layers/mpu/mp_ops.py
+3
-9
python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_static_mp_layers.py
...unittests/collective/fleet/test_fleet_static_mp_layers.py
+2
-2
未找到文件。
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
浏览文件 @
18d33346
...
...
@@ -62,12 +62,10 @@ DECLARE_INPLACE_OP_INFERER(AllreduceSumInplaceInferer, {"X", "Out"});
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP
ERATOR
(
c_allreduce_sum
,
REGISTER_OP
_WITHOUT_GRADIENT
(
c_allreduce_sum
,
ops
::
CAllReduceOp
,
ops
::
CAllReduceSumOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
CAllReduceSumOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
CAllReduceSumOpMaker
,
ops
::
AllreduceSumInplaceInferer
);
ops
::
AllreduceSumInplaceInferer
)
REGISTER_OP_CPU_KERNEL
(
c_allreduce_sum
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedSum
,
float
>
,
...
...
paddle/fluid/operators/collective/mp_allreduce_sum_op.cc
0 → 100644
浏览文件 @
18d33346
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
paddle
{
namespace
framework
{
class
OpDesc
;
}
// namespace framework
namespace
imperative
{
class
OpBase
;
}
// namespace imperative
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
class
MpAllReduceSumOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
}
};
class
MpAllReduceSumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor), tensor to be allreduced in model parallel."
);
AddOutput
(
"Out"
,
"(Tensor) the allreduced result in model parallel."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) communication ring id."
)
.
SetDefault
(
0
);
#if defined(PADDLE_WITH_ASCEND_CL)
AddAttr
<
std
::
string
>
(
"tag"
,
"(string default tag) tag for all reduce."
)
.
SetDefault
(
"tag"
);
#endif
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
.
SetDefault
(
false
);
AddComment
(
string
::
Sprintf
(
R"DOC(
MpAllReduceSum Operator
Call collective AllReduceSum in model parallel. If input and output are
the same variable, in-place allreduce will be used.
Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce
)DOC"
));
}
};
template
<
typename
T
>
class
MpAllReduceSumOpGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
retv
)
const
override
{
retv
->
SetType
(
"c_identity"
);
retv
->
SetInput
(
"X"
,
this
->
OutputGrad
(
"Out"
));
retv
->
SetOutput
(
"Out"
,
this
->
InputGrad
(
"X"
));
retv
->
SetAttrMap
(
this
->
Attrs
());
}
};
DECLARE_INPLACE_OP_INFERER
(
MpAllReduceSumInplaceInferer
,
{
"X"
,
"Out"
});
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OPERATOR
(
mp_allreduce_sum
,
ops
::
MpAllReduceSumOp
,
ops
::
MpAllReduceSumOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
MpAllReduceSumOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
MpAllReduceSumOpMaker
,
ops
::
MpAllReduceSumInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
mp_allreduce_sum
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedSum
,
float
>
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedSum
,
double
>
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedSum
,
int
>
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedSum
,
int64_t
>
,
ops
::
CAllReduceOpCPUKernel
<
ops
::
kRedSum
,
plat
::
float16
>
)
paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc
0 → 100644
浏览文件 @
18d33346
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
mp_allreduce_sum
,
ops
::
CAllReduceOpCUDAKernel
<
ops
::
kRedSum
,
float
>
,
#if NCCL_VERSION_CODE >= 21000
ops
::
CAllReduceOpCUDAKernel
<
ops
::
kRedSum
,
plat
::
bfloat16
>
,
#endif
ops
::
CAllReduceOpCUDAKernel
<
ops
::
kRedSum
,
double
>
,
ops
::
CAllReduceOpCUDAKernel
<
ops
::
kRedSum
,
int
>
,
ops
::
CAllReduceOpCUDAKernel
<
ops
::
kRedSum
,
int64_t
>
,
ops
::
CAllReduceOpCUDAKernel
<
ops
::
kRedSum
,
plat
::
float16
>
)
paddle/fluid/operators/collective/mp_allreduce_sum_op.kps
0 → 100644
浏览文件 @
18d33346
#ifdef PADDLE_WITH_XPU_KP
// Please do not modify the following code
#if defined(__CUDA_ARCH__)
#undef __CUDA_ARCH__
#endif
#if defined(__CUDACC__)
#undef __CUDACC__
#endif
#if defined(__CUDA__)
#undef __CUDA__
#endif
#if defined(__NVCC__)
#undef __NVCC__
#endif
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(mp_allreduce_sum,
KP,
plat::XPUPlace,
ops::CAllReduceOpXPUKernel<ops::kRedSum, float>);
#endif
paddle/fluid/operators/collective/mp_allreduce_sum_op_mlu.cc
0 → 100644
浏览文件 @
18d33346
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_MLU_KERNEL
(
mp_allreduce_sum
,
ops
::
CAllReduceOpMLUKernel
<
ops
::
kRedSum
,
float
>
,
ops
::
CAllReduceOpMLUKernel
<
ops
::
kRedSum
,
plat
::
float16
>
,
ops
::
CAllReduceOpMLUKernel
<
ops
::
kRedSum
,
int
>
,
ops
::
CAllReduceOpMLUKernel
<
ops
::
kRedSum
,
int16_t
>
,
ops
::
CAllReduceOpMLUKernel
<
ops
::
kRedSum
,
int8_t
>
,
ops
::
CAllReduceOpMLUKernel
<
ops
::
kRedSum
,
uint8_t
>
)
paddle/fluid/operators/collective/mp_allreduce_sum_op_npu.cc
0 → 100644
浏览文件 @
18d33346
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
paddle
{
namespace
platform
{
struct
ASCENDPlace
;
}
// namespace platform
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
mp_allreduce_sum
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
int
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
int8_t
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
float
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
plat
::
float16
>
)
paddle/fluid/operators/collective/mp_allreduce_sum_op_xpu.cc
0 → 100644
浏览文件 @
18d33346
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_XPU_KERNEL
(
mp_allreduce_sum
,
ops
::
CAllReduceOpXPUKernel
<
ops
::
kRedSum
,
float
>
,
ops
::
CAllReduceOpXPUKernel
<
ops
::
kRedSum
,
plat
::
float16
>
,
ops
::
CAllReduceOpXPUKernel
<
ops
::
kRedSum
,
int
>
)
python/paddle/distributed/fleet/layers/mpu/mp_ops.py
浏览文件 @
18d33346
...
...
@@ -266,8 +266,6 @@ def _mp_allreduce(
use_calc_stream
,
'ring_id'
,
ring_id
,
"use_model_parallel"
,
use_model_parallel
,
)
@
staticmethod
...
...
@@ -289,19 +287,17 @@ def _mp_allreduce(
ring_id
=
0
if
group
is
None
else
group
.
id
if
_in_legacy_dygraph
():
if
op
==
ReduceOp
.
SUM
:
return
_legacy_C_ops
.
c
_allreduce_sum_
(
return
_legacy_C_ops
.
mp
_allreduce_sum_
(
tensor
,
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
"use_model_parallel"
,
use_model_parallel
,
)
else
:
raise
ValueError
(
"Unknown parameter: {}."
.
format
(
op
))
op_type
=
'
c
_allreduce_sum'
op_type
=
'
mp
_allreduce_sum'
helper
=
LayerHelper
(
op_type
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
tensor
.
dtype
)
...
...
@@ -319,7 +315,6 @@ def _mp_allreduce(
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
use_calc_stream
,
'use_model_parallel'
:
use_model_parallel
,
},
)
return
out
...
...
@@ -602,13 +597,12 @@ def _parallel_linear(
)
if
axis
==
0
:
main_block
.
append_op
(
type
=
'
c
_allreduce_sum'
,
type
=
'
mp
_allreduce_sum'
,
inputs
=
{
'X'
:
linear_out
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
},
)
if
linear
.
bias
is
not
None
:
...
...
python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_static_mp_layers.py
浏览文件 @
18d33346
...
...
@@ -128,7 +128,7 @@ class TestDistTraning(unittest.TestCase):
ops
=
[
op
.
type
for
op
in
ops
]
self
.
assertEqual
(
ops
,
[
'c_split'
,
'matmul_v2'
,
'
c
_allreduce_sum'
,
'elementwise_add'
],
[
'c_split'
,
'matmul_v2'
,
'
mp
_allreduce_sum'
,
'elementwise_add'
],
)
weight
=
model_a
.
parallel_linear
.
weight
...
...
@@ -156,7 +156,7 @@ class TestDistTraning(unittest.TestCase):
# print(main_program)
ops
=
main_program
.
global_block
().
ops
ops
=
[
op
.
type
for
op
in
ops
]
self
.
assertEqual
(
ops
,
[
'c_embedding'
,
'
c
_allreduce_sum'
])
self
.
assertEqual
(
ops
,
[
'c_embedding'
,
'
mp
_allreduce_sum'
])
weight
=
model_a
.
embedding
.
weight
self
.
assertEqual
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录