test_imperative_star_gan_with_gradient_penalty.py 19.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18
import unittest

import numpy as np

19 20
import paddle
import paddle.fluid as fluid
21
from paddle import _legacy_C_ops
22
from paddle.tensor import random
23 24 25 26 27

if fluid.is_compiled_with_cuda():
    fluid.core.globals()['FLAGS_cudnn_deterministic'] = True


28
class Config:
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    def __init__(self, place, sort_sum_gradient=True):
        self.place = place

        if isinstance(place, fluid.CPUPlace):
            # CPU cases are extremely slow
            self.g_base_dims = 1
            self.d_base_dims = 1

            self.g_repeat_num = 1
            self.d_repeat_num = 1

            self.image_size = 32
        else:
            self.g_base_dims = 64
            self.d_base_dims = 64

            self.g_repeat_num = 6
            self.d_repeat_num = 6

            self.image_size = 256

        self.c_dim = 10
        self.batch_size = 1

        self.seed = 1

        self.lambda_rec = 10
        self.lambda_gp = 10

        self.iterations = 10

        self.sort_sum_gradient = sort_sum_gradient


def create_mnist_dataset(cfg):
    def create_target_label(label):
        return label
        # return (label + 1) % cfg.c_dim # fake label target

    def create_one_hot(label):
        ret = np.zeros([cfg.c_dim])
        ret[label] = 1
        return ret

    def __impl__():
        dataset = paddle.dataset.mnist.train()
        image_reals = []
        label_orgs = []
        label_trgs = []
        num = 0

        for image_real, label_org in dataset():
            image_real = np.reshape(np.array(image_real), [28, 28])
            image_real = np.resize(image_real, [cfg.image_size, cfg.image_size])
            image_real = np.array([image_real] * 3)

            label_trg = create_target_label(label_org)

            image_reals.append(np.array(image_real))
            label_orgs.append(create_one_hot(label_org))
            label_trgs.append(create_one_hot(label_trg))

            if len(image_reals) == cfg.batch_size:
                image_real_np = np.array(image_reals).astype('float32')
                label_org_np = np.array(label_orgs).astype('float32')
                label_trg_np = np.array(label_trgs).astype('float32')

                yield image_real_np, label_org_np, label_trg_np

                num += 1
                if num == cfg.iterations:
                    break

                image_reals = []
                label_orgs = []
                label_trgs = []

    return __impl__


class InstanceNorm(fluid.dygraph.Layer):
    def __init__(self, num_channels, epsilon=1e-5):
111
        super().__init__()
112 113 114 115 116 117
        self.epsilon = epsilon

        self.scale = self.create_parameter(shape=[num_channels], is_bias=False)
        self.bias = self.create_parameter(shape=[num_channels], is_bias=True)

    def forward(self, input):
J
Jiabin Yang 已提交
118
        if fluid._non_static_mode():
119 120 121
            out, _, _ = _legacy_C_ops.instance_norm(
                input, self.scale, self.bias, 'epsilon', self.epsilon
            )
122
            return out
123
        else:
124
            return paddle.static.nn.instance_norm(
125 126 127
                input,
                epsilon=self.epsilon,
                param_attr=fluid.ParamAttr(self.scale.name),
128 129
                bias_attr=fluid.ParamAttr(self.bias.name),
            )
130 131 132


class Conv2DLayer(fluid.dygraph.Layer):
133 134 135 136 137 138 139 140 141 142 143
    def __init__(
        self,
        num_channels,
        num_filters=64,
        filter_size=7,
        stride=1,
        padding=0,
        norm=None,
        use_bias=False,
        relufactor=None,
    ):
144
        super().__init__()
145 146 147 148
        self._conv = paddle.nn.Conv2D(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
149 150 151 152
            stride=stride,
            padding=padding,
            bias_attr=None if use_bias else False,
        )
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

        if norm is not None:
            self._norm = InstanceNorm(num_filters)
        else:
            self._norm = None

        self.relufactor = relufactor

    def forward(self, input):
        conv = self._conv(input)

        if self._norm:
            conv = self._norm(conv)

        if self.relufactor is not None:
168
            conv = paddle.nn.functional.leaky_relu(conv, self.relufactor)
169 170 171 172 173

        return conv


