test_retain_graph.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.fluid as fluid
import unittest

20
paddle.disable_static()
21 22
SEED = 2020
np.random.seed(SEED)
23
paddle.manual_seed(SEED)
24 25 26 27 28


class Generator(fluid.dygraph.Layer):
    def __init__(self):
        super(Generator, self).__init__()
29
        self.conv1 = paddle.nn.Conv2d(3, 3, 3, padding=1)
30 31 32 33 34 35 36 37 38 39

    def forward(self, x):
        x = self.conv1(x)
        x = fluid.layers.tanh(x)
        return x


class Discriminator(fluid.dygraph.Layer):
    def __init__(self):
        super(Discriminator, self).__init__()
40
        self.convd = paddle.nn.Conv2d(6, 3, 1)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

    def forward(self, x):
        x = self.convd(x)
        return x


class TestRetainGraph(unittest.TestCase):
    def cal_gradient_penalty(self,
                             netD,
                             real_data,
                             fake_data,
                             edge_data=None,
                             type='mixed',
                             constant=1.0,
                             lambda_gp=10.0):
        if lambda_gp > 0.0:
            if type == 'real':
                interpolatesv = real_data
            elif type == 'fake':
                interpolatesv = fake_data
            elif type == 'mixed':
                alpha = paddle.rand((real_data.shape[0], 1))
63 64 65 66
                alpha = paddle.expand(alpha, [
                    real_data.shape[0],
                    np.prod(real_data.shape) // real_data.shape[0]
                ])
67 68 69 70 71 72 73 74 75 76 77
                alpha = paddle.reshape(alpha, real_data.shape)
                interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
            else:
                raise NotImplementedError('{} not implemented'.format(type))
            interpolatesv.stop_gradient = False
            real_data.stop_gradient = True
            fake_AB = paddle.concat((real_data.detach(), interpolatesv), 1)
            disc_interpolates = netD(fake_AB)

            outs = paddle.fill_constant(disc_interpolates.shape,
                                        disc_interpolates.dtype, 1.0)
78
            gradients = paddle.grad(
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
                outputs=disc_interpolates,
                inputs=fake_AB,
                grad_outputs=outs,
                create_graph=True,
                retain_graph=True,
                only_inputs=True)

            gradients = paddle.reshape(gradients[0], [real_data.shape[0], -1])

            gradient_penalty = paddle.reduce_mean((paddle.norm(
                gradients + 1e-16, 2, 1) - constant)**
                                                  2) * lambda_gp  # added eps
            return gradient_penalty, gradients
        else:
            return 0.0, None

95
    def run_retain(self, need_retain):
96 97 98
        g = Generator()
        d = Discriminator()

M
MRXLT 已提交
99 100
        optim_g = paddle.optimizer.Adam(parameters=g.parameters())
        optim_d = paddle.optimizer.Adam(parameters=d.parameters())
101 102 103 104 105 106 107

        gan_criterion = paddle.nn.MSELoss()
        l1_criterion = paddle.nn.L1Loss()

        A = np.random.rand(2, 3, 32, 32).astype('float32')
        B = np.random.rand(2, 3, 32, 32).astype('float32')

108 109
        realA = paddle.to_variable(A)
        realB = paddle.to_variable(B)
110 111 112 113 114 115 116 117 118 119 120 121
        fakeB = g(realA)

        optim_d.clear_gradients()
        fake_AB = paddle.concat((realA, fakeB), 1)
        G_pred_fake = d(fake_AB.detach())

        false_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 0.0)

        G_gradient_penalty, _ = self.cal_gradient_penalty(
            d, realA, fakeB, lambda_gp=10.0)
        loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty

122
        loss_d.backward(retain_graph=need_retain)
123 124 125 126 127 128 129 130 131 132 133 134
        optim_d.minimize(loss_d)

        optim_g.clear_gradients()
        fake_AB = paddle.concat((realA, fakeB), 1)
        G_pred_fake = d(fake_AB)
        true_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 1.0)
        loss_g = l1_criterion(fakeB, realB) + gan_criterion(G_pred_fake,
                                                            true_target)

        loss_g.backward()
        optim_g.minimize(loss_g)

135 136 137 138 139
    def test_retain(self):
        self.run_retain(need_retain=True)
        self.assertRaises(
            fluid.core.EnforceNotMet, self.run_retain, need_retain=False)

140 141 142

if __name__ == '__main__':
    unittest.main()