Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b76f5a84
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b76f5a84
编写于
12月 22, 2020
作者:
Z
Zhang Ting
提交者:
GitHub
12月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the bug of dropout_grad (#29813)
上级
61820fd2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
6 addition
and
3 deletion
+6
-3
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+6
-3
未找到文件。
paddle/fluid/operators/dropout_op.h
浏览文件 @
b76f5a84
...
@@ -54,11 +54,14 @@ __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
...
@@ -54,11 +54,14 @@ __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
for
(
int
i
=
idx
*
VecSize
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
for
(
int
i
=
idx
*
VecSize
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
T
dout_vec
[
VecSize
];
T
dout_vec
[
VecSize
];
LoadT
*
value
=
reinterpret_cast
<
LoadT
*>
(
&
dout_vec
);
LoadT
*
dout_
value
=
reinterpret_cast
<
LoadT
*>
(
&
dout_vec
);
*
value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
dout
[
i
]);
*
dout_
value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
dout
[
i
]);
T
dx_vec
[
VecSize
];
MaskType
mask_vec
[
VecSize
];
MaskType
mask_vec
[
VecSize
];
MaskLoadT
*
mask_value
=
reinterpret_cast
<
MaskLoadT
*>
(
&
mask_vec
);
*
mask_value
=
*
reinterpret_cast
<
const
MaskLoadT
*>
(
&
mask
[
i
]);
T
dx_vec
[
VecSize
];
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录