nlayers.py 2.3 KB
Newer Older
L
LielinJiang 已提交
1 2 3
import functools
import numpy as np

L
fix nan  
LielinJiang 已提交
4 5
import paddle
import paddle.nn as nn
L
LielinJiang 已提交
6 7 8 9 10 11
from ...modules.norm import build_norm_layer

from .builder import DISCRIMINATORS


@DISCRIMINATORS.register()
L
fix nan  
LielinJiang 已提交
12
class NLayerDiscriminator(nn.Layer):
L
LielinJiang 已提交
13 14 15 16 17 18 19 20 21 22 23 24
    """Defines a PatchGAN discriminator"""
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'):
        """Construct a PatchGAN discriminator

        Args:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_type (str)      -- normalization layer type
        """
        super(NLayerDiscriminator, self).__init__()
        norm_layer = build_norm_layer(norm_type)
L
fix nan  
LielinJiang 已提交
25
        if type(norm_layer) == functools.partial:
L
LielinJiang 已提交
26 27 28
            use_bias = norm_layer.func == nn.InstanceNorm
        else:
            use_bias = norm_layer == nn.InstanceNorm
L
fix nan  
LielinJiang 已提交
29

L
LielinJiang 已提交
30 31
        kw = 4
        padw = 1
L
fix nan  
LielinJiang 已提交
32 33 34 35
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2)
        ]
L
LielinJiang 已提交
36 37
        nf_mult = 1
        nf_mult_prev = 1
L
fix nan  
LielinJiang 已提交
38
        for n in range(1, n_layers):
L
LielinJiang 已提交
39
            nf_mult_prev = nf_mult
L
fix nan  
LielinJiang 已提交
40
            nf_mult = min(2**n, 8)
L
LielinJiang 已提交
41
            sequence += [
L
fix nan  
LielinJiang 已提交
42 43 44 45 46 47
                nn.Conv2d(ndf * nf_mult_prev,
                          ndf * nf_mult,
                          kernel_size=kw,
                          stride=2,
                          padding=padw,
                          bias_attr=use_bias),
L
LielinJiang 已提交
48
                norm_layer(ndf * nf_mult),
L
fix nan  
LielinJiang 已提交
49
                nn.LeakyReLU(0.2)
L
LielinJiang 已提交
50 51 52
            ]

        nf_mult_prev = nf_mult
L
fix nan  
LielinJiang 已提交
53
        nf_mult = min(2**n_layers, 8)
L
LielinJiang 已提交
54
        sequence += [
L
fix nan  
LielinJiang 已提交
55 56 57 58 59 60
            nn.Conv2d(ndf * nf_mult_prev,
                      ndf * nf_mult,
                      kernel_size=kw,
                      stride=1,
                      padding=padw,
                      bias_attr=use_bias),
L
LielinJiang 已提交
61
            norm_layer(ndf * nf_mult),
L
fix nan  
LielinJiang 已提交
62
            nn.LeakyReLU(0.2)
L
LielinJiang 已提交
63 64
        ]

L
fix nan  
LielinJiang 已提交
65 66 67
        sequence += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
        ]
L
LielinJiang 已提交
68 69 70 71
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
L
fix nan  
LielinJiang 已提交
72
        return self.model(input)