From 7af67f9ff0a9b7fbeed855aa2aa58de35c7774e4 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 25 Jul 2019 14:22:35 +0800 Subject: [PATCH] Fix CPU implementation of roi_align_op backward (#18728) (#18742) --- paddle/fluid/operators/roi_align_op.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index a18aee1b86..936b2f0e9d 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -154,6 +154,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel { int width = in_dims[3]; int rois_num = rois->dims()[0]; + if (rois_num == 0) return; + auto in_stride = framework::stride(in_dims); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out->dims()); @@ -278,6 +280,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { const T* out_grad_data = out_grad->data(); T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, in_grad, static_cast(0)); + auto in_stride = framework::stride(in->dims()); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out_grad->dims()); -- GitLab