generator_styleganv2.py 17.5 KB
Newer Older
H
Hecong Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wangna11BD 已提交
15 16 17 18
# code was heavily based on https://github.com/rosinality/stylegan2-pytorch
# MIT License
# Copyright (c) 2019 Kim Seonghyeon

H
Hecong Wu 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import math
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .builder import GENERATORS
from ...modules.equalized import EqualLinear
from ...modules.fused_act import FusedLeakyReLU
from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur


class PixelNorm(nn.Layer):
    def __init__(self):
        super().__init__()
L
LielinJiang 已提交
34

35 36 37
    def forward(self, inputs):
        return inputs * paddle.rsqrt(
            paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8)
L
LielinJiang 已提交
38 39


H
Hecong Wu 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
class ModulatedConv2D(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()
L
LielinJiang 已提交
53

H
Hecong Wu 已提交
54 55 56 57 58 59
        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample
L
LielinJiang 已提交
60

H
Hecong Wu 已提交
61 62 63 64 65
        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1
L
LielinJiang 已提交
66 67 68 69 70

            self.blur = Upfirdn2dBlur(blur_kernel,
                                      pad=(pad0, pad1),
                                      upsample_factor=factor)

H
Hecong Wu 已提交
71 72 73 74 75
        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2
L
LielinJiang 已提交
76

H
Hecong Wu 已提交
77
            self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
L
LielinJiang 已提交
78 79

        fan_in = in_channel * (kernel_size * kernel_size)
H
Hecong Wu 已提交
80 81
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2
L
LielinJiang 已提交
82

H
Hecong Wu 已提交
83
        self.weight = self.create_parameter(
L
LielinJiang 已提交
84 85 86
            (1, out_channel, in_channel, kernel_size, kernel_size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
87
        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
L
LielinJiang 已提交
88

H
Hecong Wu 已提交
89
        self.demodulate = demodulate
L
LielinJiang 已提交
90

H
Hecong Wu 已提交
91 92 93
    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
L
LielinJiang 已提交
94 95
            f"upsample={self.upsample}, downsample={self.downsample})")

qq_19291021's avatar
qq_19291021 已提交
96
    def forward(self, inputs, style, apply_modulation=False):
97
        batch, in_channel, height, width = inputs.shape
L
LielinJiang 已提交
98

qq_19291021's avatar
qq_19291021 已提交
99 100
        if apply_modulation: style = self.modulation(style)
        style = style.reshape((batch, 1, in_channel, 1, 1))
H
Hecong Wu 已提交
101
        weight = self.scale * self.weight * style
qq_19291021's avatar
qq_19291021 已提交
102
        del style
L
LielinJiang 已提交
103

H
Hecong Wu 已提交
104
        if self.demodulate:
L
LielinJiang 已提交
105
            demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
H
Hecong Wu 已提交
106
            weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
L
LielinJiang 已提交
107 108 109 110

        weight = weight.reshape((batch * self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size))

H
Hecong Wu 已提交
111
        if self.upsample:
112
            inputs = inputs.reshape((1, batch * in_channel, height, width))
L
LielinJiang 已提交
113 114 115 116 117
            weight = weight.reshape((batch, self.out_channel, in_channel,
                                     self.kernel_size, self.kernel_size))
            weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
                (batch * in_channel, self.out_channel, self.kernel_size,
                 self.kernel_size))
118
            out = F.conv2d_transpose(inputs,
L
LielinJiang 已提交
119 120 121 122
                                     weight,
                                     padding=0,
                                     stride=2,
                                     groups=batch)
H
Hecong Wu 已提交
123 124 125
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
            out = self.blur(out)
L
LielinJiang 已提交
126

H
Hecong Wu 已提交
127
        elif self.downsample:
128 129 130 131
            inputs = self.blur(inputs)
            _, _, height, width = inputs.shape
            inputs = inputs.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch)
H
Hecong Wu 已提交
132 133
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
134

H
Hecong Wu 已提交
135
        else:
