Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
142e832d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
142e832d
编写于
7月 25, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pass registration
上级
5b183557
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
191 addition
and
104 deletion
+191
-104
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+9
-22
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+7
-20
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
+18
-15
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+3
-9
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+10
-24
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+5
-1
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+0
-6
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+8
-1
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+116
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+15
-5
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
142e832d
...
@@ -34,30 +34,16 @@ namespace paddle {
...
@@ -34,30 +34,16 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
loss_var_name_
=
Get
<
std
::
string
>
(
"loss_var_name"
);
places_
=
Get
<
std
::
vector
<
platform
::
Place
>>
(
"places"
);
local_scopes_
=
Get
<
std
::
vector
<
Scope
*>>
(
"local_scopes"
);
strategy_
=
Get
<
BuildStrategy
>
(
"strategy"
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
),
strategy_
(
strategy
)
{
#else
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
#endif
#endif
for
(
auto
&
p
:
params
)
{
for
(
auto
&
p
:
Get
<
std
::
unordered_set
<
std
::
string
>>
(
"params"
))
{
grad_names_
.
insert
(
GradVarName
(
p
));
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
balance_vars_
.
resize
(
places_
.
size
(),
0
);
balance_vars_
.
resize
(
places_
.
size
(),
0
);
...
@@ -241,6 +227,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
...
@@ -241,6 +227,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
Init
();
// Give the topology sort order and rebuild the graph structure.
// Give the topology sort order and rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
auto
nodes
=
graph
->
ReleaseNodes
();
auto
nodes
=
graph
->
ReleaseNodes
();
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
142e832d
...
@@ -32,20 +32,6 @@ namespace details {
...
@@ -32,20 +32,6 @@ namespace details {
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
public:
public:
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
);
#else
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
);
#endif
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
...
@@ -53,15 +39,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -53,15 +39,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
private:
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
size_t
device_id
)
const
;
void
Init
()
const
;
private:
private:
std
::
string
loss_var_name_
;
mutable
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
mutable
std
::
unordered_set
<
std
::
string
>
grad_names_
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
#endif
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
...
@@ -113,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -113,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
private:
private:
BuildStrategy
strategy_
;
mutable
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
var_name_on_devices_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
var_name_on_devices_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
...
...
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
浏览文件 @
142e832d
...
@@ -22,26 +22,29 @@ namespace paddle {
...
@@ -22,26 +22,29 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
new
MultiDevSSAGraphBuilder
);
res
->
SetNotOwned
<
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places_
);
res
->
SetNotOwned
<
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name_
);
res
->
SetNotOwned
<
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names_
);
res
->
SetNotOwned
<
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes_
);
res
->
SetNotOwned
<
BuildStrategy
>
(
"strategy"
,
&
strategy_
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
res
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nccl_ctxs_
);
local_scopes_
,
nccl_ctxs_
,
strategy_
)
#else
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
strategy_
)
#endif
#endif
);
// NOLINT
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
SSAGraphBuilder
*
previous_pass
=
res
.
release
();
new
std
::
ofstream
(
strategy_
.
debug_graphviz_path_
)
);
res
.
reset
(
new
SSAGraghBuilderWithPrinter
);
PADDLE_ENFORCE
(
fout
->
good
()
);
res
->
Set
<
SSAGraphBuilder
>
(
"previous_pass"
,
previous_pass
);
std
::
unique_ptr
<
GraphvizSSAGraphPrinter
>
graphviz_printer
(
res
->
SetNotOwned
<
std
::
string
>
(
"debug_graphviz_path"
,
new
GraphvizSSAGraphPrinter
()
);
&
strategy_
.
debug_graphviz_path_
);
res
.
reset
(
new
SSAGraghBuilderWithPrinter
(
res
->
Set
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
std
::
move
(
fout
),
std
::
move
(
graphviz_printer
),
std
::
move
(
res
))
);
new
GraphvizSSAGraphPrinter
);
}
}
res
.
reset
(
new
SSAGraghBuilderWithChecker
(
std
::
move
(
res
)));
SSAGraphBuilder
*
previous_pass
=
res
.
release
();
res
.
reset
(
new
SSAGraghBuilderWithChecker
);
res
->
Set
<
SSAGraphBuilder
>
(
"previous_pass"
,
previous_pass
);
return
res
;
return
res
;
}
}
...
...
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
142e832d
...
@@ -24,25 +24,19 @@ namespace details {
...
@@ -24,25 +24,19 @@ namespace details {
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
public:
public:
explicit
SSAGraghBuilderWithChecker
(
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
builder_
(
std
::
move
(
builder
))
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
new_graph
;
return
new_graph
;
}
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
return
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
GetVarDeviceID
(
var_name
);
}
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
private:
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
142e832d
...
@@ -14,7 +14,9 @@
...
@@ -14,7 +14,9 @@
#pragma once
#pragma once
#include <fstream>
#include <iosfwd>
#include <iosfwd>
#include <ostream>
#include <string>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
...
@@ -35,37 +37,21 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
...
@@ -35,37 +37,21 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
public:
public:
SSAGraghBuilderWithPrinter
(
std
::
ostream
&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ref_
(
sout
)
{}
SSAGraghBuilderWithPrinter
(
std
::
unique_ptr
<
std
::
ostream
>&&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
"debug_graphviz_path"
)));
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
new_graph
,
*
fout
);
return
new_graph
;
return
new_graph
;
}
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
return
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
GetVarDeviceID
(
var_name
);
}
}
private:
std
::
unique_ptr
<
SSAGraphPrinter
>
printer_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
std
::
ostream
>
stream_ptr_
;
std
::
ostream
&
stream_ref_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
142e832d
...
@@ -23,7 +23,8 @@ namespace ir {
...
@@ -23,7 +23,8 @@ namespace ir {
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path_
));
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
"graph_viz_path"
);
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
PADDLE_ENFORCE
(
fout
->
good
());
std
::
ostream
&
sout
=
*
fout
;
std
::
ostream
&
sout
=
*
fout
;
...
@@ -61,6 +62,9 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
...
@@ -61,6 +62,9 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
sout
<<
"}
\n
"
;
sout
<<
"}
\n
"
;
return
graph
;
return
graph
;
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
);
paddle/fluid/framework/ir/graph_viz_pass.h
浏览文件 @
142e832d
...
@@ -29,14 +29,8 @@ namespace ir {
...
@@ -29,14 +29,8 @@ namespace ir {
class
GraphVizPass
:
public
Pass
{
class
GraphVizPass
:
public
Pass
{
public:
public:
explicit
GraphVizPass
(
const
std
::
string
&
graph_viz_path
)
:
graph_viz_path_
(
graph_viz_path
)
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
private:
const
std
::
string
graph_viz_path_
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
142e832d
...
@@ -15,5 +15,12 @@ limitations under the License. */
...
@@ -15,5 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{}
// namespace framework
namespace
framework
{
namespace
ir
{
PassRegistry
&
PassRegistry
::
Instance
()
{
static
PassRegistry
g_pass_info_map
;
return
g_pass_info_map
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/pass.h
浏览文件 @
142e832d
...
@@ -14,9 +14,14 @@ limitations under the License. */
...
@@ -14,9 +14,14 @@ limitations under the License. */
#pragma once
#pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -25,10 +30,120 @@ namespace ir {
...
@@ -25,10 +30,120 @@ namespace ir {
class
Pass
{
class
Pass
{
public:
public:
Pass
()
=
default
;
Pass
()
=
default
;
virtual
~
Pass
()
{}
virtual
~
Pass
()
{
for
(
auto
&
attr
:
attrs_
)
{
if
(
attr_dels_
.
find
(
attr
.
first
)
!=
attr_dels_
.
end
())
{
attr_dels_
[
attr
.
first
]();
}
}
attrs_
.
clear
();
attr_dels_
.
clear
();
}
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
delete
attr
;
};
}
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
attrs_
[
attr_name
]
=
attr
;
}
private:
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
};
using
PassCreator
=
std
::
function
<
std
::
unique_ptr
<
Pass
>
()
>
;
class
Registrar
{
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void
Touch
()
{}
};
};
class
PassRegistry
{
public:
static
PassRegistry
&
Instance
();
bool
Has
(
const
std
::
string
&
pass_type
)
const
{
return
map_
.
find
(
pass_type
)
!=
map_
.
end
();
}
void
Insert
(
const
std
::
string
&
type
,
const
PassCreator
&
pass_creator
)
{
PADDLE_ENFORCE
(
!
Has
(
type
),
"Pass %s has been registered"
,
type
);
map_
.
insert
({
type
,
pass_creator
});
}
std
::
unique_ptr
<
Pass
>
Get
(
const
std
::
string
&
type
)
const
{
PADDLE_ENFORCE
(
Has
(
type
),
"Pass %s has not been registered"
,
type
);
return
map_
.
at
(
type
)();
}
private:
PassRegistry
()
=
default
;
std
::
unordered_map
<
std
::
string
,
PassCreator
>
map_
;
DISABLE_COPY_AND_ASSIGN
(
PassRegistry
);
};
template
<
typename
PassType
>
struct
PassRegistrar
:
public
Registrar
{
explicit
PassRegistrar
(
const
char
*
pass_type
)
{
PADDLE_ENFORCE
(
!
PassRegistry
::
Instance
().
Has
(
pass_type
),
"'%s' is registered more than once."
,
pass_type
);
PassRegistry
::
Instance
().
Insert
(
pass_type
,
[]()
->
std
::
unique_ptr
<
Pass
>
{
return
std
::
unique_ptr
<
Pass
>
(
new
PassType
());
});
}
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
}
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \
TouchPassRegistrar_##pass_type()
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
142e832d
...
@@ -132,19 +132,27 @@ ParallelExecutor::ParallelExecutor(
...
@@ -132,19 +132,27 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW
(
"Not compiled with CUDA."
);
PADDLE_THROW
(
"Not compiled with CUDA."
);
#endif
#endif
}
}
builder_
=
builder_factory
.
Create
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
if
(
!
build_strategy
.
debug_graphviz_path_
.
empty
())
{
if
(
!
build_strategy
.
debug_graphviz_path_
.
empty
())
{
const
std
::
string
origin_graph_path
=
string
::
Sprintf
(
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
build_strategy
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
"%s%s"
,
build_strategy
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
graph
=
ir
::
GraphVizPass
(
origin_graph_path
).
Apply
(
std
::
move
(
graph
));
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
}
builder_
=
builder_factory
.
Create
();
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
if
(
!
build_strategy
.
debug_graphviz_path_
.
empty
())
{
if
(
!
build_strategy
.
debug_graphviz_path_
.
empty
())
{
const
std
::
string
origin_graph_path
=
string
::
Sprintf
(
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
build_strategy
.
debug_graphviz_path_
.
c_str
(),
"_before_exec"
);
"%s%s"
,
build_strategy
.
debug_graphviz_path_
.
c_str
(),
"_before_exec"
);
graph
=
ir
::
GraphVizPass
(
origin_graph_path
).
Apply
(
std
::
move
(
graph
));
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
}
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
...
@@ -297,3 +305,5 @@ ParallelExecutor::~ParallelExecutor() {
...
@@ -297,3 +305,5 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
graph_viz_pass
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录