Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
604b6fc0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
604b6fc0
编写于
11月 22, 2021
作者:
L
Li Min
提交者:
GitHub
11月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug to support dropout eval grad computing. (#37305) (#37331)
fix bug to support dropout eval grad computing. cherry-pick #37305.
上级
44db219a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
31 addition
and
28 deletion
+31
-28
paddle/fluid/operators/dropout_impl.cu.h
paddle/fluid/operators/dropout_impl.cu.h
+28
-20
paddle/fluid/operators/dropout_impl_util.h
paddle/fluid/operators/dropout_impl_util.h
+0
-3
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+3
-5
未找到文件。
paddle/fluid/operators/dropout_impl.cu.h
浏览文件 @
604b6fc0
...
...
@@ -244,34 +244,42 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const
std
::
string
dropout_implementation
,
float
dropout_prob
,
const
Tensor
&
grad_y
,
const
Tensor
&
mask
,
int64_t
size
,
Tensor
*
grad_x
)
{
auto
M
=
EigenVector
<
uint8_t
>::
Flatten
(
mask
);
Tensor
*
grad_x
,
bool
is_test
=
false
)
{
auto
dX
=
EigenVector
<
T
>::
Flatten
(
*
grad_x
);
auto
dY
=
EigenVector
<
T
>::
Flatten
(
grad_y
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
if
(
dropout_implementation
==
"upscale_in_train"
)
{
if
(
dropout_
prob
==
1.0
f
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
0
)
*
dY
;
if
(
is_test
)
{
if
(
dropout_
implementation
==
"upscale_in_train"
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
1
)
*
dY
;
}
else
{
int
vec_size
=
platform
::
GetVectorizedSize
<
T
>
(
grad_y
.
data
<
T
>
());
if
(
vec_size
==
4
&&
size
%
4
==
0
)
{
auto
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
auto
stream
=
dev_ctx
.
stream
();
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
dev_ctx
,
size
);
DropoutGradCUDAKernel
<
T
,
uint8_t
,
4
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
grad_y
.
data
<
T
>
(),
mask
.
data
<
uint8_t
>
(),
factor
,
size
,
grad_x
->
data
<
T
>
());
dX
.
device
(
place
)
=
dY
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
else
{
auto
M
=
EigenVector
<
uint8_t
>::
Flatten
(
mask
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
if
(
dropout_prob
==
1.0
f
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
0
)
*
dY
;
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
()
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
int
vec_size
=
platform
::
GetVectorizedSize
<
T
>
(
grad_y
.
data
<
T
>
());
if
(
vec_size
==
4
&&
size
%
4
==
0
)
{
auto
factor
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
auto
stream
=
dev_ctx
.
stream
();
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
dev_ctx
,
size
);
DropoutGradCUDAKernel
<
T
,
uint8_t
,
4
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
grad_y
.
data
<
T
>
(),
mask
.
data
<
uint8_t
>
(),
factor
,
size
,
grad_x
->
data
<
T
>
());
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
()
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
();
}
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
();
}
}
...
...
paddle/fluid/operators/dropout_impl_util.h
浏览文件 @
604b6fc0
...
...
@@ -34,9 +34,6 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
TensorCopySync
(
*
seed
,
platform
::
CPUPlace
(),
&
seed_cpu_tensor
);
*
seed_data
=
static_cast
<
uint64_t
>
(
seed_cpu_tensor
.
data
<
int
>
()[
0
]);
*
increment
=
offset
;
}
else
if
(
seed
&&
platform
::
is_cpu_place
(
seed
->
place
()))
{
*
seed_data
=
*
(
seed
->
data
<
int
>
());
*
increment
=
offset
;
}
else
if
(
gen_cuda
->
GetIsInitPy
()
&&
(
!
is_fix_seed
))
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
offset
);
*
seed_data
=
seed_offset
.
first
;
...
...
paddle/fluid/operators/dropout_op.cu
浏览文件 @
604b6fc0
...
...
@@ -58,10 +58,6 @@ template <typename DeviceContext, typename T>
class
GPUDropoutGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
!
context
.
Attr
<
bool
>
(
"is_test"
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"GradOp is only callable when is_test is false"
));
auto
*
grad_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grad_y
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
...
...
@@ -71,10 +67,12 @@ class GPUDropoutGradKernel : public framework::OpKernel<T> {
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
DropoutGradGPUKernelDriver
<
T
>
(
dev_ctx
,
dropout_implementation
,
dropout_prob
,
*
grad_y
,
*
mask
,
size
,
grad_x
);
*
grad_y
,
*
mask
,
size
,
grad_x
,
is_test
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录