Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f6a85db9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f6a85db9
编写于
10月 11, 2022
作者:
C
Chenxiao Niu
提交者:
GitHub
10月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU] add int64 support for allgather. (#46830)
上级
7541579a
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
44 addition
and
6 deletion
+44
-6
paddle/fluid/operators/collective/c_allgather_op_mlu.cc
paddle/fluid/operators/collective/c_allgather_op_mlu.cc
+44
-6
未找到文件。
paddle/fluid/operators/collective/c_allgather_op_mlu.cc
浏览文件 @
f6a85db9
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#if defined(PADDLE_WITH_CNCL)
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/collective_helper.h"
...
@@ -27,15 +28,14 @@ template <typename T>
...
@@ -27,15 +28,14 @@ template <typename T>
class
CAllGatherOpMLUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CAllGatherOpMLUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
place
=
ctx
.
GetPlace
();
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
#if defined(PADDLE_WITH_CNCL)
#if defined(PADDLE_WITH_CNCL)
auto
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
cnclDataType_t
dtype
=
platform
::
ToCNCLDataType
(
framework
::
TransToProtoVarType
(
x
->
dtype
()));
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
CNCLCommContext
::
Instance
().
Get
(
rid
,
place
);
auto
comm
=
platform
::
CNCLCommContext
::
Instance
().
Get
(
rid
,
place
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
nranks
,
nranks
,
...
@@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel<T> {
...
@@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel<T> {
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
uint32_t
send_numel
=
x
->
numel
();
uint32_t
send_numel
=
x
->
numel
();
void
*
send_buff
=
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
x
->
data
<
T
>
()));
void
*
send_buff
;
void
*
recv_buff
=
reinterpret_cast
<
void
*>
(
out
->
data
<
T
>
());
void
*
recv_buff
;
phi
::
DenseTensor
in_tensor
,
out_tensor
;
if
(
framework
::
TransToProtoVarType
(
x
->
dtype
())
==
framework
::
proto
::
VarType
::
INT64
)
{
// cast from int64 to int32 since cncl do not support int64
in_tensor
.
mutable_data
<
int32_t
>
(
x
->
dims
(),
place
);
out_tensor
.
mutable_data
<
int32_t
>
(
out
->
dims
(),
place
);
MLUCnnlTensorDesc
x_int64_desc
(
*
x
);
MLUCnnlTensorDesc
x_int32_desc
(
in_tensor
);
cnnlCastDataType_t
cast_type
=
GetCastDataType
(
VT
::
INT64
,
VT
::
INT32
);
MLUCnnl
::
Cast
(
ctx
,
cast_type
,
x_int64_desc
.
get
(),
GetBasePtr
(
x
),
x_int32_desc
.
get
(),
GetBasePtr
(
&
in_tensor
));
send_buff
=
reinterpret_cast
<
void
*>
(
in_tensor
.
data
<
int32_t
>
());
recv_buff
=
reinterpret_cast
<
void
*>
(
out_tensor
.
data
<
int32_t
>
());
}
else
{
in_tensor
.
ShareDataWith
(
*
x
);
out_tensor
.
ShareDataWith
(
*
out
);
send_buff
=
reinterpret_cast
<
void
*>
(
in_tensor
.
data
<
T
>
());
recv_buff
=
reinterpret_cast
<
void
*>
(
out_tensor
.
data
<
T
>
());
}
mluStream
stream
=
nullptr
;
mluStream
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
MLUDeviceContext
*>
(
dev_ctx
)
->
stream
();
stream
=
static_cast
<
platform
::
MLUDeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
}
else
{
stream
=
comm
->
stream
();
stream
=
comm
->
stream
();
}
}
cnclDataType_t
dtype
=
platform
::
ToCNCLDataType
(
framework
::
TransToProtoVarType
(
in_tensor
.
dtype
()));
PADDLE_ENFORCE_MLU_SUCCESS
(
cnclAllGather
(
PADDLE_ENFORCE_MLU_SUCCESS
(
cnclAllGather
(
send_buff
,
recv_buff
,
send_numel
,
dtype
,
comm
->
comm
(),
stream
));
send_buff
,
recv_buff
,
send_numel
,
dtype
,
comm
->
comm
(),
stream
));
if
(
framework
::
TransToProtoVarType
(
x
->
dtype
())
==
framework
::
proto
::
VarType
::
INT64
)
{
// cast back from int64 out_tensor to out
MLUCnnlTensorDesc
out_int64_desc
(
*
out
);
MLUCnnlTensorDesc
out_int32_desc
(
out_tensor
);
cnnlCastDataType_t
cast_type
=
GetCastDataType
(
VT
::
INT32
,
VT
::
INT64
);
MLUCnnl
::
Cast
(
ctx
,
cast_type
,
out_int32_desc
.
get
(),
GetBasePtr
(
&
out_tensor
),
out_int64_desc
.
get
(),
GetBasePtr
(
out
));
}
#else
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with MLU."
));
"PaddlePaddle should compile with MLU."
));
...
@@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather,
...
@@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather,
ops
::
CAllGatherOpMLUKernel
<
int
>
,
ops
::
CAllGatherOpMLUKernel
<
int
>
,
ops
::
CAllGatherOpMLUKernel
<
int8_t
>
,
ops
::
CAllGatherOpMLUKernel
<
int8_t
>
,
ops
::
CAllGatherOpMLUKernel
<
int16_t
>
,
ops
::
CAllGatherOpMLUKernel
<
int16_t
>
,
ops
::
CAllGatherOpMLUKernel
<
int64_t
>
,
ops
::
CAllGatherOpMLUKernel
<
plat
::
float16
>
);
ops
::
CAllGatherOpMLUKernel
<
plat
::
float16
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录