Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
71e350c5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
71e350c5
编写于
7月 27, 2020
作者:
A
Adam
提交者:
GitHub
7月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix FC + GRU fuse pass (#25733)
上级
2d7e7759
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
38 addition
and
31 deletion
+38
-31
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
+34
-29
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
+0
-2
paddle/fluid/operators/fused/fusion_gru_op.cc
paddle/fluid/operators/fused/fusion_gru_op.cc
+4
-0
未找到文件。
paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
浏览文件 @
71e350c5
...
...
@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
patterns
::
GRU
gru_pattern
(
pattern
,
name_scope
);
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
UniqueKey
(
"x"
))
->
assert_var_not_persistable
();
// Create pattern.
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
auto
*
fc_out
=
fc_pattern
(
x
,
with_fc_bias
,
/* with_relu */
false
);
fc_out
->
AsIntermediate
();
// fc_out is a tmp var, will be removed after fuse.
patterns
::
GRU
gru_pattern
(
pattern
,
name_scope
);
gru_pattern
(
fc_out
);
// Create New OpDesc
...
...
@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
SET_IN
(
X
,
x
);
SET_IN
(
WeightX
,
weight_x
);
SET_IN
(
WeightH
,
weight_h
);
if
(
with_fc_bias
)
{
op_desc
.
SetInput
(
"Bias"
,
{
NEW_NAME
(
bias
)
+
bias
->
Name
()});
}
else
{
SET_IN
(
Bias
,
bias
);
}
SET_IN
(
Bias
,
bias
);
#undef SET_IN
// TODO(grygielski): Add H0 to the pass
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
->
Name
()});
op_desc
.
SetAttr
(
"is_reverse"
,
gru
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"origin_mode"
,
gru
->
Op
()
->
GetAttrIfExists
<
bool
>
(
"origin_mode"
));
// TODO(TJ): This should be a option for infer
op_desc
.
SetAttr
(
"use_seq"
,
true
);
op_desc
.
SetAttr
(
"activation"
,
gru
->
Op
()
->
GetAttr
(
"activation"
));
op_desc
.
SetAttr
(
"gate_activation"
,
gru
->
Op
()
->
GetAttr
(
"gate_activation"
));
#define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)})
SET_IMTERMEDIATE_OUT
(
ReorderedH0
);
...
...
@@ -68,26 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IMTERMEDIATE_OUT
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
&
scope
=
graph
->
Get
<
Scope
>
(
kParamScopeAttr
);
if
(
with_fc_bias
)
{
// Fusion GRU bias = fcbias + grubias
auto
*
fusion_bias_var
=
scope
.
Var
(
NEW_NAME
(
bias
)
+
bias
->
Name
());
auto
*
out_bias_tensor
=
fusion_bias_var
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE
(
fusion_bias_var
);
auto
*
gru_bias_var
=
scope
.
FindVar
(
bias
->
Name
());
auto
*
fc_bias_var
=
scope
.
FindVar
(
fc_bias
->
Name
());
PADDLE_ENFORCE
(
gru_bias_var
);
PADDLE_ENFORCE
(
fc_bias_var
);
const
auto
&
gru_bias_tenosr
=
gru_bias_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
// new bias = fc bias + gru bias
out_bias_tensor
->
Resize
(
gru_bias_tenosr
.
dims
());
auto
*
data
=
out_bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
out_bias_tensor
->
numel
();
i
++
)
{
data
[
i
]
=
fc_bias_tensor
.
data
<
float
>
()[
i
]
+
gru_bias_tenosr
.
data
<
float
>
()[
i
];
auto
*
gru_bias_var
=
scope
->
FindVar
(
bias
->
Name
());
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias
->
Name
());
PADDLE_ENFORCE_NE
(
gru_bias_var
,
nullptr
,
platform
::
errors
::
NotFound
(
"GRU bias var has not been found."
));
PADDLE_ENFORCE_NE
(
fc_bias_var
,
nullptr
,
platform
::
errors
::
NotFound
(
"FC bias var has not been found."
));
auto
*
gru_bias_tensor
=
gru_bias_var
->
GetMutable
<
LoDTensor
>
();
auto
*
fc_bias_tensor
=
fc_bias_var
->
GetMutable
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
gru_bias_tensor
->
numel
(),
fc_bias_tensor
->
numel
(),
platform
::
errors
::
PreconditionNotMet
(
"GRU and FC biases have to have equal number of elements."
));
auto
gru_bias_data
=
gru_bias_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
fc_bias_data
=
fc_bias_tensor
->
data
<
float
>
();
// Recompute GRU bias
for
(
int
i
=
0
;
i
<
gru_bias_tensor
->
numel
();
++
i
)
{
gru_bias_data
[
i
]
+=
fc_bias_data
[
i
];
}
}
#undef GET_NODE
...
...
@@ -108,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
IR_NODE_LINK_TO
(
x
,
op
);
IR_NODE_LINK_TO
(
weight_x
,
op
);
IR_NODE_LINK_TO
(
weight_h
,
op
);
IR_NODE_LINK_TO
(
bias
,
op
);
// actually should link to new bias if have
IR_NODE_LINK_TO
(
bias
,
op
);
IR_NODE_LINK_TO
(
op
,
hidden
);
// h0?
return
op
;
...
...
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
浏览文件 @
71e350c5
...
...
@@ -56,8 +56,6 @@ void SetConfig(AnalysisConfig *cfg) {
cfg
->
DisableGpu
();
cfg
->
SwitchIrDebug
();
cfg
->
SwitchSpecifyInputNames
(
false
);
// TODO(TJ): fix fusion gru
cfg
->
pass_builder
()
->
DeletePass
(
"fc_gru_fuse_pass"
);
}
void
SetInput
(
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
*
inputs
)
{
...
...
paddle/fluid/operators/fused/fusion_gru_op.cc
浏览文件 @
71e350c5
...
...
@@ -183,6 +183,10 @@ void FusionGRUOpMaker::Make() {
"(bool, default: True) "
"whether to use seq mode to compute GRU."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"origin_mode"
,
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录