未验证 提交 fa5b0a0e 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #376 from wuyefeilin/dygraph

......@@ -12,4 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dygraph.models
\ No newline at end of file
from . import models
from . import datasets
from . import transforms
batch_size: 4
iters: 100000
learning_rate: 0.01
train_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
mode: val
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
types:
- type: CrossEntropyLoss
coef: [1]
batch_size: 4
iters: 10000
learning_rate: 0.01
train_dataset:
type: OpticDiscSeg
dataset_root: data/optic_disc_seg
transforms:
- type: Resize
target_size: [512, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: OpticDiscSeg
dataset_root: data/optic_disc_seg
transforms:
- type: Resize
target_size: [512, 512]
- type: Normalize
mode: val
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
types:
- type: CrossEntropyLoss
coef: [1]
_base_: '../_base_/cityscapes.yml'
model:
type: FCN
backbone:
type: HRNet_W18
backbone_pretrained: pretrained_model/hrnet_w18_imagenet
num_classes: 19
backbone_channels: [270]
_base_: '../_base_/optic_disc_seg.yml'
model:
type: FCN
backbone:
type: HRNet_W18
backbone_pretrained: pretrained_model/hrnet_w18_imagenet
num_classes: 2
backbone_channels: [270]
batch_size: 2
iters: 40000
train_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: RandomHorizontalFlip
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: data/cityscapes
transforms:
- type: Normalize
mode: val
model:
type: OCRNet
backbone:
type: HRNet_W18
backbone_pretrianed: None
num_classes: 19
in_channels: 270
model_pretrained: None
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
type: CrossEntropy
......@@ -14,11 +14,13 @@
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
# from paddle.incubate.hapi.distributed import DistributedBatchSampler
from paddle.io import DistributedBatchSampler
import paddle.nn.functional as F
import dygraph.utils.logger as logger
from dygraph.utils import load_pretrained_model
......@@ -27,6 +29,27 @@ from dygraph.utils import Timer, calculate_eta
from .val import evaluate
def check_logits_losses(logits, losses):
len_logits = len(logits)
len_losses = len(losses['types'])
if len_logits != len_losses:
raise RuntimeError(
'The length of logits should equal to the types of loss config: {} != {}.'
.format(len_logits, len_losses))
def loss_computation(logits, label, losses):
check_logits_losses(logits, losses)
loss = 0
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
loss_i = losses['types'][i](logit, label)
loss += losses['coef'][i] * loss_i
return loss
def train(model,
train_dataset,
places=None,
......@@ -40,7 +63,8 @@ def train(model,
log_iters=10,
num_classes=None,
num_workers=8,
use_vdl=False):
use_vdl=False,
losses=None):
ignore_index = model.ignore_index
nranks = ParallelEnv().nranks
......@@ -90,13 +114,17 @@ def train(model,
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
loss = ddp_model(images, labels)
logits = ddp_model(images)
loss = loss_computation(logits, labels, losses)
# loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss)
loss.backward()
ddp_model.apply_collective_grads()
else:
loss = model(images, labels)
logits = model(images)
loss = loss_computation(logits, labels, losses)
# loss = model(images, labels)
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
......
......@@ -19,6 +19,8 @@ import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
import paddle.nn.functional as F
import paddle
import dygraph.utils.logger as logger
from dygraph.utils import ConfusionMatrix
......@@ -47,7 +49,9 @@ def evaluate(model,
for iter, (im, im_info, label) in tqdm.tqdm(
enumerate(eval_dataset), total=total_iters):
im = to_variable(im)
pred, _ = model(im)
# pred, _ = model(im)
logits = model(im)
pred = paddle.argmax(logits[0], axis=1)
pred = pred.numpy().astype('float32')
pred = np.squeeze(pred)
for info in im_info[::-1]:
......
......@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import manager
from . import param_init
......@@ -49,14 +49,15 @@ class ComponentManager:
return len(self._components_dict)
def __repr__(self):
return "{}:{}".format(self.__class__.__name__, list(self._components_dict.keys()))
return "{}:{}".format(self.__class__.__name__,
list(self._components_dict.keys()))
def __getitem__(self, item):
if item not in self._components_dict.keys():
raise KeyError("{} does not exist in the current {}".format(item, self))
raise KeyError("{} does not exist in the current {}".format(
item, self))
return self._components_dict[item]
@property
def components_dict(self):
return self._components_dict
......@@ -74,7 +75,9 @@ class ComponentManager:
# Currently only support class or function type
if not (inspect.isclass(component) or inspect.isfunction(component)):
raise TypeError("Expect class/function type, but received {}".format(type(component)))
raise TypeError(
"Expect class/function type, but received {}".format(
type(component)))
# Obtain the internal name of the component
component_name = component.__name__
......@@ -107,5 +110,9 @@ class ComponentManager:
return components
MODELS = ComponentManager()
BACKBONES = ComponentManager()
DATASETS = ComponentManager()
TRANSFORMS = ComponentManager()
LOSSES = ComponentManager()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
def constant_init(param, **kwargs):
initializer = fluid.initializer.Constant(**kwargs)
initializer(param, param.block)
def normal_init(param, **kwargs):
initializer = fluid.initializer.Normal(**kwargs)
initializer(param, param.block)
......@@ -19,11 +19,14 @@ from PIL import Image
from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
@manager.DATASETS.add_component
class ADE20K(Dataset):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args:
......@@ -39,7 +42,7 @@ class ADE20K(Dataset):
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.mode = mode
self.file_list = list()
self.num_classes = 150
......
......@@ -16,8 +16,11 @@ import os
import glob
from .dataset import Dataset
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
@manager.DATASETS.add_component
class Cityscapes(Dataset):
"""Cityscapes dataset `https://www.cityscapes-dataset.com/`.
The folder structure is as follow:
......@@ -42,7 +45,7 @@ class Cityscapes(Dataset):
def __init__(self, dataset_root, transforms=None, mode='train'):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = 19
......
......@@ -17,8 +17,12 @@ import os
import paddle.fluid as fluid
import numpy as np
from PIL import Image
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
@manager.DATASETS.add_component
class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format.
......@@ -52,7 +56,7 @@ class Dataset(fluid.io.Dataset):
separator=' ',
transforms=None):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = num_classes
......
......@@ -16,11 +16,14 @@ import os
from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
@manager.DATASETS.add_component
class OpticDiscSeg(Dataset):
def __init__(self,
dataset_root=None,
......@@ -28,7 +31,7 @@ class OpticDiscSeg(Dataset):
mode='train',
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = 2
......
......@@ -13,13 +13,17 @@
# limitations under the License.
import os
from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
@manager.DATASETS.add_component
class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
......@@ -36,7 +40,7 @@ class PascalVOC(Dataset):
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.mode = mode
self.file_list = list()
self.num_classes = 21
......
......@@ -13,7 +13,9 @@
# limitations under the License.
from .architectures import *
from .losses import *
from .unet import UNet
from .deeplab import *
from .fcn import *
from .pspnet import *
from .ocrnet import *
......@@ -13,6 +13,7 @@
# limitations under the License.
import math
import os
import paddle
import paddle.fluid as fluid
......@@ -23,6 +24,8 @@ from paddle.fluid.initializer import Normal
from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.cvlibs import manager
from dygraph.utils import utils
from dygraph.cvlibs import param_init
__all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
......@@ -36,6 +39,7 @@ class HRNet(fluid.dygraph.Layer):
https://arxiv.org/pdf/1908.07919.pdf.
Args:
backbone_pretrained (str): the path of pretrained model.
stage1_num_modules (int): number of modules for stage1. Default 1.
stage1_num_blocks (list): number of blocks per module for stage1. Default [4].
stage1_num_channels (list): number of channels per branch for stage1. Default [64].
......@@ -52,6 +56,7 @@ class HRNet(fluid.dygraph.Layer):
"""
def __init__(self,
backbone_pretrained=None,
stage1_num_modules=1,
stage1_num_blocks=[4],
stage1_num_channels=[64],
......@@ -141,6 +146,9 @@ class HRNet(fluid.dygraph.Layer):
has_se=self.has_se,
name="st4")
if self.training:
self.init_weight(backbone_pretrained)
def forward(self, x, label=None, mode='train'):
input_shape = x.shape[2:]
conv1 = self.conv_layer1_1(x)
......@@ -163,7 +171,31 @@ class HRNet(fluid.dygraph.Layer):
x3 = fluid.layers.resize_bilinear(st4[3], out_shape=(x0_h, x0_w))
x = fluid.layers.concat([st4[0], x1, x2, x3], axis=1)
return x
return [x]
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, 1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, 0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ConvBNLayer(fluid.dygraph.Layer):
......@@ -184,18 +216,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
num_filters,
weight_attr=ParamAttr(
name=bn_name + '_scale',
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
bn_name + '_offset',
initializer=fluid.initializer.Constant(0.0)))
self._batch_norm = BatchNorm(num_filters)
self.act = act
def forward(self, input):
......
......@@ -17,8 +17,9 @@ from __future__ import division
from __future__ import print_function
import math
import numpy as np
import os
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
......@@ -28,6 +29,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = [
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
......@@ -46,6 +48,7 @@ def make_divisible(v, divisor=8, min_value=None):
new_v += divisor
return new_v
def get_padding_same(kernel_size, dilation_rate):
"""
SAME padding implementation given kernel_size and dilation_rate.
......@@ -63,12 +66,19 @@ def get_padding_same(kernel_size, dilation_rate):
"""
k = kernel_size
r = dilation_rate
padding_same = (k + (k - 1) * (r - 1) - 1)//2
padding_same = (k + (k - 1) * (r - 1) - 1) // 2
return padding_same
class MobileNetV3(fluid.dygraph.Layer):
def __init__(self, scale=1.0, model_name="small", class_dim=1000, output_stride=None, **kwargs):
def __init__(self,
backbone_pretrained=None,
scale=1.0,
model_name="small",
class_dim=1000,
output_stride=None,
**kwargs):
super(MobileNetV3, self).__init__()
inplanes = 16
......@@ -86,10 +96,12 @@ class MobileNetV3(fluid.dygraph.Layer):
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1], # output 3 -> out_index=11
[3, 672, 112, True, "hard_swish",
1], # output 3 -> out_index=11
[5, 672, 160, True, "hard_swish", 2],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1], # output 3 -> out_index=14
[5, 960, 160, True, "hard_swish",
1], # output 3 -> out_index=14
]
self.out_indices = [2, 5, 11, 14]
......@@ -158,7 +170,6 @@ class MobileNetV3(fluid.dygraph.Layer):
sublayer=self.block_list[-1], name="conv" + str(i + 2))
inplanes = make_divisible(scale * c)
self.last_second_conv = ConvBNLayer(
in_c=inplanes,
out_c=make_divisible(scale * self.cls_ch_squeeze),
......@@ -189,6 +200,8 @@ class MobileNetV3(fluid.dygraph.Layer):
param_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
self.init_weight(backbone_pretrained)
def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
......@@ -223,6 +236,19 @@ class MobileNetV3(fluid.dygraph.Layer):
return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
......@@ -305,13 +331,14 @@ class ResidualUnit(fluid.dygraph.Layer):
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding= get_padding_same(filter_size, dilation), #int((filter_size - 1) // 2) + (dilation - 1),
padding=get_padding_same(
filter_size,
dilation), #int((filter_size - 1) // 2) + (dilation - 1),
dilation=dilation,
num_groups=mid_c,
if_act=True,
......@@ -329,6 +356,7 @@ class ResidualUnit(fluid.dygraph.Layer):
act=None,
name=name + "_linear")
self.dilation = dilation
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
......@@ -386,6 +414,7 @@ def MobileNetV3_small_x0_75(**kwargs):
model = MobileNetV3(model_name="small", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_small_x1_0(**kwargs):
model = MobileNetV3(model_name="small", scale=1.0, **kwargs)
......@@ -411,6 +440,7 @@ def MobileNetV3_large_x0_75(**kwargs):
model = MobileNetV3(model_name="large", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_large_x1_0(**kwargs):
model = MobileNetV3(model_name="large", scale=1.0, **kwargs)
......
......@@ -30,6 +30,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.utils import utils
from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = [
"ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
......@@ -47,18 +48,23 @@ class ConvBNLayer(fluid.dygraph.Layer):
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
name=None,
):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D(
pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True)
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2 if dilation ==1 else 0,
padding=(filter_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation,
groups=groups,
act=None,
......@@ -125,7 +131,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
num_filters=num_filters * 4,
filter_size=1,
stride=1,
is_vd_mode=False if if_first or stride==1 else True,
is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1")
self.shortcut = shortcut
......@@ -137,7 +143,8 @@ class BottleneckBlock(fluid.dygraph.Layer):
# If given dilation rate > 1, using corresponding padding
if self.dilation > 1:
padding = self.dilation
y = fluid.layers.pad(y, [0,0,0,0,padding,padding,padding,padding])
y = fluid.layers.pad(
y, [0, 0, 0, 0, padding, padding, padding, padding])
#####################################################################
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
......@@ -202,7 +209,13 @@ class BasicBlock(fluid.dygraph.Layer):
class ResNet_vd(fluid.dygraph.Layer):
def __init__(self, layers=50, class_dim=1000, output_stride=None, multi_grid=(1, 2, 4), **kwargs):
def __init__(self,
backbone_pretrained=None,
layers=50,
class_dim=1000,
output_stride=None,
multi_grid=(1, 2, 4),
**kwargs):
super(ResNet_vd, self).__init__()
self.layers = layers
......@@ -221,11 +234,11 @@ class ResNet_vd(fluid.dygraph.Layer):
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_channels = [64, 256, 512, 1024
] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
dilation_dict=None
dilation_dict = None
if output_stride == 8:
dilation_dict = {2: 2, 3: 4}
elif output_stride == 16:
......@@ -260,7 +273,7 @@ class ResNet_vd(fluid.dygraph.Layer):
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list=[]
block_list = []
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
......@@ -272,7 +285,8 @@ class ResNet_vd(fluid.dygraph.Layer):
###############################################################################
# Add dilation rate for some segmentation tasks, if dilation_dict is not None.
dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1
dilation_rate = dilation_dict[
block] if dilation_dict and block in dilation_dict else 1
# Actually block here is 'stage', and i is 'block' in 'stage'
# At the stage 4, expand the the dilation_rate using multi_grid, default (1, 2, 4)
......@@ -284,9 +298,11 @@ class ResNet_vd(fluid.dygraph.Layer):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels[block] if i == 0 else num_filters[block] * 4,
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1,
stride=2 if i == 0 and block != 0
and dilation_rate == 1 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
......@@ -298,7 +314,7 @@ class ResNet_vd(fluid.dygraph.Layer):
else:
for block in range(len(depth)):
shortcut = False
block_list=[]
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
......@@ -330,6 +346,8 @@ class ResNet_vd(fluid.dygraph.Layer):
name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
self.init_weight(backbone_pretrained)
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
......@@ -355,8 +373,18 @@ class ResNet_vd(fluid.dygraph.Layer):
# if os.path.exists(pretrained_model):
# utils.load_pretrained_model(self, pretrained_model)
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def ResNet18_vd(**args):
......@@ -368,11 +396,13 @@ def ResNet34_vd(**args):
model = ResNet_vd(layers=34, **args)
return model
@manager.BACKBONES.add_component
def ResNet50_vd(**args):
model = ResNet_vd(layers=50, **args)
return model
@manager.BACKBONES.add_component
def ResNet101_vd(**args):
model = ResNet_vd(layers=101, **args)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
......@@ -7,6 +23,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.models.architectures import layer_utils
from dygraph.cvlibs import manager
from dygraph.utils import utils
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
......@@ -142,7 +159,6 @@ class Seperate_Conv(fluid.dygraph.Layer):
self._act_op2 = layer_utils.Activation(act=act)
def forward(self, inputs):
x = self._conv1(inputs)
x = self._bn1(x)
......@@ -257,7 +273,12 @@ class XceptionDeeplab(fluid.dygraph.Layer):
#def __init__(self, backbone, class_dim=1000):
# add output_stride
def __init__(self, backbone, output_stride=16, class_dim=1000, **kwargs):
def __init__(self,
backbone,
backbone_pretrained=None,
output_stride=16,
class_dim=1000,
**kwargs):
super(XceptionDeeplab, self).__init__()
......@@ -280,7 +301,6 @@ class XceptionDeeplab(fluid.dygraph.Layer):
padding=1,
act="relu",
name=self.backbone + "/entry_flow/conv2")
"""
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
......@@ -381,6 +401,8 @@ class XceptionDeeplab(fluid.dygraph.Layer):
param_attr=ParamAttr(name="fc_weights"),
bias_attr=ParamAttr(name="fc_bias"))
self.init_weight(backbone_pretrained)
def forward(self, inputs):
x = self._conv1(inputs)
x = self._conv2(x)
......@@ -401,11 +423,25 @@ class XceptionDeeplab(fluid.dygraph.Layer):
x = self._fc(x)
return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def Xception41_deeplab(**args):
model = XceptionDeeplab('xception_41', **args)
return model
@manager.BACKBONES.add_component
def Xception65_deeplab(**args):
model = XceptionDeeplab("xception_65", **args)
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dygraph.cvlibs import manager
......@@ -23,10 +22,12 @@ from paddle.fluid.dygraph import Conv2D
from dygraph.utils import utils
__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
__all__ = [
'DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
"deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8",
"deeplabv3p_xception65_deeplab",
"deeplabv3p_mobilenetv3_large", "deeplabv3p_mobilenetv3_small"]
"deeplabv3p_xception65_deeplab", "deeplabv3p_mobilenetv3_large",
"deeplabv3p_mobilenetv3_small"
]
class ImageAverage(dygraph.Layer):
......@@ -40,9 +41,8 @@ class ImageAverage(dygraph.Layer):
def __init__(self, num_channels):
super(ImageAverage, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels,
num_filters=256,
filter_size=1)
self.conv_bn_relu = layer_utils.ConvBnRelu(
num_channels, num_filters=256, filter_size=1)
def forward(self, input):
x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True)
......@@ -69,18 +69,22 @@ class ASPP(dygraph.Layer):
elif output_stride == 8:
aspp_ratios = (12, 24, 36)
else:
raise NotImplementedError("Only support output_stride is 8 or 16, but received{}".format(output_stride))
raise NotImplementedError(
"Only support output_stride is 8 or 16, but received{}".format(
output_stride))
self.image_average = ImageAverage(num_channels=in_channels)
# The first aspp using 1*1 conv
self.aspp1 = layer_utils.ConvBnRelu(num_channels=in_channels,
self.aspp1 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=1,
using_sep_conv=False)
# The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0]
self.aspp2 = layer_utils.ConvBnRelu(num_channels=in_channels,
self.aspp2 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
......@@ -88,7 +92,8 @@ class ASPP(dygraph.Layer):
padding=aspp_ratios[0])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1]
self.aspp3 = layer_utils.ConvBnRelu(num_channels=in_channels,
self.aspp3 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
......@@ -96,7 +101,8 @@ class ASPP(dygraph.Layer):
padding=aspp_ratios[1])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2]
self.aspp4 = layer_utils.ConvBnRelu(num_channels=in_channels,
self.aspp4 = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
......@@ -104,9 +110,8 @@ class ASPP(dygraph.Layer):
padding=aspp_ratios[2])
# After concat op, using 1*1 conv
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280,
num_filters=256,
filter_size=1)
self.conv_bn_relu = layer_utils.ConvBnRelu(
num_channels=1280, num_filters=256, filter_size=1)
def forward(self, x):
......@@ -136,23 +141,23 @@ class Decoder(dygraph.Layer):
def __init__(self, num_classes, in_channels, using_sep_conv=True):
super(Decoder, self).__init__()
self.conv_bn_relu1 = layer_utils.ConvBnRelu(num_channels=in_channels,
num_filters=48,
filter_size=1)
self.conv_bn_relu1 = layer_utils.ConvBnRelu(
num_channels=in_channels, num_filters=48, filter_size=1)
self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=304,
self.conv_bn_relu2 = layer_utils.ConvBnRelu(
num_channels=304,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv_bn_relu3 = layer_utils.ConvBnRelu(num_channels=256,
self.conv_bn_relu3 = layer_utils.ConvBnRelu(
num_channels=256,
num_filters=256,
filter_size=3,
using_sep_conv=using_sep_conv,
padding=1)
self.conv = Conv2D(num_channels=256,
num_filters=num_classes,
filter_size=1)
self.conv = Conv2D(
num_channels=256, num_filters=num_classes, filter_size=1)
def forward(self, x, low_level_feat):
low_level_feat = self.conv_bn_relu1(low_level_feat)
......@@ -164,6 +169,7 @@ class Decoder(dygraph.Layer):
return x
@manager.MODELS.add_component
class DeepLabV3P(dygraph.Layer):
"""
The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder
......@@ -173,9 +179,11 @@ class DeepLabV3P(dygraph.Layer):
(https://arxiv.org/abs/1802.02611)
Args:
backbone (str): backbone name, currently support Xception65, Resnet101_vd. Default Resnet101_vd.
num_classes (int): the unique number of target classes.
backbone (paddle.nn.Layer): backbone networks, currently support Xception65, Resnet101_vd. Default Resnet101_vd.
num_classes (int): the unique number of target classes. Default 2.
model_pretrained (str): the path of pretrained model.
output_stride (int): the ratio of input size and final feature size. Default 16.
......@@ -193,28 +201,29 @@ class DeepLabV3P(dygraph.Layer):
using_sep_conv (bool): a bool value indicates whether using separable convolutions
in ASPP and Decoder components. Default True.
pretrained_model (str): the pretrained_model path of backbone.
"""
def __init__(self,
num_classes,
backbone,
num_classes=2,
model_pretrained=None,
output_stride=16,
backbone_indices=(0, 3),
backbone_channels=(256, 2048),
ignore_index=255,
using_sep_conv=True,
pretrained_model=None):
using_sep_conv=True):
super(DeepLabV3P, self).__init__()
self.backbone = manager.BACKBONES[backbone](output_stride=output_stride)
# self.backbone = manager.BACKBONES[backbone](output_stride=output_stride)
self.backbone = backbone
self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv)
self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv)
self.decoder = Decoder(num_classes, backbone_channels[0],
using_sep_conv)
self.ignore_index = ignore_index
self.EPS = 1e-5
self.backbone_indices = backbone_indices
self.init_weight(pretrained_model)
self.init_weight(model_pretrained)
def forward(self, input, label=None):
......@@ -238,14 +247,14 @@ class DeepLabV3P(dygraph.Layer):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
# utils.load_pretrained_model(self, pretrained_model)
# for param in self.backbone.parameters():
# param.stop_gradient = True
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label):
"""
......@@ -290,31 +299,42 @@ def build_decoder(num_classes, using_sep_conv):
@manager.MODELS.add_component
def deeplabv3p_resnet101_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet101_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet50_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return DeepLabV3P(
backbone='ResNet50_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_xception65_deeplab(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='Xception65_deeplab',
return DeepLabV3P(
backbone='Xception65_deeplab',
pretrained_model=pretrained_model,
backbone_indices=(0, 1),
backbone_channels=(128, 2048),
......@@ -324,7 +344,8 @@ def deeplabv3p_xception65_deeplab(*args, **kwargs):
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_large(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_large_x1_0',
return DeepLabV3P(
backbone='MobileNetV3_large_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(24, 160),
......@@ -334,7 +355,8 @@ def deeplabv3p_mobilenetv3_large(*args, **kwargs):
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_small(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_small_x1_0',
return DeepLabV3P(
backbone='MobileNetV3_small_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(16, 96),
......
......@@ -25,6 +25,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm
from dygraph.cvlibs import manager
from dygraph import utils
from dygraph.cvlibs import param_init
__all__ = [
"fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18",
......@@ -33,115 +34,105 @@ __all__ = [
]
@manager.MODELS.add_component
class FCN(fluid.dygraph.Layer):
"""
Fully Convolutional Networks for Semantic Segmentation.
https://arxiv.org/abs/1411.4038
Args:
backbone (str): backbone name,
num_classes (int): the unique number of target classes.
in_channels (int): the channels of input feature maps.
backbone (paddle.nn.Layer): backbone networks.
model_pretrained (str): the path of pretrained model.
backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
channels (int): channels after conv layer before the last one.
pretrained_model (str): the path of pretrained model.
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
"""
def __init__(self,
backbone,
num_classes,
in_channels,
backbone,
model_pretrained=None,
backbone_indices=(-1, ),
backbone_channels=(270, ),
channels=None,
pretrained_model=None,
ignore_index=255,
**kwargs):
super(FCN, self).__init__()
self.num_classes = num_classes
self.backbone_indices = backbone_indices
self.ignore_index = ignore_index
self.EPS = 1e-5
if channels is None:
channels = in_channels
channels = backbone_channels[backbone_indices[0]]
self.backbone = manager.BACKBONES[backbone](**kwargs)
self.backbone = backbone
self.conv_last_2 = ConvBNLayer(
num_channels=in_channels,
num_channels=backbone_channels[backbone_indices[0]],
num_filters=channels,
filter_size=1,
stride=1,
name='conv-2')
stride=1)
self.conv_last_1 = Conv2D(
num_channels=channels,
num_filters=self.num_classes,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name='conv-1_weights'))
self.init_weight(pretrained_model)
padding=0)
if self.training:
self.init_weight(model_pretrained)
def forward(self, x, label=None, mode='train'):
def forward(self, x):
input_shape = x.shape[2:]
x = self.backbone(x)
fea_list = self.backbone(x)
x = fea_list[self.backbone_indices[0]]
x = self.conv_last_2(x)
logit = self.conv_last_1(x)
logit = fluid.layers.resize_bilinear(logit, input_shape)
if self.training:
if label is None:
raise Exception('Label is need during training')
return self._get_loss(logit, label)
else:
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
return [logit]
# if self.training:
# if label is None:
# raise Exception('Label is need during training')
# return self._get_loss(logit, label)
# else:
# score_map = fluid.layers.softmax(logit, axis=1)
# score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
# pred = fluid.layers.argmax(score_map, axis=3)
# pred = fluid.layers.unsqueeze(pred, axes=[3])
# return pred, score_map
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, 1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, 0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
......@@ -150,8 +141,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
filter_size,
stride=1,
groups=1,
act="relu",
name=None):
act="relu"):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
......@@ -161,18 +151,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
num_filters,
weight_attr=ParamAttr(
name=bn_name + '_scale',
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
bn_name + '_offset',
initializer=fluid.initializer.Constant(0.0)))
self._batch_norm = BatchNorm(num_filters)
self.act = act
def forward(self, input):
......@@ -185,49 +165,49 @@ class ConvBNLayer(fluid.dygraph.Layer):
@manager.MODELS.add_component
def fcn_hrnet_w18_small_v1(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V1', in_channels=240, **kwargs)
return FCN(backbone='HRNet_W18_Small_V1', backbone_channels=(240), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w18_small_v2(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V2', in_channels=270, **kwargs)
return FCN(backbone='HRNet_W18_Small_V2', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w18(*args, **kwargs):
return FCN(backbone='HRNet_W18', in_channels=270, **kwargs)
return FCN(backbone='HRNet_W18', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w30(*args, **kwargs):
return FCN(backbone='HRNet_W30', in_channels=450, **kwargs)
return FCN(backbone='HRNet_W30', backbone_channels=(450), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w32(*args, **kwargs):
return FCN(backbone='HRNet_W32', in_channels=480, **kwargs)
return FCN(backbone='HRNet_W32', backbone_channels=(480), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w40(*args, **kwargs):
return FCN(backbone='HRNet_W40', in_channels=600, **kwargs)
return FCN(backbone='HRNet_W40', backbone_channels=(600), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w44(*args, **kwargs):
return FCN(backbone='HRNet_W44', in_channels=660, **kwargs)
return FCN(backbone='HRNet_W44', backbone_channels=(660), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w48(*args, **kwargs):
return FCN(backbone='HRNet_W48', in_channels=720, **kwargs)
return FCN(backbone='HRNet_W48', backbone_channels=(720), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w60(*args, **kwargs):
return FCN(backbone='HRNet_W60', in_channels=900, **kwargs)
return FCN(backbone='HRNet_W60', backbone_channels=(900), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w64(*args, **kwargs):
return FCN(backbone='HRNet_W64', in_channels=960, **kwargs)
return FCN(backbone='HRNet_W64', backbone_channels=(960), **kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cross_entroy_loss import CrossEntropyLoss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from dygraph.cvlibs import manager
'''
@manager.LOSSES.add_component
class CrossEntropyLoss(nn.CrossEntropyLoss):
"""
Implements the cross entropy loss function.
Args:
weight (Tensor): Weight tensor, a manual rescaling weight given
to each class and the shape is (C). It has the same dimensions as class
number and the data type is float32, float64. Default ``'None'``.
ignore_index (int64): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
reduction (str): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default ``'mean'``.
"""
def __init__(self, weight=None, ignore_index=255, reduction='mean'):
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
self.EPS = 1e-5
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
" 'none', but received %s, which is not allowed." %
self.reduction)
def forward(self, logit, label):
"""
Forward computation.
Args:
logit (Tensor): logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Variable): label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
"""
loss = paddle.nn.functional.cross_entropy(
logit,
label,
weight=self.weight,
ignore_index=self.ignore_index,
reduction=self.reduction)
mask = label != self.ignore_index
mask = paddle.cast(mask, 'float32')
avg_loss = loss / (paddle.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
'''
@manager.LOSSES.add_component
class CrossEntropyLoss(nn.Layer):
"""
Implements the cross entropy loss function.
Args:
ignore_index (int64): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
"""
def __init__(self, ignore_index=255):
super(CrossEntropyLoss, self).__init__()
self.ignore_index = ignore_index
self.EPS = 1e-5
def forward(self, logit, label):
"""
Forward computation.
Args:
logit (Tensor): logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Variable): label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
"""
if len(label.shape) != len(logit.shape):
label = paddle.unsqueeze(label, 1)
loss = F.softmax_with_cross_entropy(
logit, label, ignore_index=self.ignore_index, axis=1)
loss = paddle.reduce_mean(loss)
mask = label != self.ignore_index
mask = paddle.cast(mask, 'float32')
avg_loss = loss / (paddle.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
from paddle.fluid.dygraph import Sequential, Conv2D
from dygraph.cvlibs import manager
from dygraph.models.architectures.layer_utils import ConvBnRelu
from dygraph import utils
class SpatialGatherBlock(fluid.dygraph.Layer):
def forward(self, pixels, regions):
n, c, h, w = pixels.shape
_, k, _, _ = regions.shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels = fluid.layers.reshape(pixels, (n, c, h * w))
pixels = fluid.layers.transpose(pixels, (0, 2, 1))
# regions: from (n, k, h, w) to (n, k, h*w)
regions = fluid.layers.reshape(regions, (n, k, h * w))
regions = fluid.layers.softmax(regions, axis=2)
# feats: from (n, k, c) to (n, c, k, 1)
feats = fluid.layers.matmul(regions, pixels)
feats = fluid.layers.transpose(feats, (0, 2, 1))
feats = fluid.layers.unsqueeze(feats, axes=[-1])
return feats
class SpatialOCRModule(fluid.dygraph.Layer):
def __init__(self,
in_channels,
key_channels,
out_channels,
dropout_rate=0.1):
super(SpatialOCRModule, self).__init__()
self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
self.dropout_rate = dropout_rate
self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1)
def forward(self, pixels, regions):
context = self.attention_block(pixels, regions)
feats = fluid.layers.concat([context, pixels], axis=1)
feats = self.conv1x1(feats)
feats = fluid.layers.dropout(feats, self.dropout_rate)
return feats
class ObjectAttentionBlock(fluid.dygraph.Layer):
def __init__(self, in_channels, key_channels):
super(ObjectAttentionBlock, self).__init__()
self.in_channels = in_channels
self.key_channels = key_channels
self.f_pixel = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_object = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_down = ConvBnRelu(in_channels, key_channels, 1)
self.f_up = ConvBnRelu(key_channels, in_channels, 1)
def forward(self, x, proxy):
n, _, h, w = x.shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query = self.f_pixel(x)
query = fluid.layers.reshape(query, (n, self.key_channels, -1))
query = fluid.layers.transpose(query, (0, 2, 1))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key = self.f_object(proxy)
key = fluid.layers.reshape(key, (n, self.key_channels, -1))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value = self.f_down(proxy)
value = fluid.layers.reshape(value, (n, self.key_channels, -1))
value = fluid.layers.transpose(value, (0, 2, 1))
# sim_map (n, h1*w1, h2*w2)
sim_map = fluid.layers.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = fluid.layers.softmax(sim_map, axis=-1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context = fluid.layers.matmul(sim_map, value)
context = fluid.layers.transpose(context, (0, 2, 1))
context = fluid.layers.reshape(context, (n, self.key_channels, h, w))
context = self.f_up(context)
return context
@manager.MODELS.add_component
class OCRNet(fluid.dygraph.Layer):
def __init__(self,
num_classes,
backbone,
model_pretrained=None,
in_channels=None,
ocr_mid_channels=512,
ocr_key_channels=256,
ignore_index=255):
super(OCRNet, self).__init__()
self.ignore_index = ignore_index
self.num_classes = num_classes
self.EPS = 1e-5
self.backbone = backbone
self.spatial_gather = SpatialGatherBlock()
self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
ocr_mid_channels)
self.conv3x3_ocr = ConvBnRelu(
in_channels, ocr_mid_channels, 3, padding=1)
self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1)
self.aux_head = Sequential(
ConvBnRelu(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1))
self.init_weight(model_pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
soft_regions = self.aux_head(feats)
pixels = self.conv3x3_ocr(feats)
object_regions = self.spatial_gather(pixels, soft_regions)
ocr = self.spatial_ocr(pixels, object_regions)
logit = self.cls_head(ocr)
logit = fluid.layers.resize_bilinear(logit, x.shape[2:])
if self.training:
soft_regions = fluid.layers.resize_bilinear(soft_regions,
x.shape[2:])
cls_loss = self._get_loss(logit, label)
aux_loss = self._get_loss(soft_regions, label)
return cls_loss + 0.4 * aux_loss
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model.. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.nn.functional as F
......@@ -35,9 +34,11 @@ class PSPNet(fluid.dygraph.Layer):
(https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Pyramid_Scene_Parsing_CVPR_2017_paper.pdf)
Args:
backbone (str): backbone name, currently support Resnet50/101.
num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone name, currently support Resnet50/101.
num_classes (int): the unique number of target classes. Default 2.
model_pretrained (str): the path of pretrained model.
output_stride (int): the ratio of input size and final feature size. Default 16.
......@@ -57,42 +58,44 @@ class PSPNet(fluid.dygraph.Layer):
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default to 255.
pretrained_model (str): the pretrained_model path of backbone.
"""
def __init__(self,
num_classes,
backbone,
num_classes=2,
model_pretrained=None,
output_stride=16,
backbone_indices=(2, 3),
backbone_channels=(1024, 2048),
pp_out_channels=1024,
bin_sizes=(1, 2, 3, 6),
enable_auxiliary_loss=True,
ignore_index=255,
pretrained_model=None):
ignore_index=255):
super(PSPNet, self).__init__()
self.backbone = manager.BACKBONES[backbone](output_stride=output_stride,
multi_grid=(1, 1, 1))
# self.backbone = manager.BACKBONES[backbone](output_stride=output_stride,
# multi_grid=(1, 1, 1))
self.backbone = backbone
self.backbone_indices = backbone_indices
self.psp_module = PPModule(in_channels=backbone_channels[1],
self.psp_module = PPModule(
in_channels=backbone_channels[1],
out_channels=pp_out_channels,
bin_sizes=bin_sizes)
self.conv = Conv2D(num_channels=pp_out_channels,
self.conv = Conv2D(
num_channels=pp_out_channels,
num_filters=num_classes,
filter_size=1)
if enable_auxiliary_loss:
self.fcn_head = model_utils.FCNHead(in_channels=backbone_channels[0], out_channels=num_classes)
self.fcn_head = model_utils.FCNHead(
in_channels=backbone_channels[0], out_channels=num_classes)
self.enable_auxiliary_loss = enable_auxiliary_loss
self.ignore_index = ignore_index
self.init_weight(pretrained_model)
self.init_weight(model_pretrained)
def forward(self, input, label=None):
......@@ -107,7 +110,8 @@ class PSPNet(fluid.dygraph.Layer):
if self.enable_auxiliary_loss:
auxiliary_feat = feat_list[self.backbone_indices[0]]
auxiliary_logit = self.fcn_head(auxiliary_feat)
auxiliary_logit = fluid.layers.resize_bilinear(auxiliary_logit, input.shape[2:])
auxiliary_logit = fluid.layers.resize_bilinear(
auxiliary_logit, input.shape[2:])
if self.training:
loss = model_utils.get_loss(logit, label)
......@@ -116,7 +120,6 @@ class PSPNet(fluid.dygraph.Layer):
loss += (0.4 * auxiliary_loss)
return loss
else:
pred, score_map = model_utils.get_pred_score_map(logit)
return pred, score_map
......@@ -124,14 +127,15 @@ class PSPNet(fluid.dygraph.Layer):
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class PPModule(fluid.dygraph.Layer):
......@@ -151,9 +155,11 @@ class PPModule(fluid.dygraph.Layer):
self.bin_sizes = bin_sizes
# we use dimension reduction after pooling mentioned in original implementation.
self.stages = fluid.dygraph.LayerList([self._make_stage(in_channels, size) for size in bin_sizes])
self.stages = fluid.dygraph.LayerList(
[self._make_stage(in_channels, size) for size in bin_sizes])
self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=in_channels * 2,
self.conv_bn_relu2 = layer_utils.ConvBnRelu(
num_channels=in_channels * 2,
num_filters=out_channels,
filter_size=3,
padding=1)
......@@ -180,7 +186,8 @@ class PPModule(fluid.dygraph.Layer):
# this paddle version does not support AdaptiveAvgPool2d, so skip it here.
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = layer_utils.ConvBnRelu(num_channels=in_channels,
conv = layer_utils.ConvBnRelu(
num_channels=in_channels,
num_filters=in_channels // len(self.bin_sizes),
filter_size=1)
......@@ -190,7 +197,8 @@ class PPModule(fluid.dygraph.Layer):
cat_layers = []
for i, stage in enumerate(self.stages):
size = self.bin_sizes[i]
x = fluid.layers.adaptive_pool2d(input, pool_size=(size, size), pool_type="max")
x = fluid.layers.adaptive_pool2d(
input, pool_size=(size, size), pool_type="max")
x = stage(x)
x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:])
cat_layers.append(x)
......@@ -204,22 +212,32 @@ class PPModule(fluid.dygraph.Layer):
@manager.MODELS.add_component
def pspnet_resnet101_vd(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def pspnet_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet101_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
@manager.MODELS.add_component
def pspnet_resnet50_vd(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def pspnet_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None
return PSPNet(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
return PSPNet(
backbone='ResNet50_vd',
output_stride=8,
pretrained_model=pretrained_model,
**kwargs)
......@@ -33,7 +33,7 @@ class UNet(fluid.dygraph.Layer):
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
"""
def __init__(self, num_classes, pretrained_model=None, ignore_index=255):
def __init__(self, num_classes, model_pretrained=None, ignore_index=255):
super(UNet, self).__init__()
self.encode = UnetEncoder()
self.decode = UnetDecode()
......@@ -41,7 +41,7 @@ class UNet(fluid.dygraph.Layer):
self.ignore_index = ignore_index
self.EPS = 1e-5
self.init_weight(pretrained_model)
self.init_weight(model_pretrained)
def forward(self, x, label=None):
encode_data, short_cuts = self.encode(x)
......@@ -60,7 +60,7 @@ class UNet(fluid.dygraph.Layer):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
......
......@@ -17,78 +17,36 @@ import argparse
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
import dygraph
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.utils import logger
from dygraph.utils import Config
from dygraph.core import train
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help='Model type for training, which is one of {}'.format(
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help="The dataset you want to train, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
"--config", dest="cfg", help="The config file.", default=None, type=str)
parser.add_argument(
'--iters',
dest='iters',
help='iters for training',
type=int,
default=10000)
default=None)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=2)
default=None)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=0.01)
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrained model',
type=str,
default=None)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None)
parser.add_argument(
'--save_interval_iters',
......@@ -139,64 +97,36 @@ def main(args):
if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \
else fluid.CPUPlace()
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
# Creat dataset reader
train_transforms = T.Compose([
T.Resize(args.input_size),
T.RandomHorizontalFlip(),
T.Normalize()
])
train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
model = manager.MODELS[args.model_name](
num_classes=train_dataset.num_classes,
pretrained_model=args.pretrained_model)
# Creat optimizer
# todo, may less one than len(loader)
num_iters_each_epoch = len(train_dataset) // (
args.batch_size * ParallelEnv().nranks)
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, args.iters, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
if not args.cfg:
raise RuntimeError('No configuration file specified.')
cfg = Config(args.cfg)
train_dataset = cfg.train_dataset
if not train_dataset:
raise RuntimeError(
'The training dataset is not specified in the configuration file.'
)
val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss
train(
model,
cfg.model,
train_dataset,
places=places,
eval_dataset=eval_dataset,
optimizer=optimizer,
eval_dataset=val_dataset,
optimizer=cfg.optimizer,
save_dir=args.save_dir,
iters=args.iters,
batch_size=args.batch_size,
resume_model=args.resume_model,
iters=cfg.iters,
batch_size=cfg.batch_size,
save_interval_iters=args.save_interval_iters,
log_iters=args.log_iters,
num_classes=train_dataset.num_classes,
num_workers=args.num_workers,
use_vdl=args.use_vdl)
use_vdl=args.use_vdl,
losses=losses)
if __name__ == '__main__':
......
......@@ -21,8 +21,10 @@ from PIL import Image
import cv2
from .functional import *
from dygraph.cvlibs import manager
@manager.TRANSFORMS.add_component
class Compose:
def __init__(self, transforms, to_rgb=True):
if not isinstance(transforms, list):
......@@ -53,11 +55,12 @@ class Compose:
if len(outputs) == 3:
label = outputs[2]
im = permute(im)
if len(outputs) == 3:
label = label[np.newaxis, :, :]
# if len(outputs) == 3:
# label = label[np.newaxis, :, :]
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomHorizontalFlip:
def __init__(self, prob=0.5):
self.prob = prob
......@@ -73,6 +76,7 @@ class RandomHorizontalFlip:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomVerticalFlip:
def __init__(self, prob=0.1):
self.prob = prob
......@@ -88,6 +92,7 @@ class RandomVerticalFlip:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Resize:
# The interpolation mode
interp_dict = {
......@@ -137,6 +142,7 @@ class Resize:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeByLong:
def __init__(self, long_size):
self.long_size = long_size
......@@ -156,6 +162,7 @@ class ResizeByLong:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeRangeScaling:
def __init__(self, min_value=400, max_value=600):
if min_value > max_value:
......@@ -181,6 +188,7 @@ class ResizeRangeScaling:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeStepScaling:
def __init__(self,
min_scale_factor=0.75,
......@@ -224,6 +232,7 @@ class ResizeStepScaling:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Normalize:
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean
......@@ -245,6 +254,7 @@ class Normalize:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Padding:
def __init__(self,
target_size,
......@@ -305,6 +315,7 @@ class Padding:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomPaddingCrop:
def __init__(self,
crop_size=512,
......@@ -378,6 +389,7 @@ class RandomPaddingCrop:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomBlur:
def __init__(self, prob=0.1):
self.prob = prob
......@@ -404,6 +416,7 @@ class RandomBlur:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomRotation:
def __init__(self,
max_rotation=15,
......@@ -451,6 +464,7 @@ class RandomRotation:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomScaleAspect:
def __init__(self, min_scale=0.5, aspect_ratio=0.33):
self.min_scale = min_scale
......@@ -492,6 +506,7 @@ class RandomScaleAspect:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomDistort:
def __init__(self,
brightness_range=0.5,
......
......@@ -18,3 +18,4 @@ from .metrics import ConfusionMatrix
from .utils import *
from .timer import Timer, calculate_eta
from .get_environ_info import get_environ_info
from .config import Config
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import os
from typing import Any, Callable
import yaml
import paddle.fluid as fluid
import dygraph.cvlibs.manager as manager
class Config(object):
'''
Training config.
Args:
path(str) : the path of config file, supports yaml format only
'''
def __init__(self, path: str):
if not os.path.exists(path):
raise FileNotFoundError('File {} does not exist'.format(path))
if path.endswith('yml') or path.endswith('yaml'):
dic = self._parse_from_yaml(path)
self._build(dic)
else:
raise RuntimeError('Config file should in yaml format!')
def _update_dic(self, dic, base_dic):
"""
update dic from base_dic
"""
dic = dic.copy()
for key, val in base_dic.items():
if isinstance(val, dict) and key in dic:
dic[key] = self._update_dic(dic[key], val)
else:
dic[key] = val
return dic
def _parse_from_yaml(self, path: str):
'''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
if '_base_' in dic:
cfg_dir = os.path.dirname(path)
base_path = dic.pop('_base_')
base_path = os.path.join(cfg_dir, base_path)
base_dic = self._parse_from_yaml(base_path)
dic = self._update_dic(dic, base_dic)
return dic
def _build(self, dic: dict):
'''Build config from dictionary'''
dic = dic.copy()
self._batch_size = dic.get('batch_size', 1)
self._iters = dic.get('iters')
if 'model' not in dic:
raise RuntimeError()
self._model_cfg = dic['model']
self._model = None
self._train_dataset = dic.get('train_dataset')
self._val_dataset = dic.get('val_dataset')
self._learning_rate_cfg = dic.get('learning_rate', {})
self._learning_rate = self._learning_rate_cfg.get('value')
self._decay = self._learning_rate_cfg.get('decay', {
'type': 'poly',
'power': 0.9
})
self._loss_cfg = dic.get('loss', {})
self._losses = None
self._optimizer_cfg = dic.get('optimizer', {})
def update(self,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
'''Update config'''
if learning_rate:
self._learning_rate = learning_rate
if batch_size:
self._batch_size = batch_size
if iters:
self._iters = iters
@property
def batch_size(self) -> int:
return self._batch_size
@property
def iters(self) -> int:
if not self._iters:
raise RuntimeError('No iters specified in the configuration file.')
return self._iters
@property
def learning_rate(self) -> float:
if not self._learning_rate:
raise RuntimeError(
'No learning rate specified in the configuration file.')
if self.decay_type == 'poly':
lr = self._learning_rate
args = self.decay_args
args.setdefault('decay_steps', self.iters)
return fluid.layers.polynomial_decay(lr, **args)
else:
raise RuntimeError('Only poly decay support.')
@property
def optimizer(self) -> fluid.optimizer.Optimizer:
if self.optimizer_type == 'sgd':
lr = self.learning_rate
args = self.optimizer_args
args.setdefault('momentum', 0.9)
return fluid.optimizer.Momentum(
lr, parameter_list=self.model.parameters(), **args)
else:
raise RuntimeError('Only sgd optimizer support.')
@property
def optimizer_type(self) -> str:
otype = self._optimizer_cfg.get('type')
if not otype:
raise RuntimeError(
'No optimizer type specified in the configuration file.')
return otype
@property
def optimizer_args(self) -> dict:
args = self._optimizer_cfg.copy()
args.pop('type')
return args
@property
def decay_type(self) -> str:
return self._decay['type']
@property
def decay_args(self) -> dict:
args = self._decay.copy()
args.pop('type')
return args
@property
def loss(self) -> list:
if not self._losses:
args = self._loss_cfg.copy()
self._losses = dict()
for key, val in args.items():
if key == 'types':
self._losses['types'] = []
for item in args['types']:
self._losses['types'].append(self._load_object(item))
else:
self._losses[key] = val
if len(self._losses['coef']) != len(self._losses['types']):
raise RuntimeError(
'The length of coef should equal to types in loss config: {} != {}.'
.format(
len(self._losses['coef']), len(self._losses['types'])))
return self._losses
@property
def model(self) -> Callable:
if not self._model:
self._model = self._load_object(self._model_cfg)
return self._model
@property
def train_dataset(self) -> Any:
if not self._train_dataset:
return None
return self._load_object(self._train_dataset)
@property
def val_dataset(self) -> Any:
if not self._val_dataset:
return None
return self._load_object(self._val_dataset)
def _load_component(self, com_name: str) -> Any:
com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS, manager.LOSSES
]
for com in com_list:
if com_name in com.components_dict:
return com[com_name]
else:
raise RuntimeError(
'The specified component was not found {}.'.format(com_name))
def _load_object(self, cfg: dict) -> Any:
cfg = cfg.copy()
if 'type' not in cfg:
raise RuntimeError('No object information in {}.'.format(cfg))
component = self._load_component(cfg.pop('type'))
params = {}
for key, val in cfg.items():
if self._is_meta_type(val):
params[key] = self._load_object(val)
elif isinstance(val, list):
params[key] = [
self._load_object(item)
if self._is_meta_type(item) else item for item in val
]
else:
params[key] = val
return component(**params)
def _is_meta_type(self, item: Any) -> bool:
return isinstance(item, dict) and 'type' in item
......@@ -17,48 +17,19 @@ import argparse
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
import dygraph
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.utils import Config
from dygraph.core import evaluate
def parse_args():
parser = argparse.ArgumentParser(description='Model evaluation')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help='Model type for evaluation, which is one of {}'.format(
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help="The dataset you want to evaluation, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of evaluate
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
"--config", dest="cfg", help="The config file.", default=None, type=str)
parser.add_argument(
'--model_dir',
dest='model_dir',
......@@ -75,26 +46,21 @@ def main(args):
if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \
else fluid.CPUPlace()
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
model = manager.MODELS[args.model_name](
num_classes=eval_dataset.num_classes)
if not args.cfg:
raise RuntimeError('No configuration file specified.')
cfg = Config(args.cfg)
val_dataset = cfg.val_dataset
if not val_dataset:
raise RuntimeError(
'The verification dataset is not specified in the configuration file.'
)
evaluate(
model,
eval_dataset,
cfg.model,
val_dataset,
model_dir=args.model_dir,
num_classes=eval_dataset.num_classes)
num_classes=val_dataset.num_classes)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册