test_imperative_ptb_rnn.py 16.3 KB
Newer Older
J
JiabinYang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
L
Leo Chen 已提交
16
import paddle
J
JiabinYang 已提交
17
import paddle.fluid as fluid
18
import paddle.fluid.core as core
L
lujun 已提交
19
from paddle.fluid.dygraph.nn import Embedding
J
JiabinYang 已提交
20 21
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
L
lujun 已提交
22
from paddle.fluid.dygraph.base import to_variable
23
from paddle.fluid.dygraph import TracedLayer
24
from test_imperative_base import new_program_scope
J
JiabinYang 已提交
25
import numpy as np
26
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
J
Jiabin Yang 已提交
27
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
J
JiabinYang 已提交
28 29


30
class SimpleLSTMRNN(fluid.Layer):
31 32 33
    def __init__(
        self, hidden_size, num_steps, num_layers=2, init_scale=0.1, dropout=None
    ):
34
        super().__init__()
J
JiabinYang 已提交
35 36 37 38
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._init_scale = init_scale
        self._dropout = dropout
39 40
        self._input = None
        self._num_steps = num_steps
41 42
        self.cell_array = []
        self.hidden_array = []
43
        self._create_parameter()
J
JiabinYang 已提交
44

45
    def _create_parameter(self):
J
JiabinYang 已提交
46 47 48 49 50 51
        self.weight_1_arr = []
        self.weight_2_arr = []
        self.bias_arr = []
        self.mask_array = []

        for i in range(self._num_layers):
52
            weight_1 = self.create_parameter(
53 54
                attr=fluid.ParamAttr(
                    initializer=fluid.initializer.UniformInitializer(
55 56 57
                        low=-self._init_scale, high=self._init_scale
                    )
                ),
J
JiabinYang 已提交
58 59 60
                shape=[self._hidden_size * 2, self._hidden_size * 4],
                dtype="float32",
                default_initializer=fluid.initializer.UniformInitializer(
61 62 63
                    low=-self._init_scale, high=self._init_scale
                ),
            )
64
            self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1))
65
            bias_1 = self.create_parameter(
66 67
                attr=fluid.ParamAttr(
                    initializer=fluid.initializer.UniformInitializer(
68 69 70
                        low=-self._init_scale, high=self._init_scale
                    )
                ),
71
                shape=[self._hidden_size * 4],
J
JiabinYang 已提交
72
                dtype="float32",
73 74
                default_initializer=fluid.initializer.Constant(0.0),
            )
75
            self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))
J
JiabinYang 已提交
76

77 78 79 80 81
    def forward(self, input_embedding, init_hidden=None, init_cell=None):
        self.cell_array = []
        self.hidden_array = []

        for i in range(self._num_layers):
82 83 84 85 86 87 88 89 90 91 92 93
            pre_hidden = fluid.layers.slice(
                init_hidden, axes=[0], starts=[i], ends=[i + 1]
            )
            pre_cell = fluid.layers.slice(
                init_cell, axes=[0], starts=[i], ends=[i + 1]
            )
            pre_hidden = fluid.layers.reshape(
                pre_hidden, shape=[-1, self._hidden_size]
            )
            pre_cell = fluid.layers.reshape(
                pre_cell, shape=[-1, self._hidden_size]
            )
J
JiabinYang 已提交
94 95 96 97
            self.hidden_array.append(pre_hidden)
            self.cell_array.append(pre_cell)

        res = []
98
        for index in range(self._num_steps):
99 100 101 102 103 104
            self._input = fluid.layers.slice(
                input_embedding, axes=[1], starts=[index], ends=[index + 1]
            )
            self._input = fluid.layers.reshape(
                self._input, shape=[-1, self._hidden_size]
            )
J
JiabinYang 已提交
105 106 107 108 109 110
            for k in range(self._num_layers):
                pre_hidden = self.hidden_array[k]
                pre_cell = self.cell_array[k]
                weight_1 = self.weight_1_arr[k]
                bias = self.bias_arr[k]

111
                nn = fluid.layers.concat([self._input, pre_hidden], 1)
J
JiabinYang 已提交
112 113 114
                gate_input = fluid.layers.matmul(x=nn, y=weight_1)

                gate_input = fluid.layers.elementwise_add(gate_input, bias)
