diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu index 940c64ea53dcc120d66154305a6e8b0ebf8d3730..019d71d740b6d0d97d091948936ab7352958d08b 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu @@ -19,10 +19,10 @@ #include "include/cuda_runtime.h" __global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count, - float drop_prob) { - float scale = 1.f / drop_prob; + float keep_prob) { + float scale = 1.f / keep_prob; for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { - mask[i] = mask[i] > drop_prob; + mask[i] = mask[i] <= keep_prob; output[i] = scale * input[i] * mask[i]; } } @@ -34,8 +34,8 @@ void DropoutForward(const float *input, float *mask, float *output, size_t num_c } __global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count, - float drop_prob) { - float scale = 1.f / (1.f - drop_prob); + float keep_prob) { + float scale = 1.f / keep_prob; for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { dx[i] = scale * dy[i] * mask[i]; } diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh index 9aa05d6a08752ec02ebee0351e880222f56a6c2b..bd3de6524d6d82024958f78ac1fe109046d6e24d 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh @@ -18,9 +18,9 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ #include "device/gpu/cuda_common.h" -void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob, +void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float keep_prob, cudaStream_t cuda_stream); -void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob, +void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float keep_prob, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc index 0d2a6be9c8e1d7d24b0685029c736a46226be383..87783add614c4f598bbf2702d2717e3231b0e469 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc @@ -23,7 +23,7 @@ DropoutGpuFwdKernel::DropoutGpuFwdKernel() : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), - drop_prob_(0.0), + keep_prob_(0.0), states_init_(false), mask_generator_(nullptr) {} @@ -54,7 +54,7 @@ bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) { for (size_t x : input_shape) { num_count_ *= x; } - drop_prob_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); + keep_prob_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob")); InitSizeLists(); return true; @@ -92,7 +92,7 @@ bool DropoutGpuFwdKernel::Launch(const std::vector &inputs, const st } curandGenerateUniform(mask_generator_, mask, num_count_); - DropoutForward(input, mask, output, num_count_, drop_prob_, reinterpret_cast(stream_ptr)); + DropoutForward(input, mask, output, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h index accff17429e3a5a8ef630edb8973fb9c2bfa057d..81eb78c880884e21055bb9c3953117788ccd49eb 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h @@ -52,7 +52,7 @@ class DropoutGpuFwdKernel : public GpuKernel { cudnnHandle_t cudnn_handle_; bool is_null_input_; size_t num_count_; - float drop_prob_; + float keep_prob_; bool states_init_; curandGenerator_t mask_generator_; std::vector input_size_list_; diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc index 44f603f02d483678703796c1a83db3442ea9f16f..4517f1bb30add366ab6d6dee2e0e893ac66e73d4 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace kernel { DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel() - : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), drop_prob_(0.0) {} + : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {} DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); } @@ -50,7 +50,7 @@ bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) { for (size_t x : input_shape) { num_count_ *= x; } - drop_prob_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); + keep_prob_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob")); InitSizeLists(); return true; @@ -84,7 +84,7 @@ bool DropoutGradGpuFwdKernel::Launch(const std::vector &inputs, cons auto *mask = reinterpret_cast(inputs[1]->addr); auto *dx = reinterpret_cast(outputs[0]->addr); - DropoutBackward(dy, mask, dx, num_count_, drop_prob_, reinterpret_cast(stream_ptr)); + DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h index 79d4117b58b9a0785d4ecd15888da1280c61836f..4991b9dad59a332d823ef7ed9c49d6aba4ff3532 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h @@ -45,7 +45,7 @@ class DropoutGradGpuFwdKernel : public GpuKernel { cudnnHandle_t cudnn_handle_; bool is_null_input_; size_t num_count_; - float drop_prob_; + float keep_prob_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 260f3c509f141be2f29dcb4dbca6468c75cb3b4a..2af3d84c678e2e09cf35920285554c9464784fba 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -675,7 +675,7 @@ def get_bprop_binary_cross_entropy(self): @bprop_getters.register(P.Dropout) def get_bprop_dropout(self): """Grad definition for `Dropout` operation.""" - grad = P.DropoutGrad(self.drop_prob) + grad = P.DropoutGrad(self.keep_prob) def bprop(x, out, dout): _, mask = out diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 5426afb53c3cd1363f71b8be5296f06ead4b08a2..0701db8c3bef4a84bf4adf4c62bfb096b54e11f2 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3227,7 +3227,8 @@ class Dropout(PrimitiveWithInfer): During training, randomly zeroes some of the elements of the input tensor with probability. Args: - drop_prob (float): probability of an element to be zeroed. Default: 0. + keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9, + means dropping out 10% of input units. Inputs: - **shape** (tuple[int]) - The shape of target mask. @@ -3236,14 +3237,14 @@ class Dropout(PrimitiveWithInfer): Tensor, the value of generated mask for input shape. Examples: - >>> dropout = P.Dropout(drop_prob=0.5) + >>> dropout = P.Dropout(keep_prob=0.5) >>> in = Tensor((20, 16, 50, 50)) >>> out = dropout(in) """ @prim_attr_register - def __init__(self, drop_prob=0): - self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) + def __init__(self, keep_prob=0.5): + self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) def infer_shape(self, x_shape): validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) @@ -3262,7 +3263,8 @@ class DropoutGrad(PrimitiveWithInfer): of the input tensor with probability. Args: - drop_prob (float): probability of an element to be zeroed. Default: 0. + keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9, + means dropping out 10% of input units. Inputs: - **shape** (tuple[int]) - The shape of target mask. @@ -3271,14 +3273,14 @@ class DropoutGrad(PrimitiveWithInfer): Tensor, the value of generated mask for input shape. Examples: - >>> dropout_grad = P.DropoutGrad(drop_prob=0.5) + >>> dropout_grad = P.DropoutGrad(keep_prob=0.5) >>> in = Tensor((20, 16, 50, 50)) >>> out = dropout_grad(in) """ @prim_attr_register - def __init__(self, drop_prob=0): - self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) + def __init__(self, keep_prob=0.5): + self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) def infer_shape(self, dy_shape, mask_shape): return dy_shape