test_ptb_lm.py 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
import unittest

import numpy as np
20

L
Leo Chen 已提交
21
import paddle
22
from paddle import fluid
23 24
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.optimizer import SGDOptimizer
H
hjyp 已提交
25
from paddle.jit.api import to_static
26 27 28 29 30

PRINT_STEP = 20
SEED = 2020


31
class SimpleLSTMRNN(paddle.nn.Layer):
32 33 34
    def __init__(
        self, hidden_size, num_steps, num_layers=2, init_scale=0.1, dropout=None
    ):
35
        super().__init__()
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
        self._hidden_size = hidden_size
        self._num_layers = num_layers
        self._init_scale = init_scale
        self._dropout = dropout
        self._num_steps = num_steps
        self.cell_array = []
        self.hidden_array = []

        self.weight_1_arr = []
        self.weight_2_arr = []
        self.bias_arr = []
        self.mask_array = []

        for i in range(self._num_layers):
            weight_1 = self.create_parameter(
                attr=fluid.ParamAttr(
52
                    initializer=paddle.nn.initializer.Uniform(
53 54 55
                        low=-self._init_scale, high=self._init_scale
                    )
                ),
56 57
                shape=[self._hidden_size * 2, self._hidden_size * 4],
                dtype="float32",
58
                default_initializer=paddle.nn.initializer.Uniform(
59 60 61
                    low=-self._init_scale, high=self._init_scale
                ),
            )
62 63 64
            self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1))
            bias_1 = self.create_parameter(
                attr=fluid.ParamAttr(
65
                    initializer=paddle.nn.initializer.Uniform(
66 67 68
                        low=-self._init_scale, high=self._init_scale
                    )
                ),
69 70
                shape=[self._hidden_size * 4],
                dtype="float32",
71
                default_initializer=paddle.nn.initializer.Constant(0.0),
72
            )
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
            self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))

    def forward(self, input_embedding, init_hidden=None, init_cell=None):
        cell_array = []
        hidden_array = []

        for i in range(self._num_layers):
            hidden_array.append(init_hidden[i])
            cell_array.append(init_cell[i])

        res = []
        for index in range(self._num_steps):
            step_input = input_embedding[:, index, :]
            for k in range(self._num_layers):
                pre_hidden = hidden_array[k]
                pre_cell = cell_array[k]
                weight_1 = self.weight_1_arr[k]
                bias = self.bias_arr[k]

92
                nn = paddle.concat([step_input, pre_hidden], 1)
K
kangguangli 已提交
93
                gate_input = paddle.matmul(x=nn, y=weight_1)
94

95
                gate_input = paddle.add(gate_input, bias)
96 97
                i, j, f, o = paddle.split(
                    gate_input, num_or_sections=4, axis=-1
98
                )
99 100 101 102
                c = pre_cell * paddle.nn.functional.sigmoid(
                    f
                ) + paddle.nn.functional.sigmoid(i) * paddle.tanh(j)
                m = paddle.tanh(c) * paddle.nn.functional.sigmoid(o)
103 104 105 106 107
                hidden_array[k] = m
                cell_array[k] = c
                step_input = m

                if self._dropout is not None and self._dropout > 0.0:
C
ccrrong 已提交
108
                    step_input = paddle.nn.functional.dropout(
109
                        step_input,
C
ccrrong 已提交
110 111
                        p=self._dropout,
                        mode='upscale_in_train',
112
                    )
113
            res.append(step_input)
114
        real_res = paddle.concat(res, 1)
115
        real_res = paddle.reshape(
116 117
            real_res, [-1, self._num_steps, self._hidden_size]
        )
118
        last_hidden = paddle.concat(hidden_array, 1)
119
        last_hidden = paddle.reshape(
120 121
            last_hidden, shape=[-1, self._num_layers, self._hidden_size]
        )
122
        last_hidden = paddle.transpose(x=last_hidden, perm=[1, 0, 2])
