AttGAN_network.py 5.6 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .base_network import conv2d, deconv2d, norm_layer, linear
import paddle.fluid as fluid
import numpy as np

MAX_DIM = 64 * 16


class AttGAN_model(object):
    def __init__(self):
        pass

L
lvmengsi 已提交
30 31 32 33 34 35 36
    def network_G(self,
                  input,
                  label_org,
                  label_trg,
                  cfg,
                  name="generator",
                  is_test=False):
L
lvmengsi 已提交
37 38 39 40 41 42
        _a = label_org
        _b = label_trg
        z = self.Genc(
            input,
            name=name + '_Genc',
            dim=cfg.g_base_dims,
L
lvmengsi 已提交
43 44 45 46
            n_layers=cfg.n_layers,
            is_test=is_test)
        fake_image = self.Gdec(
            z, _b, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test)
L
lvmengsi 已提交
47

L
lvmengsi 已提交
48 49
        rec_image = self.Gdec(
            z, _a, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test)
L
lvmengsi 已提交
50 51 52 53 54 55 56 57
        return fake_image, rec_image

    def network_D(self, input, cfg, name="discriminator"):
        return self.D(input,
                      n_atts=cfg.c_dim,
                      name=name,
                      dim=cfg.d_base_dims,
                      fc_dim=cfg.d_fc_dim,
L
lvmengsi 已提交
58
                      norm=cfg.dis_norm,
L
lvmengsi 已提交
59 60 61 62 63 64
                      n_layers=cfg.n_layers)

    def concat(self, z, a):
        """Concatenate attribute vector on feature map axis."""
        ones = fluid.layers.fill_constant_batch_size_like(
            z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
L
Lv Mengsi 已提交
65
        return fluid.layers.concat([z, fluid.layers.elementwise_mul(ones, a, axis=0)], axis=1)
L
lvmengsi 已提交
66

L
lvmengsi 已提交
67
    def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
L
lvmengsi 已提交
68 69 70 71 72 73
        z = input
        zs = []
        for i in range(n_layers):
            d = min(dim * 2**i, MAX_DIM)
            #SAME padding
            z = conv2d(
L
lvmengsi 已提交
74 75 76 77
                input=z,
                num_filters=d,
                filter_size=4,
                stride=2,
L
lvmengsi 已提交
78 79 80 81 82 83
                padding_type='SAME',
                norm='batch_norm',
                activation_fn='leaky_relu',
                name=name + str(i),
                use_bias=False,
                relufactor=0.01,
L
lvmengsi 已提交
84 85
                initial='kaiming',
                is_test=is_test)
L
lvmengsi 已提交
86 87 88 89 90 91 92 93 94 95 96
            zs.append(z)

        return zs

    def Gdec(self,
             zs,
             a,
             dim=64,
             n_layers=5,
             shortcut_layers=1,
             inject_layers=1,
L
lvmengsi 已提交
97 98
             name='G_dec_',
             is_test=False):
L
lvmengsi 已提交
99 100 101 102 103 104 105 106
        shortcut_layers = min(shortcut_layers, n_layers - 1)
        inject_layers = min(inject_layers, n_layers - 1)

        z = self.concat(zs[-1], a)
        for i in range(n_layers):
            if i < n_layers - 1:
                d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
                z = deconv2d(
L
lvmengsi 已提交
107 108 109 110
                    input=z,
                    num_filters=d,
                    filter_size=4,
                    stride=2,
L
lvmengsi 已提交
111 112 113 114 115
                    padding_type='SAME',
                    name=name + str(i),
                    norm='batch_norm',
                    activation_fn='relu',
                    use_bias=False,
L
lvmengsi 已提交
116 117
                    initial='kaiming',
                    is_test=is_test)
L
lvmengsi 已提交
118 119 120 121 122 123
                if shortcut_layers > i:
                    z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1)
                if inject_layers > i:
                    z = self.concat(z, a)
            else:
                x = z = deconv2d(
L
lvmengsi 已提交
124 125 126 127
                    input=z,
                    num_filters=3,
                    filter_size=4,
                    stride=2,
L
lvmengsi 已提交
128 129 130 131
                    padding_type='SAME',
                    name=name + str(i),
                    activation_fn='tanh',
                    use_bias=True,
L
lvmengsi 已提交
132 133
                    initial='kaiming',
                    is_test=is_test)
L
lvmengsi 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        return x

    def D(self,
          x,
          n_atts=13,
          dim=64,
          fc_dim=1024,
          n_layers=5,
          norm='instance_norm',
          name='D_'):

        y = x
        for i in range(n_layers):
            d = min(dim * 2**i, MAX_DIM)
            y = conv2d(
L
lvmengsi 已提交
149 150 151 152
                input=y,
                num_filters=d,
                filter_size=4,
                stride=2,
L
lvmengsi 已提交
153
                norm=norm,
L
lvmengsi 已提交
154 155 156
                padding=1,
                activation_fn='leaky_relu',
                name=name + str(i),
L
lvmengsi 已提交
157
                use_bias=(norm == None),
L
lvmengsi 已提交
158 159 160 161
                relufactor=0.01,
                initial='kaiming')

        logit_gan = linear(
L
lvmengsi 已提交
162 163
            input=y,
            output_size=fc_dim,
L
lvmengsi 已提交
164 165 166 167 168 169 170
            activation_fn='relu',
            name=name + 'fc_adv_1',
            initial='kaiming')
        logit_gan = linear(
            logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming')

        logit_att = linear(
L
lvmengsi 已提交
171 172
            input=y,
            output_size=fc_dim,
L
lvmengsi 已提交
173 174 175 176 177 178
            activation_fn='relu',
            name=name + 'fc_cls_1',
            initial='kaiming')
        logit_att = linear(logit_att, n_atts, name=name + 'fc_cls_2')

        return logit_gan, logit_att