136 137
            inputs = inputs.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(inputs, weight, padding=self.padding, groups=batch)
H
Hecong Wu 已提交
138 139
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
140

H
Hecong Wu 已提交
141
        return out
L
LielinJiang 已提交
142 143


H
Hecong Wu 已提交
144
class NoiseInjection(nn.Layer):
145
    def __init__(self, is_concat=False):
H
Hecong Wu 已提交
146
        super().__init__()
L
LielinJiang 已提交
147 148 149

        self.weight = self.create_parameter(
            (1, ), default_initializer=nn.initializer.Constant(0.0))
150
        self.is_concat = is_concat
L
LielinJiang 已提交
151

H
Hecong Wu 已提交
152 153 154 155
    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = paddle.randn((batch, 1, height, width))
W
wangna11BD 已提交
156
        if self.is_concat:
157 158 159
            return paddle.concat([image, self.weight * noise], axis=1)
        else:
            return image + self.weight * noise
L
LielinJiang 已提交
160 161


H
Hecong Wu 已提交
162 163 164
class ConstantInput(nn.Layer):
    def __init__(self, channel, size=4):
        super().__init__()
L
LielinJiang 已提交
165 166 167 168 169

        self.input = self.create_parameter(
            (1, channel, size, size),
            default_initializer=nn.initializer.Normal())

qq_19291021's avatar
qq_19291021 已提交
170
    def forward(self, batch):
H
Hecong Wu 已提交
171
        out = self.input.tile((batch, 1, 1, 1))
L
LielinJiang 已提交
172

H
Hecong Wu 已提交
173
        return out
L
LielinJiang 已提交
174 175


H
Hecong Wu 已提交
176
class StyledConv(nn.Layer):
W
wangna11BD 已提交
177 178 179 180 181 182 183 184 185
    def __init__(self,
                 in_channel,
                 out_channel,
                 kernel_size,
                 style_dim,
                 upsample=False,
                 blur_kernel=[1, 3, 3, 1],
                 demodulate=True,
                 is_concat=False):
H
Hecong Wu 已提交
186
        super().__init__()
L
LielinJiang 已提交
187

H
Hecong Wu 已提交
188 189 190 191 192 193 194 195 196
        self.conv = ModulatedConv2D(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )
L
LielinJiang 已提交
197

198
        self.noise = NoiseInjection(is_concat=is_concat)
W
wangna11BD 已提交
199 200
        self.activate = FusedLeakyReLU(out_channel *
                                       2 if is_concat else out_channel)
L
LielinJiang 已提交
201

202 203
    def forward(self, inputs, style, noise=None):
        out = self.conv(inputs, style)
H
Hecong Wu 已提交
204 205
        out = self.noise(out, noise=noise)
        out = self.activate(out)
L
LielinJiang 已提交
206

H
Hecong Wu 已提交
207
        return out
L
LielinJiang 已提交
208 209


H
Hecong Wu 已提交
210
class ToRGB(nn.Layer):
L
LielinJiang 已提交
211 212 213 214 215
    def __init__(self,
                 in_channel,
                 style_dim,
                 upsample=True,
                 blur_kernel=[1, 3, 3, 1]):
H
Hecong Wu 已提交
216
        super().__init__()
L
LielinJiang 已提交
217

H
Hecong Wu 已提交
218 219
        if upsample:
            self.upsample = Upfirdn2dUpsample(blur_kernel)
L
LielinJiang 已提交
220 221 222 223 224 225 226 227 228

        self.conv = ModulatedConv2D(in_channel,
                                    3,
                                    1,
                                    style_dim,
                                    demodulate=False)
        self.bias = self.create_parameter((1, 3, 1, 1),
                                          nn.initializer.Constant(0.0))

229 230
    def forward(self, inputs, style, skip=None):
        out = self.conv(inputs, style)
H
Hecong Wu 已提交
231
        out = out + self.bias
L
LielinJiang 已提交
232

H
Hecong Wu 已提交
233 234
        if skip is not None:
            skip = self.upsample(skip)
