Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
dc4b48f6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dc4b48f6
编写于
8月 03, 2023
作者:
W
wz1qqx
提交者:
GitHub
8月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
eliminate small pattern (#55843)
上级
c4694c15
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
455 addition
and
284 deletion
+455
-284
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-2
paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc
+23
-26
paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc
paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc
+6
-6
paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc
...d/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc
+0
-209
paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc
...rk/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc
+330
-0
paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h
...ork/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h
+25
-25
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-1
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+1
-1
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+1
-2
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+7
-4
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+1
-2
paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc
paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc
+58
-6
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
dc4b48f6
...
...
@@ -240,8 +240,8 @@ if(WITH_XPU)
pass_library
(
yolo_box_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv1d_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv2d_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
redundant_
onnx_ops_elimination_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
redundant_
unsqueeze_squeeze_elimination_pass inference DIR xpu
DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu
DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv2d_transpose_xpu_fuse_pass inference DIR xpu DEPS
...
...
paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc
浏览文件 @
dc4b48f6
...
...
@@ -43,17 +43,17 @@ namespace patterns {
fuse ele_add + activation block in to xpu_ele_fusion op
For example:
graph:
ele
_x
add
_x
|
elementwise_add -----
ele
_y
elementwise_add -----
add
_y
|
layernorm
|
output
------------------------------------------------------
After the pass is applied:
ele
_x
|
ele
_y
add
_x
|
add
_y
| /
| /
scale---- add_layernorm_fusion ---- bias
...
...
@@ -68,8 +68,8 @@ struct AddLayernormXPUPattern : public PatternBase {
PATTERN_DECL_NODE
(
ele_add
);
PATTERN_DECL_NODE
(
l_norm
);
// declare variable node's name
PATTERN_DECL_NODE
(
ele
_x
);
PATTERN_DECL_NODE
(
ele
_y
);
PATTERN_DECL_NODE
(
add
_x
);
PATTERN_DECL_NODE
(
add
_y
);
PATTERN_DECL_NODE
(
ele_out
);
PATTERN_DECL_NODE
(
norm_bias
);
PATTERN_DECL_NODE
(
norm_scale
);
...
...
@@ -83,17 +83,16 @@ AddLayernormXPUPattern::AddLayernormXPUPattern(PDPattern* pattern,
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
ele_add
=
pattern
->
NewNode
(
ele_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
ele_x
=
pattern
->
NewNode
(
ele
_x_repr
())
auto
add_x
=
pattern
->
NewNode
(
add
_x_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsInput
();
auto
ele_y
=
pattern
->
NewNode
(
ele
_y_repr
())
auto
add_y
=
pattern
->
NewNode
(
add
_y_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
auto
ele_out
=
pattern
->
NewNode
(
ele_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"layer_norm"
,
"X"
)
->
assert_has_n_outputs
(
1
);
ele_add
->
LinksFrom
({
ele_x
,
ele_y
}).
LinksTo
({
ele_out
});
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
ele_add
->
LinksFrom
({
add_x
,
add_y
}).
LinksTo
({
ele_out
});
auto
l_norm
=
pattern
->
NewNode
(
l_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
norm_bias
=
pattern
->
NewNode
(
norm_bias_repr
())
->
AsInput
()
...
...
@@ -169,8 +168,8 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
GET_IR_NODE
(
ele_add
);
GET_IR_NODE
(
l_norm
);
// declare variable node's name
GET_IR_NODE
(
ele
_x
);
GET_IR_NODE
(
ele
_y
);
GET_IR_NODE
(
add
_x
);
GET_IR_NODE
(
add
_y
);
GET_IR_NODE
(
ele_out
);
GET_IR_NODE
(
norm_bias
);
GET_IR_NODE
(
norm_scale
);
...
...
@@ -178,21 +177,21 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
GET_IR_NODE
(
norm_variance
);
GET_IR_NODE
(
norm_out
);
auto
*
block
=
ele_add
->
Op
()
->
Block
();
auto
*
block
=
l_norm
->
Op
()
->
Block
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
auto
x_shape
=
add_x
->
Var
()
->
GetShape
();
auto
x_rank
=
x_shape
.
size
();
auto
y_shape
=
add_y
->
Var
()
->
GetShape
();
auto
y_rank
=
y_shape
.
size
();
if
(
x_rank
!=
y_rank
)
return
;
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
float
eps
=
PADDLE_GET_CONST
(
float
,
l_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
int
begin_norm_axis
=
PADDLE_GET_CONST
(
int
,
l_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
auto
layer_norm_x_dims
=
ele_out
->
Var
()
->
GetShape
();
auto
layer_norm_x_mat_dims
=
phi
::
flatten_to_2d
(
phi
::
make_ddim
(
layer_norm_x_dims
),
begin_norm_axis
);
int64_t
m
=
layer_norm_x_mat_dims
[
0
];
int64_t
n
=
layer_norm_x_mat_dims
[
1
];
std
::
string
fused_op_out_name
;
fused_op_out_name
=
norm_out
->
Name
();
...
...
@@ -200,28 +199,26 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
framework
::
OpDesc
fused_op_desc
(
block
);
fused_op_desc
.
SetType
(
"add_layernorm_xpu"
);
// set attrs for fused op
fused_op_desc
.
SetInput
(
"x"
,
{
ele
_x
->
Name
()});
fused_op_desc
.
SetInput
(
"y"
,
{
ele
_y
->
Name
()});
fused_op_desc
.
SetInput
(
"x"
,
{
add
_x
->
Name
()});
fused_op_desc
.
SetInput
(
"y"
,
{
add
_y
->
Name
()});
fused_op_desc
.
SetInput
(
"scale"
,
{
norm_scale
->
Name
()});
fused_op_desc
.
SetInput
(
"bias"
,
{
norm_bias
->
Name
()});
fused_op_desc
.
SetAttr
(
"m"
,
m
);
fused_op_desc
.
SetAttr
(
"n"
,
n
);
fused_op_desc
.
SetAttr
(
"epsilon"
,
eps
);
fused_op_desc
.
SetAttr
(
"begin_norm_axis"
,
begin_norm_axis
);
fused_op_desc
.
SetOutput
(
"out"
,
{
fused_op_out_name
});
setIntermediateOut
(
&
fused_op_desc
,
"mean"
,
name_scope_
);
setIntermediateOut
(
&
fused_op_desc
,
"variance"
,
name_scope_
);
setIntermediateOut
(
&
fused_op_desc
,
"z_add"
,
name_scope_
);
// relink fused op
auto
*
fused_op
=
graph
->
CreateOpNode
(
&
fused_op_desc
);
IR_NODE_LINK_TO
(
ele
_x
,
fused_op
);
IR_NODE_LINK_TO
(
ele
_y
,
fused_op
);
IR_NODE_LINK_TO
(
add
_x
,
fused_op
);
IR_NODE_LINK_TO
(
add
_y
,
fused_op
);
IR_NODE_LINK_TO
(
norm_scale
,
fused_op
);
IR_NODE_LINK_TO
(
norm_bias
,
fused_op
);
IR_NODE_LINK_TO
(
fused_op
,
norm_out
);
addIntermediateOut
(
fused_op
,
"mean"
,
name_scope_
,
graph
);
addIntermediateOut
(
fused_op
,
"variance"
,
name_scope_
,
graph
);
addIntermediateOut
(
fused_op
,
"z_add"
,
name_scope_
,
graph
);
delete_nodes
.
insert
({
ele_add
,
l_norm
,
ele_out
,
norm_mean
,
norm_variance
});
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
...
...
paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc
浏览文件 @
dc4b48f6
...
...
@@ -88,7 +88,7 @@ ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern,
auto
*
op_desc
=
node
->
Op
();
auto
input_var
=
node
->
inputs
[
0
]
->
Var
();
auto
pool2d_x_shape
=
input_var
->
GetShape
();
std
::
vector
<
int
>
HW
=
{
static_cast
<
int
>
(
pool2d_x_shape
[
2
]),
std
::
vector
<
int
>
hw
=
{
static_cast
<
int
>
(
pool2d_x_shape
[
2
]),
static_cast
<
int
>
(
pool2d_x_shape
[
3
])};
auto
pool_type
=
op_desc
->
GetAttrIfExists
<
std
::
string
>
(
"pooling_type"
);
...
...
@@ -98,8 +98,8 @@ ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern,
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"strides"
);
auto
paddings_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"paddings"
);
return
pool_type
==
"max"
&&
ksize_array
==
HW
&&
strides_array
==
HW
&&
return
pool_type
==
"max"
&&
ksize_array
==
hw
&&
strides_array
==
hw
&&
paddings_array
==
std
::
vector
<
int
>
{
0
,
0
};
});
auto
*
pool2d_out
=
pattern
->
NewNode
(
pool2d_out_repr
())
...
...
@@ -181,7 +181,7 @@ ReduceMeanFusePattern::ReduceMeanFusePattern(PDPattern* pattern,
auto
*
op_desc
=
node
->
Op
();
auto
input_var
=
node
->
inputs
[
0
]
->
Var
();
auto
pool2d_x_shape
=
input_var
->
GetShape
();
std
::
vector
<
int
>
HW
=
{
static_cast
<
int
>
(
pool2d_x_shape
[
2
]),
std
::
vector
<
int
>
hw
=
{
static_cast
<
int
>
(
pool2d_x_shape
[
2
]),
static_cast
<
int
>
(
pool2d_x_shape
[
3
])};
auto
pool_type
=
op_desc
->
GetAttrIfExists
<
std
::
string
>
(
"pooling_type"
);
...
...
@@ -191,8 +191,8 @@ ReduceMeanFusePattern::ReduceMeanFusePattern(PDPattern* pattern,
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"strides"
);
auto
paddings_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"paddings"
);
return
pool_type
==
"avg"
&&
ksize_array
==
HW
&&
strides_array
==
HW
&&
return
pool_type
==
"avg"
&&
ksize_array
==
hw
&&
strides_array
==
hw
&&
paddings_array
==
std
::
vector
<
int
>
{
0
,
0
};
});
auto
*
pool2d_out
=
pattern
->
NewNode
(
pool2d_out_repr
())
...
...
paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.cc
已删除
100644 → 0
浏览文件 @
c4694c15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
FoldConv1dSqueeze2Pattern
:
public
PatternBase
{
FoldConv1dSqueeze2Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
squeeze2
);
PATTERN_DECL_NODE
(
bn
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
unsqueeze2
);
// declare variable node's name
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
squeeze2_out
);
PATTERN_DECL_NODE
(
bn_bias
);
PATTERN_DECL_NODE
(
bn_mean
);
PATTERN_DECL_NODE
(
bn_scale
);
PATTERN_DECL_NODE
(
bn_var
);
PATTERN_DECL_NODE
(
bn_out
);
PATTERN_DECL_NODE
(
bn_mean_out
);
PATTERN_DECL_NODE
(
bn_saved_mean
);
PATTERN_DECL_NODE
(
bn_saved_var
);
PATTERN_DECL_NODE
(
bn_var_out
);
PATTERN_DECL_NODE
(
act_out
);
PATTERN_DECL_NODE
(
unsqueeze2_out
);
private:
std
::
string
act_type_
;
};
FoldConv1dSqueeze2Pattern
::
FoldConv1dSqueeze2Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act_type
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
act_type_
(
act_type
)
{
auto
*
x
=
pattern
->
NewNode
(
x_repr
())
->
assert_is_op_input
(
"squeeze2"
,
"X"
)
->
assert_more
([](
Node
*
node
)
{
auto
x_shape
=
node
->
Var
()
->
GetShape
();
size_t
x_rank
=
x_shape
.
size
();
return
x_rank
==
4
&&
x_shape
[
2
]
==
1
;
});
auto
*
squeeze2
=
pattern
->
NewNode
(
squeeze2_repr
())
->
assert_is_op
(
"squeeze2"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
axes_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axes"
);
return
axes_array
==
std
::
vector
<
int
>
{
-
2
};
});
auto
*
squeeze2_out
=
pattern
->
NewNode
(
squeeze2_out_repr
())
->
assert_is_op_output
(
"squeeze2"
,
"Out"
)
->
assert_is_op_input
(
"batch_norm"
,
"X"
);
squeeze2
->
LinksFrom
({
x
}).
LinksTo
({
squeeze2_out
});
auto
*
bn_bias
=
pattern
->
NewNode
(
bn_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
)
->
assert_has_n_outputs
(
1
);
auto
*
bn_mean
=
pattern
->
NewNode
(
bn_mean_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"batch_norm"
,
"Mean"
)
->
assert_has_n_outputs
(
1
);
auto
*
bn_scale
=
pattern
->
NewNode
(
bn_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
)
->
assert_has_n_outputs
(
1
);
auto
*
bn_var
=
pattern
->
NewNode
(
bn_var_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"batch_norm"
,
"Variance"
)
->
assert_has_n_outputs
(
1
);
auto
*
bn
=
pattern
->
NewNode
(
bn_repr
())
->
assert_is_op
(
"batch_norm"
);
auto
*
bn_out
=
pattern
->
NewNode
(
bn_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
act_type_
,
"X"
);
auto
*
bn_mean_out
=
pattern
->
NewNode
(
bn_mean_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
);
auto
*
bn_saved_mean
=
pattern
->
NewNode
(
bn_saved_mean_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
);
auto
*
bn_var_out
=
pattern
->
NewNode
(
bn_var_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
);
auto
*
bn_saved_var
=
pattern
->
NewNode
(
bn_saved_var_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
);
bn
->
LinksFrom
({
squeeze2_out
,
bn_bias
,
bn_mean
,
bn_scale
,
bn_var
})
.
LinksTo
({
bn_out
,
bn_mean_out
,
bn_var_out
,
bn_saved_mean
,
bn_saved_var
});
auto
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type_
);
auto
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type_
,
"Out"
)
->
assert_is_op_input
(
"unsqueeze2"
,
"X"
);
act
->
LinksFrom
({
bn_out
}).
LinksTo
({
act_out
});
auto
*
unsqueeze2
=
pattern
->
NewNode
(
unsqueeze2_repr
())
->
assert_is_op
(
"unsqueeze2"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
axes_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axes"
);
return
axes_array
==
std
::
vector
<
int
>
{
-
2
}
||
axes_array
==
std
::
vector
<
int
>
{
2
};
});
auto
*
unsqueeze2_out
=
pattern
->
NewNode
(
unsqueeze2_out_repr
())
->
assert_is_op_output
(
"unsqueeze2"
,
"Out"
);
unsqueeze2
->
LinksFrom
({
act_out
}).
LinksTo
({
unsqueeze2_out
});
}
}
// namespace patterns
void
RedundantOnnxOpsEliminationPass
::
FoldConv1dSqueeze2Ops
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
GraphPatternDetector
gpd
;
patterns
::
FoldConv1dSqueeze2Pattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
act_type
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FoldConv1dSqueeze2Ops"
;
// declare operator node's name
GET_IR_NODE
(
squeeze2
);
GET_IR_NODE
(
bn
);
GET_IR_NODE
(
act
);
GET_IR_NODE
(
unsqueeze2
);
// declare variable node's name
GET_IR_NODE
(
x
);
GET_IR_NODE
(
squeeze2_out
);
GET_IR_NODE
(
bn_out
);
GET_IR_NODE
(
act_out
);
GET_IR_NODE
(
unsqueeze2_out
);
auto
bn_op_desc
=
bn
->
Op
();
bn_op_desc
->
RenameInput
(
squeeze2_out
->
Var
()
->
Name
(),
x
->
Var
()
->
Name
());
bn_out
->
Var
()
->
SetShape
(
x
->
Var
()
->
GetShape
());
act_out
->
Var
()
->
SetShape
(
x
->
Var
()
->
GetShape
());
bn_op_desc
->
Flush
();
IR_NODE_LINK_TO
(
x
,
bn
);
// behind unsqueeze op node
auto
unsqueeze_out_link_nodes
=
unsqueeze2_out
->
outputs
;
for
(
auto
out_link_node
:
unsqueeze_out_link_nodes
)
{
auto
op_desc
=
out_link_node
->
Op
();
op_desc
->
RenameInput
(
unsqueeze2_out
->
Var
()
->
Name
(),
act_out
->
Var
()
->
Name
());
op_desc
->
Flush
();
IR_NODE_LINK_TO
(
act_out
,
out_link_node
);
}
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
=
{
squeeze2
,
squeeze2_out
,
unsqueeze2
,
unsqueeze2_out
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
void
RedundantOnnxOpsEliminationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
for
(
auto
act_type
:
{
"leaky_relu"
,
"elu"
})
{
FoldConv1dSqueeze2Ops
(
graph
,
act_type
);
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
redundant_onnx_ops_elimination_pass
,
paddle
::
framework
::
ir
::
RedundantOnnxOpsEliminationPass
);
REGISTER_PASS_CAPABILITY
(
redundant_onnx_ops_elimination_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"conv2d"
,
0
));
paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc
0 → 100644
浏览文件 @
dc4b48f6
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
FoldTranspose2OpsPattern
:
public
PatternBase
{
FoldTranspose2OpsPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
unsqueeze2
);
PATTERN_DECL_NODE
(
reduce_sum
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
transpose2_2
);
// declare variable node's name
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
unsqueeze2_out
);
PATTERN_DECL_NODE
(
sum_out
);
PATTERN_DECL_NODE
(
act_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
private:
std
::
string
act_type_
;
};
FoldTranspose2OpsPattern
::
FoldTranspose2OpsPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
const
std
::
string
&
act_type
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
act_type_
(
act_type
)
{
auto
*
x
=
pattern
->
NewNode
(
x_repr
())
->
assert_is_op_input
(
"transpose2"
,
"X"
)
->
assert_more
([](
Node
*
node
)
{
auto
x_shape
=
node
->
Var
()
->
GetShape
();
size_t
x_rank
=
x_shape
.
size
();
return
x_rank
==
3
;
});
auto
*
transpose2_1
=
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
axis_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axis"
);
return
axis_array
==
std
::
vector
<
int
>
{
0
,
2
,
1
};
});
auto
*
transpose2_1_out
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
"unsqueeze2"
,
"X"
);
transpose2_1
->
LinksFrom
({
x
}).
LinksTo
({
transpose2_1_out
});
auto
*
unsqueeze2
=
pattern
->
NewNode
(
unsqueeze2_repr
())
->
assert_is_op
(
"unsqueeze2"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
axes_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axes"
);
return
axes_array
==
std
::
vector
<
int
>
{
-
2
};
});
auto
*
unsqueeze2_out
=
pattern
->
NewNode
(
unsqueeze2_out_repr
())
->
assert_is_op_output
(
"unsqueeze2"
,
"Out"
)
->
assert_is_op_input
(
"reduce_sum"
,
"X"
);
unsqueeze2
->
LinksFrom
({
transpose2_1_out
}).
LinksTo
({
unsqueeze2_out
});
auto
*
reduce_sum
=
pattern
->
NewNode
(
reduce_sum_repr
())
->
assert_is_op
(
"reduce_sum"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
keep_dim
=
op_desc
->
GetAttrIfExists
<
bool
>
(
"keep_dim"
);
auto
dim_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"dim"
);
return
dim_array
==
std
::
vector
<
int
>
{
-
2
}
&&
!
keep_dim
;
});
auto
*
sum_out
=
pattern
->
NewNode
(
sum_out_repr
())
->
assert_is_op_output
(
"reduce_sum"
,
"Out"
)
->
assert_is_op_input
(
act_type_
,
"X"
);
reduce_sum
->
LinksFrom
({
unsqueeze2_out
}).
LinksTo
({
sum_out
});
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type_
);
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type_
,
"Out"
)
->
assert_is_op_input
(
"transpose2"
,
"X"
);
act
->
LinksFrom
({
sum_out
}).
LinksTo
({
act_out
});
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
axis_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axis"
);
return
axis_array
==
std
::
vector
<
int
>
{
0
,
2
,
1
};
});
auto
*
transpose2_2_out
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
);
transpose2_2
->
LinksFrom
({
act_out
}).
LinksTo
({
transpose2_2_out
});
}
struct
FoldGatherSqueeze2Pattern
:
public
PatternBase
{
FoldGatherSqueeze2Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
unsqueeze2_op
);
PATTERN_DECL_NODE
(
gather_op
);
PATTERN_DECL_NODE
(
squeeze2_op
);
// declare variable node's name
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
unsqueeze2_op_out
);
PATTERN_DECL_NODE
(
gather_i
);
PATTERN_DECL_NODE
(
gather_op_out
);
PATTERN_DECL_NODE
(
squeeze2_op_out
);
};
FoldGatherSqueeze2Pattern
::
FoldGatherSqueeze2Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
*
x
=
pattern
->
NewNode
(
x_repr
())
->
assert_is_op_input
(
"unsqueeze2"
,
"X"
);
auto
*
unsqueeze2_op
=
pattern
->
NewNode
(
unsqueeze2_op_repr
())
->
assert_is_op
(
"unsqueeze2"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
axes_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axes"
);
return
axes_array
.
size
()
==
1
;
});
auto
*
unsqueeze2_op_out
=
pattern
->
NewNode
(
unsqueeze2_op_out_repr
())
->
assert_is_op_output
(
"unsqueeze2"
,
"Out"
)
->
assert_is_op_input
(
"gather"
,
"X"
);
unsqueeze2_op
->
LinksFrom
({
x
}).
LinksTo
({
unsqueeze2_op_out
});
auto
*
gather_op
=
pattern
->
NewNode
(
gather_op_repr
())
->
assert_is_op
(
"gather"
);
auto
*
gather_i
=
pattern
->
NewNode
(
gather_i_repr
())
->
assert_is_op_input
(
"gather"
,
"Index"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
auto
i_shape
=
node
->
Var
()
->
GetShape
();
size_t
i_rank
=
i_shape
.
size
();
return
i_rank
==
1
;
});
auto
*
gather_op_out
=
pattern
->
NewNode
(
gather_op_out_repr
())
->
assert_is_op_output
(
"gather"
,
"Out"
)
->
assert_is_op_input
(
"squeeze2"
,
"X"
);
gather_op
->
LinksFrom
({
unsqueeze2_op_out
,
gather_i
}).
LinksTo
({
gather_op_out
});
auto
*
squeeze2_op
=
pattern
->
NewNode
(
squeeze2_op_repr
())
->
assert_is_op
(
"squeeze2"
)
->
assert_more
([](
Node
*
node
)
{
auto
*
op_desc
=
node
->
Op
();
auto
axes_array
=
op_desc
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"axes"
);
return
axes_array
.
size
()
==
1
;
});
auto
*
squeeze2_op_out
=
pattern
->
NewNode
(
squeeze2_op_out_repr
())
->
assert_is_op_output
(
"squeeze2"
,
"Out"
);
squeeze2_op
->
LinksFrom
({
gather_op_out
}).
LinksTo
({
squeeze2_op_out
});
}
}
// namespace patterns
void
RedundantUnsqueeze2EliminationPass
::
FoldTranspose2Ops
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
GraphPatternDetector
gpd
;
patterns
::
FoldTranspose2OpsPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
act_type
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FoldTranspose2Ops"
;
// declare operator node's name
GET_IR_NODE
(
transpose2_1
);
GET_IR_NODE
(
unsqueeze2
);
GET_IR_NODE
(
reduce_sum
);
GET_IR_NODE
(
act
);
GET_IR_NODE
(
transpose2_2
);
// declare variable node's name
GET_IR_NODE
(
x
);
GET_IR_NODE
(
transpose2_1_out
);
GET_IR_NODE
(
unsqueeze2_out
);
GET_IR_NODE
(
sum_out
);
GET_IR_NODE
(
act_out
);
GET_IR_NODE
(
transpose2_2_out
);
auto
act_op_desc
=
act
->
Op
();
act_op_desc
->
RenameInput
(
sum_out
->
Var
()
->
Name
(),
x
->
Var
()
->
Name
());
act_out
->
Var
()
->
SetShape
(
x
->
Var
()
->
GetShape
());
act_op_desc
->
Flush
();
IR_NODE_LINK_TO
(
x
,
act
);
// behind unsqueeze op node
auto
final_out_link_nodes
=
transpose2_2_out
->
outputs
;
for
(
auto
out_link_node
:
final_out_link_nodes
)
{
auto
op_desc
=
out_link_node
->
Op
();
op_desc
->
RenameInput
(
transpose2_2_out
->
Var
()
->
Name
(),
act_out
->
Var
()
->
Name
());
op_desc
->
Flush
();
IR_NODE_LINK_TO
(
act_out
,
out_link_node
);
}
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
=
{
transpose2_1
,
transpose2_1_out
,
unsqueeze2
,
unsqueeze2_out
,
reduce_sum
,
sum_out
,
transpose2_2
,
transpose2_2_out
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
void
RedundantUnsqueeze2EliminationPass
::
FoldGatherSqueeze2Ops
(
ir
::
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
FoldGatherSqueeze2Pattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FoldGatherSqueeze2Ops"
;
// declare operator node's name
GET_IR_NODE
(
unsqueeze2_op
);
GET_IR_NODE
(
gather_op
);
GET_IR_NODE
(
squeeze2_op
);
// declare variable node's name
GET_IR_NODE
(
x
);
GET_IR_NODE
(
unsqueeze2_op_out
);
GET_IR_NODE
(
gather_i
);
GET_IR_NODE
(
gather_op_out
);
GET_IR_NODE
(
squeeze2_op_out
);
bool
flag
=
true
;
auto
x_shape
=
x
->
Var
()
->
GetShape
();
auto
x_rank
=
static_cast
<
int
>
(
x_shape
.
size
());
std
::
vector
<
int
>
unsqueeze_axes_attr
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
unsqueeze2_op
->
Op
()
->
GetAttr
(
"axes"
));
auto
unsqueeze_axes
=
unsqueeze_axes_attr
.
front
();
unsqueeze_axes
=
unsqueeze_axes
<
0
?
unsqueeze_axes
+
x_rank
:
unsqueeze_axes
;
auto
gather_axis
=
PADDLE_GET_CONST
(
int
,
gather_op
->
Op
()
->
GetAttr
(
"axis"
));
gather_axis
=
gather_axis
<
0
?
gather_axis
+
x_rank
+
1
:
gather_axis
;
std
::
vector
<
int
>
squeeze_axes_attr
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
squeeze2_op
->
Op
()
->
GetAttr
(
"axes"
));
auto
squeeze_axes
=
squeeze_axes_attr
.
front
();
squeeze_axes
=
squeeze_axes
<
0
?
squeeze_axes
+
x_rank
+
1
:
squeeze_axes
;
flag
&=
(
unsqueeze_axes
>=
0
&&
unsqueeze_axes
<
x_rank
);
flag
&=
((
gather_axis
==
unsqueeze_axes
+
1
)
&&
(
squeeze_axes
==
gather_axis
));
if
(
!
flag
)
return
;
// x->gather->squeeze2_op_out
auto
gather_op_desc
=
gather_op
->
Op
();
gather_op_desc
->
RenameInput
(
unsqueeze2_op_out
->
Var
()
->
Name
(),
x
->
Var
()
->
Name
());
gather_op_desc
->
SetAttr
(
"axis"
,
gather_axis
-
1
);
gather_op_out
->
Var
()
->
SetShape
(
squeeze2_op_out
->
Var
()
->
GetShape
());
gather_op_desc
->
Flush
();
IR_NODE_LINK_TO
(
x
,
gather_op
);
// behind squeeze op node
auto
squeeze_out_link_nodes
=
squeeze2_op_out
->
outputs
;
for
(
auto
out_link_node
:
squeeze_out_link_nodes
)
{
auto
op_desc
=
out_link_node
->
Op
();
op_desc
->
RenameInput
(
squeeze2_op_out
->
Var
()
->
Name
(),
gather_op_out
->
Var
()
->
Name
());
op_desc
->
Flush
();
IR_NODE_LINK_TO
(
gather_op_out
,
out_link_node
);
}
std
::
unordered_set
<
const
Node
*>
delete_nodes
{
squeeze2_op
,
squeeze2_op_out
,
unsqueeze2_op
,
unsqueeze2_op_out
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
void
RedundantUnsqueeze2EliminationPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
for
(
auto
act_type
:
{
"relu"
})
{
FoldTranspose2Ops
(
graph
,
act_type
);
}
FoldGatherSqueeze2Ops
(
graph
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
redundant_unsqueeze_squeeze_elimination_pass
,
paddle
::
framework
::
ir
::
RedundantUnsqueeze2EliminationPass
);
REGISTER_PASS_CAPABILITY
(
redundant_unsqueeze_squeeze_elimination_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"conv2d"
,
0
));
paddle/fluid/framework/ir/xpu/redundant_
onnx_ops
_elimination_pass.h
→
paddle/fluid/framework/ir/xpu/redundant_
unsqueeze_squeeze
_elimination_pass.h
浏览文件 @
dc4b48f6
...
...
@@ -31,51 +31,51 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
class
Redundant
OnnxOps
EliminationPass
:
public
FusePassBase
{
class
Redundant
Unsqueeze2
EliminationPass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
/*
Origin subgraph:
x filter
| |
unsqueeze2(axes={-2}) unsqueeze2(axes={-2})
\ /
\ /
conv2d(conv1d)
x
|
elementwise_add
transpose2
|
squeeze2(axes={-2})
un
squeeze2(axes={-2})
|
batch_nor
m
reduce_su
m
|
act
|
unsqueez
e2
transpos
e2
|
conv2d(conv1d)
Fused subgraph:
x filter
| |
unsqueeze2(axes={-2}) unsqueeze2(axes={-2})
\ /
\ /
conv2d(conv1d)
x
|
elementwise_add
act
|
batch_norm
*/
void
FoldTranspose2Ops
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
;
/*
Origin subgraph:
x
|
act
unsqueeze2(axes={-2})
|
gather
|
squeeze2
|
Fused subgraph:
x
|
gather
|
conv2d(conv1d)
*/
void
FoldConv1dSqueeze2Ops
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
;
void
FoldGatherSqueeze2Ops
(
ir
::
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"redundant_
onnx_ops
_elimination_pass"
};
const
std
::
string
name_scope_
{
"redundant_
unsqueeze_squeeze
_elimination_pass"
};
};
}
// namespace ir
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
dc4b48f6
...
...
@@ -527,7 +527,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fold_interp_outsize_fuse_pass"
,
"fold_two_squeeze2_fuse_pass"
,
"conv1d_xpu_fuse_pass"
,
"redundant_
onnx_ops
_elimination_pass"
,
"redundant_
unsqueeze_squeeze
_elimination_pass"
,
"reduce_ops_fuse_pass"
,
"delete_cast_op_pass"
,
"xpu_delete_cast_op_pass"
,
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
dc4b48f6
...
...
@@ -15,7 +15,7 @@
optional
:
x_max, y_max
-
op
:
add_layernorm_xpu
args
:
(Tensor x, Tensor y, Tensor scale, Tensor bias, int
64_t m, int64_t n
, float epsilon)
args
:
(Tensor x, Tensor y, Tensor scale, Tensor bias, int
begin_norm_axis
, float epsilon)
output
:
Tensor(out), Tensor(mean), Tensor(variance), Tensor(z_add)
infer_meta
:
func
:
AddLayernormXPUInferMeta
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
dc4b48f6
...
...
@@ -24,8 +24,7 @@ XPUOpMap& get_kl2_ops() {
static
XPUOpMap
s_xpu2_kernels
{
{
"add_act_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"add_layernorm_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"add_layernorm_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"abs"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"abs_grad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
dc4b48f6
...
...
@@ -96,8 +96,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
const
MetaTensor
&
y
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
bias
,
int64_t
m
,
int64_t
n
,
int
begin_norm_axis
,
float
epsilon
,
MetaTensor
*
out
,
MetaTensor
*
mean
,
...
...
@@ -106,12 +105,16 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
int
axis
=
-
1
;
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
auto
out_dims
=
x_dims
;
if
(
x_dims
!=
y_dims
)
{
auto
out_dims
=
BroadCastInferShape
(
x_dims
,
y_dims
,
axis
);
out_dims
=
BroadCastInferShape
(
x_dims
,
y_dims
,
axis
);
out
->
set_dims
(
out_dims
);
}
else
{
out
->
set_dims
(
x
_dims
);
out
->
set_dims
(
out
_dims
);
}
auto
layer_norm_x_mat_dims
=
phi
::
flatten_to_2d
(
out_dims
,
begin_norm_axis
);
int64_t
m
=
layer_norm_x_mat_dims
[
0
];
int64_t
n
=
layer_norm_x_mat_dims
[
1
];
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
out
->
share_lod
(
x
);
...
...
paddle/phi/infermeta/fusion.h
浏览文件 @
dc4b48f6
...
...
@@ -34,8 +34,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
const
MetaTensor
&
y
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
bias
,
int64_t
m
,
int64_t
n
,
int
begin_norm_axis
,
float
epsilon
,
MetaTensor
*
out
,
MetaTensor
*
mean
,
...
...
paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc
浏览文件 @
dc4b48f6
...
...
@@ -13,19 +13,65 @@
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace
phi
{
namespace
fusion
{
static
phi
::
DDim
BroadCastInferShape
(
const
DDim
x_dims
,
const
DDim
y_dims
,
int
axis
)
{
std
::
vector
<
int
>
out_dims_array
(
x_dims
.
size
(),
-
1
);
if
(
x_dims
!=
y_dims
)
{
int
max_dim
=
std
::
max
(
x_dims
.
size
(),
y_dims
.
size
());
if
(
x_dims
.
size
()
==
y_dims
.
size
())
{
PADDLE_ENFORCE_EQ
((
axis
==
-
1
)
||
(
axis
==
0
),
true
,
phi
::
errors
::
InvalidArgument
(
"axis should be -1 or 0 while the dimension of "
"tensor X (%s) is equal to the dimension of "
"tensor Y (%s), but received axis: %s"
,
x_dims
.
size
(),
y_dims
.
size
(),
axis
));
}
PADDLE_ENFORCE_EQ
((
axis
>=
(
-
1
*
max_dim
))
&&
(
axis
<
max_dim
),
true
,
phi
::
errors
::
InvalidArgument
(
"The axis range must be [%s, %s), but axis is %s. "
"Please set the axis again."
,
-
1
*
max_dim
,
max_dim
,
axis
));
axis
=
(
axis
<
0
?
(
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
+
axis
+
1
)
:
axis
);
std
::
vector
<
int
>
x_dims_array
(
max_dim
);
std
::
vector
<
int
>
y_dims_array
(
max_dim
);
out_dims_array
.
resize
(
max_dim
);
phi
::
funcs
::
GetBroadcastDimsArrays
(
x_dims
,
y_dims
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
axis
);
return
phi
::
make_ddim
(
out_dims_array
);
}
return
x_dims
;
}
template
<
typename
T
,
typename
Context
>
void
AddLayernormXPUKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
scale
,
const
DenseTensor
&
bias
,
int64_t
m
,
int64_t
n
,
int
begin_norm_axis
,
float
epsilon
,
DenseTensor
*
out
,
DenseTensor
*
mean
,
...
...
@@ -37,12 +83,19 @@ void AddLayernormXPUKernel(const Context& ctx,
auto
*
y_data
=
reinterpret_cast
<
const
XPUType
*>
(
y
.
data
<
T
>
());
const
float
*
scale_data
=
scale
.
data
<
float
>
();
const
float
*
bias_data
=
bias
.
data
<
float
>
();
auto
*
out_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
out
));
float
*
mean_data
=
ctx
.
template
Alloc
<
float
>(
mean
);
float
*
variance_data
=
ctx
.
template
Alloc
<
float
>(
variance
);
auto
*
z_add_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
z_add
));
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
auto
out_dims
=
BroadCastInferShape
(
x_dims
,
y_dims
,
-
1
);
auto
layer_norm_x_mat_dims
=
phi
::
flatten_to_2d
(
out_dims
,
begin_norm_axis
);
int64_t
m
=
layer_norm_x_mat_dims
[
0
];
int64_t
n
=
layer_norm_x_mat_dims
[
1
];
auto
*
out_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
out
));
int
r
=
xpu
::
add_layer_norm_fusion
<
XPUType
>
(
// T
/* baidu::xpu::api::Context* ctx */
ctx
.
x_context
(),
/* const T* x */
x_data
,
...
...
@@ -66,5 +119,4 @@ PD_REGISTER_KERNEL(add_layernorm_xpu,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
AddLayernormXPUKernel
,
float
,
phi
::
dtype
::
float16
)
{}
float
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录