Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
5a3c8bf8
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5a3c8bf8
编写于
6月 09, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix in c++ side
上级
a56dcf51
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
127 addition
and
49 deletion
+127
-49
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-0
paddle/fluid/framework/details/graph_builder_factory.h
paddle/fluid/framework/details/graph_builder_factory.h
+5
-1
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+33
-11
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
+19
-4
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
+11
-3
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+25
-15
python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py
...luid/tests/unittests/test_parallel_executor_fetch_feed.py
+16
-8
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
...dle/fluid/tests/unittests/test_parallel_executor_mnist.py
+2
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py
...ests/unittests/test_parallel_executor_test_while_train.py
+12
-6
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
...uid/tests/unittests/test_parallel_executor_transformer.py
+2
-1
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
5a3c8bf8
...
...
@@ -19,6 +19,8 @@ if(WITH_GPU)
nv_library
(
broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda
)
else
()
cc_library
(
nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor
)
set
(
multi_devices_graph_builder_deps
)
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim
)
cc_library
(
broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
...
...
paddle/fluid/framework/details/graph_builder_factory.h
浏览文件 @
5a3c8bf8
...
...
@@ -40,7 +40,11 @@ class SSAGraphBuilderFactory {
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{}
strategy_
(
strategy
)
{
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_
=
nullptr
;
#endif
}
#ifdef PADDLE_WITH_CUDA
void
SetNCCLContextMap
(
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
5a3c8bf8
...
...
@@ -20,16 +20,13 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -305,7 +302,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
auto
*
out_var
=
new
VarHandle
(
vars
.
size
(),
i
,
p_name
,
p
);
vars
.
emplace_back
(
out_var
);
op_handle
->
AddOutput
(
out_var
);
#ifndef ADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
nccl_ctxs_
==
nullptr
)
{
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
}
#else
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
#endif
...
...
@@ -324,7 +326,10 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph
*
result
,
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
ops_
.
emplace_back
(
new
NCCLAllReduceOpHandle
(
local_scopes_
,
places_
,
*
nccl_ctxs_
));
new
NCCLAllReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
ops_
.
emplace_back
(
new
NCCLAllReduceOpHandle
(
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
ops_
.
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
...
...
@@ -334,13 +339,23 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
#ifdef PADDLE_WITH_CUDA
if
(
nccl_ctxs_
==
nullptr
)
{
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
}
#else
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
#endif
VLOG
(
4
)
<<
"NCCL - - - "
<<
p
;
op_handle
->
DeviceContext
(
p
)
->
Wait
();
VLOG
(
4
)
<<
"NCCL - - - "
<<
p
<<
" "
<<
op_handle
->
DeviceContext
(
p
);
auto
var
=
new
VarHandle
(
vars
.
size
()
-
1
,
i
,
og
,
p
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
#else
PADDLE_ENFORCE
(
"Not implemented"
);
#endif
}
bool
MultiDevSSAGraphBuilder
::
IsParameterGradientOnce
(
...
...
@@ -379,7 +394,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
auto
*
communication_dev_ctx
=
nccl_ctxs_
->
DevCtx
(
places_
[
i
]);
auto
*
communication_dev_ctx
=
nccl_ctxs_
?
nccl_ctxs_
->
DevCtx
(
places_
[
i
])
:
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
#else
auto
*
communication_dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
...
...
@@ -425,8 +442,13 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
vars
=
result
->
vars_
[
i
][
og
];
#ifndef PADDLE_WITH_CUDA
auto
&
p
=
places_
[
i
];
#ifdef PADDLE_WITH_CUDA
if
(
nccl_ctxs_
==
nullptr
)
{
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
}
#else
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
#endif
...
...
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
浏览文件 @
5a3c8bf8
...
...
@@ -21,15 +21,25 @@
namespace
paddle
{
namespace
framework
{
namespace
details
{
#ifdef PADDLE_WITH_CUDA
NCCLAllReduceOpHandle
::
NCCLAllReduceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
&
ctxs
)
const
platform
::
NCCLContextMap
*
ctxs
)
:
local_scopes_
(
local_scopes
),
places_
(
places
),
nccl_ctxs_
(
ctxs
)
{
for
(
auto
&
p
:
places_
)
{
this
->
dev_ctxes_
[
p
]
=
nccl_ctxs_
.
DevCtx
(
p
);
if
(
ctxs
)
{
for
(
auto
&
p
:
places_
)
{
this
->
dev_ctxes_
[
p
]
=
nccl_ctxs_
->
DevCtx
(
p
);
}
}
}
#else
NCCLAllReduceOpHandle
::
NCCLAllReduceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
)
:
local_scopes_
(
local_scopes
),
places_
(
places
)
{}
#endif
void
NCCLAllReduceOpHandle
::
RunImpl
()
{
if
(
NoDummyInputSize
()
==
1
)
{
...
...
@@ -58,6 +68,8 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
if
(
platform
::
is_gpu_place
(
lod_tensors
[
0
]
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
nccl_ctxs_
);
int
dtype
=
-
1
;
size_t
numel
=
0
;
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
...
...
@@ -75,7 +87,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
.
at
(
dev_id
);
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
all_reduce_calls
.
emplace_back
([
=
]
{
...
...
@@ -90,6 +102,9 @@ void NCCLAllReduceOpHandle::RunImpl() {
call
();
}
});
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
}
else
{
// Special handle CPU only Operator's gradient. Like CRF
auto
&
trg
=
*
this
->
local_scopes_
[
0
]
->
FindVar
(
kLocalExecScopeName
)
...
...
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
浏览文件 @
5a3c8bf8
...
...
@@ -20,17 +20,23 @@
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
namespace
details
{
struct
NCCLAllReduceOpHandle
:
public
OpHandleBase
{
#ifdef PADDLE_WITH_CUDA
NCCLAllReduceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
&
ctxs
);
const
platform
::
NCCLContextMap
*
ctxs
);
#else
NCCLAllReduceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
);
#endif
std
::
string
Name
()
const
override
;
// Delay and buffer nccl_all_reduce together can significantly increase
...
...
@@ -43,7 +49,9 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
private:
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
const
platform
::
NCCLContextMap
&
nccl_ctxs_
;
#ifdef PADDLE_WITH_CUDA
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
};
}
// namespace details
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
5a3c8bf8
...
...
@@ -44,6 +44,7 @@ class ParallelExecutorPrivate {
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
nccl_ctxs_
;
#endif
bool
own_local_scope
;
bool
use_cuda
;
};
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
...
...
@@ -60,6 +61,7 @@ ParallelExecutor::ParallelExecutor(
size_t
num_trainers
,
size_t
trainer_id
)
:
member_
(
new
ParallelExecutorPrivate
(
places
))
{
member_
->
global_scope_
=
scope
;
member_
->
use_cuda
=
exec_strategy
.
use_event_
;
// Step 1. Bcast the params to devs.
// Create local scopes
...
...
@@ -77,18 +79,22 @@ ParallelExecutor::ParallelExecutor(
}
}
if
(
member_
->
use_cuda
)
{
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
ncclUniqueId
*
nccl_id
=
nullptr
;
if
(
nccl_id_var
!=
nullptr
)
{
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
}
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
,
nccl_id
,
num_trainers
,
trainer_id
));
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
ncclUniqueId
*
nccl_id
=
nullptr
;
if
(
nccl_id_var
!=
nullptr
)
{
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
}
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
,
nccl_id
,
num_trainers
,
trainer_id
));
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
if
(
platform
::
is_gpu_place
(
places
[
0
])
&&
member_
->
local_scopes_
.
size
()
!=
1
&&
local_scopes
.
empty
())
{
// Is CUDA
}
if
(
member_
->
local_scopes_
.
size
()
!=
1
&&
local_scopes
.
empty
())
{
BCastParamsToGPUs
(
bcast_vars
);
}
// Startup Program has been run. All local scopes has correct parameters.
...
...
@@ -108,9 +114,13 @@ ParallelExecutor::ParallelExecutor(
details
::
SSAGraphBuilderFactory
builder_factory
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
);
if
(
member_
->
use_cuda
)
{
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
}
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
...
...
@@ -123,7 +133,6 @@ ParallelExecutor::ParallelExecutor(
void
ParallelExecutor
::
BCastParamsToGPUs
(
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
{
#ifdef PADDLE_WITH_CUDA
auto
*
main_scope
=
member_
->
local_scopes_
[
0
];
for
(
auto
&
var
:
vars
)
{
...
...
@@ -135,6 +144,7 @@ void ParallelExecutor::BCastParamsToGPUs(
auto
&
main_tensor
=
main_var
->
Get
<
LoDTensor
>
();
auto
&
dims
=
main_tensor
.
dims
();
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
#ifdef PADDLE_WITH_CUDA
size_t
numel
=
main_tensor
.
numel
();
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
platform
::
NCCLGroupGuard
guard
;
...
...
@@ -153,6 +163,10 @@ void ParallelExecutor::BCastParamsToGPUs(
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
0
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
}
member_
->
nccl_ctxs_
->
WaitAll
();
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
}
else
{
platform
::
CPUPlace
cpu
;
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
...
...
@@ -163,11 +177,7 @@ void ParallelExecutor::BCastParamsToGPUs(
paddle
::
framework
::
TensorCopy
(
main_tensor
,
cpu
,
t
);
}
}
member_
->
nccl_ctxs_
->
WaitAll
();
}
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
}
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py
浏览文件 @
5a3c8bf8
...
...
@@ -35,7 +35,7 @@ def Lenet(data, class_dim):
class
TestFetchOp
(
unittest
.
TestCase
):
def
parallel_exe
(
self
,
train_inputs
,
seed
):
def
parallel_exe
(
self
,
train_inputs
,
seed
,
use_cuda
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
startup
.
random_seed
=
seed
...
...
@@ -59,13 +59,13 @@ class TestFetchOp(unittest.TestCase):
# conv2d_1.b_0@GRAD. Those variables should not be pruned.
# fluid.memory_optimize(main)
place
=
fluid
.
CUDAPlace
(
0
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
data
,
label
])
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
loss
.
name
,
main_program
=
main
)
use_cuda
=
use_cuda
,
loss_name
=
loss
.
name
,
main_program
=
main
)
fetch_list
=
[]
all_vars
=
main
.
global_block
().
vars
...
...
@@ -88,14 +88,15 @@ class TestFetchOp(unittest.TestCase):
for
i
in
range
(
iters
):
train_inputs
.
append
(
tst_reader_iter
.
next
())
self
.
parallel_exe
(
train_inputs
,
seed
=
1
)
self
.
parallel_exe
(
train_inputs
,
seed
=
1
,
use_cuda
=
True
)
self
.
parallel_exe
(
train_inputs
,
seed
=
1
,
use_cuda
=
False
)
class
TestFeedParallel
(
unittest
.
TestCase
):
def
test_main
(
self
):
def
parallel_exe
(
self
,
use_cuda
,
seed
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
startup
.
random_seed
=
1
startup
.
random_seed
=
seed
with
fluid
.
scope_guard
(
fluid
.
core
.
Scope
()):
with
fluid
.
program_guard
(
main
,
startup
):
data
=
fluid
.
layers
.
data
(
...
...
@@ -111,15 +112,18 @@ class TestFeedParallel(unittest.TestCase):
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
opt
.
minimize
(
loss
)
place
=
fluid
.
CUDAPlace
(
0
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
data
,
label
])
reader
=
feeder
.
decorate_reader
(
paddle
.
batch
(
flowers
.
train
(),
batch_size
=
16
),
multi_devices
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
loss
.
name
,
main_program
=
main
)
use_cuda
=
use_cuda
,
loss_name
=
loss
.
name
,
main_program
=
main
)
for
batch_id
,
data
in
enumerate
(
reader
()):
loss_np
=
np
.
array
(
pe
.
run
(
feed
=
data
,
fetch_list
=
[
loss
.
name
])[
0
])
...
...
@@ -127,6 +131,10 @@ class TestFeedParallel(unittest.TestCase):
if
batch_id
==
2
:
break
def
test_feed_op
(
self
):
self
.
parallel_exe
(
use_cuda
=
True
,
seed
=
1
)
self
.
parallel_exe
(
use_cuda
=
False
,
seed
=
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
浏览文件 @
5a3c8bf8
...
...
@@ -117,9 +117,11 @@ class TestMNIST(TestParallelExecutorBase):
def
test_simple_fc
(
self
):
self
.
check_simple_fc_convergence
(
False
,
use_cuda
=
True
)
self
.
check_simple_fc_convergence
(
False
,
use_cuda
=
False
)
def
test_simple_fc_with_new_strategy
(
self
):
self
.
check_simple_fc_convergence
(
True
,
use_cuda
=
True
)
self
.
check_simple_fc_convergence
(
True
,
use_cuda
=
False
)
def
check_simple_fc_parallel_accuracy
(
self
,
balance_parameter_opt_between_cards
,
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py
浏览文件 @
5a3c8bf8
...
...
@@ -35,7 +35,7 @@ def simple_fc_net():
class
ParallelExecutorTestingDuringTraining
(
unittest
.
TestCase
):
def
check_network_convergence
(
self
,
build_strategy
=
None
):
def
check_network_convergence
(
self
,
use_cuda
,
build_strategy
=
None
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -49,19 +49,19 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
image
=
np
.
random
.
normal
(
size
=
(
batch_size
,
784
)).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
10
,
(
batch_size
,
1
),
dtype
=
"int64"
)
place
=
fluid
.
CUDAPlace
(
0
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
feed_dict
=
{
'image'
:
image
,
'label'
:
label
}
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
use_cuda
=
use_cuda
,
loss_name
=
loss
.
name
,
main_program
=
main
,
build_strategy
=
build_strategy
)
test_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
use_cuda
=
use_cuda
,
main_program
=
test_program
,
share_vars_from
=
train_exe
,
build_strategy
=
build_strategy
)
...
...
@@ -81,12 +81,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def
test_parallel_testing
(
self
):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
self
.
check_network_convergence
(
build_strategy
)
self
.
check_network_convergence
(
use_cuda
=
True
,
build_strategy
=
build_strategy
)
self
.
check_network_convergence
(
use_cuda
=
False
,
build_strategy
=
build_strategy
)
def
test_parallel_testing_with_new_strategy
(
self
):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
self
.
check_network_convergence
(
build_strategy
)
self
.
check_network_convergence
(
use_cuda
=
True
,
build_strategy
=
build_strategy
)
self
.
check_network_convergence
(
use_cuda
=
False
,
build_strategy
=
build_strategy
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
浏览文件 @
5a3c8bf8
...
...
@@ -167,7 +167,8 @@ class TestTransformer(TestParallelExecutorBase):
@
unittest
.
skip
(
"transformer is buggy in multi gpu"
)
def
test_main
(
self
):
self
.
check_network_convergence
(
transformer
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
False
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录