test_imperative_gan.py 9.4 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import six

import paddle
import paddle.fluid as fluid
M
minqiyang 已提交
21
import paddle.fluid.core as core
X
Xin Pan 已提交
22
from paddle.fluid.optimizer import SGDOptimizer
23
from paddle.fluid import Linear
X
Xin Pan 已提交
24
from test_imperative_base import new_program_scope
L
lujun 已提交
25
from paddle.fluid.dygraph.base import to_variable
26
from paddle.fluid.framework import _test_eager_guard
X
Xin Pan 已提交
27 28


29
class Discriminator(fluid.Layer):
30

31 32 33 34
    def __init__(self):
        super(Discriminator, self).__init__()
        self._fc1 = Linear(1, 32, act='elu')
        self._fc2 = Linear(32, 1)
X
Xin Pan 已提交
35 36 37

    def forward(self, inputs):
        x = self._fc1(inputs)
38 39
        x = self._fc2(x)
        return x
X
Xin Pan 已提交
40 41


42
class Generator(fluid.Layer):
43

44 45 46 47 48
    def __init__(self):
        super(Generator, self).__init__()
        self._fc1 = Linear(2, 64, act='elu')
        self._fc2 = Linear(64, 64, act='elu')
        self._fc3 = Linear(64, 1)
X
Xin Pan 已提交
49 50 51 52

    def forward(self, inputs):
        x = self._fc1(inputs)
        x = self._fc2(x)
53 54
        x = self._fc3(x)
        return x
X
Xin Pan 已提交
55 56


L
lujun 已提交
57
class TestDygraphGAN(unittest.TestCase):
58

59
    def func_test_gan_float32(self):
X
Xin Pan 已提交
60
        seed = 90
C
cnn 已提交
61
        paddle.seed(1)
L
Leo Chen 已提交
62
        paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
63 64
        startup = fluid.Program()
        discriminate_p = fluid.Program()
X
Xin Pan 已提交
65 66
        generate_p = fluid.Program()

X
Xin Pan 已提交
67
        scope = fluid.core.Scope()
68 69 70
        with new_program_scope(main=discriminate_p,
                               startup=startup,
                               scope=scope):
71 72
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
73

74 75 76 77 78 79
            img = fluid.layers.data(name="img",
                                    shape=[2, 1],
                                    append_batch_size=False)
            noise = fluid.layers.data(name="noise",
                                      shape=[2, 2],
                                      append_batch_size=False)
X
Xin Pan 已提交
80 81 82 83

            d_real = discriminator(img)
            d_loss_real = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
84
                    x=d_real,
85 86 87
                    label=fluid.layers.fill_constant(shape=[2, 1],
                                                     dtype='float32',
                                                     value=1.0)))
X
Xin Pan 已提交
88 89 90 91

            d_fake = discriminator(generator(noise))
            d_loss_fake = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
92
                    x=d_fake,
93 94 95
                    label=fluid.layers.fill_constant(shape=[2, 1],
                                                     dtype='float32',
                                                     value=0.0)))
X
Xin Pan 已提交
96 97 98 99 100 101 102

            d_loss = d_loss_real + d_loss_fake

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(d_loss)

        with new_program_scope(main=generate_p, startup=startup, scope=scope):
103 104
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
105

106 107 108
            noise = fluid.layers.data(name="noise",
                                      shape=[2, 2],
                                      append_batch_size=False)
X
Xin Pan 已提交
109 110 111 112

            d_fake = discriminator(generator(noise))
            g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
113
                    x=d_fake,
114 115 116
                    label=fluid.layers.fill_constant(shape=[2, 1],
                                                     dtype='float32',
                                                     value=1.0)))
X
Xin Pan 已提交
117 118 119 120

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(g_loss)

