Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eb12739e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eb12739e
编写于
7月 07, 2023
作者:
W
wz1qqx
提交者:
GitHub
7月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add layernorm fuse pass (#55154)
上级
6af85a81
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
464 addition
and
0 deletion
+464
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc
+244
-0
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
+1
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+9
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+2
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+34
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+12
-0
paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc
paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc
+70
-0
test/ir/inference/test_xpu_add_layernorm_fuse_pass.py
test/ir/inference/test_xpu_add_layernorm_fuse_pass.py
+89
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
eb12739e
...
@@ -268,6 +268,8 @@ if(WITH_XPU)
...
@@ -268,6 +268,8 @@ if(WITH_XPU)
xpu DEPS
${
XPU_PASS_DEPS
}
)
xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
add_activation_xpu_fuse_pass inference DIR xpu DEPS
pass_library
(
add_activation_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
${
XPU_PASS_DEPS
}
)
pass_library
(
add_layernorm_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
xpu_delete_cast_op_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
xpu_delete_cast_op_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fold_interp_outsize_fuse_pass inference DIR xpu DEPS
pass_library
(
fold_interp_outsize_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
${
XPU_PASS_DEPS
}
)
...
...
paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc
0 → 100644
浏览文件 @
eb12739e
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
/*
fuse ele_add + activation block in to xpu_ele_fusion op
For example:
graph:
ele_x
|
elementwise_add -----ele_y
|
layernorm
|
output
------------------------------------------------------
After the pass is applied:
ele_x
| ele_y
| /
| /
scale---- add_layernorm_fusion ---- bias
/ | \ \
/ | \ \
variance | meam z_add
Output
*/
struct
AddLayernormXPUPattern
:
public
PatternBase
{
AddLayernormXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
ele_add
);
PATTERN_DECL_NODE
(
l_norm
);
// declare variable node's name
PATTERN_DECL_NODE
(
ele_x
);
PATTERN_DECL_NODE
(
ele_y
);
PATTERN_DECL_NODE
(
ele_out
);
PATTERN_DECL_NODE
(
norm_bias
);
PATTERN_DECL_NODE
(
norm_scale
);
PATTERN_DECL_NODE
(
norm_mean
);
PATTERN_DECL_NODE
(
norm_variance
);
PATTERN_DECL_NODE
(
norm_out
);
};
AddLayernormXPUPattern
::
AddLayernormXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
ele_add
=
pattern
->
NewNode
(
ele_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
ele_x
=
pattern
->
NewNode
(
ele_x_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsInput
();
auto
ele_y
=
pattern
->
NewNode
(
ele_y_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
auto
ele_out
=
pattern
->
NewNode
(
ele_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"layer_norm"
,
"X"
)
->
assert_has_n_outputs
(
1
);
ele_add
->
LinksFrom
({
ele_x
,
ele_y
}).
LinksTo
({
ele_out
});
auto
l_norm
=
pattern
->
NewNode
(
l_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
norm_bias
=
pattern
->
NewNode
(
norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
norm_scale
=
pattern
->
NewNode
(
norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
norm_mean
=
pattern
->
NewNode
(
norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
norm_variance
=
pattern
->
NewNode
(
norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
norm_out
=
pattern
->
NewNode
(
norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
l_norm
->
LinksFrom
({
ele_out
,
norm_bias
,
norm_scale
})
.
LinksTo
({
norm_out
,
norm_mean
,
norm_variance
});
}
}
// namespace patterns
namespace
{
void
setIntermediateOut
(
OpDesc
*
desc
,
const
std
::
string
&
out_name
,
const
std
::
string
&
scope_name
)
{
std
::
string
new_name
=
scope_name
+
"/at."
+
out_name
+
".new"
;
desc
->
SetOutput
(
out_name
,
{
new_name
});
}
void
addIntermediateOut
(
Node
*
op_node
,
const
std
::
string
&
out_name
,
const
std
::
string
&
scope_name
,
Graph
*
graph
)
{
std
::
string
new_name
=
scope_name
+
"/at."
+
out_name
+
".new"
;
VarDesc
out_var
(
new_name
);
out_var
.
SetPersistable
(
false
);
auto
*
node_var
=
graph
->
CreateVarNode
(
&
out_var
);
IR_NODE_LINK_TO
(
op_node
,
node_var
);
}
}
// namespace
class
AddLayernormXPUFusePass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
FuseAddLayernorm
(
ir
::
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"add_layernorm_xpu_fuse_pass"
};
};
void
AddLayernormXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
FuseAddLayernorm
(
graph
);
}
void
AddLayernormXPUFusePass
::
FuseAddLayernorm
(
ir
::
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
AddLayernormXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle AddLayernormXPUFusePass fuse"
;
// declare operator node's name
GET_IR_NODE
(
ele_add
);
GET_IR_NODE
(
l_norm
);
// declare variable node's name
GET_IR_NODE
(
ele_x
);
GET_IR_NODE
(
ele_y
);
GET_IR_NODE
(
ele_out
);
GET_IR_NODE
(
norm_bias
);
GET_IR_NODE
(
norm_scale
);
GET_IR_NODE
(
norm_mean
);
GET_IR_NODE
(
norm_variance
);
GET_IR_NODE
(
norm_out
);
auto
*
block
=
ele_add
->
Op
()
->
Block
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
float
eps
=
PADDLE_GET_CONST
(
float
,
l_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
int
begin_norm_axis
=
PADDLE_GET_CONST
(
int
,
l_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
auto
layer_norm_x_dims
=
ele_out
->
Var
()
->
GetShape
();
auto
layer_norm_x_mat_dims
=
phi
::
flatten_to_2d
(
phi
::
make_ddim
(
layer_norm_x_dims
),
begin_norm_axis
);
int64_t
m
=
layer_norm_x_mat_dims
[
0
];
int64_t
n
=
layer_norm_x_mat_dims
[
1
];
std
::
string
fused_op_out_name
;
fused_op_out_name
=
norm_out
->
Name
();
// Generate add_layernorm fused op
framework
::
OpDesc
fused_op_desc
(
block
);
fused_op_desc
.
SetType
(
"add_layernorm_xpu"
);
// set attrs for fused op
fused_op_desc
.
SetInput
(
"x"
,
{
ele_x
->
Name
()});
fused_op_desc
.
SetInput
(
"y"
,
{
ele_y
->
Name
()});
fused_op_desc
.
SetInput
(
"scale"
,
{
norm_scale
->
Name
()});
fused_op_desc
.
SetInput
(
"bias"
,
{
norm_bias
->
Name
()});
fused_op_desc
.
SetAttr
(
"m"
,
m
);
fused_op_desc
.
SetAttr
(
"n"
,
n
);
fused_op_desc
.
SetAttr
(
"epsilon"
,
eps
);
fused_op_desc
.
SetOutput
(
"out"
,
{
fused_op_out_name
});
setIntermediateOut
(
&
fused_op_desc
,
"mean"
,
name_scope_
);
setIntermediateOut
(
&
fused_op_desc
,
"variance"
,
name_scope_
);
setIntermediateOut
(
&
fused_op_desc
,
"z_add"
,
name_scope_
);
// relink fused op
auto
*
fused_op
=
graph
->
CreateOpNode
(
&
fused_op_desc
);
IR_NODE_LINK_TO
(
ele_x
,
fused_op
);
IR_NODE_LINK_TO
(
ele_y
,
fused_op
);
IR_NODE_LINK_TO
(
norm_scale
,
fused_op
);
IR_NODE_LINK_TO
(
norm_bias
,
fused_op
);
IR_NODE_LINK_TO
(
fused_op
,
norm_out
);
addIntermediateOut
(
fused_op
,
"mean"
,
name_scope_
,
graph
);
addIntermediateOut
(
fused_op
,
"variance"
,
name_scope_
,
graph
);
addIntermediateOut
(
fused_op
,
"z_add"
,
name_scope_
,
graph
);
delete_nodes
.
insert
({
ele_add
,
l_norm
,
ele_out
,
norm_mean
,
norm_variance
});
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
add_layernorm_xpu_fuse_pass
,
paddle
::
framework
::
ir
::
AddLayernormXPUFusePass
);
REGISTER_PASS_CAPABILITY
(
add_layernorm_xpu_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"add_layernorm_xpu"
,
0
));
paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
浏览文件 @
eb12739e
...
@@ -260,6 +260,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -260,6 +260,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
"sigmoid"
,
"sigmoid"
,
"swish"
,
"swish"
,
"relu6"
,
"relu6"
,
"leaky_relu"
,
""
,
""
,
})
{
})
{
found_subgraph_count
+=
found_subgraph_count
+=
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
eb12739e
...
@@ -544,6 +544,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
...
@@ -544,6 +544,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"conv2d_xpu_fuse_pass"
,
"conv2d_xpu_fuse_pass"
,
"conv2d_transpose_xpu_fuse_pass"
,
"conv2d_transpose_xpu_fuse_pass"
,
"add_activation_xpu_fuse_pass"
,
"add_activation_xpu_fuse_pass"
,
"add_layernorm_xpu_fuse_pass"
,
"yolo_box_xpu_fuse_pass"
,
"yolo_box_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
"link_xpu_op_max_pass"
,
"inplace_op_var_pass"
,
"inplace_op_var_pass"
,
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
eb12739e
...
@@ -14,6 +14,15 @@
...
@@ -14,6 +14,15 @@
data_type
:
x
data_type
:
x
optional
:
x_max, y_max
optional
:
x_max, y_max
-
op
:
add_layernorm_xpu
args
:
(Tensor x, Tensor y, Tensor scale, Tensor bias, int64_t m, int64_t n, float epsilon)
output
:
Tensor(out), Tensor(mean), Tensor(variance), Tensor(z_add)
infer_meta
:
func
:
AddLayernormXPUInferMeta
kernel
:
func
:
add_layernorm_xpu
data_type
:
x
-
op
:
conv2d_transpose_xpu
-
op
:
conv2d_transpose_xpu
args
:
(Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format, bool has_bias, bool with_act, str act_type)
args
:
(Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format, bool has_bias, bool with_act, str act_type)
output
:
Tensor(out), Tensor(out_max)
output
:
Tensor(out), Tensor(out_max)
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
eb12739e
...
@@ -24,6 +24,8 @@ XPUOpMap& get_kl2_ops() {
...
@@ -24,6 +24,8 @@ XPUOpMap& get_kl2_ops() {
static
XPUOpMap
s_xpu2_kernels
{
static
XPUOpMap
s_xpu2_kernels
{
{
"add_act_xpu"
,
{
"add_act_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"add_layernorm_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"abs"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"abs"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"abs_grad"
,
{
"abs_grad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
eb12739e
...
@@ -92,6 +92,40 @@ void AddActXPUInferMeta(const MetaTensor& x,
...
@@ -92,6 +92,40 @@ void AddActXPUInferMeta(const MetaTensor& x,
out_max
->
set_layout
(
x
.
layout
());
out_max
->
set_layout
(
x
.
layout
());
}
}
void
AddLayernormXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
bias
,
int64_t
m
,
int64_t
n
,
float
epsilon
,
MetaTensor
*
out
,
MetaTensor
*
mean
,
MetaTensor
*
variance
,
MetaTensor
*
z_add
)
{
int
axis
=
-
1
;
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
if
(
x_dims
!=
y_dims
)
{
auto
out_dims
=
BroadCastInferShape
(
x_dims
,
y_dims
,
axis
);
out
->
set_dims
(
out_dims
);
}
else
{
out
->
set_dims
(
x_dims
);
}
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
out
->
share_lod
(
x
);
mean
->
set_dims
(
phi
::
make_ddim
({
m
}));
mean
->
set_dtype
(
DataType
::
FLOAT32
);
mean
->
set_layout
(
x
.
layout
());
variance
->
set_dims
(
phi
::
make_ddim
({
m
}));
variance
->
set_dtype
(
DataType
::
FLOAT32
);
variance
->
set_layout
(
x
.
layout
());
z_add
->
set_dims
(
phi
::
make_ddim
({
m
,
n
}));
z_add
->
set_dtype
(
x
.
dtype
());
z_add
->
set_layout
(
x
.
layout
());
}
inline
int
ConvOutSize
(
int
input_size
,
inline
int
ConvOutSize
(
int
input_size
,
int
filter_size
,
int
filter_size
,
int
dilation
,
int
dilation
,
...
...
paddle/phi/infermeta/fusion.h
浏览文件 @
eb12739e
...
@@ -30,6 +30,18 @@ void AddActXPUInferMeta(const MetaTensor& x,
...
@@ -30,6 +30,18 @@ void AddActXPUInferMeta(const MetaTensor& x,
MetaTensor
*
out
,
MetaTensor
*
out
,
MetaTensor
*
out_max
);
MetaTensor
*
out_max
);
void
AddLayernormXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
bias
,
int64_t
m
,
int64_t
n
,
float
epsilon
,
MetaTensor
*
out
,
MetaTensor
*
mean
,
MetaTensor
*
variance
,
MetaTensor
*
z_add
);
void
Conv2dXPUInferMeta
(
const
MetaTensor
&
x
,
void
Conv2dXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
filter
,
const
MetaTensor
&
filter
,
...
...
paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc
0 → 100644
浏览文件 @
eb12739e
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
AddLayernormXPUKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
scale
,
const
DenseTensor
&
bias
,
int64_t
m
,
int64_t
n
,
float
epsilon
,
DenseTensor
*
out
,
DenseTensor
*
mean
,
DenseTensor
*
variance
,
DenseTensor
*
z_add
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
*
x_data
=
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
());
auto
*
y_data
=
reinterpret_cast
<
const
XPUType
*>
(
y
.
data
<
T
>
());
const
float
*
scale_data
=
scale
.
data
<
float
>
();
const
float
*
bias_data
=
bias
.
data
<
float
>
();
auto
*
out_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
out
));
float
*
mean_data
=
ctx
.
template
Alloc
<
float
>(
mean
);
float
*
variance_data
=
ctx
.
template
Alloc
<
float
>(
variance
);
auto
*
z_add_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
z_add
));
int
r
=
xpu
::
add_layer_norm_fusion
<
XPUType
>
(
// T
/* baidu::xpu::api::Context* ctx */
ctx
.
x_context
(),
/* const T* x */
x_data
,
/* const T* y */
y_data
,
/* T* z */
out_data
,
/* int64_t m */
m
,
/* int64_t n */
n
,
/* float epsilon */
epsilon
,
/* const float* scale */
scale_data
,
/* const float* bias */
bias_data
,
/* float* mean */
mean_data
,
/* float* variance */
variance_data
,
/* T* z_add */
z_add_data
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"add_layernorm_xpu"
);
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
add_layernorm_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
AddLayernormXPUKernel
,
float
,
phi
::
dtype
::
float16
)
{}
test/ir/inference/test_xpu_add_layernorm_fuse_pass.py
0 → 100644
浏览文件 @
eb12739e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
class
TestAddLayernormXPUFusePass
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"add_layernorm_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
50
))
x_shape
=
[
batch_size
,
16
,
128
]
y_shape
=
x_shape
axis
=
-
1
epsilon
=
draw
(
st
.
floats
(
min_value
=
0.0000001
,
max_value
=
0.001
))
# begin_norm_axis has to be 2
begin_norm_axis
=
2
# Here we will compose a program
# Still has some risks that the program is invalid or cause bug while running
# Use function `is_program_valid` to filter the invalid programs before running
# Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing
elementwise_op
=
OpConfig
(
type
=
'elementwise_add'
,
inputs
=
{
'X'
:
[
'eltwise_X'
],
'Y'
:
[
'eltwise_Y'
]},
outputs
=
{
'Out'
:
[
'eltwise_output'
]},
axis
=
axis
,
)
layer_norm_op
=
OpConfig
(
"layer_norm"
,
inputs
=
{
"X"
:
[
"eltwise_output"
],
"Scale"
:
[
"layer_norm_scale"
],
"Bias"
:
[
"layer_norm_bias"
],
},
outputs
=
{
"Y"
:
[
"layer_norm_out"
],
"Mean"
:
[
"layer_norm_mean"
],
"Variance"
:
[
"layer_norm_var"
],
},
begin_norm_axis
=
begin_norm_axis
,
epsilon
=
epsilon
,
)
mini_graph
=
[
elementwise_op
,
layer_norm_op
]
program_config
=
ProgramConfig
(
ops
=
mini_graph
,
weights
=
{
"layer_norm_scale"
:
TensorConfig
(
shape
=
[
x_shape
[
2
]]),
"layer_norm_bias"
:
TensorConfig
(
shape
=
[
x_shape
[
2
]]),
},
inputs
=
{
"eltwise_X"
:
TensorConfig
(
shape
=
x_shape
),
"eltwise_Y"
:
TensorConfig
(
shape
=
y_shape
),
},
outputs
=
mini_graph
[
-
1
].
outputs
[
"Y"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"add_layernorm_xpu_fuse_pass"
],
)
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
200
)
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录