提交 f660553d 编写于 作者: J jerrywgz

enhance nms for mask rcnn, test=develop

上级 88ee56d0
......@@ -405,7 +405,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
if (num_kept == 0) {
T* od = outs->mutable_data<T>({1, 1}, ctx.GetPlace());
od[0] = -1;
batch_starts.back() = 1;
batch_starts = {0, 1};
} else {
outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
if (score_dims.size() == 3) {
......@@ -443,7 +443,6 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
framework::LoD lod;
lod.emplace_back(batch_starts);
LOG(ERROR) << "c++ lod: " << lod;
outs->set_lod(lod);
}
......
......@@ -173,13 +173,15 @@ def lod_multiclass_nms(boxes, scores, background, score_threshold,
normalized,
shared=False)
if nmsed_num == 0:
lod.append(1)
#lod.append(1)
continue
lod.append(nmsed_num)
for c, indices in nmsed_outs.items():
for idx in indices:
xmin, ymin, xmax, ymax = box[idx, c, :]
det_outs.append([c, score[idx][c], xmin, ymin, xmax, ymax])
if len(lod) == 0:
lod.append(1)
return det_outs, lod
......@@ -208,7 +210,7 @@ def batched_multiclass_nms(boxes,
normalized,
shared=True)
if nmsed_num == 0:
lod.append(1)
# lod.append(1)
continue
lod.append(nmsed_num)
......@@ -221,7 +223,8 @@ def batched_multiclass_nms(boxes,
sorted_det_out = sorted(
tmp_det_out, key=lambda tup: tup[0], reverse=False)
det_outs.extend(sorted_det_out)
if len(lod) == 0:
lod += [1]
return det_outs, lod
......@@ -259,7 +262,6 @@ class TestMulticlassNMSOp(OpTest):
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
print('python lod: ', lod)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册