Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2fd8deea
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
2fd8deea
编写于
10月 09, 2021
作者:
W
wuhuanzhou
提交者:
GitHub
10月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ support register pass via PassDesc (#36095)
支持C++开发注册GeneratePass,简化针对fusion等子图优化场景开发方式。
上级
d8887afa
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
314 addition
and
216 deletion
+314
-216
paddle/fluid/framework/ir/generate_pass.cc
paddle/fluid/framework/ir/generate_pass.cc
+110
-0
paddle/fluid/framework/ir/generate_pass.h
paddle/fluid/framework/ir/generate_pass.h
+152
-1
paddle/fluid/framework/ir/generate_pass_tester.cc
paddle/fluid/framework/ir/generate_pass_tester.cc
+52
-215
未找到文件。
paddle/fluid/framework/ir/generate_pass.cc
浏览文件 @
2fd8deea
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -224,6 +225,115 @@ bool GeneratePass::VerifyGraph(const Graph& graph) {
return
true
;
}
namespace
generate_pass
{
VarHelper
::
VarHelper
(
const
char
*
name
)
:
name_
(
name
),
type_
(
Type
::
kInput
)
{}
VarHelper
::
VarHelper
(
const
std
::
string
&
name
,
Type
type
)
:
name_
(
name
),
type_
(
type
)
{}
OpHelper
::
OpHelper
(
const
char
*
type
,
SubgraphHelper
*
subgraph_helper
)
:
type_
(
type
),
subgraph_helper_
(
subgraph_helper
)
{
op_desc_
=
subgraph_helper_
->
ProgramDesc
()
->
mutable_blocks
(
0
)
->
add_ops
();
op_desc_
->
set_type
(
type_
);
}
OpHelper
::
Arguments
::
Arguments
(
const
char
*
parameter
,
const
VarHelper
&
var_helper
)
:
parameter_
(
parameter
)
{
var_helpers_
.
push_back
(
var_helper
);
}
OpHelper
::
Arguments
::
Arguments
(
const
char
*
parameter
,
std
::
initializer_list
<
VarHelper
>
var_helpers
)
:
parameter_
(
parameter
),
var_helpers_
(
var_helpers
)
{}
OpHelper
&
OpHelper
::
operator
()(
const
Arguments
&
input
)
{
proto
::
OpDesc
::
Var
*
var
=
op_desc_
->
add_inputs
();
var
->
set_parameter
(
input
.
parameter_
);
for
(
const
VarHelper
&
var_helper
:
input
.
var_helpers_
)
{
var
->
add_arguments
()
->
assign
(
var_helper
.
name_
);
if
(
VarHelper
::
Type
::
kInput
==
var_helper
.
type_
)
{
subgraph_helper_
->
AddInputVar
(
var_helper
.
name_
);
}
}
return
*
this
;
}
OpHelper
&
OpHelper
::
operator
()(
std
::
initializer_list
<
Arguments
>
inputs
)
{
for
(
const
auto
&
input
:
inputs
)
{
operator
()(
input
);
}
return
*
this
;
}
VarHelper
OpHelper
::
Out
(
const
char
*
name
)
{
std
::
string
argument
=
patterns
::
UniqueKey
(
type_
);
proto
::
OpDesc
::
Var
*
var
=
op_desc_
->
add_outputs
();
var
->
set_parameter
(
name
);
var
->
add_arguments
()
->
assign
(
argument
);
return
VarHelper
(
argument
,
VarHelper
::
Type
::
kOutput
);
}
proto
::
ProgramDesc
*
SubgraphHelper
::
ProgramDesc
()
{
return
&
program_desc_
;
}
const
proto
::
ProgramDesc
&
SubgraphHelper
::
ProgramDesc
()
const
{
return
program_desc_
;
}
const
std
::
vector
<
std
::
string
>&
SubgraphHelper
::
InputVars
()
const
{
return
input_vars_
;
}
const
std
::
vector
<
std
::
string
>&
SubgraphHelper
::
OutputVars
()
const
{
return
output_vars_
;
}
void
SubgraphHelper
::
AddInputVar
(
const
std
::
string
&
name
)
{
auto
iter
=
std
::
find
(
input_vars_
.
begin
(),
input_vars_
.
end
(),
name
);
if
(
input_vars_
.
end
()
==
iter
)
{
input_vars_
.
push_back
(
name
);
}
}
void
SubgraphHelper
::
AddOutputVars
(
const
VarHelper
&
var_helper
)
{
output_vars_
.
push_back
(
var_helper
.
name_
);
}
}
// namespace generate_pass
PassPairs
::
PassPairs
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
)
{
AddPassDesc
(
pattern
,
replace
);
}
void
PassPairs
::
AddPassDesc
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
)
{
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc_
.
add_pass_descs
();
pass_desc
->
mutable_pattern
()
->
CopyFrom
(
pattern
.
ProgramDesc
());
pass_desc
->
mutable_replace
()
->
CopyFrom
(
replace
.
ProgramDesc
());
PADDLE_ENFORCE_EQ
(
pattern
.
InputVars
().
size
(),
replace
.
InputVars
().
size
(),
platform
::
errors
::
InvalidArgument
(
"Size of lambda expression arguments is not equal "
"between pattern/replace subgraph."
));
for
(
size_t
i
=
0
;
i
<
pattern
.
InputVars
().
size
();
i
++
)
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
pattern
.
InputVars
()[
i
]);
var_map
->
set_replace_var
(
replace
.
InputVars
()[
i
]);
}
PADDLE_ENFORCE_EQ
(
pattern
.
OutputVars
().
size
(),
replace
.
OutputVars
().
size
(),
platform
::
errors
::
InvalidArgument
(
"Size of lambda expression returns is not equal "
"between pattern/replace subgraph."
));
for
(
size_t
i
=
0
;
i
<
pattern
.
OutputVars
().
size
();
i
++
)
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
pattern
.
OutputVars
()[
i
]);
var_map
->
set_replace_var
(
replace
.
OutputVars
()[
i
]);
}
}
const
proto
::
MultiPassDesc
&
PassPairs
::
MultiPassDesc
()
const
{
return
multi_pass_desc_
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/generate_pass.h
浏览文件 @
2fd8deea
...
...
@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/pass_desc.pb.h"
...
...
@@ -43,6 +42,158 @@ class GeneratePass : public Pass {
proto
::
MultiPassDesc
multi_pass_desc_
;
};
namespace
generate_pass
{
class
VarHelper
;
class
OpHelper
;
class
SubgraphHelper
;
// VarHelper is used to represent a variable node.
struct
VarHelper
{
enum
class
Type
{
kInput
,
kOutput
};
explicit
VarHelper
(
const
char
*
name
);
VarHelper
(
const
std
::
string
&
name
,
Type
type
);
std
::
string
name_
;
Type
type_
;
};
// OpHelper is used to represent a operator node.
class
OpHelper
{
public:
// Convert multiple inputs.
struct
Arguments
{
Arguments
(
const
char
*
parameter
,
const
VarHelper
&
var_helper
);
Arguments
(
const
char
*
parameter
,
std
::
initializer_list
<
VarHelper
>
var_helpers
);
std
::
string
parameter_
;
std
::
vector
<
VarHelper
>
var_helpers_
;
};
OpHelper
(
const
char
*
type
,
SubgraphHelper
*
subgraph_helper
);
OpHelper
&
operator
()(
const
Arguments
&
input
);
OpHelper
&
operator
()(
std
::
initializer_list
<
Arguments
>
inputs
);
VarHelper
Out
(
const
char
*
name
);
private:
OpHelper
()
=
delete
;
DISABLE_COPY_AND_ASSIGN
(
OpHelper
);
const
char
*
type_
;
proto
::
OpDesc
*
op_desc_
;
SubgraphHelper
*
subgraph_helper_
;
};
/*
* SubgraphHelper is used to define pattern/replace subgraphs.
*
* Use lambda expression to define subgraph like Python. SubgraphHelper
* converts lambda expression to ProgramDesc.
*
* In order to define a subgraph, user need to use VarHelper and OpHelper.
* Use the macros instead of class names, so user can develop better and
* don't need to know too much about underlying implementation.
*
* An example of defining a subgraph as follows:
*
* SUBGRAPH_(subgraph)([subgraph=&subgraph](VAR_(x), VAR_(y), VAR_(z)) {
* auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out");
* auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out");
* return ewadd2;
* });
*
*/
class
SubgraphHelper
{
public:
SubgraphHelper
()
=
default
;
// The lambda expression is a prvalue expression.
template
<
typename
T
>
SubgraphHelper
&
operator
=
(
const
T
&&
f
)
{
proto
::
BlockDesc
*
block
=
program_desc_
.
add_blocks
();
block
->
set_idx
(
0
);
block
->
set_parent_idx
(
0
);
AddOutputVars
(
f
());
return
*
this
;
}
proto
::
ProgramDesc
*
ProgramDesc
();
const
proto
::
ProgramDesc
&
ProgramDesc
()
const
;
const
std
::
vector
<
std
::
string
>&
InputVars
()
const
;
const
std
::
vector
<
std
::
string
>&
OutputVars
()
const
;
void
AddInputVar
(
const
std
::
string
&
name
);
void
AddOutputVars
(
const
VarHelper
&
var_helper
);
template
<
size_t
i
,
typename
...
Ts
,
std
::
enable_if_t
<
i
+
1
<
sizeof
...(
Ts
)>
*
=
nullptr
>
void
AddOutputVars
(
const
std
::
tuple
<
Ts
...
>&
outputs
)
{
AddOutputVars
(
std
::
get
<
i
>
(
outputs
));
AddOutputVars
<
i
+
1
>
(
outputs
);
}
template
<
size_t
i
,
typename
...
Ts
,
std
::
enable_if_t
<
i
+
1
==
sizeof
...(
Ts
)>
*
=
nullptr
>
void
AddOutputVars
(
const
std
::
tuple
<
Ts
...
>&
outputs
)
{
AddOutputVars
(
std
::
get
<
i
>
(
outputs
));
}
template
<
typename
...
Ts
>
void
AddOutputVars
(
const
std
::
tuple
<
Ts
...
>&
outputs
)
{
AddOutputVars
<
0
>
(
outputs
);
}
private:
DISABLE_COPY_AND_ASSIGN
(
SubgraphHelper
);
std
::
vector
<
std
::
string
>
input_vars_
;
std
::
vector
<
std
::
string
>
output_vars_
;
proto
::
ProgramDesc
program_desc_
;
};
}
// namespace generate_pass
class
PassPairs
{
public:
using
SubgraphType
=
generate_pass
::
SubgraphHelper
;
PassPairs
()
=
default
;
PassPairs
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
);
void
AddPassDesc
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
);
const
proto
::
MultiPassDesc
&
MultiPassDesc
()
const
;
private:
proto
::
MultiPassDesc
multi_pass_desc_
;
};
// Use function to register in CC.
template
<
PassPairs
(
*
Functor
)(
void
)>
class
MacroPassHelper
:
public
GeneratePass
{
public:
MacroPassHelper
()
:
GeneratePass
(
Functor
().
MultiPassDesc
())
{}
};
#define VAR_(name) \
::paddle::framework::ir::generate_pass::VarHelper name = \
::paddle::framework::ir::generate_pass::VarHelper(#name)
#define OP_(type) \
::paddle::framework::ir::generate_pass::OpHelper(#type, subgraph)
#define SUBGRAPH_(name) \
::paddle::framework::ir::generate_pass::SubgraphHelper name; \
name
#define REGISTER_GENERATE_PASS(pass_type) \
paddle::framework::ir::PassPairs register_##pass_type(); \
REGISTER_PASS( \
pass_type, \
::paddle::framework::ir::MacroPassHelper<®ister_##pass_type>); \
paddle::framework::ir::PassPairs register_##pass_type()
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/generate_pass_tester.cc
浏览文件 @
2fd8deea
...
...
@@ -16,234 +16,71 @@
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
template
<
proto
::
MultiPassDesc
(
*
Functor
)(
void
)>
class
CXXGeneratePass
:
public
GeneratePass
{
public:
CXXGeneratePass
()
:
GeneratePass
(
Functor
())
{}
};
#define REGISTER_GENERATE_PASS(pass_type, function) \
REGISTER_PASS(pass_type, ::paddle::framework::ir::CXXGeneratePass<&function>)
proto
::
MultiPassDesc
generate_fc_fuse
()
{
proto
::
MultiPassDesc
multi_pass_desc
;
REGISTER_GENERATE_PASS
(
generate_fc_fuse
)
{
paddle
::
framework
::
ir
::
PassPairs
pass_pairs
;
for
(
bool
with_relu
:
{
true
,
false
})
{
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc
.
add_pass_descs
();
proto
::
BlockDesc
*
pattern
=
pass_desc
->
mutable_pattern
()
->
add_blocks
();
pattern
->
set_idx
(
0
);
pattern
->
set_parent_idx
(
0
);
proto
::
OpDesc
*
mul
=
pattern
->
add_ops
();
mul
->
set_type
(
"mul"
);
proto
::
OpDesc
::
Var
*
mul_x
=
mul
->
add_inputs
();
mul_x
->
set_parameter
(
"X"
);
mul_x
->
add_arguments
()
->
assign
(
"x"
);
proto
::
OpDesc
::
Var
*
mul_y
=
mul
->
add_inputs
();
mul_y
->
set_parameter
(
"Y"
);
mul_y
->
add_arguments
()
->
assign
(
"w"
);
proto
::
OpDesc
::
Var
*
mul_out
=
mul
->
add_outputs
();
mul_out
->
set_parameter
(
"Out"
);
mul_out
->
add_arguments
()
->
assign
(
"mul_out"
);
proto
::
OpDesc
*
ewadd
=
pattern
->
add_ops
();
ewadd
->
set_type
(
"elementwise_add"
);
proto
::
OpDesc
::
Var
*
ewadd_x
=
ewadd
->
add_inputs
();
ewadd_x
->
set_parameter
(
"X"
);
ewadd_x
->
add_arguments
()
->
assign
(
"mul_out"
);
proto
::
OpDesc
::
Var
*
ewadd_y
=
ewadd
->
add_inputs
();
ewadd_y
->
set_parameter
(
"Y"
);
ewadd_y
->
add_arguments
()
->
assign
(
"b"
);
proto
::
OpDesc
::
Var
*
ewadd_out
=
ewadd
->
add_outputs
();
ewadd_out
->
set_parameter
(
"Out"
);
ewadd_out
->
add_arguments
()
->
assign
(
"ewadd_out"
);
proto
::
OpDesc
*
relu
=
nullptr
;
proto
::
BlockDesc
*
replace
=
pass_desc
->
mutable_replace
()
->
add_blocks
();
replace
->
set_idx
(
0
);
replace
->
set_parent_idx
(
0
);
proto
::
OpDesc
*
fc
=
replace
->
add_ops
();
fc
->
set_type
(
"fc"
);
proto
::
OpDesc
::
Var
*
fc_x
=
fc
->
add_inputs
();
fc_x
->
set_parameter
(
"Input"
);
fc_x
->
add_arguments
()
->
assign
(
"x"
);
proto
::
OpDesc
::
Var
*
fc_w
=
fc
->
add_inputs
();
fc_w
->
set_parameter
(
"W"
);
fc_w
->
add_arguments
()
->
assign
(
"w"
);
proto
::
OpDesc
::
Var
*
fc_b
=
fc
->
add_inputs
();
fc_b
->
set_parameter
(
"Bias"
);
fc_b
->
add_arguments
()
->
assign
(
"b"
);
proto
::
OpDesc
::
Var
*
fc_out
=
fc
->
add_outputs
();
fc_out
->
set_parameter
(
"Out"
);
fc_out
->
add_arguments
()
->
assign
(
"fc_out"
);
for
(
const
char
*
var
:
{
"x"
,
"w"
,
"b"
,
"fc_out"
})
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
var
);
var_map
->
set_replace_var
(
var
);
}
proto
::
PassDesc
::
AttrMap
*
attr_map
=
pass_desc
->
add_attr_maps
();
attr_map
->
set_pattern_op_idx
(
0
);
attr_map
->
set_pattern_name
(
"x_num_col_dims"
);
attr_map
->
set_replace_op_idx
(
0
);
attr_map
->
set_replace_name
(
"in_num_col_dims"
);
if
(
with_relu
)
{
relu
=
pattern
->
add_ops
();
relu
->
set_type
(
"relu"
);
proto
::
OpDesc
::
Var
*
relu_x
=
relu
->
add_inputs
();
relu_x
->
set_parameter
(
"X"
);
relu_x
->
add_arguments
()
->
assign
(
"ewadd_out"
);
proto
::
OpDesc
::
Var
*
relu_out
=
relu
->
add_outputs
();
relu_out
->
set_parameter
(
"Out"
);
relu_out
->
add_arguments
()
->
assign
(
"relu_out"
);
pass_desc
->
mutable_var_maps
(
3
)
->
set_pattern_var
(
"relu_out"
);
proto
::
OpDesc
::
Attr
*
attr
=
fc
->
add_attrs
();
attr
->
set_name
(
"activation_type"
);
attr
->
set_type
(
proto
::
AttrType
::
STRING
);
attr
->
set_s
(
"relu"
);
}
else
{
pass_desc
->
mutable_var_maps
(
3
)
->
set_pattern_var
(
"ewadd_out"
);
}
// pattern
SUBGRAPH_
(
pattern
)
=
[
subgraph
=
&
pattern
,
with_relu
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
VLOG
(
3
)
<<
"exec lambda func."
;
auto
mul
=
OP_
(
mul
)({{
"X"
,
x
},
{
"Y"
,
y
}}).
Out
(
"Out"
);
auto
ewadd
=
OP_
(
elementwise_add
)({{
"X"
,
mul
},
{
"Y"
,
z
}}).
Out
(
"Out"
);
if
(
with_relu
)
{
return
OP_
(
relu
)({
"X"
,
ewadd
}).
Out
(
"Out"
);
}
else
{
return
ewadd
;
}
};
// replace
SUBGRAPH_
(
replace
)
=
[
subgraph
=
&
replace
,
with_relu
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
auto
&
fc
=
OP_
(
fc
)({{
"Input"
,
x
},
{
"W"
,
y
},
{
"Bias"
,
z
}});
return
fc
.
Out
(
"Out"
);
};
pass_pairs
.
AddPassDesc
(
pattern
,
replace
);
}
return
multi_pass_desc
;
return
pass_pairs
;
}
proto
::
MultiPassDesc
generate_multi_add_to_addn
()
{
proto
::
MultiPassDesc
multi_pass_desc
;
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc
.
add_pass_descs
();
proto
::
BlockDesc
*
pattern
=
pass_desc
->
mutable_pattern
()
->
add_blocks
();
proto
::
OpDesc
*
ewadd_0
=
pattern
->
add_ops
();
ewadd_0
->
set_type
(
"elementwise_add"
);
proto
::
OpDesc
::
Var
*
ewadd_0_x
=
ewadd_0
->
add_inputs
();
ewadd_0_x
->
set_parameter
(
"X"
);
ewadd_0_x
->
add_arguments
()
->
assign
(
"a"
);
proto
::
OpDesc
::
Var
*
ewadd_0_y
=
ewadd_0
->
add_inputs
();
ewadd_0_y
->
set_parameter
(
"Y"
);
ewadd_0_y
->
add_arguments
()
->
assign
(
"b"
);
proto
::
OpDesc
::
Var
*
ewadd_0_out
=
ewadd_0
->
add_outputs
();
ewadd_0_out
->
set_parameter
(
"Out"
);
ewadd_0_out
->
add_arguments
()
->
assign
(
"ewadd_out_0"
);
proto
::
OpDesc
*
ewadd_1
=
pattern
->
add_ops
();
ewadd_1
->
set_type
(
"elementwise_add"
);
proto
::
OpDesc
::
Var
*
ewadd_1_x
=
ewadd_1
->
add_inputs
();
ewadd_1_x
->
set_parameter
(
"X"
);
ewadd_1_x
->
add_arguments
()
->
assign
(
"ewadd_out_0"
);
proto
::
OpDesc
::
Var
*
ewadd_1_y
=
ewadd_1
->
add_inputs
();
ewadd_1_y
->
set_parameter
(
"Y"
);
ewadd_1_y
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
ewadd_1_out
=
ewadd_1
->
add_outputs
();
ewadd_1_out
->
set_parameter
(
"Out"
);
ewadd_1_out
->
add_arguments
()
->
assign
(
"ewadd_out_1"
);
proto
::
BlockDesc
*
replace
=
pass_desc
->
mutable_replace
()
->
add_blocks
();
proto
::
OpDesc
*
addn
=
replace
->
add_ops
();
addn
->
set_type
(
"add_n"
);
proto
::
OpDesc
::
Var
*
addn_x
=
addn
->
add_inputs
();
addn_x
->
set_parameter
(
"X"
);
addn_x
->
add_arguments
()
->
assign
(
"a"
);
addn_x
->
add_arguments
()
->
assign
(
"b"
);
addn_x
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
addn_out
=
addn
->
add_outputs
();
addn_out
->
set_parameter
(
"Out"
);
addn_out
->
add_arguments
()
->
assign
(
"addn_out"
);
for
(
const
char
*
var
:
{
"a"
,
"b"
,
"c"
,
"ewadd_out_1"
})
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
var
);
var_map
->
set_replace_var
(
var
);
}
pass_desc
->
mutable_var_maps
(
3
)
->
set_replace_var
(
"addn_out"
);
return
multi_pass_desc
;
REGISTER_GENERATE_PASS
(
generate_multi_add_to_addn
)
{
// pattern
SUBGRAPH_
(
pattern
)
=
[
subgraph
=
&
pattern
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
auto
ewadd1
=
OP_
(
elementwise_add
)({{
"X"
,
x
},
{
"Y"
,
y
}}).
Out
(
"Out"
);
auto
ewadd2
=
OP_
(
elementwise_add
)({{
"X"
,
ewadd1
},
{
"Y"
,
z
}}).
Out
(
"Out"
);
return
ewadd2
;
};
// replace
SUBGRAPH_
(
replace
)
=
[
subgraph
=
&
replace
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
return
OP_
(
sum
)({
"X"
,
{
x
,
y
,
z
}}).
Out
(
"Out"
);
};
return
{
pattern
,
replace
};
}
proto
::
MultiPassDesc
generate_combine_matmul
()
{
proto
::
MultiPassDesc
multi_pass_desc
;
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc
.
add_pass_descs
();
proto
::
BlockDesc
*
pattern
=
pass_desc
->
mutable_pattern
()
->
add_blocks
();
proto
::
OpDesc
*
matmul_0
=
pattern
->
add_ops
();
matmul_0
->
set_type
(
"matmul"
);
proto
::
OpDesc
::
Var
*
matmul_0_x
=
matmul_0
->
add_inputs
();
matmul_0_x
->
set_parameter
(
"X"
);
matmul_0_x
->
add_arguments
()
->
assign
(
"a"
);
proto
::
OpDesc
::
Var
*
matmul_0_y
=
matmul_0
->
add_inputs
();
matmul_0_y
->
set_parameter
(
"Y"
);
matmul_0_y
->
add_arguments
()
->
assign
(
"b"
);
proto
::
OpDesc
::
Var
*
matmul_0_out
=
matmul_0
->
add_outputs
();
matmul_0_out
->
set_parameter
(
"Out"
);
matmul_0_out
->
add_arguments
()
->
assign
(
"matmul_out_0"
);
proto
::
OpDesc
*
matmul_1
=
pattern
->
add_ops
();
matmul_1
->
set_type
(
"matmul"
);
proto
::
OpDesc
::
Var
*
matmul_1_x
=
matmul_1
->
add_inputs
();
matmul_1_x
->
set_parameter
(
"X"
);
matmul_1_x
->
add_arguments
()
->
assign
(
"a"
);
proto
::
OpDesc
::
Var
*
matmul_1_y
=
matmul_1
->
add_inputs
();
matmul_1_y
->
set_parameter
(
"Y"
);
matmul_1_y
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
matmul_1_out
=
matmul_1
->
add_outputs
();
matmul_1_out
->
set_parameter
(
"Out"
);
matmul_1_out
->
add_arguments
()
->
assign
(
"matmul_out_1"
);
proto
::
BlockDesc
*
replace
=
pass_desc
->
mutable_replace
()
->
add_blocks
();
proto
::
OpDesc
*
concat
=
replace
->
add_ops
();
concat
->
set_type
(
"concat"
);
proto
::
OpDesc
::
Var
*
concat_x
=
concat
->
add_inputs
();
concat_x
->
set_parameter
(
"X"
);
concat_x
->
add_arguments
()
->
assign
(
"b"
);
concat_x
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
concat_out
=
concat
->
add_outputs
();
concat_out
->
set_parameter
(
"Out"
);
concat_out
->
add_arguments
()
->
assign
(
"concat_out"
);
proto
::
OpDesc
*
matmul
=
replace
->
add_ops
();
matmul
->
set_type
(
"matmul"
);
proto
::
OpDesc
::
Var
*
matmul_x
=
matmul
->
add_inputs
();
matmul_x
->
set_parameter
(
"X"
);
matmul_x
->
add_arguments
()
->
assign
(
"a"
);
proto
::
OpDesc
::
Var
*
matmul_y
=
matmul
->
add_inputs
();
matmul_y
->
set_parameter
(
"Y"
);
matmul_y
->
add_arguments
()
->
assign
(
"concat_out"
);
proto
::
OpDesc
::
Var
*
matmul_out
=
matmul
->
add_outputs
();
matmul_out
->
set_parameter
(
"Out"
);
matmul_out
->
add_arguments
()
->
assign
(
"matmul_out"
);
proto
::
OpDesc
*
slice_0
=
replace
->
add_ops
();
slice_0
->
set_type
(
"slice"
);
proto
::
OpDesc
::
Var
*
slice_0_x
=
slice_0
->
add_inputs
();
slice_0_x
->
set_parameter
(
"X"
);
slice_0_x
->
add_arguments
()
->
assign
(
"matmul_out"
);
proto
::
OpDesc
::
Var
*
slice_0_out
=
slice_0
->
add_outputs
();
slice_0_out
->
set_parameter
(
"Out"
);
slice_0_out
->
add_arguments
()
->
assign
(
"slice_out_0"
);
proto
::
OpDesc
*
slice_1
=
replace
->
add_ops
();
slice_1
->
set_type
(
"slice"
);
proto
::
OpDesc
::
Var
*
slice_1_x
=
slice_1
->
add_inputs
();
slice_1_x
->
set_parameter
(
"X"
);
slice_1_x
->
add_arguments
()
->
assign
(
"matmul_out"
);
proto
::
OpDesc
::
Var
*
slice_1_out
=
slice_1
->
add_outputs
();
slice_1_out
->
set_parameter
(
"Out"
);
slice_1_out
->
add_arguments
()
->
assign
(
"slice_out_1"
);
for
(
const
char
*
var
:
{
"a"
,
"b"
,
"c"
,
"matmul_out_0"
,
"matmul_out_1"
})
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
var
);
var_map
->
set_replace_var
(
var
);
}
pass_desc
->
mutable_var_maps
(
3
)
->
set_replace_var
(
"slice_out_0"
);
pass_desc
->
mutable_var_maps
(
4
)
->
set_replace_var
(
"slice_out_1"
);
return
multi_pass_desc
;
REGISTER_GENERATE_PASS
(
generate_combine_matmul
)
{
// pattern
SUBGRAPH_
(
pattern
)
=
[
subgraph
=
&
pattern
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
auto
matmul1
=
OP_
(
matmul
)({{
"X"
,
x
},
{
"Y"
,
y
}}).
Out
(
"Out"
);
auto
matmul2
=
OP_
(
matmul
)({{
"X"
,
x
},
{
"Y"
,
z
}}).
Out
(
"Out"
);
return
std
::
make_tuple
(
matmul1
,
matmul2
);
};
// replace
SUBGRAPH_
(
replace
)
=
[
subgraph
=
&
replace
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
auto
concat
=
OP_
(
concat
)({
"X"
,
{
y
,
z
}}).
Out
(
"Out"
);
auto
matmul
=
OP_
(
matmul
)({{
"X"
,
x
},
{
"Y"
,
concat
}}).
Out
(
"Out"
);
auto
slice1
=
OP_
(
slice
)({
"X"
,
matmul
}).
Out
(
"Out"
);
auto
slice2
=
OP_
(
slice
)({
"X"
,
matmul
}).
Out
(
"Out"
);
return
std
::
make_tuple
(
slice1
,
slice2
);
};
return
{
pattern
,
replace
};
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_GENERATE_PASS
(
generate_fc_fuse
,
paddle
::
framework
::
ir
::
generate_fc_fuse
);
REGISTER_GENERATE_PASS
(
generate_multi_add_to_addn
,
paddle
::
framework
::
ir
::
generate_multi_add_to_addn
);
REGISTER_GENERATE_PASS
(
generate_combine_matmul
,
paddle
::
framework
::
ir
::
generate_combine_matmul
);
namespace
paddle
{
namespace
framework
{
namespace
ir
{
TEST
(
GeneratePass
,
construct_with_string
)
{
std
::
string
binary_str
;
generate_fc_fuse
().
SerializeToString
(
&
binary_str
);
register_generate_fc_fuse
().
MultiPassDesc
().
SerializeToString
(
&
binary_str
);
GeneratePass
generate_pass
(
binary_str
);
}
...
...
@@ -318,7 +155,7 @@ TEST(GeneratePass, generate_multi_add_to_addn) {
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
int
num_addn_nodes_after
=
GetNumOpNodes
(
graph
,
"
add_n
"
);
int
num_addn_nodes_after
=
GetNumOpNodes
(
graph
,
"
sum
"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
2
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录