class Deconv2DLayer(fluid.dygraph.Layer):
174 175 176 177 178 179 180 181 182 183 184
    def __init__(
        self,
        num_channels,
        num_filters=64,
        filter_size=7,
        stride=1,
        padding=0,
        norm=None,
        use_bias=False,
        relufactor=None,
    ):
185
        super().__init__()
186

187 188 189 190
        self._deconv = paddle.nn.Conv2DTranspose(
            num_channels,
            num_filters,
            filter_size,
191 192
            stride=stride,
            padding=padding,
193 194
            bias_attr=None if use_bias else False,
        )
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209

        if norm is not None:
            self._norm = InstanceNorm(num_filters)
        else:
            self._norm = None

        self.relufactor = relufactor

    def forward(self, input):
        deconv = self._deconv(input)

        if self._norm:
            deconv = self._norm(deconv)

        if self.relufactor is not None:
210
            deconv = paddle.nn.functional.leaky_relu(deconv, self.relufactor)
211 212 213 214 215 216

        return deconv


class ResidualBlock(fluid.dygraph.Layer):
    def __init__(self, num_channels, num_filters):
217
        super().__init__()
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        self._conv0 = Conv2DLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=3,
            stride=1,
            padding=1,
            norm=True,
            relufactor=0,
        )

        self._conv1 = Conv2DLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=1,
            padding=1,
            norm=True,
            relufactor=None,
        )
237 238 239 240 241 242 243 244 245

    def forward(self, input):
        conv0 = self._conv0(input)
        conv1 = self._conv1(conv0)
        return input + conv1


class Generator(fluid.dygraph.Layer):
    def __init__(self, cfg, num_channels=3):
246
        super().__init__()
247 248 249 250 251 252 253 254 255
        conv_base = Conv2DLayer(
            num_channels=cfg.c_dim + num_channels,
            num_filters=cfg.g_base_dims,
            filter_size=7,
            stride=1,
            padding=3,
            norm=True,
            relufactor=0,
        )
256 257 258 259

        sub_layers = [conv_base]
        cur_channels = cfg.g_base_dims
        for i in range(2):
260 261 262 263 264 265 266 267 268
            sub_layer = Conv2DLayer(
                num_channels=cur_channels,
                num_filters=cur_channels * 2,
                filter_size=4,
                stride=2,
                padding=1,
                norm=True,
                relufactor=0,
            )
269 270 271 272

            cur_channels *= 2
            sub_layers.append(sub_layer)

273
        self._conv0 = paddle.nn.Sequential(*sub_layers)
274 275 276 277

        repeat_num = cfg.g_repeat_num
        sub_layers = []
        for i in range(repeat_num):
278 279 280
            res_block = ResidualBlock(
                num_channels=cur_channels, num_filters=cfg.g_base_dims * 4
            )
281 282
            sub_layers.append(res_block)

283
        self._res_block = paddle.nn.Sequential(*sub_layers)
284 285 286 287

        cur_channels = cfg.g_base_dims * 4
        sub_layers = []
        for i in range(2):
288 289 290 291 292 293 294 295 296 297
            rate = 2 ** (1 - i)
            deconv = Deconv2DLayer(
                num_channels=cur_channels,
                num_filters=cfg.g_base_dims * rate,
                filter_size=4,
                stride=2,
                padding=1,
                relufactor=0,
                norm=True,
            )
298 299 300
            cur_channels = cfg.g_base_dims * rate
            sub_layers.append(deconv)

301
        self._deconv = paddle.nn.Sequential(*sub_layers)
302

303 304 305 306 307 308 309 310
        self._conv1 = Conv2DLayer(
            num_channels=cur_channels,
            num_filters=3,
            filter_size=7,
            stride=1,
            padding=3,
            relufactor=None,
        )
311 312 313

    def forward(self, input, label_trg):
        shape = input.shape
314
        label_trg_e = paddle.reshape(label_trg, [-1, label_trg.shape[1], 1, 1])
315
        label_trg_e = paddle.expand(label_trg_e, [-1, -1, shape[2], shape[3]])
316 317 318 319 320 321 322

        input1 = fluid.layers.concat([input, label_trg_e], 1)

        conv0 = self._conv0(input1)
        res_block = self._res_block(conv0)
        deconv = self._deconv(res_block)
        conv1 = self._conv1(deconv)
