From b76f5a8489402ce069dfaa9c3f4d172a2932bbad Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 22 Dec 2020 13:26:07 +0800 Subject: [PATCH] fix the bug of dropout_grad (#29813) --- paddle/fluid/operators/dropout_op.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 1f7f7ac224..d77193e485 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -54,11 +54,14 @@ __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { T dout_vec[VecSize]; - LoadT* value = reinterpret_cast(&dout_vec); - *value = *reinterpret_cast(&dout[i]); + LoadT* dout_value = reinterpret_cast(&dout_vec); + *dout_value = *reinterpret_cast(&dout[i]); - T dx_vec[VecSize]; MaskType mask_vec[VecSize]; + MaskLoadT* mask_value = reinterpret_cast(&mask_vec); + *mask_value = *reinterpret_cast(&mask[i]); + + T dx_vec[VecSize]; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { -- GitLab