Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
27f245cd
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
27f245cd
编写于
9月 04, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix alltoall ut, test=develop
上级
47f51e07
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
235 addition
and
10 deletion
+235
-10
paddle/fluid/operators/collective/c_alltoall_op.cc
paddle/fluid/operators/collective/c_alltoall_op.cc
+77
-0
paddle/fluid/operators/collective/c_alltoall_op.cu.cc
paddle/fluid/operators/collective/c_alltoall_op.cu.cc
+86
-0
paddle/fluid/operators/collective/c_alltoall_op.h
paddle/fluid/operators/collective/c_alltoall_op.h
+68
-0
python/paddle/fluid/tests/unittests/test_collective_base.py
python/paddle/fluid/tests/unittests/test_collective_base.py
+4
-10
未找到文件。
paddle/fluid/operators/collective/c_alltoall_op.cc
0 → 100644
浏览文件 @
27f245cd
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_alltoall_op.h"
namespace
paddle
{
namespace
operators
{
class
CAllToAllOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"CAllToAll"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"CAllToAll"
);
int
ring_id
=
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
);
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for c_scatter_op must be non-negative."
,
ring_id
));
framework
::
DDim
dim
=
ctx
->
GetInputDim
(
"X"
);
if
(
dim
[
0
]
<
0
)
dim
[
0
]
=
-
1
;
ctx
->
SetOutputDim
(
"Out"
,
dim
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
};
class
CAllToAllOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) tensor send."
);
AddOutput
(
"Out"
,
"(Tensor) the result of alltoall."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) nccl communication ring id."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
CAllToAll Operator
Gather tensors from all participators to all participators.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
c_alltoall
,
ops
::
CAllToAllOp
,
ops
::
CAllToAllOpMaker
);
REGISTER_OP_CPU_KERNEL
(
c_alltoall
,
ops
::
CAllToAllOpCPUKernel
<
float
>
,
ops
::
CAllToAllOpCPUKernel
<
double
>
,
ops
::
CAllToAllOpCPUKernel
<
int
>
,
ops
::
CAllToAllOpCPUKernel
<
int64_t
>
,
ops
::
CAllToAllOpCPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_alltoall_op.cu.cc
0 → 100644
浏览文件 @
27f245cd
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_alltoall_op.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
CAllToAllOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
int
send_numel
=
x
->
numel
();
ncclDataType_t
dtype
=
platform
::
ToNCCLDataType
(
x
->
type
());
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
int
nranks
=
comm
->
nranks
();
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for c_scatter_op must be non-negative."
,
ring_id
));
cudaStream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
stream
=
comm
->
stream
();
}
framework
::
DDim
x_dims
=
x
->
dims
();
framework
::
DDim
out_dims
(
x_dims
);
auto
send_buf
=
x
->
data
<
T
>
();
auto
recv_buf
=
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
size_t
offset
=
0
;
send_numel
/=
nranks
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
i
=
0
;
i
<
nranks
;
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclSend
(
send_buf
+
offset
,
send_numel
,
dtype
,
i
,
comm
->
comm
(),
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
recv_buf
+
offset
,
send_numel
,
dtype
,
i
,
comm
->
comm
(),
stream
));
offset
+=
send_numel
;
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
#else
PADDLE_ENFORCE_EQ
(
true
,
false
,
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
c_alltoall
,
ops
::
CAllToAllOpCUDAKernel
<
float
>
,
ops
::
CAllToAllOpCUDAKernel
<
double
>
,
ops
::
CAllToAllOpCUDAKernel
<
int
>
,
ops
::
CAllToAllOpCUDAKernel
<
int64_t
>
,
ops
::
CAllToAllOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_alltoall_op.h
0 → 100644
浏览文件 @
27f245cd
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/gather.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
CAllToAllOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_GLOO)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
root_id
=
ctx
.
Attr
<
int
>
(
"root"
);
auto
gloo
=
paddle
::
framework
::
GlooWrapper
::
GetInstance
();
PADDLE_ENFORCE_EQ
(
gloo
->
IsInitialized
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"You must initialize the gloo environment first to use it."
));
int64_t
send_numel
=
in
->
numel
();
int64_t
recv_numel
=
out
->
numel
();
auto
nranks
=
gloo
->
Size
();
auto
rank
=
gloo
->
Rank
();
T
*
recv_buff
=
out
->
data
<
T
>
();
T
*
send_buff
=
in
->
data
<
T
>
();
gloo
::
GatherOptions
opts
(
gloo
->
GetContext
());
opts
.
setOutput
(
recv_buff
,
recv_numel
);
opts
.
setInput
(
send_buff
,
send_numel
);
opts
.
setRoot
(
root_id
);
gloo
::
alltoall
(
opts
);
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_collective_base.py
浏览文件 @
27f245cd
...
...
@@ -267,16 +267,10 @@ class TestDistBase(unittest.TestCase):
elif
col_type
==
"alltoall"
:
temp11
,
temp12
=
np
.
split
(
input1
,
2
)
temp21
,
temp22
=
np
.
split
(
input2
,
2
)
need_result1
=
np
.
hstack
((
temp11
,
temp21
))
need_result2
=
np
.
hstack
((
temp12
,
temp22
))
print
(
"input1:"
,
input1
)
print
(
"input2:"
,
input2
)
print
(
"need_result1:"
,
need_result1
)
print
(
"need_result2:"
,
need_result2
)
print
(
"tr0_out:"
,
tr0_out
)
print
(
"tr1_out:"
,
tr1_out
)
self
.
assertTrue
(
np
.
allclose
(
tr1_out
,
need_result1
))
self
.
assertTrue
(
np
.
allclose
(
tr2_out
,
need_result2
))
need_result1
=
np
.
vstack
((
temp11
,
temp21
))
need_result2
=
np
.
vstack
((
temp12
,
temp22
))
self
.
assertTrue
(
np
.
allclose
(
tr0_out
,
need_result1
))
self
.
assertTrue
(
np
.
allclose
(
tr1_out
,
need_result2
))
elif
col_type
==
"reduce_scatter"
:
tmp
=
input1
+
input2
need_result1
=
tmp
[
0
:
tmp
.
shape
[
0
]
//
2
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录