fcn.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.initializer import Normal
from paddle.nn import SyncBatchNorm as BatchNorm

from dygraph.cvlibs import manager
from dygraph import utils
C
chenguowei01 已提交
28
from dygraph.cvlibs import param_init
29 30 31 32 33 34 35 36

__all__ = [
    "fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18",
    "fcn_hrnet_w30", "fcn_hrnet_w32", "fcn_hrnet_w40", "fcn_hrnet_w44",
    "fcn_hrnet_w48", "fcn_hrnet_w60", "fcn_hrnet_w64"
]


C
chenguowei01 已提交
37
@manager.MODELS.add_component
38 39 40 41 42 43 44
class FCN(fluid.dygraph.Layer):
    """
    Fully Convolutional Networks for Semantic Segmentation.
    https://arxiv.org/abs/1411.4038

    Args:
        num_classes (int): the unique number of target classes.
C
chenguowei01 已提交
45 46 47 48 49 50 51 52 53

        backbone (paddle.nn.Layer): backbone networks.

        model_pretrained (str): the path of pretrained model.

        backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1.

        backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.

C
chenguowei01 已提交
54
        channels (int): channels after conv layer before the last one.
C
chenguowei01 已提交
55

C
chenguowei01 已提交
56
        ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
57 58 59 60
    """

    def __init__(self,
                 num_classes,
C
chenguowei01 已提交
61 62 63 64
                 backbone,
                 model_pretrained=None,
                 backbone_indices=(-1, ),
                 backbone_channels=(270, ),
65 66 67 68 69 70
                 channels=None,
                 ignore_index=255,
                 **kwargs):
        super(FCN, self).__init__()

        self.num_classes = num_classes
C
chenguowei01 已提交
71
        self.backbone_indices = backbone_indices
72 73 74
        self.ignore_index = ignore_index
        self.EPS = 1e-5
        if channels is None:
C
chenguowei01 已提交
75
            channels = backbone_channels[backbone_indices[0]]
76

C
chenguowei01 已提交
77
        self.backbone = backbone
78
        self.conv_last_2 = ConvBNLayer(
C
chenguowei01 已提交
79
            num_channels=backbone_channels[backbone_indices[0]],
80 81
            num_filters=channels,
            filter_size=1,
C
chenguowei01 已提交
82
            stride=1)
83 84 85 86 87
        self.conv_last_1 = Conv2D(
            num_channels=channels,
            num_filters=self.num_classes,
            filter_size=1,
            stride=1,
C
chenguowei01 已提交
88
            padding=0)
C
chenguowei01 已提交
89 90
        if self.training:
            self.init_weight(model_pretrained)
91

C
chenguowei01 已提交
92
    def forward(self, x):
93
        input_shape = x.shape[2:]
C
chenguowei01 已提交
94 95
        fea_list = self.backbone(x)
        x = fea_list[self.backbone_indices[0]]
96 97 98
        x = self.conv_last_2(x)
        logit = self.conv_last_1(x)
        logit = fluid.layers.resize_bilinear(logit, input_shape)
C
chenguowei01 已提交
99 100 101 102 103 104 105 106 107 108 109 110
        return [logit]

        # if self.training:
        #     if label is None:
        #         raise Exception('Label is need during training')
        #     return self._get_loss(logit, label)
        # else:
        #     score_map = fluid.layers.softmax(logit, axis=1)
        #     score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
        #     pred = fluid.layers.argmax(score_map, axis=3)
        #     pred = fluid.layers.unsqueeze(pred, axes=[3])
        #     return pred, score_map
111 112 113 114 115

    def init_weight(self, pretrained_model=None):
        """
        Initialize the parameters of model parts.
        Args:
C
chenguowei01 已提交
116
            pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
117
        """
C
chenguowei01 已提交
118 119 120 121 122 123 124 125 126 127 128
        params = self.parameters()
        for param in params:
            param_name = param.name
            if 'batch_norm' in param_name:
                if 'w_0' in param_name:
                    param_init.constant_init(param, 1.0)
                elif 'b_0' in param_name:
                    param_init.constant_init(param, 0.0)
            if 'conv' in param_name and 'w_0' in param_name:
                param_init.normal_init(param, scale=0.001)

129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        if pretrained_model is not None:
            if os.path.exists(pretrained_model):
                utils.load_pretrained_model(self, pretrained_model)
            else:
                raise Exception('Pretrained model is not found: {}'.format(
                    pretrained_model))


class ConvBNLayer(fluid.dygraph.Layer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
C
chenguowei01 已提交
144
                 act="relu"):
145 146 147 148 149 150 151 152 153 154
        super(ConvBNLayer, self).__init__()

        self._conv = Conv2D(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            bias_attr=False)
C
chenguowei01 已提交
155
        self._batch_norm = BatchNorm(num_filters)
156 157 158 159 160 161 162 163 164 165 166 167
        self.act = act

    def forward(self, input):
        y = self._conv(input)
        y = self._batch_norm(y)
        if self.act == 'relu':
            y = fluid.layers.relu(y)
        return y


@manager.MODELS.add_component
def fcn_hrnet_w18_small_v1(*args, **kwargs):
C
chenguowei01 已提交
168
    return FCN(backbone='HRNet_W18_Small_V1', backbone_channels=(240), **kwargs)
169 170 171 172


@manager.MODELS.add_component
def fcn_hrnet_w18_small_v2(*args, **kwargs):
C
chenguowei01 已提交
173
    return FCN(backbone='HRNet_W18_Small_V2', backbone_channels=(270), **kwargs)
174 175 176 177


@manager.MODELS.add_component
def fcn_hrnet_w18(*args, **kwargs):
C
chenguowei01 已提交
178
    return FCN(backbone='HRNet_W18', backbone_channels=(270), **kwargs)
179 180 181 182


@manager.MODELS.add_component
def fcn_hrnet_w30(*args, **kwargs):
C
chenguowei01 已提交
183
    return FCN(backbone='HRNet_W30', backbone_channels=(450), **kwargs)
184 185 186 187


@manager.MODELS.add_component
def fcn_hrnet_w32(*args, **kwargs):
C
chenguowei01 已提交
188
    return FCN(backbone='HRNet_W32', backbone_channels=(480), **kwargs)
189 190 191 192


@manager.MODELS.add_component
def fcn_hrnet_w40(*args, **kwargs):
C
chenguowei01 已提交
193
    return FCN(backbone='HRNet_W40', backbone_channels=(600), **kwargs)
194 195 196 197


@manager.MODELS.add_component
def fcn_hrnet_w44(*args, **kwargs):
C
chenguowei01 已提交
198
    return FCN(backbone='HRNet_W44', backbone_channels=(660), **kwargs)
199 200 201 202


@manager.MODELS.add_component
def fcn_hrnet_w48(*args, **kwargs):
C
chenguowei01 已提交
203
    return FCN(backbone='HRNet_W48', backbone_channels=(720), **kwargs)
204 205 206 207


@manager.MODELS.add_component
def fcn_hrnet_w60(*args, **kwargs):
C
chenguowei01 已提交
208
    return FCN(backbone='HRNet_W60', backbone_channels=(900), **kwargs)
209 210 211 212


@manager.MODELS.add_component
def fcn_hrnet_w64(*args, **kwargs):
C
chenguowei01 已提交
213
    return FCN(backbone='HRNet_W64', backbone_channels=(960), **kwargs)