Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
cbdb8a17
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cbdb8a17
编写于
4月 18, 2019
作者:
G
gongweibao
提交者:
GitHub
4月 18, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Polish DGC code (#16818)
上级
dbf66dd0
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
369 addition
and
253 deletion
+369
-253
CMakeLists.txt
CMakeLists.txt
+7
-1
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+0
-9
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+10
-6
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+5
-180
paddle/fluid/framework/details/all_reduce_op_handle.h
paddle/fluid/framework/details/all_reduce_op_handle.h
+5
-17
paddle/fluid/framework/details/dgc_const_values.h
paddle/fluid/framework/details/dgc_const_values.h
+32
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+31
-11
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+2
-2
paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc
...le/fluid/framework/details/sparse_all_reduce_op_handle.cc
+188
-0
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
+52
-0
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+0
-5
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+4
-5
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+2
-2
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+16
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+6
-6
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+5
-7
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-1
未找到文件。
CMakeLists.txt
浏览文件 @
cbdb8a17
...
...
@@ -77,6 +77,7 @@ option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interf
option
(
WITH_HIGH_LEVEL_API_TEST
"Test fluid python high-level api interface"
OFF
)
option
(
PY_VERSION
"Compile PaddlePaddle with python3 support"
${
PY_VERSION
}
)
option
(
WITH_FAST_MATH
"Make use of fast math library, might affect the precision to some extent"
ON
)
option
(
WITH_DGC
"Use DGC(Deep Gradient Compression) or not"
ON
)
# PY_VERSION
if
(
NOT PY_VERSION
)
...
...
@@ -196,9 +197,14 @@ if(WITH_GPU)
include
(
anakin_subgraph
)
endif
()
if
(
WITH_GPU AND NOT WIN32
)
if
(
WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER
)
set
(
WITH_DGC OFF
)
endif
()
if
(
WITH_DGC
)
message
(
STATUS
"add dgc lib."
)
include
(
external/dgc
)
add_definitions
(
-DPADDLE_WITH_DGC
)
endif
()
if
(
WITH_MKL OR WITH_MKLML
)
...
...
cmake/inference_lib.cmake
浏览文件 @
cbdb8a17
...
...
@@ -131,15 +131,6 @@ elseif (NOT CBLAS_FOUND OR WIN32)
)
endif
()
if
(
WITH_GPU AND NOT WIN32
)
set
(
dgc_dir
"
${
FLUID_INSTALL_DIR
}
/third_party/install/dgc"
)
copy
(
dgc_lib
SRCS
${
DGC_INSTALL_DIR
}
/lib
${
DGC_INSTALL_DIR
}
/include
DSTS
${
dgc_dir
}
${
dgc_dir
}
DEPS dgc
)
endif
()
if
(
WITH_MKLDNN
)
set
(
dst_dir
"
${
FLUID_INSTALL_DIR
}
/third_party/install/mkldnn"
)
copy
(
mkldnn_lib
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
cbdb8a17
...
...
@@ -24,15 +24,19 @@ if(WITH_DISTRIBUTE)
endif
()
endif
()
set
(
all_reduce_deps all_reduce_op_handle
)
if
(
WITH_GPU
)
set
(
dgc_deps
""
)
if
(
NOT WIN32
)
set
(
dgc_deps dgc
)
endif
()
nv_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor
${
dgc_deps
}
)
dynload_cuda variable_visitor
)
nv_library
(
fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor
)
if
(
WITH_DGC
)
nv_library
(
sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle
)
set
(
all_reduce_deps sparse_all_reduce_op_handle
)
endif
()
if
(
WITH_DISTRIBUTE
)
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_rpc
)
...
...
@@ -80,7 +84,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass
)
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle
all_reduce_op_handle
reduce_op_handle broadcast_op_handle fused_broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle
${
all_reduce_deps
}
reduce_op_handle broadcast_op_handle fused_broadcast_op_handle
)
cc_library
(
fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle
)
...
...
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
cbdb8a17
...
...
@@ -17,11 +17,6 @@
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -40,23 +35,16 @@ namespace details {
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
,
bool
is_encoded
,
int
nranks
)
const
platform
::
NCCLContextMap
*
ctxs
)
:
OpHandleBase
(
node
),
local_scopes_
(
local_scopes
),
places_
(
places
),
nccl_ctxs_
(
ctxs
),
is_encoded_
(
is_encoded
),
nranks_
(
nranks
)
{
nccl_ctxs_
(
ctxs
)
{
if
(
nccl_ctxs_
)
{
for
(
auto
&
p
:
places_
)
{
this
->
SetDeviceContext
(
p
,
nccl_ctxs_
->
DevCtx
(
p
));
}
}
// TODO(gongwb) :polish them!
if
(
is_encoded
)
{
VLOG
(
1
)
<<
"Use dgc allreduce mode"
;
}
}
#else
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
...
...
@@ -66,92 +54,8 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void
AllReduceOpHandle
::
RunImplEncoded
()
{
platform
::
RecordEvent
record_event
(
Name
());
WaitInputVarGenerated
();
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Inputs
());
auto
out_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Outputs
());
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
places_
.
size
(),
"The NoDummyInputSize should be equal to the number of places."
);
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
out_var_handles
.
size
(),
"The NoDummyInputSize and NoDummyOutputSize should be equal."
);
std
::
vector
<
const
LoDTensor
*>
ins
;
std
::
vector
<
LoDTensor
*>
outs
;
int
k
=
-
1
;
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
local_scope
=
local_scopes_
[
i
]
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
original_name
=
paddle
::
framework
::
GradOriginalVarName
(
in_var_handles
[
i
]
->
name
());
auto
encode_var_name
=
original_name
+
g_dgc_encoded
;
auto
*
in_var
=
local_scope
->
FindVar
(
encode_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
in_var
,
"%s should not be null"
,
encode_var_name
);
auto
&
in
=
in_var
->
Get
<
LoDTensor
>
();
ins
.
emplace_back
(
&
in
);
auto
*
out
=
local_scope
->
FindVar
(
out_var_handles
[
i
]
->
name
())
->
GetMutable
<
LoDTensor
>
();
outs
.
emplace_back
(
out
);
if
(
k
<
0
)
{
k
=
GetKValue
(
in_var_handles
[
i
]
->
name
());
}
}
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ins
[
0
]
->
place
()));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
outs
[
0
]
->
place
()));
PADDLE_ENFORCE
(
nccl_ctxs_
,
"nccl_ctxs should not be nullptr."
);
int
dtype
=
-
1
;
size_t
in_numel
=
0
;
size_t
out_numel
=
0
;
PADDLE_ENFORCE
(
nranks_
>
1
);
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
place
=
places_
[
i
];
auto
&
in
=
*
ins
[
i
];
void
*
in_tensor_buf
=
const_cast
<
void
*>
(
in
.
data
<
void
>
());
auto
&
out
=
*
outs
[
i
];
float
*
out_tensor_buf
=
out
.
data
<
float
>
();
dtype
=
(
dtype
==
-
1
)
?
platform
::
ToNCCLDataType
(
in
.
type
())
:
dtype
;
in_numel
=
(
in_numel
==
0
)
?
static_cast
<
size_t
>
(
in
.
numel
())
:
in_numel
;
PADDLE_ENFORCE
(
in_numel
%
2
==
0
);
PADDLE_ENFORCE
(
in_numel
/
2
==
static_cast
<
size_t
>
(
k
));
out_numel
=
(
out_numel
==
0
)
?
static_cast
<
size_t
>
(
out
.
numel
())
:
out_numel
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
auto
&
allocator
=
platform
::
DeviceTemporaryAllocator
::
Instance
().
Get
(
place
,
stream
);
int
encode_size
=
2
*
k
*
sizeof
(
int
);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int
buf_size
=
nranks_
*
encode_size
;
auto
tmp_ious_data
=
allocator
.
Allocate
(
buf_size
);
void
*
gather_buff
=
reinterpret_cast
<
void
*>
(
tmp_ious_data
->
ptr
());
VLOG
(
10
)
<<
"in_numel:"
<<
in_numel
<<
", out_numel:"
<<
out_numel
<<
", nranks:"
<<
nranks_
<<
", gather_buf size:"
<<
buf_size
<<
", k:"
<<
k
<<
", place:"
<<
place
<<
", dtype:"
<<
dtype
;
all_reduce_calls
.
emplace_back
([
=
]
{
PADDLE_ENFORCE
(
paddle
::
communication
::
dgc
::
sparseAllGReduce
(
in_tensor_buf
,
gather_buff
,
k
,
out_tensor_buf
,
out_numel
,
comm
,
stream
));
});
}
void
AllReduceOpHandle
::
RunAllReduceFuncs
(
const
std
::
vector
<
std
::
function
<
void
()
>>
&
all_reduce_calls
)
{
this
->
RunAndRecordEvent
([
&
]
{
if
(
all_reduce_calls
.
size
()
==
1UL
)
{
// Do not use NCCLGroup when manage NCCL by per thread per device
...
...
@@ -182,68 +86,9 @@ void AllReduceOpHandle::RunImplEncoded() {
}
}
}
int
AllReduceOpHandle
::
GetKValue
(
const
std
::
string
&
grad_name
)
{
auto
original_name
=
paddle
::
framework
::
GradOriginalVarName
(
grad_name
);
auto
var_name
=
original_name
+
g_dgc_k
;
PADDLE_ENFORCE
(
local_scopes_
.
size
()
>
0
);
auto
*
scope
=
local_scopes_
[
0
];
auto
&
local_scope
=
scope
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
var
=
local_scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
auto
tensor
=
var
->
Get
<
LoDTensor
>
().
data
<
float
>
();
return
*
tensor
;
}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
bool
AllReduceOpHandle
::
IsEncoded
()
{
if
(
!
is_encoded_
)
{
return
false
;
}
auto
counter_name
=
g_dgc_counter_name
;
auto
step_name
=
g_dgc_rampup_begin_step
;
PADDLE_ENFORCE
(
local_scopes_
.
size
()
>
0
);
auto
*
scope
=
local_scopes_
[
0
];
auto
&
local_scope
=
scope
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
count_var
=
local_scope
->
FindVar
(
counter_name
);
auto
step_var
=
local_scope
->
FindVar
(
step_name
);
if
(
count_var
==
nullptr
||
step_var
==
nullptr
)
{
PADDLE_THROW
(
"not find count_var:%s or step_var:%s"
,
counter_name
,
step_var
);
}
float
count
=
*
count_var
->
Get
<
LoDTensor
>
().
data
<
float
>
();
float
step
=
*
step_var
->
Get
<
LoDTensor
>
().
data
<
float
>
();
if
(
static_cast
<
int
>
(
count
)
<
static_cast
<
int
>
(
step
))
{
VLOG
(
10
)
<<
"in all_reduce currentstep:"
<<
count
<<
" < rampup_begin_step:"
<<
step
<<
" so not use sparse all reduce"
;
return
false
;
}
return
true
;
}
#else
bool
AllReduceOpHandle
::
IsEncoded
()
{
return
false
;
}
#endif
void
AllReduceOpHandle
::
RunImpl
()
{
if
(
!
IsEncoded
())
{
RunImplNormal
();
return
;
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
RunImplEncoded
();
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
}
void
AllReduceOpHandle
::
RunImplNormal
()
{
platform
::
RecordEvent
record_event
(
Name
());
WaitInputVarGenerated
();
...
...
@@ -304,27 +149,7 @@ void AllReduceOpHandle::RunImplNormal() {
comm
,
stream
));
});
}
this
->
RunAndRecordEvent
([
&
]
{
if
(
all_reduce_calls
.
size
()
==
1UL
)
{
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls
[
0
]();
}
else
{
platform
::
NCCLGroupGuard
guard
;
for
(
auto
&
call
:
all_reduce_calls
)
{
call
();
}
}
});
if
(
FLAGS_sync_nccl_allreduce
)
{
for
(
auto
&
p
:
places_
)
{
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
cudaStreamSynchronize
(
stream
);
}
}
RunAllReduceFuncs
(
all_reduce_calls
);
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
...
...
paddle/fluid/framework/details/all_reduce_op_handle.h
浏览文件 @
cbdb8a17
...
...
@@ -28,19 +28,12 @@ namespace paddle {
namespace
framework
{
namespace
details
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
constexpr
char
g_dgc_counter_name
[]
=
"__g_dgc_counter__"
;
constexpr
char
g_dgc_rampup_begin_step
[]
=
"__g_rampup_begin_step__"
;
constexpr
char
g_dgc_encoded
[]
=
"__dgc_encoded__"
;
constexpr
char
g_dgc_k
[]
=
"__dgc_k__"
;
#endif
struct
AllReduceOpHandle
:
public
OpHandleBase
{
class
AllReduceOpHandle
:
public
OpHandleBase
{
public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
,
bool
is_encoded
=
false
,
int
nranks
=
-
1
);
const
platform
::
NCCLContextMap
*
ctxs
);
#else
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
);
...
...
@@ -54,18 +47,13 @@ struct AllReduceOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
private:
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void
RunImplEncoded
();
void
RunAllReduceFuncs
(
const
std
::
vector
<
std
::
function
<
void
()
>>
&
all_reduce_calls
);
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
bool
is_encoded_
{
false
};
int
nranks_
{
-
1
};
int
GetKValue
(
const
std
::
string
&
grad_name
);
#endif
void
RunImplNormal
();
bool
IsEncoded
();
};
}
// namespace details
...
...
paddle/fluid/framework/details/dgc_const_values.h
0 → 100644
浏览文件 @
cbdb8a17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace
paddle
{
namespace
framework
{
namespace
details
{
constexpr
char
g_dgc_counter_name
[]
=
"__g_dgc_counter__"
;
constexpr
char
g_dgc_rampup_begin_step
[]
=
"__g_rampup_begin_step__"
;
constexpr
char
g_dgc_u
[]
=
"__dgc_u__"
;
constexpr
char
g_dgc_v
[]
=
"__dgc_v__"
;
constexpr
char
g_dgc_k
[]
=
"__dgc_k__"
;
constexpr
char
g_dgc_encoded
[]
=
"__dgc_encoded__"
;
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
cbdb8a17
...
...
@@ -34,6 +34,10 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/math/math_function.h"
#if defined(PADDLE_WITH_DGC)
#include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h"
#endif
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -438,12 +442,22 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
auto
append_allreduce_op
=
[
&
](
const
std
::
vector
<
Scope
*>
&
scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
)
->
OpHandleBase
*
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_DGC)
if
(
is_encoded
)
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
SparseAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
,
is_encoded
,
static_cast
<
int
>
(
strategy_
.
trainers_endpoints_
.
size
())
*
places_
.
size
()));
}
else
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
));
}
#elif defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
,
is_encoded
,
static_cast
<
int
>
(
strategy_
.
trainers_endpoints_
.
size
())
*
places_
.
size
()));
scopes
,
places
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
...
@@ -561,7 +575,11 @@ void AllReduceSSAGraphBuilder::InsertCollectiveOp(
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
#if defined(PADDLE_WITH_DGC)
CreateAllReduceOp
(
result
,
g_name
,
IsEncoded
(
p_name
));
#else
CreateAllReduceOp
(
result
,
g_name
);
#endif
}
}
...
...
@@ -965,8 +983,9 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return
op_dev_id
;
}
bool
DistSSAGraphBuilder
::
IsEncoded
(
const
std
::
string
&
p_name
)
const
{
auto
u_name
=
p_name
+
"__dgc_u__"
;
#if defined(PADDLE_WITH_DGC)
bool
AllReduceSSAGraphBuilder
::
IsEncoded
(
const
std
::
string
&
p_name
)
const
{
auto
u_name
=
p_name
+
g_dgc_u
;
auto
it
=
all_vars_
.
find
(
u_name
);
if
(
it
==
all_vars_
.
end
())
{
VLOG
(
10
)
<<
"can't find u_name, so it's not encoded:"
<<
u_name
;
...
...
@@ -975,6 +994,11 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
return
true
;
}
#else
bool
AllReduceSSAGraphBuilder
::
IsEncoded
(
const
std
::
string
&
p_name
)
const
{
return
false
;
}
#endif
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
...
...
@@ -992,11 +1016,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateAllReduceOp
(
result
,
g_name
,
IsEncoded
(
p_name
));
#else
PADDLE_ENFORCE
(
false
,
"Compiled withoud cuda!"
);
#endif
CreateAllReduceOp
(
result
,
g_name
);
}
break
;
default:
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
cbdb8a17
...
...
@@ -113,6 +113,8 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
const
std
::
string
&
g_name
)
const
;
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
{}
bool
IsEncoded
(
const
std
::
string
&
p_name
)
const
;
};
class
AsyncSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
...
...
@@ -203,8 +205,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
mutable
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set_
;
mutable
bool
need_broadcast_var_
{
false
};
bool
IsEncoded
(
const
std
::
string
&
p_name
)
const
;
};
std
::
unordered_set
<
std
::
string
>
&
MultiDevSSAGraphBuilder
();
...
...
paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc
0 → 100644
浏览文件 @
cbdb8a17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h"
#include <algorithm>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool
(
sync_nccl_allreduce
);
namespace
paddle
{
namespace
framework
{
namespace
details
{
SparseAllReduceOpHandle
::
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
,
bool
is_encoded
,
int
nranks
)
:
AllReduceOpHandle
(
node
,
local_scopes
,
places
,
ctxs
),
is_encoded_
(
is_encoded
),
nranks_
(
nranks
)
{
// TODO(gongwb) :polish them!
if
(
is_encoded
)
{
VLOG
(
1
)
<<
"Use dgc allreduce mode"
;
}
}
void
SparseAllReduceOpHandle
::
RunImplEncoded
()
{
platform
::
RecordEvent
record_event
(
Name
());
WaitInputVarGenerated
();
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Inputs
());
auto
out_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Outputs
());
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
places_
.
size
(),
"The NoDummyInputSize should be equal to the number of places."
);
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
out_var_handles
.
size
(),
"The NoDummyInputSize and NoDummyOutputSize should be equal."
);
std
::
vector
<
const
LoDTensor
*>
ins
;
std
::
vector
<
LoDTensor
*>
outs
;
int
k
=
-
1
;
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
local_scope
=
local_scopes_
[
i
]
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
original_name
=
paddle
::
framework
::
GradOriginalVarName
(
in_var_handles
[
i
]
->
name
());
auto
encode_var_name
=
original_name
+
g_dgc_encoded
;
auto
*
in_var
=
local_scope
->
FindVar
(
encode_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
in_var
,
"%s should not be null"
,
encode_var_name
);
auto
&
in
=
in_var
->
Get
<
LoDTensor
>
();
ins
.
emplace_back
(
&
in
);
auto
*
out
=
local_scope
->
FindVar
(
out_var_handles
[
i
]
->
name
())
->
GetMutable
<
LoDTensor
>
();
outs
.
emplace_back
(
out
);
if
(
k
<
0
)
{
k
=
GetKValue
(
in_var_handles
[
i
]
->
name
());
}
}
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ins
[
0
]
->
place
()));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
outs
[
0
]
->
place
()));
PADDLE_ENFORCE
(
nccl_ctxs_
,
"nccl_ctxs should not be nullptr."
);
int
dtype
=
-
1
;
size_t
in_numel
=
0
;
size_t
out_numel
=
0
;
PADDLE_ENFORCE
(
nranks_
>
1
);
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
place
=
places_
[
i
];
auto
&
in
=
*
ins
[
i
];
void
*
in_tensor_buf
=
const_cast
<
void
*>
(
in
.
data
<
void
>
());
auto
&
out
=
*
outs
[
i
];
float
*
out_tensor_buf
=
out
.
data
<
float
>
();
dtype
=
(
dtype
==
-
1
)
?
platform
::
ToNCCLDataType
(
in
.
type
())
:
dtype
;
in_numel
=
(
in_numel
==
0
)
?
static_cast
<
size_t
>
(
in
.
numel
())
:
in_numel
;
PADDLE_ENFORCE
(
in_numel
%
2
==
0
);
PADDLE_ENFORCE
(
in_numel
/
2
==
static_cast
<
size_t
>
(
k
));
out_numel
=
(
out_numel
==
0
)
?
static_cast
<
size_t
>
(
out
.
numel
())
:
out_numel
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
auto
&
allocator
=
platform
::
DeviceTemporaryAllocator
::
Instance
().
Get
(
place
,
stream
);
int
encode_size
=
2
*
k
*
sizeof
(
int
);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int
buf_size
=
nranks_
*
encode_size
;
auto
tmp_ious_data
=
allocator
.
Allocate
(
buf_size
);
void
*
gather_buff
=
reinterpret_cast
<
void
*>
(
tmp_ious_data
->
ptr
());
VLOG
(
10
)
<<
"in_numel:"
<<
in_numel
<<
", out_numel:"
<<
out_numel
<<
", nranks:"
<<
nranks_
<<
", gather_buf size:"
<<
buf_size
<<
", k:"
<<
k
<<
", place:"
<<
place
<<
", dtype:"
<<
dtype
;
all_reduce_calls
.
emplace_back
([
=
]
{
PADDLE_ENFORCE
(
paddle
::
communication
::
dgc
::
sparseAllGReduce
(
in_tensor_buf
,
gather_buff
,
k
,
out_tensor_buf
,
out_numel
,
comm
,
stream
));
});
}
RunAllReduceFuncs
(
all_reduce_calls
);
}
int
SparseAllReduceOpHandle
::
GetKValue
(
const
std
::
string
&
grad_name
)
{
auto
original_name
=
paddle
::
framework
::
GradOriginalVarName
(
grad_name
);
auto
var_name
=
original_name
+
g_dgc_k
;
PADDLE_ENFORCE
(
local_scopes_
.
size
()
>
0
);
auto
*
scope
=
local_scopes_
[
0
];
auto
&
local_scope
=
scope
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
var
=
local_scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
auto
tensor
=
var
->
Get
<
LoDTensor
>
().
data
<
float
>
();
return
*
tensor
;
}
bool
SparseAllReduceOpHandle
::
IsEncoded
()
{
if
(
!
is_encoded_
)
{
return
false
;
}
auto
counter_name
=
g_dgc_counter_name
;
auto
step_name
=
g_dgc_rampup_begin_step
;
PADDLE_ENFORCE
(
local_scopes_
.
size
()
>
0
);
auto
*
scope
=
local_scopes_
[
0
];
auto
&
local_scope
=
scope
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
count_var
=
local_scope
->
FindVar
(
counter_name
);
auto
step_var
=
local_scope
->
FindVar
(
step_name
);
if
(
count_var
==
nullptr
||
step_var
==
nullptr
)
{
PADDLE_THROW
(
"not find count_var:%s or step_var:%s"
,
counter_name
,
step_var
);
}
float
count
=
*
count_var
->
Get
<
LoDTensor
>
().
data
<
float
>
();
float
step
=
*
step_var
->
Get
<
LoDTensor
>
().
data
<
float
>
();
if
(
static_cast
<
int
>
(
count
)
<
static_cast
<
int
>
(
step
))
{
VLOG
(
10
)
<<
"in all_reduce currentstep:"
<<
count
<<
" < rampup_begin_step:"
<<
step
<<
" so not use sparse all reduce"
;
return
false
;
}
return
true
;
}
void
SparseAllReduceOpHandle
::
RunImpl
()
{
if
(
!
IsEncoded
())
{
AllReduceOpHandle
::
RunImpl
();
return
;
}
RunImplEncoded
();
}
std
::
string
SparseAllReduceOpHandle
::
Name
()
const
{
return
"sparse_all_reduce"
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
0 → 100644
浏览文件 @
cbdb8a17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/dgc_const_values.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
SparseAllReduceOpHandle
:
public
AllReduceOpHandle
{
public:
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
,
bool
is_encoded
=
false
,
int
nranks
=
-
1
);
std
::
string
Name
()
const
override
;
protected:
void
RunImpl
()
override
;
int
GetKValue
(
const
std
::
string
&
grad_name
);
bool
IsEncoded
();
void
RunImplEncoded
();
private:
bool
is_encoded_
{
false
};
int
nranks_
{
-
1
};
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
cbdb8a17
...
...
@@ -49,11 +49,6 @@ set(SHARED_INFERENCE_SRCS
${
mkldnn_quantizer_src
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/details/zero_copy_tensor.cc
)
# FIXME(gongwb): hidden libdgc.a
if
(
WITH_GPU AND NOT WIN32
)
set
(
fluid_modules
${
fluid_modules
}
dgc
)
endif
()
if
(
WIN32
)
sep_library
(
paddle_fluid DEPS
${
fluid_modules
}
${
STATIC_INFERENCE_APIS
}
zero_copy_tensor reset_tensor_array
analysis_config
${
mkldnn_quantizer_cfg
}
paddle_pass_builder
)
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
cbdb8a17
...
...
@@ -72,7 +72,7 @@ endif()
set
(
COMMON_OP_DEPS
${
OP_HEADER_DEPS
}
)
if
(
WITH_
GPU AND NOT WIN32
)
if
(
WITH_
DGC
)
op_library
(
dgc_op DEPS dgc
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(dgc);
\n
"
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dgc
)
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
cbdb8a17
...
...
@@ -45,13 +45,12 @@ cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce)
cc_test
(
cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper
)
set
(
dgc_deps
""
)
IF
(
WITH_DGC
)
set
(
dgc_deps dgc
)
ENDIF
()
IF
(
WITH_GPU
)
set
(
GPU_CTX_DEPS dynload_cuda dynamic_loader
)
if
(
NOT WIN32
)
set
(
dgc_deps dgc
)
endif
()
ELSE
()
set
(
dgc_deps
)
ENDIF
()
IF
(
WITH_MKLDNN
)
...
...
paddle/fluid/platform/init.cc
浏览文件 @
cbdb8a17
...
...
@@ -31,7 +31,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h"
#if defined(PADDLE_WITH_
CUDA) && !defined(_WIN32
)
#if defined(PADDLE_WITH_
DGC
)
#include "dgc/dgc.h"
#endif
...
...
@@ -211,7 +211,7 @@ void InitGLOG(const std::string &prog_name) {
#endif
}
#if defined(PADDLE_WITH_
CUDA) && !defined(_WIN32
)
#if defined(PADDLE_WITH_
DGC
)
void
InitDGC
()
{
std
::
call_once
(
dgc_init_flag
,
[]()
{
PADDLE_ENFORCE
(
paddle
::
communication
::
dgc
::
dynloadNcclLib
());
...
...
paddle/fluid/pybind/const_value.cc
浏览文件 @
cbdb8a17
...
...
@@ -17,6 +17,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_DGC)
#include "paddle/fluid/framework/details/dgc_const_values.h"
#include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h"
#endif
namespace
paddle
{
namespace
pybind
{
...
...
@@ -52,6 +57,17 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker
.
def
(
"kOpCreationCallstackAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpCreationCallstackAttrName
);
#if defined(PADDLE_WITH_DGC)
auto
dgc
=
m
->
def_submodule
(
"dgc"
);
dgc
.
def
(
"kDGCUName"
,
[]
{
return
framework
::
details
::
g_dgc_u
;
});
dgc
.
def
(
"kDGCVName"
,
[]
{
return
framework
::
details
::
g_dgc_v
;
});
dgc
.
def
(
"kDGCKName"
,
[]
{
return
framework
::
details
::
g_dgc_k
;
});
dgc
.
def
(
"kDGCEncodedName"
,
[]
{
return
framework
::
details
::
g_dgc_encoded
;
});
dgc
.
def
(
"kDGCCounterName"
,
[]
{
return
framework
::
details
::
g_dgc_counter_name
;
});
dgc
.
def
(
"kDGCRampUpBeginStepName"
,
[]
{
return
framework
::
details
::
g_dgc_rampup_begin_step
;
});
#endif
}
}
// namespace pybind
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
cbdb8a17
...
...
@@ -751,14 +751,14 @@ class DGCMomentumOptimizer(MomentumOptimizer):
# step counter
self
.
_global_step_var
=
self
.
_add_auto_increment_var
(
counter_name
=
'__g_dgc_counter__'
,
begin
=
0
)
counter_name
=
core
.
dgc
.
kDGCCounterName
()
,
begin
=
0
)
# rampup begin step var for all_reduce_op_handle
self
.
_rampup_begin_step_var
=
tensor
.
create_global_var
(
shape
=
[
1
],
dtype
=
core
.
VarDesc
.
VarType
.
FP32
,
persistable
=
True
,
name
=
'__g_rampup_begin_step__'
,
name
=
core
.
dgc
.
kDGCRampUpBeginStepName
()
,
value
=
self
.
_rampup_begin_step
*
1.0
,
force_cpu
=
True
)
...
...
@@ -774,20 +774,20 @@ class DGCMomentumOptimizer(MomentumOptimizer):
shape
=
param_var
.
shape
,
dtype
=
param_var
.
dtype
,
persistable
=
True
,
name
=
param_var
.
name
+
"__dgc_u__"
,
name
=
param_var
.
name
+
core
.
dgc
.
kDGCUName
()
,
value
=
0.0
)
v_var
=
tensor
.
create_global_var
(
shape
=
param_var
.
shape
,
dtype
=
param_var
.
dtype
,
persistable
=
True
,
name
=
param_var
.
name
+
"__dgc_v__"
,
name
=
param_var
.
name
+
core
.
dgc
.
kDGCVName
()
,
value
=
0.0
)
k_var
=
tensor
.
create_global_var
(
shape
=
[
1
],
dtype
=
param_var
.
dtype
,
persistable
=
True
,
name
=
param_var
.
name
+
"__dgc_k__"
,
name
=
param_var
.
name
+
core
.
dgc
.
kDGCKName
()
,
value
=
0.0
,
force_cpu
=
True
)
...
...
@@ -795,7 +795,7 @@ class DGCMomentumOptimizer(MomentumOptimizer):
shape
=
[
1
],
dtype
=
param_var
.
dtype
,
persistable
=
True
,
name
=
param_var
.
name
+
"__dgc_encoded__"
,
name
=
param_var
.
name
+
core
.
dgc
.
kDGCEncodedName
()
,
value
=
0.0
,
force_cpu
=
False
)
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
cbdb8a17
...
...
@@ -104,11 +104,13 @@ class ParallelExecutor(object):
self
.
_scope
=
scope
if
scope
is
not
None
else
executor
.
global_scope
()
if
main_program
is
not
None
and
main_program
.
_enable_dgc
:
assert
num_trainers
>
1
assert
build_strategy
.
reduce_strategy
==
BuildStrategy
.
ReduceStrategy
.
AllReduce
assert
num_trainers
>
1
,
"dgc is not useful when num_trainers <= 1"
assert
build_strategy
.
reduce_strategy
==
BuildStrategy
.
ReduceStrategy
.
AllReduce
,
"dgc
\
only used for allreduce"
assert
num_trainers
*
len
(
self
.
_places
)
>
1
,
"dgc is not useful for single card training"
assert
use_cuda
assert
use_cuda
,
"dgc only used under cuda"
main_program
=
main_program
if
main_program
is
not
None
\
else
framework
.
default_main_program
()
...
...
@@ -125,10 +127,6 @@ class ParallelExecutor(object):
share_vars_from
=
share_vars_from
.
_compiled_program
if
share_vars_from
else
None
)
# FIXME(gongwb): I will move dgc from dist mode to allreduce mode in next pr.
if
main_program
.
_enable_dgc
:
self
.
_compiled_program
.
_build_strategy
.
is_distribution
=
True
self
.
_place
=
core
.
CUDAPlace
(
0
)
if
use_cuda
else
core
.
CPUPlace
()
self
.
_exe
=
executor
.
Executor
(
self
.
_place
)
self
.
_compiled_program
.
_compile
(
place
=
self
.
_place
,
scope
=
self
.
_scope
)
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
cbdb8a17
...
...
@@ -97,11 +97,13 @@ py_test_modules(test_imperative_se_resnext MODULES test_imperative_se_resnext EN
if
(
WITH_DISTRIBUTE
)
py_test_modules
(
test_dist_train MODULES test_dist_train SERIAL
)
set_tests_properties
(
test_listen_and_serv_op PROPERTIES TIMEOUT 20
)
if
(
WITH_DGC
)
py_test_modules
(
test_dgc_op MODULES test_dgc_op
)
endif
()
if
(
NOT APPLE
)
set_tests_properties
(
test_dist_mnist PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_dist_word2vec PROPERTIES TIMEOUT 200
)
py_test_modules
(
test_dist_se_resnext MODULES test_dist_se_resnext
)
py_test_modules
(
test_dgc_op MODULES test_dgc_op
)
set_tests_properties
(
test_dist_se_resnext PROPERTIES TIMEOUT 1000
)
py_test_modules
(
test_dist_se_resnext_nccl MODULES test_dist_se_resnext_nccl SERIAL
)
set_tests_properties
(
test_dist_se_resnext_nccl PROPERTIES TIMEOUT 1000
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录