resnet.py 6.6 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19 20 21 22 23 24
import paddle
import paddle.nn as nn
import functools

from ...modules.norm import build_norm_layer

from .builder import GENERATORS


@GENERATORS.register()
L
fix nan  
LielinJiang 已提交
25
class ResnetGenerator(nn.Layer):
L
LielinJiang 已提交
26 27 28 29
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """
L
fix nan  
LielinJiang 已提交
30 31 32 33 34 35 36 37
    def __init__(self,
                 input_nc,
                 output_nc,
                 ngf=64,
                 norm_type='instance',
                 use_dropout=False,
                 n_blocks=6,
                 padding_type='reflect'):
L
LielinJiang 已提交
38 39 40 41 42 43 44 45 46 47 48
        """Construct a Resnet-based generator

        Args:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
L
fix nan  
LielinJiang 已提交
49
        assert (n_blocks >= 0)
L
LielinJiang 已提交
50 51 52 53
        super(ResnetGenerator, self).__init__()

        norm_layer = build_norm_layer(norm_type)
        if type(norm_layer) == functools.partial:
L
LielinJiang 已提交
54
            use_bias = norm_layer.func == nn.InstanceNorm2D
L
LielinJiang 已提交
55
        else:
L
LielinJiang 已提交
56
            use_bias = norm_layer == nn.InstanceNorm2D
L
LielinJiang 已提交
57

L
fix nan  
LielinJiang 已提交
58
        model = [
littletomatodonkey's avatar
littletomatodonkey 已提交
59
            nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"),
L
LielinJiang 已提交
60 61 62 63 64
            nn.Conv2D(input_nc,
                      ngf,
                      kernel_size=7,
                      padding=0,
                      bias_attr=use_bias),
L
fix nan  
LielinJiang 已提交
65 66 67
            norm_layer(ngf),
            nn.ReLU()
        ]
L
LielinJiang 已提交
68 69 70

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
L
fix nan  
LielinJiang 已提交
71
            mult = 2**i
L
LielinJiang 已提交
72
            model += [
L
LielinJiang 已提交
73 74 75 76 77 78
                nn.Conv2D(ngf * mult,
                          ngf * mult * 2,
                          kernel_size=3,
                          stride=2,
                          padding=1,
                          bias_attr=use_bias),
L
fix nan  
LielinJiang 已提交
79 80 81 82 83 84
                norm_layer(ngf * mult * 2),
                nn.ReLU()
            ]

        mult = 2**n_downsampling
        for i in range(n_blocks):  # add ResNet blocks
L
LielinJiang 已提交
85

L
fix nan  
LielinJiang 已提交
86
            model += [
L
LielinJiang 已提交
87 88 89 90 91
                ResnetBlock(ngf * mult,
                            padding_type=padding_type,
                            norm_layer=norm_layer,
                            use_dropout=use_dropout,
                            use_bias=use_bias)
L
fix nan  
LielinJiang 已提交
92
            ]
L
LielinJiang 已提交
93 94

        for i in range(n_downsampling):  # add upsampling layers
L
fix nan  
LielinJiang 已提交
95
            mult = 2**(n_downsampling - i)
L
LielinJiang 已提交
96
            model += [
L
LielinJiang 已提交
97 98 99 100 101 102 103
                nn.Conv2DTranspose(ngf * mult,
                                   int(ngf * mult / 2),
                                   kernel_size=3,
                                   stride=2,
                                   padding=1,
                                   output_padding=1,
                                   bias_attr=use_bias),
L
fix nan  
LielinJiang 已提交
104 105 106
                norm_layer(int(ngf * mult / 2)),
                nn.ReLU()
            ]
littletomatodonkey's avatar
littletomatodonkey 已提交
107
        model += [nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect")]
L
LielinJiang 已提交
108
        model += [nn.Conv2D(ngf, output_nc, kernel_size=7, padding=0)]
L
LielinJiang 已提交
109
        model += [nn.Tanh()]
L
LielinJiang 已提交
110 111 112 113 114 115 116 117

        self.model = nn.Sequential(*model)

    def forward(self, x):
        """Standard forward"""
        return self.model(x)


L
fix nan  
LielinJiang 已提交
118
class ResnetBlock(nn.Layer):
L
LielinJiang 已提交
119 120 121 122 123 124 125 126 127 128
    """Define a Resnet block"""
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block

        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(ResnetBlock, self).__init__()
L
fix nan  
LielinJiang 已提交
129 130
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
                                                use_dropout, use_bias)
L
LielinJiang 已提交
131

L
fix nan  
LielinJiang 已提交
132 133
    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,
                         use_bias):
L
LielinJiang 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146
        """Construct a convolutional block.

        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not

        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []
        p = 0
littletomatodonkey's avatar
littletomatodonkey 已提交
147 148
        if padding_type in ['reflect', 'replicate']:
            conv_block += [nn.Pad2D(padding=[1, 1, 1, 1], mode=padding_type)]
L
LielinJiang 已提交
149 150 151
        elif padding_type == 'zero':
            p = 1
        else:
L
LielinJiang 已提交
152 153
            raise NotImplementedError('padding [%s] is not implemented' %
                                      padding_type)
L
fix nan  
LielinJiang 已提交
154 155

        conv_block += [
L
LielinJiang 已提交
156
            nn.Conv2D(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
L
fix nan  
LielinJiang 已提交
157 158 159
            norm_layer(dim),
            nn.ReLU()
        ]
L
LielinJiang 已提交
160
        if use_dropout:
L
fix nan  
LielinJiang 已提交
161 162
            conv_block += [nn.Dropout(0.5)]

L
LielinJiang 已提交
163
        p = 0
littletomatodonkey's avatar
littletomatodonkey 已提交
164 165
        if padding_type in ['reflect', 'replicate']:
            conv_block += [nn.Pad2D(padding=[1, 1, 1, 1], mode=padding_type)]
L
LielinJiang 已提交
166 167 168
        elif padding_type == 'zero':
            p = 1
        else:
L
LielinJiang 已提交
169 170
            raise NotImplementedError('padding [%s] is not implemented' %
                                      padding_type)
L
fix nan  
LielinJiang 已提交
171
        conv_block += [
L
LielinJiang 已提交
172
            nn.Conv2D(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
L
fix nan  
LielinJiang 已提交
173 174
            norm_layer(dim)
        ]
L
LielinJiang 已提交
175 176 177 178 179 180 181

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out