test_imperative_ptb_rnn.py 15.8 KB
Newer Older
J
JiabinYang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
L
Leo Chen 已提交
18
import paddle
J
JiabinYang 已提交
19
import paddle.fluid as fluid
20
import paddle.fluid.core as core
L
lujun 已提交
21
from paddle.fluid.dygraph.nn import Embedding
J
JiabinYang 已提交
22 23
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
L
lujun 已提交
24
from paddle.fluid.dygraph.base import to_variable
25
from paddle.fluid.dygraph import TracedLayer
26
from test_imperative_base import new_program_scope
J
JiabinYang 已提交
27
import numpy as np
28
import six
29
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
30
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode
J
JiabinYang 已提交
31 32


33
class SimpleLSTMRNN(fluid.Layer):
J
JiabinYang 已提交
34 35 36 37 38 39
    def __init__(self,
                 hidden_size,
                 num_steps,
                 num_layers=2,
                 init_scale=0.1,
                 dropout=None):
40
        super(SimpleLSTMRNN, self).__init__()
J
JiabinYang 已提交
41 42 43 44
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._init_scale = init_scale
        self._dropout = dropout
45 46
        self._input = None
        self._num_steps = num_steps
47 48
        self.cell_array = []
        self.hidden_array = []
49
        self._create_parameter()
J
JiabinYang 已提交
50

51
    def _create_parameter(self):
J
JiabinYang 已提交
52 53 54 55 56 57
        self.weight_1_arr = []
        self.weight_2_arr = []
        self.bias_arr = []
        self.mask_array = []

        for i in range(self._num_layers):
58
            weight_1 = self.create_parameter(
59 60 61
                attr=fluid.ParamAttr(
                    initializer=fluid.initializer.UniformInitializer(
                        low=-self._init_scale, high=self._init_scale)),
J
JiabinYang 已提交
62 63 64 65
                shape=[self._hidden_size * 2, self._hidden_size * 4],
                dtype="float32",
                default_initializer=fluid.initializer.UniformInitializer(
                    low=-self._init_scale, high=self._init_scale))
66
            self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1))
67
            bias_1 = self.create_parameter(
68 69 70 71
                attr=fluid.ParamAttr(
                    initializer=fluid.initializer.UniformInitializer(
                        low=-self._init_scale, high=self._init_scale)),
                shape=[self._hidden_size * 4],
J
JiabinYang 已提交
72 73
                dtype="float32",
                default_initializer=fluid.initializer.Constant(0.0))
74
            self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))
J
JiabinYang 已提交
75

76 77 78 79 80
    def forward(self, input_embedding, init_hidden=None, init_cell=None):
        self.cell_array = []
        self.hidden_array = []

        for i in range(self._num_layers):
J
JiabinYang 已提交
81 82 83 84 85 86 87 88 89 90 91 92
            pre_hidden = fluid.layers.slice(
                init_hidden, axes=[0], starts=[i], ends=[i + 1])
            pre_cell = fluid.layers.slice(
                init_cell, axes=[0], starts=[i], ends=[i + 1])
            pre_hidden = fluid.layers.reshape(
                pre_hidden, shape=[-1, self._hidden_size])
            pre_cell = fluid.layers.reshape(
                pre_cell, shape=[-1, self._hidden_size])
            self.hidden_array.append(pre_hidden)
            self.cell_array.append(pre_cell)

        res = []
93 94
        for index in range(self._num_steps):
            self._input = fluid.layers.slice(
J
JiabinYang 已提交
95
                input_embedding, axes=[1], starts=[index], ends=[index + 1])
96 97
            self._input = fluid.layers.reshape(
                self._input, shape=[-1, self._hidden_size])
J
JiabinYang 已提交
98 99 100 101 102 103
            for k in range(self._num_layers):
                pre_hidden = self.hidden_array[k]
                pre_cell = self.cell_array[k]
                weight_1 = self.weight_1_arr[k]
                bias = self.bias_arr[k]

104
                nn = fluid.layers.concat([self._input, pre_hidden], 1)
J
JiabinYang 已提交
105 106 107
                gate_input = fluid.layers.matmul(x=nn, y=weight_1)

                gate_input = fluid.layers.elementwise_add(gate_input, bias)
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
                i, j, f, o = fluid.layers.split(
                    gate_input, num_or_sections=4, dim=-1)
                c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid(
                    i) * fluid.layers.tanh(j)
                m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o)
                self.hidden_array[k] = m
                self.cell_array[k] = c
                self._input = m

                if self._dropout is not None and self._dropout > 0.0:
                    self._input = fluid.layers.dropout(
                        self._input,
                        dropout_prob=self._dropout,
                        dropout_implementation='upscale_in_train')
            res.append(
                fluid.layers.reshape(
                    self._input, shape=[1, -1, self._hidden_size]))
        real_res = fluid.layers.concat(res, 0)
        real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2])
        last_hidden = fluid.layers.concat(self.hidden_array, 1)
        last_hidden = fluid.layers.reshape(
            last_hidden, shape=[-1, self._num_layers, self._hidden_size])
        last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2])
        last_cell = fluid.layers.concat(self.cell_array, 1)
        last_cell = fluid.layers.reshape(
            last_cell, shape=[-1, self._num_layers, self._hidden_size])
        last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2])
        return real_res, last_hidden, last_cell
