未验证 提交 b76f5a84 编写于 作者: Z Zhang Ting 提交者: GitHub

fix the bug of dropout_grad (#29813)

上级 61820fd2
...@@ -54,11 +54,14 @@ __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, ...@@ -54,11 +54,14 @@ __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
T dout_vec[VecSize]; T dout_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&dout_vec); LoadT* dout_value = reinterpret_cast<LoadT*>(&dout_vec);
*value = *reinterpret_cast<const LoadT*>(&dout[i]); *dout_value = *reinterpret_cast<const LoadT*>(&dout[i]);
T dx_vec[VecSize];
MaskType mask_vec[VecSize]; MaskType mask_vec[VecSize];
MaskLoadT* mask_value = reinterpret_cast<MaskLoadT*>(&mask_vec);
*mask_value = *reinterpret_cast<const MaskLoadT*>(&mask[i]);
T dx_vec[VecSize];
#pragma unroll #pragma unroll
for (int ii = 0; ii < VecSize; ii++) { for (int ii = 0; ii < VecSize; ii++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册