STGAN_network.py 9.0 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .base_network import conv2d, deconv2d, norm_layer, linear
import paddle.fluid as fluid
import numpy as np

MAX_DIM = 64 * 16


class STGAN_model(object):
    def __init__(self):
        pass

L
lvmengsi 已提交
30 31 32 33 34 35 36
    def network_G(self,
                  input,
                  label_org,
                  label_trg,
                  cfg,
                  name="generator",
                  is_test=False):
L
lvmengsi 已提交
37 38 39 40 41 42
        _a = label_org
        _b = label_trg
        z = self.Genc(
            input,
            name=name + '_Genc',
            n_layers=cfg.n_layers,
L
lvmengsi 已提交
43 44
            dim=cfg.g_base_dims,
            is_test=is_test)
L
lvmengsi 已提交
45 46 47 48
        zb = self.GRU(z,
                      fluid.layers.elementwise_sub(_b, _a),
                      name=name + '_GRU',
                      dim=cfg.g_base_dims,
L
lvmengsi 已提交
49 50
                      n_layers=cfg.gru_n_layers,
                      is_test=is_test) if cfg.use_gru else z
L
lvmengsi 已提交
51 52 53 54 55
        fake_image = self.Gdec(
            zb,
            fluid.layers.elementwise_sub(_b, _a),
            name=name + '_Gdec',
            dim=cfg.g_base_dims,
L
lvmengsi 已提交
56 57
            n_layers=cfg.n_layers,
            is_test=is_test)
L
lvmengsi 已提交
58 59 60 61 62

        za = self.GRU(z,
                      fluid.layers.elementwise_sub(_a, _a),
                      name=name + '_GRU',
                      dim=cfg.g_base_dims,
L
lvmengsi 已提交
63 64
                      n_layers=cfg.gru_n_layers,
                      is_test=is_test) if cfg.use_gru else z
L
lvmengsi 已提交
65 66 67 68 69
        rec_image = self.Gdec(
            za,
            fluid.layers.elementwise_sub(_a, _a),
            name=name + '_Gdec',
            dim=cfg.g_base_dims,
L
lvmengsi 已提交
70 71
            n_layers=cfg.n_layers,
            is_test=is_test)
L
lvmengsi 已提交
72 73 74 75 76 77 78
        return fake_image, rec_image

    def network_D(self, input, cfg, name="discriminator"):
        return self.D(input,
                      n_atts=cfg.c_dim,
                      dim=cfg.d_base_dims,
                      fc_dim=cfg.d_fc_dim,
L
lvmengsi 已提交
79
                      norm=cfg.dis_norm,
L
lvmengsi 已提交
80 81 82 83 84 85 86 87 88
                      n_layers=cfg.n_layers,
                      name=name)

    def concat(self, z, a):
        """Concatenate attribute vector on feature map axis."""
        ones = fluid.layers.fill_constant_batch_size_like(
            z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
        return fluid.layers.concat([z, ones * a], axis=1)

L
lvmengsi 已提交
89
    def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
L
lvmengsi 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103
        z = input
        zs = []
        for i in range(n_layers):
            d = min(dim * 2**i, MAX_DIM)
            z = conv2d(
                z,
                d,
                4,
                2,
                padding_type='SAME',
                norm="batch_norm",
                activation_fn='leaky_relu',
                name=name + str(i),
                use_bias=False,
L
lvmengsi 已提交
104
                relufactor=0.2,
L
lvmengsi 已提交
105 106
                initial='kaiming',
                is_test=is_test)
L
lvmengsi 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119
            zs.append(z)

        return zs

    def GRU(self,
            zs,
            a,
            dim=64,
            n_layers=4,
            inject_layers=4,
            kernel_size=3,
            norm=None,
            pass_state='lstate',
L
lvmengsi 已提交
120 121
            name='G_gru_',
            is_test=False):
L
lvmengsi 已提交
122 123 124 125 126 127 128 129 130 131 132 133

        zs_ = [zs[-1]]
        state = self.concat(zs[-1], a)
        for i in range(n_layers):
            d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
            output = self.gru_cell(
                zs[n_layers - 1 - i],
                state,
                d,
                kernel_size=kernel_size,
                norm=norm,
                pass_state=pass_state,
L
lvmengsi 已提交
134 135
                name=name + str(i),
                is_test=is_test)
L
lvmengsi 已提交
136
            zs_.insert(0, output[0])
L
lvmengsi 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149
            if inject_layers > i:
                state = self.concat(output[1], a)
            else:
                state = output[1]
        return zs_

    def Gdec(self,
             zs,
             a,
             dim=64,
             n_layers=5,
             shortcut_layers=4,
             inject_layers=4,
L
lvmengsi 已提交
150 151
             name='G_dec_',
             is_test=False):
L
lvmengsi 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
        shortcut_layers = min(shortcut_layers, n_layers - 1)
        inject_layers = min(inject_layers, n_layers - 1)

        z = self.concat(zs[-1], a)
        for i in range(n_layers):
            if i < n_layers - 1:
                d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
                z = deconv2d(
                    z,
                    d,
                    4,
                    2,
                    padding_type='SAME',
                    name=name + str(i),
                    norm='batch_norm',
                    activation_fn='relu',
                    use_bias=False,
L
lvmengsi 已提交
169 170
                    initial='kaiming',
                    is_test=is_test)
L
lvmengsi 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184
                if shortcut_layers > i:
                    z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1)
                if inject_layers > i:
                    z = self.concat(z, a)
            else:
                x = z = deconv2d(
                    z,
                    3,
                    4,
                    2,
                    padding_type='SAME',
                    name=name + str(i),
                    activation_fn='tanh',
                    use_bias=True,
L
lvmengsi 已提交
185 186
                    initial='kaiming',
                    is_test=is_test)
