未验证 提交 15dd0fab 编写于 作者: Y Yang Zhang 提交者: GitHub

Add `matrix_nms` op (#645)

* Add `matrix_nms` op

* Make OP registration more robust
上级 4d7ce643
...@@ -140,8 +140,22 @@ def get_registered_modules(): ...@@ -140,8 +140,22 @@ def get_registered_modules():
def make_partial(cls): def make_partial(cls):
op_module = importlib.import_module(cls.__op__.__module__) if isinstance(cls.__op__, str):
op = getattr(op_module, cls.__op__.__name__) sep = cls.__op__.split('.')
op_name = sep[-1]
op_module = importlib.import_module('.'.join(sep[:-1]))
else:
op_name = cls.__op__.__name__
op_module = importlib.import_module(cls.__op__.__module__)
if not hasattr(op_module, op_name):
import logging
logger = logging.getLogger(__name__)
logger.warn('{} OP not found, maybe a newer version of paddle '
'is required.'.format(cls.__op__))
return cls
op = getattr(op_module, op_name)
cls.__category__ = getattr(cls, '__category__', None) or 'op' cls.__category__ = getattr(cls, '__category__', None) or 'op'
def partial_apply(self, *args, **kwargs): def partial_apply(self, *args, **kwargs):
......
...@@ -30,7 +30,7 @@ __all__ = [ ...@@ -30,7 +30,7 @@ __all__ = [
'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner',
'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead', 'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm', 'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm',
'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner' 'DeformConvNorm', 'MultiClassSoftNMS', 'MatrixNMS', 'LibraBBoxAssigner'
] ]
...@@ -492,6 +492,32 @@ class MultiClassNMS(object): ...@@ -492,6 +492,32 @@ class MultiClassNMS(object):
self.background_label = background_label self.background_label = background_label
@register
@serializable
class MatrixNMS(object):
__op__ = 'paddle.fluid.layers.matrix_nms'
__append_doc__ = True
def __init__(self,
score_threshold=.05,
post_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
use_gaussian=False,
gaussian_sigma=2.,
normalized=False,
background_label=0):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.post_threshold = post_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.normalized = normalized
self.use_gaussian = use_gaussian
self.gaussian_sigma = gaussian_sigma
self.background_label = background_label
@register @register
@serializable @serializable
class MultiClassSoftNMS(object): class MultiClassSoftNMS(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册