Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7b5a8e46
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b5a8e46
编写于
11月 25, 2020
作者:
W
Wojciech Uss
提交者:
GitHub
11月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add multi_gru_fuse_pass and tests (#28601)
* Add multi_gru_fuse_pass and tests * fix date * cleaned up headers
上级
bb16c251
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
403 addition
and
5 deletion
+403
-5
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+51
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+23
-0
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.cc
+123
-0
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h
+42
-0
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass_tester.cc
...e/fluid/framework/ir/mkldnn/multi_gru_fuse_pass_tester.cc
+156
-0
paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc
+5
-5
tools/static_mode_white_list.py
tools/static_mode_white_list.py
+1
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
7b5a8e46
...
...
@@ -111,6 +111,7 @@ if(WITH_MKLDNN)
pass_library
(
reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
matmul_transpose_reshape_fuse_pass inference DIR mkldnn
)
pass_library
(
batch_norm_act_fuse_pass inference DIR mkldnn
)
pass_library
(
multi_gru_fuse_pass inference DIR mkldnn
)
pass_library
(
multi_gru_seq_fuse_pass inference DIR mkldnn
)
endif
()
...
...
@@ -170,5 +171,6 @@ endif()
cc_test
(
test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass
)
cc_test
(
test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass
)
cc_test
(
test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass
)
cc_test
(
test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass
)
cc_test
(
test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass
)
endif
()
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
7b5a8e46
...
...
@@ -2511,6 +2511,57 @@ PDNode *patterns::FusionGru::operator()() {
return
out
;
}
PDNode
*
patterns
::
TwoFusionGruConcat
::
operator
()()
{
auto
x
=
pattern
->
NewNode
(
x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fusion_gru"
,
"X"
);
auto
gru1
=
pattern
->
NewNode
(
gru1_repr
())
->
assert_is_op
(
"fusion_gru"
)
->
assert_more
([
&
](
Node
*
node
)
{
return
node
->
Op
()
->
GetAttrIfExists
<
bool
>
(
"is_reverse"
)
==
false
;
});
auto
gru2
=
pattern
->
NewNode
(
gru2_repr
())
->
assert_is_op
(
"fusion_gru"
)
->
assert_more
([
&
](
Node
*
node
)
{
return
node
->
Op
()
->
GetAttrIfExists
<
bool
>
(
"is_reverse"
)
==
true
;
});
auto
wh1
=
pattern
->
NewNode
(
wh1_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fusion_gru"
,
"WeightH"
);
auto
wh2
=
pattern
->
NewNode
(
wh2_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fusion_gru"
,
"WeightH"
);
auto
wx1
=
pattern
->
NewNode
(
wx1_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fusion_gru"
,
"WeightX"
);
auto
wx2
=
pattern
->
NewNode
(
wx2_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fusion_gru"
,
"WeightX"
);
auto
b1
=
pattern
->
NewNode
(
b1_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fusion_gru"
,
"Bias"
);
auto
b2
=
pattern
->
NewNode
(
b2_repr
())
->
AsInput
()
->
assert_is_op_input
(
"fusion_gru"
,
"Bias"
);
auto
h1
=
pattern
->
NewNode
(
h1_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"fusion_gru"
,
"Hidden"
)
->
assert_is_op_input
(
"concat"
)
->
AsIntermediate
();
auto
h2
=
pattern
->
NewNode
(
h2_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"fusion_gru"
,
"Hidden"
)
->
assert_is_op_input
(
"concat"
)
->
AsIntermediate
();
auto
concat
=
pattern
->
NewNode
(
concat_repr
())
->
assert_is_op
(
"concat"
);
auto
out
=
pattern
->
NewNode
(
out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"concat"
,
"Out"
);
gru1
->
LinksFrom
({
x
,
wh1
,
wx1
,
b1
}).
LinksTo
({
h1
});
gru2
->
LinksFrom
({
x
,
wh2
,
wx2
,
b2
}).
LinksTo
({
h2
});
concat
->
LinksFrom
({
h1
,
h2
}).
LinksTo
({
out
});
return
out
;
}
PDNode
*
patterns
::
MultiGruSeq
::
operator
()()
{
auto
x
=
pattern
->
NewNode
(
x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"multi_gru"
,
"X"
);
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
7b5a8e46
...
...
@@ -1420,6 +1420,29 @@ struct FusionGru : public PatternBase {
PATTERN_DECL_NODE
(
out
);
};
// two concatenated fusion_gru ops
// Forward pass for fusion of two concatenated fusion_gru ops.
// concat_out is a result of the operator().
struct
TwoFusionGruConcat
:
public
PatternBase
{
TwoFusionGruConcat
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"bi_fusion_gru"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
gru1
);
PATTERN_DECL_NODE
(
gru2
);
PATTERN_DECL_NODE
(
wh1
);
PATTERN_DECL_NODE
(
wh2
);
PATTERN_DECL_NODE
(
wx1
);
PATTERN_DECL_NODE
(
wx2
);
PATTERN_DECL_NODE
(
b1
);
PATTERN_DECL_NODE
(
b2
);
PATTERN_DECL_NODE
(
h1
);
PATTERN_DECL_NODE
(
h2
);
PATTERN_DECL_NODE
(
concat
);
PATTERN_DECL_NODE
(
out
);
};
// two subsequent bi_fusion_gru ops
// Forward pass for fusion of two subsequent fusion_gru ops.
// Hidden of the last fusion_gru op is a result of the operator().
...
...
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.cc
0 → 100644
浏览文件 @
7b5a8e46
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
EigenVectorArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
double
,
Eigen
::
Dynamic
,
1
>>
;
using
string
::
PrettyLogDetail
;
namespace
{
std
::
vector
<
std
::
string
>
JoinInputs
(
Node
*
op1
,
Node
*
op2
,
std
::
string
input_name
)
{
auto
in1
=
op1
->
Op
()
->
Input
(
input_name
);
auto
&
in2
=
op2
->
Op
()
->
Input
(
input_name
);
in1
.
insert
(
in1
.
end
(),
in2
.
begin
(),
in2
.
end
());
return
in1
;
}
}
// namespace
void
MultiGRUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
3
)
<<
"Fusing two concatenated multi_gru ops."
;
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Pointer to graph argument cannot be NULL."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
PADDLE_ENFORCE_NOT_NULL
(
param_scope
(),
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
GraphPatternDetector
gpd
;
patterns
::
TwoFusionGruConcat
pattern
{
gpd
.
mutable_pattern
(),
name_scope_
};
pattern
();
int
fused_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
x
,
x
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
gru1
,
gru1
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
gru2
,
gru2
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
wh1
,
wh1
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
wh2
,
wh2
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
wx1
,
wx1
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
wx2
,
wx2
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
b1
,
b1
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
b2
,
b2
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
h1
,
h1
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
h2
,
h2
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat
,
concat
,
pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out
,
out
,
pattern
);
if
(
gru1
->
Op
()
->
GetAttrIfExists
<
bool
>
(
"origin_mode"
)
!=
gru2
->
Op
()
->
GetAttrIfExists
<
bool
>
(
"origin_mode"
))
{
LOG
(
INFO
)
<<
"The two fusion_gru ops have different values of the "
"origin_mode attribute. Skipping fuse."
;
return
;
}
auto
wx
=
JoinInputs
(
gru1
,
gru2
,
"WeightX"
);
auto
wh
=
JoinInputs
(
gru1
,
gru2
,
"WeightH"
);
auto
b
=
JoinInputs
(
gru1
,
gru2
,
"Bias"
);
OpDesc
multi_gru_desc
;
multi_gru_desc
.
SetType
(
"multi_gru"
);
multi_gru_desc
.
SetInput
(
"X"
,
std
::
vector
<
std
::
string
>
({
x
->
Name
()}));
multi_gru_desc
.
SetInput
(
"WeightX"
,
wx
);
multi_gru_desc
.
SetInput
(
"WeightH"
,
wh
);
multi_gru_desc
.
SetInput
(
"Bias"
,
b
);
multi_gru_desc
.
SetOutput
(
"Hidden"
,
std
::
vector
<
std
::
string
>
({
out
->
Name
()}));
auto
attrs_to_skip
=
{
"is_reverse"
,
"use_seq"
};
for
(
auto
&
attr
:
gru1
->
Op
()
->
GetAttrMap
())
{
if
(
std
::
find
(
attrs_to_skip
.
begin
(),
attrs_to_skip
.
end
(),
attr
.
first
)
==
attrs_to_skip
.
end
())
multi_gru_desc
.
SetAttr
(
attr
.
first
,
attr
.
second
);
}
multi_gru_desc
.
SetAttr
(
"layers"
,
1
);
auto
multi_gru
=
g
->
CreateOpNode
(
&
multi_gru_desc
);
// OpDesc will be copied.
IR_NODE_LINK_TO
(
x
,
multi_gru
);
IR_NODE_LINK_TO
(
b1
,
multi_gru
);
IR_NODE_LINK_TO
(
b2
,
multi_gru
);
IR_NODE_LINK_TO
(
wh1
,
multi_gru
);
IR_NODE_LINK_TO
(
wh2
,
multi_gru
);
IR_NODE_LINK_TO
(
wx1
,
multi_gru
);
IR_NODE_LINK_TO
(
wx2
,
multi_gru
);
IR_NODE_LINK_TO
(
multi_gru
,
out
);
GraphSafeRemoveNodes
(
graph
,
{
gru1
,
gru2
,
h1
,
h2
,
concat
});
++
fused_count
;
};
gpd
(
graph
,
handler
);
AddStatis
(
fused_count
);
PrettyLogDetail
(
"--- fused %d pairs of concatenated multi_gru ops"
,
fused_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_gru_fuse_pass
,
paddle
::
framework
::
ir
::
MultiGRUFusePass
);
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h
0 → 100644
浏览文件 @
7b5a8e46
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// This pass fuses two concatenated fusion_gru ops into a single multi_gru op.
// It turns
// a -> fusion_gru -> c -> concat -> e
// \> fusion_gru -> d /
// into
// a -> multi_gru -> e
class
MultiGRUFusePass
:
public
FusePassBase
{
public:
virtual
~
MultiGRUFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
const
std
::
string
name_scope_
{
"multi_gru"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass_tester.cc
0 → 100644
浏览文件 @
7b5a8e46
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
bool
is_reverse
=
false
,
bool
origin_mode
=
false
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
if
(
type
==
"fusion_gru"
)
{
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"WeightX"
,
{
inputs
[
1
]});
op
->
SetInput
(
"WeightH"
,
{
inputs
[
2
]});
op
->
SetInput
(
"Bias"
,
{
inputs
[
3
]});
op
->
SetOutput
(
"Hidden"
,
{
outputs
[
0
]});
op
->
SetAttr
(
"is_reverse"
,
is_reverse
);
op
->
SetAttr
(
"origin_mode"
,
origin_mode
);
}
else
if
(
type
==
"concat"
)
{
op
->
SetInput
(
"X"
,
{
inputs
[
0
],
inputs
[
1
]});
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
}
else
{
FAIL
()
<<
"Unexpected operator type."
;
}
}
static
const
std
::
initializer_list
<
std
::
string
>
variable_names
=
{
"x"
,
"wx1"
,
"wx2"
,
"wh1"
,
"wh2"
,
"b1"
,
"b2"
,
"h1"
,
"h2"
,
"out"
};
// (x, wx1, wh1, b1) -> fusion_gru1 -> h1
// (x, wx2, wh2, b2) -> fusion_gru2 -> h2
// (h1, h2) -> concat -> out
ProgramDesc
BuildProgramDesc
(
bool
origin_mode1
,
bool
origin_mode2
)
{
ProgramDesc
prog
;
for
(
auto
&
v
:
variable_names
)
{
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
}
SetOp
(
&
prog
,
"fusion_gru"
,
{
"x"
,
"wx1"
,
"wh1"
,
"b1"
},
{
"h1"
},
false
,
origin_mode1
);
SetOp
(
&
prog
,
"fusion_gru"
,
{
"x"
,
"wx2"
,
"wh2"
,
"b2"
},
{
"h2"
},
true
,
origin_mode2
);
SetOp
(
&
prog
,
"concat"
,
{
"h1"
,
"h2"
},
{
"out"
});
return
prog
;
}
void
MainTest
(
const
ProgramDesc
&
prog
,
int
removed_nodes_count
,
int
added_nodes_count
,
const
std
::
vector
<
std
::
string
>
multi_gru_inputs
,
const
std
::
string
multi_gru_output
,
bool
origin_mode
)
{
// Apply pass
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
Scope
scope
;
graph
->
SetNotOwned
(
kParamScopeAttr
,
&
scope
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"multi_gru_fuse_pass"
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
// Verify graph after fuse
int
count_multi_gru
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
Type
()
==
"multi_gru"
)
{
EXPECT_EQ
(
op
->
Input
(
"X"
)[
0
],
multi_gru_inputs
[
0
]);
EXPECT_EQ
(
op
->
Input
(
"WeightX"
).
size
(),
2u
);
EXPECT_EQ
(
op
->
Input
(
"WeightX"
)[
0
],
multi_gru_inputs
[
1
]);
EXPECT_EQ
(
op
->
Input
(
"WeightX"
)[
1
],
multi_gru_inputs
[
2
]);
EXPECT_EQ
(
op
->
Input
(
"WeightH"
).
size
(),
2u
);
EXPECT_EQ
(
op
->
Input
(
"WeightH"
)[
0
],
multi_gru_inputs
[
3
]);
EXPECT_EQ
(
op
->
Input
(
"WeightH"
)[
1
],
multi_gru_inputs
[
4
]);
EXPECT_EQ
(
op
->
Input
(
"Bias"
).
size
(),
2u
);
EXPECT_EQ
(
op
->
Input
(
"Bias"
)[
0
],
multi_gru_inputs
[
5
]);
EXPECT_EQ
(
op
->
Input
(
"Bias"
)[
1
],
multi_gru_inputs
[
6
]);
EXPECT_EQ
(
op
->
Output
(
"Hidden"
)[
0
],
multi_gru_output
);
EXPECT_EQ
(
op
->
GetAttrIfExists
<
int
>
(
"layers"
),
1
);
EXPECT_EQ
(
op
->
GetAttrIfExists
<
bool
>
(
"origin_mode"
),
origin_mode
);
++
count_multi_gru
;
}
}
}
EXPECT_EQ
(
original_nodes_num
-
removed_nodes_count
+
added_nodes_count
,
current_nodes_num
);
EXPECT_EQ
(
count_multi_gru
,
added_nodes_count
);
}
TEST
(
MultiGruFusePass
,
same_origin_modes_1
)
{
bool
origin_mode1
=
false
;
bool
origin_mode2
=
false
;
// nodes to be removed: 2x fusion_gru + 2x hidden(output) + concat
const
int
removed_nodes_count
=
5
;
// nodes to be added: multi_gru
const
int
added_nodes_count
=
1
;
const
std
::
initializer_list
<
std
::
string
>
multi_gru_inputs
=
{
"x"
,
"wx1"
,
"wx2"
,
"wh1"
,
"wh2"
,
"b1"
,
"b2"
};
MainTest
(
BuildProgramDesc
(
origin_mode1
,
origin_mode2
),
removed_nodes_count
,
added_nodes_count
,
multi_gru_inputs
,
"out"
,
origin_mode1
);
}
TEST
(
MultiGruFusePass
,
same_origin_modes_2
)
{
bool
origin_mode1
=
true
;
bool
origin_mode2
=
true
;
// nodes to be removed: 2x fusion_gru + 2x hidden(output) + concat
const
int
removed_nodes_count
=
5
;
// nodes to be added: multi_gru
const
int
added_nodes_count
=
1
;
const
std
::
initializer_list
<
std
::
string
>
multi_gru_inputs
=
{
"x"
,
"wx1"
,
"wx2"
,
"wh1"
,
"wh2"
,
"b1"
,
"b2"
};
MainTest
(
BuildProgramDesc
(
origin_mode1
,
origin_mode2
),
removed_nodes_count
,
added_nodes_count
,
multi_gru_inputs
,
"out"
,
origin_mode1
);
}
TEST
(
MultiGruFusePass
,
different_origin_modes
)
{
bool
origin_mode1
=
true
;
bool
origin_mode2
=
false
;
// the fuse should not be applied, so
// nodes to be removed: none
const
int
removed_nodes_count
=
0
;
// nodes to be added: none
const
int
added_nodes_count
=
0
;
const
std
::
initializer_list
<
std
::
string
>
multi_gru_inputs
=
{
"x"
,
"wx1"
,
"wx2"
,
"wh1"
,
"wh2"
,
"b1"
,
"b2"
};
MainTest
(
BuildProgramDesc
(
origin_mode1
,
origin_mode2
),
removed_nodes_count
,
added_nodes_count
,
multi_gru_inputs
,
"out"
,
origin_mode1
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
multi_gru_fuse_pass
);
paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc
浏览文件 @
7b5a8e46
...
...
@@ -32,8 +32,8 @@ using string::PrettyLogDetail;
namespace
{
std
::
vector
<
std
::
string
>
join_i
nputs
(
Node
*
op1
,
Node
*
op2
,
std
::
string
input_name
)
{
std
::
vector
<
std
::
string
>
JoinI
nputs
(
Node
*
op1
,
Node
*
op2
,
std
::
string
input_name
)
{
auto
in1
=
op1
->
Op
()
->
Input
(
input_name
);
auto
&
in2
=
op2
->
Op
()
->
Input
(
input_name
);
in1
.
insert
(
in1
.
end
(),
in2
.
begin
(),
in2
.
end
());
...
...
@@ -83,9 +83,9 @@ void MultiGruSeqFusePass::ApplyImpl(ir::Graph* graph) const {
return
;
}
auto
wx
=
join_i
nputs
(
gru1
,
gru2
,
"WeightX"
);
auto
wh
=
join_i
nputs
(
gru1
,
gru2
,
"WeightH"
);
auto
b
=
join_i
nputs
(
gru1
,
gru2
,
"Bias"
);
auto
wx
=
JoinI
nputs
(
gru1
,
gru2
,
"WeightX"
);
auto
wh
=
JoinI
nputs
(
gru1
,
gru2
,
"WeightH"
);
auto
b
=
JoinI
nputs
(
gru1
,
gru2
,
"Bias"
);
OpDesc
multi_gru_desc
;
multi_gru_desc
.
SetType
(
"multi_gru"
);
...
...
tools/static_mode_white_list.py
浏览文件 @
7b5a8e46
...
...
@@ -603,6 +603,7 @@ STATIC_MODE_TESTING_LIST = [
'test_matmul_bf16_mkldnn_op'
,
'test_mul_int8_mkldnn_op'
,
'test_multi_gru_mkldnn_op'
,
'test_multi_gru_fuse_pass'
,
'test_multi_gru_seq_fuse_pass'
,
'test_pool2d_int8_mkldnn_op'
,
'test_pool2d_mkldnn_op'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录