提交 d0f89f21 编写于 作者: S shipengchao 提交者: MaxwellDing

fix roi align x86 kernel to support batchsize==1 and lod empty

上级 b5e42c30
......@@ -119,6 +119,7 @@ void RoiAlignCompute::Run() {
auto rois_dims = rois->dims();
int rois_num = rois_dims[0];
auto out_dims = out->dims();
if (rois_num == 0) {
return;
}
......@@ -138,14 +139,18 @@ void RoiAlignCompute::Run() {
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>();
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
// CHECK_OR_FALSE(rois_batch_size == batch_size);
// int rois_num_with_lod = rois_lod[rois_batch_size];
// CHECK_OR_FALSE(rois_num_with_lod == rois_num);
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
if (rois->lod().empty() && in_dims[0] /* batch_size */ == 1) {
std::fill_n(roi_batch_id_data, rois_num, 0);
} else {
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
// CHECK_OR_FALSE(rois_batch_size == batch_size);
// int rois_num_with_lod = rois_lod[rois_batch_size];
// CHECK_OR_FALSE(rois_num_with_lod == rois_num);
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册