Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1066debc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1066debc
编写于
6月 29, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 29, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2698 GPU dropout rewrite
Merge pull request !2698 from VectorSL/drop
上级
0e38672b
136b4569
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
184 addition
and
187 deletion
+184
-187
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu
+40
-12
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh
+4
-3
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
+8
-76
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h
+65
-14
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc
+8
-68
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h
+56
-14
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+3
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu
浏览文件 @
1066debc
...
...
@@ -17,31 +17,59 @@
#include <stdint.h>
#include "dropout_impl.cuh"
#include "include/cuda_runtime.h"
__global__
void
DropoutForwardKernel
(
const
float
*
input
,
float
*
mask
,
float
*
output
,
size_t
num_count
,
template
<
typename
T
>
__global__
void
DropoutForwardKernel
(
const
T
*
input
,
T
*
mask
,
T
*
output
,
float
*
mask_f
,
size_t
num_count
,
float
keep_prob
)
{
float
scale
=
1.
f
/
keep_prob
;
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num_count
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
mask
[
i
]
=
mask
[
i
]
<=
keep_prob
;
output
[
i
]
=
scale
*
input
[
i
]
*
mask
[
i
];
mask_f
[
i
]
=
mask_f
[
i
]
<=
keep_prob
;
output
[
i
]
=
scale
*
input
[
i
]
*
mask_f
[
i
];
mask
[
i
]
=
mask_f
[
i
];
}
}
void
DropoutForward
(
const
float
*
input
,
float
*
mask
,
float
*
output
,
size_t
num_count
,
float
drop_prob
,
template
<
>
__global__
void
DropoutForwardKernel
(
const
half
*
input
,
half
*
mask
,
half
*
output
,
float
*
mask_f
,
size_t
num_count
,
float
keep_prob
)
{
half
scale
=
__float2half
(
1.
f
/
keep_prob
);
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num_count
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
mask_f
[
i
]
=
mask_f
[
i
]
<=
keep_prob
;
output
[
i
]
=
scale
*
input
[
i
]
*
__float2half
(
mask_f
[
i
]);
mask
[
i
]
=
__float2half
(
mask_f
[
i
]);
}
}
template
<
typename
T
>
void
DropoutForward
(
const
T
*
input
,
T
*
mask
,
T
*
output
,
float
*
mask_f
,
size_t
num_count
,
float
drop_prob
,
cudaStream_t
cuda_stream
)
{
DropoutForwardKernel
<<<
GET_BLOCKS
(
num_count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
mask
,
output
,
num_count
,
drop_prob
);
DropoutForwardKernel
<<<
GET_BLOCKS
(
num_count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
mask
,
output
,
mask_f
,
num_count
,
drop_prob
);
}
__global__
void
DropoutBackwardKernel
(
const
float
*
dy
,
const
float
*
mask
,
float
*
dx
,
size_t
num_count
,
template
<
typename
T
>
__global__
void
DropoutBackwardKernel
(
const
T
*
dy
,
const
T
*
mask
,
T
*
dx
,
size_t
num_count
,
float
keep_prob
)
{
float
scale
=
1.
f
/
keep_prob
;
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num_count
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
dx
[
i
]
=
scale
*
dy
[
i
]
*
mask
[
i
];
}
}
void
DropoutBackward
(
const
float
*
dy
,
const
float
*
mask
,
float
*
dx
,
size_t
num_count
,
float
drop_prob
,
template
<
>
__global__
void
DropoutBackwardKernel
(
const
half
*
dy
,
const
half
*
mask
,
half
*
dx
,
size_t
num_count
,
float
keep_prob
)
{
half
scale
=
__float2half
(
1.
f
/
keep_prob
);
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num_count
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
dx
[
i
]
=
scale
*
dy
[
i
]
*
mask
[
i
];
}
}
template
<
typename
T
>
void
DropoutBackward
(
const
T
*
dy
,
const
T
*
mask
,
T
*
dx
,
size_t
num_count
,
float
drop_prob
,
cudaStream_t
cuda_stream
)
{
DropoutBackwardKernel
<<<
GET_BLOCKS
(
num_count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
dy
,
mask
,
dx
,
num_count
,
drop_prob
);
}
template
void
DropoutForward
<
float
>(
const
float
*
input
,
float
*
mask
,
float
*
output
,
float
*
mask_f
,
size_t
num_count
,
float
drop_prob
,
cudaStream_t
cuda_stream
);
template
void
DropoutForward
<
half
>(
const
half
*
input
,
half
*
mask
,
half
*
output
,
float
*
mask_f
,
size_t
num_count
,
float
drop_prob
,
cudaStream_t
cuda_stream
);
template
void
DropoutBackward
<
float
>(
const
float
*
dy
,
const
float
*
mask
,
float
*
dx
,
size_t
num_count
,
float
drop_prob
,
cudaStream_t
cuda_stream
);
template
void
DropoutBackward
<
half
>(
const
half
*
dy
,
const
half
*
mask
,
half
*
dx
,
size_t
num_count
,
float
drop_prob
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh
浏览文件 @
1066debc
...
...
@@ -18,9 +18,10 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
#include "device/gpu/cuda_common.h"
void
DropoutForward
(
const
float
*
input
,
float
*
mask
,
float
*
output
,
size_t
num_count
,
float
keep_prob
,
template
<
typename
T
>
void
DropoutForward
(
const
T
*
input
,
T
*
mask
,
T
*
output
,
float
*
mask_f
,
size_t
num_count
,
float
keep_prob
,
cudaStream_t
cuda_stream
);
void
DropoutBackward
(
const
float
*
dy
,
const
float
*
mask
,
float
*
dx
,
size_t
num_count
,
float
keep_prob
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
DropoutBackward
(
const
T
*
dy
,
const
T
*
mask
,
T
*
dx
,
size_t
num_count
,
float
keep_prob
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
浏览文件 @
1066debc
...
...
@@ -15,84 +15,16 @@
*/
#include "kernel/gpu/nn/dropout_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
DropoutGpuFwdKernel
::
DropoutGpuFwdKernel
()
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
num_count_
(
0
),
keep_prob_
(
0.0
),
states_init_
(
false
),
mask_generator_
(
nullptr
)
{}
DropoutGpuFwdKernel
::~
DropoutGpuFwdKernel
()
{
DestroyResource
();
}
const
std
::
vector
<
size_t
>
&
DropoutGpuFwdKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
DropoutGpuFwdKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
DropoutGpuFwdKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
bool
DropoutGpuFwdKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Argument number is "
<<
input_num
<<
", but DropoutGpuFwdKernel needs 1."
;
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
);
if
(
is_null_input_
)
{
InitSizeLists
();
return
true
;
}
num_count_
=
1
;
for
(
size_t
x
:
input_shape
)
{
num_count_
*=
x
;
}
keep_prob_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"keep_prob"
));
InitSizeLists
();
return
true
;
}
void
DropoutGpuFwdKernel
::
InitResource
()
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
DropoutGpuFwdKernel
::
DestroyResource
()
noexcept
{}
void
DropoutGpuFwdKernel
::
InitSizeLists
()
{
size_t
input_size
=
num_count_
*
sizeof
(
float
);
input_size_list_
.
push_back
(
input_size
);
output_size_list_
.
push_back
(
input_size
);
// output size: the same with input size
output_size_list_
.
push_back
(
input_size
);
// mask size: the same with input size
}
bool
DropoutGpuFwdKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
if
(
is_null_input_
)
{
return
true
;
}
auto
*
input
=
reinterpret_cast
<
float
*>
(
inputs
[
0
]
->
addr
);
auto
*
output
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]
->
addr
);
auto
*
mask
=
reinterpret_cast
<
float
*>
(
outputs
[
1
]
->
addr
);
if
(
!
states_init_
)
{
curandCreateGenerator
(
&
mask_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
);
curandSetPseudoRandomGeneratorSeed
(
mask_generator_
,
time
(
NULL
));
states_init_
=
true
;
}
curandGenerateUniform
(
mask_generator_
,
mask
,
num_count_
);
DropoutForward
(
input
,
mask
,
output
,
num_count_
,
keep_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
MS_REG_GPU_KERNEL_ONE
(
Dropout
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
DropoutGpuFwdKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Dropout
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
DropoutGpuFwdKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h
浏览文件 @
1066debc
...
...
@@ -20,35 +20,88 @@
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
#include "include/curand.h"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
DropoutGpuFwdKernel
:
public
GpuKernel
{
public:
DropoutGpuFwdKernel
();
DropoutGpuFwdKernel
()
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
num_count_
(
0
),
keep_prob_
(
0.0
),
states_init_
(
false
),
mask_generator_
(
nullptr
)
{}
~
DropoutGpuFwdKernel
()
override
;
~
DropoutGpuFwdKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
;
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
if
(
is_null_input_
)
{
return
true
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
;
T
*
input
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
mask
=
GetDeviceAddress
<
T
>
(
outputs
,
1
);
float
*
mask_f
=
GetDeviceAddress
<
float
>
(
workspace
,
0
);
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
;
if
(
!
states_init_
)
{
curandCreateGenerator
(
&
mask_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
);
curandSetPseudoRandomGeneratorSeed
(
mask_generator_
,
time
(
NULL
));
states_init_
=
true
;
}
// curandGen only support float or double for mask.
curandGenerateUniform
(
mask_generator_
,
mask_f
,
num_count_
);
DropoutForward
(
input
,
mask
,
output
,
mask_f
,
num_count_
,
keep_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Argument number is "
<<
input_num
<<
", but DropoutGpuFwdKernel needs 1."
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
;
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
);
if
(
is_null_input_
)
{
InitSizeLists
();
return
true
;
}
num_count_
=
1
;
for
(
size_t
x
:
input_shape
)
{
num_count_
*=
x
;
}
keep_prob_
=
GetAttr
<
float
>
(
kernel_node
,
"keep_prob"
);
InitSizeLists
();
return
true
;
}
protected:
void
InitResource
()
override
;
void
InitResource
()
override
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
InitSizeLists
()
override
;
void
InitSizeLists
()
override
{
size_t
input_size
=
num_count_
*
sizeof
(
T
);
input_size_list_
.
push_back
(
input_size
);
output_size_list_
.
push_back
(
input_size
);
// output size: the same with input size
output_size_list_
.
push_back
(
input_size
);
// mask size: the same with input size
workspace_size_list_
.
push_back
(
num_count_
*
sizeof
(
float
));
// temp mask_f for curandGen
}
private:
void
DestroyResource
()
noexcept
;
cudnnHandle_t
cudnn_handle_
;
bool
is_null_input_
;
size_t
num_count_
;
...
...
@@ -59,8 +112,6 @@ class DropoutGpuFwdKernel : public GpuKernel {
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
MS_REG_GPU_KERNEL
(
Dropout
,
DropoutGpuFwdKernel
)
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc
浏览文件 @
1066debc
...
...
@@ -15,76 +15,16 @@
*/
#include "kernel/gpu/nn/dropout_grad_kernel.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
DropoutGradGpuFwdKernel
::
DropoutGradGpuFwdKernel
()
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
num_count_
(
0
),
keep_prob_
(
0.0
)
{}
DropoutGradGpuFwdKernel
::~
DropoutGradGpuFwdKernel
()
{
DestroyResource
();
}
const
std
::
vector
<
size_t
>
&
DropoutGradGpuFwdKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
DropoutGradGpuFwdKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
DropoutGradGpuFwdKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
bool
DropoutGradGpuFwdKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"Argument number is "
<<
input_num
<<
", but DropoutGradGpuFwdKernel needs 2."
;
return
false
;
}
auto
input_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
);
if
(
is_null_input_
)
{
InitSizeLists
();
return
true
;
}
num_count_
=
1
;
for
(
size_t
x
:
input_shape
)
{
num_count_
*=
x
;
}
keep_prob_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"keep_prob"
));
InitSizeLists
();
return
true
;
}
void
DropoutGradGpuFwdKernel
::
InitResource
()
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
DropoutGradGpuFwdKernel
::
DestroyResource
()
noexcept
{}
void
DropoutGradGpuFwdKernel
::
InitSizeLists
()
{
size_t
dy_size
=
num_count_
*
sizeof
(
float
);
size_t
mask_size
=
dy_size
;
size_t
dx_size
=
dy_size
;
input_size_list_
.
push_back
(
dy_size
);
input_size_list_
.
push_back
(
mask_size
);
output_size_list_
.
push_back
(
dx_size
);
}
bool
DropoutGradGpuFwdKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
if
(
is_null_input_
)
{
return
true
;
}
auto
*
dy
=
reinterpret_cast
<
float
*>
(
inputs
[
0
]
->
addr
);
auto
*
mask
=
reinterpret_cast
<
float
*>
(
inputs
[
1
]
->
addr
);
auto
*
dx
=
reinterpret_cast
<
float
*>
(
outputs
[
0
]
->
addr
);
DropoutBackward
(
dy
,
mask
,
dx
,
num_count_
,
keep_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
MS_REG_GPU_KERNEL_ONE
(
DropoutGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
DropoutGradGpuBwdKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
DropoutGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
DropoutGradGpuBwdKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h
浏览文件 @
1066debc
...
...
@@ -20,28 +20,72 @@
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
class
DropoutGradGpuFwdKernel
:
public
GpuKernel
{
template
<
typename
T
>
class
DropoutGradGpuBwdKernel
:
public
GpuKernel
{
public:
DropoutGradGpu
FwdKernel
();
~
DropoutGradGpu
FwdKernel
()
override
;
DropoutGradGpu
BwdKernel
()
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
num_count_
(
0
),
keep_prob_
(
0.0
)
{}
~
DropoutGradGpu
BwdKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
;
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
;
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
if
(
is_null_input_
)
{
return
true
;
}
T
*
dy
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
mask
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
dx
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
DropoutBackward
(
dy
,
mask
,
dx
,
num_count_
,
keep_prob_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"Argument number is "
<<
input_num
<<
", but DropoutGradGpuBwdKernel needs 2."
;
return
false
;
}
auto
input_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
);
if
(
is_null_input_
)
{
InitSizeLists
();
return
true
;
}
num_count_
=
1
;
for
(
size_t
x
:
input_shape
)
{
num_count_
*=
x
;
}
keep_prob_
=
GetAttr
<
float
>
(
kernel_node
,
"keep_prob"
);
InitSizeLists
();
return
true
;
}
protected:
void
InitResource
()
override
;
void
InitSizeLists
()
override
;
void
InitResource
()
override
{
cudnn_handle_
=
device
::
gpu
::
GPUDeviceManager
::
GetInstance
().
GetCudnnHandle
();
}
void
InitSizeLists
()
override
{
size_t
dy_size
=
num_count_
*
sizeof
(
T
);
size_t
mask_size
=
dy_size
;
size_t
dx_size
=
dy_size
;
private:
void
DestroyResource
()
noexcept
;
input_size_list_
.
push_back
(
dy_size
);
input_size_list_
.
push_back
(
mask_size
);
output_size_list_
.
push_back
(
dx_size
);
}
private:
cudnnHandle_t
cudnn_handle_
;
bool
is_null_input_
;
size_t
num_count_
;
...
...
@@ -50,8 +94,6 @@ class DropoutGradGpuFwdKernel : public GpuKernel {
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
MS_REG_GPU_KERNEL
(
DropoutGrad
,
DropoutGradGpuFwdKernel
)
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
1066debc
...
...
@@ -4460,6 +4460,7 @@ class Dropout(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x_dtype
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_subclass
(
"x"
,
x_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x_dtype"
:
x_dtype
},
valid_types
,
self
.
name
)
return
x_dtype
,
x_dtype
...
...
@@ -4494,6 +4495,8 @@ class DropoutGrad(PrimitiveWithInfer):
def
infer_dtype
(
self
,
dy_dtype
,
mask_dtype
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_subclass
(
"dy"
,
dy_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"mask"
,
mask_dtype
,
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"dy_dtype"
:
dy_dtype
},
valid_types
,
self
.
name
)
return
dy_dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录