generator_styleganv2.py 14.3 KB
Newer Older
H
Hecong Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wangna11BD 已提交
15 16 17 18
# code was heavily based on https://github.com/rosinality/stylegan2-pytorch
# MIT License
# Copyright (c) 2019 Kim Seonghyeon

H
Hecong Wu 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import math
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .builder import GENERATORS
from ...modules.equalized import EqualLinear
from ...modules.fused_act import FusedLeakyReLU
from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur


class PixelNorm(nn.Layer):
    def __init__(self):
        super().__init__()
L
LielinJiang 已提交
34

35 36 37
    def forward(self, inputs):
        return inputs * paddle.rsqrt(
            paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8)
L
LielinJiang 已提交
38 39


H
Hecong Wu 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
class ModulatedConv2D(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()
L
LielinJiang 已提交
53

H
Hecong Wu 已提交
54 55 56 57 58 59
        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample
L
LielinJiang 已提交
60

H
Hecong Wu 已提交
61 62 63 64 65
        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1
L
LielinJiang 已提交
66 67 68 69 70

            self.blur = Upfirdn2dBlur(blur_kernel,
                                      pad=(pad0, pad1),
                                      upsample_factor=factor)

H
Hecong Wu 已提交
71 72 73 74 75
        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2
L
LielinJiang 已提交
76

H
Hecong Wu 已提交
77
            self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
L
LielinJiang 已提交
78 79

        fan_in = in_channel * (kernel_size * kernel_size)
H
Hecong Wu 已提交
80 81
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2
L
LielinJiang 已提交
82

H
Hecong Wu 已提交
83
        self.weight = self.create_parameter(
L
LielinJiang 已提交
84 85 86
            (1, out_channel, in_channel, kernel_size, kernel_size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
87
        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
L
LielinJiang 已提交
88

H
Hecong Wu 已提交
89
        self.demodulate = demodulate
L
LielinJiang 已提交
90

H
Hecong Wu 已提交
91 92 93
    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
L
LielinJiang 已提交
94 95
            f"upsample={self.upsample}, downsample={self.downsample})")

96 97
    def forward(self, inputs, style):
        batch, in_channel, height, width = inputs.shape
L
LielinJiang 已提交
98

H
Hecong Wu 已提交
99 100
        style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
        weight = self.scale * self.weight * style
L
LielinJiang 已提交
101

H
Hecong Wu 已提交
102
        if self.demodulate:
L
LielinJiang 已提交
103
            demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
H
Hecong Wu 已提交
104
            weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
L
LielinJiang 已提交
105 106 107 108

        weight = weight.reshape((batch * self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size))

H
Hecong Wu 已提交
109
        if self.upsample:
110
            inputs = inputs.reshape((1, batch * in_channel, height, width))
L
LielinJiang 已提交
111 112 113 114 115
            weight = weight.reshape((batch, self.out_channel, in_channel,
                                     self.kernel_size, self.kernel_size))
            weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
                (batch * in_channel, self.out_channel, self.kernel_size,
                 self.kernel_size))
116
            out = F.conv2d_transpose(inputs,
L
LielinJiang 已提交
117 118 119 120
                                     weight,
                                     padding=0,
                                     stride=2,
                                     groups=batch)
H
Hecong Wu 已提交
121 122 123
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
            out = self.blur(out)
L
LielinJiang 已提交
124

H
Hecong Wu 已提交
125
        elif self.downsample:
126 127 128 129
            inputs = self.blur(inputs)
            _, _, height, width = inputs.shape
            inputs = inputs.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch)
H
Hecong Wu 已提交
130 131
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
132

H
Hecong Wu 已提交
133
        else:
134 135
            inputs = inputs.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(inputs, weight, padding=self.padding, groups=batch)
H
Hecong Wu 已提交
136 137
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
138

H
Hecong Wu 已提交
139
        return out
L
LielinJiang 已提交
140 141


H
Hecong Wu 已提交
142
class NoiseInjection(nn.Layer):
143
    def __init__(self, is_concat=False):
H
Hecong Wu 已提交
144
        super().__init__()
