generator_styleganv2.py 12.4 KB
Newer Older
H
Hecong Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .builder import GENERATORS
from ...modules.equalized import EqualLinear
from ...modules.fused_act import FusedLeakyReLU
from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur


class PixelNorm(nn.Layer):
    def __init__(self):
        super().__init__()
L
LielinJiang 已提交
30

H
Hecong Wu 已提交
31
    def forward(self, input):
L
LielinJiang 已提交
32 33 34 35
        return input * paddle.rsqrt(
            paddle.mean(input * input, 1, keepdim=True) + 1e-8)


H
Hecong Wu 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48
class ModulatedConv2D(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()
L
LielinJiang 已提交
49

H
Hecong Wu 已提交
50 51 52 53 54 55
        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample
L
LielinJiang 已提交
56

H
Hecong Wu 已提交
57 58 59 60 61
        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1
L
LielinJiang 已提交
62 63 64 65 66

            self.blur = Upfirdn2dBlur(blur_kernel,
                                      pad=(pad0, pad1),
                                      upsample_factor=factor)

H
Hecong Wu 已提交
67 68 69 70 71
        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2
L
LielinJiang 已提交
72

H
Hecong Wu 已提交
73
            self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
L
LielinJiang 已提交
74 75

        fan_in = in_channel * (kernel_size * kernel_size)
H
Hecong Wu 已提交
76 77
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2
L
LielinJiang 已提交
78

H
Hecong Wu 已提交
79
        self.weight = self.create_parameter(
L
LielinJiang 已提交
80 81 82
            (1, out_channel, in_channel, kernel_size, kernel_size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
83
        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
L
LielinJiang 已提交
84

H
Hecong Wu 已提交
85
        self.demodulate = demodulate
L
LielinJiang 已提交
86

H
Hecong Wu 已提交
87 88 89
    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
L
LielinJiang 已提交
90 91
            f"upsample={self.upsample}, downsample={self.downsample})")

H
Hecong Wu 已提交
92 93
    def forward(self, input, style):
        batch, in_channel, height, width = input.shape
L
LielinJiang 已提交
94

H
Hecong Wu 已提交
95 96
        style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
        weight = self.scale * self.weight * style
L
LielinJiang 已提交
97

H
Hecong Wu 已提交
98
        if self.demodulate:
L
LielinJiang 已提交
99
            demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
H
Hecong Wu 已提交
100
            weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
L
LielinJiang 已提交
101 102 103 104

        weight = weight.reshape((batch * self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size))

H
Hecong Wu 已提交
105 106
        if self.upsample:
            input = input.reshape((1, batch * in_channel, height, width))
L
LielinJiang 已提交
107 108 109 110 111 112 113 114 115 116
            weight = weight.reshape((batch, self.out_channel, in_channel,
                                     self.kernel_size, self.kernel_size))
            weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
                (batch * in_channel, self.out_channel, self.kernel_size,
                 self.kernel_size))
            out = F.conv2d_transpose(input,
                                     weight,
                                     padding=0,
                                     stride=2,
                                     groups=batch)
H
Hecong Wu 已提交
117 118 119
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
            out = self.blur(out)
L
LielinJiang 已提交
120

H
Hecong Wu 已提交
121 122 123 124 125 126 127
        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
128

H
Hecong Wu 已提交
129 130 131 132 133
        else:
            input = input.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
134

H
Hecong Wu 已提交
135
        return out
L
LielinJiang 已提交
136 137


H
Hecong Wu 已提交
138 139 140
class NoiseInjection(nn.Layer):
    def __init__(self):
        super().__init__()
L
LielinJiang 已提交
141 142 143 144

        self.weight = self.create_parameter(
            (1, ), default_initializer=nn.initializer.Constant(0.0))

H
Hecong Wu 已提交
145 146 147 148
    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = paddle.randn((batch, 1, height, width))
L
LielinJiang 已提交
149

H
Hecong Wu 已提交
150
        return image + self.weight * noise
L
LielinJiang 已提交
151 152


H
Hecong Wu 已提交
153 154 155
class ConstantInput(nn.Layer):
    def __init__(self, channel, size=4):
        super().__init__()
L
LielinJiang 已提交
156 157 158 159 160

        self.input = self.create_parameter(
            (1, channel, size, size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
161 162 163
    def forward(self, input):
        batch = input.shape[0]
        out = self.input.tile((batch, 1, 1, 1))
L
LielinJiang 已提交
164

H
Hecong Wu 已提交
165
        return out
L
LielinJiang 已提交
166 167


H
Hecong Wu 已提交
168 169 170 171 172 173 174 175 176 177 178 179
class StyledConv(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        upsample=False,
        blur_kernel=[1, 3, 3, 1],
        demodulate=True,
    ):
        super().__init__()
L
LielinJiang 已提交
180

H
Hecong Wu 已提交
181 182 183 184 185 186 187 188 189
        self.conv = ModulatedConv2D(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )
L
LielinJiang 已提交
190

H
Hecong Wu 已提交
191 192
        self.noise = NoiseInjection()
        self.activate = FusedLeakyReLU(out_channel)
L
LielinJiang 已提交
193

H
Hecong Wu 已提交
194 195 196 197
    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise)
        out = self.activate(out)
L
LielinJiang 已提交
198

H
Hecong Wu 已提交
199
        return out
L
LielinJiang 已提交
200 201


H
Hecong Wu 已提交
202
class ToRGB(nn.Layer):
L
LielinJiang 已提交
203 204 205 206 207
    def __init__(self,
                 in_channel,
                 style_dim,
                 upsample=True,
                 blur_kernel=[1, 3, 3, 1]):
H
Hecong Wu 已提交
208
        super().__init__()
L
LielinJiang 已提交
209

H
Hecong Wu 已提交
210 211
        if upsample:
            self.upsample = Upfirdn2dUpsample(blur_kernel)
L
LielinJiang 已提交
212 213 214 215 216 217 218 219 220

        self.conv = ModulatedConv2D(in_channel,
                                    3,
                                    1,
                                    style_dim,
                                    demodulate=False)
        self.bias = self.create_parameter((1, 3, 1, 1),
                                          nn.initializer.Constant(0.0))

H
Hecong Wu 已提交
221 222 223
    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias
L
LielinJiang 已提交
224

H
Hecong Wu 已提交
225 226
        if skip is not None:
            skip = self.upsample(skip)
L
LielinJiang 已提交
227

H
Hecong Wu 已提交
228
            out = out + skip
L
LielinJiang 已提交
229

H
Hecong Wu 已提交
230
        return out
L
LielinJiang 已提交
231 232


H
Hecong Wu 已提交
233 234 235 236 237 238 239 240 241 242 243 244
@GENERATORS.register()
class StyleGANv2Generator(nn.Layer):
    def __init__(
        self,
        size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
    ):
        super().__init__()
L
LielinJiang 已提交
245

H
Hecong Wu 已提交
246
        self.size = size
L
LielinJiang 已提交
247

H
Hecong Wu 已提交
248
        self.style_dim = style_dim
L
LielinJiang 已提交
249

H
Hecong Wu 已提交
250
        layers = [PixelNorm()]
L
LielinJiang 已提交
251

H
Hecong Wu 已提交
252 253
        for i in range(n_mlp):
            layers.append(
L
LielinJiang 已提交
254 255 256 257 258
                EqualLinear(style_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation="fused_lrelu"))

H
Hecong Wu 已提交
259
        self.style = nn.Sequential(*layers)
L
LielinJiang 已提交
260

H
Hecong Wu 已提交
261 262 263 264 265 266 267 268 269 270 271
        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }
L
LielinJiang 已提交
272

H
Hecong Wu 已提交
273
        self.input = ConstantInput(self.channels[4])
L
LielinJiang 已提交
274 275 276 277 278
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim,
                                blur_kernel=blur_kernel)
