Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f16e1869
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f16e1869
编写于
8月 16, 2023
作者:
J
jiangfan06
提交者:
GitHub
8月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add fast_layernorm_xpu_fuse_pass and fast_layernorm_xpu plugin (#56269)
上级
be22021c
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
827 addition
and
6 deletion
+827
-6
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/layer_norm_fuse_pass.cc
paddle/fluid/framework/ir/layer_norm_fuse_pass.cc
+1
-1
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
+3
-1
paddle/fluid/framework/ir/xpu/fast_layernorm_xpu_fuse_pass.cc
...le/fluid/framework/ir/xpu/fast_layernorm_xpu_fuse_pass.cc
+187
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+9
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+2
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+11
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+7
-0
paddle/phi/kernels/fusion/xpu/fast_layernorm_xpu_kernel.cc
paddle/phi/kernels/fusion/xpu/fast_layernorm_xpu_kernel.cc
+122
-0
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+9
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_layer_norm.xpu
...nels/xpu/plugin/src/kernel/kunlun2cpp/fast_layer_norm.xpu
+243
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_layer_norm.cpp
...le/phi/kernels/xpu/plugin/src/wrapper/fast_layer_norm.cpp
+153
-0
paddle/phi/kernels/xpu/take_along_axis_kernel.cc
paddle/phi/kernels/xpu/take_along_axis_kernel.cc
+0
-4
test/ir/inference/test_xpu_fast_layernorm_xpu_fuse_pass.py
test/ir/inference/test_xpu_fast_layernorm_xpu_fuse_pass.py
+77
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
f16e1869
...
...
@@ -284,6 +284,8 @@ if(WITH_XPU)
${
XPU_PASS_DEPS
}
)
pass_library
(
gather_squeeze_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fast_where_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
cc_library
(
...
...
paddle/fluid/framework/ir/layer_norm_fuse_pass.cc
浏览文件 @
f16e1869
...
...
@@ -372,7 +372,7 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
// ------------------ op creation and placement ---------------------------
OpDesc
ln_op_desc
;
OpDesc
ln_op_desc
(
x_mean
->
Op
()
->
Block
())
;
ln_op_desc
.
SetType
(
"layer_norm"
);
ln_op_desc
.
SetInput
(
"X"
,
{
x
->
Name
()});
ln_op_desc
.
SetInput
(
"Scale"
,
{
new_gamma_node
->
Name
()});
...
...
paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc
浏览文件 @
f16e1869
...
...
@@ -119,7 +119,9 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes(
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
block
=
node
->
Op
()
->
Block
();
break
;
if
(
block
!=
nullptr
)
{
break
;
}
}
}
Scope
&
scope
=
graph
->
Get
<
framework
::
Scope
>
(
"__param_scope__"
);
...
...
paddle/fluid/framework/ir/xpu/fast_layernorm_xpu_fuse_pass.cc
0 → 100644
浏览文件 @
f16e1869
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
/*
change layernorm op to fast_layernorm op
For example:
graph:
x
|
layernorm
|
output
------------------------------------------------------
After the pass is applied:
x
|
fast_layernorm_xpu
|
output
*/
struct
FastLayernormXPUPattern
:
public
PatternBase
{
FastLayernormXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
l_norm
);
// declare variable node's name
PATTERN_DECL_NODE
(
norm_in
);
PATTERN_DECL_NODE
(
norm_bias
);
PATTERN_DECL_NODE
(
norm_scale
);
PATTERN_DECL_NODE
(
norm_mean
);
PATTERN_DECL_NODE
(
norm_variance
);
PATTERN_DECL_NODE
(
norm_out
);
};
FastLayernormXPUPattern
::
FastLayernormXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
l_norm
=
pattern
->
NewNode
(
l_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
norm_in
=
pattern
->
NewNode
(
norm_in_repr
())
->
AsInput
()
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
auto
norm_bias
=
pattern
->
NewNode
(
norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
norm_scale
=
pattern
->
NewNode
(
norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
norm_mean
=
pattern
->
NewNode
(
norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
)
->
assert_more
([](
Node
*
node
)
{
return
node
->
outputs
.
size
()
==
0
;
});
auto
norm_variance
=
pattern
->
NewNode
(
norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
)
->
assert_more
([](
Node
*
node
)
{
return
node
->
outputs
.
size
()
==
0
;
});
auto
norm_out
=
pattern
->
NewNode
(
norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
l_norm
->
LinksFrom
({
norm_in
,
norm_bias
,
norm_scale
})
.
LinksTo
({
norm_out
,
norm_mean
,
norm_variance
});
}
}
// namespace patterns
class
FastLayernormXPUFusePass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
FuseFastLayernorm
(
ir
::
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fast_layernorm_xpu_fuse_pass"
};
};
void
FastLayernormXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
FuseFastLayernorm
(
graph
);
}
void
FastLayernormXPUFusePass
::
FuseFastLayernorm
(
ir
::
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
FastLayernormXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FastLayernormXPUFusePass"
;
// declare operator node's name
GET_IR_NODE
(
l_norm
);
// declare variable node's name
GET_IR_NODE
(
norm_in
);
GET_IR_NODE
(
norm_bias
);
GET_IR_NODE
(
norm_scale
);
GET_IR_NODE
(
norm_mean
);
GET_IR_NODE
(
norm_variance
);
GET_IR_NODE
(
norm_out
);
auto
*
block
=
l_norm
->
Op
()
->
Block
();
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope cannot be nullptr."
));
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
float
eps
=
PADDLE_GET_CONST
(
float
,
l_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
int
begin_norm_axis
=
PADDLE_GET_CONST
(
int
,
l_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
// Generate fast_layernorm_xpu op
framework
::
OpDesc
fused_op_desc
(
block
);
fused_op_desc
.
SetType
(
"fast_layernorm_xpu"
);
fused_op_desc
.
SetInput
(
"x"
,
{
norm_in
->
Name
()});
fused_op_desc
.
SetInput
(
"scale"
,
{
norm_scale
->
Name
()});
fused_op_desc
.
SetInput
(
"bias"
,
{
norm_bias
->
Name
()});
fused_op_desc
.
SetAttr
(
"epsilon"
,
eps
);
fused_op_desc
.
SetAttr
(
"begin_norm_axis"
,
begin_norm_axis
);
fused_op_desc
.
SetOutput
(
"out"
,
{
norm_out
->
Name
()});
auto
*
fused_op
=
graph
->
CreateOpNode
(
&
fused_op_desc
);
IR_NODE_LINK_TO
(
norm_in
,
fused_op
);
IR_NODE_LINK_TO
(
norm_scale
,
fused_op
);
IR_NODE_LINK_TO
(
norm_bias
,
fused_op
);
IR_NODE_LINK_TO
(
fused_op
,
norm_out
);
delete_nodes
.
insert
({
l_norm
,
norm_mean
,
norm_variance
});
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fast_layernorm_xpu_fuse_pass
,
paddle
::
framework
::
ir
::
FastLayernormXPUFusePass
);
REGISTER_PASS_CAPABILITY
(
fast_layernorm_xpu_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"layer_norm"
,
0
));
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
f16e1869
...
...
@@ -547,6 +547,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"conv2d_transpose_xpu_fuse_pass"
,
"add_activation_xpu_fuse_pass"
,
"add_layernorm_xpu_fuse_pass"
,
"fast_layernorm_xpu_fuse_pass"
,
"yolo_box_xpu_fuse_pass"
,
"fast_where_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
f16e1869
...
...
@@ -63,6 +63,15 @@
data_type
:
tables
optional
:
mask, seq_lod, max_seq_len
-
op
:
fast_layernorm_xpu
args
:
(Tensor x, Tensor scale, Tensor bias, int begin_norm_axis, float epsilon)
output
:
Tensor(out)
infer_meta
:
func
:
FastLayernormXPUInferMeta
kernel
:
func
:
fast_layernorm_xpu
data_type
:
x
-
op
:
fast_where_xpu
args
:
(Tensor condition, Tensor x, Tensor y)
output
:
Tensor(out)
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
f16e1869
...
...
@@ -306,6 +306,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet
({
phi
::
DataType
::
INT32
,
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fast_layernorm_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fc_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fill"
,
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
f16e1869
...
...
@@ -825,4 +825,15 @@ void FastWhereXPUInferMeta(const MetaTensor& condition,
out
->
set_dtype
(
x
.
dtype
());
}
void
FastLayernormXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
bias
,
int
begin_norm_axis
,
float
epsilon
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
}
}
// namespace phi
paddle/phi/infermeta/fusion.h
浏览文件 @
f16e1869
...
...
@@ -197,4 +197,11 @@ void FastWhereXPUInferMeta(const MetaTensor& condition,
const
MetaTensor
&
y
,
MetaTensor
*
out
);
void
FastLayernormXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
bias
,
int
begin_norm_axis
,
float
epsilon
,
MetaTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/fusion/xpu/fast_layernorm_xpu_kernel.cc
0 → 100644
浏览文件 @
f16e1869
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
FastLayerNormXPUKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
paddle
::
optional
<
DenseTensor
>&
scale
,
const
paddle
::
optional
<
DenseTensor
>&
bias
,
int
begin_norm_axis
,
float
epsilon
,
DenseTensor
*
out
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
const
auto
&
x_dims
=
x
.
dims
();
auto
matrix_dim
=
phi
::
flatten_to_2d
(
x_dims
,
begin_norm_axis
);
int
left
=
static_cast
<
int
>
(
matrix_dim
[
0
]);
int
right
=
static_cast
<
int
>
(
matrix_dim
[
1
]);
const
auto
*
x_data
=
x
.
data
<
T
>
();
xpu
::
ctx_guard
RAII_GUARD
(
ctx
.
x_context
());
// scale
const
float
*
scale_data_fp32
=
nullptr
;
const
auto
*
scale_ptr
=
scale
.
get_ptr
();
if
(
scale_ptr
==
nullptr
)
{
float
*
scale_data_temp
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
right
);
int
r
=
xpu
::
constant
<
float
>
(
ctx
.
x_context
(),
scale_data_temp
,
right
,
1.0
f
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"constant"
);
scale_data_fp32
=
scale_data_temp
;
}
else
if
(
scale_ptr
->
dtype
()
==
phi
::
CppTypeToDataType
<
phi
::
dtype
::
float16
>::
Type
())
{
float
*
scale_data_temp
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
scale_ptr
->
numel
());
int
r
=
xpu
::
cast
<
XPUType
,
float
>
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
scale_ptr
->
data
<
T
>
()),
scale_data_temp
,
scale_ptr
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"cast"
);
scale_data_fp32
=
scale_data_temp
;
}
else
{
// no need to cast
scale_data_fp32
=
scale_ptr
->
data
<
float
>
();
}
// bias
const
float
*
bias_data_fp32
=
nullptr
;
const
auto
*
bias_ptr
=
bias
.
get_ptr
();
if
(
bias_ptr
==
nullptr
)
{
float
*
bias_data_temp
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
right
);
int
r
=
xpu
::
constant
<
float
>
(
ctx
.
x_context
(),
bias_data_temp
,
right
,
0.0
f
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"constant"
);
bias_data_fp32
=
bias_data_temp
;
}
else
if
(
bias_ptr
->
dtype
()
==
phi
::
CppTypeToDataType
<
phi
::
dtype
::
float16
>::
Type
())
{
float
*
bias_data_temp
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
bias_ptr
->
numel
());
int
r
=
xpu
::
cast
<
XPUType
,
float
>
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
bias_ptr
->
data
<
T
>
()),
bias_data_temp
,
bias_ptr
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"cast"
);
bias_data_fp32
=
bias_data_temp
;
}
else
{
// no need to cast
bias_data_fp32
=
bias_ptr
->
data
<
float
>
();
}
auto
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
#ifdef PADDLE_WITH_XPU_PLUGIN
int
r
=
xpu
::
plugin
::
fast_layer_norm
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x_data
),
reinterpret_cast
<
XPUType
*>
(
out_data
),
left
,
right
,
epsilon
,
scale_data_fp32
,
bias_data_fp32
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"fast_layer_norm"
);
#else
// int layer_norm(Context* ctx, const T* x, T* y, int64_t m, int64_t n, float
// eps, const float* scale, const float* bias, float* mean, float* var, bool
// is_rstd = false);
int
r
=
xpu
::
layer_norm
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x_data
),
reinterpret_cast
<
XPUType
*>
(
out_data
),
left
,
right
,
epsilon
,
scale_data_fp32
,
bias_data_fp32
,
nullptr
,
nullptr
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"layer_norm"
);
#endif
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fast_layernorm_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FastLayerNormXPUKernel
,
float
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
f16e1869
...
...
@@ -66,6 +66,15 @@ DLL_EXPORT int take_along_axis(Context* ctx,
const
std
::
vector
<
int64_t
>&
idxshape
,
int64_t
axis
);
template
<
typename
T
>
DLL_EXPORT
int
fast_layer_norm
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
int64_t
m
,
int64_t
n
,
float
eps
,
const
float
*
scale
,
const
float
*
bias
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_layer_norm.xpu
0 → 100644
浏览文件 @
f16e1869
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
static inline __device__ float sum16(const float* ptr) {
float s0 = ptr[0] + ptr[8];
float s1 = ptr[1] + ptr[9];
float s2 = ptr[2] + ptr[10];
float s3 = ptr[3] + ptr[11];
float s4 = ptr[4] + ptr[12];
float s5 = ptr[5] + ptr[13];
float s6 = ptr[6] + ptr[14];
float s7 = ptr[7] + ptr[15];
s0 = s0 + s1;
s2 = s2 + s3;
s4 = s4 + s5;
s6 = s6 + s7;
s0 = s0 + s2;
s4 = s4 + s6;
return s0 + s4;
}
template <typename T>
static __device__ void update_sum_and_squaresum(T* a,
int size,
float* sum,
float* squaresum) {
__simd__ float sum_buf[16];
__simd__ float squaresum_buf[16];
float32x16_t al;
float32x16_t ah;
int rounddown_size = rounddown32(size - 1);
unsigned int mask = -1;
if ((size % 32) != 0) {
mask = ~(-1 << (size % 32));
}
vload2_lm_mz(a + rounddown_size, al, ah, mask);
float32x16_t vsum = vvadd_float32x16(al, ah);
al = vvmul_float32x16(al, al);
ah = vvmul_float32x16(ah, ah);
float32x16_t vsquaresum = vvadd_float32x16(al, ah);
for (int i = 0; i < rounddown_size; i += 32) {
vload2_lm(a + i, al, ah);
vsum = vvadd_float32x16(vsum, al);
vsum = vvadd_float32x16(vsum, ah);
al = vvmul_float32x16(al, al);
ah = vvmul_float32x16(ah, ah);
vsquaresum = vvadd_float32x16(vsquaresum, al);
vsquaresum = vvadd_float32x16(vsquaresum, ah);
}
vstore_lm_float32x16(sum_buf, vsum);
vstore_lm_float32x16(squaresum_buf, vsquaresum);
mfence_lm();
*sum = sum16(sum_buf);
*squaresum = sum16(squaresum_buf);
}
template <typename T>
static __device__ void vector_scale_and_bias_align32(
T* a,
int size,
float mean,
float var,
_shared_ptr_ const float* scale_sm,
_shared_ptr_ const float* bias_sm,
bool do_scale_bias) {
float32x16_t al;
float32x16_t ah;
float32x16_t bl;
float32x16_t bh;
mean = 0.0f - mean;
if (do_scale_bias) {
// ((a + b) - mean) * var * scale + bias
for (int i = 0; i < size; i += 32) {
vload2_lm(a + i, al, ah);
vload2_sm(scale_sm + i, bl, bh);
al = svadd_float32x16(mean, al);
ah = svadd_float32x16(mean, ah);
al = svmul_float32x16(var, al);
ah = svmul_float32x16(var, ah);
al = vvmul_float32x16(bl, al);
ah = vvmul_float32x16(bh, ah);
vload2_sm(bias_sm + i, bl, bh);
al = vvadd_float32x16(bl, al);
ah = vvadd_float32x16(bh, ah);
vstore2_lm(a + i, al, ah);
}
} else {
// ((a + b) - mean) * var
for (int i = 0; i < size; i += 32) {
vload2_lm(a + i, al, ah);
al = svadd_float32x16(mean, al);
ah = svadd_float32x16(mean, ah);
al = svmul_float32x16(var, al);
ah = svmul_float32x16(var, ah);
vstore2_lm(a + i, al, ah);
}
}
mfence_lm();
}
template <typename T>
__global__ void fast_layer_norm_tiny_align32(float epsilon,
int64_t m,
int64_t n,
const T* x,
T* y,
const float* scale,
const float* bias) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t mstart = 0;
int64_t mend = 0;
partition(tid, nthreads, m, 1, &mstart, &mend);
if (mstart >= mend) {
return;
}
float one_div_n = 1.0f / n;
constexpr int lm_buffer_size = 1664 * sizeof(float) / sizeof(T);
constexpr int sm_buffer_size = 1664 * 16;
__simd__ T xlm[lm_buffer_size];
__simd__ __shared__ float scale_sm[sm_buffer_size];
__simd__ __shared__ float bias_sm[sm_buffer_size];
int block_cnt = lm_buffer_size / n;
float sum = 0.0f;
float squaresum = 0.0f;
bool do_scale_bias = false;
if (scale != nullptr && bias != nullptr) {
do_scale_bias = true;
}
if (cid == 0 && do_scale_bias) {
GM2SM_ASYNC(scale, scale_sm, n * sizeof(float));
GM2SM(bias, bias_sm, n * sizeof(float));
}
sync_all();
for (int64_t i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * n, block_cnt * n);
GM2LM(x + i * n, xlm, readlen * sizeof(T));
for (int64_t j = 0; j < readlen; j += n) {
update_sum_and_squaresum<T>(xlm + j, n, &sum, &squaresum);
float sample_mean = sum * one_div_n;
float sample_var = squaresum * one_div_n - sample_mean * sample_mean;
float rstd = 1.0f / sqrt(sample_var + epsilon);
vector_scale_and_bias_align32<T>(
xlm + j, n, sample_mean, rstd, scale_sm, bias_sm, do_scale_bias);
}
LM2GM(xlm, y + i * n, readlen * sizeof(T));
}
}
template <typename T>
__global__ void fast_layer_norm_tiny_common(float epsilon,
int64_t m,
int64_t n,
const T* x,
T* y,
const float* scale,
const float* bias) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t mstart = 0;
int64_t mend = 0;
partition(tid, nthreads, m, 1, &mstart, &mend);
if (mstart >= mend) {
return;
}
float one_div_n = 1.0f / n;
constexpr int lm_buffer_size = 832 * sizeof(float) / sizeof(T);
constexpr int sm_buffer_size = 1664 * 16;
__simd__ T xlm[lm_buffer_size];
__simd__ __shared__ float scale_sm[sm_buffer_size];
__simd__ __shared__ float bias_sm[sm_buffer_size];
float sum = 0.0f;
float squaresum = 0.0f;
bool do_scale_bias = false;
if (scale != nullptr && bias != nullptr) {
do_scale_bias = true;
}
if (cid == 0 && do_scale_bias) {
GM2SM_ASYNC(scale, scale_sm, n * sizeof(float));
GM2SM(bias, bias_sm, n * sizeof(float));
}
sync_all();
for (int64_t i = mstart; i < mend; i += 1) {
GM2LM(x + i * n, xlm, n * sizeof(T));
update_sum_and_squaresum<T>(xlm, n, &sum, &squaresum);
float sample_mean = sum * one_div_n;
float sample_var = squaresum * one_div_n - sample_mean * sample_mean;
float rstd = 1.0f / sqrt(sample_var + epsilon);
vector_scale_and_bias_align32<T>(
xlm, n, sample_mean, rstd, scale_sm, bias_sm, do_scale_bias);
LM2GM(xlm, y + i * n, n * sizeof(T));
}
}
#define _XPU_DEF__FAST_LAYER_NORM_TINY_(DTYPE) \
template __global__ void fast_layer_norm_tiny_common<DTYPE>( \
float epsilon, \
int64_t m, \
int64_t n, \
const DTYPE* x, \
DTYPE* y, \
const float* scale, \
const float* bias); \
template __global__ void fast_layer_norm_tiny_align32<DTYPE>( \
float epsilon, \
int64_t m, \
int64_t n, \
const DTYPE* x, \
DTYPE* y, \
const float* scale, \
const float* bias);
_XPU_DEF__FAST_LAYER_NORM_TINY_(float16);
_XPU_DEF__FAST_LAYER_NORM_TINY_(float);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_layer_norm.cpp
0 → 100644
浏览文件 @
f16e1869
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
T
>
__attribute__
((
global
))
void
fast_layer_norm_tiny_common
(
float
epsilon
,
int64_t
m
,
int64_t
n
,
const
T
*
x
,
T
*
y
,
const
float
*
scale
,
const
float
*
bias
);
template
<
typename
T
>
__attribute__
((
global
))
void
fast_layer_norm_tiny_align32
(
float
epsilon
,
int64_t
m
,
int64_t
n
,
const
T
*
x
,
T
*
y
,
const
float
*
scale
,
const
float
*
bias
);
}
// namespace plugin
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
template
<
typename
T
>
static
int
cpu_wrapper
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
int64_t
m
,
int64_t
n
,
float
eps
,
const
float
*
scale
,
const
float
*
bias
)
{
for
(
int64_t
i
=
0
;
i
<
m
;
i
++
)
{
float
sum
=
0.0
f
;
float
square_sum
=
0.0
f
;
for
(
int64_t
j
=
0
;
j
<
n
;
j
++
)
{
float
v
=
static_cast
<
float
>
(
x
[
i
*
n
+
j
]);
sum
+=
v
;
square_sum
+=
v
*
v
;
}
float
mean_value
=
sum
/
n
;
float
var_value
=
square_sum
/
n
-
mean_value
*
mean_value
;
float
rstd
=
1.0
f
/
std
::
sqrt
(
var_value
+
eps
);
for
(
int64_t
j
=
0
;
j
<
n
;
j
++
)
{
float
v
=
static_cast
<
float
>
(
x
[
i
*
n
+
j
]);
float
scale_value
=
((
scale
==
nullptr
)
?
1.0
f
:
scale
[
j
]);
float
bias_value
=
((
bias
==
nullptr
)
?
0.0
f
:
bias
[
j
]);
float
out
=
(
v
-
mean_value
)
*
rstd
*
scale_value
+
bias_value
;
y
[
i
*
n
+
j
]
=
static_cast
<
T
>
(
out
);
}
}
return
SUCCESS
;
}
template
<
typename
T
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
int64_t
m
,
int64_t
n
,
float
eps
,
const
float
*
scale
,
const
float
*
bias
)
{
if
(
n
<=
832
)
{
if
(
n
%
32
==
0
)
{
xpu2
::
plugin
::
fast_layer_norm_tiny_align32
<
T
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
eps
,
m
,
n
,
x
,
y
,
scale
,
bias
);
}
else
{
xpu2
::
plugin
::
fast_layer_norm_tiny_common
<
T
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
eps
,
m
,
n
,
x
,
y
,
scale
,
bias
);
}
}
else
{
return
layer_norm
(
ctx
,
x
,
y
,
m
,
n
,
eps
,
scale
,
bias
,
nullptr
,
nullptr
);
}
return
SUCCESS
;
}
template
<
typename
T
>
int
fast_layer_norm
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
int64_t
m
,
int64_t
n
,
float
eps
,
const
float
*
scale
,
const
float
*
bias
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T1
(
ctx
,
"fast_layer_norm"
,
T
);
WRAPPER_DUMP_PARAM5
(
ctx
,
x
,
y
,
m
,
n
,
eps
);
WRAPPER_DUMP_PARAM2
(
ctx
,
scale
,
bias
);
WRAPPER_DUMP
(
ctx
);
int64_t
xylen
=
-
1
;
WRAPPER_CHECK_SHAPE
(
ctx
,
&
xylen
,
{
m
,
n
});
WRAPPER_CHECK_2PTRS
(
ctx
,
T
,
xylen
,
x
,
y
);
WRAPPER_ASSERT_GE
(
ctx
,
eps
,
0
);
WRAPPER_CHECK_PTR_OR_NULL
(
ctx
,
float
,
n
,
scale
);
WRAPPER_CHECK_PTR_OR_NULL
(
ctx
,
float
,
n
,
bias
);
if
(
ctx
->
dev
().
type
()
==
api
::
kCPU
)
{
return
cpu_wrapper
<
T
>
(
ctx
,
x
,
y
,
m
,
n
,
eps
,
scale
,
bias
);
}
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
>
(
ctx
,
x
,
y
,
m
,
n
,
eps
,
scale
,
bias
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
int
fast_layer_norm
(
Context
*
,
const
float
*
,
float
*
,
int64_t
,
int64_t
,
float
,
const
float
*
,
const
float
*
);
template
int
fast_layer_norm
(
Context
*
,
const
float16
*
,
float16
*
,
int64_t
,
int64_t
,
float
,
const
float
*
,
const
float
*
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
paddle/phi/kernels/xpu/take_along_axis_kernel.cc
浏览文件 @
f16e1869
...
...
@@ -64,10 +64,6 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
int
r
=
XPU_SUCCESS
;
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG
(
WARNING
)
<<
"Add -DWITH_XPU_PLUGIN=ON to build "
"xpu::plugin::take_along_axis(), or use "
"xpu::gather_element() instead, which leads low performance "
"in some cases."
;
if
(
index_type
==
DataType
::
INT32
)
{
r
=
xpu
::
gather_element
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
...
...
test/ir/inference/test_xpu_fast_layernorm_xpu_fuse_pass.py
0 → 100644
浏览文件 @
f16e1869
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
class
TestFastLayernormXPUFusePass
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"fast_layernorm_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
50
))
x_shape
=
[
batch_size
,
16
,
128
]
y_shape
=
x_shape
axis
=
-
1
epsilon
=
draw
(
st
.
floats
(
min_value
=
0.0000001
,
max_value
=
0.001
))
begin_norm_axis
=
2
layer_norm_op
=
OpConfig
(
"layer_norm"
,
inputs
=
{
"X"
:
[
"x"
],
"Scale"
:
[
"layer_norm_scale"
],
"Bias"
:
[
"layer_norm_bias"
],
},
outputs
=
{
"Y"
:
[
"layer_norm_out"
],
"Mean"
:
[
"layer_norm_mean"
],
"Variance"
:
[
"layer_norm_var"
],
},
begin_norm_axis
=
begin_norm_axis
,
epsilon
=
epsilon
,
)
mini_graph
=
[
layer_norm_op
]
program_config
=
ProgramConfig
(
ops
=
mini_graph
,
weights
=
{
"layer_norm_scale"
:
TensorConfig
(
shape
=
[
x_shape
[
2
]]),
"layer_norm_bias"
:
TensorConfig
(
shape
=
[
x_shape
[
2
]]),
},
inputs
=
{
"x"
:
TensorConfig
(
shape
=
x_shape
),
},
outputs
=
mini_graph
[
-
1
].
outputs
[
"Y"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_layernorm_xpu_fuse_pass"
],
)
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
200
)
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录