Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
cb40c331
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cb40c331
编写于
3月 26, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update unittest
上级
ee97687f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
70 addition
and
32 deletion
+70
-32
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+1
-1
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+29
-0
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+3
-0
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+37
-31
未找到文件。
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
cb40c331
...
@@ -33,7 +33,7 @@ void ComputationOpHandle::RunImpl() {
...
@@ -33,7 +33,7 @@ void ComputationOpHandle::RunImpl() {
}
}
}
}
op_
->
Run
(
*
scope_
,
place_
);
op_
->
Run
(
*
scope_
->
FindVar
(
"@TMP_SCOPE@"
)
->
Get
<
Scope
*>
()
,
place_
);
}
}
std
::
string
ComputationOpHandle
::
Name
()
const
{
return
op_
->
Type
();
}
std
::
string
ComputationOpHandle
::
Name
()
const
{
return
op_
->
Type
();
}
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
cb40c331
...
@@ -112,6 +112,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -112,6 +112,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
ready_ops
.
clear
();
ready_ops
.
clear
();
};
};
// Create local scopes.
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
local_scope
=
scope
->
NewScope
();
*
scope
->
Var
(
"@TMP_SCOPE@"
)
->
GetMutable
<
Scope
*>
()
=
&
local_scope
;
}
// Step 3. Execution
// Step 3. Execution
while
(
!
pending_vars
.
empty
())
{
while
(
!
pending_vars
.
empty
())
{
// 1. Run All Ready ops
// 1. Run All Ready ops
...
@@ -156,9 +162,32 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -156,9 +162,32 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Keep loop until all vars are ready.
// Keep loop until all vars are ready.
}
}
++
computation_count_
;
auto
sync_computation
=
[
&
]
{
computation_count_
=
0
;
// Wait All computational streams
for
(
auto
p
:
this
->
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
// NOTE: the temp scope can be dropped lazily if needed.
// Drop tmp scopes;
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
kid
=
*
scope
->
Var
(
"@TMP_SCOPE@"
)
->
GetMutable
<
Scope
*>
();
kid
=
nullptr
;
scope
->
DropKids
();
}
};
// Wait FetchOps.
// Wait FetchOps.
for
(
auto
&
fetch_op
:
fetch_ops
)
{
for
(
auto
&
fetch_op
:
fetch_ops
)
{
fetch_op
.
WaitAndMergeCPUTensors
();
fetch_op
.
WaitAndMergeCPUTensors
();
sync_computation
();
}
if
(
computation_count_
==
max_async_computation
)
{
sync_computation
();
}
}
return
fetch_data
;
return
fetch_data
;
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
cb40c331
...
@@ -48,6 +48,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -48,6 +48,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
platform
::
DeviceContextPool
fetch_ctxs_
;
platform
::
DeviceContextPool
fetch_ctxs_
;
const
bool
use_event_
;
const
bool
use_event_
;
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
size_t
computation_count_
{
0
};
size_t
max_async_computation
{
100
};
};
};
}
// namespace details
}
// namespace details
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
cb40c331
...
@@ -178,7 +178,32 @@ def SE_ResNeXt152():
...
@@ -178,7 +178,32 @@ def SE_ResNeXt152():
return
loss
return
loss
class
ParallelExecutor
(
unittest
.
TestCase
):
class
TestParallelExecutorBase
(
unittest
.
TestCase
):
def
check_network_convergence
(
self
,
method
,
memory_opt
=
True
,
iter
=
10
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
method
()
adam
=
fluid
.
optimizer
.
Adam
()
adam
.
minimize
(
loss
)
if
memory_opt
:
fluid
.
memory_optimize
(
main
)
exe
=
fluid
.
ParallelExecutor
(
loss_name
=
loss
.
name
,
use_cuda
=
True
)
first_loss
,
=
exe
.
run
([
loss
.
name
])
first_loss
=
numpy
.
array
(
first_loss
)
for
i
in
xrange
(
iter
):
exe
.
run
([])
last_loss
,
=
exe
.
run
([
loss
.
name
])
last_loss
=
numpy
.
array
(
last_loss
)
print
first_loss
,
last_loss
self
.
assertGreater
(
first_loss
[
0
],
last_loss
[
0
])
class
TestMNIST
(
TestParallelExecutorBase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
# Convert mnist to recordio file
# Convert mnist to recordio file
...
@@ -195,6 +220,16 @@ class ParallelExecutor(unittest.TestCase):
...
@@ -195,6 +220,16 @@ class ParallelExecutor(unittest.TestCase):
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist.recordio'
,
reader
,
feeder
)
'./mnist.recordio'
,
reader
,
feeder
)
def
test_simple_fc
(
self
):
self
.
check_network_convergence
(
simple_fc_net
)
def
test_batchnorm_fc
(
self
):
self
.
check_network_convergence
(
fc_with_batchnorm
)
class
TestResnet
(
TestParallelExecutorBase
):
@
classmethod
def
setUpClass
(
cls
):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
flowers
.
train
(),
batch_size
=
4
)
reader
=
paddle
.
batch
(
flowers
.
train
(),
batch_size
=
4
)
feeder
=
fluid
.
DataFeeder
(
feeder
=
fluid
.
DataFeeder
(
...
@@ -208,34 +243,5 @@ class ParallelExecutor(unittest.TestCase):
...
@@ -208,34 +243,5 @@ class ParallelExecutor(unittest.TestCase):
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
"./flowers.recordio"
,
reader
,
feeder
)
"./flowers.recordio"
,
reader
,
feeder
)
def
test_simple_fc
(
self
):
self
.
check_network_convergence
(
simple_fc_net
)
def
test_batchnorm_fc
(
self
):
self
.
check_network_convergence
(
fc_with_batchnorm
)
def
check_network_convergence
(
self
,
method
,
memory_opt
=
True
,
iter
=
10
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
method
()
adam
=
fluid
.
optimizer
.
Adam
()
adam
.
minimize
(
loss
)
if
memory_opt
:
fluid
.
memory_optimize
(
main
)
exe
=
fluid
.
ParallelExecutor
(
loss_name
=
loss
.
name
,
use_cuda
=
True
)
first_loss
,
=
exe
.
run
([
loss
.
name
])
first_loss
=
numpy
.
array
(
first_loss
)
for
i
in
xrange
(
iter
):
exe
.
run
([])
last_loss
,
=
exe
.
run
([
loss
.
name
])
last_loss
=
numpy
.
array
(
last_loss
)
print
first_loss
,
last_loss
self
.
assertGreater
(
first_loss
[
0
],
last_loss
[
0
])
def
test_resnet
(
self
):
def
test_resnet
(
self
):
self
.
check_network_convergence
(
SE_ResNeXt152
,
iter
=
20
)
self
.
check_network_convergence
(
SE_ResNeXt152
,
iter
=
20
0
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录