STGAN_network.py 9.4 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .base_network import conv2d, deconv2d, norm_layer, linear
import paddle.fluid as fluid
import numpy as np

MAX_DIM = 64 * 16


class STGAN_model(object):
    def __init__(self):
        pass

L
lvmengsi 已提交
30 31 32 33 34 35 36
    def network_G(self,
                  input,
                  label_org,
                  label_trg,
                  cfg,
                  name="generator",
                  is_test=False):
L
lvmengsi 已提交
37 38 39 40 41 42
        _a = label_org
        _b = label_trg
        z = self.Genc(
            input,
            name=name + '_Genc',
            n_layers=cfg.n_layers,
L
lvmengsi 已提交
43 44
            dim=cfg.g_base_dims,
            is_test=is_test)
L
lvmengsi 已提交
45 46 47 48
        zb = self.GRU(z,
                      fluid.layers.elementwise_sub(_b, _a),
                      name=name + '_GRU',
                      dim=cfg.g_base_dims,
L
lvmengsi 已提交
49 50
                      n_layers=cfg.gru_n_layers,
                      is_test=is_test) if cfg.use_gru else z
L
lvmengsi 已提交
51 52 53 54 55
        fake_image = self.Gdec(
            zb,
            fluid.layers.elementwise_sub(_b, _a),
            name=name + '_Gdec',
            dim=cfg.g_base_dims,
L
lvmengsi 已提交
56 57
            n_layers=cfg.n_layers,
            is_test=is_test)
L
lvmengsi 已提交
58 59 60 61 62

        za = self.GRU(z,
                      fluid.layers.elementwise_sub(_a, _a),
                      name=name + '_GRU',
                      dim=cfg.g_base_dims,
L
lvmengsi 已提交
63 64
                      n_layers=cfg.gru_n_layers,
                      is_test=is_test) if cfg.use_gru else z
L
lvmengsi 已提交
65 66 67 68 69
        rec_image = self.Gdec(
            za,
            fluid.layers.elementwise_sub(_a, _a),
            name=name + '_Gdec',
            dim=cfg.g_base_dims,
L
lvmengsi 已提交
70 71
            n_layers=cfg.n_layers,
            is_test=is_test)
L
lvmengsi 已提交
72 73 74 75 76 77 78
        return fake_image, rec_image

    def network_D(self, input, cfg, name="discriminator"):
        return self.D(input,
                      n_atts=cfg.c_dim,
                      dim=cfg.d_base_dims,
                      fc_dim=cfg.d_fc_dim,
L
lvmengsi 已提交
79
                      norm=cfg.dis_norm,
L
lvmengsi 已提交
80 81 82 83 84 85 86 87 88
                      n_layers=cfg.n_layers,
                      name=name)

    def concat(self, z, a):
        """Concatenate attribute vector on feature map axis."""
        ones = fluid.layers.fill_constant_batch_size_like(
            z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
        return fluid.layers.concat([z, ones * a], axis=1)

L
lvmengsi 已提交
89
    def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
L
lvmengsi 已提交
90 91 92 93 94
        z = input
        zs = []
        for i in range(n_layers):
            d = min(dim * 2**i, MAX_DIM)
            z = conv2d(
L
lvmengsi 已提交
95 96 97 98
                input=z,
                num_filters=d,
                filter_size=4,
                stride=2,
L
lvmengsi 已提交
99 100 101 102 103
                padding_type='SAME',
                norm="batch_norm",
                activation_fn='leaky_relu',
                name=name + str(i),
                use_bias=False,
L
lvmengsi 已提交
104
                relufactor=0.2,
L
lvmengsi 已提交
105 106
                initial='kaiming',
                is_test=is_test)
L
lvmengsi 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119
            zs.append(z)

        return zs

    def GRU(self,
            zs,
            a,
            dim=64,
            n_layers=4,
            inject_layers=4,
            kernel_size=3,
            norm=None,
            pass_state='lstate',
L
lvmengsi 已提交
120 121
            name='G_gru_',
            is_test=False):
L
lvmengsi 已提交
122 123 124 125 126 127

        zs_ = [zs[-1]]
        state = self.concat(zs[-1], a)
        for i in range(n_layers):
            d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
            output = self.gru_cell(
L
lvmengsi 已提交
128 129 130
                in_data=zs[n_layers - 1 - i],
                state=state,
                out_channel=d,
L
lvmengsi 已提交
131 132 133
                kernel_size=kernel_size,
                norm=norm,
                pass_state=pass_state,
L
lvmengsi 已提交
134 135
                name=name + str(i),
                is_test=is_test)
L
lvmengsi 已提交
136
            zs_.insert(0, output[0])
L
lvmengsi 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149
            if inject_layers > i:
                state = self.concat(output[1], a)
            else:
                state = output[1]
        return zs_

    def Gdec(self,
             zs,
             a,
             dim=64,
             n_layers=5,
             shortcut_layers=4,
             inject_layers=4,
L
lvmengsi 已提交
150 151
             name='G_dec_',
             is_test=False):
L
lvmengsi 已提交
152 153 154 155 156 157 158 159
        shortcut_layers = min(shortcut_layers, n_layers - 1)
        inject_layers = min(inject_layers, n_layers - 1)

        z = self.concat(zs[-1], a)
        for i in range(n_layers):
            if i < n_layers - 1:
                d = min(dim * 2**(n_layers - 1 - i), MAX_DIM)
                z = deconv2d(
L
lvmengsi 已提交
160 161 162 163
                    input=z,
                    num_filters=d,
                    filter_size=4,
                    stride=2,
L
lvmengsi 已提交
164 165 166 167 168
                    padding_type='SAME',
                    name=name + str(i),
                    norm='batch_norm',
                    activation_fn='relu',
                    use_bias=False,
L
lvmengsi 已提交
169 170
                    initial='kaiming',
                    is_test=is_test)
L
lvmengsi 已提交
171 172 173 174 175 176
                if shortcut_layers > i:
                    z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1)
                if inject_layers > i:
                    z = self.concat(z, a)
            else:
                x = z = deconv2d(
L
lvmengsi 已提交
177 178 179 180
                    input=z,
                    num_filters=3,
                    filter_size=4,
                    stride=2,
L
lvmengsi 已提交
181 182 183 184
                    padding_type='SAME',
                    name=name + str(i),
                    activation_fn='tanh',
                    use_bias=True,
L
lvmengsi 已提交
185 186
                    initial='kaiming',
                    is_test=is_test)
