test_imperative_star_gan_with_gradient_penalty.py 20.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.fluid as fluid
17
from paddle.tensor import random
18 19
import numpy as np
import unittest
20 21
from paddle import _legacy_C_ops
from paddle.fluid.framework import _test_eager_guard
22 23 24 25 26

if fluid.is_compiled_with_cuda():
    fluid.core.globals()['FLAGS_cudnn_deterministic'] = True


27
class Config:
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    def __init__(self, place, sort_sum_gradient=True):
        self.place = place

        if isinstance(place, fluid.CPUPlace):
            # CPU cases are extremely slow
            self.g_base_dims = 1
            self.d_base_dims = 1

            self.g_repeat_num = 1
            self.d_repeat_num = 1

            self.image_size = 32
        else:
            self.g_base_dims = 64
            self.d_base_dims = 64

            self.g_repeat_num = 6
            self.d_repeat_num = 6

            self.image_size = 256

        self.c_dim = 10
        self.batch_size = 1

        self.seed = 1

        self.lambda_rec = 10
        self.lambda_gp = 10

        self.iterations = 10

        self.sort_sum_gradient = sort_sum_gradient


def create_mnist_dataset(cfg):
    def create_target_label(label):
        return label
        # return (label + 1) % cfg.c_dim # fake label target

    def create_one_hot(label):
        ret = np.zeros([cfg.c_dim])
        ret[label] = 1
        return ret

    def __impl__():
        dataset = paddle.dataset.mnist.train()
        image_reals = []
        label_orgs = []
        label_trgs = []
        num = 0

        for image_real, label_org in dataset():
            image_real = np.reshape(np.array(image_real), [28, 28])
            image_real = np.resize(image_real, [cfg.image_size, cfg.image_size])
            image_real = np.array([image_real] * 3)

            label_trg = create_target_label(label_org)

            image_reals.append(np.array(image_real))
            label_orgs.append(create_one_hot(label_org))
            label_trgs.append(create_one_hot(label_trg))

            if len(image_reals) == cfg.batch_size:
                image_real_np = np.array(image_reals).astype('float32')
                label_org_np = np.array(label_orgs).astype('float32')
                label_trg_np = np.array(label_trgs).astype('float32')

                yield image_real_np, label_org_np, label_trg_np

                num += 1
                if num == cfg.iterations:
                    break

                image_reals = []
                label_orgs = []
                label_trgs = []

    return __impl__


class InstanceNorm(fluid.dygraph.Layer):
    def __init__(self, num_channels, epsilon=1e-5):
110
        super().__init__()
111 112 113 114 115 116
        self.epsilon = epsilon

        self.scale = self.create_parameter(shape=[num_channels], is_bias=False)
        self.bias = self.create_parameter(shape=[num_channels], is_bias=True)

    def forward(self, input):
J
Jiabin Yang 已提交
117
        if fluid._non_static_mode():
118 119 120
            out, _, _ = _legacy_C_ops.instance_norm(
                input, self.scale, self.bias, 'epsilon', self.epsilon
            )
121
            return out
122 123 124 125 126
        else:
            return fluid.layers.instance_norm(
                input,
                epsilon=self.epsilon,
                param_attr=fluid.ParamAttr(self.scale.name),
127 128
                bias_attr=fluid.ParamAttr(self.bias.name),
            )
129 130 131


class Conv2DLayer(fluid.dygraph.Layer):
132 133 134 135 136 137 138 139 140 141 142
    def __init__(
        self,
        num_channels,
        num_filters=64,
        filter_size=7,
        stride=1,
        padding=0,
        norm=None,
        use_bias=False,
        relufactor=None,
    ):
143
        super().__init__()
144 145 146 147
        self._conv = paddle.nn.Conv2D(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
148 149 150 151
            stride=stride,
            padding=padding,
            bias_attr=None if use_bias else False,
        )
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

        if norm is not None:
            self._norm = InstanceNorm(num_filters)
        else:
            self._norm = None

        self.relufactor = relufactor

    def forward(self, input):
        conv = self._conv(input)

        if self._norm:
            conv = self._norm(conv)

        if self.relufactor is not None:
167
            conv = paddle.nn.functional.leaky_relu(conv, self.relufactor)
168 169 170 171 172

        return conv


class Deconv2DLayer(fluid.dygraph.Layer):
173 174 175 176 177 178 179 180 181 182 183
    def __init__(
        self,
        num_channels,
        num_filters=64,
        filter_size=7,
        stride=1,
        padding=0,
        norm=None,
        use_bias=False,
        relufactor=None,
    ):
184
        super().__init__()
