diff --git a/ppdet/modeling/architecture/mask_rcnn.py b/ppdet/modeling/architecture/mask_rcnn.py index 9c542be3fdbc9407b441d26d3ad5246ed06faf20..76b17998d2a6f2a6f2e1ef3e22f7a681170845c3 100644 --- a/ppdet/modeling/architecture/mask_rcnn.py +++ b/ppdet/modeling/architecture/mask_rcnn.py @@ -133,7 +133,7 @@ class MaskRCNN(BaseArch): loss.update({'loss': total_loss}) return loss - def get_pred(self, ): + def get_pred(self, return_numpy=True): mask = self.mask_post_process(self.bboxes, self.mask_head_out, self.inputs['im_shape'], self.inputs['scale_factor'])