From 544441421f1ec650fc874cbf8532d94b993906bb Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 26 Aug 2020 17:13:02 +0800 Subject: [PATCH] make hrnet as backbone and add fcn --- dygraph/core/train.py | 5 +- dygraph/infer.py | 11 +- dygraph/models/__init__.py | 1 + dygraph/models/architectures/__init__.py | 19 ++ dygraph/models/{ => architectures}/hrnet.py | 314 +++----------------- dygraph/models/fcn.py | 230 ++++++++++++++ dygraph/train.py | 6 +- dygraph/val.py | 11 +- 8 files changed, 305 insertions(+), 292 deletions(-) create mode 100644 dygraph/models/architectures/__init__.py rename dygraph/models/{ => architectures}/hrnet.py (76%) create mode 100644 dygraph/models/fcn.py diff --git a/dygraph/core/train.py b/dygraph/core/train.py index 9f3f83c4..2bce9a16 100644 --- a/dygraph/core/train.py +++ b/dygraph/core/train.py @@ -34,7 +34,6 @@ def train(model, save_dir='output', iters=10000, batch_size=2, - pretrained_model=None, resume_model=None, save_interval_iters=1000, log_iters=10, @@ -47,8 +46,6 @@ def train(model, start_iter = 0 if resume_model is not None: start_iter = resume(model, optimizer, resume_model) - elif pretrained_model is not None: - load_pretrained_model(model, pretrained_model) if not os.path.isdir(save_dir): if os.path.exists(save_dir): @@ -126,7 +123,6 @@ def train(model, log_writer.add_scalar('Train/reader_cost', avg_train_reader_cost, iter) avg_loss = 0.0 - timer.restart() if (iter % save_interval_iters == 0 or iter == iters) and ParallelEnv().local_rank == 0: @@ -162,5 +158,6 @@ def train(model, log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter) log_writer.add_scalar('Evaluate/aAcc', avg_acc, iter) model.train() + timer.restart() if use_vdl: log_writer.close() diff --git a/dygraph/infer.py b/dygraph/infer.py index 9d05571b..2e6aa3f5 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -19,7 +19,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -from dygraph.models import MODELS +from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.core import infer @@ -32,7 +32,7 @@ def parse_args(): '--model_name', dest='model_name', help='Model type for testing, which is one of {}'.format( - str(list(MODELS.keys()))), + str(list(manager.MODELS.components_dict.keys()))), type=str, default='UNet') @@ -99,11 +99,8 @@ def main(args): transforms=test_transforms, mode='test') - if args.model_name not in MODELS: - raise Exception( - '`--model_name` is invalid. it should be one of {}'.format( - str(list(MODELS.keys())))) - model = MODELS[args.model_name](num_classes=test_dataset.num_classes) + model = manager.MODELS[args.model_name]( + num_classes=test_dataset.num_classes) infer( model, diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py index 750e77ac..6af6df34 100644 --- a/dygraph/models/__init__.py +++ b/dygraph/models/__init__.py @@ -16,6 +16,7 @@ from .architectures import * from .unet import UNet from .hrnet import * from .deeplab import * +from .fcn import * # MODELS = { # "UNet": UNet, diff --git a/dygraph/models/architectures/__init__.py b/dygraph/models/architectures/__init__.py new file mode 100644 index 00000000..730c8f97 --- /dev/null +++ b/dygraph/models/architectures/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import layer_utils +from .hrnet import * +from .resnet_vd import * +from .xception_deeplab import * +from .mobilenetv3 import * diff --git a/dygraph/models/hrnet.py b/dygraph/models/architectures/hrnet.py similarity index 76% rename from dygraph/models/hrnet.py rename to dygraph/models/architectures/hrnet.py index 2019900d..4b4750ee 100644 --- a/dygraph/models/hrnet.py +++ b/dygraph/models/architectures/hrnet.py @@ -20,20 +20,22 @@ from paddle.fluid.param_attr import ParamAttr from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.initializer import Normal -from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm +from paddle.nn import SyncBatchNorm as BatchNorm + +from dygraph.cvlibs import manager __all__ = [ "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", - "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", - "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", - "SE_HRNet_W18", "SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40", - "SE_HRNet_W44", "SE_HRNet_W48", "SE_HRNet_W60", "SE_HRNet_W64" + "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64" ] class HRNet(fluid.dygraph.Layer): + """ + HRNet: + """ + def __init__(self, - num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -46,11 +48,9 @@ class HRNet(fluid.dygraph.Layer): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[18, 36, 72, 144], - has_se=False, - ignore_index=255): + has_se=False): super(HRNet, self).__init__() - self.num_classes = num_classes self.stage1_num_modules = stage1_num_modules self.stage1_num_blocks = stage1_num_blocks self.stage1_num_channels = stage1_num_channels @@ -64,8 +64,6 @@ class HRNet(fluid.dygraph.Layer): self.stage4_num_blocks = stage4_num_blocks self.stage4_num_channels = stage4_num_channels self.has_se = has_se - self.ignore_index = ignore_index - self.EPS = 1e-5 self.conv_layer1_1 = ConvBNLayer( num_channels=3, @@ -112,6 +110,7 @@ class HRNet(fluid.dygraph.Layer): num_modules=self.stage3_num_modules, num_blocks=self.stage3_num_blocks, num_filters=self.stage3_num_channels, + has_se=self.has_se, name="st3") self.tr3 = TransitionLayer( @@ -123,24 +122,9 @@ class HRNet(fluid.dygraph.Layer): num_modules=self.stage4_num_modules, num_blocks=self.stage4_num_blocks, num_filters=self.stage4_num_channels, + has_se=self.has_se, name="st4") - last_inp_channels = sum(self.stage4_num_channels) - self.conv_last_2 = ConvBNLayer( - num_channels=last_inp_channels, - num_filters=last_inp_channels, - filter_size=1, - stride=1, - name='conv-2') - self.conv_last_1 = Conv2D( - num_channels=last_inp_channels, - num_filters=self.num_classes, - filter_size=1, - stride=1, - padding=0, - param_attr=ParamAttr( - initializer=Normal(scale=0.001), name='conv-1_weights')) - def forward(self, x, label=None, mode='train'): input_shape = x.shape[2:] conv1 = self.conv_layer1_1(x) @@ -162,40 +146,8 @@ class HRNet(fluid.dygraph.Layer): x2 = fluid.layers.resize_bilinear(st4[2], out_shape=(x0_h, x0_w)) x3 = fluid.layers.resize_bilinear(st4[3], out_shape=(x0_h, x0_w)) x = fluid.layers.concat([st4[0], x1, x2, x3], axis=1) - x = self.conv_last_2(x) - logit = self.conv_last_1(x) - logit = fluid.layers.resize_bilinear(logit, input_shape) - - if self.training: - if label is None: - raise Exception('Label is need during training') - return self._get_loss(logit, label) - else: - score_map = fluid.layers.softmax(logit, axis=1) - score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) - pred = fluid.layers.argmax(score_map, axis=3) - pred = fluid.layers.unsqueeze(pred, axes=[3]) - return pred, score_map - - def _get_loss(self, logit, label): - logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) - label = fluid.layers.transpose(label, [0, 2, 3, 1]) - mask = label != self.ignore_index - mask = fluid.layers.cast(mask, 'float32') - loss, probs = fluid.layers.softmax_with_cross_entropy( - logit, - label, - ignore_index=self.ignore_index, - return_softmax=True, - axis=-1) - - loss = loss * mask - avg_loss = fluid.layers.mean(loss) / ( - fluid.layers.mean(mask) + self.EPS) - - label.stop_gradient = True - mask.stop_gradient = True - return avg_loss + + return x class ConvBNLayer(fluid.dygraph.Layer): @@ -698,189 +650,9 @@ class LastClsOut(fluid.dygraph.Layer): return outs -def HRNet_W18_Small_V1(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[1], - stage1_num_channels=[32], - stage2_num_modules=1, - stage2_num_blocks=[2, 2], - stage2_num_channels=[16, 32], - stage3_num_modules=1, - stage3_num_blocks=[2, 2, 2], - stage3_num_channels=[16, 32, 64], - stage4_num_modules=1, - stage4_num_blocks=[2, 2, 2, 2], - stage4_num_channels=[16, 32, 64, 128]) - return model - - -def HRNet_W18_Small_V2(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[2], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[2, 2], - stage2_num_channels=[18, 36], - stage3_num_modules=1, - stage3_num_blocks=[2, 2, 2], - stage3_num_channels=[18, 36, 72], - stage4_num_modules=1, - stage4_num_blocks=[2, 2, 2, 2], - stage4_num_channels=[18, 36, 72, 144]) - return model - - -def HRNet_W18(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[18, 36], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[18, 36, 72], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[18, 36, 72, 144]) - return model - - -def HRNet_W30(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[30, 60], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[30, 60, 120], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[30, 60, 120, 240]) - return model - - -def HRNet_W32(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[32, 64], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[32, 64, 128], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[32, 64, 128, 256]) - return model - - -def HRNet_W40(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[40, 80], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[40, 80, 160], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[40, 80, 160, 320]) - return model - - -def HRNet_W44(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[44, 88], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[44, 88, 176], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[44, 88, 176, 352]) - return model - - -def HRNet_W48(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[48, 96], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[48, 96, 192], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[48, 96, 192, 384]) - return model - - -def HRNet_W60(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[60, 120], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[60, 120, 240], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[60, 120, 240, 480]) - return model - - -def HRNet_W64(num_classes): - model = HRNet( - num_classes=num_classes, - stage1_num_modules=1, - stage1_num_blocks=[4], - stage1_num_channels=[64], - stage2_num_modules=1, - stage2_num_blocks=[4, 4], - stage2_num_channels=[64, 128], - stage3_num_modules=4, - stage3_num_blocks=[4, 4, 4], - stage3_num_channels=[64, 128, 256], - stage4_num_modules=3, - stage4_num_blocks=[4, 4, 4, 4], - stage4_num_channels=[64, 128, 256, 512]) - return model - - -def SE_HRNet_W18_Small_V1(num_classes): +@manager.BACKBONES.add_component +def HRNet_W18_Small_V1(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[1], stage1_num_channels=[32], @@ -893,13 +665,13 @@ def SE_HRNet_W18_Small_V1(num_classes): stage4_num_modules=1, stage4_num_blocks=[2, 2, 2, 2], stage4_num_channels=[16, 32, 64, 128], - has_se=True) + **kwargs) return model -def SE_HRNet_W18_Small_V2(num_classes): +@manager.BACKBONES.add_component +def HRNet_W18_Small_V2(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[2], stage1_num_channels=[64], @@ -912,13 +684,13 @@ def SE_HRNet_W18_Small_V2(num_classes): stage4_num_modules=1, stage4_num_blocks=[2, 2, 2, 2], stage4_num_channels=[18, 36, 72, 144], - has_se=True) + **kwargs) return model -def SE_HRNet_W18(num_classes): +@manager.BACKBONES.add_component +def HRNet_W18(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -931,13 +703,13 @@ def SE_HRNet_W18(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[18, 36, 72, 144], - has_se=True) + **kwargs) return model -def SE_HRNet_W30(num_classes): +@manager.BACKBONES.add_component +def HRNet_W30(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -950,13 +722,13 @@ def SE_HRNet_W30(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[30, 60, 120, 240], - has_se=True) + **kwargs) return model -def SE_HRNet_W32(num_classes): +@manager.BACKBONES.add_component +def HRNet_W32(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -969,13 +741,13 @@ def SE_HRNet_W32(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[32, 64, 128, 256], - has_se=True) + **kwargs) return model -def SE_HRNet_W40(num_classes): +@manager.BACKBONES.add_component +def HRNet_W40(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -988,13 +760,13 @@ def SE_HRNet_W40(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[40, 80, 160, 320], - has_se=True) + **kwargs) return model -def SE_HRNet_W44(num_classes): +@manager.BACKBONES.add_component +def HRNet_W44(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -1007,13 +779,13 @@ def SE_HRNet_W44(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[44, 88, 176, 352], - has_se=True) + **kwargs) return model -def SE_HRNet_W48(num_classes): +@manager.BACKBONES.add_component +def HRNet_W48(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -1026,13 +798,13 @@ def SE_HRNet_W48(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[48, 96, 192, 384], - has_se=True) + **kwargs) return model -def SE_HRNet_W60(num_classes): +@manager.BACKBONES.add_component +def HRNet_W60(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -1045,13 +817,13 @@ def SE_HRNet_W60(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[60, 120, 240, 480], - has_se=True) + **kwargs) return model -def SE_HRNet_W64(num_classes): +@manager.BACKBONES.add_component +def HRNet_W64(**kwargs): model = HRNet( - num_classes=num_classes, stage1_num_modules=1, stage1_num_blocks=[4], stage1_num_channels=[64], @@ -1064,5 +836,5 @@ def SE_HRNet_W64(num_classes): stage4_num_modules=3, stage4_num_blocks=[4, 4, 4, 4], stage4_num_channels=[64, 128, 256, 512], - has_se=True) + **kwargs) return model diff --git a/dygraph/models/fcn.py b/dygraph/models/fcn.py new file mode 100644 index 00000000..1dccffbc --- /dev/null +++ b/dygraph/models/fcn.py @@ -0,0 +1,230 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear +from paddle.fluid.initializer import Normal +from paddle.nn import SyncBatchNorm as BatchNorm + +from dygraph.cvlibs import manager +from dygraph import utils + +__all__ = [ + "fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18", + "fcn_hrnet_w30", "fcn_hrnet_w32", "fcn_hrnet_w40", "fcn_hrnet_w44", + "fcn_hrnet_w48", "fcn_hrnet_w60", "fcn_hrnet_w64" +] + + +class FCN(fluid.dygraph.Layer): + """ + Fully Convolutional Networks for Semantic Segmentation. + https://arxiv.org/abs/1411.4038 + + Args: + backbone (str): backbone name, + num_classes (int): the unique number of target classes. + """ + + def __init__(self, + backbone, + num_classes, + in_channels, + channels=None, + pretrained_model=None, + has_se=False, + ignore_index=255, + **kwargs): + super(FCN, self).__init__() + + self.num_classes = num_classes + self.ignore_index = ignore_index + self.EPS = 1e-5 + if channels is None: + channels = in_channels + + self.backbone = manager.BACKBONES[backbone](**kwargs) + self.conv_last_2 = ConvBNLayer( + num_channels=in_channels, + num_filters=channels, + filter_size=1, + stride=1, + name='conv-2') + self.conv_last_1 = Conv2D( + num_channels=channels, + num_filters=self.num_classes, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr( + initializer=Normal(scale=0.001), name='conv-1_weights')) + self.init_weight(pretrained_model) + + def forward(self, x, label=None, mode='train'): + input_shape = x.shape[2:] + x = self.backbone(x) + x = self.conv_last_2(x) + logit = self.conv_last_1(x) + logit = fluid.layers.resize_bilinear(logit, input_shape) + + if self.training: + if label is None: + raise Exception('Label is need during training') + return self._get_loss(logit, label) + else: + score_map = fluid.layers.softmax(logit, axis=1) + score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) + pred = fluid.layers.argmax(score_map, axis=3) + pred = fluid.layers.unsqueeze(pred, axes=[3]) + return pred, score_map + + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + Args: + pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. + """ + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self.backbone, pretrained_model) + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) + + def _get_loss(self, logit, label): + """ + compute forward loss of the model + + Args: + logit (tensor): the logit of model output + label (tensor): ground truth + + Returns: + avg_loss (tensor): forward loss + """ + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + label = fluid.layers.transpose(label, [0, 2, 3, 1]) + mask = label != self.ignore_index + mask = fluid.layers.cast(mask, 'float32') + loss, probs = fluid.layers.softmax_with_cross_entropy( + logit, + label, + ignore_index=self.ignore_index, + return_softmax=True, + axis=-1) + + loss = loss * mask + avg_loss = fluid.layers.mean(loss) / ( + fluid.layers.mean(mask) + self.EPS) + + label.stop_gradient = True + mask.stop_gradient = True + return avg_loss + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act="relu", + name=None): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + param_attr=ParamAttr( + initializer=Normal(scale=0.001), name=name + "_weights"), + bias_attr=False) + bn_name = name + '_bn' + self._batch_norm = BatchNorm( + num_filters, + weight_attr=ParamAttr( + name=bn_name + '_scale', + initializer=fluid.initializer.Constant(1.0)), + bias_attr=ParamAttr( + bn_name + '_offset', + initializer=fluid.initializer.Constant(0.0))) + self.act = act + + def forward(self, input): + y = self._conv(input) + y = self._batch_norm(y) + if self.act == 'relu': + y = fluid.layers.relu(y) + return y + + +@manager.MODELS.add_component +def fcn_hrnet_w18_small_v1(*args, **kwargs): + return FCN(backbone='HRNet_W18_Small_V1', in_channels=240, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w18_small_v2(*args, **kwargs): + return FCN(backbone='HRNet_W18_Small_V2', in_channels=270, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w18(*args, **kwargs): + return FCN(backbone='HRNet_W18', in_channels=270, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w30(*args, **kwargs): + return FCN(backbone='HRNet_W30', in_channels=450, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w32(*args, **kwargs): + return FCN(backbone='HRNet_W32', in_channels=480, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w40(*args, **kwargs): + return FCN(backbone='HRNet_W40', in_channels=600, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w44(*args, **kwargs): + return FCN(backbone='HRNet_W44', in_channels=660, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w48(*args, **kwargs): + return FCN(backbone='HRNet_W48', in_channels=720, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w60(*args, **kwargs): + return FCN(backbone='HRNet_W60', in_channels=900, **kwargs) + + +@manager.MODELS.add_component +def fcn_hrnet_w64(*args, **kwargs): + return FCN(backbone='HRNet_W64', in_channels=960, **kwargs) diff --git a/dygraph/train.py b/dygraph/train.py index 16f678c5..eb16e996 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -167,8 +167,9 @@ def main(args): transforms=eval_transforms, mode='val') - - model = manager.MODELS[args.model_name](num_classes=train_dataset.num_classes) + model = manager.MODELS[args.model_name]( + num_classes=train_dataset.num_classes, + pretrained_model=args.pretrained_model) # Creat optimizer # todo, may less one than len(loader) @@ -191,7 +192,6 @@ def main(args): save_dir=args.save_dir, iters=args.iters, batch_size=args.batch_size, - pretrained_model=args.pretrained_model, resume_model=args.resume_model, save_interval_iters=args.save_interval_iters, log_iters=args.log_iters, diff --git a/dygraph/val.py b/dygraph/val.py index c4ea97d6..044ee1ee 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -19,7 +19,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -from dygraph.models import MODELS +from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.core import evaluate @@ -32,7 +32,7 @@ def parse_args(): '--model_name', dest='model_name', help='Model type for evaluation, which is one of {}'.format( - str(list(MODELS.keys()))), + str(list(manager.MODELS.components_dict.keys()))), type=str, default='UNet') @@ -87,11 +87,8 @@ def main(args): transforms=eval_transforms, mode='val') - if args.model_name not in MODELS: - raise Exception( - '`--model_name` is invalid. it should be one of {}'.format( - str(list(MODELS.keys())))) - model = MODELS[args.model_name](num_classes=eval_dataset.num_classes) + model = manager.MODELS[args.model_name]( + num_classes=eval_dataset.num_classes) evaluate( model, -- GitLab