From 60526c1e29dc5a8780732ebd85923fc5996b6062 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 7 Jan 2021 13:36:53 +0800 Subject: [PATCH] fix mask_fpn coverge, test=dygraph (#2016) --- dygraph/ppdet/modeling/architectures/mask_rcnn.py | 4 ++-- dygraph/tools/train.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/dygraph/ppdet/modeling/architectures/mask_rcnn.py b/dygraph/ppdet/modeling/architectures/mask_rcnn.py index 368daa727..dc6feb46f 100644 --- a/dygraph/ppdet/modeling/architectures/mask_rcnn.py +++ b/dygraph/ppdet/modeling/architectures/mask_rcnn.py @@ -87,7 +87,7 @@ class MaskRCNN(BaseArch): # compute targets here when training rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out) # BBox Head - bbox_feat, self.bbox_head_out, self.bbox_head_feat_func = self.bbox_head( + bbox_feat, self.bbox_head_out, bbox_head_feat_func = self.bbox_head( body_feats, rois, spatial_scale) rois_has_mask_int32 = None @@ -108,7 +108,7 @@ class MaskRCNN(BaseArch): # Mask Head self.mask_head_out = self.mask_head( self.inputs, body_feats, self.bboxes, bbox_feat, - rois_has_mask_int32, spatial_scale, self.bbox_head_feat_func) + rois_has_mask_int32, spatial_scale, bbox_head_feat_func) def get_loss(self, ): loss = {} diff --git a/dygraph/tools/train.py b/dygraph/tools/train.py index edf1c47d0..976501927 100755 --- a/dygraph/tools/train.py +++ b/dygraph/tools/train.py @@ -128,6 +128,16 @@ def run(FLAGS, cfg, place): # if sync_bn: # model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + # The parameter filter is temporary fix for training because of #28997 + # in Paddle. + def no_grad(param): + if param.name.startswith("conv1_") or param.name.startswith("res2a_") \ + or param.name.startswith("res2b_") or param.name.startswith("res2c_"): + return True + + for param in filter(no_grad, model.parameters()): + param.stop_gradient = True + # Parallel Model if ParallelEnv().nranks > 1: model = paddle.DataParallel(model) -- GitLab