diff --git a/ppdet/modeling/architectures/cascade_mask_rcnn.py b/ppdet/modeling/architectures/cascade_mask_rcnn.py index f353c7c1866e7c66418285b31b348e89b01065a4..97d9665d71e2dd894e99244a81c4b5ce93577f1a 100644 --- a/ppdet/modeling/architectures/cascade_mask_rcnn.py +++ b/ppdet/modeling/architectures/cascade_mask_rcnn.py @@ -437,4 +437,4 @@ class CascadeMaskRCNN(object): def test(self, feed_vars, exclude_nms=False): assert not exclude_nms, "exclude_nms for {} is not support currently".format( self.__class__.__name__) - return self.build(feed_vars, 'test', exclude_nms=exclude_nms) + return self.build(feed_vars, 'test') diff --git a/ppdet/modeling/architectures/cascade_rcnn_cls_aware.py b/ppdet/modeling/architectures/cascade_rcnn_cls_aware.py index d905e199b78e3198b52b4c9d570d3a1a36d48c89..a8773f3a8045bedbf4045fd5bfa62c5b96115e79 100644 --- a/ppdet/modeling/architectures/cascade_rcnn_cls_aware.py +++ b/ppdet/modeling/architectures/cascade_rcnn_cls_aware.py @@ -319,7 +319,7 @@ class CascadeRCNNClsAware(object): return self.build_multi_scale(feed_vars) return self.build(feed_vars, 'test') - def test(self, feed_vars): + def test(self, feed_vars, exclude_nms=False): assert not exclude_nms, "exclude_nms for {} is not support currently".format( self.__class__.__name__) return self.build(feed_vars, 'test')