Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b3cd2ae8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b3cd2ae8
编写于
8月 31, 2018
作者:
L
luotao1
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into ner_ut2
上级
07cb64ad
c709a04a
变更
35
隐藏空白更改
内联
并排
Showing
35 changed file
with
786 addition
and
277 deletion
+786
-277
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+4
-12
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+1
-1
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+37
-75
paddle/fluid/framework/ir/fc_fuse_pass.h
paddle/fluid/framework/ir/fc_fuse_pass.h
+2
-1
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
+7
-2
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+4
-4
paddle/fluid/framework/ir/fuse_pass_base.h
paddle/fluid/framework/ir/fuse_pass_base.h
+18
-2
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+180
-1
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+81
-15
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
+33
-0
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+16
-2
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+2
-2
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+5
-5
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+3
-0
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+1
-0
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+0
-1
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+19
-0
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+2
-1
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+2
-2
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
+4
-3
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+18
-9
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+3
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+77
-102
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+51
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+1
-1
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+3
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+68
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+2
-2
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+38
-14
python/paddle/fluid/tests/unittests/test_dist_mnist.py
python/paddle/fluid/tests/unittests/test_dist_mnist.py
+21
-0
python/paddle/fluid/tests/unittests/test_name_scope.py
python/paddle/fluid/tests/unittests/test_name_scope.py
+45
-0
python/paddle/fluid/tests/unittests/test_operator_desc.py
python/paddle/fluid/tests/unittests/test_operator_desc.py
+4
-1
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
...dle/fluid/tests/unittests/test_parallel_executor_mnist.py
+14
-12
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+19
-6
未找到文件。
paddle/fluid/API.spec
浏览文件 @
b3cd2ae8
...
@@ -36,6 +36,7 @@ paddle.fluid.default_startup_program ArgSpec(args=[], varargs=None, keywords=Non
...
@@ -36,6 +36,7 @@ paddle.fluid.default_startup_program ArgSpec(args=[], varargs=None, keywords=Non
paddle.fluid.default_main_program ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
paddle.fluid.default_main_program ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.name_scope ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False))
paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False))
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
...
@@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
)
const
{
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
)
const
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
communication_dev_ctx
=
nccl_ctxs_
?
nccl_ctxs_
->
DevCtx
(
places_
[
i
])
:
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
#else
auto
*
communication_dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
#endif
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
);
communication_dev_ctx
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// FIXME: Currently ScaleLossGradOp only use device_count as scale
...
@@ -744,7 +736,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -744,7 +736,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
.
emplace
(
varname
,
op_dev_id
);
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
PADDLE_
ENFORCE
(
PADDLE_
THROW
(
"the distribute training related op should be in [split_byref, "
"the distribute training related op should be in [split_byref, "
"concat]."
);
"concat]."
);
}
}
...
...
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
...
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
auto
*
while_pat_node
=
gpd
.
pattern
().
RetriveNode
(
"while"
);
auto
*
while_pat_node
=
gpd
.
pattern
().
Retri
e
veNode
(
"while"
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
marked_nodes
.
insert
(
while_node
);
marked_nodes
.
insert
(
while_node
);
};
};
...
...
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
...
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
}
}
void
BuildFCPattern
(
PDPattern
*
pattern
)
{
void
BuildFCPattern
(
PDPattern
*
pattern
)
{
// make sure the selected MUL op has one input argument is a parameter.
// Create Operators
auto
*
mul_parameter_var
=
pattern
->
NewNode
(
auto
*
mul_op
=
pattern
->
NewNode
(
"mul"
)
->
assert_is_op
(
"mul"
);
[](
Node
*
node
)
{
auto
*
elementwise_add_op
=
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
==
1UL
&&
pattern
->
NewNode
(
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
);
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"mul"
&&
node
->
Var
()
&&
// Create variables
node
->
Var
()
->
Persistable
();
// check is a parameter
// w
},
auto
*
mul_weight_var
=
pattern
->
NewNode
(
"mul_weight"
)
"mul_weight"
/*name*/
);
->
AsInput
()
->
assert_is_op_nth_input
(
"mul"
,
"Y"
,
0
);
auto
*
mul_tmp_input_var
=
pattern
->
NewNode
(
// x
[](
Node
*
node
)
{
auto
*
mul_tmp_var
=
pattern
->
NewNode
(
"mul_tmp_var"
)
bool
result
=
->
AsInput
()
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
->
assert_is_op_nth_input
(
"mul"
,
"X"
,
0
);
!
node
->
Var
()
->
Persistable
();
// this input is not an parameter.
// intermediate variable, will be removed in the IR after fuse.
if
(
!
result
)
return
false
;
auto
*
mul_out_var
=
pattern
->
NewNode
(
"mul_out"
)
// check whether one output is MUL op.
->
AsIntermediate
()
for
(
auto
*
op
:
node
->
outputs
)
{
->
assert_is_only_output_of_op
(
"mul"
)
if
(
op
->
IsOp
()
&&
op
->
Op
()
->
Type
()
==
"mul"
)
return
true
;
->
assert_is_op_input
(
"elementwise_add"
);
}
// bias
return
false
;
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
"elementwise_add_tmpvar"
)
},
->
assert_is_op_input
(
"elementwise_add"
)
"mul_tmp_var"
/*name*/
);
->
AsInput
();
// output
// select a MUL op
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
"elementwise_add_out"
)
auto
*
mul_op
=
pattern
->
NewNode
(
->
AsOutput
()
[](
Node
*
node
)
{
->
assert_is_op_output
(
"elementwise_add"
);
return
node
->
IsOp
()
&&
// start from an Op
node
->
Op
()
->
Type
()
==
"mul"
;
// type is mul
mul_op
->
LinksFrom
({
mul_weight_var
,
mul_tmp_var
}).
LinksTo
({
mul_out_var
});
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul"
/*name*/
);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto
*
mul_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
// starts from a Var
node
->
outputs
.
size
()
==
1UL
&&
// only has one consumer
node
->
outputs
.
front
()
->
IsOp
()
&&
// check basic logic
node
->
Var
()
&&
// not a ControlDepVar
node
->
outputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
// a very strong validation
},
"mul_out"
);
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto
*
elementwise_add_tmp_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
outputs
.
size
()
>=
1UL
&&
node
->
Var
()
&&
VarOutLinksToOp
(
node
,
"elementwise_add"
);
},
"elementwise_add_tmpvar"
);
// select an ELEMENTWISE_ADD op
auto
*
elementwise_add_op
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add"
/*name*/
);
// get the ELEMENTWISE_ADD op's output
auto
*
elementwise_add_out_var
=
pattern
->
NewNode
(
[](
Node
*
node
)
{
return
node
->
IsVar
()
&&
node
->
inputs
.
size
()
==
1UL
&&
node
->
Var
()
&&
node
->
inputs
.
front
()
->
Op
()
->
Type
()
==
"elementwise_add"
;
},
"elementwise_add_out"
);
mul_op
->
LinksFrom
({
mul_parameter_var
,
mul_tmp_input_var
})
.
LinksTo
({
mul_out_var
});
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
elementwise_add_op
->
LinksFrom
({
mul_out_var
,
elementwise_add_tmp_var
})
.
LinksTo
({
elementwise_add_out_var
});
.
LinksTo
({
elementwise_add_out_var
});
}
}
...
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
...
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std
::
unique_ptr
<
ir
::
Graph
>
FCFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
FCFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"fc"
,
graph
.
get
());
std
::
unordered_set
<
Node
*>
nodes2delete
;
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
BuildFCPattern
(
gpd
.
mutable_pattern
());
BuildFCPattern
(
gpd
.
mutable_pattern
());
#define GET_NODE(id) \
#define GET_NODE(id)
\
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
"pattern has no Node called %s", #id);
\
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
auto* id = subgraph.at(gpd.pattern().Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int
found_fc_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FC fuse"
;
VLOG
(
4
)
<<
"handle FC fuse"
;
...
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
...
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
graph
->
RemoveNode
(
mul
);
graph
->
RemoveNode
(
mul
);
graph
->
RemoveNode
(
elementwise_add
);
graph
->
RemoveNode
(
elementwise_add
);
graph
->
RemoveNode
(
mul_out
);
// tmp variable
graph
->
RemoveNode
(
mul_out
);
// tmp variable
found_fc_count
++
;
};
};
gpd
(
graph
.
get
(),
handler
);
gpd
(
graph
.
get
(),
handler
);
AddStatis
(
found_fc_count
);
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/framework/ir/fc_fuse_pass.h
浏览文件 @
b3cd2ae8
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -23,7 +24,7 @@ namespace ir {
...
@@ -23,7 +24,7 @@ namespace ir {
/*
/*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/
*/
class
FCFusePass
:
public
Pass
{
class
FCFusePass
:
public
FusePassBase
{
public:
public:
virtual
~
FCFusePass
()
{}
virtual
~
FCFusePass
()
{}
...
...
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
浏览文件 @
b3cd2ae8
...
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
...
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
const
std
::
vector
<
std
::
string
>&
outputs
)
{
const
std
::
vector
<
std
::
string
>&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetType
(
type
);
op
->
SetInput
(
"Xs"
,
inputs
);
if
(
type
==
"mul"
)
{
op
->
SetOutput
(
"Ys"
,
outputs
);
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
}
else
if
(
type
==
"elementwise_add"
)
{
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
}
}
// a->OP0->b
// a->OP0->b
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetriveNode
(
"any_node"
));
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
Retri
e
veNode
(
"any_node"
));
marked_nodes
.
insert
(
id
);
marked_nodes
.
insert
(
id
);
};
};
gpd
(
graph
.
get
(),
handler
);
gpd
(
graph
.
get
(),
handler
);
...
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
#undef GET_NODE
#undef GET_NODE
#undef SET_IN
#undef SET_IN
LOG
(
INFO
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
VLOG
(
4
)
<<
"hidden_n: "
<<
hidden_n
->
Name
();
LOG
(
INFO
)
<<
"cell: "
<<
cell_n
->
Name
();
VLOG
(
4
)
<<
"cell: "
<<
cell_n
->
Name
();
LOG
(
INFO
)
<<
"xx: "
<<
xx_n
->
Name
();
VLOG
(
4
)
<<
"xx: "
<<
xx_n
->
Name
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
...
...
paddle/fluid/framework/ir/fuse_pass_base.h
浏览文件 @
b3cd2ae8
...
@@ -22,21 +22,37 @@ namespace paddle {
...
@@ -22,21 +22,37 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
const
char
kParamScopeAttr
[]
=
"param_scope"
;
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
char
kFuseStatisAttr
[]
=
"__fuse_statis__"
;
class
FusePassBase
:
public
Pass
{
class
FusePassBase
:
public
Pass
{
public:
public:
void
Init
(
Graph
*
graph
)
const
{
graph_
=
graph
;
}
void
Init
(
const
std
::
string
&
repr
,
Graph
*
graph
)
const
{
repr_
=
repr
;
graph_
=
graph
;
}
Scope
*
param_scope
()
const
{
Scope
*
param_scope
()
const
{
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
PADDLE_ENFORCE
(
graph_
->
Has
(
kParamScopeAttr
));
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
return
graph_
->
Get
<
framework
::
Scope
*>
(
kParamScopeAttr
);
}
}
void
AddStatis
(
int
count_of_fused
)
const
{
PADDLE_ENFORCE
(
graph_
);
PADDLE_ENFORCE
(
!
repr_
.
empty
());
if
(
!
graph_
->
Has
(
kFuseStatisAttr
))
{
graph_
->
Set
(
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
);
}
auto
&
info
=
graph_
->
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
kFuseStatisAttr
);
info
[
repr_
]
=
count_of_fused
;
}
virtual
~
FusePassBase
()
{}
virtual
~
FusePassBase
()
{}
protected:
protected:
mutable
Graph
*
graph_
;
mutable
Graph
*
graph_
;
mutable
std
::
string
repr_
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
b3cd2ae8
...
@@ -27,6 +27,19 @@ namespace ir {
...
@@ -27,6 +27,19 @@ namespace ir {
size_t
PDPattern
::
id_
=
0UL
;
size_t
PDPattern
::
id_
=
0UL
;
PDNode
*
PDPattern
::
NewNode
(
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
"PDNode's name should be unique, get duplicate [%s]"
,
name
);
}
nodes_
.
emplace_back
(
new
PDNode
(
this
,
name
));
auto
*
cur
=
nodes_
.
back
().
get
();
node_map_
[
name
]
=
cur
;
return
cur
;
}
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
PDNode
*
PDPattern
::
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
)
{
if
(
!
name
.
empty
())
{
if
(
!
name
.
empty
())
{
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
PADDLE_ENFORCE_EQ
(
node_map_
.
count
(
name
),
0
,
...
@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
...
@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
return
cur
;
return
cur
;
}
}
PDNode
*
PDPattern
::
RetriveNode
(
const
std
::
string
&
id
)
const
{
PDNode
*
PDPattern
::
Retri
e
veNode
(
const
std
::
string
&
id
)
const
{
auto
it
=
node_map_
.
find
(
id
);
auto
it
=
node_map_
.
find
(
id
);
if
(
it
==
node_map_
.
end
())
{
if
(
it
==
node_map_
.
end
())
{
return
nullptr
;
return
nullptr
;
...
@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
...
@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
auto
subgraphs
=
DetectPatterns
();
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
UniquePatterns
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
ValidateByNodeRole
(
&
subgraphs
);
if
(
subgraphs
.
empty
())
return
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
int
id
=
0
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
for
(
auto
&
g
:
subgraphs
)
{
...
@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
...
@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
}
}
}
}
}
}
// Check to early stop if some PDNode can't find matched Node.
for
(
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
!
pdnodes2nodes_
.
count
(
pdnode
.
get
()))
{
VLOG
(
4
)
<<
pdnode
->
name
()
<<
" can't find matched Node, early stop"
;
return
false
;
}
}
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
VLOG
(
3
)
<<
pdnodes2nodes_
.
size
()
<<
" nodes marked"
;
return
!
pdnodes2nodes_
.
empty
();
return
!
pdnodes2nodes_
.
empty
();
}
}
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void
GraphPatternDetector
::
ValidateByNodeRole
(
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>*
subgraphs
)
{
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>
result
;
subgraphs
->
erase
(
std
::
remove_if
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
bool
{
// Collect the inputs and outputs.
std
::
unordered_set
<
Node
*>
ios
;
for
(
auto
&
item
:
subgraph
)
{
if
(
!
item
.
first
->
IsIntermediate
())
{
ios
.
insert
(
item
.
second
);
}
}
for
(
auto
&
item
:
subgraph
)
{
if
(
item
.
first
->
IsIntermediate
())
{
for
(
auto
*
x
:
item
.
second
->
inputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
for
(
auto
*
x
:
item
.
second
->
outputs
)
{
if
(
!
ios
.
count
(
x
))
{
return
true
;
}
}
}
}
return
false
;
}),
subgraphs
->
end
());
}
struct
HitGroup
{
struct
HitGroup
{
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
std
::
unordered_map
<
PDNode
*
,
Node
*>
roles
;
...
@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
// in edges of PDNodes.
// in edges of PDNodes.
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
for
(
const
auto
&
edge
:
pattern_
.
edges
())
{
VLOG
(
4
)
<<
"check "
<<
edge
.
first
->
name
()
<<
" -> "
<<
edge
.
second
->
name
();
VLOG
(
4
)
<<
"check "
<<
edge
.
first
->
name
()
<<
" -> "
<<
edge
.
second
->
name
();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles.
// Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
// Detect two Nodes that can match these two roles and they are connected.
auto
&
pre_groups
=
bi_records
[
step
%
2
];
auto
&
pre_groups
=
bi_records
[
step
%
2
];
...
@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target
// source -> target
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
source
:
pdnodes2nodes_
[
edge
.
first
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
for
(
Node
*
target
:
pdnodes2nodes_
[
edge
.
second
])
{
VLOG
(
8
)
<<
"check "
<<
source
->
id
()
<<
" -- "
<<
target
->
id
();
// TODO(Superjomn) add some prune strategies.
// TODO(Superjomn) add some prune strategies.
for
(
const
auto
&
group
:
pre_groups
)
{
for
(
const
auto
&
group
:
pre_groups
)
{
HitGroup
new_group
=
group
;
HitGroup
new_group
=
group
;
...
@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
...
@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
}
}
}
}
VLOG
(
3
)
<<
"step "
<<
step
<<
" get records: "
<<
cur_groups
.
size
();
VLOG
(
3
)
<<
"step "
<<
step
<<
" get records: "
<<
cur_groups
.
size
();
for
(
auto
&
group
:
cur_groups
)
{
for
(
auto
&
item
:
group
.
roles
)
{
VLOG
(
4
)
<<
"node "
<<
item
.
second
->
id
()
<<
" as "
<<
item
.
first
->
name
();
}
VLOG
(
4
)
<<
"========================================================="
;
}
}
}
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
for
(
auto
&
group
:
bi_records
[
step
%
2
])
{
...
@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
...
@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
return
*
this
;
return
*
this
;
}
}
PDNode
*
PDNode
::
assert_is_op
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op
(
const
std
::
string
&
op_type
)
{
asserts_
.
emplace_back
([
this
,
op_type
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
Type
()
==
op_type
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_var
()
{
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
x
&&
x
->
IsVar
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_var_not_persistable
()
{
assert_is_var
();
asserts_
.
emplace_back
([
this
](
Node
*
x
)
{
return
!
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_persistable_var
()
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
Var
()
->
Persistable
();
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
IsNthInput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
IsNthOutput
(
x
,
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
inputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
&&
op
->
outputs
.
size
()
==
1
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_output
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
)
{
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outputs
)
{
if
(
op
&&
op
->
IsOp
()
&&
op
->
Op
()
&&
op
->
Op
()
->
Type
()
==
op_type
)
{
return
true
;
}
}
return
false
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
inputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
)
{
assert_is_op
(
op_type
);
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
outputs
.
size
()
==
n
;
});
return
this
;
}
PDNode
*
PDNode
::
assert_more
(
PDNode
::
teller_t
&&
teller
)
{
asserts_
.
emplace_back
(
std
::
move
(
teller
));
return
this
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
b3cd2ae8
...
@@ -39,14 +39,24 @@ struct PDNode {
...
@@ -39,14 +39,24 @@ struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode.
// tell whether an ir::Node* is a candidation for a PDNode.
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
using
teller_t
=
std
::
function
<
bool
(
Node
*
)
>
;
enum
class
Type
{
kOp
,
kVar
};
enum
class
Type
{
kOp
,
kVar
};
enum
class
Role
{
kUnknown
,
// No role,
kInput
,
// an input and will be retained,
kOutput
,
// an output and will be retained,
kIntermediate
// will be removed after handler.
};
// this link to others
// this link to others
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksTo
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
PDNode
&
LinksFrom
(
const
std
::
vector
<
PDNode
*>&
others
);
bool
Tell
(
Node
*
node
)
const
{
bool
Tell
(
Node
*
node
)
const
{
PADDLE_ENFORCE
(
teller_
!=
nullptr
,
"teller should be set for a PDNode"
);
if
(
teller_
)
return
teller_
(
node
);
return
teller_
(
node
);
for
(
auto
&
asrt
:
asserts_
)
{
if
(
!
asrt
(
node
))
return
false
;
}
return
true
;
}
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
bool
IsOp
()
const
{
return
type_
==
Type
::
kOp
;
}
...
@@ -54,10 +64,52 @@ struct PDNode {
...
@@ -54,10 +64,52 @@ struct PDNode {
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
PDNode
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
&
operator
=
(
const
PDNode
&
)
=
delete
;
PDNode
(
const
PDNode
&
)
=
delete
;
// Mark this node is an Input of a subgraph and will be retained.
PDNode
*
AsInput
()
{
role_
=
Role
::
kInput
;
return
this
;
}
// Mark this node is an Output of a subgraph and will be retained.
PDNode
*
AsOutput
()
{
role_
=
Role
::
kOutput
;
return
this
;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PDNode
*
AsIntermediate
()
{
role_
=
Role
::
kIntermediate
;
return
this
;
}
bool
IsIntermediate
()
const
{
return
role_
==
Role
::
kIntermediate
;
}
bool
IsInput
()
const
{
return
role_
==
Role
::
kInput
;
}
bool
IsOutput
()
const
{
return
role_
==
Role
::
kOutput
;
}
// Assertions, helper functions to simplify the pattern definition.
PDNode
*
assert_is_op
();
PDNode
*
assert_is_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_var
();
PDNode
*
assert_var_not_persistable
();
PDNode
*
assert_is_persistable_var
();
PDNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_op_nth_output
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
PDNode
*
assert_is_only_input_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_is_only_output_of_op
(
const
std
::
string
&
op_type
);
PDNode
*
assert_op_has_n_inputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_op_has_n_outputs
(
const
std
::
string
&
op_type
,
size_t
n
);
PDNode
*
assert_more
(
teller_t
&&
teller
);
private:
private:
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
:
pattern_
(
pattern
),
name_
(
name
),
type_
(
type
)
{}
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
PDNode
(
teller_t
&&
teller
,
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
Type
type
=
Type
::
kVar
)
:
teller_
(
std
::
move
(
teller
)),
:
teller_
(
std
::
move
(
teller
)),
...
@@ -71,10 +123,13 @@ struct PDNode {
...
@@ -71,10 +123,13 @@ struct PDNode {
friend
class
PDPattern
;
friend
class
PDPattern
;
// Will removed latter.
teller_t
teller_
;
teller_t
teller_
;
std
::
vector
<
teller_t
>
asserts_
;
PDPattern
*
pattern_
;
PDPattern
*
pattern_
;
std
::
string
name_
;
std
::
string
name_
;
Type
type_
;
Type
type_
;
Role
role_
{
Role
::
kUnknown
};
};
};
/*
/*
...
@@ -87,19 +142,18 @@ struct PDNode {
...
@@ -87,19 +142,18 @@ struct PDNode {
* This pattern can be defined as with the following pseudo codes
* This pattern can be defined as with the following pseudo codes
*
*
* // Create two operator PDNodes.
* // Create two operator PDNodes.
* MUL = PDPattern.NewNode()
* MUL = PDPattern.NewNode()
.assert_is_op("mul");
* ELE = PDPattern.NewNode()
* ELE = PDPattern.NewNode()
.assert_is_op("elementwise_add");
* // Create the variable PDNodes.
* // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode()
* MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \
* // Add teller to define some rules that help to filter the target Nodes.
* .assert_is_op_input("elementwise_add") \
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
* .AsIntermediate();
* ELE.teller = lambda(node): \
* // Add relations.
* node->IsOp() && node->Op()->Type == "elementwise_add";
* MUL->LinksTo({MUL_out});
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
* MUL_out->LinksTo({ELE});
* && (ELE in node->outputs)
*
*
* One can add more specific
teller
s for PDNodes or edges, both the Operator
* One can add more specific
assert
s for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.
teller
.
* and Variable Nodes can be ruled in PDNode.
assert_more(...)
.
*
*
* PDPattern can record the general patterns, such as the pattern represents
* PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
...
@@ -112,7 +166,8 @@ class PDPattern {
...
@@ -112,7 +166,8 @@ class PDPattern {
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
void
AddEdge
(
PDNode
*
a
,
PDNode
*
b
);
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
NewNode
(
PDNode
::
teller_t
&&
teller
,
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetriveNode
(
const
std
::
string
&
id
)
const
;
PDNode
*
NewNode
(
const
std
::
string
&
name
=
NewID
());
PDNode
*
RetrieveNode
(
const
std
::
string
&
id
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
std
::
unique_ptr
<
PDNode
>>&
nodes
()
const
{
return
nodes_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
const
std
::
vector
<
edge_t
>&
edges
()
const
{
return
edges_
;
}
...
@@ -185,6 +240,9 @@ class GraphPatternDetector {
...
@@ -185,6 +240,9 @@ class GraphPatternDetector {
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
void
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
// Validate whether the intermediate nodes are linked by external nodes.
void
ValidateByNodeRole
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
#ifdef PADDLE_WITH_TESTING
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
MarkPDNodesInGraph
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
FRIEND_TEST
(
GraphPatternDetecter
,
DetectPatterns
);
...
@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
...
@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
return
var
->
Name
()
==
op
->
Op
()
->
Input
(
argument
)[
nth
];
}
}
static
bool
IsNthOutput
(
Node
*
var
,
Node
*
op
,
const
std
::
string
&
argument
,
size_t
nth
)
{
PADDLE_ENFORCE
(
var
->
IsVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
op
->
inputs
.
size
()
<=
nth
)
return
false
;
return
var
->
Name
()
==
op
->
Op
()
->
Output
(
argument
)[
nth
];
}
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
static
void
GraphSafeRemoveNodes
(
Graph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
const
std
::
unordered_set
<
const
Node
*>&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
浏览文件 @
b3cd2ae8
...
@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
...
@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
ASSERT_LE
(
count
,
2
);
ASSERT_LE
(
count
,
2
);
}
}
TEST
(
GraphPatternDetector
,
IntermediateCheck
)
{
ProgramDesc
program
;
Graph
graph
(
program
);
BuildGraph
(
&
graph
);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
GraphPatternDetector
detector
;
auto
*
op2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op2"
;
},
"op2"
);
auto
*
op3
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Name
()
==
"op3"
;
},
"op3"
);
auto
*
v2
=
detector
.
mutable_pattern
()
->
NewNode
(
[](
Node
*
x
)
{
return
x
&&
x
->
IsVar
()
&&
x
->
Name
()
==
"var2"
;
},
"var2"
)
->
AsIntermediate
();
v2
->
LinksFrom
({
op2
}).
LinksTo
({
op3
});
int
count
=
0
;
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
EXPECT_EQ
(
count
,
0
);
count
=
0
;
v2
->
AsInput
();
detector
(
&
graph
,
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
g
,
Graph
*
graph
)
{
++
count
;
});
ASSERT_EQ
(
count
,
1
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -16,13 +16,27 @@ limitations under the License. */
...
@@ -16,13 +16,27 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
using
inference
::
analysis
::
Dot
;
using
inference
::
analysis
::
Dot
;
namespace
{
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
string
FormatName
(
const
Node
*
node
)
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
||
!
node
->
Op
()
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
()))
{
return
node
->
Name
();
}
const
std
::
string
full_scope
=
boost
::
get
<
std
::
string
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
()));
return
string
::
Sprintf
(
"%s%s"
,
full_scope
.
c_str
(),
node
->
Name
().
c_str
());
}
}
// namespace
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
...
@@ -54,7 +68,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
...
@@ -54,7 +68,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
auto
marked_nodes
=
ConsumeMarkedNodes
(
graph
.
get
());
auto
marked_nodes
=
ConsumeMarkedNodes
(
graph
.
get
());
// Create nodes
// Create nodes
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
std
::
string
node_id
=
n
->
Name
(
)
+
"("
+
std
::
to_string
(
n
->
id
())
+
")"
;
std
::
string
node_id
=
FormatName
(
n
)
+
"("
+
std
::
to_string
(
n
->
id
())
+
")"
;
if
(
n
->
IsOp
())
{
if
(
n
->
IsOp
())
{
decltype
(
op_attrs
)
attr
=
decltype
(
op_attrs
)
attr
=
marked_nodes
.
count
(
n
)
?
marked_op_attrs
:
op_attrs
;
marked_nodes
.
count
(
n
)
?
marked_op_attrs
:
op_attrs
;
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
b3cd2ae8
...
@@ -55,11 +55,11 @@ class Node {
...
@@ -55,11 +55,11 @@ class Node {
std
::
string
Name
()
const
{
return
name_
;
}
std
::
string
Name
()
const
{
return
name_
;
}
VarDesc
*
Var
()
{
VarDesc
*
Var
()
{
PADDLE_ENFORCE
(
type_
==
Type
::
kVariable
);
PADDLE_ENFORCE
(
IsVar
()
);
return
var_desc_
.
get
();
return
var_desc_
.
get
();
}
}
OpDesc
*
Op
()
{
OpDesc
*
Op
()
const
{
PADDLE_ENFORCE
(
IsOp
());
PADDLE_ENFORCE
(
IsOp
());
return
op_desc_
.
get
();
return
op_desc_
.
get
();
}
}
...
...
paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
...
@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
SeqConcatFcFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
graph
.
get
());
FusePassBase
::
Init
(
"seq_concat_fc_fuse"
,
graph
.
get
());
GraphPatternDetector
detector
;
GraphPatternDetector
detector
;
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
pattern
=
detector
.
mutable_pattern
();
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
auto
*
concat_out
=
BuildSeqExpandConcatPattern
(
pattern
);
BuildFCPattern
(
pattern
,
concat_out
);
BuildFCPattern
(
pattern
,
concat_out
);
#define GET_NODE(id, pattern) \
#define GET_NODE(id, pattern)
\
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
PADDLE_ENFORCE(subgraph.count(pattern.Retri
e
veNode(#id)), \
"pattern has no Node called %s", #id); \
"pattern has no Node called %s", #id);
\
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
auto* id = subgraph.at(pattern.Retri
e
veNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
detector
(
graph
.
get
(),
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
b3cd2ae8
...
@@ -129,6 +129,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
...
@@ -129,6 +129,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
"Optimized for variable"
)
"Optimized for variable"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
std
::
string
>
(
OpNamescopeAttrName
(),
"Operator name with namesope."
)
.
SetDefault
(
""
);
Validate
();
Validate
();
}
}
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
b3cd2ae8
...
@@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker {
...
@@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker {
public:
public:
static
const
char
*
OpRoleAttrName
()
{
return
"op_role"
;
}
static
const
char
*
OpRoleAttrName
()
{
return
"op_role"
;
}
static
const
char
*
OpRoleVarAttrName
()
{
return
"op_role_var"
;
}
static
const
char
*
OpRoleVarAttrName
()
{
return
"op_role_var"
;
}
static
const
char
*
OpNamescopeAttrName
()
{
return
"op_namescope"
;
}
void
operator
()(
proto
::
OpProto
*
proto
,
OpAttrChecker
*
attr_checker
);
void
operator
()(
proto
::
OpProto
*
proto
,
OpAttrChecker
*
attr_checker
);
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
b3cd2ae8
...
@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
...
@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
void
AddGraphvizDebugerPass
(
Pass
*
pass
)
{
void
AddGraphvizDebugerPass
(
Pass
*
pass
)
{
auto
*
debuger_pass
=
pass
->
CreateGraphvizDebugerPass
();
auto
*
debuger_pass
=
pass
->
CreateGraphvizDebugerPass
();
if
(
debuger_pass
)
{
if
(
debuger_pass
)
{
LOG
(
INFO
)
<<
" - register debug pass ["
<<
debuger_pass
->
repr
()
<<
"]"
;
Register
(
debuger_pass
->
repr
(),
debuger_pass
);
Register
(
debuger_pass
->
repr
(),
debuger_pass
);
}
}
}
}
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
b3cd2ae8
...
@@ -16,10 +16,13 @@
...
@@ -16,10 +16,13 @@
#include <google/protobuf/text_format.h>
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
...
@@ -31,6 +34,8 @@ namespace paddle {
...
@@ -31,6 +34,8 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
using
namespace
framework
;
TEST
(
Analyzer
,
analysis_without_tensorrt
)
{
TEST
(
Analyzer
,
analysis_without_tensorrt
)
{
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
Argument
argument
;
Argument
argument
;
...
@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
}
}
}
}
if
(
use_analysis
&&
activate_ir
)
{
AnalysisPredictor
*
analysis_predictor
=
dynamic_cast
<
AnalysisPredictor
*>
(
predictor
.
get
());
auto
&
fuse_statis
=
analysis_predictor
->
analysis_argument
()
.
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
framework
::
ir
::
kFuseStatisAttr
);
for
(
auto
&
item
:
fuse_statis
)
{
LOG
(
INFO
)
<<
"fused "
<<
item
.
first
<<
" "
<<
item
.
second
;
}
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc"
),
1
);
}
}
}
// Directly infer with the original model.
// Directly infer with the original model.
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
b3cd2ae8
...
@@ -64,7 +64,8 @@ struct Argument {
...
@@ -64,7 +64,8 @@ struct Argument {
template
<
typename
T
>
template
<
typename
T
>
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
void
Set
(
const
std
::
string
&
key
,
T
*
data
)
{
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE_NOT_NULL
(
data
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"duplicate attr called %s"
,
key
);
PADDLE_ENFORCE
(
!
attrs_
.
count
(
key
),
"Duplicate set Argument's attr [%s]"
,
key
);
attrs_
[
key
]
=
data
;
attrs_
[
key
]
=
data
;
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
attr_deleters_
[
key
]
=
[
data
,
key
,
this
]()
{
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
VLOG
(
3
)
<<
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
;
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <vector>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
...
@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
...
@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
bool
DataFlowGraphToFluidPass
::
Initialize
(
Argument
*
argument
)
{
bool
DataFlowGraphToFluidPass
::
Initialize
(
Argument
*
argument
)
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
)
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
)
PADDLE_ENFORCE
(
!
argument
->
transformed_program_desc
);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first
// from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph.
// block) should rewritten by data flow graph.
...
@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
...
@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
}
}
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
framework
::
ir
::
kParamScopeAttr
))
{
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
LOG
(
WARNING
)
<<
"parameter changes in the scope takes effect"
;
}
}
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.cc
浏览文件 @
b3cd2ae8
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
...
@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const
std
::
string
&
prog_file
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
)
{
const
std
::
string
&
param_file
)
{
PADDLE_ENFORCE
(
argument_
);
PADDLE_ENFORCE
(
argument_
);
argument_
->
Set
(
"param_scope"
,
new
framework
::
Scope
);
argument_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
);
// Load parameters.
// Load parameters.
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
VLOG
(
3
)
<<
"Loading parameters from "
<<
model_dir
;
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
),
model_dir
,
LoadParams
(
&
argument_
->
Get
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
)
,
prog_file
,
param_file
);
model_dir
,
prog_file
,
param_file
);
}
}
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
bool
FluidToIrPass
::
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
...
...
paddle/fluid/inference/analysis/fluid_to_ir_pass.h
浏览文件 @
b3cd2ae8
...
@@ -14,12 +14,14 @@
...
@@ -14,12 +14,14 @@
#pragma once
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
using
namespace
framework
;
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
static
const
char
kFluidToIrPassesAttr
[]
=
"__fluid_to_ir_passes__"
;
...
@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
// Load program.
// Load program.
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
argument
->
origin_program_desc
.
reset
(
argument
->
origin_program_desc
.
reset
(
new
proto
::
ProgramDesc
(
program
));
new
framework
::
proto
::
ProgramDesc
(
program
));
// Create main data flow graph.
// Create main data flow graph.
if
(
!
argument
->
main_dfg
)
{
if
(
!
argument
->
main_dfg
)
{
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
}
}
argument
->
Set
(
"ir_program_desc"
,
new
framework
::
ProgramDesc
(
program
));
argument
->
Set
(
"ir_program_desc"
,
new
ProgramDesc
(
program
));
LOG
(
INFO
)
<<
"Loading parameters"
;
LOG
(
INFO
)
<<
"Loading parameters"
;
// Load parameters to argument if needed.
// Load parameters to argument if needed.
...
@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
void
Run
(
DataFlowGraph
*
graph
)
override
{
void
Run
(
DataFlowGraph
*
graph
)
override
{
// Call all the IR Passes
// Call all the IR Passes
IRPassManager
ir_passes
(
IRPassManager
ir_passes
(
argument_
->
Get
<
ProgramDesc
>
(
"ir_program_desc"
),
argument_
->
Get
<
framework
::
ProgramDesc
>
(
"ir_program_desc"
),
nullptr
);
nullptr
);
// Pass the scope from analysis to IR if needed.
// Pass the scope from analysis to IR if needed.
if
(
argument_
->
Has
(
"param_scope"
))
{
if
(
argument_
->
Has
(
ir
::
kParamScopeAttr
))
{
// Here the address is passed, attention that IR doesn't own the scope, so
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
// the real scope in analysis should live during the IR phase.
ir_passes
.
graph
().
Set
(
ir_passes
.
graph
().
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
ir
::
kParamScopeAttr
,
&
argument_
->
Get
<
framework
::
Scope
>
(
"param_scope"
)));
new
Scope
*
(
&
argument_
->
Get
<
Scope
>
(
ir
::
kParamScopeAttr
)));
}
}
const
auto
&
ir_passes_to_apply
=
const
auto
&
ir_passes_to_apply
=
...
@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
PADDLE_ENFORCE
(
argument_
->
main_dfg
.
get
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
argument_
->
main_dfg
->
Build
(
ir_passes
.
graph
());
// inherit the arguments from ir.
if
(
ir_passes
.
graph
().
Has
(
ir
::
kFuseStatisAttr
))
{
argument_
->
Set
(
ir
::
kFuseStatisAttr
,
new
std
::
unordered_map
<
std
::
string
,
int
>
(
ir_passes
.
graph
().
Get
<
std
::
unordered_map
<
std
::
string
,
int
>>
(
ir
::
kFuseStatisAttr
)));
}
}
}
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
void
EnableParamModify
(
const
std
::
string
&
model_dir
,
...
@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
...
@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private:
private:
// Load parameters from a single file or from a directory.
// Load parameters from a single file or from a directory.
bool
LoadParams
(
framework
::
Scope
*
scope
,
const
std
::
string
&
dir
,
bool
LoadParams
(
Scope
*
scope
,
const
std
::
string
&
dir
,
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
const
std
::
string
&
prog_file
,
const
std
::
string
&
param_file
);
private:
private:
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
b3cd2ae8
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
...
@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
framework
::
Scope
*
scope
)
framework
::
Scope
*
scope
)
:
program_
(
program
)
{
:
program_
(
program
)
{
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
graph_
.
reset
(
new
framework
::
ir
::
Graph
(
program
));
if
(
scope
)
graph_
->
Set
(
"param_scope"
,
new
framework
::
Scope
*
(
scope
));
if
(
scope
)
graph_
->
Set
(
framework
::
ir
::
kParamScopeAttr
,
new
framework
::
Scope
*
(
scope
));
}
}
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
b3cd2ae8
...
@@ -12,121 +12,96 @@
...
@@ -12,121 +12,96 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory>
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace
paddle
{
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
bool
AnalysisPredictor
::
Init
(
using
inference
::
Singleton
;
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
using
inference
::
analysis
::
Analyzer
;
VLOG
(
3
)
<<
"Predictor::init()"
;
using
framework
::
proto
::
ProgramDesc
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
/* This predictor is based on the original native predictor with IR and Analysis
}
else
{
* support. It will optimize IR and Parameters in the runtime.
place_
=
paddle
::
platform
::
CPUPlace
();
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
)
{
VLOG
(
3
)
<<
"Predictor::init()"
;
if
(
config_
.
use_gpu
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device
);
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
}
PADDLE_ENFORCE
(
!
parent_scope
);
if
(
parent_scope
)
{
scope_
=
parent_scope
;
sub_scope_
=
&
(
parent_scope
->
NewScope
());
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
// Initialize the inference program
if
(
!
config_
.
model_dir
.
empty
())
{
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
model_dir
);
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
}
else
{
LOG
(
ERROR
)
<<
"fail to load inference model."
;
return
false
;
}
OptimizeInferenceProgram
();
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch
();
return
true
;
}
}
PADDLE_ENFORCE
(
!
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
if
(
parent_scope
)
{
std
::
vector
<
PaddleTensor
>*
output_data
,
scope_
=
parent_scope
;
int
batch_size
=
-
1
)
override
{
sub_scope_
=
&
(
parent_scope
->
NewScope
());
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
else
{
paddle
::
framework
::
InitDevices
(
false
);
scope_
.
reset
(
new
paddle
::
framework
::
Scope
());
}
}
void
OptimizeInferenceProgram
()
{
executor_
.
reset
(
new
paddle
::
framework
::
Executor
(
place_
));
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
// Initialize the inference program
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
if
(
!
config_
.
model_dir
.
empty
())
{
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Parameters are saved in separate files sited in
// Analyze inference_program
// the specified `dirname`.
Argument
argument
;
inference_program_
=
paddle
::
inference
::
Load
(
executor_
.
get
(),
scope_
.
get
(),
if
(
!
config_
.
model_dir
.
empty
())
{
config_
.
model_dir
);
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
if
(
!
config_
.
prog_file
.
empty
()
&&
!
config_
.
param_file
.
empty
())
{
}
else
{
// All parameters are saved in a single file.
PADDLE_ENFORCE
(
// The file names should be consistent with that used
!
config_
.
param_file
.
empty
(),
// in Python API `fluid.io.save_inference_model`.
"Either model_dir or (param_file, prog_file) should be set."
);
inference_program_
=
paddle
::
inference
::
Load
(
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
executor_
.
get
(),
scope_
.
get
(),
config_
.
prog_file
,
config_
.
param_file
);
argument
.
fluid_model_program_path
.
reset
(
}
else
{
new
std
::
string
(
config_
.
prog_file
));
LOG
(
ERROR
)
<<
"fail to load inference model."
;
argument
.
fluid_model_param_path
.
reset
(
return
false
;
new
std
::
string
(
config_
.
param_file
));
}
argument
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
CHECK
(
argument
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument
.
Has
(
"param_scope"
));
// Update scope.
scope_
.
reset
(
argument
.
Release
<
framework
::
Scope
>
(
"param_scope"
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
}
private:
OptimizeInferenceProgram
();
NativeConfig
config_
;
ctx_
=
executor_
->
Prepare
(
*
inference_program_
,
0
);
};
VLOG
(
5
)
<<
"to create variables"
;
PADDLE_ENFORCE
(
scope_
.
get
());
executor_
->
CreateVariables
(
*
inference_program_
,
sub_scope_
?
sub_scope_
:
scope_
.
get
(),
0
);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch
();
return
true
;
}
void
AnalysisPredictor
::
OptimizeInferenceProgram
()
{
LOG
(
INFO
)
<<
"optimize begin"
;
FLAGS_IA_enable_ir
=
true
;
FLAGS_IA_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_IA_output_storage_path
=
""
;
// Don't output the model.
// Analyze inference_program
if
(
!
config_
.
model_dir
.
empty
())
{
argument_
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
{
PADDLE_ENFORCE
(
!
config_
.
param_file
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
argument_
.
fluid_model_program_path
.
reset
(
new
std
::
string
(
config_
.
prog_file
));
argument_
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
}
argument_
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Analyzer
().
Run
(
&
argument_
);
CHECK
(
argument_
.
transformed_program_desc
);
VLOG
(
5
)
<<
"to prepare executor"
;
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_
.
reset
(
new
framework
::
ProgramDesc
(
*
argument_
.
transformed_program_desc
));
PADDLE_ENFORCE
(
argument_
.
Has
(
framework
::
ir
::
kParamScopeAttr
));
// Update scope.
scope_
.
reset
(
argument_
.
Release
<
framework
::
Scope
>
(
framework
::
ir
::
kParamScopeAttr
));
LOG
(
INFO
)
<<
"optimize end =="
;
}
template
<
>
template
<
>
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
std
::
unique_ptr
<
PaddlePredictor
>
CreatePaddlePredictor
<
...
...
paddle/fluid/inference/api/analysis_predictor.h
0 → 100644
浏览文件 @
b3cd2ae8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace
paddle
{
using
inference
::
analysis
::
Argument
;
using
inference
::
analysis
::
Analyzer
;
using
framework
::
proto
::
ProgramDesc
;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class
AnalysisPredictor
:
public
NativePaddlePredictor
{
public:
explicit
AnalysisPredictor
(
const
NativeConfig
&
config
)
:
NativePaddlePredictor
(
config
),
config_
(
config
)
{}
bool
Init
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
parent_scope
);
bool
Run
(
const
std
::
vector
<
PaddleTensor
>&
inputs
,
std
::
vector
<
PaddleTensor
>*
output_data
,
int
batch_size
=
-
1
)
override
{
return
NativePaddlePredictor
::
Run
(
inputs
,
output_data
,
batch_size
);
}
void
OptimizeInferenceProgram
();
Argument
&
analysis_argument
()
{
return
argument_
;
}
private:
NativeConfig
config_
;
Argument
argument_
;
};
}
// namespace paddle
paddle/fluid/platform/device_context.h
浏览文件 @
b3cd2ae8
...
@@ -24,7 +24,7 @@ limitations under the License. */
...
@@ -24,7 +24,7 @@ limitations under the License. */
#endif
#endif
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
#include
<mkldnn.hpp>
#include
"mkldnn.hpp"
#endif
#endif
#include <map>
#include <map>
...
...
paddle/fluid/pybind/const_value.cc
浏览文件 @
b3cd2ae8
...
@@ -43,6 +43,9 @@ void BindConstValue(pybind11::module* m) {
...
@@ -43,6 +43,9 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker
.
def
(
op_proto_and_checker_maker
.
def
(
"kOpRoleVarAttrName"
,
"kOpRoleVarAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
);
framework
::
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
);
op_proto_and_checker_maker
.
def
(
"kOpNameScopeAttrName"
,
framework
::
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
);
}
}
}
// namespace pybind
}
// namespace pybind
...
...
python/paddle/fluid/framework.py
浏览文件 @
b3cd2ae8
...
@@ -43,6 +43,7 @@ __all__ = [
...
@@ -43,6 +43,7 @@ __all__ = [
'default_main_program'
,
'default_main_program'
,
'program_guard'
,
'program_guard'
,
'get_var'
,
'get_var'
,
'name_scope'
,
]
]
EMPTY_VAR_NAME
=
core
.
kEmptyVarName
()
EMPTY_VAR_NAME
=
core
.
kEmptyVarName
()
...
@@ -52,6 +53,70 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
...
@@ -52,6 +53,70 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX
=
core
.
kControlDepVarName
()
CONTROL_DEP_VAR_PREFIX
=
core
.
kControlDepVarName
()
class
NameScope
(
object
):
def
__init__
(
self
,
name
=
""
,
parent
=
None
):
self
.
_children
=
dict
()
self
.
_name
=
name
self
.
_parent
=
parent
def
child
(
self
,
prefix
):
if
prefix
not
in
self
.
_children
:
new_child
=
NameScope
(
prefix
,
self
)
self
.
_children
[
prefix
]
=
[
new_child
]
else
:
new_child
=
NameScope
(
prefix
+
"_%d"
%
len
(
self
.
_children
[
prefix
]),
self
)
self
.
_children
[
prefix
].
append
(
new_child
)
return
new_child
def
parent
(
self
):
return
self
.
_parent
def
name
(
self
):
return
self
.
_name
_name_scope
=
NameScope
()
@
contextlib
.
contextmanager
def
name_scope
(
prefix
=
None
):
"""
Generate hierarchical name prefix for the operators.
Note: This should only used for debugging and visualization purpose.
Don't use it for serious analysis such as graph/program transformations.
Args:
prefix(str): prefix.
Examples:
.. code-block:: python
with name_scope("encoder"):
...
with name_scope("decoder"):
...
with name_scope("attention"):
...
"""
# TODO(panyx0718): Only [0-9a-z].
assert
prefix
,
"namescope prefix cannot be empty."
global
_name_scope
_name_scope
=
_name_scope
.
child
(
prefix
)
yield
_name_scope
=
_name_scope
.
parent
()
def
_full_name_scope
():
global
_name_scope
scope
=
_name_scope
name
=
""
while
scope
:
name
=
scope
.
name
()
+
"/"
+
name
scope
=
scope
.
parent
()
return
name
def
generate_control_dev_var_name
():
def
generate_control_dev_var_name
():
import
random
import
random
return
CONTROL_DEP_VAR_PREFIX
+
"@"
+
str
(
random
.
random
())
return
CONTROL_DEP_VAR_PREFIX
+
"@"
+
str
(
random
.
random
())
...
@@ -515,6 +580,9 @@ class Operator(object):
...
@@ -515,6 +580,9 @@ class Operator(object):
self
.
desc
.
set_type
(
type
)
self
.
desc
.
set_type
(
type
)
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
type
)
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
type
)
namescope_var_name
=
op_maker
.
kOpNameScopeAttrName
()
op_attrs
[
namescope_var_name
]
=
_full_name_scope
()
def
find_name
(
var_list
,
name
):
def
find_name
(
var_list
,
name
):
for
var_name
in
var_list
:
for
var_name
in
var_list
:
if
var_list
[
var_name
]
is
not
None
and
var_name
==
name
:
if
var_list
[
var_name
]
is
not
None
and
var_name
==
name
:
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
b3cd2ae8
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
re
import
re
from
collections
import
defaultdict
from
collections
import
defaultdict
from
paddle.fluid.framework
import
Program
,
Variable
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
from
.
import
framework
from
.
import
framework
from
.
import
layers
from
.
import
layers
from
.backward
import
append_backward
from
.backward
import
append_backward
...
@@ -237,7 +237,7 @@ class Optimizer(object):
...
@@ -237,7 +237,7 @@ class Optimizer(object):
if
param_and_grad
[
1
]
is
None
:
if
param_and_grad
[
1
]
is
None
:
continue
continue
with
param_and_grad
[
0
].
block
.
program
.
optimized_guard
(
with
param_and_grad
[
0
].
block
.
program
.
optimized_guard
(
param_and_grad
):
param_and_grad
)
,
name_scope
(
"optimizer"
)
:
if
param_and_grad
[
0
].
trainable
is
True
:
if
param_and_grad
[
0
].
trainable
is
True
:
optimize_op
=
self
.
_append_optimize_op
(
loss
.
block
,
optimize_op
=
self
.
_append_optimize_op
(
loss
.
block
,
param_and_grad
)
param_and_grad
)
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
b3cd2ae8
...
@@ -82,8 +82,18 @@ class TestDistRunnerBase(object):
...
@@ -82,8 +82,18 @@ class TestDistRunnerBase(object):
strategy
=
fluid
.
ExecutionStrategy
()
strategy
=
fluid
.
ExecutionStrategy
()
strategy
.
num_threads
=
1
strategy
.
num_threads
=
1
strategy
.
allow_op_delay
=
False
strategy
.
allow_op_delay
=
False
build_stra
=
fluid
.
BuildStrategy
()
if
args
.
use_reduce
:
build_stra
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
else
:
build_stra
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
exe
=
fluid
.
ParallelExecutor
(
exe
=
fluid
.
ParallelExecutor
(
True
,
loss_name
=
avg_cost
.
name
,
exec_strategy
=
strategy
)
True
,
loss_name
=
avg_cost
.
name
,
exec_strategy
=
strategy
,
build_strategy
=
build_stra
)
feed_var_list
=
[
feed_var_list
=
[
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
...
@@ -123,6 +133,7 @@ def runtime_main(test_class):
...
@@ -123,6 +133,7 @@ def runtime_main(test_class):
'--current_endpoint'
,
type
=
str
,
required
=
False
,
default
=
""
)
'--current_endpoint'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--mem_opt'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--mem_opt'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_reduce'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -149,20 +160,25 @@ class TestDistBase(unittest.TestCase):
...
@@ -149,20 +160,25 @@ class TestDistBase(unittest.TestCase):
self
.
_python_interp
=
"python"
self
.
_python_interp
=
"python"
self
.
_sync_mode
=
True
self
.
_sync_mode
=
True
self
.
_mem_opt
=
False
self
.
_mem_opt
=
False
self
.
_use_reduce
=
False
self
.
_setup_config
()
self
.
_setup_config
()
def
start_pserver
(
self
,
model_file
,
check_error_log
):
def
start_pserver
(
self
,
model_file
,
check_error_log
):
ps0_ep
,
ps1_ep
=
self
.
_ps_endpoints
.
split
(
","
)
ps0_ep
,
ps1_ep
=
self
.
_ps_endpoints
.
split
(
","
)
ps_cmd
=
"%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist %s %s"
ps_cmd
=
"%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
sync_mode_str
=
"--sync_mode"
if
self
.
_sync_mode
else
""
mem_opt_str
=
"--mem_opt"
if
self
.
_mem_opt
else
""
ps0_cmd
=
ps_cmd
%
\
ps0_cmd
=
ps_cmd
%
\
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
ps0_ep
,
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
ps0_ep
,
self
.
_trainers
,
sync_mode_str
,
mem_opt_str
)
self
.
_trainers
)
ps1_cmd
=
ps_cmd
%
\
ps1_cmd
=
ps_cmd
%
\
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
ps1_ep
,
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
ps1_ep
,
self
.
_trainers
,
sync_mode_str
,
mem_opt_str
)
self
.
_trainers
)
if
self
.
_sync_mode
:
ps0_cmd
+=
" --sync_mode"
ps1_cmd
+=
" --sync_mode"
if
self
.
_mem_opt
:
ps0_cmd
+=
" --mem_opt"
ps1_cmd
+=
" --mem_opt"
ps0_pipe
=
subprocess
.
PIPE
ps0_pipe
=
subprocess
.
PIPE
ps1_pipe
=
subprocess
.
PIPE
ps1_pipe
=
subprocess
.
PIPE
...
@@ -242,17 +258,23 @@ class TestDistBase(unittest.TestCase):
...
@@ -242,17 +258,23 @@ class TestDistBase(unittest.TestCase):
self
.
_wait_ps_ready
(
ps1
.
pid
)
self
.
_wait_ps_ready
(
ps1
.
pid
)
ps0_ep
,
ps1_ep
=
self
.
_ps_endpoints
.
split
(
","
)
ps0_ep
,
ps1_ep
=
self
.
_ps_endpoints
.
split
(
","
)
tr_cmd
=
"%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist %s %s"
tr_cmd
=
"%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
sync_mode_str
=
"--sync_mode"
if
self
.
_sync_mode
else
""
mem_opt_str
=
"--mem_opt"
if
self
.
_mem_opt
else
""
tr0_cmd
=
tr_cmd
%
\
tr0_cmd
=
tr_cmd
%
\
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
0
,
ps0_ep
,
0
,
ps0_ep
,
self
.
_trainers
)
self
.
_trainers
,
sync_mode_str
,
mem_opt_str
)
tr1_cmd
=
tr_cmd
%
\
tr1_cmd
=
tr_cmd
%
\
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
(
self
.
_python_interp
,
model_file
,
self
.
_ps_endpoints
,
1
,
ps1_ep
,
1
,
ps1_ep
,
self
.
_trainers
)
self
.
_trainers
,
sync_mode_str
,
mem_opt_str
)
if
self
.
_sync_mode
:
tr0_cmd
+=
" --sync_mode"
tr1_cmd
+=
" --sync_mode"
if
self
.
_mem_opt
:
tr0_cmd
+=
" --mem_opt"
tr1_cmd
+=
" --mem_opt"
if
self
.
_use_reduce
:
tr0_cmd
+=
" --use_reduce"
tr1_cmd
+=
" --use_reduce"
env0
=
{
"CUDA_VISIBLE_DEVICES"
:
"0"
}
env0
=
{
"CUDA_VISIBLE_DEVICES"
:
"0"
}
env1
=
{
"CUDA_VISIBLE_DEVICES"
:
"1"
}
env1
=
{
"CUDA_VISIBLE_DEVICES"
:
"1"
}
...
@@ -303,6 +325,8 @@ class TestDistBase(unittest.TestCase):
...
@@ -303,6 +325,8 @@ class TestDistBase(unittest.TestCase):
# FIXME: use terminate() instead of sigkill.
# FIXME: use terminate() instead of sigkill.
os
.
kill
(
ps0
.
pid
,
signal
.
SIGKILL
)
os
.
kill
(
ps0
.
pid
,
signal
.
SIGKILL
)
os
.
kill
(
ps1
.
pid
,
signal
.
SIGKILL
)
os
.
kill
(
ps1
.
pid
,
signal
.
SIGKILL
)
ps0
.
terminate
()
ps1
.
terminate
()
ps0
.
wait
()
ps0
.
wait
()
ps1
.
wait
()
ps1
.
wait
()
FNULL
.
close
()
FNULL
.
close
()
...
...
python/paddle/fluid/tests/unittests/test_dist_mnist.py
浏览文件 @
b3cd2ae8
...
@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase
...
@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase
class
TestDistMnist2x2
(
TestDistBase
):
class
TestDistMnist2x2
(
TestDistBase
):
def
_setup_config
(
self
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_sync_mode
=
True
self
.
_use_reduce
=
False
def
test_se_resnext
(
self
):
def
test_se_resnext
(
self
):
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
1e-7
)
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
1e-7
)
...
@@ -37,10 +38,30 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
...
@@ -37,10 +38,30 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
class
TestDistMnistAsync
(
TestDistBase
):
class
TestDistMnistAsync
(
TestDistBase
):
def
_setup_config
(
self
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_sync_mode
=
False
self
.
_use_reduce
=
False
def
test_se_resnext
(
self
):
def
test_se_resnext
(
self
):
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
200
)
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
200
)
# FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated.
#
# class TestDistMnist2x2ReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = True
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=1e-7)
# class TestDistMnistAsyncReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = False
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=200)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_name_scope.py
0 → 100644
浏览文件 @
b3cd2ae8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
class
TestNameScope
(
unittest
.
TestCase
):
def
test_name_scope
(
self
):
with
fluid
.
name_scope
(
"s1"
):
a
=
fluid
.
layers
.
data
(
name
=
'data'
,
shape
=
[
1
],
dtype
=
'int32'
)
b
=
a
+
1
with
fluid
.
name_scope
(
"s2"
):
c
=
b
*
1
with
fluid
.
name_scope
(
"s3"
):
d
=
c
/
1
with
fluid
.
name_scope
(
"s1"
):
f
=
fluid
.
layers
.
pow
(
d
,
2.0
)
with
fluid
.
name_scope
(
"s4"
):
g
=
f
-
1
for
op
in
fluid
.
default_main_program
().
block
(
0
).
ops
:
if
op
.
type
==
'elementwise_add'
:
self
.
assertEqual
(
op
.
desc
.
attr
(
"op_namescope"
),
'/s1/'
)
elif
op
.
type
==
'elementwise_mul'
:
self
.
assertEqual
(
op
.
desc
.
attr
(
"op_namescope"
),
'/s1/s2/'
)
elif
op
.
type
==
'elementwise_div'
:
self
.
assertEqual
(
op
.
desc
.
attr
(
"op_namescope"
),
'/s1/s3/'
)
elif
op
.
type
==
'elementwise_sub'
:
self
.
assertEqual
(
op
.
desc
.
attr
(
"op_namescope"
),
'/s4/'
)
elif
op
.
type
==
'pow'
:
self
.
assertEqual
(
op
.
desc
.
attr
(
"op_namescope"
),
'/s1_1/'
)
python/paddle/fluid/tests/unittests/test_operator_desc.py
浏览文件 @
b3cd2ae8
...
@@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase):
...
@@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase):
self
.
assertEqual
(
mul_op
.
output
(
"Out"
),
[
"mul.out"
])
self
.
assertEqual
(
mul_op
.
output
(
"Out"
),
[
"mul.out"
])
self
.
assertEqual
(
self
.
assertEqual
(
set
(
mul_op
.
attr_names
),
set
(
mul_op
.
attr_names
),
set
([
"x_num_col_dims"
,
"y_num_col_dims"
,
"op_role"
,
"op_role_var"
]))
set
([
"x_num_col_dims"
,
"y_num_col_dims"
,
"op_role"
,
"op_role_var"
,
"op_namescope"
]))
self
.
assertEqual
(
mul_op
.
has_attr
(
"x_num_col_dims"
),
True
)
self
.
assertEqual
(
mul_op
.
has_attr
(
"x_num_col_dims"
),
True
)
self
.
assertEqual
(
mul_op
.
attr_type
(
"x_num_col_dims"
),
core
.
AttrType
.
INT
)
self
.
assertEqual
(
mul_op
.
attr_type
(
"x_num_col_dims"
),
core
.
AttrType
.
INT
)
self
.
assertEqual
(
mul_op
.
attr
(
"x_num_col_dims"
),
1
)
self
.
assertEqual
(
mul_op
.
attr
(
"x_num_col_dims"
),
1
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
浏览文件 @
b3cd2ae8
...
@@ -67,18 +67,20 @@ def fc_with_batchnorm(use_feed):
...
@@ -67,18 +67,20 @@ def fc_with_batchnorm(use_feed):
hidden
=
img
hidden
=
img
for
_
in
range
(
1
):
for
_
in
range
(
1
):
hidden
=
fluid
.
layers
.
fc
(
with
fluid
.
name_scope
(
"hidden"
):
hidden
,
hidden
=
fluid
.
layers
.
fc
(
size
=
200
,
hidden
,
act
=
'tanh'
,
size
=
200
,
bias_attr
=
fluid
.
ParamAttr
(
act
=
'tanh'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
)))
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
)))
hidden
=
fluid
.
layers
.
batch_norm
(
input
=
hidden
)
hidden
=
fluid
.
layers
.
batch_norm
(
input
=
hidden
)
prediction
=
fluid
.
layers
.
fc
(
hidden
,
size
=
10
,
act
=
'softmax'
)
with
fluid
.
name_scope
(
"fc_layer"
):
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
prediction
=
fluid
.
layers
.
fc
(
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
mean
(
loss
)
with
fluid
.
name_scope
(
"loss"
):
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
return
loss
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b3cd2ae8
...
@@ -273,6 +273,10 @@ class DistributeTranspiler(object):
...
@@ -273,6 +273,10 @@ class DistributeTranspiler(object):
name
=
framework
.
generate_control_dev_var_name
())
name
=
framework
.
generate_control_dev_var_name
())
grad_name_to_send_dummy_out
[
grad_varname
]
=
dummy_output
grad_name_to_send_dummy_out
[
grad_varname
]
=
dummy_output
# get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send
# will be on the same place). ParallelExecutor
# will use op_role_var to get expected device place to run this op.
program
.
global_block
().
_insert_op
(
program
.
global_block
().
_insert_op
(
index
=
index
+
1
,
index
=
index
+
1
,
type
=
"send"
,
type
=
"send"
,
...
@@ -281,8 +285,10 @@ class DistributeTranspiler(object):
...
@@ -281,8 +285,10 @@ class DistributeTranspiler(object):
attrs
=
{
attrs
=
{
"epmap"
:
eplist
,
"epmap"
:
eplist
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
OP_ROLE_VAR_ATTR_NAME
:
[
[
self
.
grad_name_to_param_name
[
grad_varname
],
grad_varname
],
self
.
grad_name_to_param_name
[
grad_varname
],
splited_grad_varname
],
"sync_mode"
:
not
self
.
sync_mode
,
"sync_mode"
:
not
self
.
sync_mode
,
})
})
for
_
,
var
in
enumerate
(
splited_vars
):
for
_
,
var
in
enumerate
(
splited_vars
):
...
@@ -326,6 +332,15 @@ class DistributeTranspiler(object):
...
@@ -326,6 +332,15 @@ class DistributeTranspiler(object):
recv_dep_in
=
grad_name_to_send_dummy_out
[
recv_dep_in
=
grad_name_to_send_dummy_out
[
self
.
param_name_to_grad_name
[
param_varname
]]
self
.
param_name_to_grad_name
[
param_varname
]]
all_recv_outputs
.
extend
(
splited_var
)
all_recv_outputs
.
extend
(
splited_var
)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op.
orig_grad_name
=
self
.
param_name_to_grad_name
[
param_varname
]
recv_op_role_var_name
=
orig_grad_name
splited_trainer_grad
=
self
.
grad_var_mapping
[
orig_grad_name
]
if
len
(
splited_trainer_grad
)
==
1
:
recv_op_role_var_name
=
splited_trainer_grad
[
0
].
name
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"recv"
,
type
=
"recv"
,
inputs
=
{
"X"
:
[
recv_dep_in
]},
inputs
=
{
"X"
:
[
recv_dep_in
]},
...
@@ -333,10 +348,8 @@ class DistributeTranspiler(object):
...
@@ -333,10 +348,8 @@ class DistributeTranspiler(object):
attrs
=
{
attrs
=
{
"epmap"
:
eps
,
"epmap"
:
eps
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
OP_ROLE_VAR_ATTR_NAME
:
param_varname
,
[
param_varname
,
recv_op_role_var_name
],
self
.
param_name_to_grad_name
[
param_varname
]
],
"sync_mode"
:
not
self
.
sync_mode
"sync_mode"
:
not
self
.
sync_mode
})
})
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录