L
LielinJiang 已提交
145 146 147

        self.weight = self.create_parameter(
            (1, ), default_initializer=nn.initializer.Constant(0.0))
148
        self.is_concat = is_concat
L
LielinJiang 已提交
149

H
Hecong Wu 已提交
150 151 152 153
    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = paddle.randn((batch, 1, height, width))
W
wangna11BD 已提交
154
        if self.is_concat:
155 156 157
            return paddle.concat([image, self.weight * noise], axis=1)
        else:
            return image + self.weight * noise
L
LielinJiang 已提交
158 159


H
Hecong Wu 已提交
160 161 162
class ConstantInput(nn.Layer):
    def __init__(self, channel, size=4):
        super().__init__()
L
LielinJiang 已提交
163 164 165 166 167

        self.input = self.create_parameter(
            (1, channel, size, size),
            default_initializer=nn.initializer.Normal())

168 169
    def forward(self, inputs):
        batch = inputs.shape[0]
H
Hecong Wu 已提交
170
        out = self.input.tile((batch, 1, 1, 1))
L
LielinJiang 已提交
171

H
Hecong Wu 已提交
172
        return out
L
LielinJiang 已提交
173 174


H
Hecong Wu 已提交
175
class StyledConv(nn.Layer):
W
wangna11BD 已提交
176 177 178 179 180 181 182 183 184
    def __init__(self,
                 in_channel,
                 out_channel,
                 kernel_size,
                 style_dim,
                 upsample=False,
                 blur_kernel=[1, 3, 3, 1],
                 demodulate=True,
                 is_concat=False):
H
Hecong Wu 已提交
185
        super().__init__()
L
LielinJiang 已提交
186

H
Hecong Wu 已提交
187 188 189 190 191 192 193 194 195
        self.conv = ModulatedConv2D(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )
L
LielinJiang 已提交
196

197
        self.noise = NoiseInjection(is_concat=is_concat)
W
wangna11BD 已提交
198 199
        self.activate = FusedLeakyReLU(out_channel *
                                       2 if is_concat else out_channel)
L
LielinJiang 已提交
200

201 202
    def forward(self, inputs, style, noise=None):
        out = self.conv(inputs, style)
H
Hecong Wu 已提交
203 204
        out = self.noise(out, noise=noise)
        out = self.activate(out)
L
LielinJiang 已提交
205

H
Hecong Wu 已提交
206
        return out
L
LielinJiang 已提交
207 208


H
Hecong Wu 已提交
209
class ToRGB(nn.Layer):
L
LielinJiang 已提交
210 211 212 213 214
    def __init__(self,
                 in_channel,
                 style_dim,
                 upsample=True,
                 blur_kernel=[1, 3, 3, 1]):
H
Hecong Wu 已提交
215
        super().__init__()
L
LielinJiang 已提交
216

H
Hecong Wu 已提交
217 218
        if upsample:
            self.upsample = Upfirdn2dUpsample(blur_kernel)
L
LielinJiang 已提交
219 220 221 222 223 224 225 226 227

        self.conv = ModulatedConv2D(in_channel,
                                    3,
                                    1,
                                    style_dim,
                                    demodulate=False)
        self.bias = self.create_parameter((1, 3, 1, 1),
                                          nn.initializer.Constant(0.0))

228 229
    def forward(self, inputs, style, skip=None):
        out = self.conv(inputs, style)
H
Hecong Wu 已提交
230
        out = out + self.bias
L
LielinJiang 已提交
231

H
Hecong Wu 已提交
232 233
        if skip is not None:
            skip = self.upsample(skip)
L
LielinJiang 已提交
234

H
Hecong Wu 已提交
235
            out = out + skip
L
LielinJiang 已提交
236

H
Hecong Wu 已提交
237
        return out
L
LielinJiang 已提交
238 239


H
Hecong Wu 已提交
240 241
@GENERATORS.register()
class StyleGANv2Generator(nn.Layer):
W
wangna11BD 已提交
242 243 244 245 246 247 248 249
    def __init__(self,
                 size,
                 style_dim,
                 n_mlp,
                 channel_multiplier=2,
                 blur_kernel=[1, 3, 3, 1],
                 lr_mlp=0.01,
                 is_concat=False):