323
        out = paddle.tanh(conv1)
324 325 326 327 328
        return out


class Discriminator(fluid.dygraph.Layer):
    def __init__(self, cfg, num_channels=3):
329
        super().__init__()
330 331 332

        cur_dim = cfg.d_base_dims

333 334 335 336 337 338 339 340
        conv_base = Conv2DLayer(
            num_channels=num_channels,
            num_filters=cur_dim,
            filter_size=4,
            stride=2,
            padding=1,
            relufactor=0.2,
        )
341 342 343 344

        repeat_num = cfg.d_repeat_num
        sub_layers = [conv_base]
        for i in range(1, repeat_num):
345 346 347 348 349 350 351 352
            sub_layer = Conv2DLayer(
                num_channels=cur_dim,
                num_filters=cur_dim * 2,
                filter_size=4,
                stride=2,
                padding=1,
                relufactor=0.2,
            )
353 354 355
            cur_dim *= 2
            sub_layers.append(sub_layer)

356
        self._conv0 = paddle.nn.Sequential(*sub_layers)
357 358 359

        kernel_size = int(cfg.image_size / np.power(2, repeat_num))

360 361 362 363 364 365 366
        self._conv1 = Conv2DLayer(
            num_channels=cur_dim,
            num_filters=1,
            filter_size=3,
            stride=1,
            padding=1,
        )
367

368 369 370
        self._conv2 = Conv2DLayer(
            num_channels=cur_dim, num_filters=cfg.c_dim, filter_size=kernel_size
        )
371 372 373 374 375 376 377 378 379 380

    def forward(self, input):
        conv = self._conv0(input)
        out1 = self._conv1(conv)
        out2 = self._conv2(conv)
        return out1, out2


def loss_cls(cls, label, cfg):
    cls_shape = cls.shape
381
    cls = paddle.reshape(cls, [-1, cls_shape[1] * cls_shape[2] * cls_shape[3]])
382
    return (
383 384 385
        paddle.sum(
            paddle.nn.functional.binary_cross_entropy_with_logits(cls, label)
        )
386 387
        / cfg.batch_size
    )
388 389 390


def calc_gradients(outputs, inputs, no_grad_set):
J
Jiabin Yang 已提交
391
    if fluid._non_static_mode():
392 393 394 395 396 397
        return fluid.dygraph.grad(
            outputs=outputs,
            inputs=inputs,
            no_grad_vars=no_grad_set,
            create_graph=True,
        )
398
    else:
399 400 401
        return fluid.gradients(
            targets=outputs, inputs=inputs, no_grad_set=no_grad_set
        )
402 403 404 405 406


def gradient_penalty(f, real, fake, no_grad_set, cfg):
    def _interpolate(a, b):
        shape = [a.shape[0]]
407
        alpha = random.uniform_random_batch_size_like(
408 409
            input=a, shape=shape, min=0.1, max=1.0, seed=cfg.seed
        )
410

411
        inner = paddle.tensor.math._multiply_with_axis(
412
            b, 1.0 - alpha, axis=0
413
        ) + paddle.tensor.math._multiply_with_axis(a, alpha, axis=0)
414 415 416 417 418 419 420
        return inner

    x = _interpolate(real, fake)
    pred, _ = f(x)
    if isinstance(pred, tuple):
        pred = pred[0]

421 422 423
    gradient = calc_gradients(
        outputs=[pred], inputs=[x], no_grad_set=no_grad_set
    )
424 425 426 427 428 429 430

    if gradient is None:
        return None

    gradient = gradient[0]
    grad_shape = gradient.shape

431
    gradient = paddle.reshape(
432 433
        gradient, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]]
    )
434 435

    epsilon = 1e-16
436
    norm = paddle.sqrt(paddle.sum(paddle.square(gradient), axis=1) + epsilon)
437

438
    gp = paddle.mean(paddle.square(norm - 1.0))
439 440 441
    return gp


442 443 444
def get_generator_loss(
    image_real, label_org, label_trg, generator, discriminator, cfg
):
445 446
    fake_img = generator(image_real, label_trg)
    rec_img = generator(fake_img, label_org)
