test_retain_graph.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddle.fluid as fluid
import unittest

20
paddle.disable_static()
21 22
SEED = 2020
np.random.seed(SEED)
C
cnn 已提交
23
paddle.seed(SEED)
24 25 26 27 28


class Generator(fluid.dygraph.Layer):
    def __init__(self):
        super(Generator, self).__init__()
C
cnn 已提交
29
        self.conv1 = paddle.nn.Conv2D(3, 3, 3, padding=1)
30 31 32 33 34 35 36 37 38 39

    def forward(self, x):
        x = self.conv1(x)
        x = fluid.layers.tanh(x)
        return x


class Discriminator(fluid.dygraph.Layer):
    def __init__(self):
        super(Discriminator, self).__init__()
C
cnn 已提交
40
        self.convd = paddle.nn.Conv2D(6, 3, 1)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

    def forward(self, x):
        x = self.convd(x)
        return x


class TestRetainGraph(unittest.TestCase):
    def cal_gradient_penalty(self,
                             netD,
                             real_data,
                             fake_data,
                             edge_data=None,
                             type='mixed',
                             constant=1.0,
                             lambda_gp=10.0):
        if lambda_gp > 0.0:
            if type == 'real':
                interpolatesv = real_data
            elif type == 'fake':
                interpolatesv = fake_data
            elif type == 'mixed':
                alpha = paddle.rand((real_data.shape[0], 1))
63 64 65 66
                alpha = paddle.expand(alpha, [
                    real_data.shape[0],
                    np.prod(real_data.shape) // real_data.shape[0]
                ])
67 68 69 70 71 72 73 74 75
                alpha = paddle.reshape(alpha, real_data.shape)
                interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
            else:
                raise NotImplementedError('{} not implemented'.format(type))
            interpolatesv.stop_gradient = False
            real_data.stop_gradient = True
            fake_AB = paddle.concat((real_data.detach(), interpolatesv), 1)
            disc_interpolates = netD(fake_AB)

C
chentianyu03 已提交
76 77
            outs = paddle.fluid.layers.fill_constant(
                disc_interpolates.shape, disc_interpolates.dtype, 1.0)
78
            gradients = paddle.grad(
79 80 81 82 83 84 85 86 87
                outputs=disc_interpolates,
                inputs=fake_AB,
                grad_outputs=outs,
                create_graph=True,
                retain_graph=True,
                only_inputs=True)

            gradients = paddle.reshape(gradients[0], [real_data.shape[0], -1])

C
chentianyu03 已提交
88 89 90
            gradient_penalty = paddle.mean((paddle.norm(gradients + 1e-16, 2, 1)
                                            - constant)**
                                           2) * lambda_gp  # added eps
91 92 93 94
            return gradient_penalty, gradients
        else:
            return 0.0, None

95
    def run_retain(self, need_retain):
96 97 98
        g = Generator()
        d = Discriminator()

M
MRXLT 已提交
99 100
        optim_g = paddle.optimizer.Adam(parameters=g.parameters())
        optim_d = paddle.optimizer.Adam(parameters=d.parameters())
101 102 103 104 105 106 107

        gan_criterion = paddle.nn.MSELoss()
        l1_criterion = paddle.nn.L1Loss()

        A = np.random.rand(2, 3, 32, 32).astype('float32')
        B = np.random.rand(2, 3, 32, 32).astype('float32')

Z
Zhou Wei 已提交
108 109
        realA = paddle.to_tensor(A)
        realB = paddle.to_tensor(B)
110 111 112 113 114 115
        fakeB = g(realA)

        optim_d.clear_gradients()
        fake_AB = paddle.concat((realA, fakeB), 1)
        G_pred_fake = d(fake_AB.detach())

C
chentianyu03 已提交
116 117
        false_target = paddle.fluid.layers.fill_constant(G_pred_fake.shape,
                                                         'float32', 0.0)
118 119 120 121 122

        G_gradient_penalty, _ = self.cal_gradient_penalty(
            d, realA, fakeB, lambda_gp=10.0)
        loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty

123
        loss_d.backward(retain_graph=need_retain)
124 125 126 127 128
        optim_d.minimize(loss_d)

        optim_g.clear_gradients()
        fake_AB = paddle.concat((realA, fakeB), 1)
        G_pred_fake = d(fake_AB)
C
chentianyu03 已提交
129 130
        true_target = paddle.fluid.layers.fill_constant(G_pred_fake.shape,
                                                        'float32', 1.0)
131 132 133 134 135 136
        loss_g = l1_criterion(fakeB, realB) + gan_criterion(G_pred_fake,
                                                            true_target)

        loss_g.backward()
        optim_g.minimize(loss_g)

J
Jiabin Yang 已提交
137
    def func_retain(self):
138
        self.run_retain(need_retain=True)
J
Jiabin Yang 已提交
139 140 141 142 143 144 145
        if not fluid.framework.in_dygraph_mode():
            self.assertRaises(RuntimeError, self.run_retain, need_retain=False)

    def test_retain(self):
        with fluid.framework._test_eager_guard():
            self.func_retain()
        self.func_retain()
146

147 148 149

if __name__ == '__main__':
    unittest.main()