test_imperative_star_gan_with_gradient_penalty.py 19.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.fluid as fluid
import numpy as np
import unittest

if fluid.is_compiled_with_cuda():
    fluid.core.globals()['FLAGS_cudnn_deterministic'] = True


class Config(object):
    def __init__(self, place, sort_sum_gradient=True):
        self.place = place

        if isinstance(place, fluid.CPUPlace):
            # CPU cases are extremely slow
            self.g_base_dims = 1
            self.d_base_dims = 1

            self.g_repeat_num = 1
            self.d_repeat_num = 1

            self.image_size = 32
        else:
            self.g_base_dims = 64
            self.d_base_dims = 64

            self.g_repeat_num = 6
            self.d_repeat_num = 6

            self.image_size = 256

        self.c_dim = 10
        self.batch_size = 1

        self.seed = 1

        self.lambda_rec = 10
        self.lambda_gp = 10

        self.iterations = 10

        self.sort_sum_gradient = sort_sum_gradient


def create_mnist_dataset(cfg):
    def create_target_label(label):
        return label
        # return (label + 1) % cfg.c_dim # fake label target

    def create_one_hot(label):
        ret = np.zeros([cfg.c_dim])
        ret[label] = 1
        return ret

    def __impl__():
        dataset = paddle.dataset.mnist.train()
        image_reals = []
        label_orgs = []
        label_trgs = []
        num = 0

        for image_real, label_org in dataset():
            image_real = np.reshape(np.array(image_real), [28, 28])
            image_real = np.resize(image_real, [cfg.image_size, cfg.image_size])
            image_real = np.array([image_real] * 3)

            label_trg = create_target_label(label_org)

            image_reals.append(np.array(image_real))
            label_orgs.append(create_one_hot(label_org))
            label_trgs.append(create_one_hot(label_trg))

            if len(image_reals) == cfg.batch_size:
                image_real_np = np.array(image_reals).astype('float32')
                label_org_np = np.array(label_orgs).astype('float32')
                label_trg_np = np.array(label_trgs).astype('float32')

                yield image_real_np, label_org_np, label_trg_np

                num += 1
                if num == cfg.iterations:
                    break

                image_reals = []
                label_orgs = []
                label_trgs = []

    return __impl__


class InstanceNorm(fluid.dygraph.Layer):
    def __init__(self, num_channels, epsilon=1e-5):
        super(InstanceNorm, self).__init__()
        self.epsilon = epsilon

        self.scale = self.create_parameter(shape=[num_channels], is_bias=False)
        self.bias = self.create_parameter(shape=[num_channels], is_bias=True)

    def forward(self, input):
        if fluid.in_dygraph_mode():
            inputs = {'X': [input], 'Scale': [self.scale], 'Bias': [self.bias]}
            attrs = {'epsilon': self.epsilon}
            return fluid.core.ops.instance_norm(inputs, attrs)['Y'][0]
        else:
            return fluid.layers.instance_norm(
                input,
                epsilon=self.epsilon,
                param_attr=fluid.ParamAttr(self.scale.name),
                bias_attr=fluid.ParamAttr(self.bias.name))


class Conv2DLayer(fluid.dygraph.Layer):
    def __init__(self,
                 num_channels,
                 num_filters=64,
                 filter_size=7,
                 stride=1,
                 padding=0,
                 norm=None,
                 use_bias=False,
                 relufactor=None):
        super(Conv2DLayer, self).__init__()
        self._conv = fluid.dygraph.Conv2D(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=padding,
            bias_attr=None if use_bias else False)

        if norm is not None:
            self._norm = InstanceNorm(num_filters)
        else:
            self._norm = None

        self.relufactor = relufactor

    def forward(self, input):
        conv = self._conv(input)

        if self._norm:
            conv = self._norm(conv)

        if self.relufactor is not None:
            conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor)

        return conv


class Deconv2DLayer(fluid.dygraph.Layer):
    def __init__(self,
                 num_channels,
                 num_filters=64,
                 filter_size=7,
                 stride=1,
                 padding=0,
                 norm=None,
                 use_bias=False,
                 relufactor=None):
        super(Deconv2DLayer, self).__init__()

        self._deconv = fluid.dygraph.Conv2DTranspose(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=padding,
            bias_attr=None if use_bias else False)

        if norm is not None:
            self._norm = InstanceNorm(num_filters)
        else:
            self._norm = None

        self.relufactor = relufactor

    def forward(self, input):
        deconv = self._deconv(input)

        if self._norm:
            deconv = self._norm(deconv)

        if self.relufactor is not None:
            deconv = fluid.layers.leaky_relu(deconv, alpha=self.relufactor)

        return deconv


