提交 ec01bc12 编写于 作者: Z zhangwen31

[host][kernel]ref: matrix_nms code cleanup

上级 d784eb9e
......@@ -273,8 +273,8 @@ void MatrixNmsCompute::Run() {
auto out_dim = box_dim + 2;
Tensor boxes_slice, scores_slice;
size_t num_out = 0;
std::vector<uint64_t> offsets = {0};
int64_t num_out = 0;
std::vector<int64_t> offsets = {0};
std::vector<float> detections;
std::vector<int> indices;
detections.reserve(out_dim * num_boxes * batch_size);
......@@ -298,23 +298,25 @@ void MatrixNmsCompute::Run() {
post_threshold,
use_gaussian,
gaussian_sigma);
offsets.push_back(offsets.back() + static_cast<uint64_t>(num_out));
offsets.push_back(offsets.back() + num_out);
}
uint64_t num_kept = offsets.back();
int64_t num_kept = offsets.back();
if (num_kept == 0) {
outs->Resize({0, out_dim});
outs->mutable_data<float>();
index->Resize({0, 1});
index->mutable_data<int>();
} else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim});
index->Resize({static_cast<int64_t>(num_kept), 1});
outs->Resize({num_kept, out_dim});
index->Resize({num_kept, 1});
std::copy(
detections.begin(), detections.end(), outs->mutable_data<float>());
std::copy(indices.begin(), indices.end(), index->mutable_data<int>());
}
LoD lod;
lod.emplace_back(offsets);
lod.emplace_back(std::vector<uint64_t>(offsets.begin(), offsets.end()));
outs->set_lod(lod);
index->set_lod(lod);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册