diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index 38eafa5fe8fc6fb1437caa98245d853e0e1566cb..0d77c7f3a79fc491dfdc54d74c7cfebd85a5992e 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -305,10 +305,10 @@ class GenerateMaskLabelsKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(gt_segms->lod()[0].size() - 1, n); int mask_dim = num_classes * resolution * resolution; - - mask_rois->mutable_data({rois->numel(), kBoxDim}, ctx.GetPlace()); - roi_has_mask_int32->mutable_data({rois->numel(), 1}, ctx.GetPlace()); - mask_int32->mutable_data({rois->numel(), mask_dim}, ctx.GetPlace()); + int roi_num = rois->lod().back()[n]; + mask_rois->mutable_data({roi_num, kBoxDim}, ctx.GetPlace()); + roi_has_mask_int32->mutable_data({roi_num, 1}, ctx.GetPlace()); + mask_int32->mutable_data({roi_num, mask_dim}, ctx.GetPlace()); framework::LoD lod; std::vector lod0(1, 0);