121 122
        exe = fluid.Executor(fluid.CPUPlace(
        ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
X
Xin Pan 已提交
123
        static_params = dict()
X
Xin Pan 已提交
124 125 126 127
        with fluid.scope_guard(scope):
            img = np.ones([2, 1], np.float32)
            noise = np.ones([2, 2], np.float32)
            exe.run(startup)
X
Xin Pan 已提交
128
            static_d_loss = exe.run(discriminate_p,
129 130 131 132
                                    feed={
                                        'img': img,
                                        'noise': noise
                                    },
X
Xin Pan 已提交
133 134 135 136 137 138
                                    fetch_list=[d_loss])[0]
            static_g_loss = exe.run(generate_p,
                                    feed={'noise': noise},
                                    fetch_list=[g_loss])[0]

            # generate_p contains all parameters needed.
X
Xin Pan 已提交
139
            for param in generate_p.global_block().all_parameters():
X
Xin Pan 已提交
140 141 142 143
                static_params[param.name] = np.array(
                    scope.find_var(param.name).get_tensor())

        dy_params = dict()
L
lujun 已提交
144
        with fluid.dygraph.guard():
C
cnn 已提交
145
            paddle.seed(1)
L
Leo Chen 已提交
146
            paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
147

148 149
            discriminator = Discriminator()
            generator = Generator()
150 151 152
            sgd = SGDOptimizer(learning_rate=1e-3,
                               parameter_list=(discriminator.parameters() +
                                               generator.parameters()))
X
Xin Pan 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165

            d_real = discriminator(to_variable(np.ones([2, 1], np.float32)))
            d_loss_real = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_real, label=to_variable(np.ones([2, 1], np.float32))))

            d_fake = discriminator(
                generator(to_variable(np.ones([2, 2], np.float32))))
            d_loss_fake = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake, label=to_variable(np.zeros([2, 1], np.float32))))

            d_loss = d_loss_real + d_loss_fake
L
lujun 已提交
166
            d_loss.backward()
X
Xin Pan 已提交
167
            sgd.minimize(d_loss)
X
Xin Pan 已提交
168 169
            discriminator.clear_gradients()
            generator.clear_gradients()
X
Xin Pan 已提交
170

X
Xin Pan 已提交
171 172 173 174 175
            d_fake = discriminator(
                generator(to_variable(np.ones([2, 2], np.float32))))
            g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake, label=to_variable(np.ones([2, 1], np.float32))))
L
lujun 已提交
176
            g_loss.backward()
X
Xin Pan 已提交
177 178
            sgd.minimize(g_loss)
            for p in discriminator.parameters():
179
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
180
            for p in generator.parameters():
181
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
182

183 184
            dy_g_loss = g_loss.numpy()
            dy_d_loss = d_loss.numpy()
X
Xin Pan 已提交
185

186 187
        dy_params2 = dict()
        with fluid.dygraph.guard():
188
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
C
cnn 已提交
189
            paddle.seed(1)
L
Leo Chen 已提交
190
            paddle.framework.random._manual_program_seed(1)
191 192
            discriminator2 = Discriminator()
            generator2 = Generator()
193 194 195
            sgd2 = SGDOptimizer(learning_rate=1e-3,
                                parameter_list=(discriminator2.parameters() +
                                                generator2.parameters()))
196 197 198 199 200 201 202 203 204 205 206 207 208

            d_real2 = discriminator2(to_variable(np.ones([2, 1], np.float32)))
            d_loss_real2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_real2, label=to_variable(np.ones([2, 1], np.float32))))

            d_fake2 = discriminator2(
                generator2(to_variable(np.ones([2, 2], np.float32))))
            d_loss_fake2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake2, label=to_variable(np.zeros([2, 1], np.float32))))

            d_loss2 = d_loss_real2 + d_loss_fake2
209
            d_loss2.backward()
210 211 212 213 214 215 216 217 218
            sgd2.minimize(d_loss2)
            discriminator2.clear_gradients()
            generator2.clear_gradients()

            d_fake2 = discriminator2(
                generator2(to_variable(np.ones([2, 2], np.float32))))
            g_loss2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake2, label=to_variable(np.ones([2, 1], np.float32))))
219
            g_loss2.backward()
220 221 222 223 224 225 226 227 228
            sgd2.minimize(g_loss2)
            for p in discriminator2.parameters():
                dy_params2[p.name] = p.numpy()
            for p in generator.parameters():
                dy_params2[p.name] = p.numpy()

            dy_g_loss2 = g_loss2.numpy()
            dy_d_loss2 = d_loss2.numpy()

X
Xin Pan 已提交
229 230 231
        self.assertEqual(dy_g_loss, static_g_loss)
        self.assertEqual(dy_d_loss, static_d_loss)
        for k, v in six.iteritems(dy_params):
232
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
X
Xin Pan 已提交
233

234 235 236
        self.assertEqual(dy_g_loss2, static_g_loss)
        self.assertEqual(dy_d_loss2, static_d_loss)
        for k, v in six.iteritems(dy_params2):
237
            np.testing.assert_allclose(v, static_params[k], rtol=1e-05)
238

239 240 241 242 243
    def test_gan_float32(self):
        with _test_eager_guard():
            self.func_test_gan_float32()
        self.func_test_gan_float32()

X
Xin Pan 已提交
244 245 246

if __name__ == '__main__':
    unittest.main()