L
lvmengsi 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        return x

    def D(self,
          x,
          n_atts=13,
          dim=64,
          fc_dim=1024,
          n_layers=5,
          norm='instance_norm',
          name='D_'):

        y = x
        for i in range(n_layers):
            d = min(dim * 2**i, MAX_DIM)
            y = conv2d(
L
lvmengsi 已提交
202 203 204 205
                input=y,
                num_filters=d,
                filter_size=4,
                stride=2,
L
lvmengsi 已提交
206 207
                norm=norm,
                padding_type="SAME",
L
lvmengsi 已提交
208 209
                activation_fn='leaky_relu',
                name=name + str(i),
L
lvmengsi 已提交
210 211
                use_bias=(norm == None),
                relufactor=0.2,
L
lvmengsi 已提交
212 213 214
                initial='kaiming')

        logit_gan = linear(
L
lvmengsi 已提交
215 216
            input=y,
            output_size=fc_dim,
L
lvmengsi 已提交
217
            activation_fn='leaky_relu',
L
lvmengsi 已提交
218 219 220
            name=name + 'fc_adv_1',
            initial='kaiming')
        logit_gan = linear(
L
lvmengsi 已提交
221
            logit_gan, output_size=1, name=name + 'fc_adv_2', initial='kaiming')
L
lvmengsi 已提交
222 223

        logit_att = linear(
L
lvmengsi 已提交
224 225
            input=y,
            output_size=fc_dim,
L
lvmengsi 已提交
226
            activation_fn='leaky_relu',
L
lvmengsi 已提交
227 228 229
            name=name + 'fc_cls_1',
            initial='kaiming')
        logit_att = linear(
L
lvmengsi 已提交
230 231 232 233
            logit_att,
            output_size=n_atts,
            name=name + 'fc_cls_2',
            initial='kaiming')
L
lvmengsi 已提交
234 235 236 237 238 239 240 241 242 243

        return logit_gan, logit_att

    def gru_cell(self,
                 in_data,
                 state,
                 out_channel,
                 kernel_size=3,
                 norm=None,
                 pass_state='lstate',
L
lvmengsi 已提交
244 245
                 name='gru',
                 is_test=False):
L
lvmengsi 已提交
246
        state_ = deconv2d(
L
lvmengsi 已提交
247 248 249 250
            input=state,
            num_filters=out_channel,
            filter_size=4,
            stride=2,
L
lvmengsi 已提交
251 252 253
            padding_type='SAME',
            name=name + '_deconv2d',
            use_bias=True,
L
lvmengsi 已提交
254 255
            initial='kaiming',
            is_test=is_test,
L
lvmengsi 已提交
256 257
        )  # upsample and make `channel` identical to `out_channel`
        reset_gate = conv2d(
L
lvmengsi 已提交
258
            input=fluid.layers.concat(
L
lvmengsi 已提交
259
                [in_data, state_], axis=1),
L
lvmengsi 已提交
260 261
            num_filters=out_channel,
            filter_size=kernel_size,
L
lvmengsi 已提交
262 263 264 265 266
            norm=norm,
            activation_fn='sigmoid',
            padding_type='SAME',
            use_bias=True,
            name=name + '_reset_gate',
L
lvmengsi 已提交
267 268
            initial='kaiming',
            is_test=is_test)
L
lvmengsi 已提交
269
        update_gate = conv2d(
L
lvmengsi 已提交
270
            input=fluid.layers.concat(
L
lvmengsi 已提交
271
                [in_data, state_], axis=1),
L
lvmengsi 已提交
272 273
            num_filters=out_channel,
            filter_size=kernel_size,
L
lvmengsi 已提交
274 275 276 277 278
            norm=norm,
            activation_fn='sigmoid',
            padding_type='SAME',
            use_bias=True,
            name=name + '_update_gate',
L
lvmengsi 已提交
279 280
            initial='kaiming',
            is_test=is_test)
L
lvmengsi 已提交
281 282
        left_state = reset_gate * state_
        new_info = conv2d(
L
lvmengsi 已提交
283
            input=fluid.layers.concat(
L
lvmengsi 已提交
284
                [in_data, left_state], axis=1),
L
lvmengsi 已提交
285 286
            num_filters=out_channel,
            filter_size=kernel_size,
L
lvmengsi 已提交
287 288 289 290 291
            norm=norm,
            activation_fn='tanh',
            name=name + '_info',
            padding_type='SAME',
            use_bias=True,
L
lvmengsi 已提交
292 293
            initial='kaiming',
            is_test=is_test)
L
lvmengsi 已提交
294 295 296 297 298 299 300
        output = (1 - update_gate) * state_ + update_gate * new_info
        if pass_state == 'output':
            return output, output
        elif pass_state == 'state':
            return output, state_
        else:
            return output, left_state