Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8cc2e28c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8cc2e28c
编写于
5月 28, 2022
作者:
S
ShenLiang
提交者:
GitHub
5月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Bug Fix]Fix global_scatter/global_gather in ProcessGroup (#43027)
* fix alltoall * rename utest
上级
9eb18c75
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
376 addition
and
11 deletion
+376
-11
paddle/fluid/distributed/collective/ProcessGroup.h
paddle/fluid/distributed/collective/ProcessGroup.h
+13
-0
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+47
-0
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
+8
-0
paddle/fluid/operators/collective/global_gather_op.cu.cc
paddle/fluid/operators/collective/global_gather_op.cu.cc
+129
-3
paddle/fluid/operators/collective/global_gather_op.h
paddle/fluid/operators/collective/global_gather_op.h
+11
-0
paddle/fluid/operators/collective/global_scatter_op.cu.cc
paddle/fluid/operators/collective/global_scatter_op.cu.cc
+127
-3
paddle/fluid/operators/collective/global_scatter_op.h
paddle/fluid/operators/collective/global_scatter_op.h
+11
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-2
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+8
-1
python/paddle/fluid/tests/unittests/test_collective_global_gather.py
...le/fluid/tests/unittests/test_collective_global_gather.py
+10
-1
python/paddle/fluid/tests/unittests/test_collective_global_scatter.py
...e/fluid/tests/unittests/test_collective_global_scatter.py
+10
-1
未找到文件。
paddle/fluid/distributed/collective/ProcessGroup.h
浏览文件 @
8cc2e28c
...
...
@@ -113,6 +113,19 @@ class ProcessGroup {
"ProcessGroup%s does not support receive"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
,
int
,
int
,
int
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support send"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
,
int
,
int
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support receive"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
)
{
// NOLINT
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
8cc2e28c
...
...
@@ -428,6 +428,53 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return
task
;
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
int
length
)
{
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
phi
::
DenseTensor
flatten_tensor
;
flatten_tensor
.
ShareDataWith
(
tensors
).
Resize
({
tensors
.
numel
()});
phi
::
DenseTensor
shared_input
=
flatten_tensor
.
Slice
(
offset
,
offset
+
length
);
std
::
vector
<
phi
::
DenseTensor
>
shared_tensors
;
shared_tensors
.
push_back
(
shared_input
);
auto
task
=
PointToPoint
(
shared_tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
,
int
dst_rank
)
{
return
platform
::
dynload
::
ncclSend
(
input
.
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
dst_rank
,
comm
,
stream
);
},
dst_rank
,
CommType
::
SEND
);
return
task
;
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
int
length
)
{
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);
phi
::
DenseTensor
flatten_tensor
;
flatten_tensor
.
ShareDataWith
(
tensors
).
Resize
({
tensors
.
numel
()});
phi
::
DenseTensor
shared_input
=
flatten_tensor
.
Slice
(
offset
,
offset
+
length
);
std
::
vector
<
phi
::
DenseTensor
>
shared_tensors
;
shared_tensors
.
push_back
(
shared_input
);
auto
task
=
PointToPoint
(
shared_tensors
,
[
&
](
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
,
int
src_rank
)
{
return
platform
::
dynload
::
ncclRecv
(
output
.
data
(),
output
.
numel
(),
platform
::
ToNCCLDataType
(
output
.
dtype
()),
src_rank
,
comm
,
stream
);
},
src_rank
,
CommType
::
RECV
);
return
task
;
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
{
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
浏览文件 @
8cc2e28c
...
...
@@ -102,6 +102,14 @@ class ProcessGroupNCCL : public ProcessGroup {
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
src_rank
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
int
length
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
int
length
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
override
;
...
...
paddle/fluid/operators/collective/global_gather_op.cu.cc
浏览文件 @
8cc2e28c
...
...
@@ -22,10 +22,10 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
GlobalGatherOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
struct
GlobalGatherFunctor
<
phi
::
GPUContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
...
...
@@ -137,6 +137,132 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
>
struct
GlobalGatherProcessGroupFunctor
<
phi
::
GPUContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
local_count
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"local_count"
);
auto
global_count
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"global_count"
);
auto
local_count_type
=
framework
::
TransToProtoVarType
(
local_count
->
dtype
());
auto
global_count_type
=
framework
::
TransToProtoVarType
(
global_count
->
dtype
());
if
(
local_count_type
!=
framework
::
proto
::
VarType
::
INT64
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Please use int64 type in local_count."
));
}
if
(
global_count_type
!=
framework
::
proto
::
VarType
::
INT64
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Please use int64 type in global_count."
));
}
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
const
int64_t
*
cpu_local_count_data
;
const
int64_t
*
cpu_global_count_data
;
auto
local_count_len
=
0
;
framework
::
Tensor
cpu_local_count
;
if
(
platform
::
is_cpu_place
(
local_count
->
place
()))
{
cpu_local_count_data
=
local_count
->
data
<
int64_t
>
();
local_count_len
=
local_count
->
numel
();
}
else
{
framework
::
TensorCopySync
(
*
local_count
,
platform
::
CPUPlace
(),
&
cpu_local_count
);
cpu_local_count_data
=
cpu_local_count
.
data
<
int64_t
>
();
local_count_len
=
cpu_local_count
.
numel
();
}
framework
::
Tensor
cpu_global_count
;
if
(
platform
::
is_cpu_place
(
global_count
->
place
()))
{
cpu_global_count_data
=
global_count
->
data
<
int64_t
>
();
}
else
{
framework
::
TensorCopySync
(
*
global_count
,
platform
::
CPUPlace
(),
&
cpu_global_count
);
cpu_global_count_data
=
cpu_global_count
.
data
<
int64_t
>
();
}
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for global gather op must be non-negative."
,
ring_id
));
auto
place
=
ctx
.
GetPlace
();
auto
map
=
distributed
::
ProcessGroupMapFromGid
::
getInstance
();
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
int
nranks
=
pg
->
GetSize
();
auto
in_feat
=
x
->
dims
()[
1
];
auto
n_expert
=
local_count
->
dims
()[
0
]
/
nranks
;
auto
fwd_count
=
0
;
for
(
auto
i
=
0
;
i
<
local_count_len
;
++
i
)
{
fwd_count
+=
cpu_local_count_data
[
i
];
}
framework
::
DDim
out_dims
=
phi
::
make_ddim
({
fwd_count
,
in_feat
});
int64_t
*
expert_ptr
=
new
int64_t
[
n_expert
*
nranks
];
expert_ptr
[
0
]
=
0
;
auto
tot_experts
=
n_expert
*
nranks
;
for
(
auto
i
=
1
;
i
<
tot_experts
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
cpu_local_count_data
[
i
-
1
];
}
auto
send_ptr
=
0
;
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
for
(
auto
i
=
0
;
i
<
n_expert
;
++
i
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
j
=
0
;
j
<
nranks
;
++
j
)
{
int
idx
=
i
+
j
*
n_expert
;
if
(
cpu_global_count_data
[
idx
])
{
phi
::
DenseTensor
tmp
=
*
x
;
pg
->
Send_Partial
(
tmp
,
j
,
send_ptr
*
in_feat
,
cpu_global_count_data
[
idx
]
*
in_feat
);
send_ptr
+=
cpu_global_count_data
[
idx
];
}
if
(
cpu_local_count_data
[
idx
])
{
pg
->
Recv_Partial
(
*
out
,
j
,
expert_ptr
[
idx
]
*
in_feat
,
cpu_local_count_data
[
idx
]
*
in_feat
);
}
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
}
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaDeviceSynchronize
());
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
hipDeviceSynchronize
());
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"NCCL version >= 2.7.3 is needed."
));
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
template
<
typename
T
>
class
GlobalGatherOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
map
=
distributed
::
ProcessGroupMapFromGid
::
getInstance
();
if
(
map
->
has
(
rid
))
{
GlobalGatherProcessGroupFunctor
<
phi
::
GPUContext
,
T
>
functor_
;
functor_
(
ctx
);
}
else
{
GlobalGatherFunctor
<
phi
::
GPUContext
,
T
>
functor_
;
functor_
(
ctx
);
}
}
};
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/operators/collective/global_gather_op.h
浏览文件 @
8cc2e28c
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -33,5 +34,15 @@ class GlobalGatherOpCPUKernel : public framework::OpKernel<T> {
}
};
template
<
typename
Context
,
typename
T
>
struct
GlobalGatherFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
);
};
template
<
typename
Context
,
typename
T
>
struct
GlobalGatherProcessGroupFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
);
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/collective/global_scatter_op.cu.cc
浏览文件 @
8cc2e28c
...
...
@@ -22,10 +22,10 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
GlobalScatterOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
struct
GlobalScatterFunctor
<
phi
::
GPUContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
...
...
@@ -137,6 +137,130 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
>
struct
GlobalScatterProcessGroupFunctor
<
phi
::
GPUContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
local_count
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"local_count"
);
auto
global_count
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"global_count"
);
auto
local_count_type
=
framework
::
TransToProtoVarType
(
local_count
->
dtype
());
auto
global_count_type
=
framework
::
TransToProtoVarType
(
global_count
->
dtype
());
if
(
local_count_type
!=
framework
::
proto
::
VarType
::
INT64
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Please use int64 type in local_count."
));
}
if
(
global_count_type
!=
framework
::
proto
::
VarType
::
INT64
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Please use int64 type in global_count."
));
}
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
const
int64_t
*
cpu_local_count_data
;
const
int64_t
*
cpu_global_count_data
;
framework
::
Tensor
cpu_local_count
;
if
(
platform
::
is_cpu_place
(
local_count
->
place
()))
{
cpu_local_count_data
=
local_count
->
data
<
int64_t
>
();
}
else
{
framework
::
TensorCopySync
(
*
local_count
,
platform
::
CPUPlace
(),
&
cpu_local_count
);
cpu_local_count_data
=
cpu_local_count
.
data
<
int64_t
>
();
}
auto
global_count_len
=
0
;
framework
::
Tensor
cpu_global_count
;
if
(
platform
::
is_cpu_place
(
global_count
->
place
()))
{
cpu_global_count_data
=
global_count
->
data
<
int64_t
>
();
global_count_len
=
global_count
->
numel
();
}
else
{
framework
::
TensorCopySync
(
*
global_count
,
platform
::
CPUPlace
(),
&
cpu_global_count
);
cpu_global_count_data
=
cpu_global_count
.
data
<
int64_t
>
();
global_count_len
=
cpu_global_count
.
numel
();
}
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"The ring_id (%d) for global scatter op must be non-negative."
,
ring_id
));
auto
place
=
ctx
.
GetPlace
();
auto
map
=
distributed
::
ProcessGroupMapFromGid
::
getInstance
();
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
int
nranks
=
pg
->
GetSize
();
auto
in_feat
=
x
->
dims
()[
1
];
auto
n_expert
=
local_count
->
dims
()[
0
]
/
nranks
;
int64_t
fwd_count
=
0
;
for
(
auto
i
=
0
;
i
<
global_count_len
;
++
i
)
{
fwd_count
+=
cpu_global_count_data
[
i
];
}
framework
::
DDim
out_dims
=
phi
::
make_ddim
({
fwd_count
,
in_feat
});
int64_t
*
expert_ptr
=
new
int64_t
[
n_expert
*
nranks
];
expert_ptr
[
0
]
=
0
;
auto
tot_experts
=
n_expert
*
nranks
;
for
(
auto
i
=
1
;
i
<
tot_experts
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
cpu_local_count_data
[
i
-
1
];
}
auto
recv_ptr
=
0
;
out
->
mutable_data
<
T
>
(
out_dims
,
place
);
for
(
auto
i
=
0
;
i
<
n_expert
;
++
i
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
j
=
0
;
j
<
nranks
;
++
j
)
{
int
idx
=
i
+
j
*
n_expert
;
if
(
cpu_local_count_data
[
idx
])
{
phi
::
DenseTensor
tmp
=
*
x
;
pg
->
Send_Partial
(
tmp
,
j
,
expert_ptr
[
idx
]
*
in_feat
,
cpu_local_count_data
[
idx
]
*
in_feat
);
}
if
(
cpu_global_count_data
[
idx
])
{
pg
->
Recv_Partial
(
*
out
,
j
,
recv_ptr
*
in_feat
,
cpu_global_count_data
[
idx
]
*
in_feat
);
recv_ptr
+=
cpu_global_count_data
[
idx
];
}
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
}
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaDeviceSynchronize
());
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
hipDeviceSynchronize
());
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"NCCL version >= 2.7.3 is needed."
));
#endif
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
template
<
typename
T
>
class
GlobalScatterOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
map
=
distributed
::
ProcessGroupMapFromGid
::
getInstance
();
if
(
map
->
has
(
rid
))
{
GlobalScatterProcessGroupFunctor
<
phi
::
GPUContext
,
T
>
functor_
;
functor_
(
ctx
);
}
else
{
GlobalScatterFunctor
<
phi
::
GPUContext
,
T
>
functor_
;
functor_
(
ctx
);
}
}
};
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/operators/collective/global_scatter_op.h
浏览文件 @
8cc2e28c
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -33,5 +34,15 @@ class GlobalScatterOpCPUKernel : public framework::OpKernel<T> {
}
};
template
<
typename
Context
,
typename
T
>
struct
GlobalScatterFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
);
};
template
<
typename
Context
,
typename
T
>
struct
GlobalScatterProcessGroupFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
);
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
8cc2e28c
...
...
@@ -1182,8 +1182,8 @@ endif()
if
((
WITH_ROCM OR WITH_GPU
)
AND NOT WIN32
)
set_tests_properties
(
test_collective_allgather_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_alltoall_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_global_gather PROPERTIES TIMEOUT
12
0
)
set_tests_properties
(
test_collective_global_scatter PROPERTIES TIMEOUT
12
0
)
set_tests_properties
(
test_collective_global_gather PROPERTIES TIMEOUT
20
0
)
set_tests_properties
(
test_collective_global_scatter PROPERTIES TIMEOUT
20
0
)
set_tests_properties
(
test_collective_sendrecv_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_broadcast_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_allreduce_api PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
8cc2e28c
...
...
@@ -191,7 +191,8 @@ class TestDistBase(unittest.TestCase):
path_id
=
"0"
,
static_mode
=
"1"
,
check_error_log
=
False
,
need_envs
=
{}):
need_envs
=
{},
eager_mode
=
True
):
if
backend
==
"nccl"
or
backend
==
"bkcl"
:
with_gloo
=
'0'
else
:
...
...
@@ -216,6 +217,12 @@ class TestDistBase(unittest.TestCase):
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
required_envs
[
"GLOO_LOG_LEVEL"
]
=
"TRACE"
if
eager_mode
:
required_envs
[
"FLAGS_enable_eager_mode"
]
=
"%d"
%
0
else
:
required_envs
[
"FLAGS_enable_eager_mode"
]
=
"%d"
%
1
tr0_out
,
tr1_out
,
pid0
,
pid1
=
self
.
_run_cluster
(
model_file
,
required_envs
)
np
.
random
.
seed
(
pid0
)
...
...
python/paddle/fluid/tests/unittests/test_collective_global_gather.py
浏览文件 @
8cc2e28c
...
...
@@ -35,7 +35,16 @@ class TestCollectiveGlobalGatherAPI(TestDistBase):
"collective_global_gather_dygraph.py"
,
"global_gather"
,
"nccl"
,
static_mode
=
"0"
)
static_mode
=
"0"
,
eager_mode
=
False
)
def
test_global_gather_nccl_dygraph_eager
(
self
):
self
.
check_with_place
(
"collective_global_gather_dygraph.py"
,
"global_gather"
,
"nccl"
,
static_mode
=
"0"
,
eager_mode
=
True
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_collective_global_scatter.py
浏览文件 @
8cc2e28c
...
...
@@ -35,7 +35,16 @@ class TestCollectiveSelectScatterAPI(TestDistBase):
"collective_global_scatter_dygraph.py"
,
"global_scatter"
,
"nccl"
,
static_mode
=
"0"
)
static_mode
=
"0"
,
eager_mode
=
False
)
def
test_global_scatter_nccl_dygraph_eager
(
self
):
self
.
check_with_place
(
"collective_global_scatter_dygraph.py"
,
"global_scatter"
,
"nccl"
,
static_mode
=
"0"
,
eager_mode
=
True
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录