447
    g_loss_rec = paddle.mean(paddle.abs(paddle.subtract(image_real, rec_img)))
448 449 450

    pred_fake, cls_fake = discriminator(fake_img)

451
    g_loss_fake = -paddle.mean(pred_fake)
452 453 454 455 456
    g_loss_cls = loss_cls(cls_fake, label_trg, cfg)
    g_loss = g_loss_fake + cfg.lambda_rec * g_loss_rec + g_loss_cls
    return g_loss


457 458 459
def get_discriminator_loss(
    image_real, label_org, label_trg, generator, discriminator, cfg
):
460 461 462 463
    fake_img = generator(image_real, label_trg)
    pred_real, cls_real = discriminator(image_real)
    pred_fake, _ = discriminator(fake_img)
    d_loss_cls = loss_cls(cls_real, label_org, cfg)
464 465
    d_loss_fake = paddle.mean(pred_fake)
    d_loss_real = -paddle.mean(pred_real)
466 467
    d_loss = d_loss_real + d_loss_fake + d_loss_cls

468 469 470 471 472 473 474
    d_loss_gp = gradient_penalty(
        discriminator,
        image_real,
        fake_img,
        set(discriminator.parameters()),
        cfg,
    )
475 476 477 478 479 480 481 482 483 484
    if d_loss_gp is not None:
        d_loss += cfg.lambda_gp * d_loss_gp

    return d_loss


def build_optimizer(layer, cfg, loss=None):
    learning_rate = 1e-3
    beta1 = 0.5
    beta2 = 0.999
J
Jiabin Yang 已提交
485
    if fluid._non_static_mode():
486 487 488 489 490 491
        return fluid.optimizer.Adam(
            learning_rate=learning_rate,
            beta1=beta1,
            beta2=beta2,
            parameter_list=layer.parameters(),
        )
492
    else:
493 494 495
        optimizer = fluid.optimizer.Adam(
            learning_rate=learning_rate, beta1=beta1, beta2=beta2
        )
496 497 498 499 500

        optimizer.minimize(loss, parameter_list=layer.parameters())
        return optimizer


501
class DyGraphTrainModel:
502
    def __init__(self, cfg):
C
cnn 已提交
503
        paddle.seed(1)
L
Leo Chen 已提交
504
        paddle.framework.random._manual_program_seed(1)
505 506 507 508 509 510 511 512 513

        self.generator = Generator(cfg)
        self.discriminator = Discriminator(cfg)

        self.g_optimizer = build_optimizer(self.generator, cfg)
        self.d_optimizer = build_optimizer(self.discriminator, cfg)

        self.cfg = cfg

514
        fluid.set_flags({'FLAGS_sort_sum_gradient': cfg.sort_sum_gradient})
515

516 517 518 519 520 521 522
    def clear_gradients(self):
        if self.g_optimizer:
            self.g_optimizer.clear_gradients()

        if self.d_optimizer:
            self.d_optimizer.clear_gradients()

523 524 525 526 527
    def run(self, image_real, label_org, label_trg):
        image_real = fluid.dygraph.to_variable(image_real)
        label_org = fluid.dygraph.to_variable(label_org)
        label_trg = fluid.dygraph.to_variable(label_trg)

528 529 530 531 532 533 534 535
        g_loss = get_generator_loss(
            image_real,
            label_org,
            label_trg,
            self.generator,
            self.discriminator,
            self.cfg,
        )
536
        g_loss.backward()
537 538
        if self.g_optimizer:
            self.g_optimizer.minimize(g_loss)
539 540

        self.clear_gradients()
541

542 543 544 545 546 547 548 549
        d_loss = get_discriminator_loss(
            image_real,
            label_org,
            label_trg,
            self.generator,
            self.discriminator,
            self.cfg,
        )
550
        d_loss.backward()
551 552
        if self.d_optimizer:
            self.d_optimizer.minimize(d_loss)
553 554

        self.clear_gradients()
555 556 557 558

        return g_loss.numpy()[0], d_loss.numpy()[0]


