test_imperative_gan.py 9.2 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

X
Xin Pan 已提交
17
import numpy as np
18
from test_imperative_base import new_program_scope
X
Xin Pan 已提交
19 20 21

import paddle
import paddle.fluid as fluid
M
minqiyang 已提交
22
import paddle.fluid.core as core
23
from paddle.fluid import Linear
L
lujun 已提交
24
from paddle.fluid.dygraph.base import to_variable
25
from paddle.fluid.framework import _test_eager_guard
26
from paddle.fluid.optimizer import SGDOptimizer
X
Xin Pan 已提交
27 28


29
class Discriminator(fluid.Layer):
30
    def __init__(self):
31
        super().__init__()
32 33
        self._fc1 = Linear(1, 32, act='elu')
        self._fc2 = Linear(32, 1)
X
Xin Pan 已提交
34 35 36

    def forward(self, inputs):
        x = self._fc1(inputs)
37 38
        x = self._fc2(x)
        return x
X
Xin Pan 已提交
39 40


41
class Generator(fluid.Layer):
42
    def __init__(self):
43
        super().__init__()
44 45 46
        self._fc1 = Linear(2, 64, act='elu')
        self._fc2 = Linear(64, 64, act='elu')
        self._fc3 = Linear(64, 1)
X
Xin Pan 已提交
47 48 49 50

    def forward(self, inputs):
        x = self._fc1(inputs)
        x = self._fc2(x)
51 52
        x = self._fc3(x)
        return x
X
Xin Pan 已提交
53 54


L
lujun 已提交
55
class TestDygraphGAN(unittest.TestCase):
56
    def func_test_gan_float32(self):
X
Xin Pan 已提交
57
        seed = 90
C
cnn 已提交
58
        paddle.seed(1)
L
Leo Chen 已提交
59
        paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
60 61
        startup = fluid.Program()
        discriminate_p = fluid.Program()
X
Xin Pan 已提交
62 63
        generate_p = fluid.Program()

X
Xin Pan 已提交
64
        scope = fluid.core.Scope()
65 66 67
        with new_program_scope(
            main=discriminate_p, startup=startup, scope=scope
        ):
68 69
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
70

71 72 73 74 75 76
            img = fluid.layers.data(
                name="img", shape=[2, 1], append_batch_size=False
            )
            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False
            )
X
Xin Pan 已提交
77 78 79 80

            d_real = discriminator(img)
            d_loss_real = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
81
                    x=d_real,
82 83 84 85 86
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0
                    ),
                )
            )
X
Xin Pan 已提交
87 88 89 90

            d_fake = discriminator(generator(noise))
            d_loss_fake = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
91
                    x=d_fake,
92 93 94 95 96
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=0.0
                    ),
                )
            )
X
Xin Pan 已提交
97 98 99 100 101 102 103

            d_loss = d_loss_real + d_loss_fake

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(d_loss)

        with new_program_scope(main=generate_p, startup=startup, scope=scope):
104 105
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
106

107 108 109
            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False
            )
X
Xin Pan 已提交
110 111 112 113

            d_fake = discriminator(generator(noise))
            g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
114
                    x=d_fake,
115 116 117 118 119
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0
                    ),
                )
            )
X
Xin Pan 已提交
120 121 122 123

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(g_loss)

124 125 126 127 128
        exe = fluid.Executor(
            fluid.CPUPlace()
            if not core.is_compiled_with_cuda()
            else fluid.CUDAPlace(0)
        )
X
Xin Pan 已提交
129
        static_params = dict()
X
Xin Pan 已提交
130 131 132 133
        with fluid.scope_guard(scope):
            img = np.ones([2, 1], np.float32)
            noise = np.ones([2, 2], np.float32)
            exe.run(startup)
134 135 136 137 138 139 140 141
            static_d_loss = exe.run(
                discriminate_p,
                feed={'img': img, 'noise': noise},
                fetch_list=[d_loss],
            )[0]
            static_g_loss = exe.run(
                generate_p, feed={'noise': noise}, fetch_list=[g_loss]
            )[0]
X
Xin Pan 已提交
142 143

            # generate_p contains all parameters needed.
X
Xin Pan 已提交
144
            for param in generate_p.global_block().all_parameters():
X
Xin Pan 已提交
145
                static_params[param.name] = np.array(
146 147
                    scope.find_var(param.name).get_tensor()
                )
X
Xin Pan 已提交
148 149

        dy_params = dict()
L
lujun 已提交
150
        with fluid.dygraph.guard():
