diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index 936b2f0e9dda10e77686c1c1703978e54d81add0..78befea2f87302769b1ddee51152ff98daff911c 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { auto spatial_scale = ctx.Attr("spatial_scale"); auto sampling_ratio = ctx.Attr("sampling_ratio"); auto in_dims = in->dims(); - if (!in_grad) { - return; - } + int channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; + + if (!in_grad) { + return; + } Tensor roi_batch_id_list; roi_batch_id_list.Resize({rois_num}); int* roi_batch_id_data = @@ -275,15 +277,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { roi_batch_id_data[i] = n; } } + in_grad->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, in_grad, static_cast(0)); + + int output_grad_size = out_grad->numel(); + + if ((!out_grad->IsInitialized()) || (output_grad_size <= 0)) { + return; + } const T* rois_data = rois->data(); const T* out_grad_data = out_grad->data(); T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - math::SetConstant set_zero; - set_zero(dev_ctx, in_grad, static_cast(0)); - auto in_stride = framework::stride(in->dims()); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out_grad->dims());