diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index a18aee1b86283cbb48f0b804ccfc476d7cd78f3b..936b2f0e9dda10e77686c1c1703978e54d81add0 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -154,6 +154,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel { int width = in_dims[3]; int rois_num = rois->dims()[0]; + if (rois_num == 0) return; + auto in_stride = framework::stride(in_dims); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out->dims()); @@ -278,6 +280,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { const T* out_grad_data = out_grad->data(); T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, in_grad, static_cast(0)); + auto in_stride = framework::stride(in->dims()); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out_grad->dims());