Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ee76ea72
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ee76ea72
编写于
2月 24, 2021
作者:
Q
Qi Li
提交者:
GitHub
2月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid collective op for rocm, test=develop (#31075)
上级
d8fa65a3
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
82 addition
and
43 deletion
+82
-43
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
+0
-2
paddle/fluid/operators/benchmark/op_tester.cc
paddle/fluid/operators/benchmark/op_tester.cc
+1
-1
paddle/fluid/operators/collective/CMakeLists.txt
paddle/fluid/operators/collective/CMakeLists.txt
+1
-1
paddle/fluid/operators/collective/allreduce_op.h
paddle/fluid/operators/collective/allreduce_op.h
+6
-2
paddle/fluid/operators/collective/barrier_op.cu.cc
paddle/fluid/operators/collective/barrier_op.cu.cc
+6
-2
paddle/fluid/operators/collective/broadcast_op.cu.cc
paddle/fluid/operators/collective/broadcast_op.cu.cc
+6
-2
paddle/fluid/operators/collective/c_allgather_op.cu.cc
paddle/fluid/operators/collective/c_allgather_op.cu.cc
+3
-3
paddle/fluid/operators/collective/c_allreduce_op.h
paddle/fluid/operators/collective/c_allreduce_op.h
+3
-3
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
+3
-3
paddle/fluid/operators/collective/c_comm_init_all_op.cc
paddle/fluid/operators/collective/c_comm_init_all_op.cc
+2
-2
paddle/fluid/operators/collective/c_comm_init_op.cc
paddle/fluid/operators/collective/c_comm_init_op.cc
+6
-2
paddle/fluid/operators/collective/c_reduce_op.h
paddle/fluid/operators/collective/c_reduce_op.h
+3
-3
paddle/fluid/operators/collective/c_reducescatter_op.cu.cc
paddle/fluid/operators/collective/c_reducescatter_op.cu.cc
+3
-3
paddle/fluid/operators/collective/c_scatter_op.cu.cc
paddle/fluid/operators/collective/c_scatter_op.cu.cc
+3
-3
paddle/fluid/operators/collective/c_sync_calc_stream_op.cc
paddle/fluid/operators/collective/c_sync_calc_stream_op.cc
+5
-1
paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
+6
-2
paddle/fluid/operators/collective/recv_v2_op.cu.cc
paddle/fluid/operators/collective/recv_v2_op.cu.cc
+13
-3
paddle/fluid/operators/collective/send_v2_op.cu.cc
paddle/fluid/operators/collective/send_v2_op.cu.cc
+10
-3
paddle/fluid/operators/detail/strided_memcpy.h
paddle/fluid/operators/detail/strided_memcpy.h
+2
-2
未找到文件。
paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
浏览文件 @
ee76ea72
...
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda.h>
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
...
...
paddle/fluid/operators/benchmark/op_tester.cc
浏览文件 @
ee76ea72
...
...
@@ -77,7 +77,7 @@ void OpTester::Run() {
if
(
platform
::
is_cpu_place
(
place_
))
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
}
else
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
platform
::
SetDeviceId
(
config_
.
device_id
);
#else
...
...
paddle/fluid/operators/collective/CMakeLists.txt
浏览文件 @
ee76ea72
...
...
@@ -13,7 +13,7 @@ endforeach()
register_operators
(
EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
if
(
WITH_NCCL
)
if
(
WITH_NCCL
OR WITH_RCCL
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
nccl_common collective_helper
)
op_library
(
c_gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
op_library
(
gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
...
...
paddle/fluid/operators/collective/allreduce_op.h
浏览文件 @
ee76ea72
...
...
@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -36,7 +36,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"AllReduce op can run on gpu place only for now."
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
...
@@ -73,7 +73,11 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
comm
,
stream
));
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
...
...
paddle/fluid/operators/collective/barrier_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/barrier_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -26,7 +26,7 @@ template <typename T>
class
BarrierOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
...
@@ -45,7 +45,11 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
sendbuff
,
recvbuff
,
numel
,
dtype
,
nccl_red_type
,
comm
->
comm
(),
stream
));
auto
comm_stream
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
)
->
stream
();
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
comm_stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
comm_stream
));
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with NCCL."
));
...
...
paddle/fluid/operators/collective/broadcast_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -33,7 +33,7 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
platform
::
errors
::
PreconditionNotMet
(
"The place of ExecutionContext should be CUDAPlace."
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
int
dev_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
()).
device
;
int
root_dev_id
=
ctx
.
Attr
<
int
>
(
"root"
);
...
...
@@ -62,7 +62,11 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
<<
" From "
<<
root_dev_id
<<
" to "
<<
dev_id
;
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
...
...
paddle/fluid/operators/collective/c_allgather_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -26,7 +26,7 @@ template <typename T>
class
CAllGatherOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
ncclDataType_t
dtype
=
platform
::
ToNCCLDataType
(
in
->
type
());
...
...
@@ -48,7 +48,7 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
const
T
*
send_buff
=
in
->
data
<
T
>
();
T
*
recv_buff
=
out
->
data
<
T
>
();
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
...
...
paddle/fluid/operators/collective/c_allreduce_op.h
浏览文件 @
ee76ea72
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -109,7 +109,7 @@ template <ReduceType red_type, typename T>
class
CAllReduceOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
...
@@ -123,7 +123,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
...
...
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -26,7 +26,7 @@ template <typename T>
class
CBroadcastOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
int
numel
=
x
->
numel
();
...
...
@@ -36,7 +36,7 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
...
...
paddle/fluid/operators/collective/c_comm_init_all_op.cc
浏览文件 @
ee76ea72
...
...
@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -52,7 +52,7 @@ class CCommInitAllOp : public framework::OperatorBase {
platform
::
errors
::
PreconditionNotMet
(
"CCommInitAllOp can run on gpu place only"
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
std
::
vector
<
int
>
devices
=
Attr
<
std
::
vector
<
int
>>
(
"devices"
);
if
(
devices
.
empty
())
{
devices
=
platform
::
GetSelectedDevices
();
...
...
paddle/fluid/operators/collective/c_comm_init_op.cc
浏览文件 @
ee76ea72
...
...
@@ -14,6 +14,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL)
#include <nccl.h>
#endif
#if defined(PADDLE_WITH_RCCL)
#include <rccl.h>
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif
...
...
@@ -26,7 +29,8 @@ namespace framework {
class
Scope
;
}
// namespace framework
}
// namespace paddle
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
...
...
@@ -50,7 +54,7 @@ class CCommInitOp : public framework::OperatorBase {
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
InvalidArgument
(
"Input con not be empty."
));
if
(
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
ncclUniqueId
*
nccl_id
=
var
->
GetMutable
<
ncclUniqueId
>
();
int
nranks
=
Attr
<
int
>
(
"nranks"
);
...
...
paddle/fluid/operators/collective/c_reduce_op.h
浏览文件 @
ee76ea72
...
...
@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -114,7 +114,7 @@ template <ReduceType red_type, typename T>
class
CReduceOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
...
@@ -129,7 +129,7 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
int
root
=
ctx
.
Attr
<
int
>
(
"root_id"
);
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
...
...
paddle/fluid/operators/collective/c_reducescatter_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -26,7 +26,7 @@ template <typename T>
class
CReduceScatterOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
...
@@ -49,7 +49,7 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
T
*
recv_buff
=
out
->
data
<
T
>
();
int
dtype
=
platform
::
ToNCCLDataType
(
in
->
type
());
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
...
...
paddle/fluid/operators/collective/c_scatter_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_scatter_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -26,7 +26,7 @@ template <typename T>
class
CScatterOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
int
numel
=
x
->
numel
();
...
...
@@ -53,7 +53,7 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
"The ring_id (%d) for c_scatter_op must be non-negative."
,
ring_id
));
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
...
...
paddle/fluid/operators/collective/c_sync_calc_stream_op.cc
浏览文件 @
ee76ea72
...
...
@@ -37,10 +37,14 @@ class CSyncCalcStreamOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Sync stream op can run on gpu place only for now."
));
#if
defined(PADDLE_WITH_CUDA
) && !defined(_WIN32)
#if
(defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
) && !defined(_WIN32)
auto
dev_ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
dev_ctx
->
stream
()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
dev_ctx
->
stream
()));
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with GPU."
));
...
...
paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
浏览文件 @
ee76ea72
...
...
@@ -19,7 +19,7 @@ namespace framework {
class
Scope
;
}
// namespace framework
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
...
...
@@ -40,11 +40,15 @@ class CSyncCommStreamOp : public framework::OperatorBase {
platform
::
errors
::
PreconditionNotMet
(
"Sync stream op can run on gpu place only for now."
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
auto
stream
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
)
->
stream
();
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with GPU."
));
...
...
paddle/fluid/operators/collective/recv_v2_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/recv_v2_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -26,7 +26,8 @@ template <typename T>
class
RecvOpV2CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
PADDLE_ENFORCE_GE
(
rid
,
0
,
...
...
@@ -45,7 +46,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
framework
::
proto
::
VarType
::
Type
type
=
framework
::
proto
::
VarType
::
Type
(
data_type
);
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
...
...
@@ -65,12 +66,21 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
// Recv the number of elements to receive first
int
numel
=
0
;
int
*
numel_ptr
=
nullptr
;
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMalloc
(
&
numel_ptr
,
sizeof
(
int
)));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMalloc
(
&
numel_ptr
,
sizeof
(
int
)));
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
static_cast
<
void
*>
(
numel_ptr
),
1
,
ncclInt
,
peer
,
comm
->
comm
(),
stream
));
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMemcpy
(
&
numel
,
numel_ptr
,
sizeof
(
int
),
hipMemcpyDeviceToHost
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpy
(
&
numel
,
numel_ptr
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
#endif
int
rest_numel
=
1
;
for
(
int
i
=
1
;
i
<
out_dims
.
size
();
++
i
)
{
...
...
paddle/fluid/operators/collective/send_v2_op.cu.cc
浏览文件 @
ee76ea72
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/send_v2_op.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -26,7 +26,8 @@ template <typename T>
class
SendOpV2CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
int
numel
=
x
->
numel
();
...
...
@@ -41,7 +42,7 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
peer
,
0
,
platform
::
errors
::
InvalidArgument
(
"The peer (%d) for send_v2 op must be non-negative."
,
peer
));
cuda
Stream_t
stream
=
nullptr
;
gpu
Stream_t
stream
=
nullptr
;
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
...
...
@@ -59,9 +60,15 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
// Send number of elements to the receiver, as the receiver may have
// no information of the Tensor size.
int
*
numel_ptr
=
nullptr
;
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMalloc
(
&
numel_ptr
,
sizeof
(
int
)));
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMemcpy
(
numel_ptr
,
&
numel
,
sizeof
(
int
),
hipMemcpyHostToDevice
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMalloc
(
&
numel_ptr
,
sizeof
(
int
)));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpy
(
numel_ptr
,
&
numel
,
sizeof
(
int
),
cudaMemcpyHostToDevice
));
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclSend
(
numel_ptr
,
1
,
ncclInt
,
peer
,
comm
->
comm
(),
stream
));
...
...
paddle/fluid/operators/detail/strided_memcpy.h
浏览文件 @
ee76ea72
...
...
@@ -34,7 +34,7 @@ struct StridedMemcpyFunctor<T, 0> {
auto
&
cpu_place
=
BOOST_GET_CONST
(
platform
::
CPUPlace
,
place
);
memory
::
Copy
(
cpu_place
,
dst
,
cpu_place
,
src
,
sizeof
(
T
));
}
else
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto
&
gpu_place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
);
auto
&
cuda_ctx
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
...
...
@@ -58,7 +58,7 @@ struct StridedMemcpyFunctor<T, 1> {
auto
&
cpu_place
=
BOOST_GET_CONST
(
platform
::
CPUPlace
,
place
);
memory
::
Copy
(
cpu_place
,
dst
,
cpu_place
,
src
,
sizeof
(
T
)
*
dst_dim
[
0
]);
}
else
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto
&
gpu_place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
);
auto
&
cuda_ctx
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录