From 37428952c6eb155d50099c25cb6d1d4288e87a43 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 21 Aug 2019 10:31:38 +0800 Subject: [PATCH] fix generate mask fpn, test=develop (#19301) --- .../fluid/operators/detection/generate_mask_labels_op.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index 38eafa5fe8f..0d77c7f3a79 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -305,10 +305,10 @@ class GenerateMaskLabelsKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(gt_segms->lod()[0].size() - 1, n); int mask_dim = num_classes * resolution * resolution; - - mask_rois->mutable_data({rois->numel(), kBoxDim}, ctx.GetPlace()); - roi_has_mask_int32->mutable_data({rois->numel(), 1}, ctx.GetPlace()); - mask_int32->mutable_data({rois->numel(), mask_dim}, ctx.GetPlace()); + int roi_num = rois->lod().back()[n]; + mask_rois->mutable_data({roi_num, kBoxDim}, ctx.GetPlace()); + roi_has_mask_int32->mutable_data({roi_num, 1}, ctx.GetPlace()); + mask_int32->mutable_data({roi_num, mask_dim}, ctx.GetPlace()); framework::LoD lod; std::vector lod0(1, 0); -- GitLab