Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d55f3b6f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d55f3b6f
编写于
6月 23, 2021
作者:
王
王明冬
提交者:
GitHub
6月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add compat precondition for attention_lstm_fuse_pass, test=develop (#33711)
上级
10171806
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
78 addition
and
3 deletion
+78
-3
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+60
-1
paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
+6
-0
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
+1
-1
paddle/fluid/framework/ir/op_compat_sensible_pass.h
paddle/fluid/framework/ir/op_compat_sensible_pass.h
+11
-0
paddle/fluid/operators/compat/fill_constant.pbtxt
paddle/fluid/operators/compat/fill_constant.pbtxt
+0
-1
未找到文件。
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
浏览文件 @
d55f3b6f
...
...
@@ -23,6 +23,61 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
AttentionLSTMFusePass
::
AttentionLSTMFusePass
()
{
AddOpCompat
(
OpCompat
(
"while"
))
.
AddInput
(
"X"
)
// A set of variables, unconstrained
.
End
()
.
AddInput
(
"Condition"
)
// An scalar
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// A set of variables, unconstrained
.
End
()
.
AddOutput
(
"StepScopes"
)
// A vector of local scope, unconstrained
.
End
()
.
AddAttr
(
"sub_block"
)
.
IsType
<
framework
::
BlockDesc
*>
()
.
End
();
AddOpCompat
(
OpCompat
(
"fill_constant"
))
.
AddInput
(
"ValueTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensorList"
)
// vector<Tensor<int>>
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"dtype"
)
.
IsNumGE
(
0
)
.
IsNumLE
(
25
)
.
End
()
.
AddAttr
(
"shape"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"value"
)
.
IsType
<
float
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"sequence_expand"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"ref_level"
)
.
IsNumGE
(
-
1
)
.
End
();
}
struct
Param
{
std
::
string
X
=
"concat_0.tmp_0"
;
std
::
string
C0
=
"cell_init"
;
...
...
@@ -43,7 +98,7 @@ struct Param {
void
PrepareParameters
(
Graph
*
graph
,
const
Param
&
param
,
ir
::
Node
*
lstm_op
);
void
FindWhileOp
(
Graph
*
graph
)
{
void
AttentionLSTMFusePass
::
FindWhileOp
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
std
::
unordered_set
<
int
>
fused_external_ops
(
{
35
,
36
,
37
,
38
,
43
,
44
,
49
,
45
,
46
,
47
,
41
,
42
,
53
,
54
,
48
,
...
...
@@ -60,6 +115,10 @@ void FindWhileOp(Graph* graph) {
auto
handle
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
auto
*
while_pat_node
=
gpd
.
pattern
().
RetrieveNode
(
"while"
);
auto
*
while_node
=
subgraph
.
at
(
while_pat_node
);
marked_nodes
.
insert
(
while_node
);
...
...
paddle/fluid/framework/ir/attention_lstm_fuse_pass.h
浏览文件 @
d55f3b6f
...
...
@@ -23,8 +23,14 @@ namespace ir {
class
Graph
;
class
AttentionLSTMFusePass
:
public
FusePassBase
{
public:
AttentionLSTMFusePass
();
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
FindWhileOp
(
Graph
*
graph
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/op_compat_sensible_pass.cc
浏览文件 @
d55f3b6f
...
...
@@ -260,7 +260,7 @@ bool OpCompatSensiblePass::IsCompat(
auto
op_type
=
node_pair
.
second
->
Op
()
->
Type
();
if
(
!
op_compat_judgers_
.
count
(
op_type
))
{
if
(
HasOpDef
(
op_type
))
{
LOG
(
WARNING
)
<<
op_type
<<
"compat not registered!"
;
LOG
(
WARNING
)
<<
op_type
<<
"
compat not registered!"
;
return
false
;
}
continue
;
...
...
paddle/fluid/framework/ir/op_compat_sensible_pass.h
浏览文件 @
d55f3b6f
...
...
@@ -31,6 +31,10 @@ class AttrCompat {
AttrCompat
(
const
std
::
string
&
attr_name
,
OpCompat
*
op_compat
)
:
optional_
(
false
),
attr_name_
(
attr_name
),
op_compat_
(
op_compat
)
{}
//! Assert the attribute type is `T`.
template
<
typename
T
>
AttrCompat
&
IsType
();
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat
&
IsStringIn
(
const
std
::
set
<
std
::
string
>&
candidates
);
...
...
@@ -207,6 +211,13 @@ class OpCompatSensiblePass : public Pass {
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
OpCompat
>>
op_compat_judgers_
;
};
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsType
()
{
conditions_
.
emplace_back
(
[](
const
Attribute
&
attr
)
->
bool
{
return
attr
.
type
()
==
typeid
(
T
);
});
return
*
this
;
}
template
<
typename
T
>
AttrCompat
&
AttrCompat
::
IsNumGT
(
T
v
)
{
conditions_
.
emplace_back
([
v
](
const
Attribute
&
attr
)
->
bool
{
...
...
paddle/fluid/operators/compat/fill_constant.pbtxt
浏览文件 @
d55f3b6f
...
...
@@ -24,7 +24,6 @@ def {
name: "value"
type: FLOAT
}
}
extra {
attrs {
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录