未验证 提交 6fbf4410 编写于 作者: W wangguanzhong 提交者: GitHub

enhance input check for roi_align, test=develop (#20238)

上级 c20b11ba
...@@ -258,7 +258,11 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -258,7 +258,11 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.Resize({rois_num}); roi_batch_id_list.Resize({rois_num});
auto cplace = platform::CPUPlace(); auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace); int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto rois_lod = rois->lod().back(); auto lod = rois->lod();
PADDLE_ENFORCE_EQ(
lod.empty(), false,
"Input(ROIs) Tensor of ROIAlignOp does not contain LoD information.");
auto rois_lod = lod.back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size, rois_batch_size, batch_size,
......
...@@ -166,7 +166,11 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -166,7 +166,11 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data = int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace()); roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back(); auto lod = rois->lod();
PADDLE_ENFORCE_EQ(
lod.empty(), false,
"Input(ROIs) Tensor of ROIAlignOp does not contain LoD information.");
auto rois_lod = lod.back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size, rois_batch_size, batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册