generator_styleganv2.py 13.5 KB
Newer Older
H
Hecong Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .builder import GENERATORS
from ...modules.equalized import EqualLinear
from ...modules.fused_act import FusedLeakyReLU
from ...modules.upfirdn2d import Upfirdn2dUpsample, Upfirdn2dBlur


class PixelNorm(nn.Layer):
    def __init__(self):
        super().__init__()
L
LielinJiang 已提交
30

H
Hecong Wu 已提交
31
    def forward(self, input):
L
LielinJiang 已提交
32 33 34 35
        return input * paddle.rsqrt(
            paddle.mean(input * input, 1, keepdim=True) + 1e-8)


H
Hecong Wu 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48
class ModulatedConv2D(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()
L
LielinJiang 已提交
49

H
Hecong Wu 已提交
50 51 52 53 54 55
        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample
L
LielinJiang 已提交
56

H
Hecong Wu 已提交
57 58 59 60 61
        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1
L
LielinJiang 已提交
62 63 64 65 66

            self.blur = Upfirdn2dBlur(blur_kernel,
                                      pad=(pad0, pad1),
                                      upsample_factor=factor)

H
Hecong Wu 已提交
67 68 69 70 71
        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2
L
LielinJiang 已提交
72

H
Hecong Wu 已提交
73
            self.blur = Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1))
L
LielinJiang 已提交
74 75

        fan_in = in_channel * (kernel_size * kernel_size)
H
Hecong Wu 已提交
76 77
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2
L
LielinJiang 已提交
78

H
Hecong Wu 已提交
79
        self.weight = self.create_parameter(
L
LielinJiang 已提交
80 81 82
            (1, out_channel, in_channel, kernel_size, kernel_size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
83
        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
L
LielinJiang 已提交
84

H
Hecong Wu 已提交
85
        self.demodulate = demodulate
L
LielinJiang 已提交
86

H
Hecong Wu 已提交
87 88 89
    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
L
LielinJiang 已提交
90 91
            f"upsample={self.upsample}, downsample={self.downsample})")

H
Hecong Wu 已提交
92 93
    def forward(self, input, style):
        batch, in_channel, height, width = input.shape
L
LielinJiang 已提交
94

H
Hecong Wu 已提交
95 96
        style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
        weight = self.scale * self.weight * style
L
LielinJiang 已提交
97

H
Hecong Wu 已提交
98
        if self.demodulate:
L
LielinJiang 已提交
99
            demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
H
Hecong Wu 已提交
100
            weight = weight * demod.reshape((batch, self.out_channel, 1, 1, 1))
L
LielinJiang 已提交
101 102 103 104

        weight = weight.reshape((batch * self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size))

H
Hecong Wu 已提交
105 106
        if self.upsample:
            input = input.reshape((1, batch * in_channel, height, width))
L
LielinJiang 已提交
107 108 109 110 111 112 113 114 115 116
            weight = weight.reshape((batch, self.out_channel, in_channel,
                                     self.kernel_size, self.kernel_size))
            weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
                (batch * in_channel, self.out_channel, self.kernel_size,
                 self.kernel_size))
            out = F.conv2d_transpose(input,
                                     weight,
                                     padding=0,
                                     stride=2,
                                     groups=batch)
H
Hecong Wu 已提交
117 118 119
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
            out = self.blur(out)
L
LielinJiang 已提交
120

H
Hecong Wu 已提交
121 122 123 124 125 126 127
        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
128

H
Hecong Wu 已提交
129 130 131 132 133
        else:
            input = input.reshape((1, batch * in_channel, height, width))
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.reshape((batch, self.out_channel, height, width))
L
LielinJiang 已提交
134

H
Hecong Wu 已提交
135
        return out
L
LielinJiang 已提交
136 137


H
Hecong Wu 已提交
138
class NoiseInjection(nn.Layer):
139
    def __init__(self, is_concat=False):
H
Hecong Wu 已提交
140
        super().__init__()
L
LielinJiang 已提交
141 142 143

        self.weight = self.create_parameter(
            (1, ), default_initializer=nn.initializer.Constant(0.0))
144
        self.is_concat = is_concat
L
LielinJiang 已提交
145

H
Hecong Wu 已提交
146 147 148 149
    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = paddle.randn((batch, 1, height, width))
150 151 152 153
        if self.is_concat: 
            return paddle.concat([image, self.weight * noise], axis=1)
        else:
            return image + self.weight * noise
L
LielinJiang 已提交
154 155


H
Hecong Wu 已提交
156 157 158
class ConstantInput(nn.Layer):
    def __init__(self, channel, size=4):
        super().__init__()
L
LielinJiang 已提交
159 160 161 162 163

        self.input = self.create_parameter(
            (1, channel, size, size),
            default_initializer=nn.initializer.Normal())

H
Hecong Wu 已提交
164 165 166
    def forward(self, input):
        batch = input.shape[0]
        out = self.input.tile((batch, 1, 1, 1))