185 186 187 188 189 190 191

        self._deconv = fluid.dygraph.Conv2DTranspose(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=padding,
192 193
            bias_attr=None if use_bias else False,
        )
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208

        if norm is not None:
            self._norm = InstanceNorm(num_filters)
        else:
            self._norm = None

        self.relufactor = relufactor

    def forward(self, input):
        deconv = self._deconv(input)

        if self._norm:
            deconv = self._norm(deconv)

        if self.relufactor is not None:
209
            deconv = paddle.nn.functional.leaky_relu(deconv, self.relufactor)
210 211 212 213 214 215

        return deconv


class ResidualBlock(fluid.dygraph.Layer):
    def __init__(self, num_channels, num_filters):
216
        super().__init__()
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        self._conv0 = Conv2DLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=3,
            stride=1,
            padding=1,
            norm=True,
            relufactor=0,
        )

        self._conv1 = Conv2DLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=1,
            padding=1,
            norm=True,
            relufactor=None,
        )
236 237 238 239 240 241 242 243 244

    def forward(self, input):
        conv0 = self._conv0(input)
        conv1 = self._conv1(conv0)
        return input + conv1


class Generator(fluid.dygraph.Layer):
    def __init__(self, cfg, num_channels=3):
245
        super().__init__()
246 247 248 249 250 251 252 253 254
        conv_base = Conv2DLayer(
            num_channels=cfg.c_dim + num_channels,
            num_filters=cfg.g_base_dims,
            filter_size=7,
            stride=1,
            padding=3,
            norm=True,
            relufactor=0,
        )
255 256 257 258

        sub_layers = [conv_base]
        cur_channels = cfg.g_base_dims
        for i in range(2):
259 260 261 262 263 264 265 266 267
            sub_layer = Conv2DLayer(
                num_channels=cur_channels,
                num_filters=cur_channels * 2,
                filter_size=4,
                stride=2,
                padding=1,
                norm=True,
                relufactor=0,
            )
268 269 270 271

            cur_channels *= 2
            sub_layers.append(sub_layer)

272
        self._conv0 = paddle.nn.Sequential(*sub_layers)
273 274 275 276

        repeat_num = cfg.g_repeat_num
        sub_layers = []
        for i in range(repeat_num):
277 278 279
            res_block = ResidualBlock(
                num_channels=cur_channels, num_filters=cfg.g_base_dims * 4
            )
280 281
            sub_layers.append(res_block)

282
        self._res_block = paddle.nn.Sequential(*sub_layers)
283 284 285 286

        cur_channels = cfg.g_base_dims * 4
        sub_layers = []
        for i in range(2):
287 288 289 290 291 292 293 294 295 296
            rate = 2 ** (1 - i)
            deconv = Deconv2DLayer(
                num_channels=cur_channels,
                num_filters=cfg.g_base_dims * rate,
                filter_size=4,
                stride=2,
                padding=1,
                relufactor=0,
                norm=True,
            )
297 298 299
            cur_channels = cfg.g_base_dims * rate
            sub_layers.append(deconv)

300
        self._deconv = paddle.nn.Sequential(*sub_layers)
301

302 303 304 305 306 307 308 309
        self._conv1 = Conv2DLayer(
            num_channels=cur_channels,
            num_filters=3,
            filter_size=7,
            stride=1,
            padding=3,
            relufactor=None,
        )
310 311 312

    def forward(self, input, label_trg):
        shape = input.shape
313
        label_trg_e = paddle.reshape(label_trg, [-1, label_trg.shape[1], 1, 1])
314
        label_trg_e = paddle.expand(label_trg_e, [-1, -1, shape[2], shape[3]])
315 316 317 318 319 320 321

        input1 = fluid.layers.concat([input, label_trg_e], 1)

        conv0 = self._conv0(input1)
        res_block = self._res_block(conv0)
        deconv = self._deconv(res_block)
        conv1 = self._conv1(deconv)
322
        out = paddle.tanh(conv1)
323 324 325 326 327
        return out


class Discriminator(fluid.dygraph.Layer):
    def __init__(self, cfg, num_channels=3):
328
        super().__init__()
329 330 331

        cur_dim = cfg.d_base_dims

332 333 334 335 336 337 338 339
        conv_base = Conv2DLayer(
            num_channels=num_channels,
            num_filters=cur_dim,
            filter_size=4,
            stride=2,
            padding=1,
            relufactor=0.2,
        )
340 341 342 343

        repeat_num = cfg.d_repeat_num
        sub_layers = [conv_base]
        for i in range(1, repeat_num):
