Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2c84debb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2c84debb
编写于
9月 27, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ut, test=develop
上级
3f170dd8
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
418 addition
and
49 deletion
+418
-49
paddle/fluid/operators/collective/gather_op_v2.cc
paddle/fluid/operators/collective/gather_op_v2.cc
+97
-0
paddle/fluid/operators/collective/gather_op_v2.cu.cc
paddle/fluid/operators/collective/gather_op_v2.cu.cc
+98
-0
paddle/fluid/operators/collective/gather_op_v2.h
paddle/fluid/operators/collective/gather_op_v2.h
+78
-0
paddle/fluid/operators/collective/scatter_op_v2.cc
paddle/fluid/operators/collective/scatter_op_v2.cc
+20
-14
paddle/fluid/operators/collective/scatter_op_v2.cu.cc
paddle/fluid/operators/collective/scatter_op_v2.cu.cc
+24
-32
paddle/fluid/operators/collective/scatter_op_v2.h
paddle/fluid/operators/collective/scatter_op_v2.h
+1
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/collective_gather_op.py
python/paddle/fluid/tests/unittests/collective_gather_op.py
+66
-0
python/paddle/fluid/tests/unittests/collective_scatter_op.py
python/paddle/fluid/tests/unittests/collective_scatter_op.py
+2
-2
python/paddle/fluid/tests/unittests/test_collective_gather.py
...on/paddle/fluid/tests/unittests/test_collective_gather.py
+31
-0
未找到文件。
paddle/fluid/operators/collective/gather_op_v2.cc
0 → 100644
浏览文件 @
2c84debb
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/gather_op_v2.h"
namespace
paddle
{
namespace
operators
{
class
GatherOpV2
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"CGather"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"CGather"
);
int
root_id
=
ctx
->
Attrs
().
Get
<
int
>
(
"root"
);
int
ring_id
=
ctx
->
Attrs
().
Get
<
int
>
(
"ring_id"
);
int
nranks
=
ctx
->
Attrs
().
Get
<
int
>
(
"nranks"
);
PADDLE_ENFORCE_GE
(
nranks
,
2
,
platform
::
errors
::
InvalidArgument
(
"The number of ranks (%d) must be greater than 1 "
"to use collective op (gather_op_v2)."
,
nranks
));
PADDLE_ENFORCE_GE
(
root_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The root_id (%d) for gather_op_v2 must be non-negative."
,
root_id
));
PADDLE_ENFORCE_LT
(
root_id
,
nranks
,
platform
::
errors
::
InvalidArgument
(
"The root_id (%d) for gather_op_v2 must be less than nranks (%d)."
,
root_id
,
nranks
));
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for gather_op_v2 must be non-negative."
,
root_id
));
framework
::
DDim
dim
=
ctx
->
GetInputDim
(
"X"
);
dim
[
0
]
=
dim
[
0
]
*
nranks
;
if
(
dim
[
0
]
<
0
)
dim
[
0
]
=
-
1
;
ctx
->
SetOutputDim
(
"Out"
,
dim
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
};
class
GatherOpV2Maker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) tensor to be gathered."
);
AddOutput
(
"Out"
,
"(Tensor) the result of gather."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) nccl communication ring id."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"root"
,
"(int default 0) root id for broadcasting."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"nranks"
,
"(int default 1) number of ranks."
).
SetDefault
(
0
);
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Gather Operator
Gather tensors from all participators.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
gather_v2
,
ops
::
GatherOpV2
,
ops
::
GatherOpV2Maker
);
REGISTER_OP_CPU_KERNEL
(
gather_v2
,
ops
::
CGatherOpV2CPUKernel
<
float
>
,
ops
::
GatherOpV2CPUKernel
<
double
>
,
ops
::
GatherOpV2CPUKernel
<
int
>
,
ops
::
GatherOpV2CPUKernel
<
int64_t
>
,
ops
::
GatherOpV2CPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/gather_op_v2.cu.cc
0 → 100644
浏览文件 @
2c84debb
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/gather_op_v2.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
GatherOpV2CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
int
send_numel
=
x
->
numel
();
ncclDataType_t
dtype
=
platform
::
ToNCCLDataType
(
x
->
type
());
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
root_id
=
ctx
.
Attr
<
int
>
(
"root"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
PADDLE_ENFORCE_EQ
(
nranks
,
comm
->
nranks
(),
platform
::
errors
::
InvalidArgument
(
"The number of ranks (%d) you set of must "
"be equal to comm->nranks (%d)."
,
nranks
,
comm
->
nranks
()));
PADDLE_ENFORCE_GE
(
root_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The root_id (%d) for gather_op_v2 must be non-negative."
,
root_id
));
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for gather_op_v2 must be non-negative."
,
ring_id
));
cudaStream_t
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
stream
=
comm
->
stream
();
}
framework
::
DDim
x_dims
=
x
->
dims
();
framework
::
DDim
out_dims
(
x_dims
);
out_dims
[
0
]
*=
nranks
;
auto
send_buf
=
x
->
data
<
T
>
();
auto
offset
=
0
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclSend
(
send_buf
,
send_numel
,
dtype
,
root_id
,
comm
->
comm
(),
stream
));
if
(
root_id
==
comm
->
rank
())
{
auto
recv_buf
=
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
for
(
auto
i
=
0
;
i
<
nranks
;
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
recv_buf
+
offset
,
send_numel
,
dtype
,
i
,
comm
->
comm
(),
stream
));
offset
+=
send_numel
;
}
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
gather_v2
,
ops
::
GatherOpV2CUDAKernel
<
float
>
,
ops
::
GatherOpV2CUDAKernel
<
double
>
,
ops
::
GatherOpV2CUDAKernel
<
int
>
,
ops
::
GatherOpV2CUDAKernel
<
int64_t
>
,
ops
::
GatherOpV2CUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/gather_op_v2.h
0 → 100644
浏览文件 @
2c84debb
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/gather.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
GatherOpV2CPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_GLOO)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
root_id
=
ctx
.
Attr
<
int
>
(
"root"
);
auto
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
auto
gloo
=
paddle
::
framework
::
GlooWrapper
::
GetInstance
();
PADDLE_ENFORCE_EQ
(
gloo
->
IsInitialized
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"You must initialize the gloo environment first to use it."
));
PADDLE_ENFORCE_EQ
(
nranks
,
gloo
->
Size
(),
platform
::
errors
::
InvalidArgument
(
"The number of ranks (%d) you set must "
"be equal to gloo->Size() (%d)."
,
nranks
,
gloo
->
Size
()));
int64_t
send_numel
=
in
->
numel
();
int64_t
recv_numel
=
out
->
numel
();
auto
in_dim
=
x
->
dims
();
auto
out_dim
=
framework
::
DDim
(
in_dim
);
out_dim
[
0
]
*=
nranks
;
auto
nranks
=
gloo
->
Size
();
auto
rank
=
gloo
->
Rank
();
gloo
::
GatherOptions
opts
(
gloo
->
GetContext
());
if
(
root_id
==
rank
)
{
T
*
recv_buff
=
out
->
mutable_data
<
T
>
(
place
,
out_dim
);
opts
.
setOutput
(
recv_buff
,
recv_numel
);
}
opts
.
setInput
(
send_buff
,
send_numel
);
opts
.
setRoot
(
root_id
);
gloo
::
gather
(
opts
);
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/collective/
c_scatter_op
.cc
→
paddle/fluid/operators/collective/
scatter_op_v2
.cc
浏览文件 @
2c84debb
...
...
@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/
c_scatter_op
.h"
#include "paddle/fluid/operators/collective/
scatter_op_v2
.h"
namespace
paddle
{
namespace
operators
{
class
CScatterOp
:
public
framework
::
OperatorWithKernel
{
class
ScatterOpV2
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -30,18 +30,23 @@ class CScatterOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE
(
nranks
,
2
,
platform
::
errors
::
InvalidArgument
(
"The number of ranks (%d) must be greater than 1 "
"to use collective op (
c_scatter op
)."
,
"to use collective op (
scatter_op_v2
)."
,
nranks
));
PADDLE_ENFORCE_GE
(
root_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The root_id (%d) for
c_scatter_op
must be non-negative."
,
"The root_id (%d) for
scatter_op_v2
must be non-negative."
,
root_id
));
PADDLE_ENFORCE_LT
(
root_id
,
nranks
,
platform
::
errors
::
InvalidArgument
(
"The root_id (%d) for scatter_op_v2 must be less "
"than the number of ranks (%d)."
,
root_id
,
nranks
));
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for
c_scatter_op
must be non-negative."
,
r
oot
_id
));
"The ring_id (%d) for
scatter_op_v2
must be non-negative."
,
r
ing
_id
));
framework
::
DDim
dim
=
ctx
->
GetInputDim
(
"X"
);
dim
[
0
]
=
dim
[
0
]
/
nranks
;
if
(
dim
[
0
]
<
0
)
dim
[
0
]
=
-
1
;
...
...
@@ -56,7 +61,7 @@ class CScatterOp : public framework::OperatorWithKernel {
}
};
class
CScatterOp
Maker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
ScatterOpV2
Maker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) tensor to be broadcasted."
);
...
...
@@ -71,7 +76,7 @@ class CScatterOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool default false) eject CUDA operations to calculation stream."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
C
Scatter Operator
Scatter Operator
Scatter the source to all participators.
)DOC"
);
}
...
...
@@ -83,10 +88,11 @@ Scatter the source to all participators.
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
c_scatter
,
ops
::
CScatterOp
,
ops
::
CScatterOpMaker
);
REGISTER_OP_WITHOUT_GRADIENT
(
scatter_v2
,
ops
::
ScatterOpV2
,
ops
::
ScatterOpV2Maker
);
REGISTER_OP_CPU_KERNEL
(
c_scatter
,
ops
::
CScatterOp
CPUKernel
<
float
>
,
ops
::
CScatterOp
CPUKernel
<
double
>
,
ops
::
CScatterOp
CPUKernel
<
int
>
,
ops
::
CScatterOp
CPUKernel
<
int64_t
>
,
ops
::
CScatterOp
CPUKernel
<
plat
::
float16
>
);
REGISTER_OP_CPU_KERNEL
(
scatter_v2
,
ops
::
ScatterOpV2
CPUKernel
<
float
>
,
ops
::
ScatterOpV2
CPUKernel
<
double
>
,
ops
::
ScatterOpV2
CPUKernel
<
int
>
,
ops
::
ScatterOpV2
CPUKernel
<
int64_t
>
,
ops
::
ScatterOpV2
CPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/
c_scatter_op
.cu.cc
→
paddle/fluid/operators/collective/
scatter_op_v2
.cu.cc
浏览文件 @
2c84debb
...
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/
c_scatter_op
.h"
#include "paddle/fluid/operators/collective/
scatter_op_v2
.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
...
...
@@ -23,7 +23,7 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
class
CScatterOp
CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ScatterOpV2
CUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_NCCL)
...
...
@@ -39,7 +39,7 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
PADDLE_ENFORCE_EQ
(
nranks
,
comm
->
nranks
(),
platform
::
errors
::
InvalidArgument
(
"The number of ranks (%d) you set
of
must "
"The number of ranks (%d) you set must "
"be equal to comm->nranks (%d)."
,
nranks
,
comm
->
nranks
()));
PADDLE_ENFORCE_GE
(
...
...
@@ -63,33 +63,25 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
framework
::
DDim
x_dims
=
x
->
dims
();
framework
::
DDim
out_dims
(
x_dims
);
framework
::
Tensor
temp
;
auto
out_ptr
=
temp
.
mutable_data
<
T
>
(
out_dims
,
place
);
out_dims
[
0
]
/=
nranks
;
auto
send_buf
=
x
->
data
<
T
>
();
auto
send_numel
=
numel
/
nranks
;
auto
recv_buf
=
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
auto
offset
=
0
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
if
(
root_id
==
comm
->
rank
())
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclBcast
(
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
x
->
data
<
T
>
())),
numel
,
dtype
,
root_id
,
comm
->
comm
(),
stream
));
framework
::
TensorCopy
(
*
static_cast
<
const
framework
::
Tensor
*>
(
x
),
place
,
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
),
static_cast
<
framework
::
Tensor
*>
(
&
temp
));
}
else
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclBcast
(
out_ptr
,
numel
,
dtype
,
root_id
,
comm
->
comm
(),
stream
));
for
(
auto
i
=
0
;
i
<
nranks
;
++
i
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclSend
(
send_buf
+
offset
,
send_numel
,
dtype
,
root_id
,
comm
->
comm
(),
stream
));
offset
+=
send_numel
;
}
}
out_dims
[
0
]
=
out_dims
[
0
]
/
nranks
;
auto
start_index
=
out_dims
[
0
]
*
comm
->
rank
();
auto
end_index
=
start_index
+
out_dims
[
0
];
temp
=
temp
.
Slice
(
start_index
,
end_index
);
temp
.
Resize
(
out_dims
);
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
framework
::
TensorCopySync
(
*
static_cast
<
const
framework
::
Tensor
*>
(
&
temp
),
place
,
static_cast
<
framework
::
Tensor
*>
(
out
));
out
->
Resize
(
out_dims
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
recv_buf
,
send_numel
,
dtype
,
root_id
,
comm
->
comm
(),
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
#else
PADDLE_ENFORCE_EQ
(
true
,
false
,
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
...
...
@@ -101,8 +93,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
c_scatter
,
ops
::
CScatterOp
CUDAKernel
<
float
>
,
ops
::
CScatterOp
CUDAKernel
<
double
>
,
ops
::
CScatterOp
CUDAKernel
<
int
>
,
ops
::
CScatterOp
CUDAKernel
<
int64_t
>
,
ops
::
CScatterOp
CUDAKernel
<
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
scatter_v2
,
ops
::
ScatterOpV2
CUDAKernel
<
float
>
,
ops
::
ScatterOpV2
CUDAKernel
<
double
>
,
ops
::
ScatterOpV2
CUDAKernel
<
int
>
,
ops
::
ScatterOpV2
CUDAKernel
<
int64_t
>
,
ops
::
ScatterOpV2
CUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/
c_scatter_op
.h
→
paddle/fluid/operators/collective/
scatter_op_v2
.h
浏览文件 @
2c84debb
...
...
@@ -31,7 +31,7 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
class
CScatterOp
CPUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ScatterOpV2
CPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_GLOO)
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
2c84debb
...
...
@@ -62,6 +62,7 @@ if(NOT WITH_GPU OR WIN32)
LIST
(
REMOVE_ITEM TEST_OPS test_broadcast
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_reduce
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_scatter
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_gather
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_reduce_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_scatter_api
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_barrier_api
)
...
...
python/paddle/fluid/tests/unittests/collective_gather_op.py
0 → 100644
浏览文件 @
2c84debb
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
argparse
import
os
import
sys
import
signal
import
time
import
socket
from
contextlib
import
closing
from
six
import
string_types
import
math
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.profiler
as
profiler
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
import
unittest
from
multiprocessing
import
Process
import
paddle.fluid.layers
as
layers
from
functools
import
reduce
from
test_collective_base
import
TestCollectiveRunnerBase
,
runtime_main
class
TestCollectiveGather
(
TestCollectiveRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
=
None
):
ring_id
=
0
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
layers
.
data
(
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
'float32'
)
toutdata
=
layers
.
data
(
name
=
"toutdata"
,
shape
=
[
20
,
1000
],
dtype
=
'float32'
)
main_prog
.
global_block
().
append_op
(
type
=
"gather_v2"
,
inputs
=
{
'X'
:
tindata
},
outputs
=
{
'Out'
:
toutdata
},
attrs
=
{
'ring_id'
:
ring_id
,
'nranks'
:
2
,
'root'
:
1
})
main_prog
.
global_block
().
append_op
(
type
=
"c_sync_comm_stream"
,
inputs
=
{
'X'
:
toutdata
},
outputs
=
{
'Out'
:
toutdata
},
attrs
=
{
'ring_id'
:
ring_id
})
return
toutdata
if
__name__
==
"__main__"
:
runtime_main
(
TestCollectiveGather
,
"gather"
,
0
)
python/paddle/fluid/tests/unittests/collective_scatter_op.py
浏览文件 @
2c84debb
...
...
@@ -49,13 +49,13 @@ class TestCollectiveScatter(TestCollectiveRunnerBase):
tindata
=
layers
.
data
(
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
'float32'
)
toutdata
=
main_prog
.
current_block
().
create_var
(
name
=
"
outofreduce
"
,
name
=
"
tinout
"
,
dtype
=
'float32'
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
False
)
main_prog
.
global_block
().
append_op
(
type
=
"
c_scatter
"
,
type
=
"
scatter_v2
"
,
inputs
=
{
'X'
:
tindata
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
rootid
,
...
...
python/paddle/fluid/tests/unittests/test_collective_gather.py
0 → 100644
浏览文件 @
2c84debb
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
test_collective_base
import
TestDistBase
class
TestCGatherOp
(
TestDistBase
):
def
_setup_config
(
self
):
pass
def
test_gather
(
self
):
self
.
check_with_place
(
"collective_gather_op.py"
,
"gather"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录