提交 e3637a3c 编写于 作者: C chenguowei01

add config

上级 1a3f44a5
batch_size: 4
iters: 100000
learning_rate: 0.01
train_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
mode: val
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
types:
- type: CrossEntropyLoss
coef: [1]
batch_size: 4
iters: 10000
learning_rate: 0.01
train_dataset:
type: OpticDiscSeg
dataset_root: data/optic_disc_seg
transforms:
- type: Resize
target_size: [512, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: OpticDiscSeg
dataset_root: data/optic_disc_seg
transforms:
- type: Resize
target_size: [512, 512]
- type: Normalize
mode: val
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
types:
- type: CrossEntropyLoss
coef: [1]
_base_: '../_base_/cityscapes.yml'
model:
type: FCN
backbone:
type: HRNet_W18
backbone_pretrained: pretrained_model/hrnet_w18_imagenet
num_classes: 19
backbone_channels: [270]
_base_: '../_base_/optic_disc_seg.yml'
model:
type: FCN
backbone:
type: HRNet_W18
backbone_pretrained: pretrained_model/hrnet_w18_imagenet
num_classes: 2
backbone_channels: [270]
......@@ -3,32 +3,33 @@ iters: 40000
train_dataset:
type: Cityscapes
dataset_root: datasets/cityscapes
dataset_root: data/cityscapes
transforms:
- type: RandomHorizontalFlip
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: datasets/cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
mode: val
model:
type: ocrnet
type: OCRNet
backbone:
type: HRNet_W18
pretrained: dygraph/pretrained_model/hrnet_w18_ssld/model
backbone_pretrianed: None
num_classes: 19
in_channels: 270
model_pretrained: None
optimizer:
type: sgd
......
......@@ -14,11 +14,13 @@
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
# from paddle.incubate.hapi.distributed import DistributedBatchSampler
from paddle.io import DistributedBatchSampler
import paddle.nn.functional as F
import dygraph.utils.logger as logger
from dygraph.utils import load_pretrained_model
......@@ -27,6 +29,27 @@ from dygraph.utils import Timer, calculate_eta
from .val import evaluate
def check_logits_losses(logits, losses):
len_logits = len(logits)
len_losses = len(losses['types'])
if len_logits != len_losses:
raise RuntimeError(
'The length of logits should equal to the types of loss config: {} != {}.'
.format(len_logits, len_losses))
def loss_computation(logits, label, losses):
check_logits_losses(logits, losses)
loss = 0
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
loss_i = losses['types'][i](logit, label)
loss += losses['coef'][i] * loss_i
return loss
def train(model,
train_dataset,
places=None,
......@@ -40,7 +63,8 @@ def train(model,
log_iters=10,
num_classes=None,
num_workers=8,
use_vdl=False):
use_vdl=False,
losses=None):
ignore_index = model.ignore_index
nranks = ParallelEnv().nranks
......@@ -90,13 +114,17 @@ def train(model,
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
loss = ddp_model(images, labels)
logits = ddp_model(images)
loss = loss_computation(logits, labels, losses)
# loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss)
loss.backward()
ddp_model.apply_collective_grads()
else:
loss = model(images, labels)
logits = model(images)
loss = loss_computation(logits, labels, losses)
# loss = model(images, labels)
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
......
......@@ -19,6 +19,8 @@ import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
import paddle.nn.functional as F
import paddle
import dygraph.utils.logger as logger
from dygraph.utils import ConfusionMatrix
......@@ -47,7 +49,9 @@ def evaluate(model,
for iter, (im, im_info, label) in tqdm.tqdm(
enumerate(eval_dataset), total=total_iters):
im = to_variable(im)
pred, _ = model(im)
# pred, _ = model(im)
logits = model(im)
pred = paddle.argmax(logits[0], axis=1)
pred = pred.numpy().astype('float32')
pred = np.squeeze(pred)
for info in im_info[::-1]:
......
......@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import manager
from . import param_init
......@@ -115,3 +115,4 @@ MODELS = ComponentManager()
BACKBONES = ComponentManager()
DATASETS = ComponentManager()
TRANSFORMS = ComponentManager()
LOSSES = ComponentManager()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
def constant_init(param, value=0.0):
initializer = fluid.initializer.Constant(value)
initializer(param, param.block)
def normal_init(param, loc=0.0, scale=1.0, seed=0):
initializer = fluid.initializer.Normal(loc=loc, scale=scale, seed=seed)
initializer(param, param.block)
......@@ -13,6 +13,7 @@
# limitations under the License.
from .architectures import *
from .losses import *
from .unet import UNet
from .deeplab import *
from .fcn import *
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import math
import os
import paddle
import paddle.fluid as fluid
......@@ -23,6 +24,8 @@ from paddle.fluid.initializer import Normal
from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.cvlibs import manager
from dygraph.utils import utils
from dygraph.cvlibs import param_init
__all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
......@@ -36,6 +39,7 @@ class HRNet(fluid.dygraph.Layer):
https://arxiv.org/pdf/1908.07919.pdf.
Args:
backbone_pretrained (str): the path of pretrained model.
stage1_num_modules (int): number of modules for stage1. Default 1.
stage1_num_blocks (list): number of blocks per module for stage1. Default [4].
stage1_num_channels (list): number of channels per branch for stage1. Default [64].
......@@ -52,6 +56,7 @@ class HRNet(fluid.dygraph.Layer):
"""
def __init__(self,
backbone_pretrained=None,
stage1_num_modules=1,
stage1_num_blocks=[4],
stage1_num_channels=[64],
......@@ -141,6 +146,8 @@ class HRNet(fluid.dygraph.Layer):
has_se=self.has_se,
name="st4")
self.init_weight(backbone_pretrained)
def forward(self, x, label=None, mode='train'):
input_shape = x.shape[2:]
conv1 = self.conv_layer1_1(x)
......@@ -163,7 +170,31 @@ class HRNet(fluid.dygraph.Layer):
x3 = fluid.layers.resize_bilinear(st4[3], out_shape=(x0_h, x0_w))
x = fluid.layers.concat([st4[0], x1, x2, x3], axis=1)
return x
return [x]
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, 1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, 0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ConvBNLayer(fluid.dygraph.Layer):
......@@ -184,18 +215,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
num_filters,
weight_attr=ParamAttr(
name=bn_name + '_scale',
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
bn_name + '_offset',
initializer=fluid.initializer.Constant(0.0)))
self._batch_norm = BatchNorm(num_filters)
self.act = act
def forward(self, input):
......
......@@ -17,8 +17,9 @@ from __future__ import division
from __future__ import print_function
import math
import numpy as np
import os
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
......@@ -28,6 +29,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = [
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
......@@ -46,6 +48,7 @@ def make_divisible(v, divisor=8, min_value=None):
new_v += divisor
return new_v
def get_padding_same(kernel_size, dilation_rate):
"""
SAME padding implementation given kernel_size and dilation_rate.
......@@ -53,7 +56,7 @@ def get_padding_same(kernel_size, dilation_rate):
(F-(k+(k -1)*(r-1))+2*p)/s + 1 = F_new
where F: a feature map
k: kernel size, r: dilation rate, p: padding value, s: stride
F_new: new feature map
F_new: new feature map
Args:
kernel_size (int)
dilation_rate (int)
......@@ -63,12 +66,19 @@ def get_padding_same(kernel_size, dilation_rate):
"""
k = kernel_size
r = dilation_rate
padding_same = (k + (k - 1) * (r - 1) - 1)//2
padding_same = (k + (k - 1) * (r - 1) - 1) // 2
return padding_same
class MobileNetV3(fluid.dygraph.Layer):
def __init__(self, scale=1.0, model_name="small", class_dim=1000, output_stride=None, **kwargs):
def __init__(self,
backbone_pretrained=None,
scale=1.0,
model_name="small",
class_dim=1000,
output_stride=None,
**kwargs):
super(MobileNetV3, self).__init__()
inplanes = 16
......@@ -77,19 +87,21 @@ class MobileNetV3(fluid.dygraph.Layer):
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1], # output 1 -> out_index=2
[3, 72, 24, False, "relu", 1], # output 1 -> out_index=2
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1], # output 2 -> out_index=5
[5, 120, 40, True, "relu", 1], # output 2 -> out_index=5
[3, 240, 80, False, "hard_swish", 2],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1], # output 3 -> out_index=11
[3, 672, 112, True, "hard_swish",
1], # output 3 -> out_index=11
[5, 672, 160, True, "hard_swish", 2],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1], # output 3 -> out_index=14
[5, 960, 160, True, "hard_swish",
1], # output 3 -> out_index=14
]
self.out_indices = [2, 5, 11, 14]
......@@ -98,17 +110,17 @@ class MobileNetV3(fluid.dygraph.Layer):
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", 2], # output 1 -> out_index=0
[3, 16, 16, True, "relu", 2], # output 1 -> out_index=0
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1], # output 2 -> out_index=3
[3, 88, 24, False, "relu", 1], # output 2 -> out_index=3
[5, 96, 40, True, "hard_swish", 2],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7
[5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7
[5, 288, 96, True, "hard_swish", 2],
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10
[5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10
]
self.out_indices = [0, 3, 7, 10]
......@@ -157,7 +169,6 @@ class MobileNetV3(fluid.dygraph.Layer):
self.add_sublayer(
sublayer=self.block_list[-1], name="conv" + str(i + 2))
inplanes = make_divisible(scale * c)
self.last_second_conv = ConvBNLayer(
in_c=inplanes,
......@@ -189,8 +200,10 @@ class MobileNetV3(fluid.dygraph.Layer):
param_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
self.init_weight(backbone_pretrained)
def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is not None:
......@@ -201,9 +214,9 @@ class MobileNetV3(fluid.dygraph.Layer):
if stride > output_stride:
rate = rate * _cfg[-1]
self.cfg[i][-1] = 1
self.dilation_cfg[i] = rate
def forward(self, inputs, label=None, dropout_prob=0.2):
x = self.conv1(inputs)
# A feature list saves each downsampling feature.
......@@ -223,6 +236,19 @@ class MobileNetV3(fluid.dygraph.Layer):
return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
......@@ -240,7 +266,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = fluid.dygraph.Conv2D(
num_channels=in_c,
num_filters=out_c,
......@@ -263,7 +289,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
name=name + "_bn_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)))
self._act_op = layer_utils.Activation(act=None)
def forward(self, x):
......@@ -304,14 +330,15 @@ class ResidualUnit(fluid.dygraph.Layer):
if_act=True,
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding= get_padding_same(filter_size, dilation), #int((filter_size - 1) // 2) + (dilation - 1),
padding=get_padding_same(
filter_size,
dilation), #int((filter_size - 1) // 2) + (dilation - 1),
dilation=dilation,
num_groups=mid_c,
if_act=True,
......@@ -329,6 +356,7 @@ class ResidualUnit(fluid.dygraph.Layer):
act=None,
name=name + "_linear")
self.dilation = dilation
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
......@@ -386,6 +414,7 @@ def MobileNetV3_small_x0_75(**kwargs):
model = MobileNetV3(model_name="small", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_small_x1_0(**kwargs):
model = MobileNetV3(model_name="small", scale=1.0, **kwargs)
......@@ -411,6 +440,7 @@ def MobileNetV3_large_x0_75(**kwargs):
model = MobileNetV3(model_name="large", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_large_x1_0(**kwargs):
model = MobileNetV3(model_name="large", scale=1.0, **kwargs)
......
......@@ -30,6 +30,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.utils import utils
from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = [
"ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
......@@ -47,18 +48,23 @@ class ConvBNLayer(fluid.dygraph.Layer):
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
name=None,
):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D(
pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True)
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2 if dilation ==1 else 0,
padding=(filter_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation,
groups=groups,
act=None,
......@@ -125,19 +131,20 @@ class BottleneckBlock(fluid.dygraph.Layer):
num_filters=num_filters * 4,
filter_size=1,
stride=1,
is_vd_mode=False if if_first or stride==1 else True,
is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
####################################################################
# If given dilation rate > 1, using corresponding padding
if self.dilation > 1:
padding = self.dilation
y = fluid.layers.pad(y, [0,0,0,0,padding,padding,padding,padding])
y = fluid.layers.pad(
y, [0, 0, 0, 0, padding, padding, padding, padding])
#####################################################################
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
......@@ -196,15 +203,21 @@ class BasicBlock(fluid.dygraph.Layer):
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv1)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
class ResNet_vd(fluid.dygraph.Layer):
def __init__(self, layers=50, class_dim=1000, output_stride=None, multi_grid=(1, 2, 4), **kwargs):
def __init__(self,
backbone_pretrained=None,
layers=50,
class_dim=1000,
output_stride=None,
multi_grid=(1, 2, 4),
**kwargs):
super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
......@@ -221,11 +234,11 @@ class ResNet_vd(fluid.dygraph.Layer):
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_channels = [64, 256, 512, 1024
] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
dilation_dict=None
dilation_dict = None
if output_stride == 8:
dilation_dict = {2: 2, 3: 4}
elif output_stride == 16:
......@@ -254,13 +267,13 @@ class ResNet_vd(fluid.dygraph.Layer):
name="conv1_3")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
# self.block_list = []
self.stage_list = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list=[]
block_list = []
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
......@@ -269,11 +282,12 @@ class ResNet_vd(fluid.dygraph.Layer):
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
###############################################################################
# Add dilation rate for some segmentation tasks, if dilation_dict is not None.
dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1
dilation_rate = dilation_dict[
block] if dilation_dict and block in dilation_dict else 1
# Actually block here is 'stage', and i is 'block' in 'stage'
# At the stage 4, expand the the dilation_rate using multi_grid, default (1, 2, 4)
if block == 3:
......@@ -284,9 +298,11 @@ class ResNet_vd(fluid.dygraph.Layer):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels[block] if i == 0 else num_filters[block] * 4,
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1,
stride=2 if i == 0 and block != 0
and dilation_rate == 1 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
......@@ -298,7 +314,7 @@ class ResNet_vd(fluid.dygraph.Layer):
else:
for block in range(len(depth)):
shortcut = False
block_list=[]
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
......@@ -330,6 +346,8 @@ class ResNet_vd(fluid.dygraph.Layer):
name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
self.init_weight(backbone_pretrained)
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
......@@ -343,7 +361,7 @@ class ResNet_vd(fluid.dygraph.Layer):
y = block(y)
#print("stage {} block {}".format(i+1, j+1), y.shape)
feat_list.append(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
......@@ -355,8 +373,18 @@ class ResNet_vd(fluid.dygraph.Layer):
# if os.path.exists(pretrained_model):
# utils.load_pretrained_model(self, pretrained_model)
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def ResNet18_vd(**args):
......@@ -368,11 +396,13 @@ def ResNet34_vd(**args):
model = ResNet_vd(layers=34, **args)
return model
@manager.BACKBONES.add_component
def ResNet50_vd(**args):
model = ResNet_vd(layers=50, **args)
return model
@manager.BACKBONES.add_component
def ResNet101_vd(**args):
model = ResNet_vd(layers=101, **args)
......@@ -386,4 +416,4 @@ def ResNet152_vd(**args):
def ResNet200_vd(**args):
model = ResNet_vd(layers=200, **args)
return model
\ No newline at end of file
return model
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
......@@ -7,6 +23,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
......@@ -86,11 +103,11 @@ class ConvBNLayer(fluid.dygraph.Layer):
momentum=0.99,
weight_attr=ParamAttr(name=name + "/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/BatchNorm/beta"))
self._act_op = layer_utils.Activation(act=act)
def forward(self, inputs):
return self._act_op(self._bn(self._conv(inputs)))
......@@ -121,7 +138,7 @@ class Seperate_Conv(fluid.dygraph.Layer):
momentum=0.99,
weight_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"))
self._act_op1 = layer_utils.Activation(act=act)
self._conv2 = Conv2D(
......@@ -139,9 +156,8 @@ class Seperate_Conv(fluid.dygraph.Layer):
momentum=0.99,
weight_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta"))
self._act_op2 = layer_utils.Activation(act=act)
def forward(self, inputs):
x = self._conv1(inputs)
......@@ -254,11 +270,16 @@ class Xception_Block(fluid.dygraph.Layer):
class XceptionDeeplab(fluid.dygraph.Layer):
#def __init__(self, backbone, class_dim=1000):
# add output_stride
def __init__(self, backbone, output_stride=16, class_dim=1000, **kwargs):
def __init__(self,
backbone,
backbone_pretrained=None,
output_stride=16,
class_dim=1000,
**kwargs):
super(XceptionDeeplab, self).__init__()
bottleneck_params = gen_bottleneck_params(backbone)
......@@ -280,7 +301,6 @@ class XceptionDeeplab(fluid.dygraph.Layer):
padding=1,
act="relu",
name=self.backbone + "/entry_flow/conv2")
"""
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
......@@ -381,6 +401,8 @@ class XceptionDeeplab(fluid.dygraph.Layer):
param_attr=ParamAttr(name="fc_weights"),
bias_attr=ParamAttr(name="fc_bias"))
self.init_weight(backbone_pretrained)
def forward(self, inputs):
x = self._conv1(inputs)
x = self._conv2(x)
......@@ -394,18 +416,32 @@ class XceptionDeeplab(fluid.dygraph.Layer):
x = self._exit_flow_1(x)
x = self._exit_flow_2(x)
feat_list.append(x)
x = self._drop(x)
x = self._pool(x)
x = fluid.layers.squeeze(x, axes=[2, 3])
x = self._fc(x)
return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def Xception41_deeplab(**args):
model = XceptionDeeplab('xception_41', **args)
return model
@manager.BACKBONES.add_component
def Xception65_deeplab(**args):
model = XceptionDeeplab("xception_65", **args)
......@@ -414,4 +450,4 @@ def Xception65_deeplab(**args):
def Xception71_deeplab(**args):
model = XceptionDeeplab("xception_71", **args)
return model
\ No newline at end of file
return model
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dygraph.cvlibs import manager
......@@ -23,10 +22,12 @@ from paddle.fluid.dygraph import Conv2D
from dygraph.utils import utils
__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
"deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8",
"deeplabv3p_xception65_deeplab",
"deeplabv3p_mobilenetv3_large", "deeplabv3p_mobilenetv3_small"]
__all__ = [
'DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
"deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8",
"deeplabv3p_xception65_deeplab", "deeplabv3p_mobilenetv3_large",
"deeplabv3p_mobilenetv3_small"
]
class ImageAverage(dygraph.Layer):
......@@ -40,9 +41,8 @@ class ImageAverage(dygraph.Layer):
def __init__(self, num_channels):
super(ImageAverage, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels,
num_filters=256,
filter_size=1)
self.conv_bn_relu = layer_utils.ConvBnRelu(
num_channels, num_filters=256, filter_size=1)
def forward(self, input):
x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True)
......@@ -69,44 +69,49 @@ class ASPP(dygraph.Layer):
elif output_stride == 8:
aspp_ratios = (12, 24, 36)
else:
raise NotImplementedError("Only support output_stride is 8 or 16, but received{}".format(output_stride))
raise NotImplementedError(
"Only support output_stride is 8 or 16, but received{}".format(
output_stride))
self.image_average = ImageAverage(num_channels=in_channels)
# The first aspp using 1*1 conv
self.aspp1 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=1,
using_sep_conv=False)
self.aspp1 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=1,
using_sep_conv=False)
# The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0]
self.aspp2 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[0],
padding=aspp_ratios[0])
self.aspp2 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[0],
padding=aspp_ratios[0])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1]
self.aspp3 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[1],
padding=aspp_ratios[1])
self.aspp3 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[1],
padding=aspp_ratios[1])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2]
self.aspp4 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[2],
padding=aspp_ratios[2])
self.aspp4 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
dilation=aspp_ratios[2],
padding=aspp_ratios[2])
# After concat op, using 1*1 conv
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280,
num_filters=256,
filter_size=1)
self.conv_bn_relu = layer_utils.ConvBnRelu(
num_channels=1280, num_filters=256, filter_size=1)
def forward(self, x):
......@@ -136,23 +141,23 @@ class Decoder(dygraph.Layer):
def __init__(self, num_classes, in_channels, using_sep_conv=True):
super(Decoder, self).__init__()
self.conv_bn_relu1 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=48,
filter_size=1)
self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=304,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv_bn_relu3 = layer_utils.ConvBnRelu(num_channels=256,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv = Conv2D(num_channels=256,
num_filters=num_classes,
filter_size=1)
self.conv_bn_relu1 = layer_utils.ConvBnRelu(
num_channels=in_channels, num_filters=48, filter_size=1)
self.conv_bn_relu2 = layer_utils.ConvBnRelu(
num_channels=304,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv_bn_relu3 = layer_utils.ConvBnRelu(
num_channels=256,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv = Conv2D(
num_channels=256, num_filters=num_classes, filter_size=1)
def forward(self, x, low_level_feat):
low_level_feat = self.conv_bn_relu1(low_level_feat)
......@@ -164,6 +169,7 @@ class Decoder(dygraph.Layer):
return x
@manager.MODELS.add_component
class DeepLabV3P(dygraph.Layer):
"""
The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder
......@@ -173,9 +179,11 @@ class DeepLabV3P(dygraph.Layer):
(https://arxiv.org/abs/1802.02611)
Args:
backbone (str): backbone name, currently support Xception65, Resnet101_vd. Default Resnet101_vd.
num_classes (int): the unique number of target classes.
backbone (paddle.nn.Layer): backbone networks, currently support Xception65, Resnet101_vd. Default Resnet101_vd.
num_classes (int): the unique number of target classes. Default 2.
model_pretrained (str): the path of pretrained model.
output_stride (int): the ratio of input size and final feature size. Default 16.
......@@ -193,28 +201,29 @@ class DeepLabV3P(dygraph.Layer):
using_sep_conv (bool): a bool value indicates whether using separable convolutions
in ASPP and Decoder components. Default True.
pretrained_model (str): the pretrained_model path of backbone.
"""
def __init__(self,
num_classes,
backbone,
num_classes=2,
model_pretrained=None,
output_stride=16,
backbone_indices=(0, 3),
backbone_channels=(256, 2048),
ignore_index=255,
using_sep_conv=True,
pretrained_model=None):
using_sep_conv=True):
super(DeepLabV3P, self).__init__()
self.backbone = manager.BACKBONES[backbone](output_stride=output_stride)
# self.backbone = manager.BACKBONES[backbone](output_stride=output_stride)
self.backbone = backbone
self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv)
self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv)
self.decoder = Decoder(num_classes, backbone_channels[0],
using_sep_conv)
self.ignore_index = ignore_index
self.EPS = 1e-5
self.backbone_indices = backbone_indices
self.init_weight(pretrained_model)
self.init_weight(model_pretrained)
def forward(self, input, label=None):
......@@ -238,14 +247,14 @@ class DeepLabV3P(dygraph.Layer):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
# utils.load_pretrained_model(self, pretrained_model)
# for param in self.backbone.parameters():
# param.stop_gradient = True
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label):
"""
......@@ -271,7 +280,7 @@ class DeepLabV3P(dygraph.Layer):
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
......@@ -290,52 +299,65 @@ def build_decoder(num_classes, using_sep_conv):
@manager.MODELS.add_component
def deeplabv3p_resnet101_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet101_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet50_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet50_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_xception65_deeplab(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='Xception65_deeplab',
pretrained_model=pretrained_model,
backbone_indices=(0, 1),
backbone_channels=(128, 2048),
**kwargs)
return DeepLabV3P(
backbone='Xception65_deeplab',
pretrained_model=pretrained_model,
backbone_indices=(0, 1),
backbone_channels=(128, 2048),
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_large(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_large_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(24, 160),
**kwargs)
return DeepLabV3P(
backbone='MobileNetV3_large_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(24, 160),
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_small(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_small_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(16, 96),
**kwargs)
return DeepLabV3P(
backbone='MobileNetV3_small_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(16, 96),
**kwargs)
......@@ -25,6 +25,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.cvlibs import manager
from dygraph import utils
from dygraph.cvlibs import param_init
__all__ = [
"fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18",
......@@ -33,114 +34,133 @@ __all__ = [
]
@manager.MODELS.add_component
class FCN(fluid.dygraph.Layer):
"""
Fully Convolutional Networks for Semantic Segmentation.
https://arxiv.org/abs/1411.4038
Args:
backbone (str): backbone name,
num_classes (int): the unique number of target classes.
in_channels (int): the channels of input feature maps.
backbone (paddle.nn.Layer): backbone networks.
model_pretrained (str): the path of pretrained model.
backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
channels (int): channels after conv layer before the last one.
pretrained_model (str): the path of pretrained model.
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
"""
def __init__(self,
backbone,
num_classes,
in_channels,
backbone,
model_pretrained=None,
backbone_indices=(-1, ),
backbone_channels=(270, ),
channels=None,
pretrained_model=None,
ignore_index=255,
**kwargs):
super(FCN, self).__init__()
self.num_classes = num_classes
self.backbone_indices = backbone_indices
self.ignore_index = ignore_index
self.EPS = 1e-5
if channels is None:
channels = in_channels
channels = backbone_channels[backbone_indices[0]]
self.backbone = manager.BACKBONES[backbone](**kwargs)
self.backbone = backbone
self.conv_last_2 = ConvBNLayer(
num_channels=in_channels,
num_channels=backbone_channels[backbone_indices[0]],
num_filters=channels,
filter_size=1,
stride=1,
name='conv-2')
stride=1)
self.conv_last_1 = Conv2D(
num_channels=channels,
num_filters=self.num_classes,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name='conv-1_weights'))
self.init_weight(pretrained_model)
padding=0)
self.init_weight(model_pretrained)
def forward(self, x, label=None, mode='train'):
def forward(self, x):
input_shape = x.shape[2:]
x = self.backbone(x)
fea_list = self.backbone(x)
x = fea_list[self.backbone_indices[0]]
x = self.conv_last_2(x)
logit = self.conv_last_1(x)
logit = fluid.layers.resize_bilinear(logit, input_shape)
if self.training:
if label is None:
raise Exception('Label is need during training')
return self._get_loss(logit, label)
else:
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
return [logit]
# if self.training:
# if label is None:
# raise Exception('Label is need during training')
# return self._get_loss(logit, label)
# else:
# score_map = fluid.layers.softmax(logit, axis=1)
# score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
# pred = fluid.layers.argmax(score_map, axis=3)
# pred = fluid.layers.unsqueeze(pred, axes=[3])
# return pred, score_map
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, 1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, 0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
# def _get_loss(self, logit, label):
# """
# compute forward loss of the model
# Args:
# logit (tensor): the logit of model output
# label (tensor): ground truth
# Returns:
# avg_loss (tensor): forward loss
# """
# logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
# label = fluid.layers.transpose(label, [0, 2, 3, 1])
# mask = label != self.ignore_index
# mask = fluid.layers.cast(mask, 'float32')
# loss, probs = fluid.layers.softmax_with_cross_entropy(
# logit,
# label,
# ignore_index=self.ignore_index,
# return_softmax=True,
# axis=-1)
# loss = loss * mask
# avg_loss = fluid.layers.mean(loss) / (
# fluid.layers.mean(mask) + self.EPS)
# label.stop_gradient = True
# mask.stop_gradient = True
# return avg_loss
class ConvBNLayer(fluid.dygraph.Layer):
......@@ -150,8 +170,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
filter_size,
stride=1,
groups=1,
act="relu",
name=None):
act="relu"):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
......@@ -161,18 +180,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
num_filters,
weight_attr=ParamAttr(
name=bn_name + '_scale',
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
bn_name + '_offset',
initializer=fluid.initializer.Constant(0.0)))
self._batch_norm = BatchNorm(num_filters)
self.act = act
def forward(self, input):
......@@ -185,49 +194,49 @@ class ConvBNLayer(fluid.dygraph.Layer):
@manager.MODELS.add_component
def fcn_hrnet_w18_small_v1(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V1', in_channels=240, **kwargs)
return FCN(backbone='HRNet_W18_Small_V1', backbone_channels=(240), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w18_small_v2(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V2', in_channels=270, **kwargs)
return FCN(backbone='HRNet_W18_Small_V2', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w18(*args, **kwargs):
return FCN(backbone='HRNet_W18', in_channels=270, **kwargs)
return FCN(backbone='HRNet_W18', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w30(*args, **kwargs):
return FCN(backbone='HRNet_W30', in_channels=450, **kwargs)
return FCN(backbone='HRNet_W30', backbone_channels=(450), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w32(*args, **kwargs):
return FCN(backbone='HRNet_W32', in_channels=480, **kwargs)
return FCN(backbone='HRNet_W32', backbone_channels=(480), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w40(*args, **kwargs):
return FCN(backbone='HRNet_W40', in_channels=600, **kwargs)
return FCN(backbone='HRNet_W40', backbone_channels=(600), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w44(*args, **kwargs):
return FCN(backbone='HRNet_W44', in_channels=660, **kwargs)
return FCN(backbone='HRNet_W44', backbone_channels=(660), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w48(*args, **kwargs):
return FCN(backbone='HRNet_W48', in_channels=720, **kwargs)
return FCN(backbone='HRNet_W48', backbone_channels=(720), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w60(*args, **kwargs):
return FCN(backbone='HRNet_W60', in_channels=900, **kwargs)
return FCN(backbone='HRNet_W60', backbone_channels=(900), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w64(*args, **kwargs):
return FCN(backbone='HRNet_W64', in_channels=960, **kwargs)
return FCN(backbone='HRNet_W64', backbone_channels=(960), **kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cross_entroy_loss import CrossEntropyLoss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from dygraph.cvlibs import manager
@manager.LOSSES.add_component
class CrossEntropyLoss(nn.CrossEntropyLoss):
"""
Implements the cross entropy loss function.
Args:
weight (Tensor): Weight tensor, a manual rescaling weight given
to each class and the shape is (C). It has the same dimensions as class
number and the data type is float32, float64. Default ``'None'``.
ignore_index (int64): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
reduction (str): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default ``'mean'``.
"""
def __init__(self, weight=None, ignore_index=255, reduction='mean'):
super(CrossEntropyLoss, self).__init__(
weight=weight, ignore_index=ignore_index, reduction=reduction)
self.EPS = 1e-5
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
" 'none', but received %s, which is not allowed." %
self.reduction)
def forward(self, logit, label):
"""
Forward computation.
Args:
logit (Tensor): logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Variable): label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
"""
loss = paddle.nn.functional.cross_entropy(
logit,
label,
weight=self.weight,
ignore_index=self.ignore_index,
reduction=self.reduction)
mask = label != self.ignore_index
mask = paddle.cast(mask, 'float32')
avg_loss = loss / (paddle.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
......@@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
from paddle.fluid.dygraph import Sequential, Conv2D
from dygraph.cvlibs import manager
from dygraph.models.architectures.layer_utils import ConvBnRelu
from dygraph import utils
class SpatialGatherBlock(fluid.dygraph.Layer):
......@@ -116,8 +119,9 @@ class ObjectAttentionBlock(fluid.dygraph.Layer):
class OCRNet(fluid.dygraph.Layer):
def __init__(self,
num_classes,
in_channels,
backbone,
model_pretrained=None,
in_channels=None,
ocr_mid_channels=512,
ocr_key_channels=256,
ignore_index=255):
......@@ -139,6 +143,8 @@ class OCRNet(fluid.dygraph.Layer):
ConvBnRelu(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1))
self.init_weight(model_pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
......@@ -164,6 +170,19 @@ class OCRNet(fluid.dygraph.Layer):
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model.. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label):
"""
compute forward loss of the model
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.nn.functional as F
......@@ -29,15 +28,17 @@ class PSPNet(fluid.dygraph.Layer):
"""
The PSPNet implementation
The orginal artile refers to
Zhao, Hengshuang, et al. "Pyramid scene parsing network."
The orginal artile refers to
Zhao, Hengshuang, et al. "Pyramid scene parsing network."
Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
(https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Pyramid_Scene_Parsing_CVPR_2017_paper.pdf)
Args:
backbone (str): backbone name, currently support Resnet50/101.
num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone name, currently support Resnet50/101.
num_classes (int): the unique number of target classes. Default 2.
model_pretrained (str): the path of pretrained model.
output_stride (int): the ratio of input size and final feature size. Default 16.
......@@ -57,42 +58,44 @@ class PSPNet(fluid.dygraph.Layer):
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default to 255.
pretrained_model (str): the pretrained_model path of backbone.
"""
def __init__(self,
num_classes,
backbone,
num_classes=2,
model_pretrained=None,
output_stride=16,
backbone_indices=(2, 3),
backbone_channels=(1024, 2048),
pp_out_channels=1024,
bin_sizes=(1, 2, 3, 6),
enable_auxiliary_loss=True,
ignore_index=255,
pretrained_model=None):
ignore_index=255):
super(PSPNet, self).__init__()
self.backbone = manager.BACKBONES[backbone](output_stride=output_stride,
multi_grid=(1, 1, 1))
# self.backbone = manager.BACKBONES[backbone](output_stride=output_stride,
# multi_grid=(1, 1, 1))
self.backbone = backbone
self.backbone_indices = backbone_indices
self.psp_module = PPModule(in_channels=backbone_channels[1],
out_channels=pp_out_channels,
bin_sizes=bin_sizes)
self.psp_module = PPModule(
in_channels=backbone_channels[1],
out_channels=pp_out_channels,
bin_sizes=bin_sizes)
self.conv = Conv2D(num_channels=pp_out_channels,
num_filters=num_classes,
filter_size=1)
self.conv = Conv2D(
num_channels=pp_out_channels,
num_filters=num_classes,
filter_size=1)
if enable_auxiliary_loss:
self.fcn_head = model_utils.FCNHead(in_channels=backbone_channels[0], out_channels=num_classes)
self.fcn_head = model_utils.FCNHead(
in_channels=backbone_channels[0], out_channels=num_classes)
self.enable_auxiliary_loss = enable_auxiliary_loss
self.ignore_index = ignore_index
self.init_weight(pretrained_model)
self.init_weight(model_pretrained)
def forward(self, input, label=None):
......@@ -107,7 +110,8 @@ class PSPNet(fluid.dygraph.Layer):
if self.enable_auxiliary_loss:
auxiliary_feat = feat_list[self.backbone_indices[0]]
auxiliary_logit = self.fcn_head(auxiliary_feat)
auxiliary_logit = fluid.layers.resize_bilinear(auxiliary_logit, input.shape[2:])
auxiliary_logit = fluid.layers.resize_bilinear(
auxiliary_logit, input.shape[2:])
if self.training:
loss = model_utils.get_loss(logit, label)
......@@ -116,7 +120,6 @@ class PSPNet(fluid.dygraph.Layer):
loss += (0.4 * auxiliary_loss)
return loss
else:
pred, score_map = model_utils.get_pred_score_map(logit)
return pred, score_map
......@@ -124,14 +127,15 @@ class PSPNet(fluid.dygraph.Layer):
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class PPModule(fluid.dygraph.Layer):
......@@ -151,19 +155,21 @@ class PPModule(fluid.dygraph.Layer):
self.bin_sizes = bin_sizes
# we use dimension reduction after pooling mentioned in original implementation.
self.stages = fluid.dygraph.LayerList([self._make_stage(in_channels, size) for size in bin_sizes])
self.stages = fluid.dygraph.LayerList(
[self._make_stage(in_channels, size) for size in bin_sizes])
self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=in_channels * 2,
num_filters=out_channels,
filter_size=3,
padding=1)
self.conv_bn_relu2 = layer_utils.ConvBnRelu(
num_channels=in_channels * 2,
num_filters=out_channels,
filter_size=3,
padding=1)
def _make_stage(self, in_channels, size):
"""
Create one pooling layer.
In our implementation, we adopt the same dimention reduction as the original paper that might be
slightly different with other implementations.
slightly different with other implementations.
After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations
keep the channels to be same.
......@@ -180,9 +186,10 @@ class PPModule(fluid.dygraph.Layer):
# this paddle version does not support AdaptiveAvgPool2d, so skip it here.
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=in_channels // len(self.bin_sizes),
filter_size=1)
conv = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=in_channels // len(self.bin_sizes),
filter_size=1)
return conv
......@@ -190,7 +197,8 @@ class PPModule(fluid.dygraph.Layer):
cat_layers = []
for i, stage in enumerate(self.stages):
size = self.bin_sizes[i]
x = fluid.layers.adaptive_pool2d(input, pool_size=(size, size), pool_type="max")
x = fluid.layers.adaptive_pool2d(
input, pool_size=(size, size), pool_type="max")
x = stage(x)
x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:])
cat_layers.append(x)
......@@ -204,22 +212,32 @@ class PPModule(fluid.dygraph.Layer):
@manager.MODELS.add_component
def pspnet_resnet101_vd(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def pspnet_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet101_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component
def pspnet_resnet50_vd(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def pspnet_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet50_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
......@@ -33,7 +33,7 @@ class UNet(fluid.dygraph.Layer):
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
"""
def __init__(self, num_classes, pretrained_model=None, ignore_index=255):
def __init__(self, num_classes, model_pretrained=None, ignore_index=255):
super(UNet, self).__init__()
self.encode = UnetEncoder()
self.decode = UnetDecode()
......@@ -41,7 +41,7 @@ class UNet(fluid.dygraph.Layer):
self.ignore_index = ignore_index
self.EPS = 1e-5
self.init_weight(pretrained_model)
self.init_weight(model_pretrained)
def forward(self, x, label=None):
encode_data, short_cuts = self.encode(x)
......@@ -60,7 +60,7 @@ class UNet(fluid.dygraph.Layer):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
......
......@@ -110,6 +110,8 @@ def main(args):
val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss
train(
cfg.model,
train_dataset,
......@@ -123,7 +125,8 @@ def main(args):
log_iters=args.log_iters,
num_classes=train_dataset.num_classes,
num_workers=args.num_workers,
use_vdl=args.use_vdl)
use_vdl=args.use_vdl,
losses=losses)
if __name__ == '__main__':
......
......@@ -55,8 +55,8 @@ class Compose:
if len(outputs) == 3:
label = outputs[2]
im = permute(im)
if len(outputs) == 3:
label = label[np.newaxis, :, :]
# if len(outputs) == 3:
# label = label[np.newaxis, :, :]
return (im, im_info, label)
......
......@@ -35,15 +35,34 @@ class Config(object):
raise FileNotFoundError('File {} does not exist'.format(path))
if path.endswith('yml') or path.endswith('yaml'):
self._parse_from_yaml(path)
dic = self._parse_from_yaml(path)
self._build(dic)
else:
raise RuntimeError('Config file should in yaml format!')
def _update_dic(self, dic, base_dic):
"""
update dic from base_dic
"""
dic = dic.copy()
for key, val in base_dic.items():
if isinstance(val, dict) and key in dic:
dic[key] = self._update_dic(dic[key], val)
else:
dic[key] = val
return dic
def _parse_from_yaml(self, path: str):
'''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
self._build(dic)
if '_base_' in dic:
cfg_dir = os.path.dirname(path)
base_path = dic.pop('_base_')
base_path = os.path.join(cfg_dir, base_path)
base_dic = self._parse_from_yaml(base_path)
dic = self._update_dic(dic, base_dic)
return dic
def _build(self, dic: dict):
'''Build config from dictionary'''
......@@ -68,6 +87,7 @@ class Config(object):
})
self._loss_cfg = dic.get('loss', {})
self._losses = None
self._optimizer_cfg = dic.get('optimizer', {})
......@@ -145,14 +165,23 @@ class Config(object):
return args
@property
def loss_type(self) -> str:
...
@property
def loss_args(self) -> dict:
args = self._loss_cfg.copy()
args.pop('type')
return args
def loss(self) -> list:
if not self._losses:
args = self._loss_cfg.copy()
self._losses = dict()
for key, val in args.items():
if key == 'types':
self._losses['types'] = []
for item in args['types']:
self._losses['types'].append(self._load_object(item))
else:
self._losses[key] = val
if len(self._losses['coef']) != len(self._losses['types']):
raise RuntimeError(
'The length of coef should equal to types in loss config: {} != {}.'
.format(
len(self._losses['coef']), len(self._losses['types'])))
return self._losses
@property
def model(self) -> Callable:
......@@ -175,7 +204,7 @@ class Config(object):
def _load_component(self, com_name: str) -> Any:
com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS
manager.TRANSFORMS, manager.LOSSES
]
for com in com_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册