From 8a16038c55bc00787743279462478bf7c49b98b7 Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Wed, 3 Jun 2020 17:06:25 +0800 Subject: [PATCH] modify faster rcnn --- paddlex/cv/models/faster_rcnn.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paddlex/cv/models/faster_rcnn.py b/paddlex/cv/models/faster_rcnn.py index 86d21db..1b73205 100644 --- a/paddlex/cv/models/faster_rcnn.py +++ b/paddlex/cv/models/faster_rcnn.py @@ -79,8 +79,6 @@ class FasterRCNN(BaseAPI): layers = 50 variant = 'd' norm_type = 'affine_channel' - if self.bbox_loss_type != 'SmoothL1Loss': - norm_type = 'bn' elif backbone_name == 'ResNet101': layers = 101 variant = 'b' @@ -89,14 +87,15 @@ class FasterRCNN(BaseAPI): layers = 101 variant = 'd' norm_type = 'affine_channel' - if self.bbox_loss_type != 'SmoothL1Loss': - norm_type = 'bn' elif backbone_name == 'HRNet_W18': backbone = paddlex.cv.nets.hrnet.HRNet( width=18, freeze_norm=True, norm_decay=0., freeze_at=0) if self.with_fpn is False: self.with_fpn = True return backbone + if backbone_name.startswith('ResNet'): + if self.bbox_loss_type != 'SmoothL1Loss': + norm_type = 'bn' if self.with_fpn: backbone = paddlex.cv.nets.resnet.ResNet( norm_type='bn' if norm_type is None else norm_type, -- GitLab