test_imperative_ptb_rnn.py 16.9 KB
Newer Older
J
JiabinYang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
L
Leo Chen 已提交
16
import paddle
J
JiabinYang 已提交
17
import paddle.fluid as fluid
18
import paddle.fluid.core as core
L
lujun 已提交
19
from paddle.fluid.dygraph.nn import Embedding
J
JiabinYang 已提交
20 21
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
L
lujun 已提交
22
from paddle.fluid.dygraph.base import to_variable
23
from paddle.fluid.dygraph import TracedLayer
24
from test_imperative_base import new_program_scope
J
JiabinYang 已提交
25
import numpy as np
26
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
J
Jiabin Yang 已提交
27
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
J
JiabinYang 已提交
28 29


30
class SimpleLSTMRNN(fluid.Layer):
31

J
JiabinYang 已提交
32 33 34 35 36 37
    def __init__(self,
                 hidden_size,
                 num_steps,
                 num_layers=2,
                 init_scale=0.1,
                 dropout=None):
38
        super(SimpleLSTMRNN, self).__init__()
J
JiabinYang 已提交
39 40 41 42
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._init_scale = init_scale
        self._dropout = dropout
43 44
        self._input = None
        self._num_steps = num_steps
45 46
        self.cell_array = []
        self.hidden_array = []
47
        self._create_parameter()
J
JiabinYang 已提交
48

49
    def _create_parameter(self):
J
JiabinYang 已提交
50 51 52 53 54 55
        self.weight_1_arr = []
        self.weight_2_arr = []
        self.bias_arr = []
        self.mask_array = []

        for i in range(self._num_layers):
56
            weight_1 = self.create_parameter(
57 58 59
                attr=fluid.ParamAttr(
                    initializer=fluid.initializer.UniformInitializer(
                        low=-self._init_scale, high=self._init_scale)),
J
JiabinYang 已提交
60 61 62 63
                shape=[self._hidden_size * 2, self._hidden_size * 4],
                dtype="float32",
                default_initializer=fluid.initializer.UniformInitializer(
                    low=-self._init_scale, high=self._init_scale))
64
            self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1))
65
            bias_1 = self.create_parameter(
66 67 68 69
                attr=fluid.ParamAttr(
                    initializer=fluid.initializer.UniformInitializer(
                        low=-self._init_scale, high=self._init_scale)),
                shape=[self._hidden_size * 4],
J
JiabinYang 已提交
70 71
                dtype="float32",
                default_initializer=fluid.initializer.Constant(0.0))
72
            self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))
J
JiabinYang 已提交
73

74 75 76 77 78
    def forward(self, input_embedding, init_hidden=None, init_cell=None):
        self.cell_array = []
        self.hidden_array = []

        for i in range(self._num_layers):
79 80 81 82 83 84 85 86 87 88 89 90
            pre_hidden = fluid.layers.slice(init_hidden,
                                            axes=[0],
                                            starts=[i],
                                            ends=[i + 1])
            pre_cell = fluid.layers.slice(init_cell,
                                          axes=[0],
                                          starts=[i],
                                          ends=[i + 1])
            pre_hidden = fluid.layers.reshape(pre_hidden,
                                              shape=[-1, self._hidden_size])
            pre_cell = fluid.layers.reshape(pre_cell,
                                            shape=[-1, self._hidden_size])
J
JiabinYang 已提交
91 92 93 94
            self.hidden_array.append(pre_hidden)
            self.cell_array.append(pre_cell)

        res = []
95
        for index in range(self._num_steps):
96 97 98 99 100 101
            self._input = fluid.layers.slice(input_embedding,
                                             axes=[1],
                                             starts=[index],
                                             ends=[index + 1])
            self._input = fluid.layers.reshape(self._input,
                                               shape=[-1, self._hidden_size])
J
JiabinYang 已提交
102 103 104 105 106 107
            for k in range(self._num_layers):
                pre_hidden = self.hidden_array[k]
                pre_cell = self.cell_array[k]
                weight_1 = self.weight_1_arr[k]
                bias = self.bias_arr[k]

108
                nn = fluid.layers.concat([self._input, pre_hidden], 1)
J
JiabinYang 已提交
109 110 111
                gate_input = fluid.layers.matmul(x=nn, y=weight_1)

                gate_input = fluid.layers.elementwise_add(gate_input, bias)
112 113 114
                i, j, f, o = fluid.layers.split(gate_input,
                                                num_or_sections=4,
                                                dim=-1)
115 116 117 118 119 120 121 122 123 124 125 126 127
                c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid(
                    i) * fluid.layers.tanh(j)
                m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o)
                self.hidden_array[k] = m
                self.cell_array[k] = c
                self._input = m

                if self._dropout is not None and self._dropout > 0.0:
                    self._input = fluid.layers.dropout(
                        self._input,
                        dropout_prob=self._dropout,
                        dropout_implementation='upscale_in_train')
            res.append(
128 129
                fluid.layers.reshape(self._input,
                                     shape=[1, -1, self._hidden_size]))
130 131 132 133 134 135 136 137 138 139 140
        real_res = fluid.layers.concat(res, 0)
        real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2])
        last_hidden = fluid.layers.concat(self.hidden_array, 1)
        last_hidden = fluid.layers.reshape(
            last_hidden, shape=[-1, self._num_layers, self._hidden_size])
        last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2])
        last_cell = fluid.layers.concat(self.cell_array, 1)
        last_cell = fluid.layers.reshape(
            last_cell, shape=[-1, self._num_layers, self._hidden_size])
        last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2])
        return real_res, last_hidden, last_cell
J
JiabinYang 已提交
141 142


143
class PtbModel(fluid.Layer):
144

J
JiabinYang 已提交
145 146 147 148 149 150
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 num_layers=2,
                 num_steps=20,
                 init_scale=0.1,
151
                 is_sparse=False,
J
JiabinYang 已提交
152
                 dropout=None):
153
        super(PtbModel, self).__init__()
J
JiabinYang 已提交
154 155 156 157 158 159
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.init_scale = init_scale
        self.num_layers = num_layers
        self.num_steps = num_steps
        self.dropout = dropout
160 161 162 163 164
        self.simple_lstm_rnn = SimpleLSTMRNN(hidden_size,
                                             num_steps,
                                             num_layers=num_layers,
                                             init_scale=init_scale,
                                             dropout=dropout)
165
        self.embedding = Embedding(
J
JiabinYang 已提交
166 167
            size=[vocab_size, hidden_size],
            dtype='float32',
168
            is_sparse=is_sparse,
J
JiabinYang 已提交
169 170 171 172
            param_attr=fluid.ParamAttr(
                name='embedding_para',
                initializer=fluid.initializer.UniformInitializer(
                    low=-init_scale, high=init_scale)))
173
        self.softmax_weight = self.create_parameter(
174 175
            attr=fluid.ParamAttr(),
            shape=[self.hidden_size, self.vocab_size],
J
JiabinYang 已提交
176 177 178
            dtype="float32",
            default_initializer=fluid.initializer.UniformInitializer(
                low=-self.init_scale, high=self.init_scale))
179
        self.softmax_bias = self.create_parameter(
180 181
            attr=fluid.ParamAttr(),
            shape=[self.vocab_size],
J
JiabinYang 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
            dtype="float32",
            default_initializer=fluid.initializer.UniformInitializer(
                low=-self.init_scale, high=self.init_scale))

    def forward(self, input, label, init_hidden, init_cell):
        init_h = fluid.layers.reshape(
            init_hidden, shape=[self.num_layers, -1, self.hidden_size])

        init_c = fluid.layers.reshape(
            init_cell, shape=[self.num_layers, -1, self.hidden_size])

        x_emb = self.embedding(input)
        x_emb = fluid.layers.reshape(
            x_emb, shape=[-1, self.num_steps, self.hidden_size])
        if self.dropout is not None and self.dropout > 0.0:
            x_emb = fluid.layers.dropout(
                x_emb,
                dropout_prob=self.drop_out,
                dropout_implementation='upscale_in_train')
201 202
        rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(
            x_emb, init_h, init_c)
J
JiabinYang 已提交
203 204
        rnn_out = fluid.layers.reshape(
            rnn_out, shape=[-1, self.num_steps, self.hidden_size])
205
        projection = fluid.layers.matmul(rnn_out, self.softmax_weight)
J
JiabinYang 已提交
206
        projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
207 208 209 210 211
        projection = fluid.layers.reshape(projection,
                                          shape=[-1, self.vocab_size])
        loss = fluid.layers.softmax_with_cross_entropy(logits=projection,
                                                       label=label,
                                                       soft_label=False)
J
JiabinYang 已提交
212 213 214 215 216 217 218
        loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
        loss = fluid.layers.reduce_mean(loss, dim=[0])
        loss = fluid.layers.reduce_sum(loss)

        return loss, last_hidden, last_cell


L
lujun 已提交
219
class TestDygraphPtbRnn(unittest.TestCase):
220

221
    def func_test_ptb_rnn(self):
222 223 224
        for is_sparse in [True, False]:
            self.ptb_rnn_cpu_float32(is_sparse)

225 226 227 228 229
    def test_ptb_rnn(self):
        with _test_eager_guard():
            self.func_test_ptb_rnn()
        self.func_test_ptb_rnn()

230
    def ptb_rnn_cpu_float32(self, is_sparse):
J
JiabinYang 已提交
231 232 233 234 235 236 237
        seed = 90
        hidden_size = 10
        vocab_size = 1000
        num_layers = 1
        num_steps = 3
        init_scale = 0.1
        batch_size = 4
238
        batch_num = 200
239 240
        traced_layer = None

L
lujun 已提交
241
        with fluid.dygraph.guard():
C
cnn 已提交
242
            paddle.seed(seed)
L
Leo Chen 已提交
243
            paddle.framework.random._manual_program_seed(seed)
J
JiabinYang 已提交
244
            # TODO: marsyang1993 Change seed to
245 246 247 248 249 250 251 252 253
            ptb_model = PtbModel(hidden_size=hidden_size,
                                 vocab_size=vocab_size,
                                 num_layers=num_layers,
                                 num_steps=num_steps,
                                 init_scale=init_scale,
                                 is_sparse=is_sparse)

            sgd = SGDOptimizer(learning_rate=1e-3,
                               parameter_list=ptb_model.parameters())
254 255
            dy_param_updated = dict()
            dy_param_init = dict()
J
JiabinYang 已提交
256 257 258
            dy_loss = None
            last_hidden = None
            last_cell = None
259

260 261
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
262

263
            for i in range(batch_num):
J
JiabinYang 已提交
264 265 266 267 268
                x_data = np.arange(12).reshape(4, 3).astype('int64')
                y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
                y_data = y_data.reshape((-1, 1))
                init_hidden_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32')
269 270
                init_cell_data = np.zeros((num_layers, batch_size, hidden_size),
                                          dtype='float32')
J
JiabinYang 已提交
271 272 273 274
                x = to_variable(x_data)
                y = to_variable(y_data)
                init_hidden = to_variable(init_hidden_data)
                init_cell = to_variable(init_cell_data)
J
Jiabin Yang 已提交
275
                if i % 5 == 0 and _in_legacy_dygraph():
276 277 278
                    outs, traced_layer = TracedLayer.trace(
                        ptb_model, [x, y, init_hidden, init_cell])
                    outs_static = traced_layer([x, y, init_hidden, init_cell])
279
                    helper.assertEachVar(outs, outs_static)
280 281 282 283 284 285 286 287

                    if program is not None:
                        self.assertTrue(
                            is_equal_program(traced_layer.program, program))

                    program = traced_layer.program

                    traced_layer.save_inference_model(
288
                        './infe_imperative_ptb_rnn', feed=list(range(4)))
289 290 291 292 293
                else:
                    outs = ptb_model(x, y, init_hidden, init_cell)

                dy_loss, last_hidden, last_cell = outs

J
JiabinYang 已提交
294
                if i == 0:
295
                    for param in ptb_model.parameters():
296
                        dy_param_init[param.name] = param.numpy()
L
lujun 已提交
297
                dy_loss.backward()
J
JiabinYang 已提交
298
                sgd.minimize(dy_loss)
299 300 301
                ptb_model.clear_gradients()
                if i == batch_num - 1:
                    for param in ptb_model.parameters():
302
                        dy_param_updated[param.name] = param.numpy()
303

304 305 306 307
            dy_loss_value = dy_loss.numpy()
            dy_last_cell_value = last_cell.numpy()
            dy_last_hidden_value = last_hidden.numpy()

308
        with new_program_scope():
C
cnn 已提交
309
            paddle.seed(seed)
L
Leo Chen 已提交
310
            paddle.framework.random._manual_program_seed(seed)
311 312 313 314 315 316
            ptb_model = PtbModel(hidden_size=hidden_size,
                                 vocab_size=vocab_size,
                                 num_layers=num_layers,
                                 num_steps=num_steps,
                                 init_scale=init_scale,
                                 is_sparse=is_sparse)
317

318 319
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
320
            sgd = SGDOptimizer(learning_rate=1e-3)
321 322 323
            x = fluid.layers.data(name="x",
                                  shape=[-1, num_steps],
                                  dtype='int64')
324
            y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
325 326 327 328 329 330
            init_hidden = fluid.layers.data(name="init_hidden",
                                            shape=[1],
                                            dtype='float32')
            init_cell = fluid.layers.data(name="init_cell",
                                          shape=[1],
                                          dtype='float32')
331 332 333 334 335 336 337

            static_loss, static_last_hidden, static_last_cell = ptb_model(
                x, y, init_hidden, init_cell)
            sgd.minimize(static_loss)
            static_param_updated = dict()
            static_param_init = dict()
            static_param_name_list = list()
338
            for param in ptb_model.parameters():
339 340 341 342 343 344
                static_param_name_list.append(param.name)

            out = exe.run(framework.default_startup_program(),
                          fetch_list=static_param_name_list)
            for i in range(len(static_param_name_list)):
                static_param_init[static_param_name_list[i]] = out[i]
J
JiabinYang 已提交
345 346 347
            static_loss_value = None
            static_last_cell_value = None
            static_last_hidden_value = None
348
            for i in range(batch_num):
349 350 351 352 353 354
                x_data = np.arange(12).reshape(4, 3).astype('int64')
                y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
                x_data = x_data.reshape((-1, num_steps, 1))
                y_data = y_data.reshape((-1, 1))
                init_hidden_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32')
355 356
                init_cell_data = np.zeros((num_layers, batch_size, hidden_size),
                                          dtype='float32')
357 358 359 360 361 362 363 364 365 366 367
                fetch_list = [static_loss, static_last_hidden, static_last_cell]
                fetch_list.extend(static_param_name_list)
                out = exe.run(fluid.default_main_program(),
                              feed={
                                  "x": x_data,
                                  "y": y_data,
                                  "init_hidden": init_hidden_data,
                                  "init_cell": init_cell_data
                              },
                              fetch_list=fetch_list)
                static_loss_value = out[0]
368 369
                static_last_hidden_value = out[1]
                static_last_cell_value = out[2]
J
JiabinYang 已提交
370

371 372 373 374 375
                if i == batch_num - 1:
                    for k in range(3, len(out)):
                        static_param_updated[static_param_name_list[k -
                                                                    3]] = out[k]

376 377 378 379 380
        np.testing.assert_array_equal(static_loss_value, dy_loss_value)
        np.testing.assert_array_equal(static_last_cell_value,
                                      dy_last_cell_value)
        np.testing.assert_array_equal(static_last_hidden_value,
                                      dy_last_hidden_value)
381
        for key, value in static_param_init.items():
382
            np.testing.assert_array_equal(value, dy_param_init[key])
383
        for key, value in static_param_updated.items():
384
            np.testing.assert_array_equal(value, dy_param_updated[key])
J
JiabinYang 已提交
385 386 387 388


if __name__ == '__main__':
    unittest.main()