test_imperative_gan.py 9.2 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

X
Xin Pan 已提交
17
import numpy as np
18
from test_imperative_base import new_program_scope
X
Xin Pan 已提交
19 20 21

import paddle
import paddle.fluid as fluid
M
minqiyang 已提交
22
import paddle.fluid.core as core
L
lujun 已提交
23
from paddle.fluid.dygraph.base import to_variable
24
from paddle.fluid.optimizer import SGDOptimizer
25
from paddle.nn import Linear
X
Xin Pan 已提交
26 27


28
class Discriminator(fluid.Layer):
29
    def __init__(self):
30
        super().__init__()
31
        self._fc1 = Linear(1, 32)
32
        self._fc2 = Linear(32, 1)
X
Xin Pan 已提交
33 34 35

    def forward(self, inputs):
        x = self._fc1(inputs)
36
        x = paddle.nn.functional.elu(x)
37 38
        x = self._fc2(x)
        return x
X
Xin Pan 已提交
39 40


41
class Generator(fluid.Layer):
42
    def __init__(self):
43
        super().__init__()
44 45
        self._fc1 = Linear(2, 64)
        self._fc2 = Linear(64, 64)
46
        self._fc3 = Linear(64, 1)
X
Xin Pan 已提交
47 48 49

    def forward(self, inputs):
        x = self._fc1(inputs)
50
        x = paddle.nn.functional.elu(x)
X
Xin Pan 已提交
51
        x = self._fc2(x)
52
        x = paddle.nn.functional.elu(x)
53 54
        x = self._fc3(x)
        return x
X
Xin Pan 已提交
55 56


L
lujun 已提交
57
class TestDygraphGAN(unittest.TestCase):
58
    def test_gan_float32(self):
X
Xin Pan 已提交
59
        seed = 90
C
cnn 已提交
60
        paddle.seed(1)
L
Leo Chen 已提交
61
        paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
62 63
        startup = fluid.Program()
        discriminate_p = fluid.Program()
X
Xin Pan 已提交
64 65
        generate_p = fluid.Program()

X
Xin Pan 已提交
66
        scope = fluid.core.Scope()
67 68 69
        with new_program_scope(
            main=discriminate_p, startup=startup, scope=scope
        ):
70 71
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
72

73 74 75 76 77 78
            img = fluid.layers.data(
                name="img", shape=[2, 1], append_batch_size=False
            )
            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False
            )
X
Xin Pan 已提交
79 80

            d_real = discriminator(img)
81
            d_loss_real = paddle.mean(
82 83
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_real,
84 85 86 87 88
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0
                    ),
                )
            )
X
Xin Pan 已提交
89 90

            d_fake = discriminator(generator(noise))
91
            d_loss_fake = paddle.mean(
92 93
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake,
94 95 96 97 98
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=0.0
                    ),
                )
            )
X
Xin Pan 已提交
99 100 101 102 103 104 105

            d_loss = d_loss_real + d_loss_fake

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(d_loss)

        with new_program_scope(main=generate_p, startup=startup, scope=scope):
106 107
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
108

109 110 111
            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False
            )
X
Xin Pan 已提交
112 113

            d_fake = discriminator(generator(noise))
114
            g_loss = paddle.mean(
115 116
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake,
117 118 119 120 121
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0
                    ),
                )
            )
X
Xin Pan 已提交
122 123 124 125

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(g_loss)

126 127 128 129 130
        exe = fluid.Executor(
            fluid.CPUPlace()
            if not core.is_compiled_with_cuda()
            else fluid.CUDAPlace(0)
        )
X
Xin Pan 已提交
131
        static_params = dict()
X
Xin Pan 已提交
132 133 134 135
        with fluid.scope_guard(scope):
            img = np.ones([2, 1], np.float32)
            noise = np.ones([2, 2], np.float32)
            exe.run(startup)
136 137 138 139 140 141 142 143
            static_d_loss = exe.run(
                discriminate_p,
                feed={'img': img, 'noise': noise},
                fetch_list=[d_loss],
            )[0]
            static_g_loss = exe.run(
                generate_p, feed={'noise': noise}, fetch_list=[g_loss]
            )[0]
X
Xin Pan 已提交
144 145

            # generate_p contains all parameters needed.
X
Xin Pan 已提交
146
            for param in generate_p.global_block().all_parameters():
X
Xin Pan 已提交
147
                static_params[param.name] = np.array(
148 149
                    scope.find_var(param.name).get_tensor()
                )
X
Xin Pan 已提交
150 151

        dy_params = dict()
