diff --git a/dygraph/__init__.py b/dygraph/__init__.py index 9c52eaa7ed1463ca40036bd959610b0d1fd80fea..ab403ec3077181129a4641bee96683d9ac82cba6 100644 --- a/dygraph/__init__.py +++ b/dygraph/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dygraph.models \ No newline at end of file +from . import models +from . import datasets +from . import transforms diff --git a/dygraph/configs/_base_/cityscapes.yml b/dygraph/configs/_base_/cityscapes.yml new file mode 100644 index 0000000000000000000000000000000000000000..0ceee39221137a782e7008848240dbc5c31ea595 --- /dev/null +++ b/dygraph/configs/_base_/cityscapes.yml @@ -0,0 +1,39 @@ +batch_size: 4 +iters: 100000 +learning_rate: 0.01 + +train_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: Normalize + mode: train + +val_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: Normalize + mode: val + + +optimizer: + type: sgd + +learning_rate: + value: 0.01 + decay: + type: poly + power: 0.9 + +loss: + types: + - type: CrossEntropyLoss + coef: [1] diff --git a/dygraph/configs/_base_/optic_disc_seg.yml b/dygraph/configs/_base_/optic_disc_seg.yml new file mode 100644 index 0000000000000000000000000000000000000000..19b08188a5f2b3d1eb2837708c90243b88deb3ed --- /dev/null +++ b/dygraph/configs/_base_/optic_disc_seg.yml @@ -0,0 +1,37 @@ +batch_size: 4 +iters: 10000 +learning_rate: 0.01 + +train_dataset: + type: OpticDiscSeg + dataset_root: data/optic_disc_seg + transforms: + - type: Resize + target_size: [512, 512] + - type: RandomHorizontalFlip + - type: Normalize + mode: train + +val_dataset: + type: OpticDiscSeg + dataset_root: data/optic_disc_seg + transforms: + - type: Resize + target_size: [512, 512] + - type: Normalize + mode: val + + +optimizer: + type: sgd + +learning_rate: + value: 0.01 + decay: + type: poly + power: 0.9 + +loss: + types: + - type: CrossEntropyLoss + coef: [1] diff --git a/dygraph/configs/fcn_hrnet/fcn_hrnetw18_cityscapes_1024x512_100k.yml b/dygraph/configs/fcn_hrnet/fcn_hrnetw18_cityscapes_1024x512_100k.yml new file mode 100644 index 0000000000000000000000000000000000000000..cbb82ed86117f4c82a61d992f6aebeff56d005a9 --- /dev/null +++ b/dygraph/configs/fcn_hrnet/fcn_hrnetw18_cityscapes_1024x512_100k.yml @@ -0,0 +1,9 @@ +_base_: '../_base_/cityscapes.yml' + +model: + type: FCN + backbone: + type: HRNet_W18 + backbone_pretrained: pretrained_model/hrnet_w18_imagenet + num_classes: 19 + backbone_channels: [270] diff --git a/dygraph/configs/fcn_hrnet/fcn_hrnetw18_optic_disc_512x512_10k.yml b/dygraph/configs/fcn_hrnet/fcn_hrnetw18_optic_disc_512x512_10k.yml new file mode 100644 index 0000000000000000000000000000000000000000..11b394e2d54da5685003aad4954b4bff01dcaf2b --- /dev/null +++ b/dygraph/configs/fcn_hrnet/fcn_hrnetw18_optic_disc_512x512_10k.yml @@ -0,0 +1,9 @@ +_base_: '../_base_/optic_disc_seg.yml' + +model: + type: FCN + backbone: + type: HRNet_W18 + backbone_pretrained: pretrained_model/hrnet_w18_imagenet + num_classes: 2 + backbone_channels: [270] diff --git a/dygraph/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml b/dygraph/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml new file mode 100644 index 0000000000000000000000000000000000000000..e71271c1e59930cfbb0bcd68aab31b3706a82374 --- /dev/null +++ b/dygraph/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml @@ -0,0 +1,44 @@ +batch_size: 2 +iters: 40000 + +train_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: Normalize + mode: train + +val_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: Normalize + mode: val + +model: + type: OCRNet + backbone: + type: HRNet_W18 + backbone_pretrianed: None + num_classes: 19 + in_channels: 270 + model_pretrained: None + +optimizer: + type: sgd + +learning_rate: + value: 0.01 + decay: + type: poly + power: 0.9 + +loss: + type: CrossEntropy diff --git a/dygraph/core/train.py b/dygraph/core/train.py index 59df4fd595e658409d84854a1d28045f6e4297d3..f3578667ff6b9233dfa8b8d82407ba70b7221a40 100644 --- a/dygraph/core/train.py +++ b/dygraph/core/train.py @@ -14,11 +14,13 @@ import os +import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.io import DataLoader # from paddle.incubate.hapi.distributed import DistributedBatchSampler from paddle.io import DistributedBatchSampler +import paddle.nn.functional as F import dygraph.utils.logger as logger from dygraph.utils import load_pretrained_model @@ -27,6 +29,27 @@ from dygraph.utils import Timer, calculate_eta from .val import evaluate +def check_logits_losses(logits, losses): + len_logits = len(logits) + len_losses = len(losses['types']) + if len_logits != len_losses: + raise RuntimeError( + 'The length of logits should equal to the types of loss config: {} != {}.' + .format(len_logits, len_losses)) + + +def loss_computation(logits, label, losses): + check_logits_losses(logits, losses) + loss = 0 + for i in range(len(logits)): + logit = logits[i] + if logit.shape[-2:] != label.shape[-2:]: + logit = F.resize_bilinear(logit, label.shape[-2:]) + loss_i = losses['types'][i](logit, label) + loss += losses['coef'][i] * loss_i + return loss + + def train(model, train_dataset, places=None, @@ -40,7 +63,8 @@ def train(model, log_iters=10, num_classes=None, num_workers=8, - use_vdl=False): + use_vdl=False, + losses=None): ignore_index = model.ignore_index nranks = ParallelEnv().nranks @@ -90,13 +114,17 @@ def train(model, images = data[0] labels = data[1].astype('int64') if nranks > 1: - loss = ddp_model(images, labels) + logits = ddp_model(images) + loss = loss_computation(logits, labels, losses) + # loss = ddp_model(images, labels) # apply_collective_grads sum grads over multiple gpus. loss = ddp_model.scale_loss(loss) loss.backward() ddp_model.apply_collective_grads() else: - loss = model(images, labels) + logits = model(images) + loss = loss_computation(logits, labels, losses) + # loss = model(images, labels) loss.backward() optimizer.minimize(loss) model.clear_gradients() diff --git a/dygraph/core/val.py b/dygraph/core/val.py index e5e8dd4bfaa502d510cacf0dabb67a42d76ac9d7..22e84a314cd4ffe8093f81dad724f3d7d12a05fe 100644 --- a/dygraph/core/val.py +++ b/dygraph/core/val.py @@ -19,6 +19,8 @@ import tqdm import cv2 from paddle.fluid.dygraph.base import to_variable import paddle.fluid as fluid +import paddle.nn.functional as F +import paddle import dygraph.utils.logger as logger from dygraph.utils import ConfusionMatrix @@ -47,7 +49,9 @@ def evaluate(model, for iter, (im, im_info, label) in tqdm.tqdm( enumerate(eval_dataset), total=total_iters): im = to_variable(im) - pred, _ = model(im) + # pred, _ = model(im) + logits = model(im) + pred = paddle.argmax(logits[0], axis=1) pred = pred.numpy().astype('float32') pred = np.squeeze(pred) for info in im_info[::-1]: diff --git a/dygraph/cvlibs/__init__.py b/dygraph/cvlibs/__init__.py index 4085427caed2dd8ebe15eda343456ca38b8011f6..18812001388cbfd1ecf7dc4d38398ddd91711af4 100644 --- a/dygraph/cvlibs/__init__.py +++ b/dygraph/cvlibs/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from . import manager +from . import param_init diff --git a/dygraph/cvlibs/manager.py b/dygraph/cvlibs/manager.py index e4a952b3cca0f451b18a7d5bb2c9c0c4654d8c11..339070069c7e39532ec7fe2c826851a8d0f53df6 100644 --- a/dygraph/cvlibs/manager.py +++ b/dygraph/cvlibs/manager.py @@ -44,19 +44,20 @@ class ComponentManager: def __init__(self): self._components_dict = dict() - + def __len__(self): return len(self._components_dict) def __repr__(self): - return "{}:{}".format(self.__class__.__name__, list(self._components_dict.keys())) + return "{}:{}".format(self.__class__.__name__, + list(self._components_dict.keys())) def __getitem__(self, item): if item not in self._components_dict.keys(): - raise KeyError("{} does not exist in the current {}".format(item, self)) + raise KeyError("{} does not exist in the current {}".format( + item, self)) return self._components_dict[item] - @property def components_dict(self): return self._components_dict @@ -74,7 +75,9 @@ class ComponentManager: # Currently only support class or function type if not (inspect.isclass(component) or inspect.isfunction(component)): - raise TypeError("Expect class/function type, but received {}".format(type(component))) + raise TypeError( + "Expect class/function type, but received {}".format( + type(component))) # Obtain the internal name of the component component_name = component.__name__ @@ -92,7 +95,7 @@ class ComponentManager: Args: components (function | class | list | tuple): support three types of components - + Returns: None """ @@ -104,8 +107,12 @@ class ComponentManager: else: component = components self._add_single_component(component) - + return components + MODELS = ComponentManager() -BACKBONES = ComponentManager() \ No newline at end of file +BACKBONES = ComponentManager() +DATASETS = ComponentManager() +TRANSFORMS = ComponentManager() +LOSSES = ComponentManager() diff --git a/dygraph/cvlibs/param_init.py b/dygraph/cvlibs/param_init.py new file mode 100644 index 0000000000000000000000000000000000000000..567399c0a0c7d2310931b1c0ccae13cd0d5422b1 --- /dev/null +++ b/dygraph/cvlibs/param_init.py @@ -0,0 +1,25 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid + + +def constant_init(param, **kwargs): + initializer = fluid.initializer.Constant(**kwargs) + initializer(param, param.block) + + +def normal_init(param, **kwargs): + initializer = fluid.initializer.Normal(**kwargs) + initializer(param, param.block) diff --git a/dygraph/datasets/ade.py b/dygraph/datasets/ade.py index 1c9065e38f677290d81bb0d8be5224b2a54c0adf..8cb8ec2cebfac98d52283ccd21796553db36bffe 100644 --- a/dygraph/datasets/ade.py +++ b/dygraph/datasets/ade.py @@ -19,11 +19,14 @@ from PIL import Image from .dataset import Dataset from dygraph.utils.download import download_file_and_uncompress +from dygraph.cvlibs import manager +from dygraph.transforms import Compose DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" +@manager.DATASETS.add_component class ADE20K(Dataset): """ADE20K dataset `http://sceneparsing.csail.mit.edu/`. Args: @@ -39,7 +42,7 @@ class ADE20K(Dataset): transforms=None, download=True): self.dataset_root = dataset_root - self.transforms = transforms + self.transforms = Compose(transforms) self.mode = mode self.file_list = list() self.num_classes = 150 diff --git a/dygraph/datasets/cityscapes.py b/dygraph/datasets/cityscapes.py index 7f6c8742ffa287b31dad7394b4cd89db2a72bfc9..ee28754d290ec9ca0526c34d10d9b0ccaa89e6b7 100644 --- a/dygraph/datasets/cityscapes.py +++ b/dygraph/datasets/cityscapes.py @@ -16,8 +16,11 @@ import os import glob from .dataset import Dataset +from dygraph.cvlibs import manager +from dygraph.transforms import Compose +@manager.DATASETS.add_component class Cityscapes(Dataset): """Cityscapes dataset `https://www.cityscapes-dataset.com/`. The folder structure is as follow: @@ -42,7 +45,7 @@ class Cityscapes(Dataset): def __init__(self, dataset_root, transforms=None, mode='train'): self.dataset_root = dataset_root - self.transforms = transforms + self.transforms = Compose(transforms) self.file_list = list() self.mode = mode self.num_classes = 19 diff --git a/dygraph/datasets/dataset.py b/dygraph/datasets/dataset.py index 5853103582cfb0ebefdaebe36e11d51231a0d493..c65e20fd2e97511baf4159a3a1eaf2661927a21e 100644 --- a/dygraph/datasets/dataset.py +++ b/dygraph/datasets/dataset.py @@ -17,8 +17,12 @@ import os import paddle.fluid as fluid import numpy as np from PIL import Image +from dygraph.cvlibs import manager +from dygraph.transforms import Compose + +@manager.DATASETS.add_component class Dataset(fluid.io.Dataset): """Pass in a custom dataset that conforms to the format. @@ -52,7 +56,7 @@ class Dataset(fluid.io.Dataset): separator=' ', transforms=None): self.dataset_root = dataset_root - self.transforms = transforms + self.transforms = Compose(transforms) self.file_list = list() self.mode = mode self.num_classes = num_classes diff --git a/dygraph/datasets/optic_disc_seg.py b/dygraph/datasets/optic_disc_seg.py index 82d18e8c5e51ec12b487a252b1c4ac1dc77838d1..2c6d2b2d56febbe4b45130528c970a43e53d0fd9 100644 --- a/dygraph/datasets/optic_disc_seg.py +++ b/dygraph/datasets/optic_disc_seg.py @@ -16,11 +16,14 @@ import os from .dataset import Dataset from dygraph.utils.download import download_file_and_uncompress +from dygraph.cvlibs import manager +from dygraph.transforms import Compose DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" +@manager.DATASETS.add_component class OpticDiscSeg(Dataset): def __init__(self, dataset_root=None, @@ -28,7 +31,7 @@ class OpticDiscSeg(Dataset): mode='train', download=True): self.dataset_root = dataset_root - self.transforms = transforms + self.transforms = Compose(transforms) self.file_list = list() self.mode = mode self.num_classes = 2 diff --git a/dygraph/datasets/voc.py b/dygraph/datasets/voc.py index d11f4c9e7e2d39577e75e0f5705814ce3a189c53..da1f9971ff440fbedf10ec2debc7ddaccd372226 100644 --- a/dygraph/datasets/voc.py +++ b/dygraph/datasets/voc.py @@ -13,13 +13,17 @@ # limitations under the License. import os + from .dataset import Dataset from dygraph.utils.download import download_file_and_uncompress +from dygraph.cvlibs import manager +from dygraph.transforms import Compose DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" +@manager.DATASETS.add_component class PascalVOC(Dataset): """Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset, please run the voc_augment.py in tools. @@ -36,7 +40,7 @@ class PascalVOC(Dataset): transforms=None, download=True): self.dataset_root = dataset_root - self.transforms = transforms + self.transforms = Compose(transforms) self.mode = mode self.file_list = list() self.num_classes = 21 diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py index 52b73c3b7aa7e38868ca3588e0df6fd430431bf0..f3a62e3b39c80b47bb4d50e54f7dae4018cd2d32 100644 --- a/dygraph/models/__init__.py +++ b/dygraph/models/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. from .architectures import * +from .losses import * from .unet import UNet from .deeplab import * from .fcn import * from .pspnet import * +from .ocrnet import * diff --git a/dygraph/models/architectures/hrnet.py b/dygraph/models/architectures/hrnet.py index 3e12bc057fd9c90f412a86f0020fe6acbae7c89d..ea3db3b15f4bb2235761a3dea87ad3f49af3bbc9 100644 --- a/dygraph/models/architectures/hrnet.py +++ b/dygraph/models/architectures/hrnet.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import os import paddle import paddle.fluid as fluid @@ -23,6 +24,8 @@ from paddle.fluid.initializer import Normal from paddle.nn import SyncBatchNorm as BatchNorm from dygraph.cvlibs import manager +from dygraph.utils import utils +from dygraph.cvlibs import param_init __all__ = [ "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", @@ -36,6 +39,7 @@ class HRNet(fluid.dygraph.Layer): https://arxiv.org/pdf/1908.07919.pdf. Args: + backbone_pretrained (str): the path of pretrained model. stage1_num_modules (int): number of modules for stage1. Default 1. stage1_num_blocks (list): number of blocks per module for stage1. Default [4]. stage1_num_channels (list): number of channels per branch for stage1. Default [64]. @@ -52,6 +56,7 @@ class HRNet(fluid.dygraph.Layer): """ def __init__(self, + backbone_pretrained=None, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -141,6 +146,9 @@ class HRNet(fluid.dygraph.Layer): has_se=self.has_se, name="st4") + if self.training: + self.init_weight(backbone_pretrained) + def forward(self, x, label=None, mode='train'): input_shape = x.shape[2:] conv1 = self.conv_layer1_1(x) @@ -163,7 +171,31 @@ class HRNet(fluid.dygraph.Layer): x3 = fluid.layers.resize_bilinear(st4[3], out_shape=(x0_h, x0_w)) x = fluid.layers.concat([st4[0], x1, x2, x3], axis=1) - return x + return [x] + + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + Args: + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. + """ + params = self.parameters() + for param in params: + param_name = param.name + if 'batch_norm' in param_name: + if 'w_0' in param_name: + param_init.constant_init(param, 1.0) + elif 'b_0' in param_name: + param_init.constant_init(param, 0.0) + if 'conv' in param_name and 'w_0' in param_name: + param_init.normal_init(param, scale=0.001) + + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) class ConvBNLayer(fluid.dygraph.Layer): @@ -184,18 +216,8 @@ class ConvBNLayer(fluid.dygraph.Layer): stride=stride, padding=(filter_size - 1) // 2, groups=groups, - param_attr=ParamAttr( - initializer=Normal(scale=0.001), name=name + "_weights"), bias_attr=False) - bn_name = name + '_bn' - self._batch_norm = BatchNorm( - num_filters, - weight_attr=ParamAttr( - name=bn_name + '_scale', - initializer=fluid.initializer.Constant(1.0)), - bias_attr=ParamAttr( - bn_name + '_offset', - initializer=fluid.initializer.Constant(0.0))) + self._batch_norm = BatchNorm(num_filters) self.act = act def forward(self, input): diff --git a/dygraph/models/architectures/mobilenetv3.py b/dygraph/models/architectures/mobilenetv3.py index 2899e3f76567cee638b07c86174896a19f51bd2f..07805c1b806d18f47d96b8ae1a35c734625f67b3 100644 --- a/dygraph/models/architectures/mobilenetv3.py +++ b/dygraph/models/architectures/mobilenetv3.py @@ -17,8 +17,9 @@ from __future__ import division from __future__ import print_function import math -import numpy as np +import os +import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr @@ -28,6 +29,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm from dygraph.models.architectures import layer_utils from dygraph.cvlibs import manager +from dygraph.utils import utils __all__ = [ "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5", @@ -46,6 +48,7 @@ def make_divisible(v, divisor=8, min_value=None): new_v += divisor return new_v + def get_padding_same(kernel_size, dilation_rate): """ SAME padding implementation given kernel_size and dilation_rate. @@ -53,7 +56,7 @@ def get_padding_same(kernel_size, dilation_rate): (F-(k+(k -1)*(r-1))+2*p)/s + 1 = F_new where F: a feature map k: kernel size, r: dilation rate, p: padding value, s: stride - F_new: new feature map + F_new: new feature map Args: kernel_size (int) dilation_rate (int) @@ -63,12 +66,19 @@ def get_padding_same(kernel_size, dilation_rate): """ k = kernel_size r = dilation_rate - padding_same = (k + (k - 1) * (r - 1) - 1)//2 + padding_same = (k + (k - 1) * (r - 1) - 1) // 2 return padding_same + class MobileNetV3(fluid.dygraph.Layer): - def __init__(self, scale=1.0, model_name="small", class_dim=1000, output_stride=None, **kwargs): + def __init__(self, + backbone_pretrained=None, + scale=1.0, + model_name="small", + class_dim=1000, + output_stride=None, + **kwargs): super(MobileNetV3, self).__init__() inplanes = 16 @@ -77,19 +87,21 @@ class MobileNetV3(fluid.dygraph.Layer): # k, exp, c, se, nl, s, [3, 16, 16, False, "relu", 1], [3, 64, 24, False, "relu", 2], - [3, 72, 24, False, "relu", 1], # output 1 -> out_index=2 + [3, 72, 24, False, "relu", 1], # output 1 -> out_index=2 [5, 72, 40, True, "relu", 2], [5, 120, 40, True, "relu", 1], - [5, 120, 40, True, "relu", 1], # output 2 -> out_index=5 + [5, 120, 40, True, "relu", 1], # output 2 -> out_index=5 [3, 240, 80, False, "hard_swish", 2], [3, 200, 80, False, "hard_swish", 1], [3, 184, 80, False, "hard_swish", 1], [3, 184, 80, False, "hard_swish", 1], [3, 480, 112, True, "hard_swish", 1], - [3, 672, 112, True, "hard_swish", 1], # output 3 -> out_index=11 + [3, 672, 112, True, "hard_swish", + 1], # output 3 -> out_index=11 [5, 672, 160, True, "hard_swish", 2], [5, 960, 160, True, "hard_swish", 1], - [5, 960, 160, True, "hard_swish", 1], # output 3 -> out_index=14 + [5, 960, 160, True, "hard_swish", + 1], # output 3 -> out_index=14 ] self.out_indices = [2, 5, 11, 14] @@ -98,17 +110,17 @@ class MobileNetV3(fluid.dygraph.Layer): elif model_name == "small": self.cfg = [ # k, exp, c, se, nl, s, - [3, 16, 16, True, "relu", 2], # output 1 -> out_index=0 + [3, 16, 16, True, "relu", 2], # output 1 -> out_index=0 [3, 72, 24, False, "relu", 2], - [3, 88, 24, False, "relu", 1], # output 2 -> out_index=3 + [3, 88, 24, False, "relu", 1], # output 2 -> out_index=3 [5, 96, 40, True, "hard_swish", 2], [5, 240, 40, True, "hard_swish", 1], [5, 240, 40, True, "hard_swish", 1], [5, 120, 48, True, "hard_swish", 1], - [5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7 + [5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7 [5, 288, 96, True, "hard_swish", 2], [5, 576, 96, True, "hard_swish", 1], - [5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10 + [5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10 ] self.out_indices = [0, 3, 7, 10] @@ -157,7 +169,6 @@ class MobileNetV3(fluid.dygraph.Layer): self.add_sublayer( sublayer=self.block_list[-1], name="conv" + str(i + 2)) inplanes = make_divisible(scale * c) - self.last_second_conv = ConvBNLayer( in_c=inplanes, @@ -189,8 +200,10 @@ class MobileNetV3(fluid.dygraph.Layer): param_attr=ParamAttr("fc_weights"), bias_attr=ParamAttr(name="fc_offset")) + self.init_weight(backbone_pretrained) + def modify_bottle_params(self, output_stride=None): - + if output_stride is not None and output_stride % 2 != 0: raise Exception("output stride must to be even number") if output_stride is not None: @@ -201,9 +214,9 @@ class MobileNetV3(fluid.dygraph.Layer): if stride > output_stride: rate = rate * _cfg[-1] self.cfg[i][-1] = 1 - + self.dilation_cfg[i] = rate - + def forward(self, inputs, label=None, dropout_prob=0.2): x = self.conv1(inputs) # A feature list saves each downsampling feature. @@ -223,6 +236,19 @@ class MobileNetV3(fluid.dygraph.Layer): return x, feat_list + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + Args: + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. + """ + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) + class ConvBNLayer(fluid.dygraph.Layer): def __init__(self, @@ -240,7 +266,7 @@ class ConvBNLayer(fluid.dygraph.Layer): super(ConvBNLayer, self).__init__() self.if_act = if_act self.act = act - + self.conv = fluid.dygraph.Conv2D( num_channels=in_c, num_filters=out_c, @@ -263,7 +289,7 @@ class ConvBNLayer(fluid.dygraph.Layer): name=name + "_bn_offset", regularizer=fluid.regularizer.L2DecayRegularizer( regularization_coeff=0.0))) - + self._act_op = layer_utils.Activation(act=None) def forward(self, x): @@ -304,14 +330,15 @@ class ResidualUnit(fluid.dygraph.Layer): if_act=True, act=act, name=name + "_expand") - - + self.bottleneck_conv = ConvBNLayer( in_c=mid_c, out_c=mid_c, filter_size=filter_size, stride=stride, - padding= get_padding_same(filter_size, dilation), #int((filter_size - 1) // 2) + (dilation - 1), + padding=get_padding_same( + filter_size, + dilation), #int((filter_size - 1) // 2) + (dilation - 1), dilation=dilation, num_groups=mid_c, if_act=True, @@ -329,6 +356,7 @@ class ResidualUnit(fluid.dygraph.Layer): act=None, name=name + "_linear") self.dilation = dilation + def forward(self, inputs): x = self.expand_conv(inputs) x = self.bottleneck_conv(x) @@ -386,6 +414,7 @@ def MobileNetV3_small_x0_75(**kwargs): model = MobileNetV3(model_name="small", scale=0.75, **kwargs) return model + @manager.BACKBONES.add_component def MobileNetV3_small_x1_0(**kwargs): model = MobileNetV3(model_name="small", scale=1.0, **kwargs) @@ -411,6 +440,7 @@ def MobileNetV3_large_x0_75(**kwargs): model = MobileNetV3(model_name="large", scale=0.75, **kwargs) return model + @manager.BACKBONES.add_component def MobileNetV3_large_x1_0(**kwargs): model = MobileNetV3(model_name="large", scale=1.0, **kwargs) diff --git a/dygraph/models/architectures/resnet_vd.py b/dygraph/models/architectures/resnet_vd.py index c27c810c46c0bbdc06053e747c7a7eaeb22be6e1..582934505385872c60ff92204fd862836e6ae7fb 100644 --- a/dygraph/models/architectures/resnet_vd.py +++ b/dygraph/models/architectures/resnet_vd.py @@ -30,6 +30,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm from dygraph.utils import utils from dygraph.models.architectures import layer_utils from dygraph.cvlibs import manager +from dygraph.utils import utils __all__ = [ "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd" @@ -47,18 +48,23 @@ class ConvBNLayer(fluid.dygraph.Layer): groups=1, is_vd_mode=False, act=None, - name=None, ): + name=None, + ): super(ConvBNLayer, self).__init__() self.is_vd_mode = is_vd_mode self._pool2d_avg = Pool2D( - pool_size=2, pool_stride=2, pool_padding=0, pool_type='avg', ceil_mode=True) + pool_size=2, + pool_stride=2, + pool_padding=0, + pool_type='avg', + ceil_mode=True) self._conv = Conv2D( num_channels=num_channels, num_filters=num_filters, filter_size=filter_size, stride=stride, - padding=(filter_size - 1) // 2 if dilation ==1 else 0, + padding=(filter_size - 1) // 2 if dilation == 1 else 0, dilation=dilation, groups=groups, act=None, @@ -125,19 +131,20 @@ class BottleneckBlock(fluid.dygraph.Layer): num_filters=num_filters * 4, filter_size=1, stride=1, - is_vd_mode=False if if_first or stride==1 else True, + is_vd_mode=False if if_first or stride == 1 else True, name=name + "_branch1") self.shortcut = shortcut def forward(self, inputs): y = self.conv0(inputs) - + #################################################################### # If given dilation rate > 1, using corresponding padding if self.dilation > 1: padding = self.dilation - y = fluid.layers.pad(y, [0,0,0,0,padding,padding,padding,padding]) + y = fluid.layers.pad( + y, [0, 0, 0, 0, padding, padding, padding, padding]) ##################################################################### conv1 = self.conv1(y) conv2 = self.conv2(conv1) @@ -196,15 +203,21 @@ class BasicBlock(fluid.dygraph.Layer): else: short = self.short(inputs) y = fluid.layers.elementwise_add(x=short, y=conv1) - + layer_helper = LayerHelper(self.full_name(), act='relu') return layer_helper.append_activation(y) class ResNet_vd(fluid.dygraph.Layer): - def __init__(self, layers=50, class_dim=1000, output_stride=None, multi_grid=(1, 2, 4), **kwargs): + def __init__(self, + backbone_pretrained=None, + layers=50, + class_dim=1000, + output_stride=None, + multi_grid=(1, 2, 4), + **kwargs): super(ResNet_vd, self).__init__() - + self.layers = layers supported_layers = [18, 34, 50, 101, 152, 200] assert layers in supported_layers, \ @@ -221,11 +234,11 @@ class ResNet_vd(fluid.dygraph.Layer): depth = [3, 8, 36, 3] elif layers == 200: depth = [3, 12, 48, 3] - num_channels = [64, 256, 512, - 1024] if layers >= 50 else [64, 64, 128, 256] + num_channels = [64, 256, 512, 1024 + ] if layers >= 50 else [64, 64, 128, 256] num_filters = [64, 128, 256, 512] - dilation_dict=None + dilation_dict = None if output_stride == 8: dilation_dict = {2: 2, 3: 4} elif output_stride == 16: @@ -254,13 +267,13 @@ class ResNet_vd(fluid.dygraph.Layer): name="conv1_3") self.pool2d_max = Pool2D( pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') - + # self.block_list = [] self.stage_list = [] if layers >= 50: for block in range(len(depth)): shortcut = False - block_list=[] + block_list = [] for i in range(depth[block]): if layers in [101, 152] and block == 2: if i == 0: @@ -269,11 +282,12 @@ class ResNet_vd(fluid.dygraph.Layer): conv_name = "res" + str(block + 2) + "b" + str(i) else: conv_name = "res" + str(block + 2) + chr(97 + i) - + ############################################################################### # Add dilation rate for some segmentation tasks, if dilation_dict is not None. - dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1 - + dilation_rate = dilation_dict[ + block] if dilation_dict and block in dilation_dict else 1 + # Actually block here is 'stage', and i is 'block' in 'stage' # At the stage 4, expand the the dilation_rate using multi_grid, default (1, 2, 4) if block == 3: @@ -284,9 +298,11 @@ class ResNet_vd(fluid.dygraph.Layer): bottleneck_block = self.add_sublayer( 'bb_%d_%d' % (block, i), BottleneckBlock( - num_channels=num_channels[block] if i == 0 else num_filters[block] * 4, + num_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1, + stride=2 if i == 0 and block != 0 + and dilation_rate == 1 else 1, shortcut=shortcut, if_first=block == i == 0, name=conv_name, @@ -298,7 +314,7 @@ class ResNet_vd(fluid.dygraph.Layer): else: for block in range(len(depth)): shortcut = False - block_list=[] + block_list = [] for i in range(depth[block]): conv_name = "res" + str(block + 2) + chr(97 + i) basic_block = self.add_sublayer( @@ -330,6 +346,8 @@ class ResNet_vd(fluid.dygraph.Layer): name="fc_0.w_0"), bias_attr=ParamAttr(name="fc_0.b_0")) + self.init_weight(backbone_pretrained) + def forward(self, inputs): y = self.conv1_1(inputs) y = self.conv1_2(y) @@ -343,7 +361,7 @@ class ResNet_vd(fluid.dygraph.Layer): y = block(y) #print("stage {} block {}".format(i+1, j+1), y.shape) feat_list.append(y) - + y = self.pool2d_avg(y) y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_channels]) y = self.out(y) @@ -355,8 +373,18 @@ class ResNet_vd(fluid.dygraph.Layer): # if os.path.exists(pretrained_model): # utils.load_pretrained_model(self, pretrained_model) - - + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + Args: + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. + """ + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) def ResNet18_vd(**args): @@ -368,11 +396,13 @@ def ResNet34_vd(**args): model = ResNet_vd(layers=34, **args) return model + @manager.BACKBONES.add_component def ResNet50_vd(**args): model = ResNet_vd(layers=50, **args) return model + @manager.BACKBONES.add_component def ResNet101_vd(**args): model = ResNet_vd(layers=101, **args) @@ -386,4 +416,4 @@ def ResNet152_vd(**args): def ResNet200_vd(**args): model = ResNet_vd(layers=200, **args) - return model \ No newline at end of file + return model diff --git a/dygraph/models/architectures/xception_deeplab.py b/dygraph/models/architectures/xception_deeplab.py index f96dcb6936e25444c1d79b2461b941634fbb4c2f..4f7d97f837fcc2b7394be3ceef15b06387a5844a 100644 --- a/dygraph/models/architectures/xception_deeplab.py +++ b/dygraph/models/architectures/xception_deeplab.py @@ -1,3 +1,19 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr @@ -7,6 +23,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm from dygraph.models.architectures import layer_utils from dygraph.cvlibs import manager +from dygraph.utils import utils __all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"] @@ -86,11 +103,11 @@ class ConvBNLayer(fluid.dygraph.Layer): momentum=0.99, weight_attr=ParamAttr(name=name + "/BatchNorm/gamma"), bias_attr=ParamAttr(name=name + "/BatchNorm/beta")) - + self._act_op = layer_utils.Activation(act=act) def forward(self, inputs): - + return self._act_op(self._bn(self._conv(inputs))) @@ -121,7 +138,7 @@ class Seperate_Conv(fluid.dygraph.Layer): momentum=0.99, weight_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"), bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta")) - + self._act_op1 = layer_utils.Activation(act=act) self._conv2 = Conv2D( @@ -139,9 +156,8 @@ class Seperate_Conv(fluid.dygraph.Layer): momentum=0.99, weight_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"), bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta")) - + self._act_op2 = layer_utils.Activation(act=act) - def forward(self, inputs): x = self._conv1(inputs) @@ -254,11 +270,16 @@ class Xception_Block(fluid.dygraph.Layer): class XceptionDeeplab(fluid.dygraph.Layer): - + #def __init__(self, backbone, class_dim=1000): # add output_stride - def __init__(self, backbone, output_stride=16, class_dim=1000, **kwargs): - + def __init__(self, + backbone, + backbone_pretrained=None, + output_stride=16, + class_dim=1000, + **kwargs): + super(XceptionDeeplab, self).__init__() bottleneck_params = gen_bottleneck_params(backbone) @@ -280,7 +301,6 @@ class XceptionDeeplab(fluid.dygraph.Layer): padding=1, act="relu", name=self.backbone + "/entry_flow/conv2") - """ bottleneck_params = { "entry_flow": (3, [2, 2, 2], [128, 256, 728]), @@ -381,6 +401,8 @@ class XceptionDeeplab(fluid.dygraph.Layer): param_attr=ParamAttr(name="fc_weights"), bias_attr=ParamAttr(name="fc_bias")) + self.init_weight(backbone_pretrained) + def forward(self, inputs): x = self._conv1(inputs) x = self._conv2(x) @@ -394,18 +416,32 @@ class XceptionDeeplab(fluid.dygraph.Layer): x = self._exit_flow_1(x) x = self._exit_flow_2(x) feat_list.append(x) - + x = self._drop(x) x = self._pool(x) x = fluid.layers.squeeze(x, axes=[2, 3]) x = self._fc(x) return x, feat_list + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + Args: + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. + """ + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) + def Xception41_deeplab(**args): model = XceptionDeeplab('xception_41', **args) return model + @manager.BACKBONES.add_component def Xception65_deeplab(**args): model = XceptionDeeplab("xception_65", **args) @@ -414,4 +450,4 @@ def Xception65_deeplab(**args): def Xception71_deeplab(**args): model = XceptionDeeplab("xception_71", **args) - return model \ No newline at end of file + return model diff --git a/dygraph/models/deeplab.py b/dygraph/models/deeplab.py index e9a0167f044e658a3c494d2c465f3ed729ef9367..6911b63900d62b427e94a2b22e4919f6b664f250 100644 --- a/dygraph/models/deeplab.py +++ b/dygraph/models/deeplab.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os from dygraph.cvlibs import manager @@ -23,10 +22,12 @@ from paddle.fluid.dygraph import Conv2D from dygraph.utils import utils -__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8", - "deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8", - "deeplabv3p_xception65_deeplab", - "deeplabv3p_mobilenetv3_large", "deeplabv3p_mobilenetv3_small"] +__all__ = [ + 'DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8", + "deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8", + "deeplabv3p_xception65_deeplab", "deeplabv3p_mobilenetv3_large", + "deeplabv3p_mobilenetv3_small" +] class ImageAverage(dygraph.Layer): @@ -40,9 +41,8 @@ class ImageAverage(dygraph.Layer): def __init__(self, num_channels): super(ImageAverage, self).__init__() - self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels, - num_filters=256, - filter_size=1) + self.conv_bn_relu = layer_utils.ConvBnRelu( + num_channels, num_filters=256, filter_size=1) def forward(self, input): x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True) @@ -69,44 +69,49 @@ class ASPP(dygraph.Layer): elif output_stride == 8: aspp_ratios = (12, 24, 36) else: - raise NotImplementedError("Only support output_stride is 8 or 16, but received{}".format(output_stride)) + raise NotImplementedError( + "Only support output_stride is 8 or 16, but received{}".format( + output_stride)) self.image_average = ImageAverage(num_channels=in_channels) # The first aspp using 1*1 conv - self.aspp1 = layer_utils.ConvBnRelu(num_channels=in_channels, - num_filters=256, - filter_size=1, - using_sep_conv=False) + self.aspp1 = layer_utils.ConvBnRelu( + num_channels=in_channels, + num_filters=256, + filter_size=1, + using_sep_conv=False) # The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0] - self.aspp2 = layer_utils.ConvBnRelu(num_channels=in_channels, - num_filters=256, - filter_size=3, - using_sep_conv=using_sep_conv, - dilation=aspp_ratios[0], - padding=aspp_ratios[0]) + self.aspp2 = layer_utils.ConvBnRelu( + num_channels=in_channels, + num_filters=256, + filter_size=3, + using_sep_conv=using_sep_conv, + dilation=aspp_ratios[0], + padding=aspp_ratios[0]) # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1] - self.aspp3 = layer_utils.ConvBnRelu(num_channels=in_channels, - num_filters=256, - filter_size=3, - using_sep_conv=using_sep_conv, - dilation=aspp_ratios[1], - padding=aspp_ratios[1]) + self.aspp3 = layer_utils.ConvBnRelu( + num_channels=in_channels, + num_filters=256, + filter_size=3, + using_sep_conv=using_sep_conv, + dilation=aspp_ratios[1], + padding=aspp_ratios[1]) # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2] - self.aspp4 = layer_utils.ConvBnRelu(num_channels=in_channels, - num_filters=256, - filter_size=3, - using_sep_conv=using_sep_conv, - dilation=aspp_ratios[2], - padding=aspp_ratios[2]) + self.aspp4 = layer_utils.ConvBnRelu( + num_channels=in_channels, + num_filters=256, + filter_size=3, + using_sep_conv=using_sep_conv, + dilation=aspp_ratios[2], + padding=aspp_ratios[2]) # After concat op, using 1*1 conv - self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280, - num_filters=256, - filter_size=1) + self.conv_bn_relu = layer_utils.ConvBnRelu( + num_channels=1280, num_filters=256, filter_size=1) def forward(self, x): @@ -136,23 +141,23 @@ class Decoder(dygraph.Layer): def __init__(self, num_classes, in_channels, using_sep_conv=True): super(Decoder, self).__init__() - self.conv_bn_relu1 = layer_utils.ConvBnRelu(num_channels=in_channels, - num_filters=48, - filter_size=1) - - self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=304, - num_filters=256, - filter_size=3, - using_sep_conv=using_sep_conv, - padding=1) - self.conv_bn_relu3 = layer_utils.ConvBnRelu(num_channels=256, - num_filters=256, - filter_size=3, - using_sep_conv=using_sep_conv, - padding=1) - self.conv = Conv2D(num_channels=256, - num_filters=num_classes, - filter_size=1) + self.conv_bn_relu1 = layer_utils.ConvBnRelu( + num_channels=in_channels, num_filters=48, filter_size=1) + + self.conv_bn_relu2 = layer_utils.ConvBnRelu( + num_channels=304, + num_filters=256, + filter_size=3, + using_sep_conv=using_sep_conv, + padding=1) + self.conv_bn_relu3 = layer_utils.ConvBnRelu( + num_channels=256, + num_filters=256, + filter_size=3, + using_sep_conv=using_sep_conv, + padding=1) + self.conv = Conv2D( + num_channels=256, num_filters=num_classes, filter_size=1) def forward(self, x, low_level_feat): low_level_feat = self.conv_bn_relu1(low_level_feat) @@ -164,6 +169,7 @@ class Decoder(dygraph.Layer): return x +@manager.MODELS.add_component class DeepLabV3P(dygraph.Layer): """ The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder @@ -173,9 +179,11 @@ class DeepLabV3P(dygraph.Layer): (https://arxiv.org/abs/1802.02611) Args: - backbone (str): backbone name, currently support Xception65, Resnet101_vd. Default Resnet101_vd. + num_classes (int): the unique number of target classes. + + backbone (paddle.nn.Layer): backbone networks, currently support Xception65, Resnet101_vd. Default Resnet101_vd. - num_classes (int): the unique number of target classes. Default 2. + model_pretrained (str): the path of pretrained model. output_stride (int): the ratio of input size and final feature size. Default 16. @@ -193,28 +201,29 @@ class DeepLabV3P(dygraph.Layer): using_sep_conv (bool): a bool value indicates whether using separable convolutions in ASPP and Decoder components. Default True. - pretrained_model (str): the pretrained_model path of backbone. """ def __init__(self, + num_classes, backbone, - num_classes=2, + model_pretrained=None, output_stride=16, backbone_indices=(0, 3), backbone_channels=(256, 2048), ignore_index=255, - using_sep_conv=True, - pretrained_model=None): + using_sep_conv=True): super(DeepLabV3P, self).__init__() - self.backbone = manager.BACKBONES[backbone](output_stride=output_stride) + # self.backbone = manager.BACKBONES[backbone](output_stride=output_stride) + self.backbone = backbone self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv) - self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv) + self.decoder = Decoder(num_classes, backbone_channels[0], + using_sep_conv) self.ignore_index = ignore_index self.EPS = 1e-5 self.backbone_indices = backbone_indices - self.init_weight(pretrained_model) + self.init_weight(model_pretrained) def forward(self, input, label=None): @@ -238,14 +247,14 @@ class DeepLabV3P(dygraph.Layer): """ Initialize the parameters of model parts. Args: - pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. """ if pretrained_model is not None: if os.path.exists(pretrained_model): - utils.load_pretrained_model(self.backbone, pretrained_model) - # utils.load_pretrained_model(self, pretrained_model) - # for param in self.backbone.parameters(): - # param.stop_gradient = True + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) def _get_loss(self, logit, label): """ @@ -271,7 +280,7 @@ class DeepLabV3P(dygraph.Layer): loss = loss * mask avg_loss = fluid.layers.mean(loss) / ( - fluid.layers.mean(mask) + self.EPS) + fluid.layers.mean(mask) + self.EPS) label.stop_gradient = True mask.stop_gradient = True @@ -290,52 +299,65 @@ def build_decoder(num_classes, using_sep_conv): @manager.MODELS.add_component def deeplabv3p_resnet101_vd(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) + return DeepLabV3P( + backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) @manager.MODELS.add_component def deeplabv3p_resnet101_vd_os8(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) + return DeepLabV3P( + backbone='ResNet101_vd', + output_stride=8, + pretrained_model=pretrained_model, + **kwargs) @manager.MODELS.add_component def deeplabv3p_resnet50_vd(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) + return DeepLabV3P( + backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) @manager.MODELS.add_component def deeplabv3p_resnet50_vd_os8(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) + return DeepLabV3P( + backbone='ResNet50_vd', + output_stride=8, + pretrained_model=pretrained_model, + **kwargs) @manager.MODELS.add_component def deeplabv3p_xception65_deeplab(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='Xception65_deeplab', - pretrained_model=pretrained_model, - backbone_indices=(0, 1), - backbone_channels=(128, 2048), - **kwargs) + return DeepLabV3P( + backbone='Xception65_deeplab', + pretrained_model=pretrained_model, + backbone_indices=(0, 1), + backbone_channels=(128, 2048), + **kwargs) @manager.MODELS.add_component def deeplabv3p_mobilenetv3_large(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='MobileNetV3_large_x1_0', - pretrained_model=pretrained_model, - backbone_indices=(0, 3), - backbone_channels=(24, 160), - **kwargs) + return DeepLabV3P( + backbone='MobileNetV3_large_x1_0', + pretrained_model=pretrained_model, + backbone_indices=(0, 3), + backbone_channels=(24, 160), + **kwargs) @manager.MODELS.add_component def deeplabv3p_mobilenetv3_small(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='MobileNetV3_small_x1_0', - pretrained_model=pretrained_model, - backbone_indices=(0, 3), - backbone_channels=(16, 96), - **kwargs) + return DeepLabV3P( + backbone='MobileNetV3_small_x1_0', + pretrained_model=pretrained_model, + backbone_indices=(0, 3), + backbone_channels=(16, 96), + **kwargs) diff --git a/dygraph/models/fcn.py b/dygraph/models/fcn.py index ce1ab409dbc39f55c56a49be794d1bbb23e9336b..a852cff88bc6062c154c8c9df0d5c99b604f0abf 100644 --- a/dygraph/models/fcn.py +++ b/dygraph/models/fcn.py @@ -25,6 +25,7 @@ from paddle.nn import SyncBatchNorm as BatchNorm from dygraph.cvlibs import manager from dygraph import utils +from dygraph.cvlibs import param_init __all__ = [ "fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18", @@ -33,115 +34,105 @@ __all__ = [ ] +@manager.MODELS.add_component class FCN(fluid.dygraph.Layer): """ Fully Convolutional Networks for Semantic Segmentation. https://arxiv.org/abs/1411.4038 Args: - backbone (str): backbone name, num_classes (int): the unique number of target classes. - in_channels (int): the channels of input feature maps. + + backbone (paddle.nn.Layer): backbone networks. + + model_pretrained (str): the path of pretrained model. + + backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1. + + backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index. + channels (int): channels after conv layer before the last one. - pretrained_model (str): the path of pretrained model. + ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255. """ def __init__(self, - backbone, num_classes, - in_channels, + backbone, + model_pretrained=None, + backbone_indices=(-1, ), + backbone_channels=(270, ), channels=None, - pretrained_model=None, ignore_index=255, **kwargs): super(FCN, self).__init__() self.num_classes = num_classes + self.backbone_indices = backbone_indices self.ignore_index = ignore_index self.EPS = 1e-5 if channels is None: - channels = in_channels + channels = backbone_channels[backbone_indices[0]] - self.backbone = manager.BACKBONES[backbone](**kwargs) + self.backbone = backbone self.conv_last_2 = ConvBNLayer( - num_channels=in_channels, + num_channels=backbone_channels[backbone_indices[0]], num_filters=channels, filter_size=1, - stride=1, - name='conv-2') + stride=1) self.conv_last_1 = Conv2D( num_channels=channels, num_filters=self.num_classes, filter_size=1, stride=1, - padding=0, - param_attr=ParamAttr( - initializer=Normal(scale=0.001), name='conv-1_weights')) - self.init_weight(pretrained_model) + padding=0) + if self.training: + self.init_weight(model_pretrained) - def forward(self, x, label=None, mode='train'): + def forward(self, x): input_shape = x.shape[2:] - x = self.backbone(x) + fea_list = self.backbone(x) + x = fea_list[self.backbone_indices[0]] x = self.conv_last_2(x) logit = self.conv_last_1(x) logit = fluid.layers.resize_bilinear(logit, input_shape) - - if self.training: - if label is None: - raise Exception('Label is need during training') - return self._get_loss(logit, label) - else: - score_map = fluid.layers.softmax(logit, axis=1) - score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) - pred = fluid.layers.argmax(score_map, axis=3) - pred = fluid.layers.unsqueeze(pred, axes=[3]) - return pred, score_map + return [logit] + + # if self.training: + # if label is None: + # raise Exception('Label is need during training') + # return self._get_loss(logit, label) + # else: + # score_map = fluid.layers.softmax(logit, axis=1) + # score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) + # pred = fluid.layers.argmax(score_map, axis=3) + # pred = fluid.layers.unsqueeze(pred, axes=[3]) + # return pred, score_map def init_weight(self, pretrained_model=None): """ Initialize the parameters of model parts. Args: - pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. """ + params = self.parameters() + for param in params: + param_name = param.name + if 'batch_norm' in param_name: + if 'w_0' in param_name: + param_init.constant_init(param, 1.0) + elif 'b_0' in param_name: + param_init.constant_init(param, 0.0) + if 'conv' in param_name and 'w_0' in param_name: + param_init.normal_init(param, scale=0.001) + if pretrained_model is not None: if os.path.exists(pretrained_model): - utils.load_pretrained_model(self.backbone, pretrained_model) utils.load_pretrained_model(self, pretrained_model) else: raise Exception('Pretrained model is not found: {}'.format( pretrained_model)) - def _get_loss(self, logit, label): - """ - compute forward loss of the model - - Args: - logit (tensor): the logit of model output - label (tensor): ground truth - - Returns: - avg_loss (tensor): forward loss - """ - logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) - label = fluid.layers.transpose(label, [0, 2, 3, 1]) - mask = label != self.ignore_index - mask = fluid.layers.cast(mask, 'float32') - loss, probs = fluid.layers.softmax_with_cross_entropy( - logit, - label, - ignore_index=self.ignore_index, - return_softmax=True, - axis=-1) - - loss = loss * mask - avg_loss = fluid.layers.mean(loss) / ( - fluid.layers.mean(mask) + self.EPS) - - label.stop_gradient = True - mask.stop_gradient = True - return avg_loss - class ConvBNLayer(fluid.dygraph.Layer): def __init__(self, @@ -150,8 +141,7 @@ class ConvBNLayer(fluid.dygraph.Layer): filter_size, stride=1, groups=1, - act="relu", - name=None): + act="relu"): super(ConvBNLayer, self).__init__() self._conv = Conv2D( @@ -161,18 +151,8 @@ class ConvBNLayer(fluid.dygraph.Layer): stride=stride, padding=(filter_size - 1) // 2, groups=groups, - param_attr=ParamAttr( - initializer=Normal(scale=0.001), name=name + "_weights"), bias_attr=False) - bn_name = name + '_bn' - self._batch_norm = BatchNorm( - num_filters, - weight_attr=ParamAttr( - name=bn_name + '_scale', - initializer=fluid.initializer.Constant(1.0)), - bias_attr=ParamAttr( - bn_name + '_offset', - initializer=fluid.initializer.Constant(0.0))) + self._batch_norm = BatchNorm(num_filters) self.act = act def forward(self, input): @@ -185,49 +165,49 @@ class ConvBNLayer(fluid.dygraph.Layer): @manager.MODELS.add_component def fcn_hrnet_w18_small_v1(*args, **kwargs): - return FCN(backbone='HRNet_W18_Small_V1', in_channels=240, **kwargs) + return FCN(backbone='HRNet_W18_Small_V1', backbone_channels=(240), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w18_small_v2(*args, **kwargs): - return FCN(backbone='HRNet_W18_Small_V2', in_channels=270, **kwargs) + return FCN(backbone='HRNet_W18_Small_V2', backbone_channels=(270), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w18(*args, **kwargs): - return FCN(backbone='HRNet_W18', in_channels=270, **kwargs) + return FCN(backbone='HRNet_W18', backbone_channels=(270), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w30(*args, **kwargs): - return FCN(backbone='HRNet_W30', in_channels=450, **kwargs) + return FCN(backbone='HRNet_W30', backbone_channels=(450), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w32(*args, **kwargs): - return FCN(backbone='HRNet_W32', in_channels=480, **kwargs) + return FCN(backbone='HRNet_W32', backbone_channels=(480), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w40(*args, **kwargs): - return FCN(backbone='HRNet_W40', in_channels=600, **kwargs) + return FCN(backbone='HRNet_W40', backbone_channels=(600), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w44(*args, **kwargs): - return FCN(backbone='HRNet_W44', in_channels=660, **kwargs) + return FCN(backbone='HRNet_W44', backbone_channels=(660), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w48(*args, **kwargs): - return FCN(backbone='HRNet_W48', in_channels=720, **kwargs) + return FCN(backbone='HRNet_W48', backbone_channels=(720), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w60(*args, **kwargs): - return FCN(backbone='HRNet_W60', in_channels=900, **kwargs) + return FCN(backbone='HRNet_W60', backbone_channels=(900), **kwargs) @manager.MODELS.add_component def fcn_hrnet_w64(*args, **kwargs): - return FCN(backbone='HRNet_W64', in_channels=960, **kwargs) + return FCN(backbone='HRNet_W64', backbone_channels=(960), **kwargs) diff --git a/dygraph/models/losses/__init__.py b/dygraph/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f58a9fe1dccce025fa5ee9dec8887fbfc3b9deb8 --- /dev/null +++ b/dygraph/models/losses/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cross_entroy_loss import CrossEntropyLoss diff --git a/dygraph/models/losses/cross_entroy_loss.py b/dygraph/models/losses/cross_entroy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d37116d186556b1c9e3a3e2d71ebca84653348ca --- /dev/null +++ b/dygraph/models/losses/cross_entroy_loss.py @@ -0,0 +1,119 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +import paddle.nn.functional as F + +from dygraph.cvlibs import manager +''' +@manager.LOSSES.add_component +class CrossEntropyLoss(nn.CrossEntropyLoss): + """ + Implements the cross entropy loss function. + + Args: + weight (Tensor): Weight tensor, a manual rescaling weight given + to each class and the shape is (C). It has the same dimensions as class + number and the data type is float32, float64. Default ``'None'``. + ignore_index (int64): Specifies a target value that is ignored + and does not contribute to the input gradient. Default ``255``. + reduction (str): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned. + Default ``'mean'``. + + """ + + def __init__(self, weight=None, ignore_index=255, reduction='mean'): + self.weight = weight + self.ignore_index = ignore_index + self.reduction = reduction + self.EPS = 1e-5 + if self.reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or" + " 'none', but received %s, which is not allowed." % + self.reduction) + + def forward(self, logit, label): + """ + Forward computation. + Args: + logit (Tensor): logit tensor, the data type is float32, float64. Shape is + (N, C), where C is number of classes, and if shape is more than 2D, this + is (N, C, D1, D2,..., Dk), k >= 1. + label (Variable): label tensor, the data type is int64. Shape is (N), where each + value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is + (N, D1, D2,..., Dk), k >= 1. + """ + loss = paddle.nn.functional.cross_entropy( + logit, + label, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction) + + mask = label != self.ignore_index + mask = paddle.cast(mask, 'float32') + avg_loss = loss / (paddle.mean(mask) + self.EPS) + + label.stop_gradient = True + mask.stop_gradient = True + return avg_loss +''' + + +@manager.LOSSES.add_component +class CrossEntropyLoss(nn.Layer): + """ + Implements the cross entropy loss function. + + Args: + ignore_index (int64): Specifies a target value that is ignored + and does not contribute to the input gradient. Default ``255``. + """ + + def __init__(self, ignore_index=255): + super(CrossEntropyLoss, self).__init__() + self.ignore_index = ignore_index + self.EPS = 1e-5 + + def forward(self, logit, label): + """ + Forward computation. + Args: + logit (Tensor): logit tensor, the data type is float32, float64. Shape is + (N, C), where C is number of classes, and if shape is more than 2D, this + is (N, C, D1, D2,..., Dk), k >= 1. + label (Variable): label tensor, the data type is int64. Shape is (N), where each + value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is + (N, D1, D2,..., Dk), k >= 1. + """ + if len(label.shape) != len(logit.shape): + label = paddle.unsqueeze(label, 1) + + loss = F.softmax_with_cross_entropy( + logit, label, ignore_index=self.ignore_index, axis=1) + loss = paddle.reduce_mean(loss) + + mask = label != self.ignore_index + mask = paddle.cast(mask, 'float32') + avg_loss = loss / (paddle.mean(mask) + self.EPS) + + label.stop_gradient = True + mask.stop_gradient = True + return avg_loss diff --git a/dygraph/models/ocrnet.py b/dygraph/models/ocrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..bdadd6d5b2a1e1946a9207eaa166705fb51da06e --- /dev/null +++ b/dygraph/models/ocrnet.py @@ -0,0 +1,215 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle.fluid as fluid +from paddle.fluid.dygraph import Sequential, Conv2D + +from dygraph.cvlibs import manager +from dygraph.models.architectures.layer_utils import ConvBnRelu +from dygraph import utils + + +class SpatialGatherBlock(fluid.dygraph.Layer): + def forward(self, pixels, regions): + n, c, h, w = pixels.shape + _, k, _, _ = regions.shape + + # pixels: from (n, c, h, w) to (n, h*w, c) + pixels = fluid.layers.reshape(pixels, (n, c, h * w)) + pixels = fluid.layers.transpose(pixels, (0, 2, 1)) + + # regions: from (n, k, h, w) to (n, k, h*w) + regions = fluid.layers.reshape(regions, (n, k, h * w)) + regions = fluid.layers.softmax(regions, axis=2) + + # feats: from (n, k, c) to (n, c, k, 1) + feats = fluid.layers.matmul(regions, pixels) + feats = fluid.layers.transpose(feats, (0, 2, 1)) + feats = fluid.layers.unsqueeze(feats, axes=[-1]) + + return feats + + +class SpatialOCRModule(fluid.dygraph.Layer): + def __init__(self, + in_channels, + key_channels, + out_channels, + dropout_rate=0.1): + super(SpatialOCRModule, self).__init__() + + self.attention_block = ObjectAttentionBlock(in_channels, key_channels) + self.dropout_rate = dropout_rate + self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1) + + def forward(self, pixels, regions): + context = self.attention_block(pixels, regions) + feats = fluid.layers.concat([context, pixels], axis=1) + + feats = self.conv1x1(feats) + feats = fluid.layers.dropout(feats, self.dropout_rate) + + return feats + + +class ObjectAttentionBlock(fluid.dygraph.Layer): + def __init__(self, in_channels, key_channels): + super(ObjectAttentionBlock, self).__init__() + + self.in_channels = in_channels + self.key_channels = key_channels + + self.f_pixel = Sequential( + ConvBnRelu(in_channels, key_channels, 1), + ConvBnRelu(key_channels, key_channels, 1)) + + self.f_object = Sequential( + ConvBnRelu(in_channels, key_channels, 1), + ConvBnRelu(key_channels, key_channels, 1)) + + self.f_down = ConvBnRelu(in_channels, key_channels, 1) + + self.f_up = ConvBnRelu(key_channels, in_channels, 1) + + def forward(self, x, proxy): + n, _, h, w = x.shape + + # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) + query = self.f_pixel(x) + query = fluid.layers.reshape(query, (n, self.key_channels, -1)) + query = fluid.layers.transpose(query, (0, 2, 1)) + + # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) + key = self.f_object(proxy) + key = fluid.layers.reshape(key, (n, self.key_channels, -1)) + + # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) + value = self.f_down(proxy) + value = fluid.layers.reshape(value, (n, self.key_channels, -1)) + value = fluid.layers.transpose(value, (0, 2, 1)) + + # sim_map (n, h1*w1, h2*w2) + sim_map = fluid.layers.matmul(query, key) + sim_map = (self.key_channels**-.5) * sim_map + sim_map = fluid.layers.softmax(sim_map, axis=-1) + + # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) + context = fluid.layers.matmul(sim_map, value) + context = fluid.layers.transpose(context, (0, 2, 1)) + context = fluid.layers.reshape(context, (n, self.key_channels, h, w)) + context = self.f_up(context) + + return context + + +@manager.MODELS.add_component +class OCRNet(fluid.dygraph.Layer): + def __init__(self, + num_classes, + backbone, + model_pretrained=None, + in_channels=None, + ocr_mid_channels=512, + ocr_key_channels=256, + ignore_index=255): + super(OCRNet, self).__init__() + + self.ignore_index = ignore_index + self.num_classes = num_classes + self.EPS = 1e-5 + + self.backbone = backbone + self.spatial_gather = SpatialGatherBlock() + self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels, + ocr_mid_channels) + self.conv3x3_ocr = ConvBnRelu( + in_channels, ocr_mid_channels, 3, padding=1) + self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1) + + self.aux_head = Sequential( + ConvBnRelu(in_channels, in_channels, 3, padding=1), + Conv2D(in_channels, self.num_classes, 1)) + + self.init_weight(model_pretrained) + + def forward(self, x, label=None): + feats = self.backbone(x) + + soft_regions = self.aux_head(feats) + pixels = self.conv3x3_ocr(feats) + + object_regions = self.spatial_gather(pixels, soft_regions) + ocr = self.spatial_ocr(pixels, object_regions) + + logit = self.cls_head(ocr) + logit = fluid.layers.resize_bilinear(logit, x.shape[2:]) + + if self.training: + soft_regions = fluid.layers.resize_bilinear(soft_regions, + x.shape[2:]) + cls_loss = self._get_loss(logit, label) + aux_loss = self._get_loss(soft_regions, label) + return cls_loss + 0.4 * aux_loss + + score_map = fluid.layers.softmax(logit, axis=1) + score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) + pred = fluid.layers.argmax(score_map, axis=3) + pred = fluid.layers.unsqueeze(pred, axes=[3]) + return pred, score_map + + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + Args: + pretrained_model ([str], optional): the path of pretrained model.. Defaults to None. + """ + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) + + def _get_loss(self, logit, label): + """ + compute forward loss of the model + + Args: + logit (tensor): the logit of model output + label (tensor): ground truth + + Returns: + avg_loss (tensor): forward loss + """ + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + label = fluid.layers.transpose(label, [0, 2, 3, 1]) + mask = label != self.ignore_index + mask = fluid.layers.cast(mask, 'float32') + loss, probs = fluid.layers.softmax_with_cross_entropy( + logit, + label, + ignore_index=self.ignore_index, + return_softmax=True, + axis=-1) + + loss = loss * mask + avg_loss = fluid.layers.mean(loss) / ( + fluid.layers.mean(mask) + self.EPS) + + label.stop_gradient = True + mask.stop_gradient = True + + return avg_loss diff --git a/dygraph/models/pspnet.py b/dygraph/models/pspnet.py index 2b30d43c1c008d4aafb8f5eba8907da14e1dbfb1..d4457ed53435aa75257b68c476b55c15ab701c68 100644 --- a/dygraph/models/pspnet.py +++ b/dygraph/models/pspnet.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os import paddle.nn.functional as F @@ -29,15 +28,17 @@ class PSPNet(fluid.dygraph.Layer): """ The PSPNet implementation - The orginal artile refers to - Zhao, Hengshuang, et al. "Pyramid scene parsing network." + The orginal artile refers to + Zhao, Hengshuang, et al. "Pyramid scene parsing network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. (https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Pyramid_Scene_Parsing_CVPR_2017_paper.pdf) Args: - backbone (str): backbone name, currently support Resnet50/101. + num_classes (int): the unique number of target classes. + + backbone (Paddle.nn.Layer): backbone name, currently support Resnet50/101. - num_classes (int): the unique number of target classes. Default 2. + model_pretrained (str): the path of pretrained model. output_stride (int): the ratio of input size and final feature size. Default 16. @@ -57,42 +58,44 @@ class PSPNet(fluid.dygraph.Layer): enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True. ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default to 255. - - pretrained_model (str): the pretrained_model path of backbone. """ def __init__(self, + num_classes, backbone, - num_classes=2, + model_pretrained=None, output_stride=16, backbone_indices=(2, 3), backbone_channels=(1024, 2048), pp_out_channels=1024, bin_sizes=(1, 2, 3, 6), enable_auxiliary_loss=True, - ignore_index=255, - pretrained_model=None): + ignore_index=255): super(PSPNet, self).__init__() - self.backbone = manager.BACKBONES[backbone](output_stride=output_stride, - multi_grid=(1, 1, 1)) + # self.backbone = manager.BACKBONES[backbone](output_stride=output_stride, + # multi_grid=(1, 1, 1)) + self.backbone = backbone self.backbone_indices = backbone_indices - self.psp_module = PPModule(in_channels=backbone_channels[1], - out_channels=pp_out_channels, - bin_sizes=bin_sizes) + self.psp_module = PPModule( + in_channels=backbone_channels[1], + out_channels=pp_out_channels, + bin_sizes=bin_sizes) - self.conv = Conv2D(num_channels=pp_out_channels, - num_filters=num_classes, - filter_size=1) + self.conv = Conv2D( + num_channels=pp_out_channels, + num_filters=num_classes, + filter_size=1) if enable_auxiliary_loss: - self.fcn_head = model_utils.FCNHead(in_channels=backbone_channels[0], out_channels=num_classes) + self.fcn_head = model_utils.FCNHead( + in_channels=backbone_channels[0], out_channels=num_classes) self.enable_auxiliary_loss = enable_auxiliary_loss self.ignore_index = ignore_index - self.init_weight(pretrained_model) + self.init_weight(model_pretrained) def forward(self, input, label=None): @@ -107,7 +110,8 @@ class PSPNet(fluid.dygraph.Layer): if self.enable_auxiliary_loss: auxiliary_feat = feat_list[self.backbone_indices[0]] auxiliary_logit = self.fcn_head(auxiliary_feat) - auxiliary_logit = fluid.layers.resize_bilinear(auxiliary_logit, input.shape[2:]) + auxiliary_logit = fluid.layers.resize_bilinear( + auxiliary_logit, input.shape[2:]) if self.training: loss = model_utils.get_loss(logit, label) @@ -116,7 +120,6 @@ class PSPNet(fluid.dygraph.Layer): loss += (0.4 * auxiliary_loss) return loss - else: pred, score_map = model_utils.get_pred_score_map(logit) return pred, score_map @@ -124,14 +127,15 @@ class PSPNet(fluid.dygraph.Layer): def init_weight(self, pretrained_model=None): """ Initialize the parameters of model parts. - Args: - pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. """ - if pretrained_model is not None: if os.path.exists(pretrained_model): - utils.load_pretrained_model(self.backbone, pretrained_model) + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) class PPModule(fluid.dygraph.Layer): @@ -151,19 +155,21 @@ class PPModule(fluid.dygraph.Layer): self.bin_sizes = bin_sizes # we use dimension reduction after pooling mentioned in original implementation. - self.stages = fluid.dygraph.LayerList([self._make_stage(in_channels, size) for size in bin_sizes]) + self.stages = fluid.dygraph.LayerList( + [self._make_stage(in_channels, size) for size in bin_sizes]) - self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=in_channels * 2, - num_filters=out_channels, - filter_size=3, - padding=1) + self.conv_bn_relu2 = layer_utils.ConvBnRelu( + num_channels=in_channels * 2, + num_filters=out_channels, + filter_size=3, + padding=1) def _make_stage(self, in_channels, size): """ Create one pooling layer. In our implementation, we adopt the same dimention reduction as the original paper that might be - slightly different with other implementations. + slightly different with other implementations. After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations keep the channels to be same. @@ -180,9 +186,10 @@ class PPModule(fluid.dygraph.Layer): # this paddle version does not support AdaptiveAvgPool2d, so skip it here. # prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) - conv = layer_utils.ConvBnRelu(num_channels=in_channels, - num_filters=in_channels // len(self.bin_sizes), - filter_size=1) + conv = layer_utils.ConvBnRelu( + num_channels=in_channels, + num_filters=in_channels // len(self.bin_sizes), + filter_size=1) return conv @@ -190,7 +197,8 @@ class PPModule(fluid.dygraph.Layer): cat_layers = [] for i, stage in enumerate(self.stages): size = self.bin_sizes[i] - x = fluid.layers.adaptive_pool2d(input, pool_size=(size, size), pool_type="max") + x = fluid.layers.adaptive_pool2d( + input, pool_size=(size, size), pool_type="max") x = stage(x) x = fluid.layers.resize_bilinear(x, out_shape=input.shape[2:]) cat_layers.append(x) @@ -204,22 +212,32 @@ class PPModule(fluid.dygraph.Layer): @manager.MODELS.add_component def pspnet_resnet101_vd(*args, **kwargs): pretrained_model = None - return PSPNet(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) + return PSPNet( + backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) @manager.MODELS.add_component def pspnet_resnet101_vd_os8(*args, **kwargs): pretrained_model = None - return PSPNet(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) + return PSPNet( + backbone='ResNet101_vd', + output_stride=8, + pretrained_model=pretrained_model, + **kwargs) @manager.MODELS.add_component def pspnet_resnet50_vd(*args, **kwargs): pretrained_model = None - return PSPNet(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) + return PSPNet( + backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) @manager.MODELS.add_component def pspnet_resnet50_vd_os8(*args, **kwargs): pretrained_model = None - return PSPNet(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) + return PSPNet( + backbone='ResNet50_vd', + output_stride=8, + pretrained_model=pretrained_model, + **kwargs) diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py index 4ce28400c8f7dd64f40d651506d13988aa764c39..e2a7c007caa68a74deb322cc4d4d8b66a1b75035 100644 --- a/dygraph/models/unet.py +++ b/dygraph/models/unet.py @@ -33,7 +33,7 @@ class UNet(fluid.dygraph.Layer): ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255. """ - def __init__(self, num_classes, pretrained_model=None, ignore_index=255): + def __init__(self, num_classes, model_pretrained=None, ignore_index=255): super(UNet, self).__init__() self.encode = UnetEncoder() self.decode = UnetDecode() @@ -41,7 +41,7 @@ class UNet(fluid.dygraph.Layer): self.ignore_index = ignore_index self.EPS = 1e-5 - self.init_weight(pretrained_model) + self.init_weight(model_pretrained) def forward(self, x, label=None): encode_data, short_cuts = self.encode(x) @@ -60,7 +60,7 @@ class UNet(fluid.dygraph.Layer): """ Initialize the parameters of model parts. Args: - pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. + pretrained_model ([str], optional): the path of pretrained model. Defaults to None. """ if pretrained_model is not None: if os.path.exists(pretrained_model): diff --git a/dygraph/train.py b/dygraph/train.py index 382a41a06afe03efc28ac8493ad6910cf74333ca..92ffbcb1efcc7904f04a790ee22735005bbd936c 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -17,78 +17,36 @@ import argparse import paddle.fluid as fluid from paddle.fluid.dygraph.parallel import ParallelEnv -from dygraph.datasets import DATASETS -import dygraph.transforms as T +import dygraph from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.utils import logger +from dygraph.utils import Config from dygraph.core import train def parse_args(): parser = argparse.ArgumentParser(description='Model training') - - # params of model - parser.add_argument( - '--model_name', - dest='model_name', - help='Model type for training, which is one of {}'.format( - str(list(manager.MODELS.components_dict.keys()))), - type=str, - default='UNet') - - # params of dataset - parser.add_argument( - '--dataset', - dest='dataset', - help="The dataset you want to train, which is one of {}".format( - str(list(DATASETS.keys()))), - type=str, - default='OpticDiscSeg') - parser.add_argument( - '--dataset_root', - dest='dataset_root', - help="dataset root directory", - type=str, - default=None) - # params of training parser.add_argument( - "--input_size", - dest="input_size", - help="The image size for net inputs.", - nargs=2, - default=[512, 512], - type=int) + "--config", dest="cfg", help="The config file.", default=None, type=str) parser.add_argument( '--iters', dest='iters', help='iters for training', type=int, - default=10000) + default=None) parser.add_argument( '--batch_size', dest='batch_size', help='Mini batch size of one gpu or cpu', type=int, - default=2) + default=None) parser.add_argument( '--learning_rate', dest='learning_rate', help='Learning rate', type=float, - default=0.01) - parser.add_argument( - '--pretrained_model', - dest='pretrained_model', - help='The path of pretrained model', - type=str, - default=None) - parser.add_argument( - '--resume_model', - dest='resume_model', - help='The path of resume model', - type=str, default=None) parser.add_argument( '--save_interval_iters', @@ -139,64 +97,36 @@ def main(args): if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \ else fluid.CPUPlace() - if args.dataset not in DATASETS: - raise Exception('`--dataset` is invalid. it should be one of {}'.format( - str(list(DATASETS.keys())))) - dataset = DATASETS[args.dataset] - with fluid.dygraph.guard(places): - # Creat dataset reader - train_transforms = T.Compose([ - T.Resize(args.input_size), - T.RandomHorizontalFlip(), - T.Normalize() - ]) - train_dataset = dataset( - dataset_root=args.dataset_root, - transforms=train_transforms, - mode='train') - - eval_dataset = None - if args.do_eval: - eval_transforms = T.Compose( - [T.Resize(args.input_size), - T.Normalize()]) - eval_dataset = dataset( - dataset_root=args.dataset_root, - transforms=eval_transforms, - mode='val') - - model = manager.MODELS[args.model_name]( - num_classes=train_dataset.num_classes, - pretrained_model=args.pretrained_model) - - # Creat optimizer - # todo, may less one than len(loader) - num_iters_each_epoch = len(train_dataset) // ( - args.batch_size * ParallelEnv().nranks) - lr_decay = fluid.layers.polynomial_decay( - args.learning_rate, args.iters, end_learning_rate=0, power=0.9) - optimizer = fluid.optimizer.Momentum( - lr_decay, - momentum=0.9, - parameter_list=model.parameters(), - regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5)) + if not args.cfg: + raise RuntimeError('No configuration file specified.') + + cfg = Config(args.cfg) + train_dataset = cfg.train_dataset + if not train_dataset: + raise RuntimeError( + 'The training dataset is not specified in the configuration file.' + ) + + val_dataset = cfg.val_dataset if args.do_eval else None + + losses = cfg.loss train( - model, + cfg.model, train_dataset, places=places, - eval_dataset=eval_dataset, - optimizer=optimizer, + eval_dataset=val_dataset, + optimizer=cfg.optimizer, save_dir=args.save_dir, - iters=args.iters, - batch_size=args.batch_size, - resume_model=args.resume_model, + iters=cfg.iters, + batch_size=cfg.batch_size, save_interval_iters=args.save_interval_iters, log_iters=args.log_iters, num_classes=train_dataset.num_classes, num_workers=args.num_workers, - use_vdl=args.use_vdl) + use_vdl=args.use_vdl, + losses=losses) if __name__ == '__main__': diff --git a/dygraph/transforms/transforms.py b/dygraph/transforms/transforms.py index 935a2c0f8670eaa24b148844aa727efe6942e666..91404ade7d263c6df551ee8b15f74f9d1df96ae0 100644 --- a/dygraph/transforms/transforms.py +++ b/dygraph/transforms/transforms.py @@ -21,8 +21,10 @@ from PIL import Image import cv2 from .functional import * +from dygraph.cvlibs import manager +@manager.TRANSFORMS.add_component class Compose: def __init__(self, transforms, to_rgb=True): if not isinstance(transforms, list): @@ -53,11 +55,12 @@ class Compose: if len(outputs) == 3: label = outputs[2] im = permute(im) - if len(outputs) == 3: - label = label[np.newaxis, :, :] + # if len(outputs) == 3: + # label = label[np.newaxis, :, :] return (im, im_info, label) +@manager.TRANSFORMS.add_component class RandomHorizontalFlip: def __init__(self, prob=0.5): self.prob = prob @@ -73,6 +76,7 @@ class RandomHorizontalFlip: return (im, im_info, label) +@manager.TRANSFORMS.add_component class RandomVerticalFlip: def __init__(self, prob=0.1): self.prob = prob @@ -88,6 +92,7 @@ class RandomVerticalFlip: return (im, im_info, label) +@manager.TRANSFORMS.add_component class Resize: # The interpolation mode interp_dict = { @@ -137,6 +142,7 @@ class Resize: return (im, im_info, label) +@manager.TRANSFORMS.add_component class ResizeByLong: def __init__(self, long_size): self.long_size = long_size @@ -156,6 +162,7 @@ class ResizeByLong: return (im, im_info, label) +@manager.TRANSFORMS.add_component class ResizeRangeScaling: def __init__(self, min_value=400, max_value=600): if min_value > max_value: @@ -181,6 +188,7 @@ class ResizeRangeScaling: return (im, im_info, label) +@manager.TRANSFORMS.add_component class ResizeStepScaling: def __init__(self, min_scale_factor=0.75, @@ -224,6 +232,7 @@ class ResizeStepScaling: return (im, im_info, label) +@manager.TRANSFORMS.add_component class Normalize: def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): self.mean = mean @@ -245,6 +254,7 @@ class Normalize: return (im, im_info, label) +@manager.TRANSFORMS.add_component class Padding: def __init__(self, target_size, @@ -305,6 +315,7 @@ class Padding: return (im, im_info, label) +@manager.TRANSFORMS.add_component class RandomPaddingCrop: def __init__(self, crop_size=512, @@ -378,6 +389,7 @@ class RandomPaddingCrop: return (im, im_info, label) +@manager.TRANSFORMS.add_component class RandomBlur: def __init__(self, prob=0.1): self.prob = prob @@ -404,6 +416,7 @@ class RandomBlur: return (im, im_info, label) +@manager.TRANSFORMS.add_component class RandomRotation: def __init__(self, max_rotation=15, @@ -451,6 +464,7 @@ class RandomRotation: return (im, im_info, label) +@manager.TRANSFORMS.add_component class RandomScaleAspect: def __init__(self, min_scale=0.5, aspect_ratio=0.33): self.min_scale = min_scale @@ -492,6 +506,7 @@ class RandomScaleAspect: return (im, im_info, label) +@manager.TRANSFORMS.add_component class RandomDistort: def __init__(self, brightness_range=0.5, diff --git a/dygraph/utils/__init__.py b/dygraph/utils/__init__.py index e1e92959a70f240f6c59d999e1e135004d5b0de2..a22f9e5ec0ff32a5e42b6c2d7d6bed14a56994a1 100644 --- a/dygraph/utils/__init__.py +++ b/dygraph/utils/__init__.py @@ -18,3 +18,4 @@ from .metrics import ConfusionMatrix from .utils import * from .timer import Timer, calculate_eta from .get_environ_info import get_environ_info +from .config import Config diff --git a/dygraph/utils/config.py b/dygraph/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7a6aac4cd5ec667fdef606c3b85270abe9ea4d --- /dev/null +++ b/dygraph/utils/config.py @@ -0,0 +1,239 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import codecs +import os +from typing import Any, Callable + +import yaml +import paddle.fluid as fluid + +import dygraph.cvlibs.manager as manager + + +class Config(object): + ''' + Training config. + + Args: + path(str) : the path of config file, supports yaml format only + ''' + + def __init__(self, path: str): + if not os.path.exists(path): + raise FileNotFoundError('File {} does not exist'.format(path)) + + if path.endswith('yml') or path.endswith('yaml'): + dic = self._parse_from_yaml(path) + self._build(dic) + else: + raise RuntimeError('Config file should in yaml format!') + + def _update_dic(self, dic, base_dic): + """ + update dic from base_dic + """ + dic = dic.copy() + for key, val in base_dic.items(): + if isinstance(val, dict) and key in dic: + dic[key] = self._update_dic(dic[key], val) + else: + dic[key] = val + return dic + + def _parse_from_yaml(self, path: str): + '''Parse a yaml file and build config''' + with codecs.open(path, 'r', 'utf-8') as file: + dic = yaml.load(file, Loader=yaml.FullLoader) + if '_base_' in dic: + cfg_dir = os.path.dirname(path) + base_path = dic.pop('_base_') + base_path = os.path.join(cfg_dir, base_path) + base_dic = self._parse_from_yaml(base_path) + dic = self._update_dic(dic, base_dic) + return dic + + def _build(self, dic: dict): + '''Build config from dictionary''' + dic = dic.copy() + + self._batch_size = dic.get('batch_size', 1) + self._iters = dic.get('iters') + + if 'model' not in dic: + raise RuntimeError() + self._model_cfg = dic['model'] + self._model = None + + self._train_dataset = dic.get('train_dataset') + self._val_dataset = dic.get('val_dataset') + + self._learning_rate_cfg = dic.get('learning_rate', {}) + self._learning_rate = self._learning_rate_cfg.get('value') + self._decay = self._learning_rate_cfg.get('decay', { + 'type': 'poly', + 'power': 0.9 + }) + + self._loss_cfg = dic.get('loss', {}) + self._losses = None + + self._optimizer_cfg = dic.get('optimizer', {}) + + def update(self, + learning_rate: float = None, + batch_size: int = None, + iters: int = None): + '''Update config''' + if learning_rate: + self._learning_rate = learning_rate + + if batch_size: + self._batch_size = batch_size + + if iters: + self._iters = iters + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def iters(self) -> int: + if not self._iters: + raise RuntimeError('No iters specified in the configuration file.') + return self._iters + + @property + def learning_rate(self) -> float: + if not self._learning_rate: + raise RuntimeError( + 'No learning rate specified in the configuration file.') + + if self.decay_type == 'poly': + lr = self._learning_rate + args = self.decay_args + args.setdefault('decay_steps', self.iters) + return fluid.layers.polynomial_decay(lr, **args) + else: + raise RuntimeError('Only poly decay support.') + + @property + def optimizer(self) -> fluid.optimizer.Optimizer: + if self.optimizer_type == 'sgd': + lr = self.learning_rate + args = self.optimizer_args + args.setdefault('momentum', 0.9) + return fluid.optimizer.Momentum( + lr, parameter_list=self.model.parameters(), **args) + else: + raise RuntimeError('Only sgd optimizer support.') + + @property + def optimizer_type(self) -> str: + otype = self._optimizer_cfg.get('type') + if not otype: + raise RuntimeError( + 'No optimizer type specified in the configuration file.') + return otype + + @property + def optimizer_args(self) -> dict: + args = self._optimizer_cfg.copy() + args.pop('type') + return args + + @property + def decay_type(self) -> str: + return self._decay['type'] + + @property + def decay_args(self) -> dict: + args = self._decay.copy() + args.pop('type') + return args + + @property + def loss(self) -> list: + if not self._losses: + args = self._loss_cfg.copy() + self._losses = dict() + for key, val in args.items(): + if key == 'types': + self._losses['types'] = [] + for item in args['types']: + self._losses['types'].append(self._load_object(item)) + else: + self._losses[key] = val + if len(self._losses['coef']) != len(self._losses['types']): + raise RuntimeError( + 'The length of coef should equal to types in loss config: {} != {}.' + .format( + len(self._losses['coef']), len(self._losses['types']))) + return self._losses + + @property + def model(self) -> Callable: + if not self._model: + self._model = self._load_object(self._model_cfg) + return self._model + + @property + def train_dataset(self) -> Any: + if not self._train_dataset: + return None + return self._load_object(self._train_dataset) + + @property + def val_dataset(self) -> Any: + if not self._val_dataset: + return None + return self._load_object(self._val_dataset) + + def _load_component(self, com_name: str) -> Any: + com_list = [ + manager.MODELS, manager.BACKBONES, manager.DATASETS, + manager.TRANSFORMS, manager.LOSSES + ] + + for com in com_list: + if com_name in com.components_dict: + return com[com_name] + else: + raise RuntimeError( + 'The specified component was not found {}.'.format(com_name)) + + def _load_object(self, cfg: dict) -> Any: + cfg = cfg.copy() + if 'type' not in cfg: + raise RuntimeError('No object information in {}.'.format(cfg)) + + component = self._load_component(cfg.pop('type')) + + params = {} + for key, val in cfg.items(): + if self._is_meta_type(val): + params[key] = self._load_object(val) + elif isinstance(val, list): + params[key] = [ + self._load_object(item) + if self._is_meta_type(item) else item for item in val + ] + else: + params[key] = val + + return component(**params) + + def _is_meta_type(self, item: Any) -> bool: + return isinstance(item, dict) and 'type' in item diff --git a/dygraph/val.py b/dygraph/val.py index 044ee1eedfce98975918b6d374edce19d450a292..f4b7d6399c155d629add1888131c0d6bf7430421 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -17,48 +17,19 @@ import argparse import paddle.fluid as fluid from paddle.fluid.dygraph.parallel import ParallelEnv -from dygraph.datasets import DATASETS -import dygraph.transforms as T +import dygraph from dygraph.cvlibs import manager from dygraph.utils import get_environ_info +from dygraph.utils import Config from dygraph.core import evaluate def parse_args(): parser = argparse.ArgumentParser(description='Model evaluation') - # params of model - parser.add_argument( - '--model_name', - dest='model_name', - help='Model type for evaluation, which is one of {}'.format( - str(list(manager.MODELS.components_dict.keys()))), - type=str, - default='UNet') - - # params of dataset - parser.add_argument( - '--dataset', - dest='dataset', - help="The dataset you want to evaluation, which is one of {}".format( - str(list(DATASETS.keys()))), - type=str, - default='OpticDiscSeg') - parser.add_argument( - '--dataset_root', - dest='dataset_root', - help="dataset root directory", - type=str, - default=None) - # params of evaluate parser.add_argument( - "--input_size", - dest="input_size", - help="The image size for net inputs.", - nargs=2, - default=[512, 512], - type=int) + "--config", dest="cfg", help="The config file.", default=None, type=str) parser.add_argument( '--model_dir', dest='model_dir', @@ -75,26 +46,21 @@ def main(args): if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \ else fluid.CPUPlace() - if args.dataset not in DATASETS: - raise Exception('`--dataset` is invalid. it should be one of {}'.format( - str(list(DATASETS.keys())))) - dataset = DATASETS[args.dataset] - with fluid.dygraph.guard(places): - eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) - eval_dataset = dataset( - dataset_root=args.dataset_root, - transforms=eval_transforms, - mode='val') - - model = manager.MODELS[args.model_name]( - num_classes=eval_dataset.num_classes) - + if not args.cfg: + raise RuntimeError('No configuration file specified.') + + cfg = Config(args.cfg) + val_dataset = cfg.val_dataset + if not val_dataset: + raise RuntimeError( + 'The verification dataset is not specified in the configuration file.' + ) evaluate( - model, - eval_dataset, + cfg.model, + val_dataset, model_dir=args.model_dir, - num_classes=eval_dataset.num_classes) + num_classes=val_dataset.num_classes) if __name__ == '__main__':