J
JiabinYang 已提交
136 137


138
class PtbModel(fluid.Layer):
J
JiabinYang 已提交
139 140 141 142 143 144
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 num_layers=2,
                 num_steps=20,
                 init_scale=0.1,
145
                 is_sparse=False,
J
JiabinYang 已提交
146
                 dropout=None):
147
        super(PtbModel, self).__init__()
J
JiabinYang 已提交
148 149 150 151 152 153 154 155 156 157 158 159
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.init_scale = init_scale
        self.num_layers = num_layers
        self.num_steps = num_steps
        self.dropout = dropout
        self.simple_lstm_rnn = SimpleLSTMRNN(
            hidden_size,
            num_steps,
            num_layers=num_layers,
            init_scale=init_scale,
            dropout=dropout)
160
        self.embedding = Embedding(
J
JiabinYang 已提交
161 162
            size=[vocab_size, hidden_size],
            dtype='float32',
163
            is_sparse=is_sparse,
J
JiabinYang 已提交
164 165 166 167
            param_attr=fluid.ParamAttr(
                name='embedding_para',
                initializer=fluid.initializer.UniformInitializer(
                    low=-init_scale, high=init_scale)))
168
        self.softmax_weight = self.create_parameter(
169 170
            attr=fluid.ParamAttr(),
            shape=[self.hidden_size, self.vocab_size],
J
JiabinYang 已提交
171 172 173
            dtype="float32",
            default_initializer=fluid.initializer.UniformInitializer(
                low=-self.init_scale, high=self.init_scale))
174
        self.softmax_bias = self.create_parameter(
175 176
            attr=fluid.ParamAttr(),
            shape=[self.vocab_size],
J
JiabinYang 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
            dtype="float32",
            default_initializer=fluid.initializer.UniformInitializer(
                low=-self.init_scale, high=self.init_scale))

    def forward(self, input, label, init_hidden, init_cell):
        init_h = fluid.layers.reshape(
            init_hidden, shape=[self.num_layers, -1, self.hidden_size])

        init_c = fluid.layers.reshape(
            init_cell, shape=[self.num_layers, -1, self.hidden_size])

        x_emb = self.embedding(input)
        x_emb = fluid.layers.reshape(
            x_emb, shape=[-1, self.num_steps, self.hidden_size])
        if self.dropout is not None and self.dropout > 0.0:
            x_emb = fluid.layers.dropout(
                x_emb,
                dropout_prob=self.drop_out,
                dropout_implementation='upscale_in_train')
        rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(x_emb, init_h,
                                                               init_c)
        rnn_out = fluid.layers.reshape(
            rnn_out, shape=[-1, self.num_steps, self.hidden_size])
200
        projection = fluid.layers.matmul(rnn_out, self.softmax_weight)
J
JiabinYang 已提交
201 202 203 204 205 206 207 208 209 210 211 212
        projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
        projection = fluid.layers.reshape(
            projection, shape=[-1, self.vocab_size])
        loss = fluid.layers.softmax_with_cross_entropy(
            logits=projection, label=label, soft_label=False)
        loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
        loss = fluid.layers.reduce_mean(loss, dim=[0])
        loss = fluid.layers.reduce_sum(loss)

        return loss, last_hidden, last_cell


L
lujun 已提交
213
class TestDygraphPtbRnn(unittest.TestCase):
214
    def func_test_ptb_rnn(self):
215 216 217
        for is_sparse in [True, False]:
            self.ptb_rnn_cpu_float32(is_sparse)

218 219 220 221 222
    def test_ptb_rnn(self):
        with _test_eager_guard():
            self.func_test_ptb_rnn()
        self.func_test_ptb_rnn()

223
    def ptb_rnn_cpu_float32(self, is_sparse):
J
JiabinYang 已提交
224 225 226 227 228 229 230
        seed = 90
        hidden_size = 10
        vocab_size = 1000
        num_layers = 1
        num_steps = 3
        init_scale = 0.1
        batch_size = 4
231
        batch_num = 200
232 233
        traced_layer = None

L
lujun 已提交
234
        with fluid.dygraph.guard():
C
cnn 已提交
235
            paddle.seed(seed)
L
Leo Chen 已提交
236
            paddle.framework.random._manual_program_seed(seed)
J
JiabinYang 已提交
237 238 239 240 241 242
            # TODO: marsyang1993 Change seed to
            ptb_model = PtbModel(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
243 244
                init_scale=init_scale,
                is_sparse=is_sparse)
J
JiabinYang 已提交
245

246 247
            sgd = SGDOptimizer(
                learning_rate=1e-3, parameter_list=ptb_model.parameters())
248 249
            dy_param_updated = dict()
            dy_param_init = dict()
J
JiabinYang 已提交
250 251 252
            dy_loss = None
            last_hidden = None
            last_cell = None
253

254 255
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
256

257
            for i in range(batch_num):
