Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2e149999
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2e149999
编写于
11月 05, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean1
test=develop
上级
34b401fc
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
84 addition
and
65 deletion
+84
-65
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
...uid/framework/details/fast_threaded_ssa_graph_executor.cc
+4
-4
paddle/fluid/framework/details/gather_op_handle_test.cc
paddle/fluid/framework/details/gather_op_handle_test.cc
+6
-6
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
...fluid/framework/details/multi_devices_graph_check_pass.cc
+4
-4
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+13
-13
paddle/fluid/framework/details/multi_devices_helper.h
paddle/fluid/framework/details/multi_devices_helper.h
+3
-4
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+3
-1
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+10
-12
paddle/fluid/framework/details/ssa_graph_executor.cc
paddle/fluid/framework/details/ssa_graph_executor.cc
+1
-2
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+1
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+9
-9
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+7
-7
paddle/fluid/framework/details/var_handle.h
paddle/fluid/framework/details/var_handle.h
+3
-1
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+20
-0
未找到文件。
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
浏览文件 @
2e149999
...
...
@@ -36,9 +36,9 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
for
(
auto
&
op
:
ops
)
{
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
op_deps_
.
emplace
(
op
.
get
()
,
dep
);
op_deps_
.
emplace
(
op
,
dep
);
if
(
dep
==
0
)
{
bootstrap_ops_
.
emplace_back
(
op
.
get
()
);
bootstrap_ops_
.
emplace_back
(
op
);
}
}
...
...
@@ -54,13 +54,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle
::
framework
::
FeedFetchList
fetches
;
fetches
.
resize
(
fetch_tensors
.
size
());
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>
>
fetch_ops
;
std
::
vector
<
FetchOpHandle
*
>
fetch_ops
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
fetched_vars
[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
}
}
}
...
...
paddle/fluid/framework/details/gather_op_handle_test.cc
浏览文件 @
2e149999
...
...
@@ -31,8 +31,8 @@ struct TestGatherOpHandle {
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
param_scopes_
;
Scope
g_scope_
;
std
::
unique_ptr
<
OpHandleBase
>
op_handle_
;
std
::
vector
<
std
::
unique_ptr
<
VarHandleBase
>
>
vars_
;
OpHandleBase
*
op_handle_
;
std
::
vector
<
VarHandleBase
*
>
vars_
;
std
::
vector
<
p
::
Place
>
gpu_list_
;
void
WaitAll
()
{
...
...
@@ -84,8 +84,8 @@ struct TestGatherOpHandle {
nodes
.
emplace_back
(
ir
::
CreateNodeForTest
(
"node"
,
ir
::
Node
::
Type
::
kOperation
).
release
());
op_handle_
.
reset
(
new
GatherOpHandle
(
nodes
.
back
().
get
(),
local_scopes_
,
gpu_list_
)
)
;
op_handle_
=
new
GatherOpHandle
(
nodes
.
back
().
get
(),
local_scopes_
,
gpu_list_
);
// add input
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
...
...
@@ -102,7 +102,7 @@ struct TestGatherOpHandle {
ir
::
CreateNodeForTest
(
"node2"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
DummyVarHandle
*
in_dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
()
.
get
()
);
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
());
in_dummy_var_handle
->
ClearGeneratedOp
();
op_handle_
->
AddInput
(
in_dummy_var_handle
);
...
...
@@ -119,7 +119,7 @@ struct TestGatherOpHandle {
ir
::
CreateNodeForTest
(
"node4"
,
ir
::
Node
::
Type
::
kVariable
).
release
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
DummyVarHandle
*
dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
()
.
get
()
);
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
());
op_handle_
->
AddOutput
(
dummy_var_handle
);
}
...
...
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
浏览文件 @
2e149999
...
...
@@ -36,20 +36,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
.
get
()
);
insert_pending_var
(
version_pair
);
}
}
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
insert_pending_var
(
var
.
get
()
);
insert_pending_var
(
var
);
}
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
ready_ops
.
insert
(
op
.
get
()
);
ready_ops
.
insert
(
op
);
}
else
{
pending_ops
.
insert
({
op
.
get
(),
op
.
get
()
->
NoDupInputSize
()});
pending_ops
.
insert
({
op
,
op
->
NoDupInputSize
()});
}
}
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
2e149999
...
...
@@ -93,7 +93,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
}
var_holder
.
emplace_back
(
var
);
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
var
=
*
var_holder
.
rbegin
();
}
return
var
;
}
...
...
@@ -155,7 +155,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
...
@@ -498,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
auto
*
in
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
()
.
get
()
;
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
();
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
...
...
@@ -535,7 +535,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_varnames
.
size
();
++
dev_id
)
{
for
(
auto
&
p_name
:
bcast_varnames
[
dev_id
])
{
auto
*
in
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
dev_id
).
at
(
p_name
).
back
()
.
get
()
;
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
dev_id
).
at
(
p_name
).
back
();
op_handle
->
AddInput
(
in
);
for
(
size_t
out_dev_id
=
0
;
out_dev_id
<
places_
.
size
();
++
out_dev_id
)
{
auto
&
p
=
places_
[
out_dev_id
];
...
...
@@ -571,7 +571,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
...
...
@@ -579,7 +579,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
()
);
op_handle
->
AddInput
(
prev_grad
);
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
...
...
@@ -600,14 +600,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
()
.
get
()
);
op_handle
->
AddInput
(
vars
.
back
());
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
d_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
d_name
,
p
);
...
...
@@ -691,7 +691,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
...
...
@@ -699,7 +699,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
()
);
op_handle
->
AddInput
(
prev_grad
);
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
var
=
...
...
@@ -760,14 +760,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
}
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
VarHandle
*
var
=
nullptr
;
for
(
int
place_offset
=
0
;
place_offset
<
num_places
;
++
place_offset
)
{
auto
&
var_holders
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
input
->
Name
()];
if
(
!
var_holder
.
empty
())
{
var
=
var_holder
.
rbegin
()
->
get
();
var
=
*
var_holder
.
rbegin
();
op_handle
->
AddInput
(
var
);
}
}
...
...
@@ -840,7 +840,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// all places
auto
p
=
places_
[
op_dev_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
()
.
get
()
;
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
...
paddle/fluid/framework/details/multi_devices_helper.h
浏览文件 @
2e149999
...
...
@@ -36,18 +36,17 @@ namespace details {
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandle
*>>>
GraphVars
;
const
char
kGraphVars
[]
=
"vars"
;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>
>
GraphDepVars
;
typedef
std
::
unordered_set
<
VarHandleBase
*
>
GraphDepVars
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>
>
GraphOps
;
typedef
std
::
vector
<
OpHandleBase
*
>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
2e149999
...
...
@@ -31,7 +31,9 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node.
class
OpHandleBase
{
public:
explicit
OpHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{}
explicit
OpHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{
node_
->
WrappedBy
(
this
);
}
virtual
~
OpHandleBase
();
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
2e149999
...
...
@@ -71,14 +71,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unique_ptr
<
ReferenceCountOpHandle
>
>
std
::
unordered_map
<
OpHandleBase
*
,
ReferenceCountOpHandle
*
>
compute_ref_cnt_map
;
auto
get_ref_cnts_from_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
std
::
vector
<
std
::
string
>
var_names_in_op
;
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
()
);
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
==
nullptr
||
!
platform
::
is_gpu_place
(
compute_op
->
GetPlace
()))
return
var_names_in_op
;
...
...
@@ -121,9 +120,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
};
auto
update_ref_cnts_from_non_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
if
(
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
())
!=
nullptr
)
return
;
OpHandleBase
*
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
if
(
dynamic_cast
<
ComputationOpHandle
*>
(
op
)
!=
nullptr
)
return
;
for
(
VarHandleBase
*
var_handle_base
:
vars
)
{
auto
*
var_handle
=
dynamic_cast
<
VarHandle
*>
(
var_handle_base
);
if
(
var_handle
==
nullptr
||
!
var_handle
->
Node
()
->
IsVar
())
continue
;
...
...
@@ -151,7 +149,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
AddDependencyBetween
(
next_compute_op
,
ref_cnt_handle
,
graph
.
get
());
compute_ref_cnt_map
[
next_compute_op
]
.
reset
(
ref_cnt_handle
)
;
compute_ref_cnt_map
[
next_compute_op
]
=
ref_cnt_handle
;
}
}
}
...
...
@@ -165,7 +163,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if
(
in_var_names
.
empty
()
&&
out_var_names
.
empty
())
continue
;
in_var_names
.
insert
(
in_var_names
.
end
(),
out_var_names
.
begin
(),
out_var_names
.
end
());
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
.
get
()
);
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
compute_op
->
GetPlace
());
ir
::
Node
*
ref_cnt_node
=
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
...
...
@@ -173,7 +171,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
AddDependencyBetween
(
compute_op
,
ref_cnt_handle
,
graph
.
get
());
compute_ref_cnt_map
[
compute_op
]
.
reset
(
ref_cnt_handle
)
;
compute_ref_cnt_map
[
compute_op
]
=
ref_cnt_handle
;
}
for
(
auto
&
op
:
all_ops
)
{
...
...
@@ -181,11 +179,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op
(
op
,
op
->
Outputs
());
}
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>
>
new_all_ops
;
std
::
vector
<
OpHandleBase
*
>
new_all_ops
;
new_all_ops
.
reserve
(
compute_ref_cnt_map
.
size
()
+
all_ops
.
size
());
for
(
auto
&
op
:
all_ops
)
{
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
()
.
get
()
);
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
());
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
// Add LeafNode to ReferenceCountOpHandle
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
...
...
paddle/fluid/framework/details/ssa_graph_executor.cc
浏览文件 @
2e149999
...
...
@@ -19,8 +19,7 @@ namespace framework {
namespace
details
{
SSAGraphExecutor
::~
SSAGraphExecutor
()
{}
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>*
fetch_ops
)
{
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
FetchOpHandle
*>*
fetch_ops
)
{
if
(
fetch_ops
->
empty
())
return
;
for
(
auto
&
op
:
*
fetch_ops
)
{
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
2e149999
...
...
@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
};
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>*
fetch_ops
);
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
FetchOpHandle
*>*
fetch_ops
);
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
2e149999
...
...
@@ -51,25 +51,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
version_pair
.
get
()
);
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
version_pair
);
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
.
get
()
);
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
);
}
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
details
::
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
()
);
ready_ops
.
insert
(
op
);
}
else
{
InsertPendingOp
(
&
pending_ops
,
op
.
get
()
);
InsertPendingOp
(
&
pending_ops
,
op
);
}
}
// Step 2. Insert FetchOps
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>
>
fetch_ops
;
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>
>
fetch_dependencies
;
std
::
vector
<
FetchOpHandle
*
>
fetch_ops
;
std
::
unordered_set
<
VarHandleBase
*
>
fetch_dependencies
;
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
InsertFetchOps
(
fetch_tensors
,
&
fetch_ops
,
&
fetch_dependencies
,
&
pending_ops
,
...
...
@@ -140,8 +140,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>
>
*
fetch_ops
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>
>
*
fetch_dependencies
,
std
::
vector
<
FetchOpHandle
*
>
*
fetch_ops
,
std
::
unordered_set
<
VarHandleBase
*
>
*
fetch_dependencies
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
)
{
...
...
@@ -151,7 +151,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
fetched_vars
[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
}
}
}
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
2e149999
...
...
@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
VarHandleBase
*
var
)
const
;
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensor
s
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
*
fetch_op
s
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
*
fetch_dependencie
s
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_op
s
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending
_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
);
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
FetchOpHandle
*>
*
fetch_op
s
,
std
::
unordered_set
<
VarHandleBase
*>
*
fetch_dependencie
s
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_op
s
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_var
s
,
BlockingQueue
<
VarHandleBase
*>
*
ready
_vars
,
FeedFetchList
*
fetch_data
);
private:
ExecutionStrategy
strategy_
;
...
...
paddle/fluid/framework/details/var_handle.h
浏览文件 @
2e149999
...
...
@@ -35,7 +35,9 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e.
// This is a single assignment graph.
struct
VarHandleBase
{
explicit
VarHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{}
explicit
VarHandleBase
(
ir
::
Node
*
node
)
:
node_
(
node
)
{
node_
->
WrappedBy
(
this
);
}
virtual
~
VarHandleBase
();
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
2e149999
...
...
@@ -27,6 +27,8 @@ namespace ir {
// Node should normally created by Graph::CreateXXXNode().
class
Node
{
public:
virtual
~
Node
()
{}
enum
class
Type
{
kOperation
,
kVariable
};
static
constexpr
char
kControlDepVarName
[]
=
"__control_var"
;
...
...
@@ -44,6 +46,20 @@ class Node {
return
op_desc_
.
get
();
}
template
<
typename
T
>
void
WrappedBy
(
T
*
wrapper
)
{
if
(
!
wrapper_
.
empty
())
{
wrapper_deleter_
();
}
wrapper_
=
wrapper
;
wrapper_deleter_
=
[
wrapper
]()
{
delete
wrapper
;
};
}
template
<
typename
T
>
T
&
Wrapper
()
{
return
*
boost
::
any_cast
<
T
*>
(
wrapper_
);
}
// Please don't use this API!
int
id
()
const
{
return
id_
;
}
...
...
@@ -95,6 +111,10 @@ class Node {
static
int
count_
;
// Please don't use this API or make this public.
static
void
ResetId
()
{
count_
=
0
;
}
boost
::
any
wrapper_
;
std
::
function
<
void
(
void
)
>
wrapper_deleter_
;
DISABLE_COPY_AND_ASSIGN
(
Node
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录