Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ca552933
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ca552933
编写于
11月 30, 2022
作者:
Z
zhangbo9674
提交者:
GitHub
11月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add fuse_act_add_grad_pass (#48346)
* add fuse act add grad pass * polish code * refine code * add test * refine code
上级
e337d280
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
272 addition
and
12 deletion
+272
-12
paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc
paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc
+128
-1
paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h
paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h
+8
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+27
-2
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+21
-0
paddle/fluid/framework/new_executor/interpreter/data_transfer.cc
...fluid/framework/new_executor/interpreter/data_transfer.cc
+0
-1
paddle/fluid/operators/fused/fused_elemwise_activation_op.h
paddle/fluid/operators/fused/fused_elemwise_activation_op.h
+20
-5
python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py
...e/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py
+68
-3
未找到文件。
paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc
浏览文件 @
ca552933
...
...
@@ -31,6 +31,7 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
{
std
::
unordered_set
<
std
::
string
>
in_place_act_types
=
{
"relu_grad"
};
graph
=
FuseElewiseAddActInplaceGrad
(
graph
,
in_place_act_types
);
graph
=
FuseActElewiseAddInplaceGrad
(
graph
,
in_place_act_types
);
}
// Remove the removable intermediate_out.
...
...
@@ -110,7 +111,7 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle Fuse
ElewiseAddAct
fuse"
;
VLOG
(
4
)
<<
"handle Fuse
ActElewiseAdd
fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
act_elewise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ele_x
,
ele_x
,
act_elewise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
...
...
@@ -220,6 +221,86 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
return
graph
;
}
// the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
ir
::
Graph
*
FuseElewiseAddActPass
::
FuseActElewiseAddInplaceGrad
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
"act_elewise_add_grad"
,
graph
);
GraphPatternDetector
gpd
;
auto
*
d_out_var
=
gpd
.
mutable_pattern
()
->
NewNode
(
"act_elewise_add_grad_inplace/d_out_var"
)
->
AsInput
()
->
assert_is_ops_input
({
"elementwise_add_grad"
},
GradVarName
(
"Out"
));
patterns
::
ActElewiseAddInplaceGrad
act_elewise_add_grad_pattern
(
gpd
.
mutable_pattern
(),
"act_elewise_add_grad_inplace"
);
act_elewise_add_grad_pattern
(
d_out_var
,
act_types
);
int
found_elewise_add_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle ActFuseElewiseAddGrad1 fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
ele_add_grad_op
,
ele_add_grad_op
,
act_elewise_add_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_grad_op
,
act_grad_op
,
act_elewise_add_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
intermediate_var
,
intermediate_var
,
act_elewise_add_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
d_intermediate_var
,
d_intermediate_var
,
act_elewise_add_grad_pattern
);
std
::
string
d_out_var_n
=
subgraph
.
at
(
d_out_var
)
->
Name
();
std
::
string
intermediate_var_n
=
intermediate_var
->
Name
();
std
::
string
d_intermediate_var_n
=
d_intermediate_var
->
Name
();
OpDesc
desc
;
desc
.
SetType
(
"fused_elemwise_add_activation_grad"
);
desc
.
SetInput
(
"IntermediateOut"
,
std
::
vector
<
std
::
string
>
({
intermediate_var_n
}));
desc
.
SetInput
(
"X"
,
{});
desc
.
SetInput
(
"Y"
,
ele_add_grad_op
->
Op
()
->
Input
(
"X"
));
desc
.
SetInput
(
"Out"
,
{});
desc
.
SetInput
(
GradVarName
(
"Out"
),
std
::
vector
<
std
::
string
>
({
d_out_var_n
}));
desc
.
SetOutput
(
GradVarName
(
"X"
),
act_grad_op
->
Op
()
->
Output
(
GradVarName
(
"X"
)));
desc
.
SetOutput
(
GradVarName
(
"Y"
),
ele_add_grad_op
->
Op
()
->
Output
(
GradVarName
(
"X"
)));
desc
.
SetOutput
(
GradVarName
(
"IntermediateOut"
),
std
::
vector
<
std
::
string
>
({
d_intermediate_var_n
}));
desc
.
SetAttr
(
"save_intermediate_out"
,
false
);
desc
.
SetAttr
(
"functor_list"
,
std
::
vector
<
std
::
string
>
({
ele_add_grad_op
->
Op
()
->
Type
(),
act_grad_op
->
Op
()
->
Type
()}));
for
(
auto
&
n
:
{
ele_add_grad_op
->
Op
(),
act_grad_op
->
Op
()})
{
for
(
auto
&
m_ele
:
n
->
GetAttrMap
())
{
desc
.
SetAttr
(
m_ele
.
first
,
m_ele
.
second
);
}
}
auto
fused_node
=
g
->
CreateOpNode
(
&
desc
);
VLOG
(
4
)
<<
"
\n\t
"
<<
d_out_var_n
<<
" -> "
<<
ele_add_grad_op
->
Name
()
<<
" -> "
<<
d_intermediate_var_n
<<
"
\n\t
"
<<
intermediate_var_n
<<
" and "
<<
d_intermediate_var_n
<<
" -> "
<<
act_grad_op
->
Name
();
ReLinkNodes2
(
g
,
d_intermediate_var
,
ele_add_grad_op
,
act_grad_op
,
fused_node
);
found_elewise_add_act_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_elewise_add_act_count
);
return
graph
;
}
Node
*
FuseElewiseAddActPass
::
CreateFuseElewiseAddActNode
(
Graph
*
g
,
const
Node
*
op_1
,
...
...
@@ -364,6 +445,52 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
GraphSafeRemoveNodes
(
graph
,
nodes2delete
);
}
void
FuseElewiseAddActPass
::
ReLinkNodes2
(
Graph
*
graph
,
const
Node
*
intermediate_out
,
Node
*
op_1
,
Node
*
op_2
,
Node
*
fused_op
)
const
{
// delete act
for
(
auto
&
in
:
op_1
->
inputs
)
{
fused_op
->
inputs
.
emplace_back
(
in
);
in
->
outputs
=
this
->
ReplaceNode
(
op_1
,
fused_op
,
in
->
outputs
);
}
std
::
unordered_set
<
const
Node
*>
nodes2delete
;
for
(
auto
&
out
:
op_1
->
outputs
)
{
if
(
out
->
IsCtrlVar
())
{
auto
result_iter
=
std
::
find_if
(
op_2
->
inputs
.
begin
(),
op_2
->
inputs
.
end
(),
[
&
out
](
const
Node
*
node
)
->
bool
{
return
node
==
out
;
});
if
(
result_iter
==
op_2
->
inputs
.
end
())
{
IR_OP_VAR_LINK
(
fused_op
,
out
);
}
else
{
nodes2delete
.
emplace
(
out
);
}
}
else
{
IR_OP_VAR_LINK
(
fused_op
,
out
);
}
}
for
(
auto
&
in
:
op_2
->
inputs
)
{
if
(
in
==
intermediate_out
||
nodes2delete
.
count
(
in
))
{
continue
;
}
fused_op
->
inputs
.
emplace_back
(
in
);
in
->
outputs
=
this
->
ReplaceNode
(
op_2
,
fused_op
,
in
->
outputs
);
}
for
(
auto
&
out
:
op_2
->
outputs
)
{
IR_OP_VAR_LINK
(
fused_op
,
out
);
}
nodes2delete
.
insert
(
std
::
move
(
op_1
));
nodes2delete
.
insert
(
std
::
move
(
op_2
));
GraphSafeRemoveNodes
(
graph
,
nodes2delete
);
}
std
::
vector
<
Node
*>
FuseElewiseAddActPass
::
ReplaceNode
(
Node
*
cur_node
,
Node
*
new_node
,
const
std
::
vector
<
Node
*>
&
nodes
)
const
{
std
::
vector
<
Node
*>
new_list
(
nodes
.
size
());
...
...
paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h
浏览文件 @
ca552933
...
...
@@ -49,6 +49,9 @@ class FuseElewiseAddActPass : public FusePassBase {
ir
::
Graph
*
FuseElewiseAddActInplaceGrad
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
)
const
;
ir
::
Graph
*
FuseActElewiseAddInplaceGrad
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
)
const
;
/**
* Remove the removable intermediate_out.
* - If the intermediate_out is only used by the backward op, but the
...
...
@@ -69,6 +72,11 @@ class FuseElewiseAddActPass : public FusePassBase {
Node
*
op_1
,
Node
*
op_2
,
Node
*
fused_op
)
const
;
void
ReLinkNodes2
(
Graph
*
graph
,
const
Node
*
intermediate_out
,
Node
*
op_1
,
Node
*
op_2
,
Node
*
fused_op
)
const
;
Node
*
CreateFuseElewiseAddActNode
(
Graph
*
g
,
const
Node
*
op_1
,
const
Node
*
op_2
,
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
ca552933
...
...
@@ -91,7 +91,6 @@ void GraphPatternDetector::operator()(Graph *graph,
if
(
!
MarkPDNodesInGraph
(
*
graph
))
{
return
;
}
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
SortSubgraphs
(
&
subgraphs
);
...
...
@@ -99,7 +98,6 @@ void GraphPatternDetector::operator()(Graph *graph,
ValidateByNodeRole
(
&
subgraphs
);
if
(
subgraphs
.
empty
())
return
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
VLOG
(
3
)
<<
"optimizing #"
<<
id
++
<<
" subgraph"
;
...
...
@@ -1613,6 +1611,33 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return
ele_add_grad
;
}
PDNode
*
patterns
::
ActElewiseAddInplaceGrad
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
d_out_var
,
std
::
unordered_set
<
std
::
string
>
act_types
)
{
VLOG
(
4
)
<<
"ActElewiseAddInplaceGrad::operator"
;
auto
*
ele_add_grad_op
=
pattern
->
NewNode
(
ele_add_grad_op_repr
())
->
assert_is_op
(
"elementwise_add_grad"
);
auto
*
act_grad_op
=
pattern
->
NewNode
(
act_grad_op_repr
())
->
assert_is_ops
(
act_types
);
auto
*
d_intermediate_out_var
=
pattern
->
NewNode
(
d_intermediate_var_repr
())
->
assert_is_op_output
(
"elementwise_add_grad"
,
GradVarName
(
"Y"
))
->
assert_is_ops_input
(
act_types
,
GradVarName
(
"Out"
));
auto
*
intermediate_out_var
=
pattern
->
NewNode
(
intermediate_var_repr
())
->
assert_is_op_input
(
"elementwise_add_grad"
,
"Y"
)
->
assert_is_ops_input
(
act_types
,
"Out"
);
ele_add_grad_op
->
LinksFrom
({
d_out_var
});
d_intermediate_out_var
->
LinksFrom
({
ele_add_grad_op
}).
LinksTo
({
act_grad_op
});
intermediate_out_var
->
LinksTo
({
ele_add_grad_op
});
intermediate_out_var
->
LinksTo
({
act_grad_op
});
return
act_grad_op
;
}
PDNode
*
patterns
::
ElewiseAddAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
ele_x_var
,
std
::
unordered_set
<
std
::
string
>
act_types
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
ca552933
...
...
@@ -928,6 +928,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
PATTERN_DECL_NODE
(
ele_y
);
};
// the backward of ele_add(act(x), y)
// the act is inplace.
// op: elementwise_add_grad + act_grad
// named nodes: elementwise_add_grad, act_grad
// ele_y, d_ele_y, d_intermeiate_out, intermediate_out, d_x
struct
ActElewiseAddInplaceGrad
:
public
PatternBase
{
ActElewiseAddInplaceGrad
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"act_elewise_add_grad1"
)
{}
// ele_add_grad: in["Y", "Out@GRAD"], out["IntermediateOut@GRAD", "Y@GRAD"]
// act_grad: in["IntermediateOut", "IntermediateOut@GRAD"], out["X@GRAD"]
PDNode
*
operator
()(
PDNode
*
d_out_var
,
std
::
unordered_set
<
std
::
string
>
acts
);
// declare operator node's name
PATTERN_DECL_NODE
(
ele_add_grad_op
);
PATTERN_DECL_NODE
(
act_grad_op
);
// // declare variable node's name
PATTERN_DECL_NODE
(
intermediate_var
);
PATTERN_DECL_NODE
(
d_intermediate_var
);
};
// The following patterns are used to fuse linear and act (ReLu or GeLU)
// formula: act(F.linear(x))
// op: matmul_v2 + elementwise_add + act
...
...
paddle/fluid/framework/new_executor/interpreter/data_transfer.cc
浏览文件 @
ca552933
...
...
@@ -462,7 +462,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
for
(
auto
&
var_name_item
:
*
ins_map_temp
)
{
bool
should_skip_input
=
no_buffer_ins
&&
no_buffer_ins
->
count
(
var_name_item
.
first
)
>
0
;
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
auto
var
=
var_name_item
.
second
[
i
];
auto
var_name
=
new_ins
[
var_name_item
.
first
].
at
(
i
);
...
...
paddle/fluid/operators/fused/fused_elemwise_activation_op.h
浏览文件 @
ca552933
...
...
@@ -664,11 +664,9 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
in_y
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) should not be nullptr."
));
auto
in_out
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Out"
);
PADDLE_ENFORCE_NE
(
in_out
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Input(Out) should not be nullptr."
));
phi
::
DenseTensor
*
in_out
=
const_cast
<
phi
::
DenseTensor
*>
(
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Out"
));
auto
in_out_grad
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"Out"
));
PADDLE_ENFORCE_NE
(
in_out_grad
,
...
...
@@ -726,6 +724,23 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
in_x
=
const_cast
<
phi
::
DenseTensor
*>
(
in_out_grad
);
}
// Get in_Out
if
(
ctx
.
HasInput
(
"Out"
))
{
PADDLE_ENFORCE_NE
(
in_out
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Input(X) should not be null."
));
}
else
{
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, in_y and in_out.
PADDLE_ENFORCE_EQ
(
InputXCanBeAbsent
(
functor_list
),
true
,
platform
::
errors
::
InvalidArgument
(
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent."
));
in_out
=
const_cast
<
phi
::
DenseTensor
*>
(
in_out_grad
);
}
bool
has_in_place
=
HasInPlaceUnary
(
functor_list
);
if
(
has_in_place
)
{
RunGradFunctors
<
DeviceContext
,
T
,
true
/*InPlace*/
>
(
ctx
,
...
...
python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py
浏览文件 @
ca552933
...
...
@@ -14,10 +14,11 @@
import
os
import
unittest
import
numpy
from
parallel_executor_test_base
import
DeviceType
,
TestParallelExecutorBase
from
simple_nets
import
fc_with_batchnorm
,
init_data
,
simple_fc_net
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
...
...
@@ -89,8 +90,72 @@ class TestMNIST(TestParallelExecutorBase):
)
if
__name__
==
'__main__'
:
import
paddle
class
TestFuseActElewiseAddInplaceGradPass
(
unittest
.
TestCase
):
def
build_program
(
self
,
main_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
X
=
fluid
.
data
(
name
=
"X"
,
shape
=
[
3
,
3
],
dtype
=
'float32'
)
Y
=
fluid
.
data
(
name
=
"Y"
,
shape
=
[
3
,
3
],
dtype
=
'float32'
)
Out1
=
X
*
5
Out2
=
fluid
.
layers
.
relu
(
Out1
)
prediction
=
fluid
.
layers
.
elementwise_add
(
Y
,
Out2
,
axis
=
1
)
loss
=
paddle
.
mean
(
prediction
)
sgd
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd
.
minimize
(
loss
)
return
X
,
Y
,
loss
def
check
(
self
,
place
):
paddle
.
seed
(
1
)
numpy
.
random
.
seed
(
1
)
paddle
.
framework
.
random
.
_manual_program_seed
(
1
)
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
X
,
Y
,
loss
=
self
.
build_program
(
main_program
,
startup_program
)
exe
=
fluid
.
Executor
(
place
)
x
=
numpy
.
random
.
random
(
size
=
(
3
,
3
)).
astype
(
'float32'
)
y
=
numpy
.
random
.
random
(
size
=
(
3
,
3
)).
astype
(
'float32'
)
label
=
numpy
.
random
.
random
(
size
=
(
3
,
3
)).
astype
(
'float32'
)
# open fused_pass
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
fuse_elewise_add_act_ops
=
True
compiled_prog_fused
=
paddle
.
static
.
CompiledProgram
(
main_program
,
build_strategy
=
build_strategy
)
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup_program
)
loss_data_fused
=
exe
.
run
(
compiled_prog_fused
,
feed
=
{
"X"
:
x
,
"Y"
:
y
},
fetch_list
=
[
loss
.
name
],
)
# close fused_pass
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
fuse_elewise_add_act_ops
=
False
compiled_prog
=
paddle
.
static
.
CompiledProgram
(
main_program
,
build_strategy
=
build_strategy
)
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup_program
)
loss_data
=
exe
.
run
(
compiled_prog
,
feed
=
{
"X"
:
x
,
"Y"
:
y
},
fetch_list
=
[
loss
.
name
]
)
self
.
assertEqual
(
loss_data_fused
,
loss_data
)
def
test_fuse_act_add_grad_pass_cpu
(
self
):
place
=
fluid
.
CPUPlace
()
self
.
check
(
place
)
def
test_fuse_act_add_grad_pass_cuda
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
fluid
.
CUDAPlace
(
0
)
self
.
check
(
place
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录