H
Hecong Wu 已提交
279
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
L
LielinJiang 已提交
280

H
Hecong Wu 已提交
281 282
        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
L
LielinJiang 已提交
283

H
Hecong Wu 已提交
284 285 286 287
        self.convs = nn.LayerList()
        self.upsamples = nn.LayerList()
        self.to_rgbs = nn.LayerList()
        self.noises = nn.Layer()
L
LielinJiang 已提交
288

H
Hecong Wu 已提交
289
        in_channel = self.channels[4]
L
LielinJiang 已提交
290

H
Hecong Wu 已提交
291 292
        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
L
LielinJiang 已提交
293 294 295 296
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f"noise_{layer_idx}",
                                        paddle.randn(shape))

H
Hecong Wu 已提交
297
        for i in range(3, self.log_size + 1):
L
LielinJiang 已提交
298 299
            out_channel = self.channels[2**i]

H
Hecong Wu 已提交
300 301 302 303 304 305 306 307
            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
L
LielinJiang 已提交
308 309
                ))

H
Hecong Wu 已提交
310
            self.convs.append(
L
LielinJiang 已提交
311 312 313 314 315 316
                StyledConv(out_channel,
                           out_channel,
                           3,
                           style_dim,
                           blur_kernel=blur_kernel))

H
Hecong Wu 已提交
317
            self.to_rgbs.append(ToRGB(out_channel, style_dim))