L
lujun 已提交
152
        with fluid.dygraph.guard():
C
cnn 已提交
153
            paddle.seed(1)
L
Leo Chen 已提交
154
            paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
155

156 157
            discriminator = Discriminator()
            generator = Generator()
158 159 160 161 162 163
            sgd = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator.parameters() + generator.parameters()
                ),
            )
X
Xin Pan 已提交
164 165

            d_real = discriminator(to_variable(np.ones([2, 1], np.float32)))
166
            d_loss_real = paddle.mean(
167 168
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_real, label=to_variable(np.ones([2, 1], np.float32))
169 170
                )
            )
X
Xin Pan 已提交
171 172

            d_fake = discriminator(
173 174
                generator(to_variable(np.ones([2, 2], np.float32)))
            )
175
            d_loss_fake = paddle.mean(
176 177 178
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake,
                    label=to_variable(np.zeros([2, 1], np.float32)),
179 180
                )
            )
X
Xin Pan 已提交
181 182

            d_loss = d_loss_real + d_loss_fake
L
lujun 已提交
183
            d_loss.backward()
X
Xin Pan 已提交
184
            sgd.minimize(d_loss)
X
Xin Pan 已提交
185 186
            discriminator.clear_gradients()
            generator.clear_gradients()
X
Xin Pan 已提交
187

X
Xin Pan 已提交
188
            d_fake = discriminator(
189 190
                generator(to_variable(np.ones([2, 2], np.float32)))
            )
191
            g_loss = paddle.mean(
192 193
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake, label=to_variable(np.ones([2, 1], np.float32))
194 195
                )
            )
L
lujun 已提交
196
            g_loss.backward()
X
Xin Pan 已提交
197 198
            sgd.minimize(g_loss)
            for p in discriminator.parameters():
199
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
200
            for p in generator.parameters():
201
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
202

203 204
            dy_g_loss = g_loss.numpy()
            dy_d_loss = d_loss.numpy()
X
Xin Pan 已提交
205

206 207
        dy_params2 = dict()
        with fluid.dygraph.guard():
208
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
C
cnn 已提交
209
            paddle.seed(1)
L
Leo Chen 已提交
210
            paddle.framework.random._manual_program_seed(1)
211 212
            discriminator2 = Discriminator()
            generator2 = Generator()
213 214 215 216 217 218
            sgd2 = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator2.parameters() + generator2.parameters()
                ),
            )
219 220

            d_real2 = discriminator2(to_variable(np.ones([2, 1], np.float32)))
221
            d_loss_real2 = paddle.mean(
222 223 224
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_real2,
                    label=to_variable(np.ones([2, 1], np.float32)),
225 226
                )
            )
227 228

            d_fake2 = discriminator2(
229 230
                generator2(to_variable(np.ones([2, 2], np.float32)))
            )
231
            d_loss_fake2 = paddle.mean(
232 233 234
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake2,
                    label=to_variable(np.zeros([2, 1], np.float32)),
235 236
                )
            )
237 238

            d_loss2 = d_loss_real2 + d_loss_fake2
239
            d_loss2.backward()
240 241 242 243 244
            sgd2.minimize(d_loss2)
            discriminator2.clear_gradients()
            generator2.clear_gradients()

            d_fake2 = discriminator2(
245 246
                generator2(to_variable(np.ones([2, 2], np.float32)))
            )
247
            g_loss2 = paddle.mean(
248 249 250
                paddle.nn.functional.binary_cross_entropy_with_logits(
                    logit=d_fake2,
                    label=to_variable(np.ones([2, 1], np.float32)),
251 252
                )
            )
253
            g_loss2.backward()
254 255 256 257 258 259 260 261 262
            sgd2.minimize(g_loss2)
            for p in discriminator2.parameters():
                dy_params2[p.name] = p.numpy()
            for p in generator.parameters():
                dy_params2[p.name] = p.numpy()

            dy_g_loss2 = g_loss2.numpy()
            dy_d_loss2 = d_loss2.numpy()

X
Xin Pan 已提交
263 264
        self.assertEqual(dy_g_loss, static_g_loss)
        self.assertEqual(dy_d_loss, static_d_loss)
265
        for k, v in dy_params.items():
266
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
X
Xin Pan 已提交
267

268 269
        self.assertEqual(dy_g_loss2, static_g_loss)
        self.assertEqual(dy_d_loss2, static_d_loss)
270
        for k, v in dy_params2.items():
271
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
272

X
Xin Pan 已提交
273 274 275

if __name__ == '__main__':
    unittest.main()