H
Hecong Wu 已提交
250
        super().__init__()
L
LielinJiang 已提交
251

H
Hecong Wu 已提交
252
        self.size = size
L
LielinJiang 已提交
253

H
Hecong Wu 已提交
254
        self.style_dim = style_dim
L
LielinJiang 已提交
255

H
Hecong Wu 已提交
256
        layers = [PixelNorm()]
L
LielinJiang 已提交
257

H
Hecong Wu 已提交
258 259
        for i in range(n_mlp):
            layers.append(
L
LielinJiang 已提交
260 261 262 263 264
                EqualLinear(style_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation="fused_lrelu"))

H
Hecong Wu 已提交
265
        self.style = nn.Sequential(*layers)
L
LielinJiang 已提交
266

H
Hecong Wu 已提交
267 268 269 270 271 272 273 274 275 276 277
        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }
L
LielinJiang 已提交
278

H
Hecong Wu 已提交
279
        self.input = ConstantInput(self.channels[4])
L
LielinJiang 已提交
280 281 282 283
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim,
284 285
                                blur_kernel=blur_kernel,
                                is_concat=is_concat)
W
wangna11BD 已提交
286 287 288 289
        self.to_rgb1 = ToRGB(self.channels[4] *
                             2 if is_concat else self.channels[4],
                             style_dim,
                             upsample=False)
L
LielinJiang 已提交
290

H
Hecong Wu 已提交
291 292
        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
L
LielinJiang 已提交
293

H
Hecong Wu 已提交
294 295 296 297
        self.convs = nn.LayerList()
        self.upsamples = nn.LayerList()
        self.to_rgbs = nn.LayerList()
        self.noises = nn.Layer()
L
LielinJiang 已提交
298

H
Hecong Wu 已提交
299
        in_channel = self.channels[4]
L
LielinJiang 已提交
300

H
Hecong Wu 已提交
301 302
        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
L
LielinJiang 已提交
303 304 305 306
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f"noise_{layer_idx}",
                                        paddle.randn(shape))

H
Hecong Wu 已提交
307
        for i in range(3, self.log_size + 1):
L
LielinJiang 已提交
308 309
            out_channel = self.channels[2**i]

H
Hecong Wu 已提交
310 311
            self.convs.append(
                StyledConv(
W
wangna11BD 已提交
312
                    in_channel * 2 if is_concat else in_channel,
H
Hecong Wu 已提交
313 314 315 316 317
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
318
                    is_concat=is_concat,
L
LielinJiang 已提交
319 320
                ))

H
Hecong Wu 已提交
321
            self.convs.append(
W
wangna11BD 已提交
322
                StyledConv(out_channel * 2 if is_concat else out_channel,
L
LielinJiang 已提交
323 324 325
                           out_channel,
                           3,
                           style_dim,
326 327
                           blur_kernel=blur_kernel,
                           is_concat=is_concat))
L
LielinJiang 已提交
328

W
wangna11BD 已提交
329 330
            self.to_rgbs.append(
                ToRGB(out_channel * 2 if is_concat else out_channel, style_dim))
L
LielinJiang 已提交
331

H
Hecong Wu 已提交
332
            in_channel = out_channel
L
LielinJiang 已提交
333

H
Hecong Wu 已提交
334
        self.n_latent = self.log_size * 2 - 2
335
        self.is_concat = is_concat
L
LielinJiang 已提交
336

H
Hecong Wu 已提交
337
    def make_noise(self):
L
LielinJiang 已提交
338 339
        noises = [paddle.randn((1, 1, 2**2, 2**2))]

H
Hecong Wu 已提交
340 341
        for i in range(3, self.log_size + 1):
            for _ in range(2):
L
LielinJiang 已提交
342 343
                noises.append(paddle.randn((1, 1, 2**i, 2**i)))

H
Hecong Wu 已提交
344
        return noises
L
LielinJiang 已提交
345

