nlayers.py 2.1 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5
import paddle
import functools
import numpy as np
import paddle.nn as nn

L
LielinJiang 已提交
6
from ...modules.nn import ReflectionPad2d, LeakyReLU, Dropout, BCEWithLogitsLoss, Pad2D, MSELoss
L
LielinJiang 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from ...modules.norm import build_norm_layer

from .builder import DISCRIMINATORS


@DISCRIMINATORS.register()
class NLayerDiscriminator(paddle.fluid.dygraph.Layer):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'):
        """Construct a PatchGAN discriminator

        Args:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_type (str)      -- normalization layer type
        """
        super(NLayerDiscriminator, self).__init__()
        norm_layer = build_norm_layer(norm_type)
        if type(norm_layer) == functools.partial:  
            use_bias = norm_layer.func == nn.InstanceNorm
        else:
            use_bias = norm_layer == nn.InstanceNorm
        
        kw = 4
        padw = 1
L
LielinJiang 已提交
34
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), LeakyReLU(0.2, True)]
L
LielinJiang 已提交
35 36 37 38 39 40
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers): 
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
L
LielinJiang 已提交
41
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias_attr=use_bias),
L
LielinJiang 已提交
42 43 44 45 46 47 48
                norm_layer(ndf * nf_mult),
                LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
L
LielinJiang 已提交
49
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias_attr=use_bias),
L
LielinJiang 已提交
50 51 52 53
            norm_layer(ndf * nf_mult),
            LeakyReLU(0.2, True)
        ]

L
LielinJiang 已提交
54
        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
L
LielinJiang 已提交
55 56 57 58 59
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)