115 116 117
                i, j, f, o = fluid.layers.split(
                    gate_input, num_or_sections=4, dim=-1
                )
118 119 120 121
                c = pre_cell * paddle.nn.functional.sigmoid(
                    f
                ) + paddle.nn.functional.sigmoid(i) * paddle.tanh(j)
                m = paddle.tanh(c) * paddle.nn.functional.sigmoid(o)
122 123 124 125 126 127 128 129
                self.hidden_array[k] = m
                self.cell_array[k] = c
                self._input = m

                if self._dropout is not None and self._dropout > 0.0:
                    self._input = fluid.layers.dropout(
                        self._input,
                        dropout_prob=self._dropout,
130 131
                        dropout_implementation='upscale_in_train',
                    )
132
            res.append(
133 134 135 136
                fluid.layers.reshape(
                    self._input, shape=[1, -1, self._hidden_size]
                )
            )
137 138 139 140
        real_res = fluid.layers.concat(res, 0)
        real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2])
        last_hidden = fluid.layers.concat(self.hidden_array, 1)
        last_hidden = fluid.layers.reshape(
141 142
            last_hidden, shape=[-1, self._num_layers, self._hidden_size]
        )
143 144 145
        last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2])
        last_cell = fluid.layers.concat(self.cell_array, 1)
        last_cell = fluid.layers.reshape(
146 147
            last_cell, shape=[-1, self._num_layers, self._hidden_size]
        )
148 149
        last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2])
        return real_res, last_hidden, last_cell
J
JiabinYang 已提交
150 151


152
class PtbModel(fluid.Layer):
153 154 155 156 157 158 159 160 161 162
    def __init__(
        self,
        hidden_size,
        vocab_size,
        num_layers=2,
        num_steps=20,
        init_scale=0.1,
        is_sparse=False,
        dropout=None,
    ):
163
        super().__init__()
J
JiabinYang 已提交
164 165 166 167 168 169
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.init_scale = init_scale
        self.num_layers = num_layers
        self.num_steps = num_steps
        self.dropout = dropout
170 171 172 173 174 175 176
        self.simple_lstm_rnn = SimpleLSTMRNN(
            hidden_size,
            num_steps,
            num_layers=num_layers,
            init_scale=init_scale,
            dropout=dropout,
        )
177
        self.embedding = Embedding(
J
JiabinYang 已提交
178 179
            size=[vocab_size, hidden_size],
            dtype='float32',
180
            is_sparse=is_sparse,
J
JiabinYang 已提交
181 182 183
            param_attr=fluid.ParamAttr(
                name='embedding_para',
                initializer=fluid.initializer.UniformInitializer(
184 185 186 187
                    low=-init_scale, high=init_scale
                ),
            ),
        )
188
        self.softmax_weight = self.create_parameter(
189 190
            attr=fluid.ParamAttr(),
            shape=[self.hidden_size, self.vocab_size],
J
JiabinYang 已提交
191 192
            dtype="float32",
            default_initializer=fluid.initializer.UniformInitializer(
193 194 195
                low=-self.init_scale, high=self.init_scale
            ),
        )
196
        self.softmax_bias = self.create_parameter(
197 198
            attr=fluid.ParamAttr(),
            shape=[self.vocab_size],
J
JiabinYang 已提交
199 200
            dtype="float32",
            default_initializer=fluid.initializer.UniformInitializer(
201 202 203
                low=-self.init_scale, high=self.init_scale
            ),
        )
J
JiabinYang 已提交
204 205 206

    def forward(self, input, label, init_hidden, init_cell):
        init_h = fluid.layers.reshape(
207 208
            init_hidden, shape=[self.num_layers, -1, self.hidden_size]
        )
J
JiabinYang 已提交
209 210

        init_c = fluid.layers.reshape(
211 212
            init_cell, shape=[self.num_layers, -1, self.hidden_size]
        )
J
JiabinYang 已提交
213 214 215

        x_emb = self.embedding(input)
        x_emb = fluid.layers.reshape(
216 217
            x_emb, shape=[-1, self.num_steps, self.hidden_size]
        )
J
JiabinYang 已提交
218 219 220 221
        if self.dropout is not None and self.dropout > 0.0:
            x_emb = fluid.layers.dropout(
                x_emb,
                dropout_prob=self.drop_out,
222 223
                dropout_implementation='upscale_in_train',
            )
224
        rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(
225 226
            x_emb, init_h, init_c
        )
J
JiabinYang 已提交
227
        rnn_out = fluid.layers.reshape(
228 229
            rnn_out, shape=[-1, self.num_steps, self.hidden_size]
        )
230
        projection = fluid.layers.matmul(rnn_out, self.softmax_weight)
J
JiabinYang 已提交
231
        projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
232 233 234 235 236 237
        projection = fluid.layers.reshape(
            projection, shape=[-1, self.vocab_size]
        )
        loss = fluid.layers.softmax_with_cross_entropy(
            logits=projection, label=label, soft_label=False
        )
J
JiabinYang 已提交
238 239 240 241 242 243 244
        loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
        loss = fluid.layers.reduce_mean(loss, dim=[0])
        loss = fluid.layers.reduce_sum(loss)

        return loss, last_hidden, last_cell


L
lujun 已提交
245
class TestDygraphPtbRnn(unittest.TestCase):
246
    def func_test_ptb_rnn(self):
247 248 249
        for is_sparse in [True, False]:
            self.ptb_rnn_cpu_float32(is_sparse)

250 251 252 253 254
    def test_ptb_rnn(self):
        with _test_eager_guard():
            self.func_test_ptb_rnn()
        self.func_test_ptb_rnn()

255
    def ptb_rnn_cpu_float32(self, is_sparse):
J
JiabinYang 已提交
256 257 258 259 260 261 262
        seed = 90
        hidden_size = 10
        vocab_size = 1000
        num_layers = 1
        num_steps = 3
        init_scale = 0.1
        batch_size = 4
263
        batch_num = 200
264 265
        traced_layer = None

L
lujun 已提交
266
        with fluid.dygraph.guard():
C
cnn 已提交
267
            paddle.seed(seed)
L
Leo Chen 已提交
268
            paddle.framework.random._manual_program_seed(seed)
J
JiabinYang 已提交
269
            # TODO: marsyang1993 Change seed to
270 271 272 273 274 275 276 277 278 279 280 281
            ptb_model = PtbModel(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
                init_scale=init_scale,
                is_sparse=is_sparse,
            )

            sgd = SGDOptimizer(
                learning_rate=1e-3, parameter_list=ptb_model.parameters()
            )
282 283
            dy_param_updated = dict()
            dy_param_init = dict()
J
JiabinYang 已提交
284 285 286
            dy_loss = None
            last_hidden = None
            last_cell = None
287

288 289
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
290

291
            for i in range(batch_num):
J
JiabinYang 已提交
292 293 294 295
                x_data = np.arange(12).reshape(4, 3).astype('int64')
                y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
                y_data = y_data.reshape((-1, 1))
                init_hidden_data = np.zeros(
296 297 298 299 300
                    (num_layers, batch_size, hidden_size), dtype='float32'
                )
                init_cell_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32'
                )
J
JiabinYang 已提交
301 302 303 304
                x = to_variable(x_data)
                y = to_variable(y_data)
                init_hidden = to_variable(init_hidden_data)
                init_cell = to_variable(init_cell_data)
J
Jiabin Yang 已提交
305
                if i % 5 == 0 and _in_legacy_dygraph():
306
                    outs, traced_layer = TracedLayer.trace(
307 308
                        ptb_model, [x, y, init_hidden, init_cell]
                    )
309
                    outs_static = traced_layer([x, y, init_hidden, init_cell])
310
                    helper.assertEachVar(outs, outs_static)
311 312 313

                    if program is not None:
                        self.assertTrue(
314 315
                            is_equal_program(traced_layer.program, program)
                        )
316 317 318 319

                    program = traced_layer.program

                    traced_layer.save_inference_model(
320 321
                        './infe_imperative_ptb_rnn', feed=list(range(4))
                    )
322 323 324 325 326
                else:
                    outs = ptb_model(x, y, init_hidden, init_cell)

                dy_loss, last_hidden, last_cell = outs

J
JiabinYang 已提交
327
                if i == 0:
