generator_styleganv2.py 13.8 KB
Newer Older
H
Hecong Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wangna11BD 已提交
15 16 17 18
# code was heavily based on https://github.com/rosinality/stylegan2-pytorch
# MIT License
# Copyright (c) 2019 Kim Seonghyeon

H
Hecong Wu 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import math
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .builder import GENERATORS
from ...modules.equalized import EqualLinear
from ...modules.fused_act import FusedLeakyReLU
from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur


class PixelNorm(nn.Layer):
    def __init__(self):
        super().__init__()
L
LielinJiang 已提交
34

H
Hecong Wu 已提交
35
    def forward(self, input):
L
LielinJiang 已提交
36 37 38 39
        return input * paddle.rsqrt(
            paddle.mean(input * input, 1, keepdim=True) + 1e-8)


H
Hecong Wu 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
class ModulatedConv2D(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()
L
LielinJiang 已提交
53

H
Hecong Wu 已提交
54 55 56 57 58 59
        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample
L
LielinJiang 已提交
60

H
Hecong Wu 已提交
61 62 63 64 65
        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1
L
LielinJiang 已提交
66 67 68 69 70

            self.blur = Upfirdn2dBlur(blur_kernel,
                                      pad=(pad0, pad1),
                                      upsample_factor=factor)

H
Hecong Wu 已提交
71 72 73 74 75
        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2
L
LielinJiang 已提交
76

H
Hecong Wu 已提交
77
            self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
L
LielinJiang 已提交
78 79

        fan_in = in_channel * (kernel_size * kernel_size)
H
Hecong Wu 已提交
80 81
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2
L
LielinJiang 已提交
82

H
Hecong Wu 已提交
83
        self.weight = self.create_parameter(
L
LielinJiang 已提交
84 85 86
            (1, out_channel, in_channel, kernel_size, kernel_size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
87
        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
L
LielinJiang 已提交
88

H
Hecong Wu 已提交
89
        self.demodulate = demodulate
L
LielinJiang 已提交
90

H
Hecong Wu 已提交
91 92 93
    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
L
LielinJiang 已提交
94 95
            f"upsample={self.upsample}, downsample={self.downsample})")

H
Hecong Wu 已提交
96 97
    def forward(self, input, style):
        batch, in_channel, height, width = input.shape
L
LielinJiang 已提交
98

H
Hecong Wu 已提交
99 100
        style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
        weight = self.scale * self.weight * style
L
LielinJiang 已提交
101

H
Hecong Wu 已提交
102
        if self.demodulate:
L
LielinJiang 已提交
103
            demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
H
Hecong Wu 已提交
104
            weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
L
LielinJiang 已提交
105 106 107 108

        weight = weight.reshape((batch * self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size))

H
Hecong Wu 已提交
109 110
        if self.upsample:
            input = input.reshape((1, batch * in_channel, height, width))
L
LielinJiang 已提交
111 112 113 114 115 116 117 118 119 120
            weight = weight.reshape((batch, self.out_channel, in_channel,
                                     self.kernel_size, self.kernel_size))
            weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
                (batch * in_channel, self.out_channel, self.kernel_size,
                 self.kernel_size))
            out = F.conv2d_transpose(input,
                                     weight,
                                     padding=0,
                                     stride=2,
                                     groups=batch)
H
Hecong Wu 已提交
121 122 123
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
            out = self.blur(out)
L
LielinJiang 已提交
124

H
Hecong Wu 已提交
125 126 127 128 129 130 131
        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
132

H
Hecong Wu 已提交
133 134 135 136 137
        else:
            input = input.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
138

H
Hecong Wu 已提交
139
        return out
L
LielinJiang 已提交
140 141


H
Hecong Wu 已提交
142
class NoiseInjection(nn.Layer):
143
    def __init__(self, is_concat=False):
H
Hecong Wu 已提交
144
        super().__init__()
L
LielinJiang 已提交
145 146 147

        self.weight = self.create_parameter(
            (1, ), default_initializer=nn.initializer.Constant(0.0))
148
        self.is_concat = is_concat
L
LielinJiang 已提交
149

H
Hecong Wu 已提交
150 151 152 153
    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = paddle.randn((batch, 1, height, width))
W
wangna11BD 已提交
154
        if self.is_concat:
155 156 157
            return paddle.concat([image, self.weight * noise], axis=1)
        else:
            return image + self.weight * noise
L
LielinJiang 已提交
158 159


H
Hecong Wu 已提交
160 161 162
class ConstantInput(nn.Layer):
    def __init__(self, channel, size=4):
        super().__init__()
L
LielinJiang 已提交
163 164 165 166 167

        self.input = self.create_parameter(
            (1, channel, size, size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
168 169 170
    def forward(self, input):
        batch = input.shape[0]
        out = self.input.tile((batch, 1, 1, 1))
L
LielinJiang 已提交
171

H
Hecong Wu 已提交
172
        return out
L
LielinJiang 已提交
173 174


H
Hecong Wu 已提交
175
class StyledConv(nn.Layer):
W
wangna11BD 已提交
176 177 178 179 180 181 182 183 184
    def __init__(self,
                 in_channel,
                 out_channel,
                 kernel_size,
                 style_dim,
                 upsample=False,
                 blur_kernel=[1, 3, 3, 1],
                 demodulate=True,
                 is_concat=False):
H
Hecong Wu 已提交
185
        super().__init__()
L
LielinJiang 已提交
186

H
Hecong Wu 已提交
187 188 189 190 191 192 193 194 195
        self.conv = ModulatedConv2D(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )
L
LielinJiang 已提交
196

197
        self.noise = NoiseInjection(is_concat=is_concat)
W
wangna11BD 已提交
198 199
        self.activate = FusedLeakyReLU(out_channel *
                                       2 if is_concat else out_channel)
L
LielinJiang 已提交
200

H
Hecong Wu 已提交
201 202 203 204
    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise)
        out = self.activate(out)
L
LielinJiang 已提交
205

H
Hecong Wu 已提交
206
        return out
L
LielinJiang 已提交
207 208


H
Hecong Wu 已提交
209
class ToRGB(nn.Layer):
L
LielinJiang 已提交
210 211 212 213 214
    def __init__(self,
                 in_channel,
                 style_dim,
                 upsample=True,
                 blur_kernel=[1, 3, 3, 1]):
H
Hecong Wu 已提交
215
        super().__init__()
L
LielinJiang 已提交
216

H
Hecong Wu 已提交
217 218
        if upsample:
            self.upsample = Upfirdn2dUpsample(blur_kernel)
L
LielinJiang 已提交
219 220 221 222 223 224 225 226 227

        self.conv = ModulatedConv2D(in_channel,
                                    3,
                                    1,
                                    style_dim,
                                    demodulate=False)
        self.bias = self.create_parameter((1, 3, 1, 1),
                                          nn.initializer.Constant(0.0))

H
Hecong Wu 已提交
228 229 230
    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias
L
LielinJiang 已提交
231

H
Hecong Wu 已提交
232 233
        if skip is not None:
            skip = self.upsample(skip)
L
LielinJiang 已提交
234

H
Hecong Wu 已提交
235
            out = out + skip
L
LielinJiang 已提交
236

H
Hecong Wu 已提交
237
        return out
L
LielinJiang 已提交
238 239


H
Hecong Wu 已提交
240 241
@GENERATORS.register()
class StyleGANv2Generator(nn.Layer):
W
wangna11BD 已提交
242 243 244 245 246 247 248 249
    def __init__(self,
                 size,
                 style_dim,
                 n_mlp,
                 channel_multiplier=2,
                 blur_kernel=[1, 3, 3, 1],
                 lr_mlp=0.01,
                 is_concat=False):
H
Hecong Wu 已提交
250
        super().__init__()
L
LielinJiang 已提交
251

H
Hecong Wu 已提交
252
        self.size = size
L
LielinJiang 已提交
253

H
Hecong Wu 已提交
254
        self.style_dim = style_dim
L
LielinJiang 已提交
255

H
Hecong Wu 已提交
256
        layers = [PixelNorm()]
L
LielinJiang 已提交
257

H
Hecong Wu 已提交
258 259
        for i in range(n_mlp):
            layers.append(
L
LielinJiang 已提交
260 261 262 263 264
                EqualLinear(style_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation="fused_lrelu"))

H
Hecong Wu 已提交
265
        self.style = nn.Sequential(*layers)
L
LielinJiang 已提交
266

H
Hecong Wu 已提交
267 268 269 270 271 272 273 274 275 276 277
        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }
L
LielinJiang 已提交
278

H
Hecong Wu 已提交
279
        self.input = ConstantInput(self.channels[4])
L
LielinJiang 已提交
280 281 282 283
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim,
284 285
                                blur_kernel=blur_kernel,
                                is_concat=is_concat)