L
LielinJiang 已提交
235

H
Hecong Wu 已提交
236
            out = out + skip
L
LielinJiang 已提交
237

H
Hecong Wu 已提交
238
        return out
L
LielinJiang 已提交
239 240


H
Hecong Wu 已提交
241 242
@GENERATORS.register()
class StyleGANv2Generator(nn.Layer):
W
wangna11BD 已提交
243 244 245 246 247 248 249 250
    def __init__(self,
                 size,
                 style_dim,
                 n_mlp,
                 channel_multiplier=2,
                 blur_kernel=[1, 3, 3, 1],
                 lr_mlp=0.01,
                 is_concat=False):
H
Hecong Wu 已提交
251
        super().__init__()
L
LielinJiang 已提交
252

H
Hecong Wu 已提交
253 254
        self.size = size
        self.style_dim = style_dim
qq_19291021's avatar
qq_19291021 已提交
255 256
        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
L
LielinJiang 已提交
257

H
Hecong Wu 已提交
258
        layers = [PixelNorm()]
L
LielinJiang 已提交
259

H
Hecong Wu 已提交
260 261
        for i in range(n_mlp):
            layers.append(
L
LielinJiang 已提交
262 263 264 265 266
                EqualLinear(style_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation="fused_lrelu"))

H
Hecong Wu 已提交
267
        self.style = nn.Sequential(*layers)
L
LielinJiang 已提交
268

H
Hecong Wu 已提交
269 270 271 272 273 274 275 276 277 278 279
        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }
qq_19291021's avatar
qq_19291021 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
        self.channels_lst = []
        self.w_idx_lst = [
            0,1,        # 4
            1,2,3,      # 8
            3,4,5,      # 16
            5,6,7,      # 32
            7,8,9,      # 64
            9,10,11,    # 128
            11,12,13,   # 256
            13,14,15,   # 512
            15,16,17,   # 1024
        ]
        self.style_layers = [
            0,    #1,
            2, 3, #4,
            5, 6, #7,
            8, 9, #10,
            11, 12,# 13,
            14, 15,# 16,
            17, 18,# 19,
            20, 21,# 22,
            23, 24,# 25
        ]

        if self.log_size != 10:
            self.w_idx_lst = self.w_idx_lst[:-(3 * (10 - self.log_size))]
            self.style_layers = self.style_layers[:-(2 * (10 - self.log_size))]
L
LielinJiang 已提交
307

H
Hecong Wu 已提交
308
        self.input = ConstantInput(self.channels[4])
L
LielinJiang 已提交
309 310 311 312
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim,
313 314
                                blur_kernel=blur_kernel,
                                is_concat=is_concat)
W
wangna11BD 已提交
315 316 317 318
        self.to_rgb1 = ToRGB(self.channels[4] *
                             2 if is_concat else self.channels[4],
                             style_dim,
                             upsample=False)
qq_19291021's avatar
qq_19291021 已提交
319
        self.channels_lst.extend([self.channels[4], self.channels[4]])
L
LielinJiang 已提交
320

H
Hecong Wu 已提交
321 322 323 324
        self.convs = nn.LayerList()
        self.upsamples = nn.LayerList()
        self.to_rgbs = nn.LayerList()
        self.noises = nn.Layer()
L
LielinJiang 已提交
325

H
Hecong Wu 已提交
326
        in_channel = self.channels[4]
L
LielinJiang 已提交
327

H
Hecong Wu 已提交
328 329
        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
L
LielinJiang 已提交
330 331 332 333
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f"noise_{layer_idx}",
                                        paddle.randn(shape))

H
Hecong Wu 已提交
334
        for i in range(3, self.log_size + 1):
L
LielinJiang 已提交
335 336
            out_channel = self.channels[2**i]

H
Hecong Wu 已提交
337 338
            self.convs.append(
                StyledConv(
W
wangna11BD 已提交
339
                    in_channel * 2 if is_concat else in_channel,
H
Hecong Wu 已提交
340 341 342 343 344
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
345
                    is_concat=is_concat,
L
LielinJiang 已提交
346 347
                ))

