提交 d0f89f21 编写于 作者: S shipengchao 提交者: MaxwellDing

fix roi align x86 kernel to support batchsize==1 and lod empty

上级 b5e42c30
...@@ -119,6 +119,7 @@ void RoiAlignCompute::Run() { ...@@ -119,6 +119,7 @@ void RoiAlignCompute::Run() {
auto rois_dims = rois->dims(); auto rois_dims = rois->dims();
int rois_num = rois_dims[0]; int rois_num = rois_dims[0];
auto out_dims = out->dims(); auto out_dims = out->dims();
if (rois_num == 0) { if (rois_num == 0) {
return; return;
} }
...@@ -138,14 +139,18 @@ void RoiAlignCompute::Run() { ...@@ -138,14 +139,18 @@ void RoiAlignCompute::Run() {
roi_batch_id_list.Resize({rois_num}); roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(); int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>();
auto rois_lod = rois->lod().back(); if (rois->lod().empty() && in_dims[0] /* batch_size */ == 1) {
int rois_batch_size = rois_lod.size() - 1; std::fill_n(roi_batch_id_data, rois_num, 0);
// CHECK_OR_FALSE(rois_batch_size == batch_size); } else {
// int rois_num_with_lod = rois_lod[rois_batch_size]; auto rois_lod = rois->lod().back();
// CHECK_OR_FALSE(rois_num_with_lod == rois_num); int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) { // CHECK_OR_FALSE(rois_batch_size == batch_size);
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { // int rois_num_with_lod = rois_lod[rois_batch_size];
roi_batch_id_data[i] = n; // CHECK_OR_FALSE(rois_num_with_lod == rois_num);
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册