W
wangna11BD 已提交
286 287 288 289
        self.to_rgb1 = ToRGB(self.channels[4] *
                             2 if is_concat else self.channels[4],
                             style_dim,
                             upsample=False)
L
LielinJiang 已提交
290

H
Hecong Wu 已提交
291 292
        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
L
LielinJiang 已提交
293

H
Hecong Wu 已提交
294 295 296 297
        self.convs = nn.LayerList()
        self.upsamples = nn.LayerList()
        self.to_rgbs = nn.LayerList()
        self.noises = nn.Layer()
L
LielinJiang 已提交
298

H
Hecong Wu 已提交
299
        in_channel = self.channels[4]
L
LielinJiang 已提交
300

H
Hecong Wu 已提交
301 302
        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
L
LielinJiang 已提交
303 304 305 306
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f"noise_{layer_idx}",
                                        paddle.randn(shape))

H
Hecong Wu 已提交
307
        for i in range(3, self.log_size + 1):
L
LielinJiang 已提交
308 309
            out_channel = self.channels[2**i]

H
Hecong Wu 已提交
310 311
            self.convs.append(
                StyledConv(
W
wangna11BD 已提交
312
                    in_channel * 2 if is_concat else in_channel,
H
Hecong Wu 已提交
313 314 315 316 317
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
318
                    is_concat=is_concat,
L
LielinJiang 已提交
319 320
                ))