L
LielinJiang 已提交
167

H
Hecong Wu 已提交
168
        return out
L
LielinJiang 已提交
169 170


H
Hecong Wu 已提交
171 172 173 174 175 176 177 178 179 180
class StyledConv(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        upsample=False,
        blur_kernel=[1, 3, 3, 1],
        demodulate=True,
181
        is_concat=False
H
Hecong Wu 已提交
182 183
    ):
        super().__init__()
L
LielinJiang 已提交
184

H
Hecong Wu 已提交
185 186 187 188 189 190 191 192 193
        self.conv = ModulatedConv2D(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )
L
LielinJiang 已提交
194

195 196
        self.noise = NoiseInjection(is_concat=is_concat)
        self.activate = FusedLeakyReLU(out_channel*2 if is_concat else out_channel)
L
LielinJiang 已提交
197

H
Hecong Wu 已提交
198 199 200 201
    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise)
        out = self.activate(out)
L
LielinJiang 已提交
202

H
Hecong Wu 已提交
203
        return out
L
LielinJiang 已提交
204 205


H
Hecong Wu 已提交
206
class ToRGB(nn.Layer):
L
LielinJiang 已提交
207 208 209 210 211
    def __init__(self,
                 in_channel,
                 style_dim,
                 upsample=True,
                 blur_kernel=[1, 3, 3, 1]):
H
Hecong Wu 已提交
212
        super().__init__()
L
LielinJiang 已提交
213

H
Hecong Wu 已提交
214 215
        if upsample:
            self.upsample = Upfirdn2dUpsample(blur_kernel)
L
LielinJiang 已提交
216 217 218 219 220 221 222 223 224

        self.conv = ModulatedConv2D(in_channel,
                                    3,
                                    1,
                                    style_dim,
                                    demodulate=False)
        self.bias = self.create_parameter((1, 3, 1, 1),
                                          nn.initializer.Constant(0.0))

H
Hecong Wu 已提交
225 226 227
    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias
L
LielinJiang 已提交
228

H
Hecong Wu 已提交
229 230
        if skip is not None:
            skip = self.upsample(skip)
L
LielinJiang 已提交
231

H
Hecong Wu 已提交
232
            out = out + skip
L
LielinJiang 已提交
233

H
Hecong Wu 已提交
234
        return out
L
LielinJiang 已提交
235 236


H
Hecong Wu 已提交
237 238 239 240 241 242 243 244 245 246
@GENERATORS.register()
class StyleGANv2Generator(nn.Layer):
    def __init__(
        self,
        size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
247
        is_concat=False
H
Hecong Wu 已提交
248 249
    ):
        super().__init__()
L
LielinJiang 已提交
250

H
Hecong Wu 已提交
251
        self.size = size
L
LielinJiang 已提交
252

H
Hecong Wu 已提交
253
        self.style_dim = style_dim
L
LielinJiang 已提交
254

H
Hecong Wu 已提交
255
        layers = [PixelNorm()]
L
LielinJiang 已提交
256

H
Hecong Wu 已提交
257 258
        for i in range(n_mlp):
            layers.append(
L
LielinJiang 已提交
259 260 261 262 263
                EqualLinear(style_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation="fused_lrelu"))

H
Hecong Wu 已提交
264
        self.style = nn.Sequential(*layers)
L
LielinJiang 已提交
265

H
Hecong Wu 已提交
266 267 268 269 270 271 272 273 274 275 276
        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }
L
LielinJiang 已提交
277

H
Hecong Wu 已提交
278
        self.input = ConstantInput(self.channels[4])
L
LielinJiang 已提交
279 280 281 282
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim,
283 284 285
                                blur_kernel=blur_kernel,
                                is_concat=is_concat)
        self.to_rgb1 = ToRGB(self.channels[4]*2 if is_concat else self.channels[4], style_dim, upsample=False)
L
LielinJiang 已提交
286

H
Hecong Wu 已提交
287 288
        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
L
LielinJiang 已提交
289

H
Hecong Wu 已提交
290 291 292 293
        self.convs = nn.LayerList()
        self.upsamples = nn.LayerList()
        self.to_rgbs = nn.LayerList()
        self.noises = nn.Layer()
L
LielinJiang 已提交
294

H
Hecong Wu 已提交
295
        in_channel = self.channels[4]
L
LielinJiang 已提交
296

H
Hecong Wu 已提交
297 298
        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
L
LielinJiang 已提交
299 300 301 302
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f"noise_{layer_idx}",
                                        paddle.randn(shape))

H
Hecong Wu 已提交
303
        for i in range(3, self.log_size + 1):
L
LielinJiang 已提交
304 305
            out_channel = self.channels[2**i]

H
Hecong Wu 已提交
306 307
            self.convs.append(
                StyledConv(
308
                    in_channel*2 if is_concat else in_channel,
H
Hecong Wu 已提交
309 310 311 312 313
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
314
                    is_concat=is_concat,
L
LielinJiang 已提交
315 316
                ))