J
JiabinYang 已提交
258 259 260 261 262 263 264 265 266 267 268
                x_data = np.arange(12).reshape(4, 3).astype('int64')
                y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
                y_data = y_data.reshape((-1, 1))
                init_hidden_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32')
                init_cell_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32')
                x = to_variable(x_data)
                y = to_variable(y_data)
                init_hidden = to_variable(init_hidden_data)
                init_cell = to_variable(init_cell_data)
269
                if i % 5 == 0 and (not _in_eager_mode()):
270 271 272
                    outs, traced_layer = TracedLayer.trace(
                        ptb_model, [x, y, init_hidden, init_cell])
                    outs_static = traced_layer([x, y, init_hidden, init_cell])
273
                    helper.assertEachVar(outs, outs_static)
274 275 276 277 278 279 280 281

                    if program is not None:
                        self.assertTrue(
                            is_equal_program(traced_layer.program, program))

                    program = traced_layer.program

                    traced_layer.save_inference_model(
282
                        './infe_imperative_ptb_rnn', feed=list(range(4)))
283 284 285 286 287
                else:
                    outs = ptb_model(x, y, init_hidden, init_cell)

                dy_loss, last_hidden, last_cell = outs

J
JiabinYang 已提交
288
                if i == 0:
289
                    for param in ptb_model.parameters():
290
                        dy_param_init[param.name] = param.numpy()
L
lujun 已提交
291
                dy_loss.backward()
J
JiabinYang 已提交
292
                sgd.minimize(dy_loss)
293 294 295
                ptb_model.clear_gradients()
                if i == batch_num - 1:
                    for param in ptb_model.parameters():
296
                        dy_param_updated[param.name] = param.numpy()
297

298 299 300 301
            dy_loss_value = dy_loss.numpy()
            dy_last_cell_value = last_cell.numpy()
            dy_last_hidden_value = last_hidden.numpy()

302
        with new_program_scope():
C
cnn 已提交
303
            paddle.seed(seed)
L
Leo Chen 已提交
304
            paddle.framework.random._manual_program_seed(seed)
305 306 307 308 309
            ptb_model = PtbModel(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
310 311
                init_scale=init_scale,
                is_sparse=is_sparse)
312

313 314
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
315
            sgd = SGDOptimizer(learning_rate=1e-3)
316
            x = fluid.layers.data(
317
                name="x", shape=[-1, num_steps], dtype='int64')
318 319 320 321 322 323 324 325 326 327 328 329
            y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
            init_hidden = fluid.layers.data(
                name="init_hidden", shape=[1], dtype='float32')
            init_cell = fluid.layers.data(
                name="init_cell", shape=[1], dtype='float32')

            static_loss, static_last_hidden, static_last_cell = ptb_model(
                x, y, init_hidden, init_cell)
            sgd.minimize(static_loss)
            static_param_updated = dict()
            static_param_init = dict()
            static_param_name_list = list()
330
            for param in ptb_model.parameters():
331 332 333 334 335 336
                static_param_name_list.append(param.name)

            out = exe.run(framework.default_startup_program(),
                          fetch_list=static_param_name_list)
            for i in range(len(static_param_name_list)):
                static_param_init[static_param_name_list[i]] = out[i]
J
JiabinYang 已提交
337 338 339
            static_loss_value = None
            static_last_cell_value = None
            static_last_hidden_value = None
340
            for i in range(batch_num):
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
                x_data = np.arange(12).reshape(4, 3).astype('int64')
                y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
                x_data = x_data.reshape((-1, num_steps, 1))
                y_data = y_data.reshape((-1, 1))
                init_hidden_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32')
                init_cell_data = np.zeros(
                    (num_layers, batch_size, hidden_size), dtype='float32')
                fetch_list = [static_loss, static_last_hidden, static_last_cell]
                fetch_list.extend(static_param_name_list)
                out = exe.run(fluid.default_main_program(),
                              feed={
                                  "x": x_data,
                                  "y": y_data,
                                  "init_hidden": init_hidden_data,
                                  "init_cell": init_cell_data
                              },
                              fetch_list=fetch_list)
                static_loss_value = out[0]
360 361
                static_last_hidden_value = out[1]
                static_last_cell_value = out[2]
J
JiabinYang 已提交
362

363 364 365 366 367
                if i == batch_num - 1:
                    for k in range(3, len(out)):
                        static_param_updated[static_param_name_list[k -
                                                                    3]] = out[k]

368
        self.assertTrue(np.array_equal(static_loss_value, dy_loss_value))
369
        self.assertTrue(
370
            np.array_equal(static_last_cell_value, dy_last_cell_value))
371
        self.assertTrue(
372
            np.array_equal(static_last_hidden_value, dy_last_hidden_value))
373
        for key, value in six.iteritems(static_param_init):
374
            self.assertTrue(np.array_equal(value, dy_param_init[key]))
375
        for key, value in six.iteritems(static_param_updated):
376
            self.assertTrue(np.array_equal(value, dy_param_updated[key]))
J
JiabinYang 已提交
377 378 379 380


if __name__ == '__main__':
    unittest.main()