Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cdeffff4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cdeffff4
编写于
6月 21, 2021
作者:
Z
zhiboniu
提交者:
GitHub
6月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gpt2 train loss Nan problem by add a line __syncthreads in BlockReduceSum (#33659)
上级
18043ab5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
12 addition
and
7 deletion
+12
-7
paddle/fluid/operators/correlation_op.cu
paddle/fluid/operators/correlation_op.cu
+1
-0
paddle/fluid/operators/layer_norm_op.cu
paddle/fluid/operators/layer_norm_op.cu
+10
-7
paddle/fluid/operators/math/math_cuda_utils.h
paddle/fluid/operators/math/math_cuda_utils.h
+1
-0
未找到文件。
paddle/fluid/operators/correlation_op.cu
浏览文件 @
cdeffff4
...
@@ -42,6 +42,7 @@ __forceinline__ __device__ T blockReduceSum(T val) {
...
@@ -42,6 +42,7 @@ __forceinline__ __device__ T blockReduceSum(T val) {
int
wid
=
threadIdx
.
x
/
warpSize
;
int
wid
=
threadIdx
.
x
/
warpSize
;
val
=
warpReduceSum
(
val
);
val
=
warpReduceSum
(
val
);
__syncthreads
();
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
__syncthreads
();
...
...
paddle/fluid/operators/layer_norm_op.cu
浏览文件 @
cdeffff4
...
@@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
...
@@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
}
}
template
<
typename
U
>
template
<
typename
U
>
__forceinline__
__device__
U
BlockReduceSum
(
U
val
)
{
__forceinline__
__device__
U
BlockReduceSum
(
U
val
,
U
*
shared
)
{
static
__shared__
U
shared
[
32
];
int
lane
=
threadIdx
.
x
%
warpSize
;
int
lane
=
threadIdx
.
x
%
warpSize
;
int
wid
=
threadIdx
.
x
/
warpSize
;
int
wid
=
threadIdx
.
x
/
warpSize
;
val
=
WarpReduceSum
(
val
);
// Each warp performs partial reduction
val
=
WarpReduceSum
(
val
);
// Each warp performs partial reduction
__syncthreads
();
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
// Write reduced value to shared memory
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
// Write reduced value to shared memory
__syncthreads
();
// Wait for all partial reductions
__syncthreads
();
// Wait for all partial reductions
// read from shared memory only if that warp existed
// read from shared memory only if that warp existed
val
=
val
=
(
threadIdx
.
x
<
blockDim
.
x
/
warpSize
)
?
shared
[
lane
]
:
static_cast
<
U
>
(
0
);
(
threadIdx
.
x
<
blockDim
.
x
/
warpSize
)
?
shared
[
lane
]
:
static_cast
<
U
>
(
0
);
...
@@ -183,6 +182,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
...
@@ -183,6 +182,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
int64_t
feature_size
)
{
int64_t
feature_size
)
{
__shared__
U
mean_share
;
__shared__
U
mean_share
;
__shared__
U
var_share
;
__shared__
U
var_share
;
__shared__
U
shared_mean
[
32
];
__shared__
U
shared_var
[
32
];
int64_t
beg_idx
=
blockIdx
.
x
*
feature_size
+
threadIdx
.
x
;
int64_t
beg_idx
=
blockIdx
.
x
*
feature_size
+
threadIdx
.
x
;
int64_t
end_idx
=
(
blockIdx
.
x
+
1
)
*
feature_size
;
int64_t
end_idx
=
(
blockIdx
.
x
+
1
)
*
feature_size
;
...
@@ -196,8 +197,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
...
@@ -196,8 +197,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
var_val
+=
(
tmp
*
tmp
);
var_val
+=
(
tmp
*
tmp
);
}
}
mean_val
=
BlockReduceSum
<
U
>
(
mean_val
);
mean_val
=
BlockReduceSum
<
U
>
(
mean_val
,
shared_mean
);
var_val
=
BlockReduceSum
<
U
>
(
var_val
);
var_val
=
BlockReduceSum
<
U
>
(
var_val
,
shared_var
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
auto
scale
=
static_cast
<
float
>
(
1.
)
/
static_cast
<
float
>
(
feature_size
);
auto
scale
=
static_cast
<
float
>
(
1.
)
/
static_cast
<
float
>
(
feature_size
);
...
@@ -541,8 +542,10 @@ __global__ void LayerNormBackwardGradientAll(
...
@@ -541,8 +542,10 @@ __global__ void LayerNormBackwardGradientAll(
}
}
}
}
d_scale_partial
=
BlockReduceSum
<
U
>
(
d_scale_partial
);
__shared__
U
shared_scale
[
32
];
d_bias_partial
=
BlockReduceSum
<
U
>
(
d_bias_partial
);
__shared__
U
shared_bias
[
32
];
d_scale_partial
=
BlockReduceSum
<
U
>
(
d_scale_partial
,
shared_scale
);
d_bias_partial
=
BlockReduceSum
<
U
>
(
d_bias_partial
,
shared_bias
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
d_scale
[
blockIdx
.
x
+
col_offset
]
=
d_scale_partial
;
d_scale
[
blockIdx
.
x
+
col_offset
]
=
d_scale_partial
;
...
...
paddle/fluid/operators/math/math_cuda_utils.h
浏览文件 @
cdeffff4
...
@@ -188,6 +188,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
...
@@ -188,6 +188,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
val
=
warpReduceSum
<
T
>
(
val
,
mask
);
val
=
warpReduceSum
<
T
>
(
val
,
mask
);
__syncthreads
();
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
__syncthreads
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录