From 6e326ca2c6b5b2011d2612ea0b4165b7cdb1c819 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 20 Aug 2019 15:55:08 +0800 Subject: [PATCH] optimize the realization of cuda dropout (#19136) * cuda optimie for dropout * remove tmp swp file * fix compile error test=develop * test=develop optimize the cuda realization of dropout op * remove unsed code test=develop * remove tmp file test=develop --- paddle/fluid/operators/dropout_op.cu | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index e26eba68f1..e3d758c3a2 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -11,14 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include #include #include #include #include #include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/platform/dynload/curand.h" #include "paddle/fluid/platform/float16.h" - namespace paddle { namespace operators { @@ -27,10 +29,7 @@ __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, MaskType* mask_data, T* dst, bool is_upscale_in_train) { - thrust::minstd_rand rng; - rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); - + curandStatePhilox4_32_10_t state; int idx = blockDim.x * blockIdx.x + threadIdx.x; int step_size = 0; @@ -39,12 +38,12 @@ __global__ void RandomGenerator(const size_t n, const int seed, for (; idx < n; idx += blockDim.x * gridDim.x) { T s = src[idx]; if (step_size == 0) { - rng.discard(idx); + curand_init(seed, idx, idx, &state); step_size = blockDim.x * gridDim.x; } else { - rng.discard(step_size); + curand_init(seed, idx, step_size, &state); } - if (dist(rng) < dropout_prob) { + if (curand_uniform(&state) < dropout_prob) { mask = 0; dest = 0; } else { -- GitLab