diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index 936b2f0e9dda10e77686c1c1703978e54d81add0..79922501be5e26bcda920948f25d95ed6ace42c6 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { auto spatial_scale = ctx.Attr("spatial_scale"); auto sampling_ratio = ctx.Attr("sampling_ratio"); auto in_dims = in->dims(); - if (!in_grad) { - return; - } + int channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; int rois_num = rois->dims()[0]; + + if (!in_grad) { + return; + } Tensor roi_batch_id_list; roi_batch_id_list.Resize({rois_num}); int* roi_batch_id_data = @@ -276,14 +278,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { } } - const T* rois_data = rois->data(); - const T* out_grad_data = out_grad->data(); - T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); - + in_grad->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); math::SetConstant set_zero; set_zero(dev_ctx, in_grad, static_cast(0)); + int output_grad_size = out_grad->numel(); + + if ((!out_grad->IsInitialized()) || (output_grad_size <= 0)) { + return; + } + + const T* rois_data = rois->data(); + const T* out_grad_data = out_grad->data(); + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + auto in_stride = framework::stride(in->dims()); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out_grad->dims());