Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2fd8deea
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2fd8deea
编写于
10月 09, 2021
作者:
W
wuhuanzhou
提交者:
GitHub
10月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ support register pass via PassDesc (#36095)
支持C++开发注册GeneratePass,简化针对fusion等子图优化场景开发方式。
上级
d8887afa
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
314 addition
and
216 deletion
+314
-216
paddle/fluid/framework/ir/generate_pass.cc
paddle/fluid/framework/ir/generate_pass.cc
+110
-0
paddle/fluid/framework/ir/generate_pass.h
paddle/fluid/framework/ir/generate_pass.h
+152
-1
paddle/fluid/framework/ir/generate_pass_tester.cc
paddle/fluid/framework/ir/generate_pass_tester.cc
+52
-215
未找到文件。
paddle/fluid/framework/ir/generate_pass.cc
浏览文件 @
2fd8deea
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -224,6 +225,115 @@ bool GeneratePass::VerifyGraph(const Graph& graph) {
...
@@ -224,6 +225,115 @@ bool GeneratePass::VerifyGraph(const Graph& graph) {
return
true
;
return
true
;
}
}
namespace
generate_pass
{
VarHelper
::
VarHelper
(
const
char
*
name
)
:
name_
(
name
),
type_
(
Type
::
kInput
)
{}
VarHelper
::
VarHelper
(
const
std
::
string
&
name
,
Type
type
)
:
name_
(
name
),
type_
(
type
)
{}
OpHelper
::
OpHelper
(
const
char
*
type
,
SubgraphHelper
*
subgraph_helper
)
:
type_
(
type
),
subgraph_helper_
(
subgraph_helper
)
{
op_desc_
=
subgraph_helper_
->
ProgramDesc
()
->
mutable_blocks
(
0
)
->
add_ops
();
op_desc_
->
set_type
(
type_
);
}
OpHelper
::
Arguments
::
Arguments
(
const
char
*
parameter
,
const
VarHelper
&
var_helper
)
:
parameter_
(
parameter
)
{
var_helpers_
.
push_back
(
var_helper
);
}
OpHelper
::
Arguments
::
Arguments
(
const
char
*
parameter
,
std
::
initializer_list
<
VarHelper
>
var_helpers
)
:
parameter_
(
parameter
),
var_helpers_
(
var_helpers
)
{}
OpHelper
&
OpHelper
::
operator
()(
const
Arguments
&
input
)
{
proto
::
OpDesc
::
Var
*
var
=
op_desc_
->
add_inputs
();
var
->
set_parameter
(
input
.
parameter_
);
for
(
const
VarHelper
&
var_helper
:
input
.
var_helpers_
)
{
var
->
add_arguments
()
->
assign
(
var_helper
.
name_
);
if
(
VarHelper
::
Type
::
kInput
==
var_helper
.
type_
)
{
subgraph_helper_
->
AddInputVar
(
var_helper
.
name_
);
}
}
return
*
this
;
}
OpHelper
&
OpHelper
::
operator
()(
std
::
initializer_list
<
Arguments
>
inputs
)
{
for
(
const
auto
&
input
:
inputs
)
{
operator
()(
input
);
}
return
*
this
;
}
VarHelper
OpHelper
::
Out
(
const
char
*
name
)
{
std
::
string
argument
=
patterns
::
UniqueKey
(
type_
);
proto
::
OpDesc
::
Var
*
var
=
op_desc_
->
add_outputs
();
var
->
set_parameter
(
name
);
var
->
add_arguments
()
->
assign
(
argument
);
return
VarHelper
(
argument
,
VarHelper
::
Type
::
kOutput
);
}
proto
::
ProgramDesc
*
SubgraphHelper
::
ProgramDesc
()
{
return
&
program_desc_
;
}
const
proto
::
ProgramDesc
&
SubgraphHelper
::
ProgramDesc
()
const
{
return
program_desc_
;
}
const
std
::
vector
<
std
::
string
>&
SubgraphHelper
::
InputVars
()
const
{
return
input_vars_
;
}
const
std
::
vector
<
std
::
string
>&
SubgraphHelper
::
OutputVars
()
const
{
return
output_vars_
;
}
void
SubgraphHelper
::
AddInputVar
(
const
std
::
string
&
name
)
{
auto
iter
=
std
::
find
(
input_vars_
.
begin
(),
input_vars_
.
end
(),
name
);
if
(
input_vars_
.
end
()
==
iter
)
{
input_vars_
.
push_back
(
name
);
}
}
void
SubgraphHelper
::
AddOutputVars
(
const
VarHelper
&
var_helper
)
{
output_vars_
.
push_back
(
var_helper
.
name_
);
}
}
// namespace generate_pass
PassPairs
::
PassPairs
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
)
{
AddPassDesc
(
pattern
,
replace
);
}
void
PassPairs
::
AddPassDesc
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
)
{
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc_
.
add_pass_descs
();
pass_desc
->
mutable_pattern
()
->
CopyFrom
(
pattern
.
ProgramDesc
());
pass_desc
->
mutable_replace
()
->
CopyFrom
(
replace
.
ProgramDesc
());
PADDLE_ENFORCE_EQ
(
pattern
.
InputVars
().
size
(),
replace
.
InputVars
().
size
(),
platform
::
errors
::
InvalidArgument
(
"Size of lambda expression arguments is not equal "
"between pattern/replace subgraph."
));
for
(
size_t
i
=
0
;
i
<
pattern
.
InputVars
().
size
();
i
++
)
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
pattern
.
InputVars
()[
i
]);
var_map
->
set_replace_var
(
replace
.
InputVars
()[
i
]);
}
PADDLE_ENFORCE_EQ
(
pattern
.
OutputVars
().
size
(),
replace
.
OutputVars
().
size
(),
platform
::
errors
::
InvalidArgument
(
"Size of lambda expression returns is not equal "
"between pattern/replace subgraph."
));
for
(
size_t
i
=
0
;
i
<
pattern
.
OutputVars
().
size
();
i
++
)
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
pattern
.
OutputVars
()[
i
]);
var_map
->
set_replace_var
(
replace
.
OutputVars
()[
i
]);
}
}
const
proto
::
MultiPassDesc
&
PassPairs
::
MultiPassDesc
()
const
{
return
multi_pass_desc_
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/generate_pass.h
浏览文件 @
2fd8deea
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/pass_desc.pb.h"
#include "paddle/fluid/framework/pass_desc.pb.h"
...
@@ -43,6 +42,158 @@ class GeneratePass : public Pass {
...
@@ -43,6 +42,158 @@ class GeneratePass : public Pass {
proto
::
MultiPassDesc
multi_pass_desc_
;
proto
::
MultiPassDesc
multi_pass_desc_
;
};
};
namespace
generate_pass
{
class
VarHelper
;
class
OpHelper
;
class
SubgraphHelper
;
// VarHelper is used to represent a variable node.
struct
VarHelper
{
enum
class
Type
{
kInput
,
kOutput
};
explicit
VarHelper
(
const
char
*
name
);
VarHelper
(
const
std
::
string
&
name
,
Type
type
);
std
::
string
name_
;
Type
type_
;
};
// OpHelper is used to represent a operator node.
class
OpHelper
{
public:
// Convert multiple inputs.
struct
Arguments
{
Arguments
(
const
char
*
parameter
,
const
VarHelper
&
var_helper
);
Arguments
(
const
char
*
parameter
,
std
::
initializer_list
<
VarHelper
>
var_helpers
);
std
::
string
parameter_
;
std
::
vector
<
VarHelper
>
var_helpers_
;
};
OpHelper
(
const
char
*
type
,
SubgraphHelper
*
subgraph_helper
);
OpHelper
&
operator
()(
const
Arguments
&
input
);
OpHelper
&
operator
()(
std
::
initializer_list
<
Arguments
>
inputs
);
VarHelper
Out
(
const
char
*
name
);
private:
OpHelper
()
=
delete
;
DISABLE_COPY_AND_ASSIGN
(
OpHelper
);
const
char
*
type_
;
proto
::
OpDesc
*
op_desc_
;
SubgraphHelper
*
subgraph_helper_
;
};
/*
* SubgraphHelper is used to define pattern/replace subgraphs.
*
* Use lambda expression to define subgraph like Python. SubgraphHelper
* converts lambda expression to ProgramDesc.
*
* In order to define a subgraph, user need to use VarHelper and OpHelper.
* Use the macros instead of class names, so user can develop better and
* don't need to know too much about underlying implementation.
*
* An example of defining a subgraph as follows:
*
* SUBGRAPH_(subgraph)([subgraph=&subgraph](VAR_(x), VAR_(y), VAR_(z)) {
* auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out");
* auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out");
* return ewadd2;
* });
*
*/
class
SubgraphHelper
{
public:
SubgraphHelper
()
=
default
;
// The lambda expression is a prvalue expression.
template
<
typename
T
>
SubgraphHelper
&
operator
=
(
const
T
&&
f
)
{
proto
::
BlockDesc
*
block
=
program_desc_
.
add_blocks
();
block
->
set_idx
(
0
);
block
->
set_parent_idx
(
0
);
AddOutputVars
(
f
());
return
*
this
;
}
proto
::
ProgramDesc
*
ProgramDesc
();
const
proto
::
ProgramDesc
&
ProgramDesc
()
const
;
const
std
::
vector
<
std
::
string
>&
InputVars
()
const
;
const
std
::
vector
<
std
::
string
>&
OutputVars
()
const
;
void
AddInputVar
(
const
std
::
string
&
name
);
void
AddOutputVars
(
const
VarHelper
&
var_helper
);
template
<
size_t
i
,
typename
...
Ts
,
std
::
enable_if_t
<
i
+
1
<
sizeof
...(
Ts
)>
*
=
nullptr
>
void
AddOutputVars
(
const
std
::
tuple
<
Ts
...
>&
outputs
)
{
AddOutputVars
(
std
::
get
<
i
>
(
outputs
));
AddOutputVars
<
i
+
1
>
(
outputs
);
}
template
<
size_t
i
,
typename
...
Ts
,
std
::
enable_if_t
<
i
+
1
==
sizeof
...(
Ts
)>
*
=
nullptr
>
void
AddOutputVars
(
const
std
::
tuple
<
Ts
...
>&
outputs
)
{
AddOutputVars
(
std
::
get
<
i
>
(
outputs
));
}
template
<
typename
...
Ts
>
void
AddOutputVars
(
const
std
::
tuple
<
Ts
...
>&
outputs
)
{
AddOutputVars
<
0
>
(
outputs
);
}
private:
DISABLE_COPY_AND_ASSIGN
(
SubgraphHelper
);
std
::
vector
<
std
::
string
>
input_vars_
;
std
::
vector
<
std
::
string
>
output_vars_
;
proto
::
ProgramDesc
program_desc_
;
};
}
// namespace generate_pass
class
PassPairs
{
public:
using
SubgraphType
=
generate_pass
::
SubgraphHelper
;
PassPairs
()
=
default
;
PassPairs
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
);
void
AddPassDesc
(
const
SubgraphType
&
pattern
,
const
SubgraphType
&
replace
);
const
proto
::
MultiPassDesc
&
MultiPassDesc
()
const
;
private:
proto
::
MultiPassDesc
multi_pass_desc_
;
};
// Use function to register in CC.
template
<
PassPairs
(
*
Functor
)(
void
)>
class
MacroPassHelper
:
public
GeneratePass
{
public:
MacroPassHelper
()
:
GeneratePass
(
Functor
().
MultiPassDesc
())
{}
};
#define VAR_(name) \
::paddle::framework::ir::generate_pass::VarHelper name = \
::paddle::framework::ir::generate_pass::VarHelper(#name)
#define OP_(type) \
::paddle::framework::ir::generate_pass::OpHelper(#type, subgraph)
#define SUBGRAPH_(name) \
::paddle::framework::ir::generate_pass::SubgraphHelper name; \
name
#define REGISTER_GENERATE_PASS(pass_type) \
paddle::framework::ir::PassPairs register_##pass_type(); \
REGISTER_PASS( \
pass_type, \
::paddle::framework::ir::MacroPassHelper<®ister_##pass_type>); \
paddle::framework::ir::PassPairs register_##pass_type()
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/generate_pass_tester.cc
浏览文件 @
2fd8deea
...
@@ -16,234 +16,71 @@
...
@@ -16,234 +16,71 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
REGISTER_GENERATE_PASS
(
generate_fc_fuse
)
{
namespace
framework
{
paddle
::
framework
::
ir
::
PassPairs
pass_pairs
;
namespace
ir
{
template
<
proto
::
MultiPassDesc
(
*
Functor
)(
void
)>
class
CXXGeneratePass
:
public
GeneratePass
{
public:
CXXGeneratePass
()
:
GeneratePass
(
Functor
())
{}
};
#define REGISTER_GENERATE_PASS(pass_type, function) \
REGISTER_PASS(pass_type, ::paddle::framework::ir::CXXGeneratePass<&function>)
proto
::
MultiPassDesc
generate_fc_fuse
()
{
proto
::
MultiPassDesc
multi_pass_desc
;
for
(
bool
with_relu
:
{
true
,
false
})
{
for
(
bool
with_relu
:
{
true
,
false
})
{
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc
.
add_pass_descs
();
// pattern
proto
::
BlockDesc
*
pattern
=
pass_desc
->
mutable_pattern
()
->
add_blocks
();
SUBGRAPH_
(
pattern
)
=
pattern
->
set_idx
(
0
);
[
subgraph
=
&
pattern
,
with_relu
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
pattern
->
set_parent_idx
(
0
);
VLOG
(
3
)
<<
"exec lambda func."
;
proto
::
OpDesc
*
mul
=
pattern
->
add_ops
();
auto
mul
=
OP_
(
mul
)({{
"X"
,
x
},
{
"Y"
,
y
}}).
Out
(
"Out"
);
mul
->
set_type
(
"mul"
);
auto
ewadd
=
OP_
(
elementwise_add
)({{
"X"
,
mul
},
{
"Y"
,
z
}}).
Out
(
"Out"
);
proto
::
OpDesc
::
Var
*
mul_x
=
mul
->
add_inputs
();
if
(
with_relu
)
{
mul_x
->
set_parameter
(
"X"
);
return
OP_
(
relu
)({
"X"
,
ewadd
}).
Out
(
"Out"
);
mul_x
->
add_arguments
()
->
assign
(
"x"
);
}
else
{
proto
::
OpDesc
::
Var
*
mul_y
=
mul
->
add_inputs
();
return
ewadd
;
mul_y
->
set_parameter
(
"Y"
);
}
mul_y
->
add_arguments
()
->
assign
(
"w"
);
};
proto
::
OpDesc
::
Var
*
mul_out
=
mul
->
add_outputs
();
// replace
mul_out
->
set_parameter
(
"Out"
);
SUBGRAPH_
(
replace
)
=
mul_out
->
add_arguments
()
->
assign
(
"mul_out"
);
[
subgraph
=
&
replace
,
with_relu
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
proto
::
OpDesc
*
ewadd
=
pattern
->
add_ops
();
auto
&
fc
=
OP_
(
fc
)({{
"Input"
,
x
},
{
"W"
,
y
},
{
"Bias"
,
z
}});
ewadd
->
set_type
(
"elementwise_add"
);
return
fc
.
Out
(
"Out"
);
proto
::
OpDesc
::
Var
*
ewadd_x
=
ewadd
->
add_inputs
();
};
ewadd_x
->
set_parameter
(
"X"
);
pass_pairs
.
AddPassDesc
(
pattern
,
replace
);
ewadd_x
->
add_arguments
()
->
assign
(
"mul_out"
);
proto
::
OpDesc
::
Var
*
ewadd_y
=
ewadd
->
add_inputs
();
ewadd_y
->
set_parameter
(
"Y"
);
ewadd_y
->
add_arguments
()
->
assign
(
"b"
);
proto
::
OpDesc
::
Var
*
ewadd_out
=
ewadd
->
add_outputs
();
ewadd_out
->
set_parameter
(
"Out"
);
ewadd_out
->
add_arguments
()
->
assign
(
"ewadd_out"
);
proto
::
OpDesc
*
relu
=
nullptr
;
proto
::
BlockDesc
*
replace
=
pass_desc
->
mutable_replace
()
->
add_blocks
();
replace
->
set_idx
(
0
);
replace
->
set_parent_idx
(
0
);
proto
::
OpDesc
*
fc
=
replace
->
add_ops
();
fc
->
set_type
(
"fc"
);
proto
::
OpDesc
::
Var
*
fc_x
=
fc
->
add_inputs
();
fc_x
->
set_parameter
(
"Input"
);
fc_x
->
add_arguments
()
->
assign
(
"x"
);
proto
::
OpDesc
::
Var
*
fc_w
=
fc
->
add_inputs
();
fc_w
->
set_parameter
(
"W"
);
fc_w
->
add_arguments
()
->
assign
(
"w"
);
proto
::
OpDesc
::
Var
*
fc_b
=
fc
->
add_inputs
();
fc_b
->
set_parameter
(
"Bias"
);
fc_b
->
add_arguments
()
->
assign
(
"b"
);
proto
::
OpDesc
::
Var
*
fc_out
=
fc
->
add_outputs
();
fc_out
->
set_parameter
(
"Out"
);
fc_out
->
add_arguments
()
->
assign
(
"fc_out"
);
for
(
const
char
*
var
:
{
"x"
,
"w"
,
"b"
,
"fc_out"
})
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
var
);
var_map
->
set_replace_var
(
var
);
}
proto
::
PassDesc
::
AttrMap
*
attr_map
=
pass_desc
->
add_attr_maps
();
attr_map
->
set_pattern_op_idx
(
0
);
attr_map
->
set_pattern_name
(
"x_num_col_dims"
);
attr_map
->
set_replace_op_idx
(
0
);
attr_map
->
set_replace_name
(
"in_num_col_dims"
);
if
(
with_relu
)
{
relu
=
pattern
->
add_ops
();
relu
->
set_type
(
"relu"
);
proto
::
OpDesc
::
Var
*
relu_x
=
relu
->
add_inputs
();
relu_x
->
set_parameter
(
"X"
);
relu_x
->
add_arguments
()
->
assign
(
"ewadd_out"
);
proto
::
OpDesc
::
Var
*
relu_out
=
relu
->
add_outputs
();
relu_out
->
set_parameter
(
"Out"
);
relu_out
->
add_arguments
()
->
assign
(
"relu_out"
);
pass_desc
->
mutable_var_maps
(
3
)
->
set_pattern_var
(
"relu_out"
);
proto
::
OpDesc
::
Attr
*
attr
=
fc
->
add_attrs
();
attr
->
set_name
(
"activation_type"
);
attr
->
set_type
(
proto
::
AttrType
::
STRING
);
attr
->
set_s
(
"relu"
);
}
else
{
pass_desc
->
mutable_var_maps
(
3
)
->
set_pattern_var
(
"ewadd_out"
);
}
}
}
return
multi_pass_desc
;
return
pass_pairs
;
}
}
proto
::
MultiPassDesc
generate_multi_add_to_addn
()
{
REGISTER_GENERATE_PASS
(
generate_multi_add_to_addn
)
{
proto
::
MultiPassDesc
multi_pass_desc
;
// pattern
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc
.
add_pass_descs
();
SUBGRAPH_
(
pattern
)
=
[
subgraph
=
&
pattern
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
proto
::
BlockDesc
*
pattern
=
pass_desc
->
mutable_pattern
()
->
add_blocks
();
auto
ewadd1
=
OP_
(
elementwise_add
)({{
"X"
,
x
},
{
"Y"
,
y
}}).
Out
(
"Out"
);
proto
::
OpDesc
*
ewadd_0
=
pattern
->
add_ops
();
auto
ewadd2
=
OP_
(
elementwise_add
)({{
"X"
,
ewadd1
},
{
"Y"
,
z
}}).
Out
(
"Out"
);
ewadd_0
->
set_type
(
"elementwise_add"
);
return
ewadd2
;
proto
::
OpDesc
::
Var
*
ewadd_0_x
=
ewadd_0
->
add_inputs
();
};
ewadd_0_x
->
set_parameter
(
"X"
);
// replace
ewadd_0_x
->
add_arguments
()
->
assign
(
"a"
);
SUBGRAPH_
(
replace
)
=
[
subgraph
=
&
replace
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
proto
::
OpDesc
::
Var
*
ewadd_0_y
=
ewadd_0
->
add_inputs
();
return
OP_
(
sum
)({
"X"
,
{
x
,
y
,
z
}}).
Out
(
"Out"
);
ewadd_0_y
->
set_parameter
(
"Y"
);
};
ewadd_0_y
->
add_arguments
()
->
assign
(
"b"
);
return
{
pattern
,
replace
};
proto
::
OpDesc
::
Var
*
ewadd_0_out
=
ewadd_0
->
add_outputs
();
ewadd_0_out
->
set_parameter
(
"Out"
);
ewadd_0_out
->
add_arguments
()
->
assign
(
"ewadd_out_0"
);
proto
::
OpDesc
*
ewadd_1
=
pattern
->
add_ops
();
ewadd_1
->
set_type
(
"elementwise_add"
);
proto
::
OpDesc
::
Var
*
ewadd_1_x
=
ewadd_1
->
add_inputs
();
ewadd_1_x
->
set_parameter
(
"X"
);
ewadd_1_x
->
add_arguments
()
->
assign
(
"ewadd_out_0"
);
proto
::
OpDesc
::
Var
*
ewadd_1_y
=
ewadd_1
->
add_inputs
();
ewadd_1_y
->
set_parameter
(
"Y"
);
ewadd_1_y
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
ewadd_1_out
=
ewadd_1
->
add_outputs
();
ewadd_1_out
->
set_parameter
(
"Out"
);
ewadd_1_out
->
add_arguments
()
->
assign
(
"ewadd_out_1"
);
proto
::
BlockDesc
*
replace
=
pass_desc
->
mutable_replace
()
->
add_blocks
();
proto
::
OpDesc
*
addn
=
replace
->
add_ops
();
addn
->
set_type
(
"add_n"
);
proto
::
OpDesc
::
Var
*
addn_x
=
addn
->
add_inputs
();
addn_x
->
set_parameter
(
"X"
);
addn_x
->
add_arguments
()
->
assign
(
"a"
);
addn_x
->
add_arguments
()
->
assign
(
"b"
);
addn_x
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
addn_out
=
addn
->
add_outputs
();
addn_out
->
set_parameter
(
"Out"
);
addn_out
->
add_arguments
()
->
assign
(
"addn_out"
);
for
(
const
char
*
var
:
{
"a"
,
"b"
,
"c"
,
"ewadd_out_1"
})
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
var
);
var_map
->
set_replace_var
(
var
);
}
pass_desc
->
mutable_var_maps
(
3
)
->
set_replace_var
(
"addn_out"
);
return
multi_pass_desc
;
}
}
proto
::
MultiPassDesc
generate_combine_matmul
()
{
REGISTER_GENERATE_PASS
(
generate_combine_matmul
)
{
proto
::
MultiPassDesc
multi_pass_desc
;
// pattern
proto
::
PassDesc
*
pass_desc
=
multi_pass_desc
.
add_pass_descs
();
SUBGRAPH_
(
pattern
)
=
[
subgraph
=
&
pattern
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
proto
::
BlockDesc
*
pattern
=
pass_desc
->
mutable_pattern
()
->
add_blocks
();
auto
matmul1
=
OP_
(
matmul
)({{
"X"
,
x
},
{
"Y"
,
y
}}).
Out
(
"Out"
);
proto
::
OpDesc
*
matmul_0
=
pattern
->
add_ops
();
auto
matmul2
=
OP_
(
matmul
)({{
"X"
,
x
},
{
"Y"
,
z
}}).
Out
(
"Out"
);
matmul_0
->
set_type
(
"matmul"
);
return
std
::
make_tuple
(
matmul1
,
matmul2
);
proto
::
OpDesc
::
Var
*
matmul_0_x
=
matmul_0
->
add_inputs
();
};
matmul_0_x
->
set_parameter
(
"X"
);
// replace
matmul_0_x
->
add_arguments
()
->
assign
(
"a"
);
SUBGRAPH_
(
replace
)
=
[
subgraph
=
&
replace
](
VAR_
(
x
),
VAR_
(
y
),
VAR_
(
z
))
{
proto
::
OpDesc
::
Var
*
matmul_0_y
=
matmul_0
->
add_inputs
();
auto
concat
=
OP_
(
concat
)({
"X"
,
{
y
,
z
}}).
Out
(
"Out"
);
matmul_0_y
->
set_parameter
(
"Y"
);
auto
matmul
=
OP_
(
matmul
)({{
"X"
,
x
},
{
"Y"
,
concat
}}).
Out
(
"Out"
);
matmul_0_y
->
add_arguments
()
->
assign
(
"b"
);
auto
slice1
=
OP_
(
slice
)({
"X"
,
matmul
}).
Out
(
"Out"
);
proto
::
OpDesc
::
Var
*
matmul_0_out
=
matmul_0
->
add_outputs
();
auto
slice2
=
OP_
(
slice
)({
"X"
,
matmul
}).
Out
(
"Out"
);
matmul_0_out
->
set_parameter
(
"Out"
);
return
std
::
make_tuple
(
slice1
,
slice2
);
matmul_0_out
->
add_arguments
()
->
assign
(
"matmul_out_0"
);
};
proto
::
OpDesc
*
matmul_1
=
pattern
->
add_ops
();
return
{
pattern
,
replace
};
matmul_1
->
set_type
(
"matmul"
);
proto
::
OpDesc
::
Var
*
matmul_1_x
=
matmul_1
->
add_inputs
();
matmul_1_x
->
set_parameter
(
"X"
);
matmul_1_x
->
add_arguments
()
->
assign
(
"a"
);
proto
::
OpDesc
::
Var
*
matmul_1_y
=
matmul_1
->
add_inputs
();
matmul_1_y
->
set_parameter
(
"Y"
);
matmul_1_y
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
matmul_1_out
=
matmul_1
->
add_outputs
();
matmul_1_out
->
set_parameter
(
"Out"
);
matmul_1_out
->
add_arguments
()
->
assign
(
"matmul_out_1"
);
proto
::
BlockDesc
*
replace
=
pass_desc
->
mutable_replace
()
->
add_blocks
();
proto
::
OpDesc
*
concat
=
replace
->
add_ops
();
concat
->
set_type
(
"concat"
);
proto
::
OpDesc
::
Var
*
concat_x
=
concat
->
add_inputs
();
concat_x
->
set_parameter
(
"X"
);
concat_x
->
add_arguments
()
->
assign
(
"b"
);
concat_x
->
add_arguments
()
->
assign
(
"c"
);
proto
::
OpDesc
::
Var
*
concat_out
=
concat
->
add_outputs
();
concat_out
->
set_parameter
(
"Out"
);
concat_out
->
add_arguments
()
->
assign
(
"concat_out"
);
proto
::
OpDesc
*
matmul
=
replace
->
add_ops
();
matmul
->
set_type
(
"matmul"
);
proto
::
OpDesc
::
Var
*
matmul_x
=
matmul
->
add_inputs
();
matmul_x
->
set_parameter
(
"X"
);
matmul_x
->
add_arguments
()
->
assign
(
"a"
);
proto
::
OpDesc
::
Var
*
matmul_y
=
matmul
->
add_inputs
();
matmul_y
->
set_parameter
(
"Y"
);
matmul_y
->
add_arguments
()
->
assign
(
"concat_out"
);
proto
::
OpDesc
::
Var
*
matmul_out
=
matmul
->
add_outputs
();
matmul_out
->
set_parameter
(
"Out"
);
matmul_out
->
add_arguments
()
->
assign
(
"matmul_out"
);
proto
::
OpDesc
*
slice_0
=
replace
->
add_ops
();
slice_0
->
set_type
(
"slice"
);
proto
::
OpDesc
::
Var
*
slice_0_x
=
slice_0
->
add_inputs
();
slice_0_x
->
set_parameter
(
"X"
);
slice_0_x
->
add_arguments
()
->
assign
(
"matmul_out"
);
proto
::
OpDesc
::
Var
*
slice_0_out
=
slice_0
->
add_outputs
();
slice_0_out
->
set_parameter
(
"Out"
);
slice_0_out
->
add_arguments
()
->
assign
(
"slice_out_0"
);
proto
::
OpDesc
*
slice_1
=
replace
->
add_ops
();
slice_1
->
set_type
(
"slice"
);
proto
::
OpDesc
::
Var
*
slice_1_x
=
slice_1
->
add_inputs
();
slice_1_x
->
set_parameter
(
"X"
);
slice_1_x
->
add_arguments
()
->
assign
(
"matmul_out"
);
proto
::
OpDesc
::
Var
*
slice_1_out
=
slice_1
->
add_outputs
();
slice_1_out
->
set_parameter
(
"Out"
);
slice_1_out
->
add_arguments
()
->
assign
(
"slice_out_1"
);
for
(
const
char
*
var
:
{
"a"
,
"b"
,
"c"
,
"matmul_out_0"
,
"matmul_out_1"
})
{
proto
::
PassDesc
::
VarMap
*
var_map
=
pass_desc
->
add_var_maps
();
var_map
->
set_pattern_var
(
var
);
var_map
->
set_replace_var
(
var
);
}
pass_desc
->
mutable_var_maps
(
3
)
->
set_replace_var
(
"slice_out_0"
);
pass_desc
->
mutable_var_maps
(
4
)
->
set_replace_var
(
"slice_out_1"
);
return
multi_pass_desc
;
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_GENERATE_PASS
(
generate_fc_fuse
,
paddle
::
framework
::
ir
::
generate_fc_fuse
);
REGISTER_GENERATE_PASS
(
generate_multi_add_to_addn
,
paddle
::
framework
::
ir
::
generate_multi_add_to_addn
);
REGISTER_GENERATE_PASS
(
generate_combine_matmul
,
paddle
::
framework
::
ir
::
generate_combine_matmul
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
TEST
(
GeneratePass
,
construct_with_string
)
{
TEST
(
GeneratePass
,
construct_with_string
)
{
std
::
string
binary_str
;
std
::
string
binary_str
;
generate_fc_fuse
().
SerializeToString
(
&
binary_str
);
register_generate_fc_fuse
().
MultiPassDesc
().
SerializeToString
(
&
binary_str
);
GeneratePass
generate_pass
(
binary_str
);
GeneratePass
generate_pass
(
binary_str
);
}
}
...
@@ -318,7 +155,7 @@ TEST(GeneratePass, generate_multi_add_to_addn) {
...
@@ -318,7 +155,7 @@ TEST(GeneratePass, generate_multi_add_to_addn) {
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
int
num_nodes_after
=
graph
->
Nodes
().
size
();
int
num_addn_nodes_after
=
GetNumOpNodes
(
graph
,
"
add_n
"
);
int
num_addn_nodes_after
=
GetNumOpNodes
(
graph
,
"
sum
"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
2
,
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
2
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录