Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ae67dcea
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae67dcea
编写于
9月 12, 2018
作者:
T
Tao Luo
提交者:
GitHub
9月 12, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13366 from luotao1/fusion_lstm_bug
fix fusion_lstm unique_name bug
上级
f6cbe10a
b12322ce
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
8 addition
and
5 deletion
+8
-5
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+2
-3
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+6
-2
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
ae67dcea
...
@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
...
@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
if
(
with_fc_bias
)
{
if
(
with_fc_bias
)
{
// Add FC-bias with LSTM-bias and create a new weight
// Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE
(
scope
);
PADDLE_ENFORCE
(
scope
);
const
std
::
string
&
new_bias_var
=
name_scope
+
"_bias.new"
;
const
std
::
string
&
new_bias_var
=
patterns
::
UniqueKey
(
"NewBias"
)
;
auto
*
bias_var
=
scope
->
Var
(
new_bias_var
);
auto
*
bias_var
=
scope
->
Var
(
new_bias_var
);
PADDLE_ENFORCE
(
bias_var
);
PADDLE_ENFORCE
(
bias_var
);
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
bias_tensor
=
bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
...
@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
lstm
,
lstm
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lstm
,
lstm
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
lstm_pattern
);
...
@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
...
@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
fc_bias
);
fc_bias
);
// Remove unneeded nodes.
// Remove unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
mul
,
lstm
,
elementwise_add
});
{
mul
,
lstm
,
elementwise_add
,
fc_bias
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
}
else
{
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
mul_out
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
mul_out
,
fc_pattern
);
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
ae67dcea
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
...
@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
void
IRPassManager
::
Apply
(
const
std
::
vector
<
std
::
string
>
&
passes
)
{
// Apply all the passes
// Apply all the passes
std
::
string
pre_pass
;
std
::
string
pre_pass
;
int
pass_num
=
0
;
for
(
const
std
::
string
&
pass_name
:
passes
)
{
for
(
const
std
::
string
&
pass_name
:
passes
)
{
PrettyLogEndl
(
Style
::
H2
(),
"--- Running IR pass [%s]"
,
pass_name
);
PrettyLogEndl
(
Style
::
H2
(),
"--- Running IR pass [%s]"
,
pass_name
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_name
);
if
(
pass_name
==
"graph_viz_pass"
)
{
if
(
pass_name
==
"graph_viz_pass"
)
{
std
::
string
dot_file_path
=
std
::
string
dot_file_path
=
std
::
to_string
(
pass_num
)
+
"_ir_"
+
"ir_"
+
(
pre_pass
.
empty
()
?
"origin"
:
pre_pass
)
+
".dot"
;
(
pre_pass
.
empty
()
?
"origin"
:
pre_pass
)
+
".dot"
;
pass
->
Set
(
"graph_viz_path"
,
new
std
::
string
(
std
::
move
(
dot_file_path
)));
pass
->
Set
(
"graph_viz_path"
,
new
std
::
string
(
std
::
move
(
dot_file_path
)));
pass_num
++
;
}
}
graph_
=
pass
->
Apply
(
std
::
move
(
graph_
));
graph_
=
pass
->
Apply
(
std
::
move
(
graph_
));
pre_pass
=
pass_name
;
pre_pass
=
pass_name
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录