Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
106e2852
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
106e2852
编写于
12月 12, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unittest for parllelgraph mode test=develop
上级
5cc83f79
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
164 addition
and
128 deletion
+164
-128
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+6
-2
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
...le/fluid/framework/details/parallel_ssa_graph_executor.cc
+8
-12
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+1
-1
paddle/fluid/operators/reader/ctr_reader.h
paddle/fluid/operators/reader/ctr_reader.h
+1
-1
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+84
-80
python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py
...addle/fluid/tests/unittests/test_parallel_executor_crf.py
+3
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
...dle/fluid/tests/unittests/test_parallel_executor_mnist.py
+24
-14
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
...fluid/tests/unittests/test_parallel_executor_seresnext.py
+32
-17
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
...uid/tests/unittests/test_parallel_executor_transformer.py
+5
-1
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
106e2852
...
...
@@ -300,7 +300,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto
nodes
=
graph
->
ReleaseNodes
();
ir
::
Graph
&
result
=
*
graph
;
//
int num_trainers = Get<int>(kNumTrainers);
int
num_trainers
=
Get
<
int
>
(
kNumTrainers
);
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
...
...
@@ -387,7 +387,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
}
// if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
if
(
!
is_forwarding
&&
nccl_ctxs_
->
contexts_
.
size
()
>
1
)
{
// insert synchronous ops at the backpropagation; and
// insert synchronous ops if the graph contains mutilple places.
if
(
!
is_forwarding
&&
(
places_
.
size
()
>
1
||
num_trainers
>
1
||
(
nccl_ctxs_
&&
nccl_ctxs_
->
contexts_
.
size
()
>
1
)))
{
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
if
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
...
...
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
浏览文件 @
106e2852
...
...
@@ -49,18 +49,18 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
call
=
[
this
,
i
,
&
fetch_tensors
]()
->
FeedFetchList
{
try
{
return
executors_
[
i
]
->
Run
(
fetch_tensors
);
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
}
return
FeedFetchList
();
};
if
(
pool_
)
{
run_futures
.
emplace_back
(
pool_
->
enqueue
(
std
::
move
(
call
)));
}
else
{
try
{
fetch_datas
.
emplace_back
(
std
::
move
(
call
()));
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
break
;
}
call
();
}
}
...
...
@@ -69,11 +69,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
if
(
exception_holder_
.
IsCaught
())
{
f
.
wait
();
}
else
{
try
{
fetch_datas
.
emplace_back
(
std
::
move
(
f
.
get
()));
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
}
}
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
106e2852
...
...
@@ -87,7 +87,7 @@ ParallelExecutor::ParallelExecutor(
"the number of places must be greater than 1."
);
PADDLE_ENFORCE
(
exec_strategy
.
type_
!=
ExecutionStrategy
::
kParallelGraph
,
"You should set build_strategy.reduce with 'AllReduce' for "
"ParallelGraph executor type"
);
"
the
ParallelGraph executor type"
);
}
// Step 1. Bcast the params to devs.
...
...
paddle/fluid/operators/reader/ctr_reader.h
浏览文件 @
106e2852
...
...
@@ -48,7 +48,7 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class
CTRReader
:
public
framework
::
FileReader
{
public:
explicit
CTRReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
,
int
batch_size
,
in
t
thread_num
,
int
batch_size
,
size_
t
thread_num
,
const
std
::
vector
<
std
::
string
>&
slots
,
const
std
::
vector
<
std
::
string
>&
file_list
)
:
batch_size_
(
batch_size
),
slots_
(
slots
),
file_list_
(
file_list
)
{
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
106e2852
...
...
@@ -26,9 +26,12 @@ import sys
__all__
=
[
'TestParallelExecutorBase'
]
ExecutorType
=
fluid
.
ExecutionStrategy
().
ExecutorType
class
TestParallelExecutorBase
(
unittest
.
TestCase
):
def
check_network_convergence
(
self
,
def
check_network_convergence
(
self
,
method
,
use_cuda
=
True
,
memory_opt
=
True
,
...
...
@@ -41,7 +44,7 @@ class TestParallelExecutorBase(unittest.TestCase):
use_reduce
=
False
,
fuse_elewise_add_act_ops
=
False
,
optimizer
=
fluid
.
optimizer
.
Adam
,
use_fast_executor
=
False
,
exec_type
=
fluid
.
ExecutionStrategy
().
ExecutorType
.
Default
,
enable_sequential_execution
=
False
):
def
run_executor
(
exe
,
feed
,
fetch_list
,
program
=
None
):
if
isinstance
(
exe
,
fluid
.
ParallelExecutor
):
...
...
@@ -58,6 +61,8 @@ class TestParallelExecutorBase(unittest.TestCase):
startup
=
fluid
.
Program
()
startup
.
random_seed
=
1
# Fix random seed
main
.
random_seed
=
1
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
main
,
startup
):
if
seed
is
not
None
:
startup
.
random_seed
=
seed
...
...
@@ -75,8 +80,7 @@ class TestParallelExecutorBase(unittest.TestCase):
startup_exe
.
run
(
startup
)
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
allow_op_delay
=
allow_op_delay
if
use_fast_executor
:
exec_strategy
.
use_experimental_executor
=
True
exec_strategy
.
executor_type
=
exec_type
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py
浏览文件 @
106e2852
...
...
@@ -181,6 +181,9 @@ class TestCRFModel(unittest.TestCase):
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
is_sparse
=
True
,
build_strategy
=
build_strategy
,
use_cuda
=
True
)
self
.
check_network_convergence
(
is_sparse
=
True
,
build_strategy
=
build_strategy
,
use_cuda
=
True
)
self
.
check_network_convergence
(
is_sparse
=
True
,
build_strategy
=
build_strategy
,
use_cuda
=
False
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
浏览文件 @
106e2852
...
...
@@ -20,7 +20,7 @@ import numpy as np
import
paddle.fluid.core
as
core
import
os
import
paddle.fluid
as
fluid
from
parallel_executor_test_base
import
TestParallelExecutorBase
from
parallel_executor_test_base
import
TestParallelExecutorBase
,
ExecutorType
def
simple_fc_net
(
use_feed
):
...
...
@@ -99,7 +99,10 @@ class TestMNIST(TestParallelExecutorBase):
self
.
assertAlmostEqual
(
loss
[
0
],
loss
[
1
],
delta
=
1e-4
)
# simple_fc
def
check_simple_fc_convergence
(
self
,
use_cuda
,
use_reduce
=
False
):
def
check_simple_fc_convergence
(
self
,
use_cuda
,
use_reduce
=
False
,
exec_type
=
ExecutorType
.
Default
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
...
...
@@ -110,19 +113,21 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_reduce
=
use_reduce
)
use_reduce
=
use_reduce
,
exec_type
=
exec_type
)
def
test_simple_fc
(
self
):
# use_cuda
self
.
check_simple_fc_convergence
(
True
)
self
.
check_simple_fc_convergence
(
True
,
ExecutorType
.
Default
)
self
.
check_simple_fc_convergence
(
True
,
ExecutorType
.
ParallelGraph
)
self
.
check_simple_fc_convergence
(
False
)
def
test_simple_fc_with_new_strategy
(
self
):
# use_cuda, use_reduce
# use_cuda, use_reduce
a
self
.
_compare_reduce_and_allreduce
(
simple_fc_net
,
True
)
self
.
_compare_reduce_and_allreduce
(
simple_fc_net
,
False
)
def
check_simple_fc_parallel_accuracy
(
self
,
use_cuda
):
def
check_simple_fc_parallel_accuracy
(
self
,
use_cuda
,
exec_type
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
...
...
@@ -134,14 +139,16 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_parallel_executor
=
False
)
use_parallel_executor
=
False
,
exec_type
=
exec_type
)
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_parallel_executor
=
True
)
use_parallel_executor
=
True
,
exec_type
=
exec_type
)
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_first_loss
),
...
...
@@ -151,10 +158,12 @@ class TestMNIST(TestParallelExecutorBase):
np
.
mean
(
parallel_last_loss
),
single_last_loss
,
delta
=
1e-6
)
def
test_simple_fc_parallel_accuracy
(
self
):
self
.
check_simple_fc_parallel_accuracy
(
True
)
self
.
check_simple_fc_parallel_accuracy
(
False
)
self
.
check_simple_fc_parallel_accuracy
(
True
,
ExecutorType
.
Default
)
self
.
check_simple_fc_parallel_accuracy
(
True
,
ExecutorType
.
ParallelGraph
)
# FIXME(Yancey1989): ParallelGraph executor type support CPU mode
self
.
check_simple_fc_parallel_accuracy
(
False
,
ExecutorType
.
Default
)
def
check_batchnorm_fc_convergence
(
self
,
use_cuda
,
use_fast_executor
):
def
check_batchnorm_fc_convergence
(
self
,
use_cuda
,
exec_type
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
...
...
@@ -165,12 +174,13 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_fast_executor
=
use_fast_executor
)
exec_type
=
exec_type
)
def
test_batchnorm_fc
(
self
):
for
use_cuda
in
(
False
,
True
):
for
use_fast_executor
in
(
False
,
True
):
self
.
check_batchnorm_fc_convergence
(
use_cuda
,
use_fast_executor
)
for
exec_type
in
(
ExecutorType
.
Default
,
ExecutorType
.
Experimental
,
ExecutorType
.
ParallelGraph
):
self
.
check_batchnorm_fc_convergence
(
use_cuda
,
exec_type
)
def
test_batchnorm_fc_with_new_strategy
(
self
):
# FIXME(zcd): close this test temporally.
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
浏览文件 @
106e2852
...
...
@@ -19,7 +19,7 @@ import paddle.fluid.layers.ops as ops
from
paddle.fluid.initializer
import
init_on_cpu
from
paddle.fluid.layers.learning_rate_scheduler
import
_decay_step_counter
import
paddle.fluid.core
as
core
from
parallel_executor_test_base
import
TestParallelExecutorBase
from
parallel_executor_test_base
import
TestParallelExecutorBase
,
ExecutorType
import
unittest
import
math
import
os
...
...
@@ -167,13 +167,17 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120):
return
decayed_lr
def
optimizer
(
learning_rate
=
0.01
):
optimizer
=
fluid
.
optimizer
.
Momentum
(
def
optimizer
(
learning_rate
=
0.01
,
lr_scale
=
1.0
):
def
_opt
():
return
fluid
.
optimizer
.
Momentum
(
learning_rate
=
cosine_decay
(
learning_rate
=
learning_rate
,
step_each_epoch
=
2
,
epochs
=
1
),
learning_rate
=
learning_rate
/
lr_scale
,
step_each_epoch
=
2
,
epochs
=
1
),
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
return
optimizer
return
_opt
class
TestResnet
(
TestParallelExecutorBase
):
...
...
@@ -216,7 +220,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
False
,
optimizer
=
optimizer
)
optimizer
=
optimizer
()
)
reduce_first_loss
,
reduce_last_loss
=
self
.
check_network_convergence
(
model
,
feed_dict
=
{
"image"
:
img
,
...
...
@@ -225,7 +229,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
True
,
optimizer
=
optimizer
)
optimizer
=
optimizer
()
)
for
loss
in
zip
(
all_reduce_first_loss
,
reduce_first_loss
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-6
)
...
...
@@ -243,7 +247,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
False
,
optimizer
=
optimizer
,
optimizer
=
optimizer
()
,
enable_sequential_execution
=
True
)
reduce_first_loss_seq
,
reduce_last_loss_seq
=
self
.
check_network_convergence
(
...
...
@@ -254,7 +258,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
True
,
optimizer
=
optimizer
,
optimizer
=
optimizer
()
,
enable_sequential_execution
=
True
)
for
loss
in
zip
(
all_reduce_first_loss
,
all_reduce_first_loss_seq
):
...
...
@@ -277,7 +281,9 @@ class TestResnet(TestParallelExecutorBase):
use_cuda
=
True
,
use_reduce
=
False
,
iter
=
20
,
delta2
=
1e-6
):
delta2
=
1e-6
,
exec_type
=
ExecutorType
.
Default
,
lr_scale
=
1.0
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
...
...
@@ -295,8 +301,9 @@ class TestResnet(TestParallelExecutorBase):
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
use_reduce
,
optimizer
=
optimizer
,
use_parallel_executor
=
False
)
optimizer
=
optimizer
(),
use_parallel_executor
=
False
,
exec_type
=
exec_type
)
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
model
,
feed_dict
=
{
"image"
:
img
,
...
...
@@ -305,7 +312,8 @@ class TestResnet(TestParallelExecutorBase):
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
use_reduce
,
optimizer
=
optimizer
)
optimizer
=
optimizer
(
lr_scale
=
lr_scale
),
exec_type
=
exec_type
)
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_first_loss
),
single_first_loss
[
0
],
delta
=
1e-6
)
...
...
@@ -313,7 +321,14 @@ class TestResnet(TestParallelExecutorBase):
np
.
mean
(
parallel_last_loss
),
single_last_loss
[
0
],
delta
=
delta2
)
def
test_seresnext_with_learning_rate_decay
(
self
):
self
.
_check_resnet_convergence
(
model
=
SE_ResNeXt50Small
,
use_cuda
=
True
)
if
core
.
is_compiled_with_cuda
():
self
.
_check_resnet_convergence
(
model
=
SE_ResNeXt50Small
,
use_cuda
=
True
)
self
.
_check_resnet_convergence
(
model
=
SE_ResNeXt50Small
,
use_cuda
=
True
,
exec_type
=
ExecutorType
.
ParallelGraph
,
lr_scale
=
core
.
get_cuda_device_count
())
self
.
_check_resnet_convergence
(
model
=
SE_ResNeXt50Small
,
use_cuda
=
False
,
iter
=
2
,
delta2
=
1e-3
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
浏览文件 @
106e2852
...
...
@@ -17,7 +17,7 @@ from __future__ import print_function
import
paddle.fluid
as
fluid
import
transformer_model
import
numpy
as
np
from
parallel_executor_test_base
import
TestParallelExecutorBase
from
parallel_executor_test_base
import
TestParallelExecutorBase
,
ExecutorType
import
unittest
import
paddle
import
paddle.fluid.core
as
core
...
...
@@ -173,6 +173,10 @@ class TestTransformer(TestParallelExecutorBase):
def
test_main
(
self
):
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
,
exec_type
=
ExecutorType
.
ParallelGraph
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
,
enable_sequential_execution
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
False
,
iter
=
5
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录