From deee78abf5a333960559312c31eb33091ab54d21 Mon Sep 17 00:00:00 2001 From: FDInSky <48318485+FDInSky@users.noreply.github.com> Date: Fri, 26 Jul 2019 16:52:42 +0800 Subject: [PATCH] fix roi_align_op cpu backward's bug (#18825) [cherry pick]fix roi_align_op cpu backward's bug --- paddle/fluid/operators/roi_align_op.h | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index 936b2f0e9dd..79922501be5 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { auto spatial_scale = ctx.Attr("spatial_scale"); auto sampling_ratio = ctx.Attr("sampling_ratio"); auto in_dims = in->dims(); - if (!in_grad) { - return; - } + int channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; + + if (!in_grad) { + return; + } Tensor roi_batch_id_list; roi_batch_id_list.Resize({rois_num}); int* roi_batch_id_data = @@ -276,14 +278,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { } } - const T* rois_data = rois->data(); - const T* out_grad_data = out_grad->data(); - T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); - + in_grad->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); math::SetConstant set_zero; set_zero(dev_ctx, in_grad, static_cast(0)); + int output_grad_size = out_grad->numel(); + + if ((!out_grad->IsInitialized()) || (output_grad_size <= 0)) { + return; + } + + const T* rois_data = rois->data(); + const T* out_grad_data = out_grad->data(); + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + auto in_stride = framework::stride(in->dims()); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out_grad->dims()); -- GitLab