Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ca9e77a8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ca9e77a8
编写于
3月 01, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
3月 01, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sum op support for fusion group (#22771)
* Add the codegen and auto fusion for sum Op in fusion group
上级
b681215a
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
76 addition
and
46 deletion
+76
-46
paddle/fluid/framework/ir/fusion_group/code_generator.cc
paddle/fluid/framework/ir/fusion_group/code_generator.cc
+12
-9
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
.../fluid/framework/ir/fusion_group/code_generator_helper.cc
+25
-2
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
...e/fluid/framework/ir/fusion_group/code_generator_helper.h
+2
-1
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
...d/framework/ir/fusion_group/elementwise_group_detector.cc
+7
-30
paddle/fluid/framework/ir/fusion_group/operation.cc
paddle/fluid/framework/ir/fusion_group/operation.cc
+15
-3
paddle/fluid/framework/ir/fusion_group/operation.h
paddle/fluid/framework/ir/fusion_group/operation.h
+2
-1
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
...dle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
+13
-0
未找到文件。
paddle/fluid/framework/ir/fusion_group/code_generator.cc
浏览文件 @
ca9e77a8
...
...
@@ -60,18 +60,21 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// - X, Y in forward operations
// - X, Y, Out, out@GRAD in backward operations
std
::
vector
<
int
>
input_ids
;
std
::
vector
<
std
::
string
>
input_names
=
OperationMap
::
Instance
().
Get
(
op
->
Type
()).
input_names
;
auto
operation
=
OperationMap
::
Instance
().
Get
(
op
->
Type
());
std
::
vector
<
std
::
string
>
input_names
=
operation
.
input_names
;
for
(
auto
&
name
:
input_names
)
{
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if
(
HasInput
(
node
,
name
)
&&
op
->
Input
(
name
).
size
()
>=
1U
)
{
// TODO(liuyiqun): support duplicated input.
if
((
HasInput
(
node
,
name
)
&&
op
->
Input
(
name
).
size
()
>=
1U
))
{
for
(
size_t
i
=
0
;
i
<
op
->
Input
(
name
).
size
();
i
++
)
{
PADDLE_ENFORCE_NE
(
var_ids
.
find
(
op
->
Input
(
name
)[
0
]),
var_ids
.
end
(),
var_ids
.
find
(
op
->
Input
(
name
)[
i
]),
var_ids
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Input(%s) of operation %s is not set."
,
name
,
op
->
Type
()));
input_ids
.
push_back
(
var_ids
[
op
->
Input
(
name
)[
0
]]);
input_ids
.
push_back
(
var_ids
[
op
->
Input
(
name
)[
i
]]);
}
}
else
{
input_ids
.
push_back
(
-
1
);
}
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
浏览文件 @
ca9e77a8
...
...
@@ -33,9 +33,32 @@ static T StringTo(const std::string& str) {
return
value
;
}
static
std
::
string
ExpandMultivariateTemplate
(
const
std
::
string
rhs
,
const
size_t
input_size
)
{
int
start_pos
=
rhs
.
find
(
"["
,
0
);
int
end_pos
=
rhs
.
find
(
"]"
,
0
);
std
::
string
sum_rhs
=
rhs
.
substr
(
0
,
start_pos
);
std
::
string
sum_rhs_component
=
rhs
.
substr
(
start_pos
+
1
,
(
end_pos
-
start_pos
-
1
));
int
replace_pos
=
sum_rhs_component
.
find
(
"?"
,
0
);
for
(
size_t
i
=
1
;
i
<
input_size
;
i
++
)
{
std
::
string
append_str
=
sum_rhs_component
.
replace
(
replace_pos
,
1
,
std
::
to_string
(
i
));
sum_rhs
=
sum_rhs
+
append_str
;
}
return
sum_rhs
;
}
std
::
string
OperationExpression
::
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
size_t
i
)
const
{
auto
rhs
=
OperationMap
::
Instance
().
Get
(
op_type_
).
exprs
[
i
];
size_t
exprs_index
)
const
{
auto
rhs
=
OperationMap
::
Instance
().
Get
(
op_type_
).
exprs
[
exprs_index
];
auto
num_operands
=
OperationMap
::
Instance
().
Get
(
op_type_
).
num_operands
;
if
(
num_operands
==
-
1
)
{
size_t
input_size
=
input_ids_
.
size
();
rhs
=
ExpandMultivariateTemplate
(
rhs
,
input_size
);
}
for
(
size_t
i
=
0
;
i
<
rhs
.
size
();
i
++
)
{
size_t
pos
=
i
;
if
(
rhs
[
pos
]
==
'$'
&&
rhs
[
pos
+
1
]
==
'{'
)
{
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
浏览文件 @
ca9e77a8
...
...
@@ -52,7 +52,8 @@ class OperationExpression {
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
std
::
string
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
size_t
i
=
0
)
const
;
std
::
string
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
size_t
exprs_index
=
0
)
const
;
std
::
string
GetLHS
(
size_t
i
=
0
)
const
;
private:
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
浏览文件 @
ca9e77a8
...
...
@@ -24,23 +24,13 @@ namespace framework {
namespace
ir
{
namespace
fusion_group
{
static
std
::
unordered_set
<
std
::
string
>
binary_op_types
;
static
std
::
unordered_set
<
std
::
string
>
unary_op_types
;
static
std
::
unordered_set
<
std
::
string
>
elementwise_op_types
;
static
std
::
unordered_set
<
std
::
string
>&
GetBinaryOpTypes
()
{
if
(
binary_op_types
.
empty
())
{
binary_op_types
=
OperationMap
::
Instance
().
Find
(
/* type= */
0
,
/* num_operands= */
2
);
static
std
::
unordered_set
<
std
::
string
>&
GetElementwiseOpTypes
()
{
if
(
elementwise_op_types
.
empty
())
{
elementwise_op_types
=
OperationMap
::
Instance
().
Find
(
/* type= */
0
);
}
return
binary_op_types
;
}
static
std
::
unordered_set
<
std
::
string
>&
GetUnaryOpTypes
()
{
if
(
unary_op_types
.
empty
())
{
unary_op_types
=
OperationMap
::
Instance
().
Find
(
/* type= */
0
,
/* num_operands= */
1
);
}
return
unary_op_types
;
return
elementwise_op_types
;
}
static
bool
IsSpecifiedOp
(
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
...
...
@@ -70,13 +60,8 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return
l
.
size
()
!=
0U
&&
r
.
size
()
!=
0U
&&
l
==
r
;
}
static
bool
IsBinaryOp
(
const
Node
*
n
)
{
if
(
IsSpecifiedOp
(
GetBinaryOpTypes
(),
n
))
{
if
((
!
IsGradOp
(
n
)
&&
n
->
inputs
.
size
()
!=
2U
)
||
n
->
inputs
.
size
()
==
0U
)
{
return
false
;
}
// The shape of all inputs should be the same.
bool
ElementwiseGroupDetector
::
IsElementwiseOp
(
const
Node
*
n
)
{
if
(
IsSpecifiedOp
(
GetElementwiseOpTypes
(),
n
))
{
std
::
vector
<
int64_t
>
shape_0
;
for
(
size_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
auto
*
in_i
=
n
->
inputs
[
i
];
...
...
@@ -98,14 +83,6 @@ static bool IsBinaryOp(const Node* n) {
return
false
;
}
static
bool
IsUnaryOp
(
const
Node
*
n
)
{
return
IsSpecifiedOp
(
GetUnaryOpTypes
(),
n
);
}
bool
ElementwiseGroupDetector
::
IsElementwiseOp
(
const
Node
*
n
)
{
return
IsBinaryOp
(
n
)
||
IsUnaryOp
(
n
);
}
std
::
vector
<
std
::
vector
<
Node
*>>
ElementwiseGroupDetector
::
operator
()(
Graph
*
graph
)
{
auto
teller
=
[
&
](
const
Node
*
n
)
->
bool
{
return
IsElementwiseOp
(
n
);
};
...
...
paddle/fluid/framework/ir/fusion_group/operation.cc
浏览文件 @
ca9e77a8
...
...
@@ -25,13 +25,13 @@ OperationMap* OperationMap::map = nullptr;
OperationMap
::
OperationMap
()
{
InsertUnaryElementwiseOperations
();
InsertBinaryElementwiseOperations
();
InsertMultivariateElementwiseOperations
();
}
std
::
unordered_set
<
std
::
string
>
OperationMap
::
Find
(
int
type
,
int
num_operands
)
{
std
::
unordered_set
<
std
::
string
>
OperationMap
::
Find
(
int
type
)
{
std
::
unordered_set
<
std
::
string
>
res
;
for
(
auto
&
t
:
operations_
)
{
if
((
t
.
second
.
type
==
type
)
&&
(
num_operands
<
0
||
t
.
second
.
num_operands
==
num_operands
))
{
if
(
t
.
second
.
type
==
type
)
{
res
.
insert
(
t
.
first
);
}
}
...
...
@@ -153,6 +153,18 @@ void OperationMap::InsertBinaryElementwiseOperations() {
{
"${3} * (${0} > ${1})"
,
"${3} * (${0} <= ${1})"
});
}
void
OperationMap
::
InsertMultivariateElementwiseOperations
()
{
auto
insert_handler
=
[
&
](
std
::
string
op_type
,
std
::
string
expr
,
std
::
vector
<
std
::
string
>
grad_exprs
)
{
int
type
=
0
;
int
num_oprands
=
-
1
;
// here ... represent the number of input is changed
Insert
(
type
,
num_oprands
,
op_type
,
expr
,
grad_exprs
,
{
"X"
},
{
"Out"
});
};
insert_handler
(
"sum"
,
"${0}[ + ${?}]"
,
{});
}
}
// namespace fusion_group
}
// namespace ir
}
// namespace framework
...
...
paddle/fluid/framework/ir/fusion_group/operation.h
浏览文件 @
ca9e77a8
...
...
@@ -84,7 +84,7 @@ class OperationMap {
return
*
map
;
}
std
::
unordered_set
<
std
::
string
>
Find
(
int
type
,
int
num_operands
=
-
1
);
std
::
unordered_set
<
std
::
string
>
Find
(
int
type
);
bool
Has
(
std
::
string
op_type
)
{
return
operations_
.
find
(
op_type
)
!=
operations_
.
end
();
...
...
@@ -106,6 +106,7 @@ class OperationMap {
void
InsertUnaryElementwiseOperations
();
void
InsertBinaryElementwiseOperations
();
void
InsertMultivariateElementwiseOperations
();
private:
static
OperationMap
*
map
;
...
...
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
浏览文件 @
ca9e77a8
...
...
@@ -138,5 +138,18 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
self
.
num_fused_ops
=
1
class
FusionGroupPassSumTest
(
FusionGroupPassTest
):
def
build_program
(
self
,
dtype
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
5
)
tmp_0
=
layers
.
elementwise_add
(
self
.
feed_vars
[
0
],
self
.
feed_vars
[
1
])
tmp_1
=
layers
.
sum
([
tmp_0
,
self
.
feed_vars
[
2
],
self
.
feed_vars
[
3
]])
tmp_2
=
layers
.
sum
([
tmp_1
,
self
.
feed_vars
[
4
]])
self
.
fetch_list
=
[
tmp_0
,
tmp_1
]
self
.
num_fused_ops
=
1
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录