rec_densenet.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dorren 已提交
14 15 16 17 18
"""
This code is refer from:
https://github.com/LBH1024/CAN/models/densenet.py

"""
19 20 21 22 23

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

D
dorren 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class Bottleneck(nn.Layer):
    def __init__(self, nChannels, growthRate, use_dropout):
        super(Bottleneck, self).__init__()
        interChannels = 4 * growthRate
        self.bn1 = nn.BatchNorm2D(interChannels)
        self.conv1 = nn.Conv2D(
            nChannels, interChannels, kernel_size=1,
            bias_attr=None)  # Xavier initialization
        self.bn2 = nn.BatchNorm2D(growthRate)
        self.conv2 = nn.Conv2D(
            interChannels, growthRate, kernel_size=3, padding=1,
            bias_attr=None)  # Xavier initialization
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        if self.use_dropout:
            out = self.dropout(out)
        out = F.relu(self.bn2(self.conv2(out)))
        if self.use_dropout:
            out = self.dropout(out)
        out = paddle.concat([x, out], 1)
        return out


class SingleLayer(nn.Layer):
    def __init__(self, nChannels, growthRate, use_dropout):
        super(SingleLayer, self).__init__()
        self.bn1 = nn.BatchNorm2D(nChannels)
        self.conv1 = nn.Conv2D(
            nChannels, growthRate, kernel_size=3, padding=1, bias_attr=False)

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        out = self.conv1(F.relu(x))
        if self.use_dropout:
            out = self.dropout(out)

        out = paddle.concat([x, out], 1)
        return out


class Transition(nn.Layer):
    def __init__(self, nChannels, out_channels, use_dropout):
        super(Transition, self).__init__()
        self.bn1 = nn.BatchNorm2D(out_channels)
        self.conv1 = nn.Conv2D(
            nChannels, out_channels, kernel_size=1, bias_attr=False)
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        if self.use_dropout:
            out = self.dropout(out)
        out = F.avg_pool2d(out, 2, ceil_mode=True, exclusive=False)
        return out


class DenseNet(nn.Layer):
    def __init__(self, growthRate, reduction, bottleneck, use_dropout,
                 input_channel, **kwargs):
        super(DenseNet, self).__init__()
96

D
dorren 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        nDenseBlocks = 16
        nChannels = 2 * growthRate

        self.conv1 = nn.Conv2D(
            input_channel,
            nChannels,
            kernel_size=7,
            padding=3,
            stride=2,
            bias_attr=False)
        self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
                                       bottleneck, use_dropout)
        nChannels += nDenseBlocks * growthRate
        out_channels = int(math.floor(nChannels * reduction))
        self.trans1 = Transition(nChannels, out_channels, use_dropout)

        nChannels = out_channels
        self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
                                       bottleneck, use_dropout)
        nChannels += nDenseBlocks * growthRate
        out_channels = int(math.floor(nChannels * reduction))
        self.trans2 = Transition(nChannels, out_channels, use_dropout)

        nChannels = out_channels
        self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
                                       bottleneck, use_dropout)
        self.out_channels = out_channels

    def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,
                    use_dropout):
        layers = []
        for i in range(int(nDenseBlocks)):
            if bottleneck:
                layers.append(Bottleneck(nChannels, growthRate, use_dropout))
            else:
                layers.append(SingleLayer(nChannels, growthRate, use_dropout))
            nChannels += growthRate
        return nn.Sequential(*layers)

    def forward(self, inputs):
        x, x_m, y = inputs
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2, ceil_mode=True)
        out = self.dense1(out)
        out = self.trans1(out)
        out = self.dense2(out)
        out = self.trans2(out)
        out = self.dense3(out)
        return out, x_m, y