From 79b93f7600fb6458cf0a52d13aa306ee9ec27f52 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 26 Aug 2020 21:31:39 +0800 Subject: [PATCH] add comments --- dygraph/README.md | 6 ++-- dygraph/models/architectures/hrnet.py | 18 ++++++++++- dygraph/models/fcn.py | 5 ++- dygraph/models/unet.py | 44 ++++++++++++++++++++++++--- 4 files changed, 64 insertions(+), 9 deletions(-) diff --git a/dygraph/README.md b/dygraph/README.md index 897e86a9..01d4374e 100644 --- a/dygraph/README.md +++ b/dygraph/README.md @@ -6,7 +6,7 @@ export PYTHONPATH=$PYTHONPATH:`pwd` ## 训练 ``` -python3 train.py --model_name UNet \ +python3 train.py --model_name unet \ --dataset OpticDiscSeg \ --input_size 192 192 \ --iters 10 \ @@ -17,7 +17,7 @@ python3 train.py --model_name UNet \ ## 评估 ``` -python3 val.py --model_name UNet \ +python3 val.py --model_name unet \ --dataset OpticDiscSeg \ --input_size 192 192 \ --model_dir output/best_model @@ -25,7 +25,7 @@ python3 val.py --model_name UNet \ ## 预测 ``` -python3 infer.py --model_name UNet \ +python3 infer.py --model_name unet \ --dataset OpticDiscSeg \ --model_dir output/best_model \ --input_size 192 192 diff --git a/dygraph/models/architectures/hrnet.py b/dygraph/models/architectures/hrnet.py index 4b4750ee..3e12bc05 100644 --- a/dygraph/models/architectures/hrnet.py +++ b/dygraph/models/architectures/hrnet.py @@ -32,7 +32,23 @@ __all__ = [ class HRNet(fluid.dygraph.Layer): """ - HRNet: + HRNet:Deep High-Resolution Representation Learning for Visual Recognition + https://arxiv.org/pdf/1908.07919.pdf. + + Args: + stage1_num_modules (int): number of modules for stage1. Default 1. + stage1_num_blocks (list): number of blocks per module for stage1. Default [4]. + stage1_num_channels (list): number of channels per branch for stage1. Default [64]. + stage2_num_modules (int): number of modules for stage2. Default 1. + stage2_num_blocks (list): number of blocks per module for stage2. Default [4, 4] + stage2_num_channels (list): number of channels per branch for stage2. Default [18, 36]. + stage3_num_modules (int): number of modules for stage3. Default 4. + stage3_num_blocks (list): number of blocks per module for stage3. Default [4, 4, 4] + stage3_num_channels (list): number of channels per branch for stage3. Default [18, 36, 72]. + stage4_num_modules (int): number of modules for stage4. Default 3. + stage4_num_blocks (list): number of blocks per module for stage4. Default [4, 4, 4, 4] + stage4_num_channels (list): number of channels per branch for stage4. Default [18, 36, 72. 144]. + has_se (bool): whether to use Squeeze-and-Excitation module. Default False. """ def __init__(self, diff --git a/dygraph/models/fcn.py b/dygraph/models/fcn.py index 1dccffbc..ce1ab409 100644 --- a/dygraph/models/fcn.py +++ b/dygraph/models/fcn.py @@ -41,6 +41,10 @@ class FCN(fluid.dygraph.Layer): Args: backbone (str): backbone name, num_classes (int): the unique number of target classes. + in_channels (int): the channels of input feature maps. + channels (int): channels after conv layer before the last one. + pretrained_model (str): the path of pretrained model. + ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255. """ def __init__(self, @@ -49,7 +53,6 @@ class FCN(fluid.dygraph.Layer): in_channels, channels=None, pretrained_model=None, - has_se=False, ignore_index=255, **kwargs): super(FCN, self).__init__() diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py index 6f04643a..7a2b80bd 100644 --- a/dygraph/models/unet.py +++ b/dygraph/models/unet.py @@ -12,13 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle.fluid as fluid from paddle.fluid.dygraph import Conv2D, Pool2D -from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm +from paddle.nn import SyncBatchNorm as BatchNorm + +from dygraph.cvlibs import manager +from dygraph import utils class UNet(fluid.dygraph.Layer): - def __init__(self, num_classes, ignore_index=255): + """ + U-Net: Convolutional Networks for Biomedical Image Segmentation. + https://arxiv.org/abs/1505.04597 + + Args: + num_classes (int): the unique number of target classes. + pretrained_model (str): the path of pretrained model. + ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255. + """ + + def __init__(self, num_classes, pretrained_model=None, ignore_index=255): super(UNet, self).__init__() self.encode = UnetEncoder() self.decode = UnetDecode() @@ -26,6 +41,8 @@ class UNet(fluid.dygraph.Layer): self.ignore_index = ignore_index self.EPS = 1e-5 + self.init_weight(pretrained_model) + def forward(self, x, label=None): encode_data, short_cuts = self.encode(x) decode_data = self.decode(encode_data, short_cuts) @@ -39,6 +56,20 @@ class UNet(fluid.dygraph.Layer): pred = fluid.layers.unsqueeze(pred, axes=[3]) return pred, score_map + def init_weight(self, pretrained_model=None): + """ + Initialize the parameters of model parts. + Args: + pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None. + """ + if pretrained_model is not None: + if os.path.exists(pretrained_model): + utils.load_pretrained_model(self.backbone, pretrained_model) + utils.load_pretrained_model(self, pretrained_model) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained_model)) + def _get_loss(self, logit, label): logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) label = fluid.layers.transpose(label, [0, 2, 3, 1]) @@ -108,14 +139,14 @@ class DoubleConv(fluid.dygraph.Layer): filter_size=3, stride=1, padding=1) - self.bn0 = BatchNorm(num_channels=num_filters) + self.bn0 = BatchNorm(num_filters) self.conv1 = Conv2D( num_channels=num_filters, num_filters=num_filters, filter_size=3, stride=1, padding=1) - self.bn1 = BatchNorm(num_channels=num_filters) + self.bn1 = BatchNorm(num_filters) def forward(self, x): x = self.conv0(x) @@ -166,3 +197,8 @@ class GetLogit(fluid.dygraph.Layer): def forward(self, x): x = self.conv(x) return x + + +@manager.MODELS.add_component +def unet(*args, **kwargs): + return UNet(*args, **kwargs) -- GitLab