Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
fd1994b6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fd1994b6
编写于
6月 02, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1759 Gpu Dropout kernel fix
Merge pull request !1759 from chenweifeng/dropout
上级
b3c6da90
cf0820aa
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
26 addition
and
24 deletion
+26
-24
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu
+5
-5
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh
+2
-2
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
+3
-3
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h
+1
-1
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc
+3
-3
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h
+1
-1
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+10
-8
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu
浏览文件 @
fd1994b6
...
...
@@ -19,10 +19,10 @@
#include "include/cuda_runtime.h"
__global__
void
DropoutForwardKernel
(
const
float
*
input
,
float
*
mask
,
float
*
output
,
size_t
num_count
,
float
dro
p_prob
)
{
float
scale
=
1.
f
/
dro
p_prob
;
float
kee
p_prob
)
{
float
scale
=
1.
f
/
kee
p_prob
;
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num_count
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
mask
[
i
]
=
mask
[
i
]
>
dro
p_prob
;
mask
[
i
]
=
mask
[
i
]
<=
kee
p_prob
;
output
[
i
]
=
scale
*
input
[
i
]
*
mask
[
i
];
}
}
...
...
@@ -34,8 +34,8 @@ void DropoutForward(const float *input, float *mask, float *output, size_t num_c
}
__global__
void
DropoutBackwardKernel
(
const
float
*
dy
,
const
float
*
mask
,
float
*
dx
,
size_t
num_count
,
float
dro
p_prob
)
{
float
scale
=
1.
f
/
(
1.
f
-
drop_prob
)
;
float
kee
p_prob
)
{
float
scale
=
1.
f
/
keep_prob
;
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num_count
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
dx
[
i
]
=
scale
*
dy
[
i
]
*
mask
[
i
];
}
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh
浏览文件 @
fd1994b6
...
...
@@ -18,9 +18,9 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
#include "device/gpu/cuda_common.h"
void
DropoutForward
(
const
float
*
input
,
float
*
mask
,
float
*
output
,
size_t
num_count
,
float
dro
p_prob
,
void
DropoutForward
(
const
float
*
input
,
float
*
mask
,
float
*
output
,
size_t
num_count
,
float
kee
p_prob
,
cudaStream_t
cuda_stream
);
void
DropoutBackward
(
const
float
*
dy
,
const
float
*
mask
,
float
*
dx
,
size_t
num_count
,
float
dro
p_prob
,
void
DropoutBackward
(
const
float
*
dy
,
const
float
*
mask
,
float
*
dx
,
size_t
num_count
,
float
kee
p_prob
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
浏览文件 @
fd1994b6
...
...
@@ -23,7 +23,7 @@ DropoutGpuFwdKernel::DropoutGpuFwdKernel()
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
num_count_
(
0
),
dro
p_prob_
(
0.0
),
kee
p_prob_
(
0.0
),
states_init_
(
false
),
mask_generator_
(
nullptr
)
{}
...
...
@@ -54,7 +54,7 @@ bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) {
for
(
size_t
x
:
input_shape
)
{
num_count_
*=
x
;
}
drop_prob_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"dro
p_prob"
));
keep_prob_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"kee
p_prob"
));
InitSizeLists
();
return
true
;
...
...
@@ -92,7 +92,7 @@ bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const st
}
curandGenerateUniform
(
mask_generator_
,
mask
,
num_count_
);
DropoutForward
(
input
,
mask
,
output
,
num_count_
,
dro
p_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
DropoutForward
(
input
,
mask
,
output
,
num_count_
,
kee
p_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
...
...
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h
浏览文件 @
fd1994b6
...
...
@@ -52,7 +52,7 @@ class DropoutGpuFwdKernel : public GpuKernel {
cudnnHandle_t
cudnn_handle_
;
bool
is_null_input_
;
size_t
num_count_
;
float
dro
p_prob_
;
float
kee
p_prob_
;
bool
states_init_
;
curandGenerator_t
mask_generator_
;
std
::
vector
<
size_t
>
input_size_list_
;
...
...
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc
浏览文件 @
fd1994b6
...
...
@@ -20,7 +20,7 @@
namespace
mindspore
{
namespace
kernel
{
DropoutGradGpuFwdKernel
::
DropoutGradGpuFwdKernel
()
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
num_count_
(
0
),
dro
p_prob_
(
0.0
)
{}
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
num_count_
(
0
),
kee
p_prob_
(
0.0
)
{}
DropoutGradGpuFwdKernel
::~
DropoutGradGpuFwdKernel
()
{
DestroyResource
();
}
...
...
@@ -50,7 +50,7 @@ bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) {
for
(
size_t
x
:
input_shape
)
{
num_count_
*=
x
;
}
drop_prob_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"dro
p_prob"
));
keep_prob_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"kee
p_prob"
));
InitSizeLists
();
return
true
;
...
...
@@ -84,7 +84,7 @@ bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, cons
auto
*
mask
=
reinterpret_cast
<
float
*>
(
inputs
[
1
]
->
addr
);
auto
*
dx
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]
->
addr
);
DropoutBackward
(
dy
,
mask
,
dx
,
num_count_
,
dro
p_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
DropoutBackward
(
dy
,
mask
,
dx
,
num_count_
,
kee
p_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
...
...
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h
浏览文件 @
fd1994b6
...
...
@@ -45,7 +45,7 @@ class DropoutGradGpuFwdKernel : public GpuKernel {
cudnnHandle_t
cudnn_handle_
;
bool
is_null_input_
;
size_t
num_count_
;
float
dro
p_prob_
;
float
kee
p_prob_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
...
...
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
fd1994b6
...
...
@@ -675,7 +675,7 @@ def get_bprop_binary_cross_entropy(self):
@
bprop_getters
.
register
(
P
.
Dropout
)
def
get_bprop_dropout
(
self
):
"""Grad definition for `Dropout` operation."""
grad
=
P
.
DropoutGrad
(
self
.
dro
p_prob
)
grad
=
P
.
DropoutGrad
(
self
.
kee
p_prob
)
def
bprop
(
x
,
out
,
dout
):
_
,
mask
=
out
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
fd1994b6
...
...
@@ -3227,7 +3227,8 @@ class Dropout(PrimitiveWithInfer):
During training, randomly zeroes some of the elements of the input tensor with probability.
Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.
Inputs:
- **shape** (tuple[int]) - The shape of target mask.
...
...
@@ -3236,14 +3237,14 @@ class Dropout(PrimitiveWithInfer):
Tensor, the value of generated mask for input shape.
Examples:
>>> dropout = P.Dropout(
dro
p_prob=0.5)
>>> dropout = P.Dropout(
kee
p_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout(in)
"""
@
prim_attr_register
def
__init__
(
self
,
drop_prob
=
0
):
self
.
drop_prob
=
validator
.
check_number_range
(
"drop_prob"
,
drop_prob
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
def
__init__
(
self
,
keep_prob
=
0.5
):
self
.
keep_prob
=
validator
.
check_number_range
(
"keep_prob"
,
keep_prob
,
0
,
1
,
Rel
.
INC_RIGHT
,
self
.
name
)
def
infer_shape
(
self
,
x_shape
):
validator
.
check_integer
(
"x_shape"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
...
...
@@ -3262,7 +3263,8 @@ class DropoutGrad(PrimitiveWithInfer):
of the input tensor with probability.
Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.
Inputs:
- **shape** (tuple[int]) - The shape of target mask.
...
...
@@ -3271,14 +3273,14 @@ class DropoutGrad(PrimitiveWithInfer):
Tensor, the value of generated mask for input shape.
Examples:
>>> dropout_grad = P.DropoutGrad(
dro
p_prob=0.5)
>>> dropout_grad = P.DropoutGrad(
kee
p_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout_grad(in)
"""
@
prim_attr_register
def
__init__
(
self
,
drop_prob
=
0
):
self
.
drop_prob
=
validator
.
check_number_range
(
"drop_prob"
,
drop_prob
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
def
__init__
(
self
,
keep_prob
=
0.5
):
self
.
keep_prob
=
validator
.
check_number_range
(
"keep_prob"
,
keep_prob
,
0
,
1
,
Rel
.
INC_RIGHT
,
self
.
name
)
def
infer_shape
(
self
,
dy_shape
,
mask_shape
):
return
dy_shape
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录