H
Hecong Wu 已提交
321
            self.convs.append(
W
wangna11BD 已提交
322
                StyledConv(out_channel * 2 if is_concat else out_channel,
L
LielinJiang 已提交
323 324 325
                           out_channel,
                           3,
                           style_dim,
326 327
                           blur_kernel=blur_kernel,
                           is_concat=is_concat))
L
LielinJiang 已提交
328

W
wangna11BD 已提交
329 330
            self.to_rgbs.append(
                ToRGB(out_channel * 2 if is_concat else out_channel, style_dim))
L
LielinJiang 已提交
331

H
Hecong Wu 已提交
332
            in_channel = out_channel
L
LielinJiang 已提交
333

H
Hecong Wu 已提交
334
        self.n_latent = self.log_size * 2 - 2
335
        self.is_concat = is_concat
L
LielinJiang 已提交
336

H
Hecong Wu 已提交
337
    def make_noise(self):
L
LielinJiang 已提交
338 339
        noises = [paddle.randn((1, 1, 2**2, 2**2))]

H
Hecong Wu 已提交
340 341
        for i in range(3, self.log_size + 1):
            for _ in range(2):
L
LielinJiang 已提交
342 343
                noises.append(paddle.randn((1, 1, 2**i, 2**i)))

H
Hecong Wu 已提交
344
        return noises
L
LielinJiang 已提交
345

H
Hecong Wu 已提交
346
    def mean_latent(self, n_latent):
L
LielinJiang 已提交
347
        latent_in = paddle.randn((n_latent, self.style_dim))
H
Hecong Wu 已提交
348
        latent = self.style(latent_in).mean(0, keepdim=True)
L
LielinJiang 已提交
349

H
Hecong Wu 已提交
350
        return latent
L
LielinJiang 已提交
351

H
Hecong Wu 已提交
352 353
    def get_latent(self, input):
        return self.style(input)
L
LielinJiang 已提交
354

H
Hecong Wu 已提交
355 356 357 358 359 360 361 362 363 364 365 366 367
    def forward(
        self,
        styles,
        return_latents=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]
L
LielinJiang 已提交
368

H
Hecong Wu 已提交
369 370 371 372 373
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
L
LielinJiang 已提交
374 375
                    getattr(self.noises, f"noise_{i}")
                    for i in range(self.num_layers)
H
Hecong Wu 已提交
376
                ]
L
LielinJiang 已提交
377

H
Hecong Wu 已提交
378 379
        if truncation < 1:
            style_t = []
L
LielinJiang 已提交
380

H
Hecong Wu 已提交
381
            for style in styles:
L
LielinJiang 已提交
382 383 384
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

H
Hecong Wu 已提交
385
            styles = style_t
L
LielinJiang 已提交
386

H
Hecong Wu 已提交
387 388
        if len(styles) < 2:
            inject_index = self.n_latent
L
LielinJiang 已提交
389

H
Hecong Wu 已提交
390 391
            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
392

H
Hecong Wu 已提交
393 394
            else:
                latent = styles[0]
L
LielinJiang 已提交
395

H
Hecong Wu 已提交
396 397 398
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)
L
LielinJiang 已提交
399

H
Hecong Wu 已提交
400
            latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
401 402 403
            latent2 = styles[1].unsqueeze(1).tile(
                (1, self.n_latent - inject_index, 1))

H
Hecong Wu 已提交
404
            latent = paddle.concat([latent, latent2], 1)
L
LielinJiang 已提交
405

H
Hecong Wu 已提交
406 407
        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
L
LielinJiang 已提交
408

H
Hecong Wu 已提交
409
        skip = self.to_rgb1(out, latent[:, 1])
L
LielinJiang 已提交
410

H
Hecong Wu 已提交
411
        i = 1
412 413 414 415
        if self.is_concat:
            noise_i = 1

            outs = []
W
wangna11BD 已提交
416 417 418 419 420 421 422
            for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
                                            self.to_rgbs):
                out = conv1(out, latent[:, i],
                            noise=noise[(noise_i + 1) // 2])  ### 1 for 2
                out = conv2(out,
                            latent[:, i + 1],
                            noise=noise[(noise_i + 2) // 2])  ### 1 for 2
423
                skip = to_rgb(out, latent[:, i + 2], skip)
W
wangna11BD 已提交
424

425 426
                i += 2
                noise_i += 2
W
wangna11BD 已提交
427 428 429 430
        else:
            for conv1, conv2, noise1, noise2, to_rgb in zip(
                    self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
                    self.to_rgbs):
431 432 433 434 435
                out = conv1(out, latent[:, i], noise=noise1)
                out = conv2(out, latent[:, i + 1], noise=noise2)
                skip = to_rgb(out, latent[:, i + 2], skip)

                i += 2
L
LielinJiang 已提交
436

H
Hecong Wu 已提交
437
        image = skip
L
LielinJiang 已提交
438

H
Hecong Wu 已提交
439 440
        if return_latents:
            return image, latent
L
LielinJiang 已提交
441

H
Hecong Wu 已提交
442 443
        else:
            return image, None