H
Hecong Wu 已提交
346
    def mean_latent(self, n_latent):
L
LielinJiang 已提交
347
        latent_in = paddle.randn((n_latent, self.style_dim))
H
Hecong Wu 已提交
348
        latent = self.style(latent_in).mean(0, keepdim=True)
L
LielinJiang 已提交
349

H
Hecong Wu 已提交
350
        return latent
L
LielinJiang 已提交
351

352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
    def get_latent(self, inputs):
        return self.style(inputs)

    def get_mean_style(self):
        mean_style = None
        with paddle.no_grad():
            for i in range(10):
                style = self.mean_latent(1024)
                if mean_style is None:
                    mean_style = style
                else:
                    mean_style += style

        mean_style /= 10
        return mean_style
L
LielinJiang 已提交
367

H
Hecong Wu 已提交
368 369 370 371 372
    def forward(
        self,
        styles,
        return_latents=False,
        inject_index=None,
373
        truncation=1.0,
H
Hecong Wu 已提交
374 375 376 377 378 379 380
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]
L
LielinJiang 已提交
381

H
Hecong Wu 已提交
382 383 384 385 386
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
L
LielinJiang 已提交
387 388
                    getattr(self.noises, f"noise_{i}")
                    for i in range(self.num_layers)
H
Hecong Wu 已提交
389
                ]
L
LielinJiang 已提交
390

391
        if truncation < 1.0:
H
Hecong Wu 已提交
392
            style_t = []
393 394
            if truncation_latent is None:
                truncation_latent = self.get_mean_style()
H
Hecong Wu 已提交
395
            for style in styles:
L
LielinJiang 已提交
396 397 398
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

H
Hecong Wu 已提交
399
            styles = style_t
L
LielinJiang 已提交
400

H
Hecong Wu 已提交
401 402
        if len(styles) < 2:
            inject_index = self.n_latent
L
LielinJiang 已提交
403

H
Hecong Wu 已提交
404 405
            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
406

H
Hecong Wu 已提交
407 408
            else:
                latent = styles[0]
L
LielinJiang 已提交
409

H
Hecong Wu 已提交
410 411 412
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)
L
LielinJiang 已提交
413

H
Hecong Wu 已提交
414
            latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
415 416 417
            latent2 = styles[1].unsqueeze(1).tile(
                (1, self.n_latent - inject_index, 1))

H
Hecong Wu 已提交
418
            latent = paddle.concat([latent, latent2], 1)
L
LielinJiang 已提交
419

H
Hecong Wu 已提交
420 421
        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
L
LielinJiang 已提交
422

H
Hecong Wu 已提交
423
        skip = self.to_rgb1(out, latent[:, 1])
L
LielinJiang 已提交
424

H
Hecong Wu 已提交
425
        i = 1
426 427 428 429
        if self.is_concat:
            noise_i = 1

            outs = []
W
wangna11BD 已提交
430 431 432 433 434 435 436
            for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
                                            self.to_rgbs):
                out = conv1(out, latent[:, i],
                            noise=noise[(noise_i + 1) // 2])  ### 1 for 2
                out = conv2(out,
                            latent[:, i + 1],
                            noise=noise[(noise_i + 2) // 2])  ### 1 for 2
437
                skip = to_rgb(out, latent[:, i + 2], skip)
W
wangna11BD 已提交
438

439 440
                i += 2
                noise_i += 2
W
wangna11BD 已提交
441 442 443 444
        else:
            for conv1, conv2, noise1, noise2, to_rgb in zip(
                    self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
                    self.to_rgbs):
445 446 447 448 449
                out = conv1(out, latent[:, i], noise=noise1)
                out = conv2(out, latent[:, i + 1], noise=noise2)
                skip = to_rgb(out, latent[:, i + 2], skip)

                i += 2
L
LielinJiang 已提交
450

H
Hecong Wu 已提交
451
        image = skip
L
LielinJiang 已提交
452

H
Hecong Wu 已提交
453 454
        if return_latents:
            return image, latent
L
LielinJiang 已提交
455

H
Hecong Wu 已提交
456 457
        else:
            return image, None