H
Hecong Wu 已提交
348
            self.convs.append(
W
wangna11BD 已提交
349
                StyledConv(out_channel * 2 if is_concat else out_channel,
L
LielinJiang 已提交
350 351 352
                           out_channel,
                           3,
                           style_dim,
353 354
                           blur_kernel=blur_kernel,
                           is_concat=is_concat))
L
LielinJiang 已提交
355

W
wangna11BD 已提交
356 357
            self.to_rgbs.append(
                ToRGB(out_channel * 2 if is_concat else out_channel, style_dim))
L
LielinJiang 已提交
358

qq_19291021's avatar
qq_19291021 已提交
359
            self.channels_lst.extend([in_channel, out_channel, out_channel])
H
Hecong Wu 已提交
360
            in_channel = out_channel
L
LielinJiang 已提交
361

H
Hecong Wu 已提交
362
        self.n_latent = self.log_size * 2 - 2
363
        self.is_concat = is_concat
L
LielinJiang 已提交
364

H
Hecong Wu 已提交
365
    def make_noise(self):
L
LielinJiang 已提交
366 367
        noises = [paddle.randn((1, 1, 2**2, 2**2))]

H
Hecong Wu 已提交
368 369
        for i in range(3, self.log_size + 1):
            for _ in range(2):
L
LielinJiang 已提交
370 371
                noises.append(paddle.randn((1, 1, 2**i, 2**i)))

H
Hecong Wu 已提交
372
        return noises
L
LielinJiang 已提交
373

H
Hecong Wu 已提交
374
    def mean_latent(self, n_latent):
L
LielinJiang 已提交
375
        latent_in = paddle.randn((n_latent, self.style_dim))
H
Hecong Wu 已提交
376
        latent = self.style(latent_in).mean(0, keepdim=True)
L
LielinJiang 已提交
377

H
Hecong Wu 已提交
378
        return latent
L
LielinJiang 已提交
379

380 381 382
    def get_latent(self, inputs):
        return self.style(inputs)

qq_19291021's avatar
qq_19291021 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
    def get_latents(
        self,
        inputs,
        truncation=1.0,
        truncation_cutoff=None,
        truncation_latent=None,
        input_is_latent=False,
    ):
        assert truncation >= 0, "truncation should be a float in range [0, 1]"

        if not input_is_latent:
            style = self.style(inputs)
        if truncation < 1.0:
            if truncation_latent is None:
                truncation_latent = self.get_mean_style()
            cutoff = truncation_cutoff
            if truncation_cutoff is None:
                style = truncation_latent + \
                    truncation * (style - truncation_latent)
            else:
                style[:, :cutoff] = truncation_latent[:, :cutoff] + \
                    truncation * (style[:, :cutoff] - truncation_latent[:, :cutoff])
        return style

    @paddle.no_grad()
    def get_mean_style(self, n_sample=10, n_latent=1024):
409
        mean_style = None
qq_19291021's avatar
qq_19291021 已提交
410 411 412 413 414 415
        for i in range(n_sample):
            style = self.mean_latent(n_latent)
            if mean_style is None:
                mean_style = style
            else:
                mean_style += style
416

qq_19291021's avatar
qq_19291021 已提交
417
        mean_style /= n_sample
418
        return mean_style
L
LielinJiang 已提交
419