328
                    for param in ptb_model.parameters():
329
                        dy_param_init[param.name] = param.numpy()
L
lujun 已提交
330
                dy_loss.backward()
J
JiabinYang 已提交
331
                sgd.minimize(dy_loss)
332 333 334
                ptb_model.clear_gradients()
                if i == batch_num - 1:
                    for param in ptb_model.parameters():
335
                        dy_param_updated[param.name] = param.numpy()
336

337 338 339 340
            dy_loss_value = dy_loss.numpy()
            dy_last_cell_value = last_cell.numpy()
            dy_last_hidden_value = last_hidden.numpy()

341
        with new_program_scope():
C
cnn 已提交
342
            paddle.seed(seed)
L
Leo Chen 已提交
343
            paddle.framework.random._manual_program_seed(seed)
344 345 346 347 348 349 350 351 352 353 354 355 356 357
            ptb_model = PtbModel(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
                init_scale=init_scale,
                is_sparse=is_sparse,
            )

            exe = fluid.Executor(
                fluid.CPUPlace()
                if not core.is_compiled_with_cuda()
                else fluid.CUDAPlace(0)
            )
358
            sgd = SGDOptimizer(learning_rate=1e-3)
359 360 361
            x = fluid.layers.data(
                name="x", shape=[-1, num_steps], dtype='int64'
            )
362
            y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
363 364 365 366 367 368
            init_hidden = fluid.layers.data(
                name="init_hidden", shape=[1], dtype='float32'
            )
            init_cell = fluid.layers.data(
                name="init_cell", shape=[1], dtype='float32'
            )
369 370

            static_loss, static_last_hidden, static_last_cell = ptb_model(
371 372
                x, y, init_hidden, init_cell
            )
373 374 375 376
            sgd.minimize(static_loss)
            static_param_updated = dict()
            static_param_init = dict()
            static_param_name_list = list()
377
            for param in ptb_model.parameters():
378 379
                static_param_name_list.append(param.name)

380 381 382 383
            out = exe.run(
                framework.default_startup_program(),
                fetch_list=static_param_name_list,
            )
384 385
            for i in range(len(static_param_name_list)):
                static_param_init[static_param_name_list[i]] = out[i]
J
JiabinYang 已提交
386 387 388
            static_loss_value = None
            static_last_cell_value = None
            static_last_hidden_value = None
389
            for i in range(batch_num):
390 391 392 393 394
                x_data = np.arange(12).reshape(4, 3).astype('int64')
                y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
                x_data = x_data.reshape((-1, num_steps, 1))
                y_data = y_data.reshape((-1, 1))
                init_hidden_data = np.zeros(
395 396 397 398 399
                    (num_layers, batch_size, hidden_size), dtype='float32'
                )
                init_cell_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32'
                )
400 401
                fetch_list = [static_loss, static_last_hidden, static_last_cell]
                fetch_list.extend(static_param_name_list)
402 403 404 405 406 407 408 409 410 411
                out = exe.run(
                    fluid.default_main_program(),
                    feed={
                        "x": x_data,
                        "y": y_data,
                        "init_hidden": init_hidden_data,
                        "init_cell": init_cell_data,
                    },
                    fetch_list=fetch_list,
                )
412
                static_loss_value = out[0]
413 414
                static_last_hidden_value = out[1]
                static_last_cell_value = out[2]
J
JiabinYang 已提交
415

416 417
                if i == batch_num - 1:
                    for k in range(3, len(out)):
418 419 420
                        static_param_updated[
                            static_param_name_list[k - 3]
                        ] = out[k]
421

422
        np.testing.assert_array_equal(static_loss_value, dy_loss_value)
423 424 425 426 427 428
        np.testing.assert_array_equal(
            static_last_cell_value, dy_last_cell_value
        )
        np.testing.assert_array_equal(
            static_last_hidden_value, dy_last_hidden_value
        )
429
        for key, value in static_param_init.items():
430
            np.testing.assert_array_equal(value, dy_param_init[key])
431
        for key, value in static_param_updated.items():
432
            np.testing.assert_array_equal(value, dy_param_updated[key])
J
JiabinYang 已提交
433 434 435 436


if __name__ == '__main__':
    unittest.main()