H
Hecong Wu 已提交
317
            self.convs.append(
318
                StyledConv(out_channel*2 if is_concat else out_channel,
L
LielinJiang 已提交
319 320 321
                           out_channel,
                           3,
                           style_dim,
322 323
                           blur_kernel=blur_kernel,
                           is_concat=is_concat))
L
LielinJiang 已提交
324

325
            self.to_rgbs.append(ToRGB(out_channel*2 if is_concat else out_channel, style_dim))
L
LielinJiang 已提交
326

H
Hecong Wu 已提交
327
            in_channel = out_channel
L
LielinJiang 已提交
328

H
Hecong Wu 已提交
329
        self.n_latent = self.log_size * 2 - 2
330
        self.is_concat = is_concat
L
LielinJiang 已提交
331

H
Hecong Wu 已提交
332
    def make_noise(self):
L
LielinJiang 已提交
333 334
        noises = [paddle.randn((1, 1, 2**2, 2**2))]

H
Hecong Wu 已提交
335 336
        for i in range(3, self.log_size + 1):
            for _ in range(2):
L
LielinJiang 已提交
337 338
                noises.append(paddle.randn((1, 1, 2**i, 2**i)))

H
Hecong Wu 已提交
339
        return noises
L
LielinJiang 已提交
340

H
Hecong Wu 已提交
341
    def mean_latent(self, n_latent):
L
LielinJiang 已提交
342
        latent_in = paddle.randn((n_latent, self.style_dim))
H
Hecong Wu 已提交
343
        latent = self.style(latent_in).mean(0, keepdim=True)
L
LielinJiang 已提交
344

H
Hecong Wu 已提交
345
        return latent
L
LielinJiang 已提交
346

H
Hecong Wu 已提交
347 348
    def get_latent(self, input):
        return self.style(input)
L
LielinJiang 已提交
349

H
Hecong Wu 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362
    def forward(
        self,
        styles,
        return_latents=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]
L
LielinJiang 已提交
363

H
Hecong Wu 已提交
364 365 366 367 368
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
L
LielinJiang 已提交
369 370
                    getattr(self.noises, f"noise_{i}")
                    for i in range(self.num_layers)
H
Hecong Wu 已提交
371
                ]
L
LielinJiang 已提交
372

H
Hecong Wu 已提交
373 374
        if truncation < 1:
            style_t = []
L
LielinJiang 已提交
375

H
Hecong Wu 已提交
376
            for style in styles:
L
LielinJiang 已提交
377 378 379
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

H
Hecong Wu 已提交
380
            styles = style_t
L
LielinJiang 已提交
381

H
Hecong Wu 已提交
382 383
        if len(styles) < 2:
            inject_index = self.n_latent
L
LielinJiang 已提交
384

H
Hecong Wu 已提交
385 386
            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
387

H
Hecong Wu 已提交
388 389
            else:
                latent = styles[0]
L
LielinJiang 已提交
390

H
Hecong Wu 已提交
391 392 393
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)
L
LielinJiang 已提交
394

H
Hecong Wu 已提交
395
            latent = styles[0].unsqueeze(1).tile((1, inject_index, 1))
L
LielinJiang 已提交
396 397 398
            latent2 = styles[1].unsqueeze(1).tile(
                (1, self.n_latent - inject_index, 1))

H
Hecong Wu 已提交
399
            latent = paddle.concat([latent, latent2], 1)
L
LielinJiang 已提交
400

H
Hecong Wu 已提交
401 402
        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
L
LielinJiang 已提交
403

H
Hecong Wu 已提交
404
        skip = self.to_rgb1(out, latent[:, 1])
L
LielinJiang 已提交
405

H
Hecong Wu 已提交
406
        i = 1
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
        if self.is_concat:
            noise_i = 1

            outs = []
            for conv1, conv2, to_rgb in zip(
                self.convs[::2], self.convs[1::2], self.to_rgbs):
                out = conv1(out, latent[:, i], noise=noise[(noise_i + 1)//2]) ### 1 for 2
                out = conv2(out, latent[:, i + 1], noise=noise[(noise_i + 2)//2]) ### 1 for 2
                skip = to_rgb(out, latent[:, i + 2], skip)
                
                i += 2
                noise_i += 2
        else:    
            for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2],
                                                            self.convs[1::2],
                                                            noise[1::2],
                                                            noise[2::2],
                                                            self.to_rgbs):
                out = conv1(out, latent[:, i], noise=noise1)
                out = conv2(out, latent[:, i + 1], noise=noise2)
                skip = to_rgb(out, latent[:, i + 2], skip)

                i += 2
L
LielinJiang 已提交
430

H
Hecong Wu 已提交
431
        image = skip
L
LielinJiang 已提交
432

H
Hecong Wu 已提交
433 434
        if return_latents:
            return image, latent
L
LielinJiang 已提交
435

H
Hecong Wu 已提交
436 437
        else:
            return image, None