qq_19291021's avatar
qq_19291021 已提交
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
    def get_latent_S(self, inputs):
        return self.style_affine(self.style(inputs))

    def style_affine(self, latent):
        if latent.ndim < 3:
            latent = latent.unsqueeze(1).tile((1, self.n_latent, 1))
        latent_ = []
        latent_.append(self.conv1.conv.modulation(latent[:, 0]))
        latent_.append(self.to_rgb1.conv.modulation(latent[:, 1]))

        i = 1
        for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
                                        self.to_rgbs):
            latent_.append(conv1.conv.modulation(latent[:, i + 0]))
            latent_.append(conv2.conv.modulation(latent[:, i + 1]))
            latent_.append(to_rgb.conv.modulation(latent[:, i + 2]))
            i += 2
        return latent_  #paddle.concat(latent_, axis=1)

    def synthesis(self,
                  latent,
                  noise=None,
                  randomize_noise=True,
                  is_w_latent=False):
        out = self.input(latent[0].shape[0])
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
                #noise = [paddle.randn(getattr(self.noises, f"noise_{i}").shape) for i in range(self.num_layers)]
            else:
                noise = [
                    getattr(self.noises, f"noise_{i}")
                    for i in range(self.num_layers)
                ]

        out = self.conv1(out, latent[0], noise=noise[0])

        skip = self.to_rgb1(out, latent[1])

        i = 2
        if self.is_concat:
            noise_i = 1

            for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
                                            self.to_rgbs):
                out = conv1(out, latent[i],
                            noise=noise[(noise_i + 1) // 2])  ### 1 for 2
                out = conv2(out, latent[i + 1],
                            noise=noise[(noise_i + 2) // 2])  ### 1 for 2
                skip = to_rgb(out, latent[i + 2], skip)

                i += 3
                noise_i += 2
        else:
            for conv1, conv2, noise1, noise2, to_rgb in zip(
                    self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
                    self.to_rgbs):
                out = conv1(out, latent[i], noise=noise1)
                out = conv2(out, latent[i + 1], noise=noise2)
                skip = to_rgb(out, latent[i + 2], skip)

                i += 3

        return skip  #image = skip

H
Hecong Wu 已提交
485 486 487 488 489
    def forward(
        self,
        styles,
        return_latents=False,
        inject_index=None,
490
        truncation=1.0,
qq_19291021's avatar
qq_19291021 已提交
491
        truncation_cutoff=None,
H
Hecong Wu 已提交
492 493 494 495 496 497 498
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]
L
LielinJiang 已提交
499

500
        if truncation < 1.0:
H
Hecong Wu 已提交
501
            style_t = []
502 503
            if truncation_latent is None:
                truncation_latent = self.get_mean_style()
qq_19291021's avatar
qq_19291021 已提交
504
            cutoff = truncation_cutoff
H
Hecong Wu 已提交
505
            for style in styles:
qq_19291021's avatar
qq_19291021 已提交
506 507 508 509 510 511 512
                if truncation_cutoff is None:
                    style = truncation_latent + \
                        truncation * (style - truncation_latent)
                else:
                    style[:, :cutoff] = truncation_latent[:, :cutoff] + \
                    truncation * (style[:, :cutoff] - truncation_latent[:, :cutoff])
                style_t.append(style)
H
Hecong Wu 已提交
513
            styles = style_t
L
LielinJiang 已提交
514

H
Hecong Wu 已提交
515 516
        if len(styles) < 2:
            inject_index = self.n_latent
L
LielinJiang 已提交
517

H
Hecong Wu 已提交
518 519
            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
520

H
Hecong Wu 已提交
521 522
            else:
                latent = styles[0]
L
LielinJiang 已提交
523

H
Hecong Wu 已提交
524 525 526
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)
L
LielinJiang 已提交
527

H
Hecong Wu 已提交
528
            latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
529 530 531
            latent2 = styles[1].unsqueeze(1).tile(
                (1, self.n_latent - inject_index, 1))

H
Hecong Wu 已提交
532
            latent = paddle.concat([latent, latent2], 1)
L
LielinJiang 已提交
533

qq_19291021's avatar
qq_19291021 已提交
534 535
        #if not input_is_affined_latent:
        styles = self.style_affine(latent)
L
LielinJiang 已提交
536

qq_19291021's avatar
qq_19291021 已提交
537
        image = self.synthesis(styles, noise, randomize_noise)
L
LielinJiang 已提交
538

H
Hecong Wu 已提交
539 540 541 542
        if return_latents:
            return image, latent
        else:
            return image, None