diff --git a/lite/kernels/x86/roi_align_compute.cc b/lite/kernels/x86/roi_align_compute.cc index 9642098b80a5fccf51379c178eb0fada986b4b22..26efd9160c59d0a45e53800d62e050bbfd941799 100644 --- a/lite/kernels/x86/roi_align_compute.cc +++ b/lite/kernels/x86/roi_align_compute.cc @@ -119,6 +119,7 @@ void RoiAlignCompute::Run() { auto rois_dims = rois->dims(); int rois_num = rois_dims[0]; auto out_dims = out->dims(); + if (rois_num == 0) { return; } @@ -138,14 +139,18 @@ void RoiAlignCompute::Run() { roi_batch_id_list.Resize({rois_num}); int* roi_batch_id_data = roi_batch_id_list.mutable_data(); - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - // CHECK_OR_FALSE(rois_batch_size == batch_size); - // int rois_num_with_lod = rois_lod[rois_batch_size]; - // CHECK_OR_FALSE(rois_num_with_lod == rois_num); - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; + if (rois->lod().empty() && in_dims[0] /* batch_size */ == 1) { + std::fill_n(roi_batch_id_data, rois_num, 0); + } else { + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + // CHECK_OR_FALSE(rois_batch_size == batch_size); + // int rois_num_with_lod = rois_lod[rois_batch_size]; + // CHECK_OR_FALSE(rois_num_with_lod == rois_num); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } } }