test_imperative_gan.py 8.7 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import unittest
import numpy as np
import six
import sys

import paddle
import paddle.fluid as fluid
M
minqiyang 已提交
23
import paddle.fluid.core as core
X
Xin Pan 已提交
24
from paddle.fluid.optimizer import SGDOptimizer
25
from paddle.fluid import Conv2D, Pool2D, Linear
X
Xin Pan 已提交
26
from test_imperative_base import new_program_scope
L
lujun 已提交
27
from paddle.fluid.dygraph.base import to_variable
X
Xin Pan 已提交
28 29


30
class Discriminator(fluid.Layer):
31 32 33 34
    def __init__(self):
        super(Discriminator, self).__init__()
        self._fc1 = Linear(1, 32, act='elu')
        self._fc2 = Linear(32, 1)
X
Xin Pan 已提交
35 36 37

    def forward(self, inputs):
        x = self._fc1(inputs)
38 39
        x = self._fc2(x)
        return x
X
Xin Pan 已提交
40 41


42
class Generator(fluid.Layer):
43 44 45 46 47
    def __init__(self):
        super(Generator, self).__init__()
        self._fc1 = Linear(2, 64, act='elu')
        self._fc2 = Linear(64, 64, act='elu')
        self._fc3 = Linear(64, 1)
X
Xin Pan 已提交
48 49 50 51

    def forward(self, inputs):
        x = self._fc1(inputs)
        x = self._fc2(x)
52 53
        x = self._fc3(x)
        return x
X
Xin Pan 已提交
54 55


L
lujun 已提交
56
class TestDygraphGAN(unittest.TestCase):
M
minqiyang 已提交
57
    def test_gan_float32(self):
X
Xin Pan 已提交
58
        seed = 90
L
Leo Chen 已提交
59 60
        paddle.manual_seed(1)
        paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
61 62
        startup = fluid.Program()
        discriminate_p = fluid.Program()
X
Xin Pan 已提交
63 64
        generate_p = fluid.Program()

X
Xin Pan 已提交
65 66 67
        scope = fluid.core.Scope()
        with new_program_scope(
                main=discriminate_p, startup=startup, scope=scope):
68 69
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
70 71 72 73 74 75 76 77 78

            img = fluid.layers.data(
                name="img", shape=[2, 1], append_batch_size=False)
            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False)

            d_real = discriminator(img)
            d_loss_real = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
79 80 81
                    x=d_real,
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0)))
X
Xin Pan 已提交
82 83 84 85

            d_fake = discriminator(generator(noise))
            d_loss_fake = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
86 87 88
                    x=d_fake,
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=0.0)))
X
Xin Pan 已提交
89 90 91 92 93 94 95

            d_loss = d_loss_real + d_loss_fake

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(d_loss)

        with new_program_scope(main=generate_p, startup=startup, scope=scope):
96 97
            discriminator = Discriminator()
            generator = Generator()
X
Xin Pan 已提交
98 99 100 101 102 103 104

            noise = fluid.layers.data(
                name="noise", shape=[2, 2], append_batch_size=False)

            d_fake = discriminator(generator(noise))
            g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
X
Xin Pan 已提交
105 106 107
                    x=d_fake,
                    label=fluid.layers.fill_constant(
                        shape=[2, 1], dtype='float32', value=1.0)))
X
Xin Pan 已提交
108 109 110 111

            sgd = SGDOptimizer(learning_rate=1e-3)
            sgd.minimize(g_loss)

M
minqiyang 已提交
112 113
        exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda(
        ) else fluid.CUDAPlace(0))
X
Xin Pan 已提交
114
        static_params = dict()
X
Xin Pan 已提交
115 116 117 118
        with fluid.scope_guard(scope):
            img = np.ones([2, 1], np.float32)
            noise = np.ones([2, 2], np.float32)
            exe.run(startup)
X
Xin Pan 已提交
119 120 121 122 123 124 125 126 127
            static_d_loss = exe.run(discriminate_p,
                                    feed={'img': img,
                                          'noise': noise},
                                    fetch_list=[d_loss])[0]
            static_g_loss = exe.run(generate_p,
                                    feed={'noise': noise},
                                    fetch_list=[g_loss])[0]

            # generate_p contains all parameters needed.
X
Xin Pan 已提交
128
            for param in generate_p.global_block().all_parameters():
X
Xin Pan 已提交
129 130 131 132
                static_params[param.name] = np.array(
                    scope.find_var(param.name).get_tensor())

        dy_params = dict()
L
lujun 已提交
133
        with fluid.dygraph.guard():
L
Leo Chen 已提交
134 135
            paddle.manual_seed(1)
            paddle.framework.random._manual_program_seed(1)
X
Xin Pan 已提交
136

137 138 139 140 141 142
            discriminator = Discriminator()
            generator = Generator()
            sgd = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator.parameters() + generator.parameters()))
X
Xin Pan 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155

            d_real = discriminator(to_variable(np.ones([2, 1], np.float32)))
            d_loss_real = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_real, label=to_variable(np.ones([2, 1], np.float32))))

            d_fake = discriminator(
                generator(to_variable(np.ones([2, 2], np.float32))))
            d_loss_fake = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake, label=to_variable(np.zeros([2, 1], np.float32))))

            d_loss = d_loss_real + d_loss_fake
L
lujun 已提交
156
            d_loss.backward()
X
Xin Pan 已提交
157
            sgd.minimize(d_loss)
X
Xin Pan 已提交
158 159
            discriminator.clear_gradients()
            generator.clear_gradients()
X
Xin Pan 已提交
160

X
Xin Pan 已提交
161 162 163 164 165
            d_fake = discriminator(
                generator(to_variable(np.ones([2, 2], np.float32))))
            g_loss = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake, label=to_variable(np.ones([2, 1], np.float32))))
L
lujun 已提交
166
            g_loss.backward()
X
Xin Pan 已提交
167 168
            sgd.minimize(g_loss)
            for p in discriminator.parameters():
169
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
170
            for p in generator.parameters():
171
                dy_params[p.name] = p.numpy()
X
Xin Pan 已提交
172

173 174
            dy_g_loss = g_loss.numpy()
            dy_d_loss = d_loss.numpy()
X
Xin Pan 已提交
175

176 177
        dy_params2 = dict()
        with fluid.dygraph.guard():
178
            fluid.set_flags({'FLAGS_sort_sum_gradient': True})
L
Leo Chen 已提交
179 180
            paddle.manual_seed(1)
            paddle.framework.random._manual_program_seed(1)
181 182 183 184 185 186
            discriminator2 = Discriminator()
            generator2 = Generator()
            sgd2 = SGDOptimizer(
                learning_rate=1e-3,
                parameter_list=(
                    discriminator2.parameters() + generator2.parameters()))
187 188 189 190 191 192 193 194 195 196 197 198 199

            d_real2 = discriminator2(to_variable(np.ones([2, 1], np.float32)))
            d_loss_real2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_real2, label=to_variable(np.ones([2, 1], np.float32))))

            d_fake2 = discriminator2(
                generator2(to_variable(np.ones([2, 2], np.float32))))
            d_loss_fake2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake2, label=to_variable(np.zeros([2, 1], np.float32))))

            d_loss2 = d_loss_real2 + d_loss_fake2
200
            d_loss2.backward()
201 202 203 204 205 206 207 208 209
            sgd2.minimize(d_loss2)
            discriminator2.clear_gradients()
            generator2.clear_gradients()

            d_fake2 = discriminator2(
                generator2(to_variable(np.ones([2, 2], np.float32))))
            g_loss2 = fluid.layers.reduce_mean(
                fluid.layers.sigmoid_cross_entropy_with_logits(
                    x=d_fake2, label=to_variable(np.ones([2, 1], np.float32))))
210
            g_loss2.backward()
211 212 213 214 215 216 217 218 219
            sgd2.minimize(g_loss2)
            for p in discriminator2.parameters():
                dy_params2[p.name] = p.numpy()
            for p in generator.parameters():
                dy_params2[p.name] = p.numpy()

            dy_g_loss2 = g_loss2.numpy()
            dy_d_loss2 = d_loss2.numpy()

X
Xin Pan 已提交
220 221 222 223
        self.assertEqual(dy_g_loss, static_g_loss)
        self.assertEqual(dy_d_loss, static_d_loss)
        for k, v in six.iteritems(dy_params):
            self.assertTrue(np.allclose(v, static_params[k]))
X
Xin Pan 已提交
224

225 226 227 228 229
        self.assertEqual(dy_g_loss2, static_g_loss)
        self.assertEqual(dy_d_loss2, static_d_loss)
        for k, v in six.iteritems(dy_params2):
            self.assertTrue(np.allclose(v, static_params[k]))

X
Xin Pan 已提交
230 231 232

if __name__ == '__main__':
    unittest.main()