123
        last_cell = paddle.concat(cell_array, 1)
124
        last_cell = paddle.reshape(
125 126
            last_cell, shape=[-1, self._num_layers, self._hidden_size]
        )
127
        last_cell = paddle.transpose(x=last_cell, perm=[1, 0, 2])
128 129 130
        return real_res, last_hidden, last_cell


131
class PtbModel(paddle.nn.Layer):
132 133 134 135 136 137 138 139 140
    def __init__(
        self,
        hidden_size,
        vocab_size,
        num_layers=2,
        num_steps=20,
        init_scale=0.1,
        dropout=None,
    ):
141
        super().__init__()
142 143 144 145 146 147
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.init_scale = init_scale
        self.num_layers = num_layers
        self.num_steps = num_steps
        self.dropout = dropout
148 149 150 151 152 153 154
        self.simple_lstm_rnn = SimpleLSTMRNN(
            hidden_size,
            num_steps,
            num_layers=num_layers,
            init_scale=init_scale,
            dropout=dropout,
        )
155 156 157 158 159
        self.embedding = paddle.nn.Embedding(
            vocab_size,
            hidden_size,
            sparse=False,
            weight_attr=fluid.ParamAttr(
160
                name='embedding_para',
161
                initializer=paddle.nn.initializer.Uniform(
162 163 164 165
                    low=-init_scale, high=init_scale
                ),
            ),
        )
166 167 168 169
        self.softmax_weight = self.create_parameter(
            attr=fluid.ParamAttr(),
            shape=[self.hidden_size, self.vocab_size],
            dtype="float32",
170
            default_initializer=paddle.nn.initializer.Uniform(
171 172 173
                low=-self.init_scale, high=self.init_scale
            ),
        )
174 175 176 177
        self.softmax_bias = self.create_parameter(
            attr=fluid.ParamAttr(),
            shape=[self.vocab_size],
            dtype="float32",
178
            default_initializer=paddle.nn.initializer.Uniform(
179 180 181
                low=-self.init_scale, high=self.init_scale
            ),
        )
182 183 184 185

    def build_once(self, input, label, init_hidden, init_cell):
        pass

H
hjyp 已提交
186
    @to_static
187 188
    def forward(self, input, label, init_hidden, init_cell):

189
        init_h = paddle.reshape(
190 191
            init_hidden, shape=[self.num_layers, -1, self.hidden_size]
        )
192

193
        init_c = paddle.reshape(
194 195
            init_cell, shape=[self.num_layers, -1, self.hidden_size]
        )
196 197 198

        x_emb = self.embedding(input)

199
        x_emb = paddle.reshape(
200 201
            x_emb, shape=[-1, self.num_steps, self.hidden_size]
        )
202
        if self.dropout is not None and self.dropout > 0.0:
C
ccrrong 已提交
203
            x_emb = paddle.nn.functional.dropout(
204
                x_emb,
C
ccrrong 已提交
205 206
                p=self.dropout,
                mode='upscale_in_train',
207
            )
208
        rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(
209 210
            x_emb, init_h, init_c
        )
211

K
kangguangli 已提交
212
        projection = paddle.matmul(rnn_out, self.softmax_weight)
213
        projection = paddle.add(projection, self.softmax_bias)
214

215
        loss = paddle.nn.functional.softmax_with_cross_entropy(
216 217
            logits=projection, label=label, soft_label=False
        )
218
        loss = paddle.reshape(loss, shape=[-1, self.num_steps])
219
        loss = paddle.mean(loss, axis=[0])
220
        loss = paddle.sum(loss)
221 222 223 224 225 226 227 228

        return loss, last_hidden, last_cell

    def debug_emb(self):

        np.save("emb_grad", self.x_emb.gradient())


229 230
def train(place):

231 232 233 234 235 236 237 238 239 240 241
    num_layers = 1
    batch_size = 4
    hidden_size = 10
    num_steps = 3
    init_scale = 0.1
    max_epoch = 1
    dropout = 0.0
    vocab_size = 1000
    batch_num = 200

    with fluid.dygraph.guard(place):
C
cnn 已提交
242
        paddle.seed(SEED)
L
Leo Chen 已提交
243
        paddle.framework.random._manual_program_seed(SEED)
244 245 246 247 248 249 250 251 252 253 254 255
        ptb_model = PtbModel(
            hidden_size=hidden_size,
            vocab_size=vocab_size,
            num_layers=num_layers,
            num_steps=num_steps,
            init_scale=init_scale,
            dropout=dropout,
        )

        sgd = SGDOptimizer(
            learning_rate=1e-3, parameter_list=ptb_model.parameters()
        )
256 257 258 259 260 261 262

        for epoch_id in range(max_epoch):

            total_loss = 0.0
            iters = 0.0
            total_sample = 0

263 264 265 266 267 268
            init_hidden_data = np.zeros(
                (num_layers, batch_size, hidden_size), dtype='float32'
            )
            init_cell_data = np.zeros(
                (num_layers, batch_size, hidden_size), dtype='float32'
            )
269 270 271 272 273 274 275 276 277 278 279 280 281 282

            init_hidden = to_variable(init_hidden_data)
            init_cell = to_variable(init_cell_data)
            for step_id in range(batch_num):
                x_data = np.arange(12).reshape(4, 3).astype('int64')
                y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
                y_data = y_data.reshape((-1, 1))

                x_data = x_data.reshape((-1, num_steps, 1))
                y_data = y_data.reshape((-1, num_steps, 1))

                x = to_variable(x_data)
                y = to_variable(y_data)

283
                dy_loss, last_hidden, last_cell = ptb_model(
284 285
                    x, y, init_hidden, init_cell
                )
286 287 288 289 290 291 292 293 294 295 296
                out_loss = dy_loss.numpy()

                dy_loss.backward()
                sgd.minimize(dy_loss)
                ptb_model.clear_gradients()

                total_loss += out_loss
                iters += num_steps
                total_sample += 1
                if step_id % PRINT_STEP == 0:
                    if step_id == 0:
297
                        logging.info(
298 299 300
                            "epoch %d | step %d, loss %0.3f"
                            % (epoch_id, step_id, total_loss / total_sample)
                        )
301 302 303 304 305
                        avg_batch_time = time.time()
                    else:
                        speed = PRINT_STEP / (time.time() - avg_batch_time)
                        logging.info(
                            "epoch %d | step %d, loss %0.3f, speed %.3f steps/s"
306 307 308 309 310 311 312
                            % (
                                epoch_id,
                                step_id,
                                total_loss / total_sample,
                                speed,
                            )
                        )
313 314
                        avg_batch_time = time.time()

315
        return out_loss, last_hidden.numpy(), last_cell.numpy()
316 317


318
def train_dygraph(place):
R
Ryan 已提交
319
    paddle.jit.enable_to_static(False)
320
    return train(place)
321 322


323
def train_static(place):
R
Ryan 已提交
324
    paddle.jit.enable_to_static(True)
325
    return train(place)
326 327 328 329


class TestPtb(unittest.TestCase):
    def setUp(self):
330 331 332
        self.place = (
            fluid.CUDAPlace(0)
            if fluid.is_compiled_with_cuda()
333
            else fluid.CPUPlace()
334
        )
335 336 337 338 339

    def test_check_result(self):
        loss_1, hidden_1, cell_1 = train_static(self.place)
        loss_2, hidden_2, cell_2 = train_dygraph(self.place)

340 341 342
        np.testing.assert_allclose(loss_1, loss_2, rtol=1e-05)
        np.testing.assert_allclose(hidden_1, hidden_2, rtol=1e-05)
        np.testing.assert_allclose(cell_1, cell_2, rtol=1e-05)
343 344 345


if __name__ == '__main__':
346
    unittest.main()