unet.py 6.7 KB
Newer Older
L
fix nan  
LielinJiang 已提交
1
import functools
L
LielinJiang 已提交
2 3 4 5 6 7 8 9
import paddle
import paddle.nn as nn

from ...modules.norm import build_norm_layer
from .builder import GENERATORS


@GENERATORS.register()
L
fix nan  
LielinJiang 已提交
10
class UnetGenerator(nn.Layer):
L
LielinJiang 已提交
11
    """Create a Unet-based generator"""
L
fix nan  
LielinJiang 已提交
12 13 14 15 16 17 18
    def __init__(self,
                 input_nc,
                 output_nc,
                 num_downs,
                 ngf=64,
                 norm_type='batch',
                 use_dropout=False):
L
LielinJiang 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
        """Construct a Unet generator
        Args:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer

        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        norm_layer = build_norm_layer(norm_type)
        # construct unet structure
L
fix nan  
LielinJiang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
        unet_block = UnetSkipConnectionBlock(
            ngf * 8,
            ngf * 8,
            input_nc=None,
            submodule=None,
            norm_layer=norm_layer,
            innermost=True)  # add the innermost layer
        for i in range(num_downs -
                       5):  # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8,
                                                 ngf * 8,
                                                 input_nc=None,
                                                 submodule=unet_block,
                                                 norm_layer=norm_layer,
                                                 use_dropout=use_dropout)
L
LielinJiang 已提交
49
        # gradually reduce the number of filters from ngf * 8 to ngf
L
fix nan  
LielinJiang 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
        unet_block = UnetSkipConnectionBlock(ngf * 4,
                                             ngf * 8,
                                             input_nc=None,
                                             submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2,
                                             ngf * 4,
                                             input_nc=None,
                                             submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf,
                                             ngf * 2,
                                             input_nc=None,
                                             submodule=unet_block,
                                             norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(
            output_nc,
            ngf,
            input_nc=input_nc,
            submodule=unet_block,
            outermost=True,
            norm_layer=norm_layer)  # add the outermost layer
L
LielinJiang 已提交
72 73 74 75 76 77

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


L
fix nan  
LielinJiang 已提交
78
class UnetSkipConnectionBlock(nn.Layer):
L
LielinJiang 已提交
79 80 81 82
    """Defines the Unet submodule with skip connection.
        X -------------------identity----------------------
        |-- downsampling -- |submodule| -- upsampling --|
    """
L
fix nan  
LielinJiang 已提交
83 84 85 86 87 88 89 90 91
    def __init__(self,
                 outer_nc,
                 inner_nc,
                 input_nc=None,
                 submodule=None,
                 outermost=False,
                 innermost=False,
                 norm_layer=nn.BatchNorm,
                 use_dropout=False):
L
LielinJiang 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        """Construct a Unet submodule with skip connections.

        Parameters:
            outer_nc (int) -- the number of filters in the outer conv layer
            inner_nc (int) -- the number of filters in the inner conv layer
            input_nc (int) -- the number of channels in input images/features
            submodule (UnetSkipConnectionBlock) -- previously defined submodules
            outermost (bool)    -- if this module is the outermost module
            innermost (bool)    -- if this module is the innermost module
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
        """
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
L
LielinJiang 已提交
107
            use_bias = norm_layer.func == nn.InstanceNorm2D
L
LielinJiang 已提交
108
        else:
L
LielinJiang 已提交
109
            use_bias = norm_layer == nn.InstanceNorm2D
L
LielinJiang 已提交
110 111
        if input_nc is None:
            input_nc = outer_nc
L
LielinJiang 已提交
112
        downconv = nn.Conv2D(input_nc,
L
fix nan  
LielinJiang 已提交
113 114 115 116 117 118
                             inner_nc,
                             kernel_size=4,
                             stride=2,
                             padding=1,
                             bias_attr=use_bias)
        downrelu = nn.LeakyReLU(0.2)
L
LielinJiang 已提交
119
        downnorm = norm_layer(inner_nc)
L
fix nan  
LielinJiang 已提交
120
        uprelu = nn.ReLU()
L
LielinJiang 已提交
121 122 123
        upnorm = norm_layer(outer_nc)

        if outermost:
L
LielinJiang 已提交
124
            upconv = nn.Conv2DTranspose(inner_nc * 2,
L
fix nan  
LielinJiang 已提交
125 126 127
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
L
LielinJiang 已提交
128 129
                                        padding=1)
            down = [downconv]
L
fix nan  
LielinJiang 已提交
130
            up = [uprelu, upconv, nn.Tanh()]
L
LielinJiang 已提交
131 132
            model = down + [submodule] + up
        elif innermost:
L
LielinJiang 已提交
133
            upconv = nn.Conv2DTranspose(inner_nc,
L
fix nan  
LielinJiang 已提交
134 135 136 137 138
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1,
                                        bias_attr=use_bias)
L
LielinJiang 已提交
139 140 141 142
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
L
LielinJiang 已提交
143
            upconv = nn.Conv2DTranspose(inner_nc * 2,
L
fix nan  
LielinJiang 已提交
144 145 146 147 148
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1,
                                        bias_attr=use_bias)
L
LielinJiang 已提交
149 150 151 152
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
L
fix nan  
LielinJiang 已提交
153
                model = down + [submodule] + up + [nn.Dropout(0.5)]
L
LielinJiang 已提交
154 155 156 157 158 159 160 161
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
L
fix nan  
LielinJiang 已提交
162
        else:  # add skip connections
L
LielinJiang 已提交
163
            return paddle.concat([x, self.model(x)], 1)