Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
55accdfc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
55accdfc
编写于
9月 27, 2022
作者:
W
wenbin
提交者:
GitHub
9月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
preln_residual_bias optimization (#46496)
* half2 * add epsilon
上级
4d772144
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
153 addition
and
24 deletion
+153
-24
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
...d/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
+153
-24
未找到文件。
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
浏览文件 @
55accdfc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -30,6 +31,116 @@ namespace paddle {
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#define FINAL_MASK 0xffffffff
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceSumV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
warpReduceSumV2
<
T
,
NUM
>
(
val
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
i
][
wid
]
=
val
[
i
];
}
}
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
}
warpReduceSumV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
__global__
void
generalAddBiasResidualLayerNormOpt2
(
half2
*
normed_output
,
half2
*
output
,
const
half2
*
__restrict
bias
,
const
half2
*
__restrict
src
,
const
half2
*
__restrict
residual
,
const
half2
*
__restrict
gamma
,
const
half2
*
__restrict
beta
,
int
m
,
int
n
,
float
epsilon
)
{
__shared__
float
s_mean
;
__shared__
float
s_variance
;
float
x_sum
=
0.0
f
;
float
x2_sum
=
0.0
f
;
const
int
b_offset
=
blockIdx
.
x
*
n
;
#pragma unroll 2
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
index
=
b_offset
+
i
;
float
val_1
=
0.0
f
;
float
val_2
=
0.0
f
;
half2
tmp
;
if
(
bias
)
{
tmp
=
__ldg
(
&
bias
[
i
]);
val_1
+=
static_cast
<
float
>
(
tmp
.
x
);
val_2
+=
static_cast
<
float
>
(
tmp
.
y
);
}
{
tmp
=
__ldg
(
&
residual
[
index
]);
val_1
+=
static_cast
<
float
>
(
tmp
.
x
);
val_2
+=
static_cast
<
float
>
(
tmp
.
y
);
}
{
tmp
=
__ldg
(
&
src
[
index
]);
val_1
+=
static_cast
<
float
>
(
tmp
.
x
);
val_2
+=
static_cast
<
float
>
(
tmp
.
y
);
}
tmp
.
x
=
__float2half_rn
(
val_1
);
tmp
.
y
=
__float2half_rn
(
val_2
);
output
[
index
]
=
tmp
;
x_sum
+=
val_1
+
val_2
;
x2_sum
+=
val_1
*
val_1
+
val_2
*
val_2
;
}
float
sums
[
2
];
sums
[
0
]
=
x_sum
;
sums
[
1
]
=
x2_sum
;
blockReduceSumV2
<
float
,
2
>
(
sums
);
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sums
[
0
]
/
n
/
2
;
s_variance
=
rsqrtf
(
sums
[
1
]
/
n
/
2
-
s_mean
*
s_mean
+
epsilon
);
}
__syncthreads
();
half2
mean_2
=
__float2half2_rn
(
s_mean
);
half2
var_2
=
__float2half2_rn
(
s_variance
);
#pragma unroll 2
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
index
=
b_offset
+
i
;
half2
val
=
__hmul2
(
__hmul2
(
__hsub2
(
output
[
index
],
mean_2
),
var_2
),
__ldg
(
&
gamma
[
i
]));
if
(
beta
)
{
val
=
__hadd2
(
val
,
__ldg
(
&
beta
[
i
]));
}
normed_output
[
index
]
=
val
;
}
}
#endif
using
half
=
phi
::
dtype
::
float16
;
#if IS_TRT_VERSION_GE(6000)
...
...
@@ -306,30 +417,48 @@ int PrelnResidualBiasPluginDynamic::enqueue(
float
*
mean
=
nullptr
;
float
*
var
=
nullptr
;
const
int
VecSize
=
8
;
paddle
::
operators
::
FusedLayernormResidualDropoutBiasFunctor
<
half
,
uint8_t
,
VecSize
,
float
,
false
>
()(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
,
stream
);
// if odd
if
(
hidden
&
1
==
0
)
{
int
half_n
=
hidden
/
2
;
int
half_n_32
=
(
half_n
+
31
)
/
32
*
32
;
int
block
(
std
::
min
(
half_n_32
,
512
));
generalAddBiasResidualLayerNormOpt2
<<<
rows
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
half2
*>
(
layernorm_dst
),
reinterpret_cast
<
half2
*>
(
dst
),
(
const
half2
*
)
bias
,
(
const
half2
*
)
input2
,
(
const
half2
*
)
input1
,
(
const
half2
*
)
scale
,
(
const
half2
*
)
layernorm_bias
,
rows
,
half_n
,
epsilon
);
}
else
{
paddle
::
operators
::
FusedLayernormResidualDropoutBiasFunctor
<
half
,
uint8_t
,
VecSize
,
float
,
false
>
()(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
,
stream
);
}
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录