L
LielinJiang 已提交
318

H
Hecong Wu 已提交
319
            in_channel = out_channel
L
LielinJiang 已提交
320

H
Hecong Wu 已提交
321
        self.n_latent = self.log_size * 2 - 2
L
LielinJiang 已提交
322

H
Hecong Wu 已提交
323
    def make_noise(self):
L
LielinJiang 已提交
324 325
        noises = [paddle.randn((1, 1, 2**2, 2**2))]

H
Hecong Wu 已提交
326 327
        for i in range(3, self.log_size + 1):
            for _ in range(2):
L
LielinJiang 已提交
328 329
                noises.append(paddle.randn((1, 1, 2**i, 2**i)))

H
Hecong Wu 已提交
330
        return noises
L
LielinJiang 已提交
331

H
Hecong Wu 已提交
332
    def mean_latent(self, n_latent):
L
LielinJiang 已提交
333
        latent_in = paddle.randn((n_latent, self.style_dim))
H
Hecong Wu 已提交
334
        latent = self.style(latent_in).mean(0, keepdim=True)
L
LielinJiang 已提交
335

H
Hecong Wu 已提交
336
        return latent
L
LielinJiang 已提交
337

H
Hecong Wu 已提交
338 339
    def get_latent(self, input):
        return self.style(input)
L
LielinJiang 已提交
340

H
Hecong Wu 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353
    def forward(
        self,
        styles,
        return_latents=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]
L
LielinJiang 已提交
354

H
Hecong Wu 已提交
355 356 357 358 359
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
L
LielinJiang 已提交
360 361
                    getattr(self.noises, f"noise_{i}")
                    for i in range(self.num_layers)
H
Hecong Wu 已提交
362
                ]
L
LielinJiang 已提交
363

H
Hecong Wu 已提交
364 365
        if truncation < 1:
            style_t = []
L
LielinJiang 已提交
366

H
Hecong Wu 已提交
367
            for style in styles:
L
LielinJiang 已提交
368 369 370
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

H
Hecong Wu 已提交
371
            styles = style_t
L
LielinJiang 已提交
372

H
Hecong Wu 已提交
373 374
        if len(styles) < 2:
            inject_index = self.n_latent
L
LielinJiang 已提交
375

H
Hecong Wu 已提交
376 377
            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
378

H
Hecong Wu 已提交
379 380
            else:
                latent = styles[0]
L
LielinJiang 已提交
381

H
Hecong Wu 已提交
382 383 384
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)
L
LielinJiang 已提交
385

H
Hecong Wu 已提交
386
            latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
387 388 389
            latent2 = styles[1].unsqueeze(1).tile(
                (1, self.n_latent - inject_index, 1))

H
Hecong Wu 已提交
390
            latent = paddle.concat([latent, latent2], 1)
L
LielinJiang 已提交
391

H
Hecong Wu 已提交
392 393
        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
L
LielinJiang 已提交
394

H
Hecong Wu 已提交
395
        skip = self.to_rgb1(out, latent[:, 1])
L
LielinJiang 已提交
396

H
Hecong Wu 已提交
397
        i = 1
L
LielinJiang 已提交
398 399 400 401 402
        for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2],
                                                        self.convs[1::2],
                                                        noise[1::2],
                                                        noise[2::2],
                                                        self.to_rgbs):
H
Hecong Wu 已提交
403 404 405
            out = conv1(out, latent[:, i], noise=noise1)
            out = conv2(out, latent[:, i + 1], noise=noise2)
            skip = to_rgb(out, latent[:, i + 2], skip)
L
LielinJiang 已提交
406

H
Hecong Wu 已提交
407
            i += 2
L
LielinJiang 已提交
408

H
Hecong Wu 已提交
409
        image = skip
L
LielinJiang 已提交
410

H
Hecong Wu 已提交
411 412
        if return_latents:
            return image, latent
L
LielinJiang 已提交
413

H
Hecong Wu 已提交
414 415
        else:
            return image, None