diff --git a/official/vision/configs/retinanet.py b/official/vision/configs/retinanet.py index cb45cbbec549d6af6de9c0d298b1d6232a7f4edb..1f582778f4bcaf4e726f54f6a3fa9fc3ccac874c 100644 --- a/official/vision/configs/retinanet.py +++ b/official/vision/configs/retinanet.py @@ -139,6 +139,8 @@ class DetectionGenerator(hyperparams.Config): ) # Return decoded boxes/scores even if apply_nms is set `True`. return_decoded: Optional[bool] = None + # Only works when nms_version='v2'. + use_class_agnostic_nms: Optional[bool] = False @dataclasses.dataclass diff --git a/official/vision/modeling/factory.py b/official/vision/modeling/factory.py index 4dbdea85bb08a278fdb02f619a23e64a30b72c9b..0c0b0852fe4cf9dac323e82581030922c3967cc7 100644 --- a/official/vision/modeling/factory.py +++ b/official/vision/modeling/factory.py @@ -324,6 +324,7 @@ def build_retinanet( soft_nms_sigma=generator_config.soft_nms_sigma, tflite_post_processing_config=tflite_post_processing_config, return_decoded=generator_config.return_decoded, + use_class_agnostic_nms=generator_config.use_class_agnostic_nms, ) model = retinanet_model.RetinaNetModel(