class ResidualBlock(fluid.dygraph.Layer):
    def __init__(self, num_channels, num_filters):
        super(ResidualBlock, self).__init__()
        self._conv0 = Conv2DLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=3,
            stride=1,
            padding=1,
            norm=True,
            relufactor=0)

        self._conv1 = Conv2DLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=1,
            padding=1,
            norm=True,
            relufactor=None)

    def forward(self, input):
        conv0 = self._conv0(input)
        conv1 = self._conv1(conv0)
        return input + conv1


class Generator(fluid.dygraph.Layer):
    def __init__(self, cfg, num_channels=3):
        super(Generator, self).__init__()
        conv_base = Conv2DLayer(
            num_channels=cfg.c_dim + num_channels,
            num_filters=cfg.g_base_dims,
            filter_size=7,
            stride=1,
            padding=3,
            norm=True,
            relufactor=0)

        sub_layers = [conv_base]
        cur_channels = cfg.g_base_dims
        for i in range(2):
            sub_layer = Conv2DLayer(
                num_channels=cur_channels,
                num_filters=cur_channels * 2,
                filter_size=4,
                stride=2,
                padding=1,
                norm=True,
                relufactor=0)

            cur_channels *= 2
            sub_layers.append(sub_layer)

        self._conv0 = fluid.dygraph.Sequential(*sub_layers)

        repeat_num = cfg.g_repeat_num
        sub_layers = []
        for i in range(repeat_num):
            res_block = ResidualBlock(
                num_channels=cur_channels, num_filters=cfg.g_base_dims * 4)
            sub_layers.append(res_block)

        self._res_block = fluid.dygraph.Sequential(*sub_layers)

        cur_channels = cfg.g_base_dims * 4
        sub_layers = []
        for i in range(2):
            rate = 2**(1 - i)
            deconv = Deconv2DLayer(
                num_channels=cur_channels,
                num_filters=cfg.g_base_dims * rate,
                filter_size=4,
                stride=2,
                padding=1,
                relufactor=0,
                norm=True)
            cur_channels = cfg.g_base_dims * rate
            sub_layers.append(deconv)

        self._deconv = fluid.dygraph.Sequential(*sub_layers)

        self._conv1 = Conv2DLayer(
            num_channels=cur_channels,
            num_filters=3,
            filter_size=7,
            stride=1,
            padding=3,
            relufactor=None)

    def forward(self, input, label_trg):
        shape = input.shape
        label_trg_e = fluid.layers.reshape(label_trg,
                                           [-1, label_trg.shape[1], 1, 1])
        label_trg_e = fluid.layers.expand(
            x=label_trg_e, expand_times=[1, 1, shape[2], shape[3]])

        input1 = fluid.layers.concat([input, label_trg_e], 1)

        conv0 = self._conv0(input1)
        res_block = self._res_block(conv0)
        deconv = self._deconv(res_block)
        conv1 = self._conv1(deconv)
        out = fluid.layers.tanh(conv1)
        return out


class Discriminator(fluid.dygraph.Layer):
    def __init__(self, cfg, num_channels=3):
        super(Discriminator, self).__init__()

        cur_dim = cfg.d_base_dims

        conv_base = Conv2DLayer(
            num_channels=num_channels,
            num_filters=cur_dim,
            filter_size=4,
            stride=2,
            padding=1,
            relufactor=0.2)

        repeat_num = cfg.d_repeat_num
        sub_layers = [conv_base]
        for i in range(1, repeat_num):
            sub_layer = Conv2DLayer(
                num_channels=cur_dim,
                num_filters=cur_dim * 2,
                filter_size=4,
                stride=2,
                padding=1,
                relufactor=0.2)
            cur_dim *= 2
            sub_layers.append(sub_layer)

        self._conv0 = fluid.dygraph.Sequential(*sub_layers)

        kernel_size = int(cfg.image_size / np.power(2, repeat_num))

        self._conv1 = Conv2DLayer(
            num_channels=cur_dim,
            num_filters=1,
            filter_size=3,
            stride=1,
            padding=1)

        self._conv2 = Conv2DLayer(
            num_channels=cur_dim,
            num_filters=cfg.c_dim,
            filter_size=kernel_size)

    def forward(self, input):
        conv = self._conv0(input)
        out1 = self._conv1(conv)
        out2 = self._conv2(conv)
        return out1, out2


