提交 fbce7ca3 编写于 作者: C chenguowei01

update unet fcn hrnet

上级 b5a12185
...@@ -32,6 +32,7 @@ learning_rate: ...@@ -32,6 +32,7 @@ learning_rate:
decay: decay:
type: poly type: poly
power: 0.9 power: 0.9
end_lr: 0.0
loss: loss:
types: types:
......
...@@ -30,6 +30,7 @@ learning_rate: ...@@ -30,6 +30,7 @@ learning_rate:
decay: decay:
type: poly type: poly
power: 0.9 power: 0.9
end_lr: 0
loss: loss:
types: types:
......
...@@ -4,9 +4,10 @@ model: ...@@ -4,9 +4,10 @@ model:
type: FCN type: FCN
backbone: backbone:
type: HRNet_W18 type: HRNet_W18
pretrained: pretrained_model/hrnet_w18_imagenet
num_classes: 19 num_classes: 19
backbone_channels: [270] pretrained: Null
backbone_pretrained: pretrained_model/hrnet_w18_imagenet backbone_indices: [-1]
optimizer: optimizer:
weight_decay: 0.0005 weight_decay: 0.0005
...@@ -4,6 +4,7 @@ model: ...@@ -4,6 +4,7 @@ model:
type: FCN type: FCN
backbone: backbone:
type: HRNet_W18 type: HRNet_W18
num_classes: 2 pretrained: pretrained_model/hrnet_w18_imagenet
backbone_channels: [270] num_classes: 19
backbone_pretrained: pretrained_model/hrnet_w18_imagenet pretrained: Null
backbone_indices: [-1]
...@@ -4,6 +4,10 @@ model: ...@@ -4,6 +4,10 @@ model:
type: FCN type: FCN
backbone: backbone:
type: HRNet_W48 type: HRNet_W48
pretrained: pretrained_model/hrnet_w48_imagenet
num_classes: 19 num_classes: 19
backbone_channels: [720] pretrained: Null
backbone_pretrained: pretrained_model/hrnet_w48_imagenet backbone_indices: [-1]
optimizer:
weight_decay: 0.0005
_base_: '../_base_/cityscapes.yml'
batch_size: 2
iters: 40000
model:
type: UNet
num_classes: 19
pretrained: Null
...@@ -57,6 +57,7 @@ class HRNet(nn.Layer): ...@@ -57,6 +57,7 @@ class HRNet(nn.Layer):
""" """
def __init__(self, def __init__(self,
pretrained=None,
stage1_num_modules=1, stage1_num_modules=1,
stage1_num_blocks=[4], stage1_num_blocks=[4],
stage1_num_channels=[64], stage1_num_channels=[64],
...@@ -71,7 +72,7 @@ class HRNet(nn.Layer): ...@@ -71,7 +72,7 @@ class HRNet(nn.Layer):
stage4_num_channels=[18, 36, 72, 144], stage4_num_channels=[18, 36, 72, 144],
has_se=False): has_se=False):
super(HRNet, self).__init__() super(HRNet, self).__init__()
self.pretrained = pretrained
self.stage1_num_modules = stage1_num_modules self.stage1_num_modules = stage1_num_modules
self.stage1_num_blocks = stage1_num_blocks self.stage1_num_blocks = stage1_num_blocks
self.stage1_num_channels = stage1_num_channels self.stage1_num_channels = stage1_num_channels
...@@ -85,6 +86,7 @@ class HRNet(nn.Layer): ...@@ -85,6 +86,7 @@ class HRNet(nn.Layer):
self.stage4_num_blocks = stage4_num_blocks self.stage4_num_blocks = stage4_num_blocks
self.stage4_num_channels = stage4_num_channels self.stage4_num_channels = stage4_num_channels
self.has_se = has_se self.has_se = has_se
self.feat_channels = [sum(stage4_num_channels)]
self.conv_layer1_1 = layer_libs.ConvBNReLU( self.conv_layer1_1 = layer_libs.ConvBNReLU(
in_channels=3, in_channels=3,
...@@ -145,6 +147,7 @@ class HRNet(nn.Layer): ...@@ -145,6 +147,7 @@ class HRNet(nn.Layer):
num_filters=self.stage4_num_channels, num_filters=self.stage4_num_channels,
has_se=self.has_se, has_se=self.has_se,
name="st4") name="st4")
self.init_weight()
def forward(self, x, label=None, mode='train'): def forward(self, x, label=None, mode='train'):
input_shape = x.shape[2:] input_shape = x.shape[2:]
...@@ -170,6 +173,20 @@ class HRNet(nn.Layer): ...@@ -170,6 +173,20 @@ class HRNet(nn.Layer):
return [x] return [x]
def init_weight(self):
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, value=1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, value=0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if self.pretrained is not None:
utils.load_pretrained_model(self, self.pretrained)
class Layer1(nn.Layer): class Layer1(nn.Layer):
def __init__(self, def __init__(self,
......
...@@ -36,65 +36,78 @@ __all__ = [ ...@@ -36,65 +36,78 @@ __all__ = [
@manager.MODELS.add_component @manager.MODELS.add_component
class FCN(nn.Layer): class FCN(nn.Layer):
def __init__(self,
num_classes,
backbone,
pretrained=None,
backbone_indices=(-1, ),
channels=None):
super(FCN, self).__init__()
self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = FCNHead(num_classes, backbone_indices, backbone_channels,
channels)
utils.load_entire_model(self, pretrained)
def forward(self, input):
feat_list = self.backbone(input)
logit_list = self.head(feat_list)
return [
F.resize_bilinear(logit, input.shape[2:]) for logit in logit_list
]
class FCNHead(nn.Layer):
""" """
Fully Convolutional Networks for Semantic Segmentation. A simple implementation for Fully Convolutional Networks for Semantic Segmentation.
https://arxiv.org/abs/1411.4038 https://arxiv.org/abs/1411.4038
Args: Args:
num_classes (int): the unique number of target classes. num_classes (int): the unique number of target classes.
backbone (paddle.nn.Layer): backbone networks. backbone (paddle.nn.Layer): backbone networks.
model_pretrained (str): the path of pretrained model. model_pretrained (str): the path of pretrained model.
backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1. backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index. backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
channels (int): channels after conv layer before the last one. channels (int): channels after conv layer before the last one.
""" """
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone,
backbone_pretrained=None,
model_pretrained=None,
backbone_indices=(-1, ), backbone_indices=(-1, ),
backbone_channels=(270, ), backbone_channels=(270, ),
channels=None): channels=None):
super(FCN, self).__init__() super(FCNHead, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.backbone_pretrained = backbone_pretrained
self.model_pretrained = model_pretrained
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
if channels is None: if channels is None:
channels = backbone_channels[0] channels = backbone_channels[0]
self.backbone = backbone self.conv_1 = layer_libs.ConvBNReLU(
self.conv_last_2 = layer_libs.ConvBNReLU(
in_channels=backbone_channels[0], in_channels=backbone_channels[0],
out_channels=channels, out_channels=channels,
kernel_size=1, kernel_size=1,
padding='same', padding='same',
stride=1) stride=1)
self.conv_last_1 = Conv2d( self.cls = Conv2d(
in_channels=channels, in_channels=channels,
out_channels=self.num_classes, out_channels=self.num_classes,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
if self.training:
self.init_weight() self.init_weight()
def forward(self, x): def forward(self, feat_list):
input_shape = x.shape[2:] logit_list = []
fea_list = self.backbone(x) x = feat_list[self.backbone_indices[0]]
x = fea_list[self.backbone_indices[0]] x = self.conv_1(x)
x = self.conv_last_2(x) logit = self.cls(x)
logit = self.conv_last_1(x) logit_list.append(logit)
logit = F.resize_bilinear(logit, input_shape) return logit_list
return [logit]
def init_weight(self): def init_weight(self):
params = self.parameters() params = self.parameters()
...@@ -108,22 +121,6 @@ class FCN(nn.Layer): ...@@ -108,22 +121,6 @@ class FCN(nn.Layer):
if 'conv' in param_name and 'w_0' in param_name: if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001) param_init.normal_init(param, scale=0.001)
if self.model_pretrained is not None:
if os.path.exists(self.model_pretrained):
utils.load_pretrained_model(self, self.model_pretrained)
else:
raise Exception('Pretrained model is not found: {}'.format(
self.model_pretrained))
elif self.backbone_pretrained is not None:
if os.path.exists(self.backbone_pretrained):
utils.load_pretrained_model(self.backbone,
self.backbone_pretrained)
else:
raise Exception('Pretrained model is not found: {}'.format(
self.backbone_pretrained))
else:
logger.warning('No pretrained model to load, train from scratch')
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w18_small_v1(*args, **kwargs): def fcn_hrnet_w18_small_v1(*args, **kwargs):
......
...@@ -33,38 +33,25 @@ class UNet(nn.Layer): ...@@ -33,38 +33,25 @@ class UNet(nn.Layer):
Args: Args:
num_classes (int): the unique number of target classes. num_classes (int): the unique number of target classes.
pretrained_model (str): the path of pretrained model. pretrained (str): the path of pretrained model for fine tuning.
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
""" """
def __init__(self, num_classes, model_pretrained=None, ignore_index=255): def __init__(self, num_classes, pretrained=None):
super(UNet, self).__init__() super(UNet, self).__init__()
self.model_pretrained = model_pretrained
self.ignore_index = ignore_index
self.encode = UnetEncoder() self.encode = UnetEncoder()
self.decode = UnetDecode() self.decode = UnetDecode()
self.get_logit = GetLogit(64, num_classes) self.get_logit = GetLogit(64, num_classes)
self.EPS = 1e-5
self.init_weight() utils.load_entire_model(self, pretrained)
def forward(self, x, label=None): def forward(self, x, label=None):
logit_list = []
encode_data, short_cuts = self.encode(x) encode_data, short_cuts = self.encode(x)
decode_data = self.decode(encode_data, short_cuts) decode_data = self.decode(encode_data, short_cuts)
logit = self.get_logit(decode_data) logit = self.get_logit(decode_data)
return [logit] logit_list.append(logit)
return logit_list
def init_weight(self):
"""
Initialize the parameters of model parts.
"""
if self.model_pretrained is not None:
if os.path.exists(self.model_pretrained):
utils.load_pretrained_model(self, self.model_pretrained)
else:
raise Exception('Pretrained model is not found: {}'.format(
self.model_pretrained))
class UnetEncoder(nn.Layer): class UnetEncoder(nn.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册