344 345 346 347 348 349 350 351
            sub_layer = Conv2DLayer(
                num_channels=cur_dim,
                num_filters=cur_dim * 2,
                filter_size=4,
                stride=2,
                padding=1,
                relufactor=0.2,
            )
352 353 354
            cur_dim *= 2
            sub_layers.append(sub_layer)

355
        self._conv0 = paddle.nn.Sequential(*sub_layers)
356 357 358

        kernel_size = int(cfg.image_size / np.power(2, repeat_num))

359 360 361 362 363 364 365
        self._conv1 = Conv2DLayer(
            num_channels=cur_dim,
            num_filters=1,
            filter_size=3,
            stride=1,
            padding=1,
        )
366

367 368 369
        self._conv2 = Conv2DLayer(
            num_channels=cur_dim, num_filters=cfg.c_dim, filter_size=kernel_size
        )
370 371 372 373 374 375 376 377 378 379

    def forward(self, input):
        conv = self._conv0(input)
        out1 = self._conv1(conv)
        out2 = self._conv2(conv)
        return out1, out2


def loss_cls(cls, label, cfg):
    cls_shape = cls.shape
380
    cls = paddle.reshape(cls, [-1, cls_shape[1] * cls_shape[2] * cls_shape[3]])
381 382 383 384 385 386
    return (
        fluid.layers.reduce_sum(
            fluid.layers.sigmoid_cross_entropy_with_logits(cls, label)
        )
        / cfg.batch_size
    )
387 388 389


def calc_gradients(outputs, inputs, no_grad_set):
J
Jiabin Yang 已提交
390
    if fluid._non_static_mode():
391 392 393 394 395 396
        return fluid.dygraph.grad(
            outputs=outputs,
            inputs=inputs,
            no_grad_vars=no_grad_set,
            create_graph=True,
        )
397
    else:
398 399 400
        return fluid.gradients(
            targets=outputs, inputs=inputs, no_grad_set=no_grad_set
        )
401 402 403 404 405


def gradient_penalty(f, real, fake, no_grad_set, cfg):
    def _interpolate(a, b):
        shape = [a.shape[0]]
406
        alpha = random.uniform_random_batch_size_like(
407 408
            input=a, shape=shape, min=0.1, max=1.0, seed=cfg.seed
        )
409 410

        inner = fluid.layers.elementwise_mul(
411 412
            b, 1.0 - alpha, axis=0
        ) + fluid.layers.elementwise_mul(a, alpha, axis=0)
413 414 415 416 417 418 419
        return inner

    x = _interpolate(real, fake)
    pred, _ = f(x)
    if isinstance(pred, tuple):
        pred = pred[0]

420 421 422
    gradient = calc_gradients(
        outputs=[pred], inputs=[x], no_grad_set=no_grad_set
    )
423 424 425 426 427 428 429

    if gradient is None:
        return None

    gradient = gradient[0]
    grad_shape = gradient.shape

430
    gradient = paddle.reshape(
431 432
        gradient, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]]
    )
433 434

    epsilon = 1e-16
435
    norm = paddle.sqrt(paddle.sum(paddle.square(gradient), axis=1) + epsilon)
436

437
    gp = paddle.mean(paddle.square(norm - 1.0))
438 439 440
    return gp


441 442 443
def get_generator_loss(
    image_real, label_org, label_trg, generator, discriminator, cfg
):
444 445 446
    fake_img = generator(image_real, label_trg)
    rec_img = generator(fake_img, label_org)
    g_loss_rec = fluid.layers.reduce_mean(
447
        paddle.abs(paddle.subtract(image_real, rec_img))
448
    )
449 450 451

    pred_fake, cls_fake = discriminator(fake_img)

452
    g_loss_fake = -paddle.mean(pred_fake)
453 454 455 456 457
    g_loss_cls = loss_cls(cls_fake, label_trg, cfg)
    g_loss = g_loss_fake + cfg.lambda_rec * g_loss_rec + g_loss_cls
    return g_loss


458 459 460
def get_discriminator_loss(
    image_real, label_org, label_trg, generator, discriminator, cfg
):
461 462 463 464
    fake_img = generator(image_real, label_trg)
    pred_real, cls_real = discriminator(image_real)
    pred_fake, _ = discriminator(fake_img)
    d_loss_cls = loss_cls(cls_real, label_org, cfg)
465 466
    d_loss_fake = paddle.mean(pred_fake)
    d_loss_real = -paddle.mean(pred_real)
467 468
    d_loss = d_loss_real + d_loss_fake + d_loss_cls

469 470 471 472 473 474 475
    d_loss_gp = gradient_penalty(
        discriminator,
        image_real,
        fake_img,
        set(discriminator.parameters()),
        cfg,
    )
