提交 fd1994b6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1759 Gpu Dropout kernel fix

Merge pull request !1759 from chenweifeng/dropout
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include "include/cuda_runtime.h" #include "include/cuda_runtime.h"
__global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count, __global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count,
float drop_prob) { float keep_prob) {
float scale = 1.f / drop_prob; float scale = 1.f / keep_prob;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
mask[i] = mask[i] > drop_prob; mask[i] = mask[i] <= keep_prob;
output[i] = scale * input[i] * mask[i]; output[i] = scale * input[i] * mask[i];
} }
} }
...@@ -34,8 +34,8 @@ void DropoutForward(const float *input, float *mask, float *output, size_t num_c ...@@ -34,8 +34,8 @@ void DropoutForward(const float *input, float *mask, float *output, size_t num_c
} }
__global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count, __global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count,
float drop_prob) { float keep_prob) {
float scale = 1.f / (1.f - drop_prob); float scale = 1.f / keep_prob;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
dx[i] = scale * dy[i] * mask[i]; dx[i] = scale * dy[i] * mask[i];
} }
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
#include "device/gpu/cuda_common.h" #include "device/gpu/cuda_common.h"
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob, void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float keep_prob,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob, void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float keep_prob,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
...@@ -23,7 +23,7 @@ DropoutGpuFwdKernel::DropoutGpuFwdKernel() ...@@ -23,7 +23,7 @@ DropoutGpuFwdKernel::DropoutGpuFwdKernel()
: cudnn_handle_(nullptr), : cudnn_handle_(nullptr),
is_null_input_(false), is_null_input_(false),
num_count_(0), num_count_(0),
drop_prob_(0.0), keep_prob_(0.0),
states_init_(false), states_init_(false),
mask_generator_(nullptr) {} mask_generator_(nullptr) {}
...@@ -54,7 +54,7 @@ bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) { ...@@ -54,7 +54,7 @@ bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) {
for (size_t x : input_shape) { for (size_t x : input_shape) {
num_count_ *= x; num_count_ *= x;
} }
drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob"));
InitSizeLists(); InitSizeLists();
return true; return true;
...@@ -92,7 +92,7 @@ bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const st ...@@ -92,7 +92,7 @@ bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const st
} }
curandGenerateUniform(mask_generator_, mask, num_count_); curandGenerateUniform(mask_generator_, mask, num_count_);
DropoutForward(input, mask, output, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); DropoutForward(input, mask, output, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }
......
...@@ -52,7 +52,7 @@ class DropoutGpuFwdKernel : public GpuKernel { ...@@ -52,7 +52,7 @@ class DropoutGpuFwdKernel : public GpuKernel {
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
bool is_null_input_; bool is_null_input_;
size_t num_count_; size_t num_count_;
float drop_prob_; float keep_prob_;
bool states_init_; bool states_init_;
curandGenerator_t mask_generator_; curandGenerator_t mask_generator_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel() DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel()
: cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), drop_prob_(0.0) {} : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {}
DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); } DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); }
...@@ -50,7 +50,7 @@ bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) { ...@@ -50,7 +50,7 @@ bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) {
for (size_t x : input_shape) { for (size_t x : input_shape) {
num_count_ *= x; num_count_ *= x;
} }
drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob"));
InitSizeLists(); InitSizeLists();
return true; return true;
...@@ -84,7 +84,7 @@ bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, cons ...@@ -84,7 +84,7 @@ bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, cons
auto *mask = reinterpret_cast<float *>(inputs[1]->addr); auto *mask = reinterpret_cast<float *>(inputs[1]->addr);
auto *dx = reinterpret_cast<float *>(outputs[0]->addr); auto *dx = reinterpret_cast<float *>(outputs[0]->addr);
DropoutBackward(dy, mask, dx, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }
......
...@@ -45,7 +45,7 @@ class DropoutGradGpuFwdKernel : public GpuKernel { ...@@ -45,7 +45,7 @@ class DropoutGradGpuFwdKernel : public GpuKernel {
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
bool is_null_input_; bool is_null_input_;
size_t num_count_; size_t num_count_;
float drop_prob_; float keep_prob_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
......
...@@ -675,7 +675,7 @@ def get_bprop_binary_cross_entropy(self): ...@@ -675,7 +675,7 @@ def get_bprop_binary_cross_entropy(self):
@bprop_getters.register(P.Dropout) @bprop_getters.register(P.Dropout)
def get_bprop_dropout(self): def get_bprop_dropout(self):
"""Grad definition for `Dropout` operation.""" """Grad definition for `Dropout` operation."""
grad = P.DropoutGrad(self.drop_prob) grad = P.DropoutGrad(self.keep_prob)
def bprop(x, out, dout): def bprop(x, out, dout):
_, mask = out _, mask = out
......
...@@ -3227,7 +3227,8 @@ class Dropout(PrimitiveWithInfer): ...@@ -3227,7 +3227,8 @@ class Dropout(PrimitiveWithInfer):
During training, randomly zeroes some of the elements of the input tensor with probability. During training, randomly zeroes some of the elements of the input tensor with probability.
Args: Args:
drop_prob (float): probability of an element to be zeroed. Default: 0. keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.
Inputs: Inputs:
- **shape** (tuple[int]) - The shape of target mask. - **shape** (tuple[int]) - The shape of target mask.
...@@ -3236,14 +3237,14 @@ class Dropout(PrimitiveWithInfer): ...@@ -3236,14 +3237,14 @@ class Dropout(PrimitiveWithInfer):
Tensor, the value of generated mask for input shape. Tensor, the value of generated mask for input shape.
Examples: Examples:
>>> dropout = P.Dropout(drop_prob=0.5) >>> dropout = P.Dropout(keep_prob=0.5)
>>> in = Tensor((20, 16, 50, 50)) >>> in = Tensor((20, 16, 50, 50))
>>> out = dropout(in) >>> out = dropout(in)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, drop_prob=0): def __init__(self, keep_prob=0.5):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
...@@ -3262,7 +3263,8 @@ class DropoutGrad(PrimitiveWithInfer): ...@@ -3262,7 +3263,8 @@ class DropoutGrad(PrimitiveWithInfer):
of the input tensor with probability. of the input tensor with probability.
Args: Args:
drop_prob (float): probability of an element to be zeroed. Default: 0. keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.
Inputs: Inputs:
- **shape** (tuple[int]) - The shape of target mask. - **shape** (tuple[int]) - The shape of target mask.
...@@ -3271,14 +3273,14 @@ class DropoutGrad(PrimitiveWithInfer): ...@@ -3271,14 +3273,14 @@ class DropoutGrad(PrimitiveWithInfer):
Tensor, the value of generated mask for input shape. Tensor, the value of generated mask for input shape.
Examples: Examples:
>>> dropout_grad = P.DropoutGrad(drop_prob=0.5) >>> dropout_grad = P.DropoutGrad(keep_prob=0.5)
>>> in = Tensor((20, 16, 50, 50)) >>> in = Tensor((20, 16, 50, 50))
>>> out = dropout_grad(in) >>> out = dropout_grad(in)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, drop_prob=0): def __init__(self, keep_prob=0.5):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)
def infer_shape(self, dy_shape, mask_shape): def infer_shape(self, dy_shape, mask_shape):
return dy_shape return dy_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册