Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f4851f14
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f4851f14
编写于
5月 07, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean code
上级
22ab14c5
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
80 addition
and
48 deletion
+80
-48
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+60
-48
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+16
-0
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+4
-0
未找到文件。
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
f4851f14
...
...
@@ -14,8 +14,6 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Should revisit it if overlapping is available.
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
auto
InsertPendingVar
=
[
&
pending_vars
,
&
ready_vars
](
VarHandleBase
&
var
)
{
pending_vars
.
insert
(
&
var
);
if
(
var
.
generated_op_
==
nullptr
)
{
ready_vars
.
Push
(
&
var
);
}
};
auto
InsertPendingOp
=
[
&
pending_ops
](
OpHandleBase
&
op_instance
)
{
pending_ops
.
insert
({
&
op_instance
,
op_instance
.
Inputs
().
size
()});
};
// Transform SSAGraph to pending_ops & pending_vars
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
*
version_pair
);
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
()
);
}
}
}
for
(
auto
&
var
:
graph_
->
dep_vars_
)
{
InsertPendingVar
(
*
var
);
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
()
);
}
for
(
auto
&
op
:
graph_
->
ops_
)
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
());
}
else
{
InsertPendingOp
(
*
op
);
InsertPendingOp
(
&
pending_ops
,
op
.
get
()
);
}
}
// Step 2. Insert FetchOps
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
fetch_ops
;
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
}
}
}
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
fetch_dependencies
;
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
[
i
];
auto
&
vars
=
fetched_vars
.
at
(
var_name
);
auto
*
op
=
new
FetchOpHandle
(
&
fetch_data
,
i
,
&
local_scopes_
);
fetch_ops
.
emplace_back
(
op
);
for
(
auto
&
p
:
places_
)
{
op
->
SetDeviceContext
(
p
,
fetch_ctxs_
.
Get
(
p
));
}
for
(
auto
*
var
:
vars
)
{
op
->
AddInput
(
var
);
}
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
auto
*
fetch_dummy
=
new
DummyVarHandle
();
op
->
AddOutput
(
fetch_dummy
);
fetch_dependencies
.
emplace
(
fetch_dummy
);
InsertPendingVar
(
*
fetch_dummy
);
InsertPendingOp
(
*
op
);
}
InsertFetchOps
(
fetch_tensors
,
&
fetch_ops
,
&
fetch_dependencies
,
&
pending_ops
,
&
pending_vars
,
&
ready_vars
,
&
fetch_data
);
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
OpHandleBase
*>
&
set
)
{
for
(
auto
*
op
:
set
)
{
...
...
@@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return
fetch_data
;
}
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
*
fetch_ops
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
*
fetch_dependencies
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
)
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
}
}
}
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
[
i
];
auto
&
vars
=
fetched_vars
.
at
(
var_name
);
auto
*
op
=
new
FetchOpHandle
(
fetch_data
,
i
,
&
local_scopes_
);
fetch_ops
->
emplace_back
(
op
);
for
(
auto
&
p
:
places_
)
{
op
->
SetDeviceContext
(
p
,
fetch_ctxs_
.
Get
(
p
));
}
for
(
auto
*
var
:
vars
)
{
op
->
AddInput
(
var
);
}
auto
*
fetch_dummy
=
new
DummyVarHandle
();
op
->
AddOutput
(
fetch_dummy
);
fetch_dependencies
->
emplace
(
fetch_dummy
);
this
->
InsertPendingVar
(
pending_vars
,
ready_vars
,
fetch_dummy
);
this
->
InsertPendingOp
(
pending_ops
,
op
);
}
}
void
ThreadedSSAGraphExecutor
::
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
OpHandleBase
*
op_instance
)
const
{
pending_ops
->
insert
({
op_instance
,
op_instance
->
Inputs
().
size
()});
}
void
ThreadedSSAGraphExecutor
::
InsertPendingVar
(
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
VarHandleBase
*
var
)
const
{
pending_vars
->
insert
(
var
);
if
(
var
->
generated_op_
==
nullptr
)
{
ready_vars
->
Push
(
var
);
}
}
void
ThreadedSSAGraphExecutor
::
RunOp
(
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
f4851f14
...
...
@@ -23,6 +23,7 @@
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace
paddle
{
...
...
@@ -58,6 +59,21 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
std
::
atomic
<
int
>
running_ops_
;
bool
allow_op_delay_
;
void
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
OpHandleBase
*
op_instance
)
const
;
void
InsertPendingVar
(
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
VarHandleBase
*
var
)
const
;
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
*
fetch_ops
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
*
fetch_dependencies
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
);
};
}
// namespace details
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
f4851f14
...
...
@@ -721,3 +721,7 @@ class TestCRFModel(unittest.TestCase):
def
test_update_dense_parameter
(
self
):
self
.
check_network_convergence
(
is_sparse
=
False
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录