Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7f6bb160
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7f6bb160
编写于
6月 21, 2023
作者:
C
csy0225
提交者:
GitHub
6月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU][Inference] Delete redundant squeeze/unsqueeze op. (#54754)
上级
55704db5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
737 addition
and
0 deletion
+737
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/xpu/redundant_squeeze_unsqueeze_elimination_pass.cc
...rk/ir/xpu/redundant_squeeze_unsqueeze_elimination_pass.cc
+546
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
test/ir/inference/test_xpu_redundant_squeeze_unsqueeze_elimination.py
...rence/test_xpu_redundant_squeeze_unsqueeze_elimination.py
+188
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
7f6bb160
...
...
@@ -236,6 +236,8 @@ if(WITH_XPU)
set
(
XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils
)
pass_library
(
yolo_box_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv2d_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu
DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fc_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
...
...
paddle/fluid/framework/ir/xpu/redundant_squeeze_unsqueeze_elimination_pass.cc
0 → 100644
浏览文件 @
7f6bb160
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
// Delete redundant squeeze/unsqueeze op
/*
For example:
graph:
Input
|
|
squeeze
|
|
squeeze out
|
|
activation(leaky_relu)
|
|
activation out
|
|
unsqueeze
|
|
Output
------------------------------------------------------
After the pass is applied:
Input
|
|
activation(leaky_relu)
|
|
Output
*/
struct
SqueezeActivationUnsqueezeEliminationPattern
:
public
PatternBase
{
SqueezeActivationUnsqueezeEliminationPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
squeeze
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
unsqueeze
);
// declare variable node's name
PATTERN_DECL_NODE
(
squeeze_input
);
PATTERN_DECL_NODE
(
squeeze_out
);
PATTERN_DECL_NODE
(
act_out
);
PATTERN_DECL_NODE
(
unsqueeze_out
);
private:
std
::
string
act_type_
;
};
SqueezeActivationUnsqueezeEliminationPattern
::
SqueezeActivationUnsqueezeEliminationPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act_type
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
act_type_
(
act_type
)
{
// squeeze
auto
squeeze
=
pattern
->
NewNode
(
squeeze_repr
())
->
assert_is_op
(
"squeeze2"
);
auto
squeeze_input
=
pattern
->
NewNode
(
squeeze_input_repr
())
->
assert_is_op_input
(
"squeeze2"
,
"X"
)
->
AsInput
();
auto
squeeze_out
=
pattern
->
NewNode
(
squeeze_out_repr
())
->
assert_is_op_output
(
"squeeze2"
,
"Out"
);
squeeze
->
LinksFrom
({
squeeze_input
}).
LinksTo
({
squeeze_out
});
// activation
auto
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type_
);
auto
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type_
,
"Out"
);
squeeze_out
->
assert_is_op_input
(
act_type_
,
"X"
);
act
->
LinksFrom
({
squeeze_out
}).
LinksTo
({
act_out
});
// unsqueeze
auto
unsqueeze
=
pattern
->
NewNode
(
unsqueeze_repr
())
->
assert_is_op
(
"unsqueeze2"
);
auto
unsqueeze_out
=
pattern
->
NewNode
(
unsqueeze_out_repr
())
->
assert_is_op_output
(
"unsqueeze2"
,
"Out"
)
->
AsOutput
();
act_out
->
assert_is_op_input
(
"unsqueeze2"
,
"X"
);
unsqueeze
->
LinksFrom
({
act_out
}).
LinksTo
({
unsqueeze_out
});
}
/*
Function Description:Delete redundant squeeze/unsqueeze op
Pattern: custom pattern
For example:
graph:
Input1
|
|
squeeze1
|
|
squeeze1 out Input2
| |
| |
activation1(leaky_relu) squeeze2
| |
| |
activation1 out squeeze2 out
| |
| |
- - - - elementwise operation(elementwise_add) - - - -
|
|
activation2(leaky_relu)
|
|
activation2 out
|
|
- - - - - - - - - - - - - - - - - - -
| | | |
| | | |
unsqueeze 1 ...... unsqueeze n-1 unsqueeze n
| | | |
| | | |
Output 1 ...... Output n-1 Output n
------------------------------------------------------
After the pass is applied:
Input1
|
|
activation1(leaky_relu)
|
|
activation1 out Input2
| |
| |
- - - - elementwise operation(elementwise_add) - - - -
|
|
activation2(leaky_relu)
|
|
activation2 out
|
|
- - - - - - - - - - - - - - - - - - -
| | | |
| | | |
Output 1 ...... Output n-1 Output n
*/
struct
CustomSqueezeUnsqueezeEliminationPattern
:
public
PatternBase
{
CustomSqueezeUnsqueezeEliminationPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act1_type
,
const
std
::
string
&
act2_type
,
const
std
::
string
&
elementwise_type
,
const
bool
act1_in_branch_x
);
// declare operator node's name
PATTERN_DECL_NODE
(
squeeze1
);
PATTERN_DECL_NODE
(
squeeze2
);
PATTERN_DECL_NODE
(
act1
);
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
act2
);
// declare variable node's name
PATTERN_DECL_NODE
(
squeeze1_input
);
PATTERN_DECL_NODE
(
squeeze1_out
);
PATTERN_DECL_NODE
(
act1_out
);
PATTERN_DECL_NODE
(
squeeze2_input
);
PATTERN_DECL_NODE
(
squeeze2_out
);
PATTERN_DECL_NODE
(
elementwise_out
);
PATTERN_DECL_NODE
(
act2_out
);
private:
std
::
string
act1_type_
;
std
::
string
act2_type_
;
std
::
string
elementwise_type_
;
bool
act1_in_branch_x_
;
};
CustomSqueezeUnsqueezeEliminationPattern
::
CustomSqueezeUnsqueezeEliminationPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act1_type
,
const
std
::
string
&
act2_type
,
const
std
::
string
&
elementwise_type
,
const
bool
act1_in_branch_x
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
act1_type_
(
act1_type
),
act2_type_
(
act2_type
),
elementwise_type_
(
elementwise_type
),
act1_in_branch_x_
(
act1_in_branch_x
)
{
// squeeze1
auto
squeeze1
=
pattern
->
NewNode
(
squeeze1_repr
())
->
assert_is_op
(
"squeeze2"
);
auto
squeeze1_input
=
pattern
->
NewNode
(
squeeze1_input_repr
())
->
assert_is_op_input
(
"squeeze2"
,
"X"
)
->
AsInput
();
auto
squeeze1_out
=
pattern
->
NewNode
(
squeeze1_out_repr
())
->
assert_is_op_output
(
"squeeze2"
,
"Out"
);
squeeze1
->
LinksFrom
({
squeeze1_input
}).
LinksTo
({
squeeze1_out
});
// activation1
auto
act1
=
pattern
->
NewNode
(
act1_repr
())
->
assert_is_op
(
act1_type_
);
auto
act1_out
=
pattern
->
NewNode
(
act1_out_repr
())
->
assert_is_op_output
(
act1_type_
,
"Out"
);
squeeze1_out
->
assert_is_op_input
(
act1_type_
,
"X"
);
act1
->
LinksFrom
({
squeeze1_out
}).
LinksTo
({
act1_out
});
// squeeze2
auto
squeeze2
=
pattern
->
NewNode
(
squeeze2_repr
())
->
assert_is_op
(
"squeeze2"
);
auto
squeeze2_input
=
pattern
->
NewNode
(
squeeze2_input_repr
())
->
assert_is_op_input
(
"squeeze2"
,
"X"
)
->
AsInput
();
auto
squeeze2_out
=
pattern
->
NewNode
(
squeeze2_out_repr
())
->
assert_is_op_output
(
"squeeze2"
,
"Out"
);
squeeze2
->
LinksFrom
({
squeeze2_input
}).
LinksTo
({
squeeze2_out
});
// elementwise
auto
elementwise
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
elementwise_type_
);
auto
elementwise_out
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
assert_is_op_output
(
elementwise_type_
,
"Out"
);
if
(
act1_in_branch_x_
)
{
act1_out
->
assert_is_op_input
(
elementwise_type_
,
"X"
);
squeeze2_out
->
assert_is_op_input
(
elementwise_type_
,
"Y"
);
}
else
{
act1_out
->
assert_is_op_input
(
elementwise_type_
,
"Y"
);
squeeze2_out
->
assert_is_op_input
(
elementwise_type_
,
"X"
);
}
elementwise
->
LinksFrom
({
act1_out
,
squeeze2_out
}).
LinksTo
({
elementwise_out
});
// activation2
auto
act2
=
pattern
->
NewNode
(
act2_repr
())
->
assert_is_op
(
act2_type_
);
auto
act2_out
=
pattern
->
NewNode
(
act2_out_repr
())
->
assert_is_op_output
(
act2_type_
,
"Out"
);
elementwise_out
->
assert_is_op_input
(
act2_type_
,
"X"
);
act2
->
LinksFrom
({
elementwise_out
}).
LinksTo
({
act2_out
});
act2_out
->
AsOutput
();
}
}
// namespace patterns
class
SqueezeActivationUnsqueezeEliminationPass
:
public
FusePassBase
{
public:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
int
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
;
const
std
::
string
name_scope_
{
"squeeze_activation_unsqueeze_elimination_pass"
};
};
void
SqueezeActivationUnsqueezeEliminationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
std
::
vector
<
std
::
string
>
support_act_type
{
"relu"
,
"sigmoid"
,
"tanh"
,
"gelu"
,
"leaky_relu"
,
"hard_swish"
,
"hard_sigmoid"
,
"relu6"
,
"swish"
};
int
found_subgraph_count
=
0
;
for
(
auto
act_type
:
support_act_type
)
{
found_subgraph_count
+=
ApplyImpl
(
graph
,
act_type
);
}
AddStatis
(
found_subgraph_count
);
}
int
SqueezeActivationUnsqueezeEliminationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
GraphPatternDetector
gpd
;
patterns
::
SqueezeActivationUnsqueezeEliminationPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
act_type
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle squeeze activation unsqueeze elimination."
;
/* Get operator node's name */
GET_IR_NODE
(
squeeze
);
GET_IR_NODE
(
act
);
GET_IR_NODE
(
unsqueeze
);
/* Get variable node's name*/
GET_IR_NODE
(
squeeze_input
);
GET_IR_NODE
(
squeeze_out
);
GET_IR_NODE
(
act_out
);
GET_IR_NODE
(
unsqueeze_out
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
// Judge squeeze1 && squeeze2 op shape is same or not, if axes is same, the
// shape is same too.
std
::
vector
<
int
>
squeeze_axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
squeeze
->
Op
()
->
GetAttr
(
"axes"
));
std
::
vector
<
int
>
unsqueeze_axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
unsqueeze
->
Op
()
->
GetAttr
(
"axes"
));
bool
elimination
=
(
squeeze_axes
==
unsqueeze_axes
);
if
(
!
elimination
)
return
;
// act
auto
act_op_desc
=
act
->
Op
();
act_op_desc
->
RenameInput
(
squeeze_out
->
Var
()
->
Name
(),
squeeze_input
->
Var
()
->
Name
());
act_out
->
Var
()
->
SetShape
(
squeeze_input
->
Var
()
->
GetShape
());
act_op_desc
->
Flush
();
IR_NODE_LINK_TO
(
squeeze_input
,
act
);
// behind unsqueeze op node
auto
unsqueeze_out_link_nodes
=
unsqueeze_out
->
outputs
;
for
(
auto
out_link_node
:
unsqueeze_out_link_nodes
)
{
auto
op_desc
=
out_link_node
->
Op
();
op_desc
->
RenameInput
(
unsqueeze_out
->
Var
()
->
Name
(),
act_out
->
Var
()
->
Name
());
op_desc
->
Flush
();
IR_NODE_LINK_TO
(
act_out
,
out_link_node
);
}
std
::
unordered_set
<
const
Node
*>
delete_nodes
{
squeeze
,
squeeze_out
,
unsqueeze
,
unsqueeze_out
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
class
CustomSqueezeUnsqueezeEliminationPass
:
public
FusePassBase
{
public:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
int
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act1_type
,
const
std
::
string
&
act2_type
,
const
std
::
string
&
elementwise_type
,
bool
act1_in_branch_x
)
const
;
const
std
::
string
name_scope_
{
"custom_squeeze_unsqueeze_elimination_pass"
};
};
void
CustomSqueezeUnsqueezeEliminationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
std
::
vector
<
std
::
string
>
support_act_type
{
"relu"
,
"sigmoid"
,
"tanh"
,
"gelu"
,
"leaky_relu"
,
"hard_swish"
,
"hard_sigmoid"
,
"relu6"
,
"swish"
};
std
::
vector
<
std
::
string
>
support_elementwise_type
{
"elementwise_add"
,
"elementwise_sub"
,
"elementwise_mul"
,
"elementwise_div"
};
int
found_subgraph_count
=
0
;
for
(
auto
act1_type
:
support_act_type
)
{
for
(
auto
act2_type
:
support_act_type
)
{
for
(
auto
elementwise_type
:
support_elementwise_type
)
{
for
(
auto
act1_in_branch_x
:
{
true
,
false
})
{
found_subgraph_count
+=
ApplyImpl
(
graph
,
act1_type
,
act2_type
,
elementwise_type
,
act1_in_branch_x
);
}
}
}
}
AddStatis
(
found_subgraph_count
);
}
int
CustomSqueezeUnsqueezeEliminationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act1_type
,
const
std
::
string
&
act2_type
,
const
std
::
string
&
elementwise_type
,
const
bool
act1_in_branch_x
)
const
{
GraphPatternDetector
gpd
;
patterns
::
CustomSqueezeUnsqueezeEliminationPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
act1_type
,
act2_type
,
elementwise_type
,
act1_in_branch_x
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle custom squeeze unsqueeze elimination pass."
;
/* Get operator node's name */
GET_IR_NODE
(
squeeze1
);
GET_IR_NODE
(
squeeze2
);
GET_IR_NODE
(
act1
);
GET_IR_NODE
(
elementwise
);
GET_IR_NODE
(
act2
);
/* Get variable node's name*/
GET_IR_NODE
(
squeeze1_input
);
GET_IR_NODE
(
squeeze1_out
);
GET_IR_NODE
(
act1_out
);
GET_IR_NODE
(
squeeze2_input
);
GET_IR_NODE
(
squeeze2_out
);
GET_IR_NODE
(
elementwise_out
);
GET_IR_NODE
(
act2_out
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
// Judge squeeze1 && squeeze2 op shape is same or not, if axes is same, the
// shape is same too.
std
::
vector
<
int
>
squeeze1_axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
squeeze1
->
Op
()
->
GetAttr
(
"axes"
));
std
::
vector
<
int
>
squeeze2_axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
squeeze2
->
Op
()
->
GetAttr
(
"axes"
));
bool
elimination
=
(
squeeze1_axes
==
squeeze2_axes
);
if
(
!
elimination
)
return
;
// act1
auto
act1_op_desc
=
act1
->
Op
();
std
::
string
squeeze1_input_var_name
=
squeeze1_input
->
Var
()
->
Name
();
std
::
string
squeeze1_out_var_name
=
squeeze1_out
->
Var
()
->
Name
();
act1_op_desc
->
RenameInput
(
squeeze1_out_var_name
,
squeeze1_input_var_name
);
act1_out
->
Var
()
->
SetShape
(
squeeze1_input
->
Var
()
->
GetShape
());
act1_op_desc
->
Flush
();
IR_NODE_LINK_TO
(
squeeze1_input
,
act1
);
// elementwise
auto
elementwise_op_desc
=
elementwise
->
Op
();
std
::
string
squeeze2_input_var_name
=
squeeze2_input
->
Var
()
->
Name
();
std
::
string
squeeze2_out_var_name
=
squeeze2_out
->
Var
()
->
Name
();
elementwise_op_desc
->
RenameInput
(
squeeze2_out_var_name
,
squeeze2_input_var_name
);
elementwise_out
->
Var
()
->
SetShape
(
squeeze2_input
->
Var
()
->
GetShape
());
elementwise_op_desc
->
Flush
();
IR_NODE_LINK_TO
(
squeeze2_input
,
elementwise
);
std
::
string
act2_out_var_name
=
act2_out
->
Var
()
->
Name
();
std
::
vector
<
Node
*>
remove_nodes
;
auto
act2_out_link_nodes
=
act2_out
->
outputs
;
for
(
auto
out_link_node
:
act2_out_link_nodes
)
{
auto
op_desc
=
out_link_node
->
Op
();
if
(
op_desc
->
Type
()
==
"unsqueeze2"
)
{
std
::
vector
<
int
>
unsqueeze_axes
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op_desc
->
GetAttr
(
"axes"
));
elimination
=
elimination
&&
(
unsqueeze_axes
==
squeeze1_axes
);
if
(
elimination
)
{
remove_nodes
.
push_back
(
out_link_node
);
delete_nodes
.
insert
(
out_link_node
);
}
}
}
if
(
!
elimination
)
return
;
act2_out
->
Var
()
->
SetShape
(
elementwise_out
->
Var
()
->
GetShape
());
for
(
auto
unsqueeze_node
:
remove_nodes
)
{
std
::
string
unsqueeze_out_var_name
=
unsqueeze_node
->
Op
()
->
Output
(
"Out"
)[
0
];
for
(
auto
unsqueeze_out_node
:
unsqueeze_node
->
outputs
)
{
// find unsqueeze "Out" var node
if
(
unsqueeze_out_node
->
Name
()
==
unsqueeze_out_var_name
)
{
// Do delete operation
delete_nodes
.
insert
(
unsqueeze_out_node
);
for
(
auto
next_node
:
unsqueeze_out_node
->
outputs
)
{
auto
next_op_desc
=
next_node
->
Op
();
next_op_desc
->
RenameInput
(
unsqueeze_out_var_name
,
act2_out_var_name
);
next_op_desc
->
Flush
();
IR_NODE_LINK_TO
(
act2_out
,
next_node
);
}
}
}
}
if
(
elimination
)
{
delete_nodes
.
insert
(
squeeze1
);
delete_nodes
.
insert
(
squeeze2
);
delete_nodes
.
insert
(
squeeze1_out
);
delete_nodes
.
insert
(
squeeze2_out
);
}
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
class
RedundantSqueezeUnsqueezeEliminationPass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
const
std
::
string
name_scope_
{
"redundant_squeeze_unsqueeze_elimination_pass"
};
};
void
RedundantSqueezeUnsqueezeEliminationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
4
)
<<
"handle redundant squeeze unsqueeze elimination."
;
SqueezeActivationUnsqueezeEliminationPass
squeeze_activation_unsqueeze_elimination_pass
;
squeeze_activation_unsqueeze_elimination_pass
.
ApplyImpl
(
graph
);
CustomSqueezeUnsqueezeEliminationPass
custom_squeeze_unsqueeze_elimination_pass
;
custom_squeeze_unsqueeze_elimination_pass
.
ApplyImpl
(
graph
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
redundant_squeeze_unsqueeze_elimination_pass
,
paddle
::
framework
::
ir
::
RedundantSqueezeUnsqueezeEliminationPass
);
REGISTER_PASS_CAPABILITY
(
redundant_squeeze_unsqueeze_elimination_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"squeeze2"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
EQ
(
"unsqueeze2"
,
0
));
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
7f6bb160
...
...
@@ -531,6 +531,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"matmul_weight_trans_pass"
,
"map_matmulv2_to_matmul_xpu_pass"
,
"reshape2_matmul_xpu_fuse_pass"
,
"redundant_squeeze_unsqueeze_elimination_pass"
,
"fc_xpu_fuse_pass"
,
"conv2d_xpu_fuse_pass"
,
"add_activation_xpu_fuse_pass"
,
...
...
test/ir/inference/test_xpu_redundant_squeeze_unsqueeze_elimination.py
0 → 100644
浏览文件 @
7f6bb160
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
class
TestXpuRedundantSqueezeUnsqueezeEliminationPass
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"leaky_relu"
],
(
1e-5
,
1e-5
)
def
sample_program_config
(
self
,
draw
):
x_shape
=
draw
(
st
.
sampled_from
([[
1
,
32
,
1
,
4
]]))
alpha
=
0.009999999776482582
axes
=
[
2
]
squeeze_op
=
OpConfig
(
"squeeze2"
,
inputs
=
{
"X"
:
[
"squeeze_input"
],
},
outputs
=
{
"Out"
:
[
"squeeze_out"
]},
axes
=
axes
,
)
leaky_relu_op
=
OpConfig
(
"leaky_relu"
,
inputs
=
{
"X"
:
[
"squeeze_out"
],
},
outputs
=
{
"Out"
:
[
"leaky_relu_out"
]},
alpha
=
alpha
,
)
unsqueeze_op
=
OpConfig
(
"unsqueeze2"
,
inputs
=
{
"X"
:
[
"leaky_relu_out"
],
},
outputs
=
{
"Out"
:
[
"unsqueeze_out"
]},
axes
=
axes
,
)
ops
=
[
squeeze_op
,
leaky_relu_op
,
unsqueeze_op
]
def
generate_data
(
shape
):
return
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
program_config
=
ProgramConfig
(
ops
=
ops
,
inputs
=
{
"squeeze_input"
:
TensorConfig
(
data_gen
=
partial
(
generate_data
,
x_shape
)
),
},
weights
=
{},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
min_success_num
=
1
,
passes
=
[
"redundant_squeeze_unsqueeze_elimination_pass"
],
)
class
TestXpuRedundantSqueezeUnsqueezeEliminationPass2
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"leaky_relu"
,
"elementwise_add"
,
"leaky_relu"
],
(
1e-5
,
1e-5
,
)
def
sample_program_config
(
self
,
draw
):
x_shape
=
draw
(
st
.
sampled_from
([[
1
,
32
,
1
,
4
]]))
alpha
=
0.009999999776482582
axes
=
[
2
]
squeeze_op_1
=
OpConfig
(
"squeeze2"
,
inputs
=
{
"X"
:
[
"squeeze_1_input"
],
},
outputs
=
{
"Out"
:
[
"squeeze_1_out"
]},
axes
=
axes
,
)
leaky_relu_op_1
=
OpConfig
(
"leaky_relu"
,
inputs
=
{
"X"
:
[
"squeeze_1_out"
],
},
outputs
=
{
"Out"
:
[
"leaky_relu_1_out"
]},
alpha
=
alpha
,
)
squeeze_op_2
=
OpConfig
(
"squeeze2"
,
inputs
=
{
"X"
:
[
"squeeze_2_input"
],
},
outputs
=
{
"Out"
:
[
"squeeze_2_out"
]},
axes
=
axes
,
)
elementwise_add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"leaky_relu_1_out"
],
"Y"
:
[
"squeeze_2_out"
]},
outputs
=
{
"Out"
:
[
"elementwise_add_out"
]},
)
leaky_relu_op_2
=
OpConfig
(
"leaky_relu"
,
inputs
=
{
"X"
:
[
"elementwise_add_out"
],
},
outputs
=
{
"Out"
:
[
"leaky_relu_2_out"
]},
alpha
=
alpha
,
)
unsqueeze_op_1
=
OpConfig
(
"unsqueeze2"
,
inputs
=
{
"X"
:
[
"leaky_relu_2_out"
],
},
outputs
=
{
"Out"
:
[
"unsqueeze_1_out"
]},
axes
=
axes
,
)
unsqueeze_op_2
=
OpConfig
(
"unsqueeze2"
,
inputs
=
{
"X"
:
[
"leaky_relu_2_out"
],
},
outputs
=
{
"Out"
:
[
"unsqueeze_2_out"
]},
axes
=
axes
,
)
ops
=
[
squeeze_op_1
,
leaky_relu_op_1
,
squeeze_op_2
,
elementwise_add_op
,
leaky_relu_op_2
,
unsqueeze_op_1
,
unsqueeze_op_2
,
]
def
generate_data
(
shape
):
return
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
program_config
=
ProgramConfig
(
ops
=
ops
,
inputs
=
{
"squeeze_1_input"
:
TensorConfig
(
data_gen
=
partial
(
generate_data
,
x_shape
)
),
"squeeze_2_input"
:
TensorConfig
(
data_gen
=
partial
(
generate_data
,
x_shape
)
),
},
weights
=
{},
outputs
=
[
"unsqueeze_1_out"
,
"unsqueeze_2_out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
min_success_num
=
1
,
passes
=
[
"redundant_squeeze_unsqueeze_elimination_pass"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录