From 71101c9cf72a0c158f159d4b9c1ccd7002fa761c Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 7 Mar 2019 12:27:45 +0000 Subject: [PATCH] fix input_grad not set zero. test=develop --- paddle/fluid/operators/temporal_shift_op.cu | 3 +++ paddle/fluid/operators/temporal_shift_op.h | 1 + 2 files changed, 4 insertions(+) diff --git a/paddle/fluid/operators/temporal_shift_op.cu b/paddle/fluid/operators/temporal_shift_op.cu index b62b4703e2c..b555c08c223 100644 --- a/paddle/fluid/operators/temporal_shift_op.cu +++ b/paddle/fluid/operators/temporal_shift_op.cu @@ -129,6 +129,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { const T* output_grad_data = output_grad->data(); T* input_grad_data = input_grad->mutable_data({nt, c, h, w}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), input_grad, + static_cast(0)); int pixelNum = nt * chw; int grid_dim = (pixelNum + 512 - 1) / 512; diff --git a/paddle/fluid/operators/temporal_shift_op.h b/paddle/fluid/operators/temporal_shift_op.h index 9b96def3c72..3342a8b4a1b 100644 --- a/paddle/fluid/operators/temporal_shift_op.h +++ b/paddle/fluid/operators/temporal_shift_op.h @@ -88,6 +88,7 @@ class TemporalShiftGradKernel : public framework::OpKernel { const T* output_grad_data = output_grad->data(); T* input_grad_data = input_grad->mutable_data({nt, c, h, w}, ctx.GetPlace()); + memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); int src_it = 0; for (int i = 0; i < output_grad->numel(); i++) { -- GitLab