559
class StaticGraphTrainModel:
560 561 562 563 564 565 566
    def __init__(self, cfg):
        self.cfg = cfg

        def create_data_layer():
            image_real = fluid.data(
                shape=[None, 3, cfg.image_size, cfg.image_size],
                dtype='float32',
567 568 569 570 571 572 573 574
                name='image_real',
            )
            label_org = fluid.data(
                shape=[None, cfg.c_dim], dtype='float32', name='label_org'
            )
            label_trg = fluid.data(
                shape=[None, cfg.c_dim], dtype='float32', name='label_trg'
            )
575 576
            return image_real, label_org, label_trg

C
cnn 已提交
577
        paddle.seed(cfg.seed)
L
Leo Chen 已提交
578
        paddle.framework.random._manual_program_seed(cfg.seed)
579 580 581 582 583 584 585 586
        self.gen_program = fluid.Program()
        gen_startup_program = fluid.Program()

        with fluid.program_guard(self.gen_program, gen_startup_program):
            with fluid.unique_name.guard():
                image_real, label_org, label_trg = create_data_layer()
                generator = Generator(cfg)
                discriminator = Discriminator(cfg)
587 588 589 590 591 592 593 594
                g_loss = get_generator_loss(
                    image_real,
                    label_org,
                    label_trg,
                    generator,
                    discriminator,
                    cfg,
                )
595 596 597 598 599 600 601 602 603
                build_optimizer(generator, cfg, loss=g_loss)

        self.dis_program = fluid.Program()
        dis_startup_program = fluid.Program()
        with fluid.program_guard(self.dis_program, dis_startup_program):
            with fluid.unique_name.guard():
                image_real, label_org, label_trg = create_data_layer()
                generator = Generator(cfg)
                discriminator = Discriminator(cfg)
604 605 606 607 608 609 610 611
                d_loss = get_discriminator_loss(
                    image_real,
                    label_org,
                    label_trg,
                    generator,
                    discriminator,
                    cfg,
                )
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
                build_optimizer(discriminator, cfg, loss=d_loss)

        self.executor = fluid.Executor(cfg.place)
        self.scope = fluid.Scope()

        with fluid.scope_guard(self.scope):
            self.executor.run(gen_startup_program)
            self.executor.run(dis_startup_program)

        self.g_loss = g_loss
        self.d_loss = d_loss

    def run(self, image_real, label_org, label_trg):
        feed = {
            'image_real': image_real,
            'label_org': label_org,
628
            'label_trg': label_trg,
629 630
        }
        with fluid.scope_guard(self.scope):
631 632 633 634 635 636
            g_loss_val = self.executor.run(
                self.gen_program, feed=feed, fetch_list=[self.g_loss]
            )[0]
            d_loss_val = self.executor.run(
                self.dis_program, feed=feed, fetch_list=[self.d_loss]
            )[0]
637 638 639 640
            return g_loss_val[0], d_loss_val[0]


class TestStarGANWithGradientPenalty(unittest.TestCase):
641
    def func_main(self):
642 643 644 645 646
        self.place_test(fluid.CPUPlace())

        if fluid.is_compiled_with_cuda():
            self.place_test(fluid.CUDAPlace(0))

647 648 649 650 651 652 653 654 655
    def place_test(self, place):
        cfg = Config(place, False)

        dataset = create_mnist_dataset(cfg)
        dataset = paddle.reader.cache(dataset)

        fluid_dygraph_loss = []
        with fluid.dygraph.guard(cfg.place):
            fluid_dygraph_model = DyGraphTrainModel(cfg)
656 657 658
            for batch_id, (image_real, label_org, label_trg) in enumerate(
                dataset()
            ):
659 660 661 662
                loss = fluid_dygraph_model.run(image_real, label_org, label_trg)
                fluid_dygraph_loss.append(loss)

        eager_dygraph_loss = []
663
        with fluid.dygraph.guard(cfg.place):
664
            eager_dygraph_model = DyGraphTrainModel(cfg)
665 666 667
            for batch_id, (image_real, label_org, label_trg) in enumerate(
                dataset()
            ):
668 669
                loss = eager_dygraph_model.run(image_real, label_org, label_trg)
                eager_dygraph_loss.append(loss)
670

671
    def test_all_cases(self):
672
        self.func_main()
673

674 675

if __name__ == '__main__':
H
hong 已提交
676
    paddle.enable_static()
677
    unittest.main()