L
lvmengsi 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
        return x

    def D(self,
          x,
          n_atts=13,
          dim=64,
          fc_dim=1024,
          n_layers=5,
          norm='instance_norm',
          name='D_'):

        y = x
        for i in range(n_layers):
            d = min(dim * 2**i, MAX_DIM)
            y = conv2d(
                y,
                d,
                4,
                2,
L
lvmengsi 已提交
206 207
                norm=norm,
                padding_type="SAME",
L
lvmengsi 已提交
208 209
                activation_fn='leaky_relu',
                name=name + str(i),
L
lvmengsi 已提交
210 211
                use_bias=(norm == None),
                relufactor=0.2,
L
lvmengsi 已提交
212 213 214 215 216
                initial='kaiming')

        logit_gan = linear(
            y,
            fc_dim,
L
lvmengsi 已提交
217
            activation_fn='leaky_relu',
L
lvmengsi 已提交
218 219 220 221 222 223 224 225
            name=name + 'fc_adv_1',
            initial='kaiming')
        logit_gan = linear(
            logit_gan, 1, name=name + 'fc_adv_2', initial='kaiming')

        logit_att = linear(
            y,
            fc_dim,
L
lvmengsi 已提交
226
            activation_fn='leaky_relu',
L
lvmengsi 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240
            name=name + 'fc_cls_1',
            initial='kaiming')
        logit_att = linear(
            logit_att, n_atts, name=name + 'fc_cls_2', initial='kaiming')

        return logit_gan, logit_att

    def gru_cell(self,
                 in_data,
                 state,
                 out_channel,
                 kernel_size=3,
                 norm=None,
                 pass_state='lstate',
L
lvmengsi 已提交
241 242
                 name='gru',
                 is_test=False):
L
lvmengsi 已提交
243 244 245 246 247 248 249 250
        state_ = deconv2d(
            state,
            out_channel,
            4,
            2,
            padding_type='SAME',
            name=name + '_deconv2d',
            use_bias=True,
L
lvmengsi 已提交
251 252
            initial='kaiming',
            is_test=is_test,
L
lvmengsi 已提交
253 254 255 256 257 258 259 260 261 262 263
        )  # upsample and make `channel` identical to `out_channel`
        reset_gate = conv2d(
            fluid.layers.concat(
                [in_data, state_], axis=1),
            out_channel,
            kernel_size,
            norm=norm,
            activation_fn='sigmoid',
            padding_type='SAME',
            use_bias=True,
            name=name + '_reset_gate',
L
lvmengsi 已提交
264 265
            initial='kaiming',
            is_test=is_test)
L
lvmengsi 已提交
266 267 268 269 270 271 272 273 274 275
        update_gate = conv2d(
            fluid.layers.concat(
                [in_data, state_], axis=1),
            out_channel,
            kernel_size,
            norm=norm,
            activation_fn='sigmoid',
            padding_type='SAME',
            use_bias=True,
            name=name + '_update_gate',
L
lvmengsi 已提交
276 277
            initial='kaiming',
            is_test=is_test)
L
lvmengsi 已提交
278 279 280 281 282 283 284 285 286 287 288
        left_state = reset_gate * state_
        new_info = conv2d(
            fluid.layers.concat(
                [in_data, left_state], axis=1),
            out_channel,
            kernel_size,
            norm=norm,
            activation_fn='tanh',
            name=name + '_info',
            padding_type='SAME',
            use_bias=True,
L
lvmengsi 已提交
289 290
            initial='kaiming',
            is_test=is_test)
L
lvmengsi 已提交
291 292 293 294 295 296 297
        output = (1 - update_gate) * state_ + update_gate * new_info
        if pass_state == 'output':
            return output, output
        elif pass_state == 'state':
            return output, state_
        else:
            return output, left_state