476 477 478 479 480 481 482 483 484 485
    if d_loss_gp is not None:
        d_loss += cfg.lambda_gp * d_loss_gp

    return d_loss


def build_optimizer(layer, cfg, loss=None):
    learning_rate = 1e-3
    beta1 = 0.5
    beta2 = 0.999
J
Jiabin Yang 已提交
486
    if fluid._non_static_mode():
487 488 489 490 491 492
        return fluid.optimizer.Adam(
            learning_rate=learning_rate,
            beta1=beta1,
            beta2=beta2,
            parameter_list=layer.parameters(),
        )
493
    else:
494 495 496
        optimizer = fluid.optimizer.Adam(
            learning_rate=learning_rate, beta1=beta1, beta2=beta2
        )
497 498 499 500 501

        optimizer.minimize(loss, parameter_list=layer.parameters())
        return optimizer


502
class DyGraphTrainModel:
503
    def __init__(self, cfg):
C
cnn 已提交
504
        paddle.seed(1)
L
Leo Chen 已提交
505
        paddle.framework.random._manual_program_seed(1)
506 507 508 509 510 511 512 513 514

        self.generator = Generator(cfg)
        self.discriminator = Discriminator(cfg)

        self.g_optimizer = build_optimizer(self.generator, cfg)
        self.d_optimizer = build_optimizer(self.discriminator, cfg)

        self.cfg = cfg

515
        fluid.set_flags({'FLAGS_sort_sum_gradient': cfg.sort_sum_gradient})
516

517 518 519 520 521 522 523
    def clear_gradients(self):
        if self.g_optimizer:
            self.g_optimizer.clear_gradients()

        if self.d_optimizer:
            self.d_optimizer.clear_gradients()

524 525 526 527 528
    def run(self, image_real, label_org, label_trg):
        image_real = fluid.dygraph.to_variable(image_real)
        label_org = fluid.dygraph.to_variable(label_org)
        label_trg = fluid.dygraph.to_variable(label_trg)

529 530 531 532 533 534 535 536
        g_loss = get_generator_loss(
            image_real,
            label_org,
            label_trg,
            self.generator,
            self.discriminator,
            self.cfg,
        )
537
        g_loss.backward()
538 539
        if self.g_optimizer:
            self.g_optimizer.minimize(g_loss)
540 541

        self.clear_gradients()
542

543 544 545 546 547 548 549 550
        d_loss = get_discriminator_loss(
            image_real,
            label_org,
            label_trg,
            self.generator,
            self.discriminator,
            self.cfg,
        )
551
        d_loss.backward()
552 553
        if self.d_optimizer:
            self.d_optimizer.minimize(d_loss)
554 555

        self.clear_gradients()
556 557 558 559

        return g_loss.numpy()[0], d_loss.numpy()[0]


560
class StaticGraphTrainModel:
561 562 563 564 565 566 567
    def __init__(self, cfg):
        self.cfg = cfg

        def create_data_layer():
            image_real = fluid.data(
                shape=[None, 3, cfg.image_size, cfg.image_size],
                dtype='float32',
568 569 570 571 572 573 574 575
                name='image_real',
            )
            label_org = fluid.data(
                shape=[None, cfg.c_dim], dtype='float32', name='label_org'
            )
            label_trg = fluid.data(
                shape=[None, cfg.c_dim], dtype='float32', name='label_trg'
            )
576 577
            return image_real, label_org, label_trg

C
cnn 已提交
578
        paddle.seed(cfg.seed)
L
Leo Chen 已提交
579
        paddle.framework.random._manual_program_seed(cfg.seed)
580 581 582 583 584 585 586 587
        self.gen_program = fluid.Program()
        gen_startup_program = fluid.Program()

        with fluid.program_guard(self.gen_program, gen_startup_program):
            with fluid.unique_name.guard():
                image_real, label_org, label_trg = create_data_layer()
                generator = Generator(cfg)
                discriminator = Discriminator(cfg)
588 589 590 591 592 593 594 595
                g_loss = get_generator_loss(
                    image_real,
                    label_org,
                    label_trg,
                    generator,
                    discriminator,
                    cfg,
                )
596 597 598 599 600 601 602 603 604
                build_optimizer(generator, cfg, loss=g_loss)

        self.dis_program = fluid.Program()
        dis_startup_program = fluid.Program()
        with fluid.program_guard(self.dis_program, dis_startup_program):
            with fluid.unique_name.guard():
                image_real, label_org, label_trg = create_data_layer()
                generator = Generator(cfg)
                discriminator = Discriminator(cfg)
