test_imperative_gan.py 9.4 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

X
Xin Pan 已提交
17
import numpy as np
18
from test_imperative_base import new_program_scope
X
Xin Pan 已提交
19 20 21

import paddle
import paddle.fluid as fluid
M
minqiyang 已提交
22
import paddle.fluid.core as core
L
lujun 已提交
23
from paddle.fluid.dygraph.base import to_variable
24
from paddle.fluid.framework import _test_eager_guard
25
from paddle.fluid.optimizer import SGDOptimizer
26
from paddle.nn import Linear
X
Xin Pan 已提交
27 28


29
class Discriminator(fluid.Layer):
30
    def __init__(self):
31
        super().__init__()
32
        self._fc1 = Linear(1, 32)
33
        self._fc2 = Linear(32, 1)
X
Xin Pan 已提交
34 35 36

    def forward(self, inputs):
        x = self._fc1(inputs)
37
        x = paddle.nn.functional.elu(x)
38 39
        x = self._fc2(x)
        return x
X
Xin Pan 已提交
40 41


42
class Generator(fluid.Layer):
43
    def __init__(self):
44
        super().__init__()
45 46
        self._fc1 = Linear(2, 64)
        self._fc2 = Linear(64, 64)
47
        self._fc3 = Linear(64, 1)
X
Xin Pan 已提交
48 49 50

    def forward(self, inputs):
        x = self._fc1(inputs)
51
        x = paddle.nn.functional.elu(x)
X
Xin Pan 已提交
52
        x = self._fc2(x)
53
        x = paddle.nn.functional.elu(x)
54 55
        x = self._fc3(x)
        return x
X
Xin Pan 已提交
56 57


L
lujun 已提交
58
class TestDygraphGAN(unittest.TestCase):
59
    def func_test_gan_float32(self):
X
Xin Pan 已提交
60
        seed = 90
C
cnn 已提交
61
        paddle.seed(1)
L
Leo Chen 已提交
62
        paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
63 64
        startup = fluid.Program()
        discriminate_p = fluid.Program()
X
Xin Pan 已提交
65 66
        generate_p = fluid.Program()

X
Xin Pan 已提交
67
        scope = fluid.core.Scope()
68 69 70
        with new_program_scope(
            main=discriminate_p, startup=startup, scope=scope
        ):
71 72
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
73

74 75 76 77 78 79
            img = fluid.layers.data(
                name="img", shape=[2, 1], append_batch_size=False
            )
            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False
            )
X
Xin Pan 已提交
80 81

            d_real = discriminator(img)
82
            d_loss_real = paddle.mean(
83 84
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_real,
85 86 87 88 89
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0
                    ),
                )
            )
X
Xin Pan 已提交
90 91

            d_fake = discriminator(generator(noise))
92
            d_loss_fake = paddle.mean(
93 94
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake,
95 96 97 98 99
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=0.0
                    ),
                )
            )
X
Xin Pan 已提交
100 101 102 103 104 105 106

            d_loss = d_loss_real + d_loss_fake

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(d_loss)

        with new_program_scope(main=generate_p, startup=startup, scope=scope):
107 108
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
109

110 111 112
            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False
            )
X
Xin Pan 已提交
113 114

            d_fake = discriminator(generator(noise))
115
            g_loss = paddle.mean(
116 117
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake,
118 119 120 121 122
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0
                    ),
                )
            )
X
Xin Pan 已提交
123 124 125 126

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(g_loss)

127 128 129 130 131
        exe = fluid.Executor(
            fluid.CPUPlace()
            if not core.is_compiled_with_cuda()
            else fluid.CUDAPlace(0)
        )
X
Xin Pan 已提交
132
        static_params = dict()
X
Xin Pan 已提交
133 134 135 136
        with fluid.scope_guard(scope):
            img = np.ones([2, 1], np.float32)
            noise = np.ones([2, 2], np.float32)
            exe.run(startup)
137 138 139 140 141 142 143 144
            static_d_loss = exe.run(
                discriminate_p,
                feed={'img': img, 'noise': noise},
                fetch_list=[d_loss],
            )[0]
            static_g_loss = exe.run(
                generate_p, feed={'noise': noise}, fetch_list=[g_loss]
            )[0]
X
Xin Pan 已提交
145 146

            # generate_p contains all parameters needed.
X
Xin Pan 已提交
147
            for param in generate_p.global_block().all_parameters():
X
Xin Pan 已提交
148
                static_params[param.name] = np.array(
149 150
                    scope.find_var(param.name).get_tensor()
                )
X
Xin Pan 已提交
151 152

        dy_params = dict()
