未验证 提交 60526c1e 编写于 作者: W wangguanzhong 提交者: GitHub

fix mask_fpn coverge, test=dygraph (#2016)

上级 82999c65
......@@ -87,7 +87,7 @@ class MaskRCNN(BaseArch):
# compute targets here when training
rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out)
# BBox Head
bbox_feat, self.bbox_head_out, self.bbox_head_feat_func = self.bbox_head(
bbox_feat, self.bbox_head_out, bbox_head_feat_func = self.bbox_head(
body_feats, rois, spatial_scale)
rois_has_mask_int32 = None
......@@ -108,7 +108,7 @@ class MaskRCNN(BaseArch):
# Mask Head
self.mask_head_out = self.mask_head(
self.inputs, body_feats, self.bboxes, bbox_feat,
rois_has_mask_int32, spatial_scale, self.bbox_head_feat_func)
rois_has_mask_int32, spatial_scale, bbox_head_feat_func)
def get_loss(self, ):
loss = {}
......
......@@ -128,6 +128,16 @@ def run(FLAGS, cfg, place):
# if sync_bn:
# model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# The parameter filter is temporary fix for training because of #28997
# in Paddle.
def no_grad(param):
if param.name.startswith("conv1_") or param.name.startswith("res2a_") \
or param.name.startswith("res2b_") or param.name.startswith("res2c_"):
return True
for param in filter(no_grad, model.parameters()):
param.stop_gradient = True
# Parallel Model
if ParallelEnv().nranks > 1:
model = paddle.DataParallel(model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册