From cff5e2c173afc4431085b9382c716be7a9b91759 Mon Sep 17 00:00:00 2001 From: FDInSky <48318485+FDInSky@users.noreply.github.com> Date: Thu, 25 Jul 2019 19:42:09 +0800 Subject: [PATCH] fix roi_align_op cpu backward's bug (#18789) * test=develop fix cpu roi_align_op backward bug --- paddle/fluid/operators/roi_align_op.h | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index 936b2f0e9dd..78befea2f87 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { auto spatial_scale = ctx.Attr("spatial_scale"); auto sampling_ratio = ctx.Attr("sampling_ratio"); auto in_dims = in->dims(); - if (!in_grad) { - return; - } + int channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; + + if (!in_grad) { + return; + } Tensor roi_batch_id_list; roi_batch_id_list.Resize({rois_num}); int* roi_batch_id_data = @@ -275,15 +277,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { roi_batch_id_data[i] = n; } } + in_grad->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, in_grad, static_cast(0)); + + int output_grad_size = out_grad->numel(); + + if ((!out_grad->IsInitialized()) || (output_grad_size <= 0)) { + return; + } const T* rois_data = rois->data(); const T* out_grad_data = out_grad->data(); T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - math::SetConstant set_zero; - set_zero(dev_ctx, in_grad, static_cast(0)); - auto in_stride = framework::stride(in->dims()); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out_grad->dims()); -- GitLab