L
lujun 已提交
153
        with fluid.dygraph.guard():
C
cnn 已提交
154
            paddle.seed(1)
L
Leo Chen 已提交
155
            paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
156

157 158
            discriminator = Discriminator()
            generator = Generator()
159 160 161 162 163 164
            sgd = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator.parameters() + generator.parameters()
                ),
            )
X
Xin Pan 已提交
165 166

            d_real = discriminator(to_variable(np.ones([2, 1], np.float32)))
167
            d_loss_real = paddle.mean(
168 169
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_real, label=to_variable(np.ones([2, 1], np.float32))
170 171
                )
            )
X
Xin Pan 已提交
172 173

            d_fake = discriminator(
174 175
                generator(to_variable(np.ones([2, 2], np.float32)))
            )
176
            d_loss_fake = paddle.mean(
177 178 179
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake,
                    label=to_variable(np.zeros([2, 1], np.float32)),
180 181
                )
            )
X
Xin Pan 已提交
182 183

            d_loss = d_loss_real + d_loss_fake
L
lujun 已提交
184
            d_loss.backward()
X
Xin Pan 已提交
185
            sgd.minimize(d_loss)
X
Xin Pan 已提交
186 187
            discriminator.clear_gradients()
            generator.clear_gradients()
X
Xin Pan 已提交
188

X
Xin Pan 已提交
189
            d_fake = discriminator(
190 191
                generator(to_variable(np.ones([2, 2], np.float32)))
            )
192
            g_loss = paddle.mean(
193 194
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake, label=to_variable(np.ones([2, 1], np.float32))
195 196
                )
            )
L
lujun 已提交
197
            g_loss.backward()
X
Xin Pan 已提交
198 199
            sgd.minimize(g_loss)
            for p in discriminator.parameters():
200
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
201
            for p in generator.parameters():
202
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
203

204 205
            dy_g_loss = g_loss.numpy()
            dy_d_loss = d_loss.numpy()
X
Xin Pan 已提交
206

207 208
        dy_params2 = dict()
        with fluid.dygraph.guard():
209
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
C
cnn 已提交
210
            paddle.seed(1)
L
Leo Chen 已提交
211
            paddle.framework.random._manual_program_seed(1)
212 213
            discriminator2 = Discriminator()
            generator2 = Generator()
214 215 216 217 218 219
            sgd2 = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator2.parameters() + generator2.parameters()
                ),
            )
220 221

            d_real2 = discriminator2(to_variable(np.ones([2, 1], np.float32)))
222
            d_loss_real2 = paddle.mean(
223 224 225
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_real2,
                    label=to_variable(np.ones([2, 1], np.float32)),
226 227
                )
            )
228 229

            d_fake2 = discriminator2(
230 231
                generator2(to_variable(np.ones([2, 2], np.float32)))
            )
232
            d_loss_fake2 = paddle.mean(
233 234 235
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake2,
                    label=to_variable(np.zeros([2, 1], np.float32)),
236 237
                )
            )
238 239

            d_loss2 = d_loss_real2 + d_loss_fake2
240
            d_loss2.backward()
241 242 243 244 245
            sgd2.minimize(d_loss2)
            discriminator2.clear_gradients()
            generator2.clear_gradients()

            d_fake2 = discriminator2(
246 247
                generator2(to_variable(np.ones([2, 2], np.float32)))
            )
248
            g_loss2 = paddle.mean(
249 250 251
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake2,
                    label=to_variable(np.ones([2, 1], np.float32)),
252 253
                )
            )
254
            g_loss2.backward()
255 256 257 258 259 260 261 262 263
            sgd2.minimize(g_loss2)
            for p in discriminator2.parameters():
                dy_params2[p.name] = p.numpy()
            for p in generator.parameters():
                dy_params2[p.name] = p.numpy()

            dy_g_loss2 = g_loss2.numpy()
            dy_d_loss2 = d_loss2.numpy()

X
Xin Pan 已提交
264 265
        self.assertEqual(dy_g_loss, static_g_loss)
        self.assertEqual(dy_d_loss, static_d_loss)
266
        for k, v in dy_params.items():
267
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
X
Xin Pan 已提交
268

269 270
        self.assertEqual(dy_g_loss2, static_g_loss)
        self.assertEqual(dy_d_loss2, static_d_loss)
271
        for k, v in dy_params2.items():
272
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
273

274 275 276 277 278
    def test_gan_float32(self):
        with _test_eager_guard():
            self.func_test_gan_float32()
        self.func_test_gan_float32()

X
Xin Pan 已提交
279 280 281

if __name__ == '__main__':
    unittest.main()