Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c3f6e0e8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c3f6e0e8
编写于
7月 20, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add namespace to Graph
上级
0b3465d2
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
82 addition
and
71 deletion
+82
-71
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+19
-16
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+15
-12
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+4
-4
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+4
-4
paddle/fluid/framework/details/ssa_graph_checker.cc
paddle/fluid/framework/details/ssa_graph_checker.cc
+1
-1
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+3
-2
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+2
-2
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+4
-3
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-1
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+2
-2
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+7
-5
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+11
-11
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+4
-4
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+2
-2
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+1
-1
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
c3f6e0e8
...
@@ -68,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
...
@@ -68,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
Graph
*
result
,
ir
::
Node
*
node
,
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
...
@@ -192,8 +193,9 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
...
@@ -192,8 +193,9 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
// optimizer nodes after last backward nodes.
std
::
vector
<
ir
::
Node
*>
SortOpsAndDelayOptimizeOp
(
const
Graph
&
graph
)
{
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std
::
vector
<
ir
::
Node
*>
ret
=
ir
::
TopologySort
(
graph
);
std
::
vector
<
ir
::
Node
*>
SortOpsAndDelayOptimizeOp
(
const
ir
::
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
ret
=
ir
::
TopologySortOperations
(
graph
);
size_t
last_backward
=
0
;
size_t
last_backward
=
0
;
std
::
vector
<
ir
::
Node
*>
optimize_ops
;
std
::
vector
<
ir
::
Node
*>
optimize_ops
;
std
::
vector
<
ir
::
Node
*>
sorted_ret
;
std
::
vector
<
ir
::
Node
*>
sorted_ret
;
...
@@ -232,8 +234,8 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
...
@@ -232,8 +234,8 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
return
sorted_ret
;
return
sorted_ret
;
}
}
std
::
unique_ptr
<
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
// Rebuild the graph structure.
// Rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
auto
nodes
=
std
::
move
(
graph
->
nodes
);
auto
nodes
=
std
::
move
(
graph
->
nodes
);
...
@@ -245,7 +247,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -245,7 +247,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
}
}
}
}
Graph
&
result
=
*
graph
;
ir
::
Graph
&
result
=
*
graph
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
...
@@ -397,7 +399,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
...
@@ -397,7 +399,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif
#endif
}
}
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
{
size_t
src_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -427,7 +429,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
...
@@ -427,7 +429,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
int
dev_id
)
const
{
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
...
@@ -436,7 +438,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
...
@@ -436,7 +438,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
}
}
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
...
@@ -466,7 +468,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
...
@@ -466,7 +468,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
}
}
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
...
@@ -529,7 +531,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
...
@@ -529,7 +531,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
Graph
*
result
)
const
{
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -559,7 +561,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
...
@@ -559,7 +561,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOps
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
size_t
num_places
)
const
{
size_t
num_places
)
const
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
...
@@ -571,7 +573,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
...
@@ -571,7 +573,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
}
}
}
}
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
Graph
*
result
,
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -604,7 +606,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
...
@@ -604,7 +606,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
...
@@ -617,7 +619,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
...
@@ -617,7 +619,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
input_var_names
;
...
@@ -664,7 +666,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
...
@@ -664,7 +666,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
}
}
// Create RPC related op handles that connects its in ops and out ops.
// Create RPC related op handles that connects its in ops and out ops.
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
{
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
c3f6e0e8
...
@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
);
const
BuildStrategy
&
strategy
);
#endif
#endif
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
private:
void
CreateOpHandleIOs
(
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
private:
private:
std
::
string
loss_var_name_
;
std
::
string
loss_var_name_
;
...
@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
...
@@ -79,16 +81,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -79,16 +81,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
&
nodes
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
&
nodes
)
const
;
void
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
void
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
const
std
::
string
&
prev_op_name
)
const
;
void
CreateComputationalOps
(
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
num_places
)
const
;
size_t
num_places
)
const
;
void
CreateScaleLossGradOp
(
Graph
*
result
)
const
;
void
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
;
VarHandle
*
CreateReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
,
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
;
int
dst_dev_id
)
const
;
void
CreateComputationalOp
(
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
;
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
;
bool
IsParameterGradientOnce
(
bool
IsParameterGradientOnce
(
const
std
::
string
&
og
,
const
std
::
string
&
og
,
...
@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
void
InsertAllReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertDataBalanceOp
(
Graph
*
result
,
void
InsertDataBalanceOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
;
const
std
::
vector
<
std
::
string
>
&
datas
)
const
;
void
CreateBroadcastOp
(
Graph
*
result
,
const
std
::
string
&
p_name
,
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
c3f6e0e8
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
Graph
*
graph
)
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
}
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
...
@@ -81,7 +81,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
...
@@ -81,7 +81,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return
var
;
return
var
;
}
}
void
SSAGraphBuilder
::
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
void
SSAGraphBuilder
::
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
...
@@ -93,7 +93,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
...
@@ -93,7 +93,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
op_handle
->
AddOutput
(
var
);
op_handle
->
AddOutput
(
var
);
}
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
Graph
*
graph
)
{
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
continue
;
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
c3f6e0e8
...
@@ -64,19 +64,19 @@ class SSAGraphBuilder : public ir::Pass {
...
@@ -64,19 +64,19 @@ class SSAGraphBuilder : public ir::Pass {
*
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
*/
static
void
PolishGraphToSupportDataHazards
(
Graph
*
graph
);
static
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
ir
::
Node
*
node
,
static
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
size_t
place_offset
);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
// which belongs to graph
static
void
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
static
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
Graph
*
graph
);
static
void
AddOutputToLeafOps
(
ir
::
Graph
*
graph
);
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_checker.cc
浏览文件 @
c3f6e0e8
...
@@ -20,7 +20,7 @@ namespace paddle {
...
@@ -20,7 +20,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
bool
SSAGraghBuilderWithChecker
::
IsValidGraph
(
const
Graph
*
graph
)
const
{
bool
SSAGraghBuilderWithChecker
::
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
std
::
unordered_set
<
VarHandleBase
*>
ready_vars
;
std
::
unordered_set
<
VarHandleBase
*>
ready_vars
;
...
...
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
c3f6e0e8
...
@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
...
@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
builder_
(
std
::
move
(
builder
))
{}
:
builder_
(
std
::
move
(
builder
))
{}
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
new_graph
;
return
new_graph
;
...
@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
...
@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return
builder_
->
GetVarDeviceID
(
var_name
);
return
builder_
->
GetVarDeviceID
(
var_name
);
}
}
bool
IsValidGraph
(
const
Graph
*
graph
)
const
;
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
private:
private:
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
浏览文件 @
c3f6e0e8
...
@@ -21,7 +21,7 @@ namespace framework {
...
@@ -21,7 +21,7 @@ namespace framework {
namespace
details
{
namespace
details
{
template
<
typename
Callback
>
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
Graph
&
graph
,
Callback
callback
)
{
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
...
@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) {
...
@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) {
}
}
}
}
void
GraphvizSSAGraphPrinter
::
Print
(
const
Graph
&
graph
,
void
GraphvizSSAGraphPrinter
::
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
{
std
::
ostream
&
sout
)
const
{
size_t
var_id
=
0
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
...
...
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
c3f6e0e8
...
@@ -25,12 +25,12 @@ namespace details {
...
@@ -25,12 +25,12 @@ namespace details {
class
SSAGraphPrinter
{
class
SSAGraphPrinter
{
public:
public:
virtual
~
SSAGraphPrinter
()
{}
virtual
~
SSAGraphPrinter
()
{}
virtual
void
Print
(
const
Graph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
virtual
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
};
};
class
GraphvizSSAGraphPrinter
:
public
SSAGraphPrinter
{
class
GraphvizSSAGraphPrinter
:
public
SSAGraphPrinter
{
public:
public:
void
Print
(
const
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
};
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
...
@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
...
@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_
(
std
::
move
(
sout
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
return
new_graph
;
return
new_graph
;
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
c3f6e0e8
...
@@ -21,7 +21,8 @@ namespace framework {
...
@@ -21,7 +21,8 @@ namespace framework {
namespace
details
{
namespace
details
{
ThreadedSSAGraphExecutor
::
ThreadedSSAGraphExecutor
(
ThreadedSSAGraphExecutor
::
ThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
Graph
>
&&
graph
)
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
)
:
graph_
(
std
::
move
(
graph
)),
:
graph_
(
std
::
move
(
graph
)),
pool_
(
strategy
.
num_threads_
>=
2
?
new
::
ThreadPool
(
strategy
.
num_threads_
)
pool_
(
strategy
.
num_threads_
>=
2
?
new
::
ThreadPool
(
strategy
.
num_threads_
)
:
nullptr
),
:
nullptr
),
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
c3f6e0e8
...
@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
ThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
Graph
>
&&
graph
);
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
// Run a SSAGraph by a thread pool
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
// Use topological sort algorithm
...
@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details
::
OpHandleBase
*
op
);
details
::
OpHandleBase
*
op
);
private:
private:
std
::
unique_ptr
<
Graph
>
graph_
;
std
::
unique_ptr
<
ir
::
Graph
>
graph_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
...
...
paddle/fluid/framework/ir/graph.cc
浏览文件 @
c3f6e0e8
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
/*
/*
namespace {
namespace {
void SortHelper(
void SortHelper(
...
@@ -41,7 +42,7 @@ void SortHelper(
...
@@ -41,7 +42,7 @@ void SortHelper(
ret->push_back(node);
ret->push_back(node);
}
}
std::vector<ir::Node*> TopologySort(
std::vector<ir::Node*> TopologySort
Operations
(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
std::vector<ir::Node *> ret;
...
@@ -156,7 +157,7 @@ bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
...
@@ -156,7 +157,7 @@ bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
return false;
return false;
}
}
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
std::map<ir::Node *, std::unordered_set<ir::Node *>> Build
Operation
AdjList(
const std::vector<ir::Node*> &nodes) {
const std::vector<ir::Node*> &nodes) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
...
@@ -178,17 +179,17 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
...
@@ -178,17 +179,17 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
return adj_list;
return adj_list;
}
}
std::vector<ir::Node *> TopologySortOperationFromInToOut(
std::vector<ir::Node *> TopologySortOperation
sOperation
FromInToOut(
const std::vector<std::unique_ptr<ir::Node>> &nodes) {
const std::vector<std::unique_ptr<ir::Node>> &nodes) {
std::vector<ir::Node*> tmp;
std::vector<ir::Node*> tmp;
for (auto& n : nodes) {
for (auto& n : nodes) {
tmp.push_back(n.get());
tmp.push_back(n.get());
}
}
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildAdjList(tmp);
Build
Operation
AdjList(tmp);
PADDLE_ENFORCE(!HasCircle(adj_list));
PADDLE_ENFORCE(!HasCircle(adj_list));
std::vector<ir::Node*> ret = TopologySort(adj_list);
std::vector<ir::Node*> ret = TopologySort
Operations
(adj_list);
ir::Node *last_backward = nullptr;
ir::Node *last_backward = nullptr;
std::vector<ir::Node *> optimize_ops;
std::vector<ir::Node *> optimize_ops;
...
@@ -235,5 +236,6 @@ BuildAdjList(tmp);
...
@@ -235,5 +236,6 @@ BuildAdjList(tmp);
return ret;
return ret;
}*/
}*/
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph.h
浏览文件 @
c3f6e0e8
...
@@ -26,13 +26,13 @@ limitations under the License. */
...
@@ -26,13 +26,13 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
class
Graph
{
class
Graph
{
public:
public:
explicit
Graph
(
const
ProgramDesc
&
program
);
explicit
Graph
(
const
ProgramDesc
&
program
);
virtual
~
Graph
()
{
virtual
~
Graph
()
{
for
(
auto
&
attr
:
attrs_
)
{
for
(
auto
&
attr
:
attrs_
)
{
attr_dels_
[
attr
.
first
]();
attr_dels_
[
attr
.
first
]();
}
}
attrs_
.
clear
();
attrs_
.
clear
();
...
@@ -40,12 +40,12 @@ class Graph {
...
@@ -40,12 +40,12 @@ class Graph {
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
attrs_
[
attr_name
]
=
attr
;
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
...
@@ -54,17 +54,17 @@ class Graph {
...
@@ -54,17 +54,17 @@ class Graph {
};
};
}
}
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
var_desc
));
nodes
.
emplace_back
(
new
ir
::
Node
(
var_desc
));
return
nodes
.
back
().
get
();
return
nodes
.
back
().
get
();
}
}
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
op_desc
));
nodes
.
emplace_back
(
new
ir
::
Node
(
op_desc
));
return
nodes
.
back
().
get
();
return
nodes
.
back
().
get
();
}
}
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
name
,
type
));
nodes
.
emplace_back
(
new
ir
::
Node
(
name
,
type
));
return
nodes
.
back
().
get
();
return
nodes
.
back
().
get
();
}
}
...
@@ -73,10 +73,10 @@ class Graph {
...
@@ -73,10 +73,10 @@ class Graph {
private:
private:
// NOTE: program_ shouldn't be exposed to user.
// NOTE: program_ shouldn't be exposed to user.
const
ProgramDesc
&
program_
;
const
ProgramDesc
&
program_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
};
};
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
c3f6e0e8
...
@@ -64,7 +64,7 @@ bool HasCircleHelper(
...
@@ -64,7 +64,7 @@ bool HasCircleHelper(
bool
HasCircle
(
const
Graph
&
graph
)
{
bool
HasCircle
(
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
BuildAdjList
(
graph
);
Build
Operation
AdjList
(
graph
);
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
in_trace
;
std
::
unordered_set
<
ir
::
Node
*>
in_trace
;
...
@@ -76,9 +76,9 @@ bool HasCircle(const Graph &graph) {
...
@@ -76,9 +76,9 @@ bool HasCircle(const Graph &graph) {
return
false
;
return
false
;
}
}
std
::
vector
<
ir
::
Node
*>
TopologySort
(
const
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
TopologySort
Operations
(
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
BuildAdjList
(
graph
);
Build
Operation
AdjList
(
graph
);
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
vector
<
ir
::
Node
*>
ret
;
std
::
vector
<
ir
::
Node
*>
ret
;
for
(
auto
adj
:
adj_list
)
{
for
(
auto
adj
:
adj_list
)
{
...
@@ -89,7 +89,7 @@ std::vector<ir::Node *> TopologySort(const Graph &graph) {
...
@@ -89,7 +89,7 @@ std::vector<ir::Node *> TopologySort(const Graph &graph) {
return
ret
;
return
ret
;
}
}
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildAdjList
(
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
Build
Operation
AdjList
(
const
Graph
&
graph
)
{
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
;
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
;
...
...
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
c3f6e0e8
...
@@ -26,9 +26,9 @@ namespace framework {
...
@@ -26,9 +26,9 @@ namespace framework {
namespace
ir
{
namespace
ir
{
bool
HasCircle
(
const
Graph
&
graph
);
bool
HasCircle
(
const
Graph
&
graph
);
std
::
vector
<
ir
::
Node
*>
TopologySort
(
const
Graph
&
graph
);
std
::
vector
<
ir
::
Node
*>
TopologySort
Operations
(
const
Graph
&
graph
);
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildAdjList
(
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
Build
Operation
AdjList
(
const
Graph
&
graph
);
const
Graph
&
graph
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/graph_test.cc
浏览文件 @
c3f6e0e8
...
@@ -93,7 +93,7 @@ TEST(GraphTest, Basic) {
...
@@ -93,7 +93,7 @@ TEST(GraphTest, Basic) {
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
prog
.
MutableBlock
(
0
)
->
Var
(
"test_out"
)
->
GetType
());
prog
.
MutableBlock
(
0
)
->
Var
(
"test_out"
)
->
GetType
());
std
::
unique_ptr
<
Graph
>
g
(
new
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
ASSERT_EQ
(
g
->
nodes
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
0
]
->
Name
(),
"test_a"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
0
]
->
Name
(),
"test_a"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
1
]
->
Name
(),
"test_b"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
1
]
->
Name
(),
"test_b"
);
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
c3f6e0e8
...
@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor(
...
@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor(
#endif
#endif
}
}
builder_
=
builder_factory
.
Create
();
builder_
=
builder_factory
.
Create
();
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
main_program
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录