Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0d65233b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0d65233b
编写于
9月 27, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add send, recv and alltoall ops, test=develop
上级
3f170dd8
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
631 addition
and
0 deletion
+631
-0
paddle/fluid/operators/collective/alltoall_op.cc
paddle/fluid/operators/collective/alltoall_op.cc
+75
-0
paddle/fluid/operators/collective/alltoall_op.cu.cc
paddle/fluid/operators/collective/alltoall_op.cu.cc
+86
-0
paddle/fluid/operators/collective/alltoall_op.h
paddle/fluid/operators/collective/alltoall_op.h
+43
-0
paddle/fluid/operators/collective/recv_op_v2.cc
paddle/fluid/operators/collective/recv_op_v2.cc
+93
-0
paddle/fluid/operators/collective/recv_op_v2.cu.cc
paddle/fluid/operators/collective/recv_op_v2.cu.cc
+96
-0
paddle/fluid/operators/collective/recv_op_v2.h
paddle/fluid/operators/collective/recv_op_v2.h
+38
-0
paddle/fluid/operators/collective/send_op_v2.cc
paddle/fluid/operators/collective/send_op_v2.cc
+77
-0
paddle/fluid/operators/collective/send_op_v2.cu.cc
paddle/fluid/operators/collective/send_op_v2.cu.cc
+85
-0
paddle/fluid/operators/collective/send_op_v2.h
paddle/fluid/operators/collective/send_op_v2.h
+38
-0
未找到文件。
paddle/fluid/operators/collective/alltoall_op.cc
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h"
namespace
paddle
{
namespace
operators
{
class
AllToAllOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"CAllToAll"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"CAllToAll"
);
int
ring_id
=
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
);
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for alltoall_op must be non-negative."
,
ring_id
));
framework
::
DDim
dim
=
ctx
->
GetInputDim
(
"X"
);
if
(
dim
[
0
]
<
0
)
dim
[
0
]
=
-
1
;
ctx
->
SetOutputDim
(
"Out"
,
dim
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
};
class
AllToAllOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) tensor send."
);
AddOutput
(
"Out"
,
"(Tensor) the result of alltoall."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) nccl communication ring id."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AllToAll Operator
Gather tensors from all participators to all participators.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
alltoall
,
ops
::
AllToAllOp
,
ops
::
AllToAllOpMaker
);
REGISTER_OP_CPU_KERNEL
(
alltoall
,
ops
::
AllToAllOpCPUKernel
<
float
>
,
ops
::
AllToAllOpCPUKernel
<
double
>
,
ops
::
AllToAllOpCPUKernel
<
int
>
,
ops
::
AllToAllOpCPUKernel
<
int64_t
>
,
ops
::
AllToAllOpCPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/alltoall_op.cu.cc
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
AllToAllOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
int
send_numel
=
x
->
numel
();
ncclDataType_t
dtype
=
platform
::
ToNCCLDataType
(
x
->
type
());
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
int
nranks
=
comm
->
nranks
();
cudaStream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
stream
=
comm
->
stream
();
}
framework
::
DDim
x_dims
=
x
->
dims
();
framework
::
DDim
out_dims
(
x_dims
);
PADDLE_ENFORCE_EQ
(
x_dims
[
0
]
%
nranks
,
0
,
platform
::
errors
::
InvalidArgument
(
"The first dimension size (%d) of the input tensor must be "
"divisible by the number of ranks (%d)."
,
x_dims
[
0
],
nranks
));
auto
send_buf
=
x
->
data
<
T
>
();
auto
recv_buf
=
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
size_t
offset
=
0
;
send_numel
/=
nranks
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
i
=
0
;
i
<
nranks
;
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclSend
(
send_buf
+
offset
,
send_numel
,
dtype
,
i
,
comm
->
comm
(),
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
recv_buf
+
offset
,
send_numel
,
dtype
,
i
,
comm
->
comm
(),
stream
));
offset
+=
send_numel
;
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
alltoall
,
ops
::
AllToAllOpCUDAKernel
<
float
>
,
ops
::
AllToAllOpCUDAKernel
<
double
>
,
ops
::
AllToAllOpCUDAKernel
<
int
>
,
ops
::
AllToAllOpCUDAKernel
<
int64_t
>
,
ops
::
AllToAllOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/alltoall_op.h
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/gather.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
AllToAllOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Do not support alltoall for cpu kernel now."
));
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/collective/recv_op_v2.cc
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/recv_op_v2.h"
#include <string>
namespace
paddle
{
namespace
operators
{
class
RecvOpV2
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"CRecv"
);
int
peer
=
ctx
->
Attrs
().
Get
<
int
>
(
"peer"
);
int
ring_id
=
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
);
PADDLE_ENFORCE_GE
(
peer
,
0
,
platform
::
errors
::
InvalidArgument
(
"The peer (%d) for send_op_v2 must be non-negative."
,
peer
));
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for send_op_v2 must be non-negative."
,
ring_id
));
auto
out_shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"out_shape"
);
PADDLE_ENFORCE_GE
(
out_shape
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The size of the output shape must be greater than 0 "
"but the value given is %d."
,
out_shape
.
size
()));
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_shape
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
VLOG
(
0
)
<<
"wow1"
;
int
dtype
=
ctx
.
Attr
<
int
>
(
"dtype"
);
framework
::
proto
::
VarType
::
Type
type
=
framework
::
proto
::
VarType
::
Type
(
data_type
);
return
framework
::
OpKernelType
(
type
,
ctx
.
GetPlace
());
}
};
class
RecvOpV2Maker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddOutput
(
"Out"
,
"(Tensor) tensor to receive."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) nccl communication ring id."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"peer"
,
"(int default 0) rank id for sender."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dtype"
,
"(std::string default 5(float32)) data type of tensor."
)
.
SetDefault
(
5
);
AddAttr
<
std
::
vector
<
int
>>
(
"out_shape"
,
"shape of the output tensor."
)
.
SetDefault
(
std
::
vector
<
int
>
());
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Recv Operator
Reference: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#sendrecv
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
recv_v2
,
ops
::
RecvOpV2
,
ops
::
RecvOpV2Maker
);
REGISTER_OP_CPU_KERNEL
(
recv_v2
,
ops
::
RecvOpV2CPUKernel
<
float
>
,
ops
::
RecvOpV2CPUKernel
<
double
>
,
ops
::
RecvOpV2CPUKernel
<
int
>
,
ops
::
RecvOpV2CPUKernel
<
int64_t
>
,
ops
::
RecvOpV2CPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/recv_op_v2.cu.cc
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/send_op_v2.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
RecvOpV2CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
int
data_type
=
ctx
.
Attr
<
int
>
(
"dtype"
);
framework
::
proto
::
VarType
::
Type
type
=
framework
::
proto
::
VarType
::
Type
(
data_type
);
ncclDataType_t
dtype
=
platform
::
ToNCCLDataType
(
type
);
auto
out_dims
=
out
->
dims
();
// Recv the number of element first
int
numel
=
0
;
int
*
numel_ptr
=
nullptr
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMalloc
(
&
numel_ptr
,
sizeof
(
int
)));
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
int
peer
=
ctx
.
Attr
<
int
>
(
"peer"
);
PADDLE_ENFORCE_LT
(
peer
,
comm
->
nranks
(),
platform
::
errors
::
InvalidArgument
(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d)."
,
peer
,
comm
->
nranks
()));
cudaStream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
stream
=
comm
->
stream
();
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
static_cast
<
void
*>
(
numel_ptr
),
1
,
ncclInt
,
peer
,
comm
->
comm
(),
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpy
(
&
numel
,
numel_ptr
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
int
rest_numel
=
1
;
for
(
size_t
i
=
1
;
i
<
out_dims
.
size
();
++
i
)
{
rest_numel
=
rest_numel
*
out_dims
[
i
];
}
out_dims
[
0
]
=
numel
/
rest_numel
;
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
out
->
data
<
T
>
(),
numel
,
dtype
,
peer
,
comm
->
comm
(),
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
VLOG
(
3
)
<<
"rank "
<<
comm
->
rank
()
<<
" recv "
<<
framework
::
product
(
out
->
dims
())
<<
" from "
<<
peer
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
recv_v2
,
ops
::
RecvOpV2CUDAKernel
<
float
>
,
ops
::
RecvOpV2CUDAKernel
<
double
>
,
ops
::
RecvOpV2CUDAKernel
<
int
>
,
ops
::
RecvOpV2CUDAKernel
<
int64_t
>
,
ops
::
RecvOpV2CUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/recv_op_v2.h
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
RecvOpV2CPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Do not support recv for cpu kernel now."
));
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/collective/send_op_v2.cc
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/send_op_v2.h"
namespace
paddle
{
namespace
operators
{
class
SendOpV2
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"CSend"
);
int
peer
=
ctx
->
Attrs
().
Get
<
int
>
(
"peer"
);
int
ring_id
=
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
);
PADDLE_ENFORCE_GE
(
peer
,
0
,
platform
::
errors
::
InvalidArgument
(
"The peer (%d) for send_op_v2 must be non-negative."
,
peer
));
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for send_op_v2 must be non-negative."
,
ring_id
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
};
class
SendOpV2Maker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) tensor to be sent."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) nccl communication ring id."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"peer"
,
"(int default 0) rank id for receiver."
).
SetDefault
(
0
);
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Send Operator
Reference: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#sendrecv
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
send_v2
,
ops
::
SendOpV2
,
ops
::
SendOpV2Maker
);
REGISTER_OP_CPU_KERNEL
(
send_v2
,
ops
::
SendOpV2CPUKernel
<
float
>
,
ops
::
SendOpV2CPUKernel
<
double
>
,
ops
::
SendOpV2CPUKernel
<
int
>
,
ops
::
SendOpV2CPUKernel
<
int64_t
>
,
ops
::
SendOpV2CPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/send_op_v2.cu.cc
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/send_op_v2.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
SendOpV2CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
int
numel
=
x
->
numel
();
ncclDataType_t
dtype
=
platform
::
ToNCCLDataType
(
x
->
type
());
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
cudaStream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
stream
=
comm
->
stream
();
}
int
peer
=
ctx
.
Attr
<
int
>
(
"peer"
);
PADDLE_ENFORCE_LT
(
peer
,
comm
->
nranks
(),
platform
::
errors
::
InvalidArgument
(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d)."
,
peer
,
comm
->
nranks
()));
// Send number of elements to the receiver, as the receiver may have
// no information of the Tensor size.
int
*
numel_ptr
=
nullptr
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMalloc
(
&
numel_ptr
,
sizeof
(
int
)));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpy
(
numel_ptr
,
&
numel
,
sizeof
(
int
),
cudaMemcpyHostToDevice
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclSend
(
numel_ptr
,
1
,
ncclInt
,
peer
,
comm
->
comm
(),
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclSend
(
x
->
data
<
T
>
(),
numel
,
dtype
,
peer
,
comm
->
comm
(),
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
VLOG
(
3
)
<<
"rank "
<<
comm
->
rank
()
<<
" send "
<<
framework
::
product
(
x
->
dims
())
<<
" to "
<<
peer
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
send_v2
,
ops
::
SendOpV2CUDAKernel
<
float
>
,
ops
::
SendOpV2CUDAKernel
<
double
>
,
ops
::
SendOpV2CUDAKernel
<
int
>
,
ops
::
SendOpV2CUDAKernel
<
int64_t
>
,
ops
::
SendOpV2CUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/send_op_v2.h
0 → 100644
浏览文件 @
0d65233b
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
SendOpV2CPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Do not support send for cpu kernel now."
));
}
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录