unet.py 7.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
fix nan  
LielinJiang 已提交
15
import functools
L
LielinJiang 已提交
16 17 18 19 20 21 22 23
import paddle
import paddle.nn as nn

from ...modules.norm import build_norm_layer
from .builder import GENERATORS


@GENERATORS.register()
L
fix nan  
LielinJiang 已提交
24
class UnetGenerator(nn.Layer):
L
LielinJiang 已提交
25
    """Create a Unet-based generator"""
L
fix nan  
LielinJiang 已提交
26 27 28 29 30 31 32
    def __init__(self,
                 input_nc,
                 output_nc,
                 num_downs,
                 ngf=64,
                 norm_type='batch',
                 use_dropout=False):
L
LielinJiang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        """Construct a Unet generator
        Args:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer

        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        norm_layer = build_norm_layer(norm_type)
        # construct unet structure
L
fix nan  
LielinJiang 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        unet_block = UnetSkipConnectionBlock(
            ngf * 8,
            ngf * 8,
            input_nc=None,
            submodule=None,
            norm_layer=norm_layer,
            innermost=True)  # add the innermost layer
        for i in range(num_downs -
                       5):  # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8,
                                                 ngf * 8,
                                                 input_nc=None,
                                                 submodule=unet_block,
                                                 norm_layer=norm_layer,
                                                 use_dropout=use_dropout)
L
LielinJiang 已提交
63
        # gradually reduce the number of filters from ngf * 8 to ngf
L
fix nan  
LielinJiang 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        unet_block = UnetSkipConnectionBlock(ngf * 4,
                                             ngf * 8,
                                             input_nc=None,
                                             submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2,
                                             ngf * 4,
                                             input_nc=None,
                                             submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf,
                                             ngf * 2,
                                             input_nc=None,
                                             submodule=unet_block,
                                             norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(
            output_nc,
            ngf,
            input_nc=input_nc,
            submodule=unet_block,
            outermost=True,
            norm_layer=norm_layer)  # add the outermost layer
L
LielinJiang 已提交
86 87 88 89 90 91

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


L
fix nan  
LielinJiang 已提交
92
class UnetSkipConnectionBlock(nn.Layer):
L
LielinJiang 已提交
93 94 95 96
    """Defines the Unet submodule with skip connection.
        X -------------------identity----------------------
        |-- downsampling -- |submodule| -- upsampling --|
    """
L
fix nan  
LielinJiang 已提交
97 98 99 100 101 102 103 104 105
    def __init__(self,
                 outer_nc,
                 inner_nc,
                 input_nc=None,
                 submodule=None,
                 outermost=False,
                 innermost=False,
                 norm_layer=nn.BatchNorm,
                 use_dropout=False):
L
LielinJiang 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
        """Construct a Unet submodule with skip connections.

        Parameters:
            outer_nc (int) -- the number of filters in the outer conv layer
            inner_nc (int) -- the number of filters in the inner conv layer
            input_nc (int) -- the number of channels in input images/features
            submodule (UnetSkipConnectionBlock) -- previously defined submodules
            outermost (bool)    -- if this module is the outermost module
            innermost (bool)    -- if this module is the innermost module
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
        """
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
L
LielinJiang 已提交
121
            use_bias = norm_layer.func == nn.InstanceNorm2D
L
LielinJiang 已提交
122
        else:
L
LielinJiang 已提交
123
            use_bias = norm_layer == nn.InstanceNorm2D
L
LielinJiang 已提交
124 125
        if input_nc is None:
            input_nc = outer_nc
L
LielinJiang 已提交
126
        downconv = nn.Conv2D(input_nc,
L
fix nan  
LielinJiang 已提交
127 128 129 130 131 132
                             inner_nc,
                             kernel_size=4,
                             stride=2,
                             padding=1,
                             bias_attr=use_bias)
        downrelu = nn.LeakyReLU(0.2)
L
LielinJiang 已提交
133
        downnorm = norm_layer(inner_nc)
L
fix nan  
LielinJiang 已提交
134
        uprelu = nn.ReLU()
L
LielinJiang 已提交
135 136 137
        upnorm = norm_layer(outer_nc)

        if outermost:
L
LielinJiang 已提交
138
            upconv = nn.Conv2DTranspose(inner_nc * 2,
L
fix nan  
LielinJiang 已提交
139 140 141
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
L
LielinJiang 已提交
142 143
                                        padding=1)
            down = [downconv]
L
fix nan  
LielinJiang 已提交
144
            up = [uprelu, upconv, nn.Tanh()]
L
LielinJiang 已提交
145 146
            model = down + [submodule] + up
        elif innermost:
L
LielinJiang 已提交
147
            upconv = nn.Conv2DTranspose(inner_nc,
L
fix nan  
LielinJiang 已提交
148 149 150 151 152
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1,
                                        bias_attr=use_bias)
L
LielinJiang 已提交
153 154 155 156
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
L
LielinJiang 已提交
157
            upconv = nn.Conv2DTranspose(inner_nc * 2,
L
fix nan  
LielinJiang 已提交
158 159 160 161 162
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1,
                                        bias_attr=use_bias)
L
LielinJiang 已提交
163 164 165 166
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
L
fix nan  
LielinJiang 已提交
167
                model = down + [submodule] + up + [nn.Dropout(0.5)]
L
LielinJiang 已提交
168 169 170 171 172 173 174 175
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
L
fix nan  
LielinJiang 已提交
176
        else:  # add skip connections
L
LielinJiang 已提交
177
            return paddle.concat([x, self.model(x)], 1)