def loss_cls(cls, label, cfg):
    cls_shape = cls.shape
    cls = fluid.layers.reshape(
        cls, [-1, cls_shape[1] * cls_shape[2] * cls_shape[3]])
    return fluid.layers.reduce_sum(
        fluid.layers.sigmoid_cross_entropy_with_logits(cls,
                                                       label)) / cfg.batch_size


def calc_gradients(outputs, inputs, no_grad_set):
    if fluid.in_dygraph_mode():
        from paddle.fluid.dygraph.base import grad
        return grad(
            outputs=outputs,
            inputs=inputs,
            no_grad_set=no_grad_set,
            create_graph=True)
    else:
        return fluid.gradients(
            targets=outputs, inputs=inputs, no_grad_set=no_grad_set)


def gradient_penalty(f, real, fake, no_grad_set, cfg):
    def _interpolate(a, b):
        shape = [a.shape[0]]
        alpha = fluid.layers.uniform_random_batch_size_like(
            input=a, shape=shape, min=0.1, max=1.0, seed=cfg.seed)

        inner = fluid.layers.elementwise_mul(
            b, 1.0 - alpha, axis=0) + fluid.layers.elementwise_mul(
                a, alpha, axis=0)
        return inner

    x = _interpolate(real, fake)
    pred, _ = f(x)
    if isinstance(pred, tuple):
        pred = pred[0]

    gradient = calc_gradients(
        outputs=[pred], inputs=[x], no_grad_set=no_grad_set)

    if gradient is None:
        return None

    gradient = gradient[0]
    grad_shape = gradient.shape

    gradient = fluid.layers.reshape(
        gradient, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]])

    epsilon = 1e-16
    norm = fluid.layers.sqrt(
        fluid.layers.reduce_sum(
            fluid.layers.square(gradient), dim=1) + epsilon)

    gp = fluid.layers.reduce_mean(fluid.layers.square(norm - 1.0))
    return gp


def get_generator_loss(image_real, label_org, label_trg, generator,
                       discriminator, cfg):
    fake_img = generator(image_real, label_trg)
    rec_img = generator(fake_img, label_org)
    g_loss_rec = fluid.layers.reduce_mean(
        fluid.layers.abs(fluid.layers.elementwise_sub(image_real, rec_img)))

    pred_fake, cls_fake = discriminator(fake_img)

    g_loss_fake = -fluid.layers.mean(pred_fake)
    g_loss_cls = loss_cls(cls_fake, label_trg, cfg)
    g_loss = g_loss_fake + cfg.lambda_rec * g_loss_rec + g_loss_cls
    return g_loss


def get_discriminator_loss(image_real, label_org, label_trg, generator,
                           discriminator, cfg):
    fake_img = generator(image_real, label_trg)
    pred_real, cls_real = discriminator(image_real)
    pred_fake, _ = discriminator(fake_img)
    d_loss_cls = loss_cls(cls_real, label_org, cfg)
    d_loss_fake = fluid.layers.mean(pred_fake)
    d_loss_real = -fluid.layers.mean(pred_real)
    d_loss = d_loss_real + d_loss_fake + d_loss_cls

    d_loss_gp = gradient_penalty(discriminator, image_real, fake_img,
                                 discriminator.parameters(), cfg)
    if d_loss_gp is not None:
        d_loss += cfg.lambda_gp * d_loss_gp

    return d_loss


def build_optimizer(layer, cfg, loss=None):
    learning_rate = 1e-3
    beta1 = 0.5
    beta2 = 0.999
    if fluid.in_dygraph_mode():
        return fluid.optimizer.Adam(
            learning_rate=learning_rate,
            beta1=beta1,
            beta2=beta2,
            parameter_list=layer.parameters())
    else:
        optimizer = fluid.optimizer.Adam(
            learning_rate=learning_rate, beta1=beta1, beta2=beta2)

        optimizer.minimize(loss, parameter_list=layer.parameters())
        return optimizer


class DyGraphTrainModel(object):
    def __init__(self, cfg):
        fluid.default_startup_program().random_seed = cfg.seed
        fluid.default_main_program().random_seed = cfg.seed

        self.generator = Generator(cfg)
        self.discriminator = Discriminator(cfg)

        self.g_optimizer = build_optimizer(self.generator, cfg)
        self.d_optimizer = build_optimizer(self.discriminator, cfg)

        self.cfg = cfg

        self.backward_strategy = fluid.dygraph.BackwardStrategy()
        self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient

    def run(self, image_real, label_org, label_trg):
        image_real = fluid.dygraph.to_variable(image_real)
        label_org = fluid.dygraph.to_variable(label_org)
        label_trg = fluid.dygraph.to_variable(label_trg)

        g_loss = get_generator_loss(image_real, label_org, label_trg,
                                    self.generator, self.discriminator,
                                    self.cfg)
        g_loss.backward(self.backward_strategy)
        if self.g_optimizer:
            self.g_optimizer.minimize(g_loss)
            self.generator.clear_gradients()

        d_loss = get_discriminator_loss(image_real, label_org, label_trg,
                                        self.generator, self.discriminator,
                                        self.cfg)
        d_loss.backward(self.backward_strategy)
        if self.d_optimizer:
            self.d_optimizer.minimize(d_loss)
            self.discriminator.clear_gradients()

        return g_loss.numpy()[0], d_loss.numpy()[0]


