未验证 提交 37428952 编写于 作者: W wangguanzhong 提交者: GitHub

fix generate mask fpn, test=develop (#19301)

上级 3fdecc19
......@@ -305,10 +305,10 @@ class GenerateMaskLabelsKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(gt_segms->lod()[0].size() - 1, n);
int mask_dim = num_classes * resolution * resolution;
mask_rois->mutable_data<T>({rois->numel(), kBoxDim}, ctx.GetPlace());
roi_has_mask_int32->mutable_data<int>({rois->numel(), 1}, ctx.GetPlace());
mask_int32->mutable_data<int>({rois->numel(), mask_dim}, ctx.GetPlace());
int roi_num = rois->lod().back()[n];
mask_rois->mutable_data<T>({roi_num, kBoxDim}, ctx.GetPlace());
roi_has_mask_int32->mutable_data<int>({roi_num, 1}, ctx.GetPlace());
mask_int32->mutable_data<int>({roi_num, mask_dim}, ctx.GetPlace());
framework::LoD lod;
std::vector<size_t> lod0(1, 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册