C
cnn 已提交
151
            paddle.seed(1)
L
Leo Chen 已提交
152
            paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
153

154 155
            discriminator = Discriminator()
            generator = Generator()
156 157 158 159 160 161
            sgd = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator.parameters() + generator.parameters()
                ),
            )
X
Xin Pan 已提交
162 163 164 165

            d_real = discriminator(to_variable(np.ones([2, 1], np.float32)))
            d_loss_real = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
166 167 168
                    x=d_real, label=to_variable(np.ones([2, 1], np.float32))
                )
            )
X
Xin Pan 已提交
169 170

            d_fake = discriminator(
171 172
                generator(to_variable(np.ones([2, 2], np.float32)))
            )
X
Xin Pan 已提交
173 174
            d_loss_fake = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
175 176 177
                    x=d_fake, label=to_variable(np.zeros([2, 1], np.float32))
                )
            )
X
Xin Pan 已提交
178 179

            d_loss = d_loss_real + d_loss_fake
L
lujun 已提交
180
            d_loss.backward()
X
Xin Pan 已提交
181
            sgd.minimize(d_loss)
X
Xin Pan 已提交
182 183
            discriminator.clear_gradients()
            generator.clear_gradients()
X
Xin Pan 已提交
184

X
Xin Pan 已提交
185
            d_fake = discriminator(
186 187
                generator(to_variable(np.ones([2, 2], np.float32)))
            )
X
Xin Pan 已提交
188 189
            g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
190 191 192
                    x=d_fake, label=to_variable(np.ones([2, 1], np.float32))
                )
            )
L
lujun 已提交
193
            g_loss.backward()
X
Xin Pan 已提交
194 195
            sgd.minimize(g_loss)
            for p in discriminator.parameters():
196
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
197
            for p in generator.parameters():
198
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
199

200 201
            dy_g_loss = g_loss.numpy()
            dy_d_loss = d_loss.numpy()
X
Xin Pan 已提交
202

203 204
        dy_params2 = dict()
        with fluid.dygraph.guard():
205
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
C
cnn 已提交
206
            paddle.seed(1)
L
Leo Chen 已提交
207
            paddle.framework.random._manual_program_seed(1)
208 209
            discriminator2 = Discriminator()
            generator2 = Generator()
210 211 212 213 214 215
            sgd2 = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator2.parameters() + generator2.parameters()
                ),
            )
216 217 218 219

            d_real2 = discriminator2(to_variable(np.ones([2, 1], np.float32)))
            d_loss_real2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
220 221 222
                    x=d_real2, label=to_variable(np.ones([2, 1], np.float32))
                )
            )
223 224

            d_fake2 = discriminator2(
225 226
                generator2(to_variable(np.ones([2, 2], np.float32)))
            )
227 228
            d_loss_fake2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
229 230 231
                    x=d_fake2, label=to_variable(np.zeros([2, 1], np.float32))
                )
            )
232 233

            d_loss2 = d_loss_real2 + d_loss_fake2
234
            d_loss2.backward()
235 236 237 238 239
            sgd2.minimize(d_loss2)
            discriminator2.clear_gradients()
            generator2.clear_gradients()

            d_fake2 = discriminator2(
240 241
                generator2(to_variable(np.ones([2, 2], np.float32)))
            )
242 243
            g_loss2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
244 245 246
                    x=d_fake2, label=to_variable(np.ones([2, 1], np.float32))
                )
            )
247
            g_loss2.backward()
248 249 250 251 252 253 254 255 256
            sgd2.minimize(g_loss2)
            for p in discriminator2.parameters():
                dy_params2[p.name] = p.numpy()
            for p in generator.parameters():
                dy_params2[p.name] = p.numpy()

            dy_g_loss2 = g_loss2.numpy()
            dy_d_loss2 = d_loss2.numpy()

X
Xin Pan 已提交
257 258
        self.assertEqual(dy_g_loss, static_g_loss)
        self.assertEqual(dy_d_loss, static_d_loss)
259
        for k, v in dy_params.items():
260
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
X
Xin Pan 已提交
261

262 263
        self.assertEqual(dy_g_loss2, static_g_loss)
        self.assertEqual(dy_d_loss2, static_d_loss)
264
        for k, v in dy_params2.items():
265
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
266

267 268 269 270 271
    def test_gan_float32(self):
        with _test_eager_guard():
            self.func_test_gan_float32()
        self.func_test_gan_float32()

X
Xin Pan 已提交
272 273 274

if __name__ == '__main__':
    unittest.main()