diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index cbcdcefc450cd679466208b2fd9fc32701b38731..b7f7370b4bbbe023b2ffe7a18d77255dc275d44b 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -140,8 +140,22 @@ def get_registered_modules(): def make_partial(cls): - op_module = importlib.import_module(cls.__op__.__module__) - op = getattr(op_module, cls.__op__.__name__) + if isinstance(cls.__op__, str): + sep = cls.__op__.split('.') + op_name = sep[-1] + op_module = importlib.import_module('.'.join(sep[:-1])) + else: + op_name = cls.__op__.__name__ + op_module = importlib.import_module(cls.__op__.__module__) + + if not hasattr(op_module, op_name): + import logging + logger = logging.getLogger(__name__) + logger.warn('{} OP not found, maybe a newer version of paddle ' + 'is required.'.format(cls.__op__)) + return cls + + op = getattr(op_module, op_name) cls.__category__ = getattr(cls, '__category__', None) or 'op' def partial_apply(self, *args, **kwargs): diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 56509ab84e3aff91da0ee6f5bf9e00850b8cb213..d456b4c097b7fedd48910d759e3eb7b2c64ca603 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -30,7 +30,7 @@ __all__ = [ 'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead', 'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm', - 'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner' + 'DeformConvNorm', 'MultiClassSoftNMS', 'MatrixNMS', 'LibraBBoxAssigner' ] @@ -492,6 +492,32 @@ class MultiClassNMS(object): self.background_label = background_label +@register +@serializable +class MatrixNMS(object): + __op__ = 'paddle.fluid.layers.matrix_nms' + __append_doc__ = True + + def __init__(self, + score_threshold=.05, + post_threshold=.05, + nms_top_k=-1, + keep_top_k=100, + use_gaussian=False, + gaussian_sigma=2., + normalized=False, + background_label=0): + super(MultiClassNMS, self).__init__() + self.score_threshold = score_threshold + self.post_threshold = post_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.normalized = normalized + self.use_gaussian = use_gaussian + self.gaussian_sigma = gaussian_sigma + self.background_label = background_label + + @register @serializable class MultiClassSoftNMS(object):