605 606 607 608 609 610 611 612
                d_loss = get_discriminator_loss(
                    image_real,
                    label_org,
                    label_trg,
                    generator,
                    discriminator,
                    cfg,
                )
613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
                build_optimizer(discriminator, cfg, loss=d_loss)

        self.executor = fluid.Executor(cfg.place)
        self.scope = fluid.Scope()

        with fluid.scope_guard(self.scope):
            self.executor.run(gen_startup_program)
            self.executor.run(dis_startup_program)

        self.g_loss = g_loss
        self.d_loss = d_loss

    def run(self, image_real, label_org, label_trg):
        feed = {
            'image_real': image_real,
            'label_org': label_org,
629
            'label_trg': label_trg,
630 631
        }
        with fluid.scope_guard(self.scope):
632 633 634 635 636 637
            g_loss_val = self.executor.run(
                self.gen_program, feed=feed, fetch_list=[self.g_loss]
            )[0]
            d_loss_val = self.executor.run(
                self.dis_program, feed=feed, fetch_list=[self.d_loss]
            )[0]
638 639 640 641
            return g_loss_val[0], d_loss_val[0]


class TestStarGANWithGradientPenalty(unittest.TestCase):
642
    def func_main(self):
643 644 645 646 647
        self.place_test(fluid.CPUPlace())

        if fluid.is_compiled_with_cuda():
            self.place_test(fluid.CUDAPlace(0))

648 649 650 651 652 653 654 655 656
    def place_test(self, place):
        cfg = Config(place, False)

        dataset = create_mnist_dataset(cfg)
        dataset = paddle.reader.cache(dataset)

        fluid_dygraph_loss = []
        with fluid.dygraph.guard(cfg.place):
            fluid_dygraph_model = DyGraphTrainModel(cfg)
657 658 659
            for batch_id, (image_real, label_org, label_trg) in enumerate(
                dataset()
            ):
660 661 662 663 664 665 666
                loss = fluid_dygraph_model.run(image_real, label_org, label_trg)
                fluid_dygraph_loss.append(loss)

        eager_dygraph_loss = []
        with _test_eager_guard():
            with fluid.dygraph.guard(cfg.place):
                eager_dygraph_model = DyGraphTrainModel(cfg)
667 668 669 670 671 672
                for batch_id, (image_real, label_org, label_trg) in enumerate(
                    dataset()
                ):
                    loss = eager_dygraph_model.run(
                        image_real, label_org, label_trg
                    )
673 674
                    eager_dygraph_loss.append(loss)

675 676 677
        for (g_loss_f, d_loss_f), (g_loss_e, d_loss_e) in zip(
            fluid_dygraph_loss, eager_dygraph_loss
        ):
678 679 680 681 682 683 684 685 686 687 688 689 690 691
            self.assertEqual(g_loss_f, g_loss_e)
            self.assertEqual(d_loss_f, d_loss_e)

    def test_all_cases(self):
        self.func_main()


class TestStarGANWithGradientPenaltyLegacy(unittest.TestCase):
    def func_main(self):
        self.place_test(fluid.CPUPlace())

        if fluid.is_compiled_with_cuda():
            self.place_test(fluid.CUDAPlace(0))

692 693 694 695
    def place_test(self, place):
        cfg = Config(place)

        dataset = create_mnist_dataset(cfg)
696
        dataset = paddle.reader.cache(dataset)
697 698 699

        static_graph_model = StaticGraphTrainModel(cfg)
        static_loss = []
700 701 702
        for batch_id, (image_real, label_org, label_trg) in enumerate(
            dataset()
        ):
703 704 705 706 707 708
            loss = static_graph_model.run(image_real, label_org, label_trg)
            static_loss.append(loss)

        dygraph_loss = []
        with fluid.dygraph.guard(cfg.place):
            dygraph_model = DyGraphTrainModel(cfg)
709 710 711
            for batch_id, (image_real, label_org, label_trg) in enumerate(
                dataset()
            ):
712 713 714
                loss = dygraph_model.run(image_real, label_org, label_trg)
                dygraph_loss.append(loss)

715 716 717
        for (g_loss_s, d_loss_s), (g_loss_d, d_loss_d) in zip(
            static_loss, dygraph_loss
        ):
718 719 720
            self.assertEqual(g_loss_s, g_loss_d)
            self.assertEqual(d_loss_s, d_loss_d)

721
    def test_all_cases(self):
722
        self.func_main()
723

724 725

if __name__ == '__main__':
H
hong 已提交
726
    paddle.enable_static()
727
    unittest.main()