提交 a79ca771 编写于 作者: Y Yu-hui Chen 提交者: TF Object Detection Team

Fixed the issue of NMS not working well with CenterNet in the multiclass

scenario. See b/218767303#comment12 for more detailed explanation.

PiperOrigin-RevId: 457641652
上级 5670119e
......@@ -4235,6 +4235,15 @@ class CenterNetMetaArch(model.DetectionModel):
axis=-2)
multiclass_scores = postprocess_dict[
fields.DetectionResultFields.detection_multiclass_scores]
num_classes = tf.shape(multiclass_scores)[2]
class_mask = tf.cast(
tf.one_hot(
postprocess_dict[fields.DetectionResultFields.detection_classes],
depth=num_classes), tf.bool)
# Surpress the scores of those unselected classes to be zeros. Otherwise,
# the downstream NMS ops might be confused and introduce issues.
multiclass_scores = tf.where(
class_mask, multiclass_scores, tf.zeros_like(multiclass_scores))
num_valid_boxes = postprocess_dict.pop(
fields.DetectionResultFields.num_detections)
# Remove scores and classes as NMS will compute these form multiclass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册