未验证 提交 f0d69531 编写于 作者: W wangguanzhong 提交者: GitHub

fix mask_fpn coverge, test=dygraph (#2015)

上级 2e8b4e14
...@@ -87,7 +87,7 @@ class MaskRCNN(BaseArch): ...@@ -87,7 +87,7 @@ class MaskRCNN(BaseArch):
# compute targets here when training # compute targets here when training
rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out) rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out)
# BBox Head # BBox Head
bbox_feat, self.bbox_head_out, self.bbox_head_feat_func = self.bbox_head( bbox_feat, self.bbox_head_out, bbox_head_feat_func = self.bbox_head(
body_feats, rois, spatial_scale) body_feats, rois, spatial_scale)
rois_has_mask_int32 = None rois_has_mask_int32 = None
...@@ -108,7 +108,7 @@ class MaskRCNN(BaseArch): ...@@ -108,7 +108,7 @@ class MaskRCNN(BaseArch):
# Mask Head # Mask Head
self.mask_head_out = self.mask_head( self.mask_head_out = self.mask_head(
self.inputs, body_feats, self.bboxes, bbox_feat, self.inputs, body_feats, self.bboxes, bbox_feat,
rois_has_mask_int32, spatial_scale, self.bbox_head_feat_func) rois_has_mask_int32, spatial_scale, bbox_head_feat_func)
def get_loss(self, ): def get_loss(self, ):
loss = {} loss = {}
......
...@@ -128,6 +128,16 @@ def run(FLAGS, cfg, place): ...@@ -128,6 +128,16 @@ def run(FLAGS, cfg, place):
# if sync_bn: # if sync_bn:
# model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) # model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# The parameter filter is temporary fix for training because of #28997
# in Paddle.
def no_grad(param):
if param.name.startswith("conv1_") or param.name.startswith("res2a_") \
or param.name.startswith("res2b_") or param.name.startswith("res2c_"):
return True
for param in filter(no_grad, model.parameters()):
param.stop_gradient = True
# Parallel Model # Parallel Model
if ParallelEnv().nranks > 1: if ParallelEnv().nranks > 1:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册