class StaticGraphTrainModel(object):
    def __init__(self, cfg):
        self.cfg = cfg

        def create_data_layer():
            image_real = fluid.data(
                shape=[None, 3, cfg.image_size, cfg.image_size],
                dtype='float32',
                name='image_real')
            label_org = fluid.data(
                shape=[None, cfg.c_dim], dtype='float32', name='label_org')
            label_trg = fluid.data(
                shape=[None, cfg.c_dim], dtype='float32', name='label_trg')
            return image_real, label_org, label_trg

        self.gen_program = fluid.Program()
        gen_startup_program = fluid.Program()

        with fluid.program_guard(self.gen_program, gen_startup_program):
            self.gen_program.random_seed = cfg.seed
            gen_startup_program.random_seed = cfg.seed
            with fluid.unique_name.guard():
                image_real, label_org, label_trg = create_data_layer()
                generator = Generator(cfg)
                discriminator = Discriminator(cfg)
                g_loss = get_generator_loss(image_real, label_org, label_trg,
                                            generator, discriminator, cfg)
                build_optimizer(generator, cfg, loss=g_loss)

        self.dis_program = fluid.Program()
        dis_startup_program = fluid.Program()
        with fluid.program_guard(self.dis_program, dis_startup_program):
            self.dis_program.random_seed = cfg.seed
            dis_startup_program.random_seed = cfg.seed
            with fluid.unique_name.guard():
                image_real, label_org, label_trg = create_data_layer()
                generator = Generator(cfg)
                discriminator = Discriminator(cfg)
                d_loss = get_discriminator_loss(image_real, label_org,
                                                label_trg, generator,
                                                discriminator, cfg)
                build_optimizer(discriminator, cfg, loss=d_loss)

        self.executor = fluid.Executor(cfg.place)
        self.scope = fluid.Scope()

        with fluid.scope_guard(self.scope):
            self.executor.run(gen_startup_program)
            self.executor.run(dis_startup_program)

        self.g_loss = g_loss
        self.d_loss = d_loss

    def run(self, image_real, label_org, label_trg):
        feed = {
            'image_real': image_real,
            'label_org': label_org,
            'label_trg': label_trg
        }
        with fluid.scope_guard(self.scope):
            g_loss_val = self.executor.run(self.gen_program,
                                           feed=feed,
                                           fetch_list=[self.g_loss])[0]
            d_loss_val = self.executor.run(self.dis_program,
                                           feed=feed,
                                           fetch_list=[self.d_loss])[0]
            return g_loss_val[0], d_loss_val[0]


class TestStarGANWithGradientPenalty(unittest.TestCase):
    def test_main(self):
        self.place_test(fluid.CPUPlace())

        if fluid.is_compiled_with_cuda():
            self.place_test(fluid.CUDAPlace(0))

    def place_test(self, place):
        cfg = Config(place)

        dataset = create_mnist_dataset(cfg)
        dataset = fluid.io.cache(dataset)

        static_graph_model = StaticGraphTrainModel(cfg)
        static_loss = []
        for batch_id, (image_real, label_org,
                       label_trg) in enumerate(dataset()):
            loss = static_graph_model.run(image_real, label_org, label_trg)
            static_loss.append(loss)

        dygraph_loss = []
        with fluid.dygraph.guard(cfg.place):
            dygraph_model = DyGraphTrainModel(cfg)
            for batch_id, (image_real, label_org,
                           label_trg) in enumerate(dataset()):
                loss = dygraph_model.run(image_real, label_org, label_trg)
                dygraph_loss.append(loss)

        for (g_loss_s, d_loss_s), (g_loss_d, d_loss_d) in zip(static_loss,
                                                              dygraph_loss):
            self.assertEqual(g_loss_s, g_loss_d)
            self.assertEqual(d_loss_s, d_loss_d)


if __name__ == '__main__':
    unittest.main()