提交 e3637a3c 编写于 作者: C chenguowei01

add config

上级 1a3f44a5
batch_size: 4
iters: 100000
learning_rate: 0.01
train_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
mode: val
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
types:
- type: CrossEntropyLoss
coef: [1]
batch_size: 4
iters: 10000
learning_rate: 0.01
train_dataset:
type: OpticDiscSeg
dataset_root: data/optic_disc_seg
transforms:
- type: Resize
target_size: [512, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: OpticDiscSeg
dataset_root: data/optic_disc_seg
transforms:
- type: Resize
target_size: [512, 512]
- type: Normalize
mode: val
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
types:
- type: CrossEntropyLoss
coef: [1]
_base_: '../_base_/cityscapes.yml'
model:
type: FCN
backbone:
type: HRNet_W18
backbone_pretrained: pretrained_model/hrnet_w18_imagenet
num_classes: 19
backbone_channels: [270]
_base_: '../_base_/optic_disc_seg.yml'
model:
type: FCN
backbone:
type: HRNet_W18
backbone_pretrained: pretrained_model/hrnet_w18_imagenet
num_classes: 2
backbone_channels: [270]
...@@ -3,32 +3,33 @@ iters: 40000 ...@@ -3,32 +3,33 @@ iters: 40000
train_dataset: train_dataset:
type: Cityscapes type: Cityscapes
dataset_root: datasets/cityscapes dataset_root: data/cityscapes
transforms: transforms:
- type: RandomHorizontalFlip
- type: ResizeStepScaling - type: ResizeStepScaling
min_scale_factor: 0.5 min_scale_factor: 0.5
max_scale_factor: 2.0 max_scale_factor: 2.0
scale_step_size: 0.25 scale_step_size: 0.25
- type: RandomPaddingCrop - type: RandomPaddingCrop
crop_size: [1024, 512] crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: Normalize - type: Normalize
mode: train mode: train
val_dataset: val_dataset:
type: Cityscapes type: Cityscapes
dataset_root: datasets/cityscapes dataset_root: data/cityscapes
transforms: transforms:
- type: Normalize - type: Normalize
mode: val mode: val
model: model:
type: ocrnet type: OCRNet
backbone: backbone:
type: HRNet_W18 type: HRNet_W18
pretrained: dygraph/pretrained_model/hrnet_w18_ssld/model backbone_pretrianed: None
num_classes: 19 num_classes: 19
in_channels: 270 in_channels: 270
model_pretrained: None
optimizer: optimizer:
type: sgd type: sgd
......
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
import os import os
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
# from paddle.incubate.hapi.distributed import DistributedBatchSampler # from paddle.incubate.hapi.distributed import DistributedBatchSampler
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
import paddle.nn.functional as F
import dygraph.utils.logger as logger import dygraph.utils.logger as logger
from dygraph.utils import load_pretrained_model from dygraph.utils import load_pretrained_model
...@@ -27,6 +29,27 @@ from dygraph.utils import Timer, calculate_eta ...@@ -27,6 +29,27 @@ from dygraph.utils import Timer, calculate_eta
from .val import evaluate from .val import evaluate
def check_logits_losses(logits, losses):
len_logits = len(logits)
len_losses = len(losses['types'])
if len_logits != len_losses:
raise RuntimeError(
'The length of logits should equal to the types of loss config: {} != {}.'
.format(len_logits, len_losses))
def loss_computation(logits, label, losses):
check_logits_losses(logits, losses)
loss = 0
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
loss_i = losses['types'][i](logit, label)
loss += losses['coef'][i] * loss_i
return loss
def train(model, def train(model,
train_dataset, train_dataset,
places=None, places=None,
...@@ -40,7 +63,8 @@ def train(model, ...@@ -40,7 +63,8 @@ def train(model,
log_iters=10, log_iters=10,
num_classes=None, num_classes=None,
num_workers=8, num_workers=8,
use_vdl=False): use_vdl=False,
losses=None):
ignore_index = model.ignore_index ignore_index = model.ignore_index
nranks = ParallelEnv().nranks nranks = ParallelEnv().nranks
...@@ -90,13 +114,17 @@ def train(model, ...@@ -90,13 +114,17 @@ def train(model,
images = data[0] images = data[0]
labels = data[1].astype('int64') labels = data[1].astype('int64')
if nranks > 1: if nranks > 1:
loss = ddp_model(images, labels) logits = ddp_model(images)
loss = loss_computation(logits, labels, losses)
# loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus. # apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss) loss = ddp_model.scale_loss(loss)
loss.backward() loss.backward()
ddp_model.apply_collective_grads() ddp_model.apply_collective_grads()
else: else:
loss = model(images, labels) logits = model(images)
loss = loss_computation(logits, labels, losses)
# loss = model(images, labels)
loss.backward() loss.backward()
optimizer.minimize(loss) optimizer.minimize(loss)
model.clear_gradients() model.clear_gradients()
......
...@@ -19,6 +19,8 @@ import tqdm ...@@ -19,6 +19,8 @@ import tqdm
import cv2 import cv2
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.nn.functional as F
import paddle
import dygraph.utils.logger as logger import dygraph.utils.logger as logger
from dygraph.utils import ConfusionMatrix from dygraph.utils import ConfusionMatrix
...@@ -47,7 +49,9 @@ def evaluate(model, ...@@ -47,7 +49,9 @@ def evaluate(model,
for iter, (im, im_info, label) in tqdm.tqdm( for iter, (im, im_info, label) in tqdm.tqdm(
enumerate(eval_dataset), total=total_iters): enumerate(eval_dataset), total=total_iters):
im = to_variable(im) im = to_variable(im)
pred, _ = model(im) # pred, _ = model(im)
logits = model(im)
pred = paddle.argmax(logits[0], axis=1)
pred = pred.numpy().astype('float32') pred = pred.numpy().astype('float32')
pred = np.squeeze(pred) pred = np.squeeze(pred)
for info in im_info[::-1]: for info in im_info[::-1]:
......
...@@ -11,3 +11,6 @@ ...@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import manager
from . import param_init
...@@ -115,3 +115,4 @@ MODELS = ComponentManager() ...@@ -115,3 +115,4 @@ MODELS = ComponentManager()
BACKBONES = ComponentManager() BACKBONES = ComponentManager()
DATASETS = ComponentManager() DATASETS = ComponentManager()
TRANSFORMS = ComponentManager() TRANSFORMS = ComponentManager()
LOSSES = ComponentManager()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
def constant_init(param, value=0.0):
initializer = fluid.initializer.Constant(value)
initializer(param, param.block)
def normal_init(param, loc=0.0, scale=1.0, seed=0):
initializer = fluid.initializer.Normal(loc=loc, scale=scale, seed=seed)
initializer(param, param.block)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from .architectures import * from .architectures import *
from .losses import *
from .unet import UNet from .unet import UNet
from .deeplab import * from .deeplab import *
from .fcn import * from .fcn import *
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
import os
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -23,6 +24,8 @@ from paddle.fluid.initializer import Normal ...@@ -23,6 +24,8 @@ from paddle.fluid.initializer import Normal
from paddle.nn import SyncBatchNorm as BatchNorm from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.utils import utils
from dygraph.cvlibs import param_init
__all__ = [ __all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
...@@ -36,6 +39,7 @@ class HRNet(fluid.dygraph.Layer): ...@@ -36,6 +39,7 @@ class HRNet(fluid.dygraph.Layer):
https://arxiv.org/pdf/1908.07919.pdf. https://arxiv.org/pdf/1908.07919.pdf.
Args: Args:
backbone_pretrained (str): the path of pretrained model.
stage1_num_modules (int): number of modules for stage1. Default 1. stage1_num_modules (int): number of modules for stage1. Default 1.
stage1_num_blocks (list): number of blocks per module for stage1. Default [4]. stage1_num_blocks (list): number of blocks per module for stage1. Default [4].
stage1_num_channels (list): number of channels per branch for stage1. Default [64]. stage1_num_channels (list): number of channels per branch for stage1. Default [64].
...@@ -52,6 +56,7 @@ class HRNet(fluid.dygraph.Layer): ...@@ -52,6 +56,7 @@ class HRNet(fluid.dygraph.Layer):
""" """
def __init__(self, def __init__(self,
backbone_pretrained=None,
stage1_num_modules=1, stage1_num_modules=1,
stage1_num_blocks=[4], stage1_num_blocks=[4],
stage1_num_channels=[64], stage1_num_channels=[64],
...@@ -141,6 +146,8 @@ class HRNet(fluid.dygraph.Layer): ...@@ -141,6 +146,8 @@ class HRNet(fluid.dygraph.Layer):
has_se=self.has_se, has_se=self.has_se,
name="st4") name="st4")
self.init_weight(backbone_pretrained)
def forward(self, x, label=None, mode='train'): def forward(self, x, label=None, mode='train'):
input_shape = x.shape[2:] input_shape = x.shape[2:]
conv1 = self.conv_layer1_1(x) conv1 = self.conv_layer1_1(x)
...@@ -163,7 +170,31 @@ class HRNet(fluid.dygraph.Layer): ...@@ -163,7 +170,31 @@ class HRNet(fluid.dygraph.Layer):
x3 = fluid.layers.resize_bilinear(st4[3], out_shape=(x0_h, x0_w)) x3 = fluid.layers.resize_bilinear(st4[3], out_shape=(x0_h, x0_w))
x = fluid.layers.concat([st4[0], x1, x2, x3], axis=1) x = fluid.layers.concat([st4[0], x1, x2, x3], axis=1)
return x return [x]
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, 1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, 0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(fluid.dygraph.Layer):
...@@ -184,18 +215,8 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -184,18 +215,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name=name + "_weights"),
bias_attr=False) bias_attr=False)
bn_name = name + '_bn' self._batch_norm = BatchNorm(num_filters)
self._batch_norm = BatchNorm(
num_filters,
weight_attr=ParamAttr(
name=bn_name + '_scale',
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
bn_name + '_offset',
initializer=fluid.initializer.Constant(0.0)))
self.act = act self.act = act
def forward(self, input): def forward(self, input):
......
...@@ -17,8 +17,9 @@ from __future__ import division ...@@ -17,8 +17,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import numpy as np import os
import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
...@@ -28,6 +29,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm ...@@ -28,6 +29,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.models.architectures import layer_utils from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = [ __all__ = [
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5", "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
...@@ -46,6 +48,7 @@ def make_divisible(v, divisor=8, min_value=None): ...@@ -46,6 +48,7 @@ def make_divisible(v, divisor=8, min_value=None):
new_v += divisor new_v += divisor
return new_v return new_v
def get_padding_same(kernel_size, dilation_rate): def get_padding_same(kernel_size, dilation_rate):
""" """
SAME padding implementation given kernel_size and dilation_rate. SAME padding implementation given kernel_size and dilation_rate.
...@@ -63,12 +66,19 @@ def get_padding_same(kernel_size, dilation_rate): ...@@ -63,12 +66,19 @@ def get_padding_same(kernel_size, dilation_rate):
""" """
k = kernel_size k = kernel_size
r = dilation_rate r = dilation_rate
padding_same = (k + (k - 1) * (r - 1) - 1)//2 padding_same = (k + (k - 1) * (r - 1) - 1) // 2
return padding_same return padding_same
class MobileNetV3(fluid.dygraph.Layer): class MobileNetV3(fluid.dygraph.Layer):
def __init__(self, scale=1.0, model_name="small", class_dim=1000, output_stride=None, **kwargs): def __init__(self,
backbone_pretrained=None,
scale=1.0,
model_name="small",
class_dim=1000,
output_stride=None,
**kwargs):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
inplanes = 16 inplanes = 16
...@@ -86,10 +96,12 @@ class MobileNetV3(fluid.dygraph.Layer): ...@@ -86,10 +96,12 @@ class MobileNetV3(fluid.dygraph.Layer):
[3, 184, 80, False, "hard_swish", 1], [3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1], [3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1], [3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1], # output 3 -> out_index=11 [3, 672, 112, True, "hard_swish",
1], # output 3 -> out_index=11
[5, 672, 160, True, "hard_swish", 2], [5, 672, 160, True, "hard_swish", 2],
[5, 960, 160, True, "hard_swish", 1], [5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1], # output 3 -> out_index=14 [5, 960, 160, True, "hard_swish",
1], # output 3 -> out_index=14
] ]
self.out_indices = [2, 5, 11, 14] self.out_indices = [2, 5, 11, 14]
...@@ -158,7 +170,6 @@ class MobileNetV3(fluid.dygraph.Layer): ...@@ -158,7 +170,6 @@ class MobileNetV3(fluid.dygraph.Layer):
sublayer=self.block_list[-1], name="conv" + str(i + 2)) sublayer=self.block_list[-1], name="conv" + str(i + 2))
inplanes = make_divisible(scale * c) inplanes = make_divisible(scale * c)
self.last_second_conv = ConvBNLayer( self.last_second_conv = ConvBNLayer(
in_c=inplanes, in_c=inplanes,
out_c=make_divisible(scale * self.cls_ch_squeeze), out_c=make_divisible(scale * self.cls_ch_squeeze),
...@@ -189,6 +200,8 @@ class MobileNetV3(fluid.dygraph.Layer): ...@@ -189,6 +200,8 @@ class MobileNetV3(fluid.dygraph.Layer):
param_attr=ParamAttr("fc_weights"), param_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset")) bias_attr=ParamAttr(name="fc_offset"))
self.init_weight(backbone_pretrained)
def modify_bottle_params(self, output_stride=None): def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0: if output_stride is not None and output_stride % 2 != 0:
...@@ -223,6 +236,19 @@ class MobileNetV3(fluid.dygraph.Layer): ...@@ -223,6 +236,19 @@ class MobileNetV3(fluid.dygraph.Layer):
return x, feat_list return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
...@@ -305,13 +331,14 @@ class ResidualUnit(fluid.dygraph.Layer): ...@@ -305,13 +331,14 @@ class ResidualUnit(fluid.dygraph.Layer):
act=act, act=act,
name=name + "_expand") name=name + "_expand")
self.bottleneck_conv = ConvBNLayer( self.bottleneck_conv = ConvBNLayer(
in_c=mid_c, in_c=mid_c,
out_c=mid_c, out_c=mid_c,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding= get_padding_same(filter_size, dilation), #int((filter_size - 1) // 2) + (dilation - 1), padding=get_padding_same(
filter_size,
dilation), #int((filter_size - 1) // 2) + (dilation - 1),
dilation=dilation, dilation=dilation,
num_groups=mid_c, num_groups=mid_c,
if_act=True, if_act=True,
...@@ -329,6 +356,7 @@ class ResidualUnit(fluid.dygraph.Layer): ...@@ -329,6 +356,7 @@ class ResidualUnit(fluid.dygraph.Layer):
act=None, act=None,
name=name + "_linear") name=name + "_linear")
self.dilation = dilation self.dilation = dilation
def forward(self, inputs): def forward(self, inputs):
x = self.expand_conv(inputs) x = self.expand_conv(inputs)
x = self.bottleneck_conv(x) x = self.bottleneck_conv(x)
...@@ -386,6 +414,7 @@ def MobileNetV3_small_x0_75(**kwargs): ...@@ -386,6 +414,7 @@ def MobileNetV3_small_x0_75(**kwargs):
model = MobileNetV3(model_name="small", scale=0.75, **kwargs) model = MobileNetV3(model_name="small", scale=0.75, **kwargs)
return model return model
@manager.BACKBONES.add_component @manager.BACKBONES.add_component
def MobileNetV3_small_x1_0(**kwargs): def MobileNetV3_small_x1_0(**kwargs):
model = MobileNetV3(model_name="small", scale=1.0, **kwargs) model = MobileNetV3(model_name="small", scale=1.0, **kwargs)
...@@ -411,6 +440,7 @@ def MobileNetV3_large_x0_75(**kwargs): ...@@ -411,6 +440,7 @@ def MobileNetV3_large_x0_75(**kwargs):
model = MobileNetV3(model_name="large", scale=0.75, **kwargs) model = MobileNetV3(model_name="large", scale=0.75, **kwargs)
return model return model
@manager.BACKBONES.add_component @manager.BACKBONES.add_component
def MobileNetV3_large_x1_0(**kwargs): def MobileNetV3_large_x1_0(**kwargs):
model = MobileNetV3(model_name="large", scale=1.0, **kwargs) model = MobileNetV3(model_name="large", scale=1.0, **kwargs)
......
...@@ -30,6 +30,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm ...@@ -30,6 +30,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.utils import utils from dygraph.utils import utils
from dygraph.models.architectures import layer_utils from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = [ __all__ = [
"ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd" "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
...@@ -47,18 +48,23 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -47,18 +48,23 @@ class ConvBNLayer(fluid.dygraph.Layer):
groups=1, groups=1,
is_vd_mode=False, is_vd_mode=False,
act=None, act=None,
name=None, ): name=None,
):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D( self._pool2d_avg = Pool2D(
pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True) pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
self._conv = Conv2D( self._conv = Conv2D(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2 if dilation ==1 else 0, padding=(filter_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
act=None, act=None,
...@@ -125,7 +131,7 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -125,7 +131,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
stride=1, stride=1,
is_vd_mode=False if if_first or stride==1 else True, is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1") name=name + "_branch1")
self.shortcut = shortcut self.shortcut = shortcut
...@@ -137,7 +143,8 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -137,7 +143,8 @@ class BottleneckBlock(fluid.dygraph.Layer):
# If given dilation rate > 1, using corresponding padding # If given dilation rate > 1, using corresponding padding
if self.dilation > 1: if self.dilation > 1:
padding = self.dilation padding = self.dilation
y = fluid.layers.pad(y, [0,0,0,0,padding,padding,padding,padding]) y = fluid.layers.pad(
y, [0, 0, 0, 0, padding, padding, padding, padding])
##################################################################### #####################################################################
conv1 = self.conv1(y) conv1 = self.conv1(y)
conv2 = self.conv2(conv1) conv2 = self.conv2(conv1)
...@@ -202,7 +209,13 @@ class BasicBlock(fluid.dygraph.Layer): ...@@ -202,7 +209,13 @@ class BasicBlock(fluid.dygraph.Layer):
class ResNet_vd(fluid.dygraph.Layer): class ResNet_vd(fluid.dygraph.Layer):
def __init__(self, layers=50, class_dim=1000, output_stride=None, multi_grid=(1, 2, 4), **kwargs): def __init__(self,
backbone_pretrained=None,
layers=50,
class_dim=1000,
output_stride=None,
multi_grid=(1, 2, 4),
**kwargs):
super(ResNet_vd, self).__init__() super(ResNet_vd, self).__init__()
self.layers = layers self.layers = layers
...@@ -221,11 +234,11 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -221,11 +234,11 @@ class ResNet_vd(fluid.dygraph.Layer):
depth = [3, 8, 36, 3] depth = [3, 8, 36, 3]
elif layers == 200: elif layers == 200:
depth = [3, 12, 48, 3] depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, num_channels = [64, 256, 512, 1024
1024] if layers >= 50 else [64, 64, 128, 256] ] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
dilation_dict=None dilation_dict = None
if output_stride == 8: if output_stride == 8:
dilation_dict = {2: 2, 3: 4} dilation_dict = {2: 2, 3: 4}
elif output_stride == 16: elif output_stride == 16:
...@@ -260,7 +273,7 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -260,7 +273,7 @@ class ResNet_vd(fluid.dygraph.Layer):
if layers >= 50: if layers >= 50:
for block in range(len(depth)): for block in range(len(depth)):
shortcut = False shortcut = False
block_list=[] block_list = []
for i in range(depth[block]): for i in range(depth[block]):
if layers in [101, 152] and block == 2: if layers in [101, 152] and block == 2:
if i == 0: if i == 0:
...@@ -272,7 +285,8 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -272,7 +285,8 @@ class ResNet_vd(fluid.dygraph.Layer):
############################################################################### ###############################################################################
# Add dilation rate for some segmentation tasks, if dilation_dict is not None. # Add dilation rate for some segmentation tasks, if dilation_dict is not None.
dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1 dilation_rate = dilation_dict[
block] if dilation_dict and block in dilation_dict else 1
# Actually block here is 'stage', and i is 'block' in 'stage' # Actually block here is 'stage', and i is 'block' in 'stage'
# At the stage 4, expand the the dilation_rate using multi_grid, default (1, 2, 4) # At the stage 4, expand the the dilation_rate using multi_grid, default (1, 2, 4)
...@@ -284,9 +298,11 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -284,9 +298,11 @@ class ResNet_vd(fluid.dygraph.Layer):
bottleneck_block = self.add_sublayer( bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BottleneckBlock( BottleneckBlock(
num_channels=num_channels[block] if i == 0 else num_filters[block] * 4, num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block], num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1, stride=2 if i == 0 and block != 0
and dilation_rate == 1 else 1,
shortcut=shortcut, shortcut=shortcut,
if_first=block == i == 0, if_first=block == i == 0,
name=conv_name, name=conv_name,
...@@ -298,7 +314,7 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -298,7 +314,7 @@ class ResNet_vd(fluid.dygraph.Layer):
else: else:
for block in range(len(depth)): for block in range(len(depth)):
shortcut = False shortcut = False
block_list=[] block_list = []
for i in range(depth[block]): for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer( basic_block = self.add_sublayer(
...@@ -330,6 +346,8 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -330,6 +346,8 @@ class ResNet_vd(fluid.dygraph.Layer):
name="fc_0.w_0"), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0")) bias_attr=ParamAttr(name="fc_0.b_0"))
self.init_weight(backbone_pretrained)
def forward(self, inputs): def forward(self, inputs):
y = self.conv1_1(inputs) y = self.conv1_1(inputs)
y = self.conv1_2(y) y = self.conv1_2(y)
...@@ -355,8 +373,18 @@ class ResNet_vd(fluid.dygraph.Layer): ...@@ -355,8 +373,18 @@ class ResNet_vd(fluid.dygraph.Layer):
# if os.path.exists(pretrained_model): # if os.path.exists(pretrained_model):
# utils.load_pretrained_model(self, pretrained_model) # utils.load_pretrained_model(self, pretrained_model)
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def ResNet18_vd(**args): def ResNet18_vd(**args):
...@@ -368,11 +396,13 @@ def ResNet34_vd(**args): ...@@ -368,11 +396,13 @@ def ResNet34_vd(**args):
model = ResNet_vd(layers=34, **args) model = ResNet_vd(layers=34, **args)
return model return model
@manager.BACKBONES.add_component @manager.BACKBONES.add_component
def ResNet50_vd(**args): def ResNet50_vd(**args):
model = ResNet_vd(layers=50, **args) model = ResNet_vd(layers=50, **args)
return model return model
@manager.BACKBONES.add_component @manager.BACKBONES.add_component
def ResNet101_vd(**args): def ResNet101_vd(**args):
model = ResNet_vd(layers=101, **args) model = ResNet_vd(layers=101, **args)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
...@@ -7,6 +23,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm ...@@ -7,6 +23,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.models.architectures import layer_utils from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"] __all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
...@@ -142,7 +159,6 @@ class Seperate_Conv(fluid.dygraph.Layer): ...@@ -142,7 +159,6 @@ class Seperate_Conv(fluid.dygraph.Layer):
self._act_op2 = layer_utils.Activation(act=act) self._act_op2 = layer_utils.Activation(act=act)
def forward(self, inputs): def forward(self, inputs):
x = self._conv1(inputs) x = self._conv1(inputs)
x = self._bn1(x) x = self._bn1(x)
...@@ -257,7 +273,12 @@ class XceptionDeeplab(fluid.dygraph.Layer): ...@@ -257,7 +273,12 @@ class XceptionDeeplab(fluid.dygraph.Layer):
#def __init__(self, backbone, class_dim=1000): #def __init__(self, backbone, class_dim=1000):
# add output_stride # add output_stride
def __init__(self, backbone, output_stride=16, class_dim=1000, **kwargs): def __init__(self,
backbone,
backbone_pretrained=None,
output_stride=16,
class_dim=1000,
**kwargs):
super(XceptionDeeplab, self).__init__() super(XceptionDeeplab, self).__init__()
...@@ -280,7 +301,6 @@ class XceptionDeeplab(fluid.dygraph.Layer): ...@@ -280,7 +301,6 @@ class XceptionDeeplab(fluid.dygraph.Layer):
padding=1, padding=1,
act="relu", act="relu",
name=self.backbone + "/entry_flow/conv2") name=self.backbone + "/entry_flow/conv2")
""" """
bottleneck_params = { bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]), "entry_flow": (3, [2, 2, 2], [128, 256, 728]),
...@@ -381,6 +401,8 @@ class XceptionDeeplab(fluid.dygraph.Layer): ...@@ -381,6 +401,8 @@ class XceptionDeeplab(fluid.dygraph.Layer):
param_attr=ParamAttr(name="fc_weights"), param_attr=ParamAttr(name="fc_weights"),
bias_attr=ParamAttr(name="fc_bias")) bias_attr=ParamAttr(name="fc_bias"))
self.init_weight(backbone_pretrained)
def forward(self, inputs): def forward(self, inputs):
x = self._conv1(inputs) x = self._conv1(inputs)
x = self._conv2(x) x = self._conv2(x)
...@@ -401,11 +423,25 @@ class XceptionDeeplab(fluid.dygraph.Layer): ...@@ -401,11 +423,25 @@ class XceptionDeeplab(fluid.dygraph.Layer):
x = self._fc(x) x = self._fc(x)
return x, feat_list return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def Xception41_deeplab(**args): def Xception41_deeplab(**args):
model = XceptionDeeplab('xception_41', **args) model = XceptionDeeplab('xception_41', **args)
return model return model
@manager.BACKBONES.add_component @manager.BACKBONES.add_component
def Xception65_deeplab(**args): def Xception65_deeplab(**args):
model = XceptionDeeplab("xception_65", **args) model = XceptionDeeplab("xception_65", **args)
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
...@@ -23,10 +22,12 @@ from paddle.fluid.dygraph import Conv2D ...@@ -23,10 +22,12 @@ from paddle.fluid.dygraph import Conv2D
from dygraph.utils import utils from dygraph.utils import utils
__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8", __all__ = [
'DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
"deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8", "deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8",
"deeplabv3p_xception65_deeplab", "deeplabv3p_xception65_deeplab", "deeplabv3p_mobilenetv3_large",
"deeplabv3p_mobilenetv3_large", "deeplabv3p_mobilenetv3_small"] "deeplabv3p_mobilenetv3_small"
]
class ImageAverage(dygraph.Layer): class ImageAverage(dygraph.Layer):
...@@ -40,9 +41,8 @@ class ImageAverage(dygraph.Layer): ...@@ -40,9 +41,8 @@ class ImageAverage(dygraph.Layer):
def __init__(self, num_channels): def __init__(self, num_channels):
super(ImageAverage, self).__init__() super(ImageAverage, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels, self.conv_bn_relu = layer_utils.ConvBnRelu(
num_filters=256, num_channels, num_filters=256, filter_size=1)
filter_size=1)
def forward(self, input): def forward(self, input):
x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True) x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True)
...@@ -69,18 +69,22 @@ class ASPP(dygraph.Layer): ...@@ -69,18 +69,22 @@ class ASPP(dygraph.Layer):
elif output_stride == 8: elif output_stride == 8:
aspp_ratios = (12, 24, 36) aspp_ratios = (12, 24, 36)
else: else:
raise NotImplementedError("Only support output_stride is 8 or 16, but received{}".format(output_stride)) raise NotImplementedError(
"Only support output_stride is 8 or 16, but received{}".format(
output_stride))
self.image_average = ImageAverage(num_channels=in_channels) self.image_average = ImageAverage(num_channels=in_channels)
# The first aspp using 1*1 conv # The first aspp using 1*1 conv
self.aspp1 = layer_utils.ConvBnRelu(num_channels=in_channels, self.aspp1 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256, num_filters=256,
filter_size=1, filter_size=1,
using_sep_conv=False) using_sep_conv=False)
# The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0] # The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0]
self.aspp2 = layer_utils.ConvBnRelu(num_channels=in_channels, self.aspp2 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256, num_filters=256,
filter_size=3, filter_size=3,
using_sep_conv=using_sep_conv, using_sep_conv=using_sep_conv,
...@@ -88,7 +92,8 @@ class ASPP(dygraph.Layer): ...@@ -88,7 +92,8 @@ class ASPP(dygraph.Layer):
padding=aspp_ratios[0]) padding=aspp_ratios[0])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1] # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1]
self.aspp3 = layer_utils.ConvBnRelu(num_channels=in_channels, self.aspp3 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256, num_filters=256,
filter_size=3, filter_size=3,
using_sep_conv=using_sep_conv, using_sep_conv=using_sep_conv,
...@@ -96,7 +101,8 @@ class ASPP(dygraph.Layer): ...@@ -96,7 +101,8 @@ class ASPP(dygraph.Layer):
padding=aspp_ratios[1]) padding=aspp_ratios[1])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2] # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2]
self.aspp4 = layer_utils.ConvBnRelu(num_channels=in_channels, self.aspp4 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256, num_filters=256,
filter_size=3, filter_size=3,
using_sep_conv=using_sep_conv, using_sep_conv=using_sep_conv,
...@@ -104,9 +110,8 @@ class ASPP(dygraph.Layer): ...@@ -104,9 +110,8 @@ class ASPP(dygraph.Layer):
padding=aspp_ratios[2]) padding=aspp_ratios[2])
# After concat op, using 1*1 conv # After concat op, using 1*1 conv
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280, self.conv_bn_relu = layer_utils.ConvBnRelu(
num_filters=256, num_channels=1280, num_filters=256, filter_size=1)
filter_size=1)
def forward(self, x): def forward(self, x):
...@@ -136,23 +141,23 @@ class Decoder(dygraph.Layer): ...@@ -136,23 +141,23 @@ class Decoder(dygraph.Layer):
def __init__(self, num_classes, in_channels, using_sep_conv=True): def __init__(self, num_classes, in_channels, using_sep_conv=True):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv_bn_relu1 = layer_utils.ConvBnRelu(num_channels=in_channels, self.conv_bn_relu1 = layer_utils.ConvBnRelu(
num_filters=48, num_channels=in_channels, num_filters=48, filter_size=1)
filter_size=1)
self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=304, self.conv_bn_relu2 = layer_utils.ConvBnRelu(
num_channels=304,
num_filters=256, num_filters=256,
filter_size=3, filter_size=3,
using_sep_conv=using_sep_conv, using_sep_conv=using_sep_conv,
padding=1) padding=1)
self.conv_bn_relu3 = layer_utils.ConvBnRelu(num_channels=256, self.conv_bn_relu3 = layer_utils.ConvBnRelu(
num_channels=256,
num_filters=256, num_filters=256,
filter_size=3, filter_size=3,
using_sep_conv=using_sep_conv, using_sep_conv=using_sep_conv,
padding=1) padding=1)
self.conv = Conv2D(num_channels=256, self.conv = Conv2D(
num_filters=num_classes, num_channels=256, num_filters=num_classes, filter_size=1)
filter_size=1)
def forward(self, x, low_level_feat): def forward(self, x, low_level_feat):
low_level_feat = self.conv_bn_relu1(low_level_feat) low_level_feat = self.conv_bn_relu1(low_level_feat)
...@@ -164,6 +169,7 @@ class Decoder(dygraph.Layer): ...@@ -164,6 +169,7 @@ class Decoder(dygraph.Layer):
return x return x
@manager.MODELS.add_component
class DeepLabV3P(dygraph.Layer): class DeepLabV3P(dygraph.Layer):
""" """
The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder
...@@ -173,9 +179,11 @@ class DeepLabV3P(dygraph.Layer): ...@@ -173,9 +179,11 @@ class DeepLabV3P(dygraph.Layer):
(https://arxiv.org/abs/1802.02611) (https://arxiv.org/abs/1802.02611)
Args: Args:
backbone (str): backbone name, currently support Xception65, Resnet101_vd. Default Resnet101_vd. num_classes (int): the unique number of target classes.
backbone (paddle.nn.Layer): backbone networks, currently support Xception65, Resnet101_vd. Default Resnet101_vd.
num_classes (int): the unique number of target classes. Default 2. model_pretrained (str): the path of pretrained model.
output_stride (int): the ratio of input size and final feature size. Default 16. output_stride (int): the ratio of input size and final feature size. Default 16.
...@@ -193,28 +201,29 @@ class DeepLabV3P(dygraph.Layer): ...@@ -193,28 +201,29 @@ class DeepLabV3P(dygraph.Layer):
using_sep_conv (bool): a bool value indicates whether using separable convolutions using_sep_conv (bool): a bool value indicates whether using separable convolutions
in ASPP and Decoder components. Default True. in ASPP and Decoder components. Default True.
pretrained_model (str): the pretrained_model path of backbone.
""" """
def __init__(self, def __init__(self,
num_classes,
backbone, backbone,
num_classes=2, model_pretrained=None,
output_stride=16, output_stride=16,
backbone_indices=(0, 3), backbone_indices=(0, 3),
backbone_channels=(256, 2048), backbone_channels=(256, 2048),
ignore_index=255, ignore_index=255,
using_sep_conv=True, using_sep_conv=True):
pretrained_model=None):
super(DeepLabV3P, self).__init__() super(DeepLabV3P, self).__init__()
self.backbone = manager.BACKBONES[backbone](output_stride=output_stride) # self.backbone = manager.BACKBONES[backbone](output_stride=output_stride)
self.backbone = backbone
self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv) self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv)
self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv) self.decoder = Decoder(num_classes, backbone_channels[0],
using_sep_conv)
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.EPS = 1e-5 self.EPS = 1e-5
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
self.init_weight(pretrained_model) self.init_weight(model_pretrained)
def forward(self, input, label=None): def forward(self, input, label=None):
...@@ -238,14 +247,14 @@ class DeepLabV3P(dygraph.Layer): ...@@ -238,14 +247,14 @@ class DeepLabV3P(dygraph.Layer):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args: Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
if pretrained_model is not None: if pretrained_model is not None:
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model) utils.load_pretrained_model(self, pretrained_model)
# utils.load_pretrained_model(self, pretrained_model) else:
# for param in self.backbone.parameters(): raise Exception('Pretrained model is not found: {}'.format(
# param.stop_gradient = True pretrained_model))
def _get_loss(self, logit, label): def _get_loss(self, logit, label):
""" """
...@@ -290,31 +299,42 @@ def build_decoder(num_classes, using_sep_conv): ...@@ -290,31 +299,42 @@ def build_decoder(num_classes, using_sep_conv):
@manager.MODELS.add_component @manager.MODELS.add_component
def deeplabv3p_resnet101_vd(*args, **kwargs): def deeplabv3p_resnet101_vd(*args, **kwargs):
pretrained_model = None pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) return DeepLabV3P(
backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def deeplabv3p_resnet101_vd_os8(*args, **kwargs): def deeplabv3p_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) return DeepLabV3P(
backbone='ResNet101_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def deeplabv3p_resnet50_vd(*args, **kwargs): def deeplabv3p_resnet50_vd(*args, **kwargs):
pretrained_model = None pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) return DeepLabV3P(
backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def deeplabv3p_resnet50_vd_os8(*args, **kwargs): def deeplabv3p_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) return DeepLabV3P(
backbone='ResNet50_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def deeplabv3p_xception65_deeplab(*args, **kwargs): def deeplabv3p_xception65_deeplab(*args, **kwargs):
pretrained_model = None pretrained_model = None
return DeepLabV3P(backbone='Xception65_deeplab', return DeepLabV3P(
backbone='Xception65_deeplab',
pretrained_model=pretrained_model, pretrained_model=pretrained_model,
backbone_indices=(0, 1), backbone_indices=(0, 1),
backbone_channels=(128, 2048), backbone_channels=(128, 2048),
...@@ -324,7 +344,8 @@ def deeplabv3p_xception65_deeplab(*args, **kwargs): ...@@ -324,7 +344,8 @@ def deeplabv3p_xception65_deeplab(*args, **kwargs):
@manager.MODELS.add_component @manager.MODELS.add_component
def deeplabv3p_mobilenetv3_large(*args, **kwargs): def deeplabv3p_mobilenetv3_large(*args, **kwargs):
pretrained_model = None pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_large_x1_0', return DeepLabV3P(
backbone='MobileNetV3_large_x1_0',
pretrained_model=pretrained_model, pretrained_model=pretrained_model,
backbone_indices=(0, 3), backbone_indices=(0, 3),
backbone_channels=(24, 160), backbone_channels=(24, 160),
...@@ -334,7 +355,8 @@ def deeplabv3p_mobilenetv3_large(*args, **kwargs): ...@@ -334,7 +355,8 @@ def deeplabv3p_mobilenetv3_large(*args, **kwargs):
@manager.MODELS.add_component @manager.MODELS.add_component
def deeplabv3p_mobilenetv3_small(*args, **kwargs): def deeplabv3p_mobilenetv3_small(*args, **kwargs):
pretrained_model = None pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_small_x1_0', return DeepLabV3P(
backbone='MobileNetV3_small_x1_0',
pretrained_model=pretrained_model, pretrained_model=pretrained_model,
backbone_indices=(0, 3), backbone_indices=(0, 3),
backbone_channels=(16, 96), backbone_channels=(16, 96),
......
...@@ -25,6 +25,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm ...@@ -25,6 +25,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph import utils from dygraph import utils
from dygraph.cvlibs import param_init
__all__ = [ __all__ = [
"fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18", "fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18",
...@@ -33,114 +34,133 @@ __all__ = [ ...@@ -33,114 +34,133 @@ __all__ = [
] ]
@manager.MODELS.add_component
class FCN(fluid.dygraph.Layer): class FCN(fluid.dygraph.Layer):
""" """
Fully Convolutional Networks for Semantic Segmentation. Fully Convolutional Networks for Semantic Segmentation.
https://arxiv.org/abs/1411.4038 https://arxiv.org/abs/1411.4038
Args: Args:
backbone (str): backbone name,
num_classes (int): the unique number of target classes. num_classes (int): the unique number of target classes.
in_channels (int): the channels of input feature maps.
backbone (paddle.nn.Layer): backbone networks.
model_pretrained (str): the path of pretrained model.
backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
channels (int): channels after conv layer before the last one. channels (int): channels after conv layer before the last one.
pretrained_model (str): the path of pretrained model.
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255. ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
""" """
def __init__(self, def __init__(self,
backbone,
num_classes, num_classes,
in_channels, backbone,
model_pretrained=None,
backbone_indices=(-1, ),
backbone_channels=(270, ),
channels=None, channels=None,
pretrained_model=None,
ignore_index=255, ignore_index=255,
**kwargs): **kwargs):
super(FCN, self).__init__() super(FCN, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.backbone_indices = backbone_indices
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.EPS = 1e-5 self.EPS = 1e-5
if channels is None: if channels is None:
channels = in_channels channels = backbone_channels[backbone_indices[0]]
self.backbone = manager.BACKBONES[backbone](**kwargs) self.backbone = backbone
self.conv_last_2 = ConvBNLayer( self.conv_last_2 = ConvBNLayer(
num_channels=in_channels, num_channels=backbone_channels[backbone_indices[0]],
num_filters=channels, num_filters=channels,
filter_size=1, filter_size=1,
stride=1, stride=1)
name='conv-2')
self.conv_last_1 = Conv2D( self.conv_last_1 = Conv2D(
num_channels=channels, num_channels=channels,
num_filters=self.num_classes, num_filters=self.num_classes,
filter_size=1, filter_size=1,
stride=1, stride=1,
padding=0, padding=0)
param_attr=ParamAttr( self.init_weight(model_pretrained)
initializer=Normal(scale=0.001), name='conv-1_weights'))
self.init_weight(pretrained_model)
def forward(self, x, label=None, mode='train'): def forward(self, x):
input_shape = x.shape[2:] input_shape = x.shape[2:]
x = self.backbone(x) fea_list = self.backbone(x)
x = fea_list[self.backbone_indices[0]]
x = self.conv_last_2(x) x = self.conv_last_2(x)
logit = self.conv_last_1(x) logit = self.conv_last_1(x)
logit = fluid.layers.resize_bilinear(logit, input_shape) logit = fluid.layers.resize_bilinear(logit, input_shape)
return [logit]
if self.training:
if label is None: # if self.training:
raise Exception('Label is need during training') # if label is None:
return self._get_loss(logit, label) # raise Exception('Label is need during training')
else: # return self._get_loss(logit, label)
score_map = fluid.layers.softmax(logit, axis=1) # else:
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) # score_map = fluid.layers.softmax(logit, axis=1)
pred = fluid.layers.argmax(score_map, axis=3) # score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.unsqueeze(pred, axes=[3]) # pred = fluid.layers.argmax(score_map, axis=3)
return pred, score_map # pred = fluid.layers.unsqueeze(pred, axes=[3])
# return pred, score_map
def init_weight(self, pretrained_model=None): def init_weight(self, pretrained_model=None):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args: Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, 1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, 0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if pretrained_model is not None: if pretrained_model is not None:
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
utils.load_pretrained_model(self, pretrained_model) utils.load_pretrained_model(self, pretrained_model)
else: else:
raise Exception('Pretrained model is not found: {}'.format( raise Exception('Pretrained model is not found: {}'.format(
pretrained_model)) pretrained_model))
def _get_loss(self, logit, label): # def _get_loss(self, logit, label):
""" # """
compute forward loss of the model # compute forward loss of the model
Args: # Args:
logit (tensor): the logit of model output # logit (tensor): the logit of model output
label (tensor): ground truth # label (tensor): ground truth
Returns: # Returns:
avg_loss (tensor): forward loss # avg_loss (tensor): forward loss
""" # """
logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) # logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1]) # label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index # mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32') # mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy( # loss, probs = fluid.layers.softmax_with_cross_entropy(
logit, # logit,
label, # label,
ignore_index=self.ignore_index, # ignore_index=self.ignore_index,
return_softmax=True, # return_softmax=True,
axis=-1) # axis=-1)
loss = loss * mask # loss = loss * mask
avg_loss = fluid.layers.mean(loss) / ( # avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS) # fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True # label.stop_gradient = True
mask.stop_gradient = True # mask.stop_gradient = True
return avg_loss # return avg_loss
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(fluid.dygraph.Layer):
...@@ -150,8 +170,7 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -150,8 +170,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
filter_size, filter_size,
stride=1, stride=1,
groups=1, groups=1,
act="relu", act="relu"):
name=None):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._conv = Conv2D( self._conv = Conv2D(
...@@ -161,18 +180,8 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -161,18 +180,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name=name + "_weights"),
bias_attr=False) bias_attr=False)
bn_name = name + '_bn' self._batch_norm = BatchNorm(num_filters)
self._batch_norm = BatchNorm(
num_filters,
weight_attr=ParamAttr(
name=bn_name + '_scale',
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
bn_name + '_offset',
initializer=fluid.initializer.Constant(0.0)))
self.act = act self.act = act
def forward(self, input): def forward(self, input):
...@@ -185,49 +194,49 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -185,49 +194,49 @@ class ConvBNLayer(fluid.dygraph.Layer):
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w18_small_v1(*args, **kwargs): def fcn_hrnet_w18_small_v1(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V1', in_channels=240, **kwargs) return FCN(backbone='HRNet_W18_Small_V1', backbone_channels=(240), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w18_small_v2(*args, **kwargs): def fcn_hrnet_w18_small_v2(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V2', in_channels=270, **kwargs) return FCN(backbone='HRNet_W18_Small_V2', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w18(*args, **kwargs): def fcn_hrnet_w18(*args, **kwargs):
return FCN(backbone='HRNet_W18', in_channels=270, **kwargs) return FCN(backbone='HRNet_W18', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w30(*args, **kwargs): def fcn_hrnet_w30(*args, **kwargs):
return FCN(backbone='HRNet_W30', in_channels=450, **kwargs) return FCN(backbone='HRNet_W30', backbone_channels=(450), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w32(*args, **kwargs): def fcn_hrnet_w32(*args, **kwargs):
return FCN(backbone='HRNet_W32', in_channels=480, **kwargs) return FCN(backbone='HRNet_W32', backbone_channels=(480), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w40(*args, **kwargs): def fcn_hrnet_w40(*args, **kwargs):
return FCN(backbone='HRNet_W40', in_channels=600, **kwargs) return FCN(backbone='HRNet_W40', backbone_channels=(600), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w44(*args, **kwargs): def fcn_hrnet_w44(*args, **kwargs):
return FCN(backbone='HRNet_W44', in_channels=660, **kwargs) return FCN(backbone='HRNet_W44', backbone_channels=(660), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w48(*args, **kwargs): def fcn_hrnet_w48(*args, **kwargs):
return FCN(backbone='HRNet_W48', in_channels=720, **kwargs) return FCN(backbone='HRNet_W48', backbone_channels=(720), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w60(*args, **kwargs): def fcn_hrnet_w60(*args, **kwargs):
return FCN(backbone='HRNet_W60', in_channels=900, **kwargs) return FCN(backbone='HRNet_W60', backbone_channels=(900), **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def fcn_hrnet_w64(*args, **kwargs): def fcn_hrnet_w64(*args, **kwargs):
return FCN(backbone='HRNet_W64', in_channels=960, **kwargs) return FCN(backbone='HRNet_W64', backbone_channels=(960), **kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cross_entroy_loss import CrossEntropyLoss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from dygraph.cvlibs import manager
@manager.LOSSES.add_component
class CrossEntropyLoss(nn.CrossEntropyLoss):
"""
Implements the cross entropy loss function.
Args:
weight (Tensor): Weight tensor, a manual rescaling weight given
to each class and the shape is (C). It has the same dimensions as class
number and the data type is float32, float64. Default ``'None'``.
ignore_index (int64): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
reduction (str): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default ``'mean'``.
"""
def __init__(self, weight=None, ignore_index=255, reduction='mean'):
super(CrossEntropyLoss, self).__init__(
weight=weight, ignore_index=ignore_index, reduction=reduction)
self.EPS = 1e-5
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
" 'none', but received %s, which is not allowed." %
self.reduction)
def forward(self, logit, label):
"""
Forward computation.
Args:
logit (Tensor): logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Variable): label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
"""
loss = paddle.nn.functional.cross_entropy(
logit,
label,
weight=self.weight,
ignore_index=self.ignore_index,
reduction=self.reduction)
mask = label != self.ignore_index
mask = paddle.cast(mask, 'float32')
avg_loss = loss / (paddle.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
...@@ -12,11 +12,14 @@ ...@@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Sequential, Conv2D from paddle.fluid.dygraph import Sequential, Conv2D
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.models.architectures.layer_utils import ConvBnRelu from dygraph.models.architectures.layer_utils import ConvBnRelu
from dygraph import utils
class SpatialGatherBlock(fluid.dygraph.Layer): class SpatialGatherBlock(fluid.dygraph.Layer):
...@@ -116,8 +119,9 @@ class ObjectAttentionBlock(fluid.dygraph.Layer): ...@@ -116,8 +119,9 @@ class ObjectAttentionBlock(fluid.dygraph.Layer):
class OCRNet(fluid.dygraph.Layer): class OCRNet(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
num_classes, num_classes,
in_channels,
backbone, backbone,
model_pretrained=None,
in_channels=None,
ocr_mid_channels=512, ocr_mid_channels=512,
ocr_key_channels=256, ocr_key_channels=256,
ignore_index=255): ignore_index=255):
...@@ -139,6 +143,8 @@ class OCRNet(fluid.dygraph.Layer): ...@@ -139,6 +143,8 @@ class OCRNet(fluid.dygraph.Layer):
ConvBnRelu(in_channels, in_channels, 3, padding=1), ConvBnRelu(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1)) Conv2D(in_channels, self.num_classes, 1))
self.init_weight(model_pretrained)
def forward(self, x, label=None): def forward(self, x, label=None):
feats = self.backbone(x) feats = self.backbone(x)
...@@ -164,6 +170,19 @@ class OCRNet(fluid.dygraph.Layer): ...@@ -164,6 +170,19 @@ class OCRNet(fluid.dygraph.Layer):
pred = fluid.layers.unsqueeze(pred, axes=[3]) pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map return pred, score_map
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model.. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label): def _get_loss(self, logit, label):
""" """
compute forward loss of the model compute forward loss of the model
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -35,9 +34,11 @@ class PSPNet(fluid.dygraph.Layer): ...@@ -35,9 +34,11 @@ class PSPNet(fluid.dygraph.Layer):
(https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Pyramid_Scene_Parsing_CVPR_2017_paper.pdf) (https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Pyramid_Scene_Parsing_CVPR_2017_paper.pdf)
Args: Args:
backbone (str): backbone name, currently support Resnet50/101. num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone name, currently support Resnet50/101.
num_classes (int): the unique number of target classes. Default 2. model_pretrained (str): the path of pretrained model.
output_stride (int): the ratio of input size and final feature size. Default 16. output_stride (int): the ratio of input size and final feature size. Default 16.
...@@ -57,42 +58,44 @@ class PSPNet(fluid.dygraph.Layer): ...@@ -57,42 +58,44 @@ class PSPNet(fluid.dygraph.Layer):
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True. enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default to 255. ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default to 255.
pretrained_model (str): the pretrained_model path of backbone.
""" """
def __init__(self, def __init__(self,
num_classes,
backbone, backbone,
num_classes=2, model_pretrained=None,
output_stride=16, output_stride=16,
backbone_indices=(2, 3), backbone_indices=(2, 3),
backbone_channels=(1024, 2048), backbone_channels=(1024, 2048),
pp_out_channels=1024, pp_out_channels=1024,
bin_sizes=(1, 2, 3, 6), bin_sizes=(1, 2, 3, 6),
enable_auxiliary_loss=True, enable_auxiliary_loss=True,
ignore_index=255, ignore_index=255):
pretrained_model=None):
super(PSPNet, self).__init__() super(PSPNet, self).__init__()
self.backbone = manager.BACKBONES[backbone](output_stride=output_stride, # self.backbone = manager.BACKBONES[backbone](output_stride=output_stride,
multi_grid=(1, 1, 1)) # multi_grid=(1, 1, 1))
self.backbone = backbone
self.backbone_indices = backbone_indices self.backbone_indices = backbone_indices
self.psp_module = PPModule(in_channels=backbone_channels[1], self.psp_module = PPModule(
in_channels=backbone_channels[1],
out_channels=pp_out_channels, out_channels=pp_out_channels,
bin_sizes=bin_sizes) bin_sizes=bin_sizes)
self.conv = Conv2D(num_channels=pp_out_channels, self.conv = Conv2D(
num_channels=pp_out_channels,
num_filters=num_classes, num_filters=num_classes,
filter_size=1) filter_size=1)
if enable_auxiliary_loss: if enable_auxiliary_loss:
self.fcn_head = model_utils.FCNHead(in_channels=backbone_channels[0], out_channels=num_classes) self.fcn_head = model_utils.FCNHead(
in_channels=backbone_channels[0], out_channels=num_classes)
self.enable_auxiliary_loss = enable_auxiliary_loss self.enable_auxiliary_loss = enable_auxiliary_loss
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.init_weight(pretrained_model) self.init_weight(model_pretrained)
def forward(self, input, label=None): def forward(self, input, label=None):
...@@ -107,7 +110,8 @@ class PSPNet(fluid.dygraph.Layer): ...@@ -107,7 +110,8 @@ class PSPNet(fluid.dygraph.Layer):
if self.enable_auxiliary_loss: if self.enable_auxiliary_loss:
auxiliary_feat = feat_list[self.backbone_indices[0]] auxiliary_feat = feat_list[self.backbone_indices[0]]
auxiliary_logit = self.fcn_head(auxiliary_feat) auxiliary_logit = self.fcn_head(auxiliary_feat)
auxiliary_logit = fluid.layers.resize_bilinear(auxiliary_logit, input.shape[2:]) auxiliary_logit = fluid.layers.resize_bilinear(
auxiliary_logit, input.shape[2:])
if self.training: if self.training:
loss = model_utils.get_loss(logit, label) loss = model_utils.get_loss(logit, label)
...@@ -116,7 +120,6 @@ class PSPNet(fluid.dygraph.Layer): ...@@ -116,7 +120,6 @@ class PSPNet(fluid.dygraph.Layer):
loss += (0.4 * auxiliary_loss) loss += (0.4 * auxiliary_loss)
return loss return loss
else: else:
pred, score_map = model_utils.get_pred_score_map(logit) pred, score_map = model_utils.get_pred_score_map(logit)
return pred, score_map return pred, score_map
...@@ -124,14 +127,15 @@ class PSPNet(fluid.dygraph.Layer): ...@@ -124,14 +127,15 @@ class PSPNet(fluid.dygraph.Layer):
def init_weight(self, pretrained_model=None): def init_weight(self, pretrained_model=None):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args: Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
if pretrained_model is not None: if pretrained_model is not None:
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model) utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class PPModule(fluid.dygraph.Layer): class PPModule(fluid.dygraph.Layer):
...@@ -151,9 +155,11 @@ class PPModule(fluid.dygraph.Layer): ...@@ -151,9 +155,11 @@ class PPModule(fluid.dygraph.Layer):
self.bin_sizes = bin_sizes self.bin_sizes = bin_sizes
# we use dimension reduction after pooling mentioned in original implementation. # we use dimension reduction after pooling mentioned in original implementation.
self.stages = fluid.dygraph.LayerList([self._make_stage(in_channels, size) for size in bin_sizes]) self.stages = fluid.dygraph.LayerList(
[self._make_stage(in_channels, size) for size in bin_sizes])
self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=in_channels * 2, self.conv_bn_relu2 = layer_utils.ConvBnRelu(
num_channels=in_channels * 2,
num_filters=out_channels, num_filters=out_channels,
filter_size=3, filter_size=3,
padding=1) padding=1)
...@@ -180,7 +186,8 @@ class PPModule(fluid.dygraph.Layer): ...@@ -180,7 +186,8 @@ class PPModule(fluid.dygraph.Layer):
# this paddle version does not support AdaptiveAvgPool2d, so skip it here. # this paddle version does not support AdaptiveAvgPool2d, so skip it here.
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) # prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = layer_utils.ConvBnRelu(num_channels=in_channels, conv = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=in_channels // len(self.bin_sizes), num_filters=in_channels // len(self.bin_sizes),
filter_size=1) filter_size=1)
...@@ -190,7 +197,8 @@ class PPModule(fluid.dygraph.Layer): ...@@ -190,7 +197,8 @@ class PPModule(fluid.dygraph.Layer):
cat_layers = [] cat_layers = []
for i, stage in enumerate(self.stages): for i, stage in enumerate(self.stages):
size = self.bin_sizes[i] size = self.bin_sizes[i]
x = fluid.layers.adaptive_pool2d(input, pool_size=(size, size), pool_type="max") x = fluid.layers.adaptive_pool2d(
input, pool_size=(size, size), pool_type="max")
x = stage(x) x = stage(x)
x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:]) x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:])
cat_layers.append(x) cat_layers.append(x)
...@@ -204,22 +212,32 @@ class PPModule(fluid.dygraph.Layer): ...@@ -204,22 +212,32 @@ class PPModule(fluid.dygraph.Layer):
@manager.MODELS.add_component @manager.MODELS.add_component
def pspnet_resnet101_vd(*args, **kwargs): def pspnet_resnet101_vd(*args, **kwargs):
pretrained_model = None pretrained_model = None
return PSPNet(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) return PSPNet(
backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def pspnet_resnet101_vd_os8(*args, **kwargs): def pspnet_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None pretrained_model = None
return PSPNet(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) return PSPNet(
backbone='ResNet101_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def pspnet_resnet50_vd(*args, **kwargs): def pspnet_resnet50_vd(*args, **kwargs):
pretrained_model = None pretrained_model = None
return PSPNet(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) return PSPNet(
backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component @manager.MODELS.add_component
def pspnet_resnet50_vd_os8(*args, **kwargs): def pspnet_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None pretrained_model = None
return PSPNet(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) return PSPNet(
backbone='ResNet50_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
...@@ -33,7 +33,7 @@ class UNet(fluid.dygraph.Layer): ...@@ -33,7 +33,7 @@ class UNet(fluid.dygraph.Layer):
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255. ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
""" """
def __init__(self, num_classes, pretrained_model=None, ignore_index=255): def __init__(self, num_classes, model_pretrained=None, ignore_index=255):
super(UNet, self).__init__() super(UNet, self).__init__()
self.encode = UnetEncoder() self.encode = UnetEncoder()
self.decode = UnetDecode() self.decode = UnetDecode()
...@@ -41,7 +41,7 @@ class UNet(fluid.dygraph.Layer): ...@@ -41,7 +41,7 @@ class UNet(fluid.dygraph.Layer):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.EPS = 1e-5 self.EPS = 1e-5
self.init_weight(pretrained_model) self.init_weight(model_pretrained)
def forward(self, x, label=None): def forward(self, x, label=None):
encode_data, short_cuts = self.encode(x) encode_data, short_cuts = self.encode(x)
...@@ -60,7 +60,7 @@ class UNet(fluid.dygraph.Layer): ...@@ -60,7 +60,7 @@ class UNet(fluid.dygraph.Layer):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args: Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
if pretrained_model is not None: if pretrained_model is not None:
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
......
...@@ -110,6 +110,8 @@ def main(args): ...@@ -110,6 +110,8 @@ def main(args):
val_dataset = cfg.val_dataset if args.do_eval else None val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss
train( train(
cfg.model, cfg.model,
train_dataset, train_dataset,
...@@ -123,7 +125,8 @@ def main(args): ...@@ -123,7 +125,8 @@ def main(args):
log_iters=args.log_iters, log_iters=args.log_iters,
num_classes=train_dataset.num_classes, num_classes=train_dataset.num_classes,
num_workers=args.num_workers, num_workers=args.num_workers,
use_vdl=args.use_vdl) use_vdl=args.use_vdl,
losses=losses)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -55,8 +55,8 @@ class Compose: ...@@ -55,8 +55,8 @@ class Compose:
if len(outputs) == 3: if len(outputs) == 3:
label = outputs[2] label = outputs[2]
im = permute(im) im = permute(im)
if len(outputs) == 3: # if len(outputs) == 3:
label = label[np.newaxis, :, :] # label = label[np.newaxis, :, :]
return (im, im_info, label) return (im, im_info, label)
......
...@@ -35,15 +35,34 @@ class Config(object): ...@@ -35,15 +35,34 @@ class Config(object):
raise FileNotFoundError('File {} does not exist'.format(path)) raise FileNotFoundError('File {} does not exist'.format(path))
if path.endswith('yml') or path.endswith('yaml'): if path.endswith('yml') or path.endswith('yaml'):
self._parse_from_yaml(path) dic = self._parse_from_yaml(path)
self._build(dic)
else: else:
raise RuntimeError('Config file should in yaml format!') raise RuntimeError('Config file should in yaml format!')
def _update_dic(self, dic, base_dic):
"""
update dic from base_dic
"""
dic = dic.copy()
for key, val in base_dic.items():
if isinstance(val, dict) and key in dic:
dic[key] = self._update_dic(dic[key], val)
else:
dic[key] = val
return dic
def _parse_from_yaml(self, path: str): def _parse_from_yaml(self, path: str):
'''Parse a yaml file and build config''' '''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file: with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader) dic = yaml.load(file, Loader=yaml.FullLoader)
self._build(dic) if '_base_' in dic:
cfg_dir = os.path.dirname(path)
base_path = dic.pop('_base_')
base_path = os.path.join(cfg_dir, base_path)
base_dic = self._parse_from_yaml(base_path)
dic = self._update_dic(dic, base_dic)
return dic
def _build(self, dic: dict): def _build(self, dic: dict):
'''Build config from dictionary''' '''Build config from dictionary'''
...@@ -68,6 +87,7 @@ class Config(object): ...@@ -68,6 +87,7 @@ class Config(object):
}) })
self._loss_cfg = dic.get('loss', {}) self._loss_cfg = dic.get('loss', {})
self._losses = None
self._optimizer_cfg = dic.get('optimizer', {}) self._optimizer_cfg = dic.get('optimizer', {})
...@@ -145,14 +165,23 @@ class Config(object): ...@@ -145,14 +165,23 @@ class Config(object):
return args return args
@property @property
def loss_type(self) -> str: def loss(self) -> list:
... if not self._losses:
@property
def loss_args(self) -> dict:
args = self._loss_cfg.copy() args = self._loss_cfg.copy()
args.pop('type') self._losses = dict()
return args for key, val in args.items():
if key == 'types':
self._losses['types'] = []
for item in args['types']:
self._losses['types'].append(self._load_object(item))
else:
self._losses[key] = val
if len(self._losses['coef']) != len(self._losses['types']):
raise RuntimeError(
'The length of coef should equal to types in loss config: {} != {}.'
.format(
len(self._losses['coef']), len(self._losses['types'])))
return self._losses
@property @property
def model(self) -> Callable: def model(self) -> Callable:
...@@ -175,7 +204,7 @@ class Config(object): ...@@ -175,7 +204,7 @@ class Config(object):
def _load_component(self, com_name: str) -> Any: def _load_component(self, com_name: str) -> Any:
com_list = [ com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS, manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS manager.TRANSFORMS, manager.LOSSES
] ]
for com in com_list: for com in com_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册