Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c9d26423
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c9d26423
编写于
1月 15, 2021
作者:
Y
Yang Zhang
提交者:
GitHub
1月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix float64 bug in layer norm (#30454)
built-in `rsqrt` is shadowed
上级
badc6f22
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
7 addition
and
7 deletion
+7
-7
paddle/fluid/operators/layer_norm_op.cu
paddle/fluid/operators/layer_norm_op.cu
+7
-7
未找到文件。
paddle/fluid/operators/layer_norm_op.cu
浏览文件 @
c9d26423
...
...
@@ -108,23 +108,23 @@ struct PairForLayerNormAddFunctor {
};
template
<
typename
T
>
__inline__
__device__
T
rsqrt
(
const
T
val
)
{
__inline__
__device__
T
rsqrt
_
(
const
T
val
)
{
return
static_cast
<
T
>
(
1
)
/
sqrt
(
val
);
}
template
<
>
__inline__
__device__
float
rsqrt
(
const
float
val
)
{
__inline__
__device__
float
rsqrt
_
(
const
float
val
)
{
return
rsqrtf
(
val
);
}
template
<
>
__inline__
__device__
double
rsqrt
(
const
double
val
)
{
__inline__
__device__
double
rsqrt
_
(
const
double
val
)
{
return
rsqrt
(
val
);
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template
<
>
__inline__
__device__
half
rsqrt
(
const
half
val
)
{
__inline__
__device__
half
rsqrt
_
(
const
half
val
)
{
return
hrsqrt
(
val
);
}
#endif
...
...
@@ -161,7 +161,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
__syncthreads
();
mean_val
=
mean_share
;
U
invvar
=
rsqrt
<
U
>
(
var_share
+
static_cast
<
U
>
(
epsilon
));
U
invvar
=
rsqrt
_
<
U
>
(
var_share
+
static_cast
<
U
>
(
epsilon
));
// Step 2: Calculate y
if
(
scale
!=
nullptr
)
{
...
...
@@ -204,7 +204,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(
const
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
>=
i1_end
)
return
;
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
rsqrt
<
U
>
(
var
[
i1
]
+
epsilon
);
U
curr_invvar
=
rsqrt
_
<
U
>
(
var
[
i1
]
+
epsilon
);
for
(
int
k
=
0
;
k
<
VPT
;
++
k
)
{
const
int
i2
=
i2_off
+
k
;
const
int
load_idx
=
i1
*
n2
+
i2
;
...
...
@@ -352,7 +352,7 @@ __global__ void LayerNormBackwardComputeGradInput(
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
rsqrt
<
U
>
(
var
[
i1
]
+
epsilon
);
const
U
c_invvar
=
rsqrt
_
<
U
>
(
var
[
i1
]
+
epsilon
);
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_dout
=
dout
+
i1
*
n2
;
constexpr
int
numx
=
BDIMX
*
BDIMY
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录