Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4b8ae523
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
4b8ae523
编写于
7月 30, 2018
作者:
X
Xin Pan
提交者:
GitHub
7月 30, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12367 from panyx0718/ir_pass
Ir pass
上级
297cbeb1
99c0c204
变更
26
显示空白变更内容
内联
并排
Showing
26 changed file
with
764 addition
and
310 deletion
+764
-310
doc/fluid/design/ir/draft.md
doc/fluid/design/ir/draft.md
+97
-1
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+0
-3
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+81
-69
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+12
-25
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+3
-0
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+7
-6
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+6
-2
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
+0
-50
paddle/fluid/framework/details/ssa_graph_builder_factory.h
paddle/fluid/framework/details/ssa_graph_builder_factory.h
+0
-71
paddle/fluid/framework/details/ssa_graph_checker.cc
paddle/fluid/framework/details/ssa_graph_checker.cc
+10
-3
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+4
-16
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+3
-1
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+6
-3
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+9
-30
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+4
-4
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+6
-3
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+8
-1
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+72
-0
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+38
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+28
-1
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+168
-2
paddle/fluid/framework/ir/pass_test.cc
paddle/fluid/framework/ir/pass_test.cc
+112
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+88
-17
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+0
-1
未找到文件。
doc/fluid/design/ir/draft.md
浏览文件 @
4b8ae523
...
...
@@ -64,6 +64,41 @@ can also contain other things that describe some properties of
the
`Graph`
or
`Graph`
nodes.
`Attribute`
can be passed
across
`Pass`
. However, it should be used with care.
```
cpp
class
Graph
{
public:
explicit
Graph
(
const
ProgramDesc
&
program
);
bool
Has
(
const
std
::
string
&
attr_name
)
const
;
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
;
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
;
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
);
// Create a normal runnable operator with OpDesc.
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
);
// Create a control dependency var that connects 2 operations. The
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir
::
Node
*
CreateControlDepVar
();
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
);
// Clear all node information of the graph and return the ownership of the
// nodes.
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
ReleaseNodes
();
};
```
#### Pass
`Pass`
represents a transformation of
`Graph`
. Its input
...
...
@@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example,
a
`Pass`
can simply print out the
`Graph`
. A
`Pass`
can also fuse some
`Graph`
's
`Node`
s.
```
cpp
class
Pass
{
public:
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
// Some correctness check.
auto
new_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// Some correctness check.
return
new_graph
;
}
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
;
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
;
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
};
// In my_pass.cc
class
MyPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
// do something.
return
graph
;
}
}
REGISTER_PASS
(
my_pass
,
MyPass
)
.
RequirePassAttr
(
"places"
)
.
RequireGraphAttr
(
"dep_vars"
);
// To use the pass.
auto
my_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"my_pass"
);
graph
=
my_pass
->
Apply
(
std
::
move
(
graph
));
// Note: to force link my_pass.cc, in the code:
USE_PASS
(
my_pass
);
```
#### Optimize
`Optimize`
contains a series of
`Pass`
with defined order.
...
...
@@ -86,4 +169,17 @@ maintaining the original modeling logic.
*
Graph is transformed from raw model logic to a
form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
```
// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
auto graph = Graph(program);
graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
// For more complex Pass, Optimize Process can provide Pass attributes.
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
mem_opt_pass->Apply(std::move(graph));
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
Executor exe;
exe.Run(graph);
```
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
4b8ae523
...
...
@@ -99,7 +99,7 @@ else()
endif
()
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
4b8ae523
...
...
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
cc_library
(
ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
4b8ae523
...
...
@@ -34,30 +34,22 @@ namespace paddle {
namespace
framework
{
namespace
details
{
static
const
char
kLossVarName
[]
=
"loss_var_name"
;
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kParams
[]
=
"params"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kStrategy
[]
=
"strategy"
;
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
strategy_
=
Get
<
const
BuildStrategy
>
(
kStrategy
);
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
),
strategy_
(
strategy
)
{
#else
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
#endif
for
(
auto
&
p
:
params
)
{
for
(
auto
&
p
:
Get
<
const
std
::
unordered_set
<
std
::
string
>>
(
kParams
))
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
balance_vars_
.
resize
(
places_
.
size
(),
0
);
...
...
@@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
...
@@ -239,8 +231,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
return
sorted_ret
;
}
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
Init
();
// Give the topology sort order and rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
auto
nodes
=
graph
->
ReleaseNodes
();
...
...
@@ -254,9 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
Set
(
"vars"
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
"dep_vars"
,
new
GraphDepVars
);
result
.
Set
(
"ops"
,
new
GraphOps
);
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kShardedVarDevice
,
new
ShardedVarDevice
);
// find send/recv vars so that we can place the distributed training
// related op in the place 0
...
...
@@ -289,11 +283,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// the block.
is_forwarding
=
false
;
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
);
int
op_dev_id
=
GetOpDeviceID
(
result
,
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
var_name_on_devices_
.
emplace
(
n
->
Name
(),
op_dev_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
}
else
{
// This op runs on all devices, and its output may have parameter's
...
...
@@ -330,7 +325,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
var_name_on_devices_
.
emplace
(
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
...
...
@@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
);
#endif
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
auto
*
in
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
i
).
at
(
p_name
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
p_name
,
p
);
...
...
@@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
...
...
@@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
...
...
@@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
d_name
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
...
...
@@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return
is_pg_once
;
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
}
...
...
@@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
int
dev_id
=
GetVarDeviceID
(
graph
,
param_grad
[
1
]);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
}
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
{
auto
got
=
var_name_on_devices_
.
find
(
varname
);
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
{
auto
&
sharded_var_device
=
graph
.
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
);
auto
got
=
sharded_var_device
.
find
(
varname
);
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
...
...
@@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
...
...
@@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
}
...
...
@@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
dst_dev_id
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
...
...
@@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
prev_op
->
AddOutput
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
...
...
@@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
{
PADDLE_ENFORCE
(
...
...
@@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"fetch_barrier"
);
}
}
...
...
@@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
op_dev_id
=
GetVarDeviceID
(
*
result
,
node
->
inputs
[
0
]
->
Name
());
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
...
...
@@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
...
...
@@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
{
// send_barrier and fetch_barrier op can be scheduled on device 0
...
...
@@ -711,18 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
node
->
Op
()
->
Type
());
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
RPCOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
RPCOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
// TODO(panyx0718): This might not be needed anymore.
if
(
node
->
Op
()
->
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"send"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send_barrier"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"recv"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"recv"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// do nothing
}
else
{
...
...
@@ -744,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_pass
,
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLossVarName
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kParams
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
);
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
4b8ae523
...
...
@@ -31,39 +31,27 @@ class Scope;
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
public:
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
);
#else
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
);
#endif
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
void
Init
()
const
;
private:
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
mutable
std
::
string
loss_var_name_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
mutable
std
::
unordered_set
<
std
::
string
>
grad_names_
;
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
int
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
...
...
@@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
string
&
og
,
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
int
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
...
...
@@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
private:
BuildStrategy
strategy_
;
mutable
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
var_name_on_devices_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
4b8ae523
...
...
@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
const
ir
::
Graph
&
Graph
()
const
{
return
underlying_executor_
->
Graph
();
}
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
4b8ae523
...
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
framework
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
...
...
@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
...
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
...
...
@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
new_node
->
Name
()];
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
...
...
@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
4b8ae523
...
...
@@ -39,21 +39,25 @@ namespace details {
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
GraphVars
;
const
char
kGraphVars
[]
=
"vars"
;
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
const
char
kShardedVarDevice
[]
=
"sharded_var_device"
;
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
...
...
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
已删除
100644 → 0
浏览文件 @
297cbeb1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
#ifdef PADDLE_WITH_CUDA
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
nccl_ctxs_
,
strategy_
)
#else
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
strategy_
)
#endif
);
// NOLINT
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
strategy_
.
debug_graphviz_path_
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
unique_ptr
<
GraphvizSSAGraphPrinter
>
graphviz_printer
(
new
GraphvizSSAGraphPrinter
());
res
.
reset
(
new
SSAGraghBuilderWithPrinter
(
std
::
move
(
fout
),
std
::
move
(
graphviz_printer
),
std
::
move
(
res
)));
}
res
.
reset
(
new
SSAGraghBuilderWithChecker
(
std
::
move
(
res
)));
return
res
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_builder_factory.h
已删除
100644 → 0
浏览文件 @
297cbeb1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
class
SSAGraphBuilderFactory
{
public:
SSAGraphBuilderFactory
(
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>&
param_names
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
places_
(
places
),
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_
=
nullptr
;
#endif
}
#ifdef PADDLE_WITH_CUDA
void
SetNCCLContextMap
(
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
nccl_ctxs_
=
nccl_ctxs
;
}
#endif
std
::
unique_ptr
<
SSAGraphBuilder
>
Create
();
private:
std
::
vector
<
platform
::
Place
>
places_
;
std
::
string
loss_var_name_
;
std
::
unordered_set
<
std
::
string
>
param_names_
;
std
::
vector
<
Scope
*>
local_scopes_
;
BuildStrategy
strategy_
;
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_checker.cc
浏览文件 @
4b8ae523
...
...
@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
};
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
.
get
());
...
...
@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
insert_pending_var
(
var
.
get
());
}
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
ready_ops
.
insert
(
op
.
get
());
}
else
{
...
...
@@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphOps
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kShardedVarDevice
);
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
4b8ae523
...
...
@@ -23,26 +23,14 @@ namespace framework {
namespace
details
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
public:
explicit
SSAGraghBuilderWithChecker
(
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
builder_
(
std
::
move
(
builder
))
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
new_graph
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
PADDLE_ENFORCE
(
IsValidGraph
(
graph
.
get
()));
return
graph
;
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
private:
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
};
}
// namespace details
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
4b8ae523
...
...
@@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual
~
SSAGraphExecutor
();
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
=
0
;
virtual
const
ir
::
Graph
&
Graph
()
const
=
0
;
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
浏览文件 @
4b8ae523
...
...
@@ -22,7 +22,7 @@ namespace details {
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
...
...
@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
callback
(
*
var
);
}
}
...
...
@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
kGraphOps
))
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
...
...
@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_print_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithPrinter
);
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
4b8ae523
...
...
@@ -14,7 +14,9 @@
#pragma once
#include <fstream>
#include <iosfwd>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
...
...
@@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
public:
SSAGraghBuilderWithPrinter
(
std
::
ostream
&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ref_
(
sout
)
{}
SSAGraghBuilderWithPrinter
(
std
::
unique_ptr
<
std
::
ostream
>&&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
return
new_graph
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
const
std
::
string
>
(
"debug_graphviz_path"
)));
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
return
graph
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
}
private:
std
::
unique_ptr
<
SSAGraphPrinter
>
printer_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
std
::
ostream
>
stream_ptr_
;
std
::
ostream
&
stream_ref_
;
};
}
// namespace details
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
4b8ae523
...
...
@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
// Transform SSAGraph to pending_ops & pending_vars
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
());
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
());
}
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
details
::
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
());
}
else
{
...
...
@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
4b8ae523
...
...
@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
const
ir
::
Graph
&
Graph
()
const
{
return
*
graph_
;
}
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
4b8ae523
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
graph SRCS graph.cc DEPS node
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
pass SRCS pass.cc DEPS graph node
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry
)
cc_library
(
pass SRCS pass.cc DEPS graph node graph_helper
)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
paddle/fluid/framework/ir/graph.h
浏览文件 @
4b8ae523
...
...
@@ -40,14 +40,21 @@ class Graph {
attr_dels_
.
clear
();
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
return
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
();
}
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
Has
(
attr_name
),
"%s attr not registered for graph."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the graph"
,
attr_name
);
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
0 → 100644
浏览文件 @
4b8ae523
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphVizPath
);
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
ostream
&
sout
=
*
fout
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
ir
::
Node
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kVariable
)
continue
;
size_t
cur_var_id
=
var_id
++
;
vars
[
n
]
=
cur_var_id
;
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
]"
<<
std
::
endl
;
}
size_t
op_id
=
0
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
n
->
inputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
for
(
auto
out
:
n
->
outputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
}
}
sout
<<
"}
\n
"
;
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraphVizPath
);
paddle/fluid/framework/ir/graph_viz_pass.h
0 → 100644
浏览文件 @
4b8ae523
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
GraphVizPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/pass.cc
浏览文件 @
4b8ae523
...
...
@@ -13,7 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{}
// namespace framework
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
Graph
>
Pass
::
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
!
applied_
,
"Pass can only Apply() once."
);
PADDLE_ENFORCE
(
graph
.
get
(),
"graph passed to Pass::Apply() cannot be empty."
);
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
"Required pass atrribute %s not set."
,
attr
);
}
for
(
const
std
::
string
&
attr
:
required_graph_attrs_
)
{
PADDLE_ENFORCE
(
graph
->
Has
(
attr
),
"Required graph atrribute %s not set."
,
attr
);
}
auto
applied_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE
(
!
HasCircle
(
*
applied_graph
),
"Illegal Pass. Generated graph shouldn't has cycle."
);
applied_
=
true
;
return
applied_graph
;
}
PassRegistry
&
PassRegistry
::
Instance
()
{
static
PassRegistry
g_pass_info_map
;
return
g_pass_info_map
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/pass.h
浏览文件 @
4b8ae523
...
...
@@ -14,21 +14,187 @@ limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
template
<
typename
PassType
>
struct
PassRegistrar
;
class
Pass
{
public:
Pass
()
=
default
;
virtual
~
Pass
()
{}
virtual
~
Pass
()
{
for
(
auto
&
attr
:
attrs_
)
{
if
(
attr_dels_
.
find
(
attr
.
first
)
!=
attr_dels_
.
end
())
{
attr_dels_
[
attr
.
first
]();
}
}
attrs_
.
clear
();
attr_dels_
.
clear
();
}
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
;
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
"%s attr not registered for pass."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the pass"
,
attr_name
);
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
delete
attr
;
};
}
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
attrs_
[
attr_name
]
=
attr
;
}
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
private:
template
<
typename
PassType
>
friend
struct
PassRegistrar
;
void
RegisterRequiredPassAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_pass_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
void
RegisterRequiredGraphAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_graph_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
mutable
bool
applied_
{
false
};
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
};
using
PassCreator
=
std
::
function
<
std
::
unique_ptr
<
Pass
>
()
>
;
class
Registrar
{
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void
Touch
()
{}
};
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
class
PassRegistry
{
public:
static
PassRegistry
&
Instance
();
bool
Has
(
const
std
::
string
&
pass_type
)
const
{
return
map_
.
find
(
pass_type
)
!=
map_
.
end
();
}
void
Insert
(
const
std
::
string
&
pass_type
,
const
PassCreator
&
pass_creator
)
{
PADDLE_ENFORCE
(
!
Has
(
pass_type
),
"Pass %s has been registered"
,
pass_type
);
map_
.
insert
({
pass_type
,
pass_creator
});
}
std
::
unique_ptr
<
Pass
>
Get
(
const
std
::
string
&
pass_type
)
const
{
PADDLE_ENFORCE
(
Has
(
pass_type
),
"Pass %s has not been registered"
,
pass_type
);
return
map_
.
at
(
pass_type
)();
}
private:
PassRegistry
()
=
default
;
std
::
unordered_map
<
std
::
string
,
PassCreator
>
map_
;
DISABLE_COPY_AND_ASSIGN
(
PassRegistry
);
};
template
<
typename
PassType
>
struct
PassRegistrar
:
public
Registrar
{
explicit
PassRegistrar
(
const
char
*
pass_type
)
{
PADDLE_ENFORCE
(
!
PassRegistry
::
Instance
().
Has
(
pass_type
),
"'%s' is registered more than once."
,
pass_type
);
PassRegistry
::
Instance
().
Insert
(
pass_type
,
[
this
]()
->
std
::
unique_ptr
<
Pass
>
{
std
::
unique_ptr
<
Pass
>
pass
(
new
PassType
());
pass
->
RegisterRequiredPassAttrs
(
this
->
required_pass_attrs_
);
pass
->
RegisterRequiredGraphAttrs
(
this
->
required_graph_attrs_
);
return
pass
;
});
}
PassRegistrar
<
PassType
>
&
RequirePassAttr
(
const
std
::
string
&
attr
)
{
required_pass_attrs_
.
insert
(
attr
);
return
*
this
;
}
PassRegistrar
<
PassType
>
&
RequireGraphAttr
(
const
std
::
string
&
attr
)
{
required_graph_attrs_
.
insert
(
attr
);
return
*
this
;
}
private:
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \
TouchPassRegistrar_##pass_type()
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/pass_test.cc
0 → 100644
浏览文件 @
4b8ae523
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
BuildCircleGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
o2
->
outputs
.
push_back
(
v2
);
o1
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o1
);
}
class
TestPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
graph
->
Set
<
int
>
(
"copy_test_pass_attr"
,
new
int
);
graph
->
Set
<
int
>
(
"copy_test_graph_attr"
,
new
int
);
int
test_pass_attr
=
this
->
Get
<
int
>
(
"test_pass_attr"
);
graph
->
Get
<
int
>
(
"copy_test_pass_attr"
)
=
test_pass_attr
+
1
;
int
test_graph_attr
=
graph
->
Get
<
int
>
(
"test_graph_attr"
);
graph
->
Get
<
int
>
(
"copy_test_graph_attr"
)
=
test_graph_attr
+
1
;
return
graph
;
}
};
TEST
(
PassTest
,
TestPassAttrCheck
)
{
ProgramDesc
prog
;
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
prog
));
std
::
string
exception
;
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"test_pass_attr not set"
)
!=
exception
.
npos
);
int
val
=
1
;
graph
.
reset
(
new
Graph
(
prog
));
pass
->
SetNotOwned
<
int
>
(
"test_pass_attr"
,
&
val
);
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"test_graph_attr not set"
)
!=
exception
.
npos
);
graph
.
reset
(
new
Graph
(
prog
));
graph
->
Set
<
int
>
(
"test_graph_attr"
,
new
int
);
graph
->
Get
<
int
>
(
"test_graph_attr"
)
=
1
;
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_test_pass_attr"
),
2
);
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_test_graph_attr"
),
2
);
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"Pass can only Apply() once"
)
!=
exception
.
npos
);
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
pass
->
SetNotOwned
<
int
>
(
"test_pass_attr"
,
&
val
);
graph
.
reset
(
new
Graph
(
prog
));
BuildCircleGraph
(
graph
.
get
());
graph
->
Set
<
int
>
(
"test_graph_attr"
,
new
int
);
graph
->
Get
<
int
>
(
"test_graph_attr"
)
=
2
;
try
{
auto
tmp
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"shouldn't has cycle"
)
!=
exception
.
npos
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
test_pass
,
paddle
::
framework
::
ir
::
TestPass
)
.
RequirePassAttr
(
"test_pass_attr"
)
.
RequireGraphAttr
(
"test_graph_attr"
);
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
4b8ae523
...
...
@@ -19,19 +19,80 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
framework
{
std
::
unique_ptr
<
ir
::
Graph
>
ApplyParallelExecutorPass
(
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
param_names
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
bool
use_cuda
,
#ifdef PADDLE_WITH_CUDA
const
BuildStrategy
&
strategy
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
#else
const
BuildStrategy
&
strategy
)
{
#endif
// Convert the program to graph.
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
// Apply a graph viz pass to record a graph.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
// Convert graph to run on multi-devices.
auto
multi_device_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_pass"
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
multi_device_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
multi_device_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
multi_device_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
multi_device_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
graph
=
multi_device_pass
->
Apply
(
std
::
move
(
graph
));
// Apply a graph print pass to record a graph with device info.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
multi_device_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_print_pass"
);
multi_device_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy
.
debug_graphviz_path_
);
multi_device_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
details
::
GraphvizSSAGraphPrinter
);
graph
=
multi_device_print_pass
->
Apply
(
std
::
move
(
graph
));
}
// Verify that the graph is correct for multi-device executor.
auto
multi_device_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_check_pass"
);
graph
=
multi_device_check_pass
->
Apply
(
std
::
move
(
graph
));
return
graph
;
}
class
ParallelExecutorPrivate
{
public:
explicit
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
)
...
...
@@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor(
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
details
::
SSAGraphBuilderFactory
builder_factory
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
);
if
(
member_
->
use_cuda_
)
{
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
,
member_
->
nccl_ctxs_
.
get
());
#else
PADDLE_THROW
(
"Not compiled with CUDA."
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
);
#endif
}
builder_
=
builder_factory
.
Create
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
...
...
@@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices(
// the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool
initializing
=
builder_
.
get
()
==
nullptr
?
true
:
false
;
bool
initializing
=
member_
->
executor_
?
false
:
true
;
for
(
auto
&
var
:
vars
)
{
int
var_dev_id
=
builder_
.
get
()
==
nullptr
?
-
1
:
builder_
->
GetVarDeviceID
(
var
);
int
var_dev_id
=
-
1
;
if
(
member_
->
executor_
)
{
auto
&
sharded_var_device
=
member_
->
executor_
->
Graph
().
Get
<
details
::
ShardedVarDevice
>
(
details
::
kShardedVarDevice
);
if
(
sharded_var_device
.
find
(
var
)
!=
sharded_var_device
.
end
())
{
var_dev_id
=
sharded_var_device
.
at
(
var
);
}
}
if
(
!
initializing
&&
var_dev_id
==
-
1
)
continue
;
framework
::
Variable
*
main_var
=
nullptr
;
...
...
@@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace framework
}
// namespace paddle
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_device_pass
);
USE_PASS
(
multi_device_check_pass
);
USE_PASS
(
multi_device_print_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
4b8ae523
...
...
@@ -70,7 +70,6 @@ class ParallelExecutor {
private:
ParallelExecutorPrivate
*
member_
;
std
::
unique_ptr
<
details
::
SSAGraphBuilder
>
builder_
;
};
}
// namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录