提交 8a16038c 编写于 作者: S sunyanfang01

modify faster rcnn

上级 eca9870e
...@@ -79,8 +79,6 @@ class FasterRCNN(BaseAPI): ...@@ -79,8 +79,6 @@ class FasterRCNN(BaseAPI):
layers = 50 layers = 50
variant = 'd' variant = 'd'
norm_type = 'affine_channel' norm_type = 'affine_channel'
if self.bbox_loss_type != 'SmoothL1Loss':
norm_type = 'bn'
elif backbone_name == 'ResNet101': elif backbone_name == 'ResNet101':
layers = 101 layers = 101
variant = 'b' variant = 'b'
...@@ -89,14 +87,15 @@ class FasterRCNN(BaseAPI): ...@@ -89,14 +87,15 @@ class FasterRCNN(BaseAPI):
layers = 101 layers = 101
variant = 'd' variant = 'd'
norm_type = 'affine_channel' norm_type = 'affine_channel'
if self.bbox_loss_type != 'SmoothL1Loss':
norm_type = 'bn'
elif backbone_name == 'HRNet_W18': elif backbone_name == 'HRNet_W18':
backbone = paddlex.cv.nets.hrnet.HRNet( backbone = paddlex.cv.nets.hrnet.HRNet(
width=18, freeze_norm=True, norm_decay=0., freeze_at=0) width=18, freeze_norm=True, norm_decay=0., freeze_at=0)
if self.with_fpn is False: if self.with_fpn is False:
self.with_fpn = True self.with_fpn = True
return backbone return backbone
if backbone_name.startswith('ResNet'):
if self.bbox_loss_type != 'SmoothL1Loss':
norm_type = 'bn'
if self.with_fpn: if self.with_fpn:
backbone = paddlex.cv.nets.resnet.ResNet( backbone = paddlex.cv.nets.resnet.ResNet(
norm_type='bn' if norm_type is None else norm_type, norm_type='bn' if norm_type is None else norm_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册