未验证 提交 d8cb6707 编写于 作者: W Wenyu 提交者: GitHub

fix picodet postprocess for none-det case (#4462)

上级 397a1d57
...@@ -193,13 +193,17 @@ class PicoDetPostProcess(object): ...@@ -193,13 +193,17 @@ class PicoDetPostProcess(object):
top_k=self.keep_top_k, ) top_k=self.keep_top_k, )
picked_box_probs.append(box_probs) picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0]) picked_labels.extend([class_index] * box_probs.shape[0])
if not picked_box_probs:
return np.array([]), np.array([]), np.array([]) if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
out_boxes_num.append(0)
else:
picked_box_probs = np.concatenate(picked_box_probs) picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes # resize output boxes
picked_box_probs[:, :4] = self.warp_boxes(picked_box_probs[:, :4], picked_box_probs[:, :4] = self.warp_boxes(
self.ori_shape[batch_id]) picked_box_probs[:, :4], self.ori_shape[batch_id])
im_scale = np.concatenate([ im_scale = np.concatenate([
self.scale_factor[batch_id][::-1], self.scale_factor[batch_id][::-1],
self.scale_factor[batch_id][::-1] self.scale_factor[batch_id][::-1]
...@@ -210,7 +214,8 @@ class PicoDetPostProcess(object): ...@@ -210,7 +214,8 @@ class PicoDetPostProcess(object):
np.concatenate( np.concatenate(
[ [
np.expand_dims( np.expand_dims(
np.array(picked_labels), axis=-1), np.expand_dims( np.array(picked_labels),
axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1), picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4] picked_box_probs[:, :4]
], ],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册