unet.py 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

C
chenguowei01 已提交
17 18 19 20
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d
21 22 23 24
from paddle.nn import SyncBatchNorm as BatchNorm

from paddleseg.cvlibs import manager
from paddleseg import utils
C
chenguowei01 已提交
25
from paddleseg.models.common import layer_libs
26 27


C
chenguowei01 已提交
28 29
@manager.MODELS.add_component
class UNet(nn.Layer):
30 31 32 33 34 35
    """
    U-Net: Convolutional Networks for Biomedical Image Segmentation.
    https://arxiv.org/abs/1505.04597

    Args:
        num_classes (int): the unique number of target classes.
C
chenguowei01 已提交
36
        pretrained (str): the path of pretrained model for fine tuning.
37 38
    """

C
chenguowei01 已提交
39
    def __init__(self, num_classes, pretrained=None):
40
        super(UNet, self).__init__()
C
chenguowei01 已提交
41

42 43 44 45
        self.encode = UnetEncoder()
        self.decode = UnetDecode()
        self.get_logit = GetLogit(64, num_classes)

C
chenguowei01 已提交
46
        utils.load_entire_model(self, pretrained)
47 48

    def forward(self, x, label=None):
C
chenguowei01 已提交
49
        logit_list = []
50 51 52
        encode_data, short_cuts = self.encode(x)
        decode_data = self.decode(encode_data, short_cuts)
        logit = self.get_logit(decode_data)
C
chenguowei01 已提交
53 54
        logit_list.append(logit)
        return logit_list
C
chenguowei01 已提交
55 56 57


class UnetEncoder(nn.Layer):
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    def __init__(self):
        super(UnetEncoder, self).__init__()
        self.double_conv = DoubleConv(3, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)

    def forward(self, x):
        short_cuts = []
        x = self.double_conv(x)
        short_cuts.append(x)
        x = self.down1(x)
        short_cuts.append(x)
        x = self.down2(x)
        short_cuts.append(x)
        x = self.down3(x)
        short_cuts.append(x)
        x = self.down4(x)
        return x, short_cuts


C
chenguowei01 已提交
80
class UnetDecode(nn.Layer):
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    def __init__(self):
        super(UnetDecode, self).__init__()
        self.up1 = Up(512, 256)
        self.up2 = Up(256, 128)
        self.up3 = Up(128, 64)
        self.up4 = Up(64, 64)

    def forward(self, x, short_cuts):
        x = self.up1(x, short_cuts[3])
        x = self.up2(x, short_cuts[2])
        x = self.up3(x, short_cuts[1])
        x = self.up4(x, short_cuts[0])
        return x


C
chenguowei01 已提交
96
class DoubleConv(nn.Layer):
97 98
    def __init__(self, num_channels, num_filters):
        super(DoubleConv, self).__init__()
C
chenguowei01 已提交
99 100 101 102
        self.conv0 = Conv2d(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=3,
103 104 105
            stride=1,
            padding=1)
        self.bn0 = BatchNorm(num_filters)
C
chenguowei01 已提交
106 107 108 109
        self.conv1 = Conv2d(
            in_channels=num_filters,
            out_channels=num_filters,
            kernel_size=3,
110 111 112 113 114 115 116
            stride=1,
            padding=1)
        self.bn1 = BatchNorm(num_filters)

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
C
chenguowei01 已提交
117
        x = F.relu(x)
118 119
        x = self.conv1(x)
        x = self.bn1(x)
C
chenguowei01 已提交
120
        x = F.relu(x)
121 122 123
        return x


C
chenguowei01 已提交
124
class Down(nn.Layer):
125 126
    def __init__(self, num_channels, num_filters):
        super(Down, self).__init__()
C
chenguowei01 已提交
127
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
128 129 130 131 132 133 134 135
        self.double_conv = DoubleConv(num_channels, num_filters)

    def forward(self, x):
        x = self.max_pool(x)
        x = self.double_conv(x)
        return x


C
chenguowei01 已提交
136
class Up(nn.Layer):
137 138 139 140 141
    def __init__(self, num_channels, num_filters):
        super(Up, self).__init__()
        self.double_conv = DoubleConv(2 * num_channels, num_filters)

    def forward(self, x, short_cut):
C
chenguowei01 已提交
142 143
        x = F.resize_bilinear(x, short_cut.shape[2:])
        x = paddle.concat([x, short_cut], axis=1)
144 145 146 147
        x = self.double_conv(x)
        return x


C
chenguowei01 已提交
148
class GetLogit(nn.Layer):
149 150
    def __init__(self, num_channels, num_classes):
        super(GetLogit, self).__init__()
C
chenguowei01 已提交
151 152 153 154
        self.conv = Conv2d(
            in_channels=num_channels,
            out_channels=num_classes,
            kernel_size=3,
155 156 157 158 159 160
            stride=1,
            padding=1)

    def forward(self, x):
        x = self.conv(x)
        return x