Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e72637dd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e72637dd
编写于
2月 09, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ThreadedSSAGraphExecutor support num_iteration_per_run test=develop
上级
b1fe8d45
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
24 addition
and
18 deletion
+24
-18
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+0
-16
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+23
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+1
-0
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
e72637dd
...
@@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
...
@@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
VLOG
(
3
)
<<
"build AsyncSSAGraphExecutor"
;
VLOG
(
3
)
<<
"build AsyncSSAGraphExecutor"
;
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
if
(
strategy_
.
num_iteration_per_run_
>
1
)
{
int
read_op_num
=
0
;
for
(
auto
*
node
:
graphs_
[
0
]
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Name
()
==
"read"
)
{
read_op_num
++
;
}
}
if
(
read_op_num
==
0
)
{
LOG
(
WARNING
)
<<
"when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!"
;
}
}
// set the correct size of thread pool to each device.
// set the correct size of thread pool to each device.
strategy_
.
num_threads_
=
strategy_
.
num_threads_
<
places_
.
size
()
strategy_
.
num_threads_
=
strategy_
.
num_threads_
<
places_
.
size
()
?
1UL
?
1UL
...
@@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
...
@@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
call
=
[
this
,
i
,
&
fetch_tensors
]()
->
FeedFetchList
{
auto
call
=
[
this
,
i
,
&
fetch_tensors
]()
->
FeedFetchList
{
try
{
try
{
for
(
size_t
j
=
0
;
j
<
strategy_
.
num_iteration_per_run_
-
1
;
++
j
)
{
executors_
[
i
]
->
Run
(
fetch_tensors
);
}
return
executors_
[
i
]
->
Run
(
fetch_tensors
);
return
executors_
[
i
]
->
Run
(
fetch_tensors
);
}
catch
(...)
{
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
exception_holder_
.
Catch
(
std
::
current_exception
());
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
e72637dd
...
@@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
...
@@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
places_
(
places
),
places_
(
places
),
fetch_ctxs_
(
places
),
fetch_ctxs_
(
places
),
running_ops_
(
0
),
running_ops_
(
0
),
strategy_
(
strategy
)
{}
strategy_
(
strategy
)
{
if
(
strategy_
.
num_iteration_per_run_
>
1
)
{
int
read_op_num
=
0
;
for
(
auto
*
node
:
graph_
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Name
()
==
"read"
)
{
read_op_num
++
;
}
}
if
(
read_op_num
==
0
)
{
LOG
(
WARNING
)
<<
"when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!"
;
}
}
}
FeedFetchList
ThreadedSSAGraphExecutor
::
Run
(
inline
FeedFetchList
ThreadedSSAGraphExecutor
::
RunImpl
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
std
::
unique_ptr
<
platform
::
RecordEvent
>
event
(
std
::
unique_ptr
<
platform
::
RecordEvent
>
event
(
new
platform
::
RecordEvent
(
"ThreadedSSAGraphExecutorPrepare"
,
nullptr
));
new
platform
::
RecordEvent
(
"ThreadedSSAGraphExecutorPrepare"
,
nullptr
));
...
@@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return
fetch_data
;
return
fetch_data
;
}
}
FeedFetchList
ThreadedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
for
(
size_t
j
=
0
;
j
<
strategy_
.
num_iteration_per_run_
-
1
;
++
j
)
{
RunImpl
({});
}
return
RunImpl
(
fetch_tensors
);
}
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
FetchOpHandle
*>
*
fetch_ops
,
std
::
vector
<
FetchOpHandle
*>
*
fetch_ops
,
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
e72637dd
...
@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~
ThreadedSSAGraphExecutor
()
final
=
default
;
~
ThreadedSSAGraphExecutor
()
final
=
default
;
private:
private:
inline
FeedFetchList
RunImpl
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
);
void
RunOp
(
const
std
::
shared_ptr
<
BlockingQueue
<
VarHandleBase
*>>
&
ready_var_q
,
void
RunOp
(
const
std
::
shared_ptr
<
BlockingQueue
<
VarHandleBase
*>>
&
ready_var_q
,
details
::
OpHandleBase
*
op
);
details
::
OpHandleBase
*
op
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录