Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
addd5fce
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
addd5fce
编写于
8月 11, 2021
作者:
W
wenbin
提交者:
GitHub
8月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
miss format (#34771)
上级
4d2994cb
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
11 addition
and
3 deletion
+11
-3
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+11
-3
未找到文件。
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
addd5fce
...
...
@@ -25,6 +25,14 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
>
__device__
__forceinline__
T
local_rsqrt
(
T
num
)
{
return
rsqrt
(
static_cast
<
float
>
(
num
));
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
__device__
__forceinline__
half
local_rsqrt
(
half
num
)
{
return
hrsqrt
(
num
);
}
#endif
template
<
typename
T
,
int
TPB
>
__device__
inline
void
LayerNormSmall
(
T
val
,
const
kvp
<
T
>
&
thread_data
,
const
int
ld
,
const
int
idx
,
...
...
@@ -39,7 +47,7 @@ __device__ inline void LayerNormSmall(T val, const kvp<T> &thread_data,
if
(
threadIdx
.
x
==
0
)
{
mu
=
sum_kv
.
key
;
rsigma
=
rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
eps
);
rsigma
=
local_
rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
eps
);
}
__syncthreads
();
...
...
@@ -63,7 +71,7 @@ __device__ inline void LayerNorm(const kvp<T> &thread_data, const int ld,
if
(
threadIdx
.
x
==
0
)
{
mu
=
sum_kv
.
key
;
rsigma
=
rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
eps
);
rsigma
=
local_
rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
eps
);
}
__syncthreads
();
...
...
@@ -89,7 +97,7 @@ __device__ inline void LayerNorm2(const kvp<T> &thread_data, const int ld,
if
(
threadIdx
.
x
==
0
)
{
mu
=
sum_kv
.
key
;
rsigma
=
rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
eps
);
rsigma
=
local_
rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
eps
);
}
__syncthreads
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录