seq2seq_dygraph_model.py 31.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16 17
from seq2seq_utils import Seq2SeqModelHyperParams as args

18
import paddle
19
from paddle import fluid
20
from paddle.fluid import ParamAttr
21
from paddle.fluid.dygraph.base import to_variable
H
hjyp 已提交
22
from paddle.jit.api import to_static
23
from paddle.nn import Embedding, Layer
24

25
INF = 1.0 * 1e5
26
alpha = 0.6
27 28
uniform_initializer = lambda x: paddle.nn.initializer.Uniform(low=-x, high=x)
zero_constant = paddle.nn.initializer.Constant(0.0)
29 30 31


class BasicLSTMUnit(Layer):
32 33 34 35 36 37 38 39 40 41 42
    def __init__(
        self,
        hidden_size,
        input_size,
        param_attr=None,
        bias_attr=None,
        gate_activation=None,
        activation=None,
        forget_bias=1.0,
        dtype='float32',
    ):
43
        super().__init__(dtype)
44 45 46 47

        self._hiden_size = hidden_size
        self._param_attr = param_attr
        self._bias_attr = bias_attr
48 49
        self._gate_activation = gate_activation or paddle.nn.functional.sigmoid
        self._activation = activation or paddle.tanh
50 51 52 53 54 55 56
        self._forget_bias = forget_bias
        self._dtype = dtype
        self._input_size = input_size

        self._weight = self.create_parameter(
            attr=self._param_attr,
            shape=[self._input_size + self._hiden_size, 4 * self._hiden_size],
57 58
            dtype=self._dtype,
        )
59

60 61 62 63 64 65
        self._bias = self.create_parameter(
            attr=self._bias_attr,
            shape=[4 * self._hiden_size],
            dtype=self._dtype,
            is_bias=True,
        )
66 67

    def forward(self, input, pre_hidden, pre_cell):
68
        concat_input_hidden = paddle.concat([input, pre_hidden], 1)
K
kangguangli 已提交
69
        gate_input = paddle.matmul(x=concat_input_hidden, y=self._weight)
70

71
        gate_input = paddle.add(gate_input, self._bias)
72
        i, j, f, o = paddle.split(gate_input, num_or_sections=4, axis=-1)
73 74
        new_cell = paddle.add(
            paddle.multiply(
75 76
                pre_cell, paddle.nn.functional.sigmoid(f + self._forget_bias)
            ),
77
            paddle.multiply(paddle.nn.functional.sigmoid(i), paddle.tanh(j)),
78
        )
79

80
        new_hidden = paddle.tanh(new_cell) * paddle.nn.functional.sigmoid(o)
81 82 83 84

        return new_hidden, new_cell


85
class BaseModel(paddle.nn.Layer):
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    def __init__(
        self,
        hidden_size,
        src_vocab_size,
        tar_vocab_size,
        batch_size,
        num_layers=1,
        init_scale=0.1,
        dropout=None,
        beam_size=1,
        beam_start_token=1,
        beam_end_token=2,
        beam_max_step_num=2,
        mode='train',
    ):
101
        super().__init__()
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
        self.hidden_size = hidden_size
        self.src_vocab_size = src_vocab_size
        self.tar_vocab_size = tar_vocab_size
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.init_scale = init_scale
        self.dropout = dropout
        self.beam_size = beam_size
        self.beam_start_token = beam_start_token
        self.beam_end_token = beam_end_token
        self.beam_max_step_num = beam_max_step_num
        self.mode = mode
        self.kinf = 1e9

        param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
        bias_attr = ParamAttr(initializer=zero_constant)
        forget_bias = 1.0

        self.src_embeder = Embedding(
121 122 123
            self.src_vocab_size,
            self.hidden_size,
            weight_attr=fluid.ParamAttr(
124 125 126
                initializer=uniform_initializer(init_scale)
            ),
        )
127 128

        self.tar_embeder = Embedding(
129 130 131 132
            self.tar_vocab_size,
            self.hidden_size,
            sparse=False,
            weight_attr=fluid.ParamAttr(
133 134 135
                initializer=uniform_initializer(init_scale)
            ),
        )
136 137 138 139 140 141

        self.enc_units = []
        for i in range(num_layers):
            self.enc_units.append(
                self.add_sublayer(
                    "enc_units_%d" % i,
142 143 144 145 146 147 148 149 150
                    BasicLSTMUnit(
                        hidden_size=self.hidden_size,
                        input_size=self.hidden_size,
                        param_attr=param_attr,
                        bias_attr=bias_attr,
                        forget_bias=forget_bias,
                    ),
                )
            )
151 152 153 154 155 156

        self.dec_units = []
        for i in range(num_layers):
            self.dec_units.append(
                self.add_sublayer(
                    "dec_units_%d" % i,
157 158 159 160 161 162 163 164 165 166
                    BasicLSTMUnit(
                        hidden_size=self.hidden_size,
                        input_size=self.hidden_size,
                        param_attr=param_attr,
                        bias_attr=bias_attr,
                        forget_bias=forget_bias,
                    ),
                )
            )

167
        self.fc = paddle.nn.Linear(
168 169
            self.hidden_size,
            self.tar_vocab_size,
170 171 172 173 174
            weight_attr=paddle.ParamAttr(
                initializer=paddle.nn.initializer.Uniform(
                    low=-self.init_scale, high=self.init_scale
                )
            ),
175 176
            bias_attr=False,
        )
177 178

    def _transpose_batch_time(self, x):
179
        return paddle.transpose(x, [1, 0] + list(range(2, len(x.shape))))
180 181

    def _merge_batch_beams(self, x):
182
        return paddle.reshape(x, shape=(-1, x.shape[2]))
183 184

    def _split_batch_beams(self, x):
185
        return paddle.reshape(x, shape=(-1, self.beam_size, x.shape[1]))
186 187

    def _expand_to_beam_size(self, x):
188
        x = paddle.unsqueeze(x, [1])
189 190 191
        expand_shape = [-1] * len(x.shape)
        expand_shape[1] = self.beam_size * x.shape[1]
        x = paddle.expand(x, expand_shape)
192 193 194
        return x

    def _real_state(self, state, new_state, step_mask):
195
        new_state = paddle.tensor.math._multiply_with_axis(
196
            new_state, step_mask, axis=0
197 198 199
        ) - paddle.tensor.math._multiply_with_axis(
            state, (step_mask - 1), axis=0
        )
200 201 202
        return new_state

    def _gather(self, x, indices, batch_pos):
203
        topk_coordinates = paddle.stack([batch_pos, indices], axis=2)
204
        return paddle.gather_nd(x, topk_coordinates)
205

H
hjyp 已提交
206
    @to_static
207 208 209 210 211 212 213 214 215 216
    def forward(self, inputs):
        src, tar, label, src_sequence_length, tar_sequence_length = inputs
        if src.shape[0] < self.batch_size:
            self.batch_size = src.shape[0]

        src_emb = self.src_embeder(self._transpose_batch_time(src))

        # NOTE: modify model code about `enc_hidden` and `enc_cell` to transforme dygraph code successfully.
        # Because nested list can't be transformed now.
        enc_hidden_0 = to_variable(
217 218
            np.zeros((self.batch_size, self.hidden_size), dtype='float32')
        )
219
        enc_cell_0 = to_variable(
220 221
            np.zeros((self.batch_size, self.hidden_size), dtype='float32')
        )
222
        zero = paddle.zeros(shape=[1], dtype="int64")
223 224
        enc_hidden = paddle.tensor.create_array(dtype="float32")
        enc_cell = paddle.tensor.create_array(dtype="float32")
225 226
        for i in range(self.num_layers):
            index = zero + i
227
            enc_hidden = paddle.tensor.array_write(
228 229
                enc_hidden_0, index, array=enc_hidden
            )
230
            enc_cell = paddle.tensor.array_write(
231 232
                enc_cell_0, index, array=enc_cell
            )
233 234 235

        max_seq_len = src_emb.shape[0]

236
        enc_len_mask = paddle.static.nn.sequence_lod.sequence_mask(
237 238
            src_sequence_length, maxlen=max_seq_len, dtype="float32"
        )
239
        enc_len_mask = paddle.transpose(enc_len_mask, [1, 0])
240 241 242 243 244 245 246 247 248 249

        # TODO: Because diff exits if call while_loop in static graph.
        # In while block, a Variable created in parent block participates in the calculation of gradient,
        # the gradient is wrong because each step scope always returns the same value generated by last step.
        # NOTE: Replace max_seq_len(Tensor src_emb.shape[0]) with args.max_seq_len(int) to avoid this bug temporarily.
        for k in range(args.max_seq_len):
            enc_step_input = src_emb[k]
            step_mask = enc_len_mask[k]
            new_enc_hidden, new_enc_cell = [], []
            for i in range(self.num_layers):
250 251 252
                enc_new_hidden, enc_new_cell = self.enc_units[i](
                    enc_step_input, enc_hidden[i], enc_cell[i]
                )
253
                if self.dropout is not None and self.dropout > 0.0:
C
ccrrong 已提交
254
                    enc_step_input = paddle.nn.functional.dropout(
255
                        enc_new_hidden,
C
ccrrong 已提交
256 257
                        p=self.dropout,
                        mode='upscale_in_train',
258
                    )
259 260 261 262
                else:
                    enc_step_input = enc_new_hidden

                new_enc_hidden.append(
263 264
                    self._real_state(enc_hidden[i], enc_new_hidden, step_mask)
                )
265
                new_enc_cell.append(
266 267
                    self._real_state(enc_cell[i], enc_new_cell, step_mask)
                )
268 269 270 271 272 273 274 275 276 277 278 279

            enc_hidden, enc_cell = new_enc_hidden, new_enc_cell

        dec_hidden, dec_cell = enc_hidden, enc_cell
        tar_emb = self.tar_embeder(self._transpose_batch_time(tar))
        max_seq_len = tar_emb.shape[0]
        dec_output = []
        for step_idx in range(max_seq_len):
            j = step_idx + 0
            step_input = tar_emb[j]
            new_dec_hidden, new_dec_cell = [], []
            for i in range(self.num_layers):
280 281 282
                new_hidden, new_cell = self.dec_units[i](
                    step_input, dec_hidden[i], dec_cell[i]
                )
283 284
                new_dec_hidden.append(new_hidden)
                new_dec_cell.append(new_cell)
285
                if self.dropout is not None and self.dropout > 0.0:
C
ccrrong 已提交
286
                    step_input = paddle.nn.functional.dropout(
287
                        new_hidden,
C
ccrrong 已提交
288 289
                        p=self.dropout,
                        mode='upscale_in_train',
290
                    )
291 292 293 294
                else:
                    step_input = new_hidden
            dec_output.append(step_input)

295
        dec_output = paddle.stack(dec_output)
296
        dec_output = self.fc(self._transpose_batch_time(dec_output))
297
        loss = paddle.nn.functional.softmax_with_cross_entropy(
298 299
            logits=dec_output, label=label, soft_label=False
        )
2
201716010711 已提交
300 301
        loss = paddle.squeeze(loss, axes=[2])
        max_tar_seq_len = paddle.shape(tar)[1]
302
        tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
303 304
            tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
        )
305
        loss = loss * tar_mask
306
        loss = paddle.mean(loss, axis=[0])
307
        loss = paddle.sum(loss)
308 309 310

        return loss

H
hjyp 已提交
311
    @to_static
312 313 314 315 316 317 318
    def beam_search(self, inputs):
        src, tar, label, src_sequence_length, tar_sequence_length = inputs
        if src.shape[0] < self.batch_size:
            self.batch_size = src.shape[0]

        src_emb = self.src_embeder(self._transpose_batch_time(src))
        enc_hidden_0 = to_variable(
319 320
            np.zeros((self.batch_size, self.hidden_size), dtype='float32')
        )
321
        enc_cell_0 = to_variable(
322 323
            np.zeros((self.batch_size, self.hidden_size), dtype='float32')
        )
324
        zero = paddle.zeros(shape=[1], dtype="int64")
325 326
        enc_hidden = paddle.tensor.create_array(dtype="float32")
        enc_cell = paddle.tensor.create_array(dtype="float32")
327 328
        for j in range(self.num_layers):
            index = zero + j
329
            enc_hidden = paddle.tensor.array_write(
330 331
                enc_hidden_0, index, array=enc_hidden
            )
332
            enc_cell = paddle.tensor.array_write(
333 334
                enc_cell_0, index, array=enc_cell
            )
335 336 337

        max_seq_len = src_emb.shape[0]

338
        enc_len_mask = paddle.static.nn.sequence_lod.sequence_mask(
339 340
            src_sequence_length, maxlen=max_seq_len, dtype="float32"
        )
341
        enc_len_mask = paddle.transpose(enc_len_mask, [1, 0])
342 343 344 345 346 347 348 349

        for k in range(args.max_seq_len):
            enc_step_input = src_emb[k]
            step_mask = enc_len_mask[k]

            new_enc_hidden, new_enc_cell = [], []

            for i in range(self.num_layers):
350 351 352
                enc_new_hidden, enc_new_cell = self.enc_units[i](
                    enc_step_input, enc_hidden[i], enc_cell[i]
                )
353
                if self.dropout is not None and self.dropout > 0.0:
C
ccrrong 已提交
354
                    enc_step_input = paddle.nn.functional.dropout(
355
                        enc_new_hidden,
C
ccrrong 已提交
356 357
                        p=self.dropout,
                        mode='upscale_in_train',
358
                    )
359 360 361 362
                else:
                    enc_step_input = enc_new_hidden

                new_enc_hidden.append(
363 364
                    self._real_state(enc_hidden[i], enc_new_hidden, step_mask)
                )
365
                new_enc_cell.append(
366 367
                    self._real_state(enc_cell[i], enc_new_cell, step_mask)
                )
368 369 370 371 372

            enc_hidden, enc_cell = new_enc_hidden, new_enc_cell

        # beam search
        batch_beam_shape = (self.batch_size, self.beam_size)
373 374 375
        vocab_size_tensor = to_variable(
            np.full((1), self.tar_vocab_size)
        ).astype("int64")
376
        start_token_tensor = to_variable(
377 378
            np.full(batch_beam_shape, self.beam_start_token, dtype='int64')
        )
379
        end_token_tensor = to_variable(
380 381
            np.full(batch_beam_shape, self.beam_end_token, dtype='int64')
        )
382 383
        step_input = self.tar_embeder(start_token_tensor)
        beam_finished = to_variable(
384 385
            np.full(batch_beam_shape, 0, dtype='float32')
        )
386
        beam_state_log_probs = to_variable(
387 388 389 390
            np.array(
                [[0.0] + [-self.kinf] * (self.beam_size - 1)], dtype="float32"
            )
        )
391 392 393
        beam_state_log_probs = paddle.expand(
            beam_state_log_probs,
            [self.batch_size * beam_state_log_probs.shape[0], -1],
394
        )
395 396 397 398
        dec_hidden, dec_cell = enc_hidden, enc_cell
        dec_hidden = [self._expand_to_beam_size(ele) for ele in dec_hidden]
        dec_cell = [self._expand_to_beam_size(ele) for ele in dec_cell]

399
        batch_pos = paddle.expand(
400
            paddle.unsqueeze(
401
                to_variable(np.arange(0, self.batch_size, 1, dtype="int64")),
402 403
                [1],
            ),
404
            [-1, self.beam_size],
405
        )
406 407 408
        predicted_ids = []
        parent_ids = []

409
        for step_idx in range(paddle.to_tensor(self.beam_max_step_num)):
410
            if paddle.sum(1 - beam_finished).numpy()[0] == 0:
411 412 413 414 415 416 417 418 419 420
                break
            step_input = self._merge_batch_beams(step_input)
            new_dec_hidden, new_dec_cell = [], []
            state = 0
            dec_hidden = [
                self._merge_batch_beams(state) for state in dec_hidden
            ]
            dec_cell = [self._merge_batch_beams(state) for state in dec_cell]

            for i in range(self.num_layers):
421 422 423
                new_hidden, new_cell = self.dec_units[i](
                    step_input, dec_hidden[i], dec_cell[i]
                )
424 425
                new_dec_hidden.append(new_hidden)
                new_dec_cell.append(new_cell)
426
                if self.dropout is not None and self.dropout > 0.0:
C
ccrrong 已提交
427
                    step_input = paddle.nn.functional.dropout(
428
                        new_hidden,
C
ccrrong 已提交
429 430
                        p=self.dropout,
                        mode='upscale_in_train',
431
                    )
432 433
                else:
                    step_input = new_hidden
434

435 436 437
            cell_outputs = self._split_batch_beams(step_input)
            cell_outputs = self.fc(cell_outputs)

438 439 440
            step_log_probs = paddle.log(
                paddle.nn.functional.softmax(cell_outputs)
            )
441 442 443
            noend_array = [-self.kinf] * self.tar_vocab_size
            noend_array[self.beam_end_token] = 0
            noend_mask_tensor = to_variable(
444 445
                np.array(noend_array, dtype='float32')
            )
446

447
            step_log_probs = paddle.multiply(
448
                paddle.expand(
449
                    paddle.unsqueeze(beam_finished, [2]),
450
                    [-1, -1, self.tar_vocab_size],
451 452
                ),
                noend_mask_tensor,
453
            ) - paddle.tensor.math._multiply_with_axis(
454 455
                step_log_probs, (beam_finished - 1), axis=0
            )
456
            log_probs = paddle.tensor.math._add_with_axis(
457 458
                x=step_log_probs, y=beam_state_log_probs, axis=0
            )
459
            scores = paddle.reshape(
460 461
                log_probs, [-1, self.beam_size * self.tar_vocab_size]
            )
462
            topk_scores, topk_indices = paddle.topk(x=scores, k=self.beam_size)
463

464 465
            beam_indices = paddle.floor_divide(topk_indices, vocab_size_tensor)
            token_indices = paddle.remainder(topk_indices, vocab_size_tensor)
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
            next_log_probs = self._gather(scores, topk_indices, batch_pos)

            x = 0
            new_dec_hidden = [
                self._split_batch_beams(state) for state in new_dec_hidden
            ]
            new_dec_cell = [
                self._split_batch_beams(state) for state in new_dec_cell
            ]
            new_dec_hidden = [
                self._gather(x, beam_indices, batch_pos) for x in new_dec_hidden
            ]
            new_dec_cell = [
                self._gather(x, beam_indices, batch_pos) for x in new_dec_cell
            ]

            new_dec_hidden = [
                self._gather(x, beam_indices, batch_pos) for x in new_dec_hidden
            ]
            new_dec_cell = [
                self._gather(x, beam_indices, batch_pos) for x in new_dec_cell
            ]
            next_finished = self._gather(beam_finished, beam_indices, batch_pos)
489
            next_finished = paddle.cast(next_finished, "bool")
2
201716010711 已提交
490
            next_finished = paddle.logical_or(
491
                next_finished,
492
                paddle.equal(token_indices, end_token_tensor),
493
            )
494
            next_finished = paddle.cast(next_finished, "float32")
495 496 497 498 499 500 501 502

            dec_hidden, dec_cell = new_dec_hidden, new_dec_cell
            beam_finished = next_finished
            beam_state_log_probs = next_log_probs
            step_input = self.tar_embeder(token_indices)
            predicted_ids.append(token_indices)
            parent_ids.append(beam_indices)

503 504
        predicted_ids = paddle.stack(predicted_ids)
        parent_ids = paddle.stack(parent_ids)
505 506 507
        predicted_ids = paddle.nn.functional.gather_tree(
            predicted_ids, parent_ids
        )
508 509
        predicted_ids = self._transpose_batch_time(predicted_ids)
        return predicted_ids
510 511


512
class AttentionModel(paddle.nn.Layer):
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
    def __init__(
        self,
        hidden_size,
        src_vocab_size,
        tar_vocab_size,
        batch_size,
        num_layers=1,
        init_scale=0.1,
        dropout=None,
        beam_size=1,
        beam_start_token=1,
        beam_end_token=2,
        beam_max_step_num=2,
        mode='train',
    ):
528
        super().__init__()
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
        self.hidden_size = hidden_size
        self.src_vocab_size = src_vocab_size
        self.tar_vocab_size = tar_vocab_size
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.init_scale = init_scale
        self.dropout = dropout
        self.beam_size = beam_size
        self.beam_start_token = beam_start_token
        self.beam_end_token = beam_end_token
        self.beam_max_step_num = beam_max_step_num
        self.mode = mode
        self.kinf = 1e9

        param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale))
        bias_attr = ParamAttr(initializer=zero_constant)
        forget_bias = 1.0

        self.src_embeder = Embedding(
548 549 550
            self.src_vocab_size,
            self.hidden_size,
            weight_attr=fluid.ParamAttr(
551
                name='source_embedding',
552 553 554
                initializer=uniform_initializer(init_scale),
            ),
        )
555 556

        self.tar_embeder = Embedding(
557 558 559 560
            self.tar_vocab_size,
            self.hidden_size,
            sparse=False,
            weight_attr=fluid.ParamAttr(
561
                name='target_embedding',
562 563 564
                initializer=uniform_initializer(init_scale),
            ),
        )
565 566 567 568 569 570

        self.enc_units = []
        for i in range(num_layers):
            self.enc_units.append(
                self.add_sublayer(
                    "enc_units_%d" % i,
571 572 573 574 575 576 577 578 579
                    BasicLSTMUnit(
                        hidden_size=self.hidden_size,
                        input_size=self.hidden_size,
                        param_attr=param_attr,
                        bias_attr=bias_attr,
                        forget_bias=forget_bias,
                    ),
                )
            )
580 581 582 583 584 585 586

        self.dec_units = []
        for i in range(num_layers):
            if i == 0:
                self.dec_units.append(
                    self.add_sublayer(
                        "dec_units_%d" % i,
587 588 589 590 591 592 593 594 595 596 597 598 599 600
                        BasicLSTMUnit(
                            hidden_size=self.hidden_size,
                            input_size=self.hidden_size * 2,
                            param_attr=ParamAttr(
                                name="dec_units_%d" % i,
                                initializer=uniform_initializer(
                                    self.init_scale
                                ),
                            ),
                            bias_attr=bias_attr,
                            forget_bias=forget_bias,
                        ),
                    )
                )
601 602 603 604
            else:
                self.dec_units.append(
                    self.add_sublayer(
                        "dec_units_%d" % i,
605 606 607 608 609 610 611 612 613 614 615 616 617 618
                        BasicLSTMUnit(
                            hidden_size=self.hidden_size,
                            input_size=self.hidden_size,
                            param_attr=ParamAttr(
                                name="dec_units_%d" % i,
                                initializer=uniform_initializer(
                                    self.init_scale
                                ),
                            ),
                            bias_attr=bias_attr,
                            forget_bias=forget_bias,
                        ),
                    )
                )
619

620
        self.attn_fc = paddle.nn.Linear(
621 622
            self.hidden_size,
            self.hidden_size,
623
            weight_attr=paddle.ParamAttr(
624
                name="self_attn_fc",
625 626 627
                initializer=paddle.nn.initializer.Uniform(
                    low=-self.init_scale, high=self.init_scale
                ),
628 629 630
            ),
            bias_attr=False,
        )
631

632
        self.concat_fc = paddle.nn.Linear(
633 634
            2 * self.hidden_size,
            self.hidden_size,
635
            weight_attr=paddle.ParamAttr(
636
                name="self_concat_fc",
637 638 639
                initializer=paddle.nn.initializer.Uniform(
                    low=-self.init_scale, high=self.init_scale
                ),
640 641 642 643
            ),
            bias_attr=False,
        )

644
        self.fc = paddle.nn.Linear(
645 646
            self.hidden_size,
            self.tar_vocab_size,
647 648 649 650 651
            weight_attr=paddle.ParamAttr(
                name="self_fc",
                initializer=paddle.nn.initializer.Uniform(
                    low=-self.init_scale, high=self.init_scale
                ),
652 653 654
            ),
            bias_attr=False,
        )
655 656

    def _transpose_batch_time(self, x):
657
        return paddle.transpose(x, [1, 0] + list(range(2, len(x.shape))))
658 659

    def _merge_batch_beams(self, x):
660
        return paddle.reshape(x, shape=(-1, x.shape[2]))
661 662

    def tile_beam_merge_with_batch(self, x):
663
        x = paddle.unsqueeze(x, [1])  # [batch_size, 1, ...]
664 665 666
        expand_shape = [-1] * len(x.shape)
        expand_shape[1] = self.beam_size * x.shape[1]
        x = paddle.expand(x, expand_shape)  # [batch_size, beam_size, ...]
667
        x = paddle.transpose(
668 669
            x, list(range(2, len(x.shape))) + [0, 1]
        )  # [..., batch_size, beam_size]
670
        # use 0 to copy to avoid wrong shape
671
        x = paddle.reshape(
672 673
            x, shape=[0] * (len(x.shape) - 2) + [-1]
        )  # [..., batch_size * beam_size]
674
        x = paddle.transpose(
675 676
            x, [len(x.shape) - 1] + list(range(0, len(x.shape) - 1))
        )  # [batch_size * beam_size, ...]
677 678 679
        return x

    def _split_batch_beams(self, x):
680
        return paddle.reshape(x, shape=(-1, self.beam_size, x.shape[1]))
681 682

    def _expand_to_beam_size(self, x):
683
        x = paddle.unsqueeze(x, [1])
684 685 686
        expand_shape = [-1] * len(x.shape)
        expand_shape[1] = self.beam_size * x.shape[1]
        x = paddle.expand(x, expand_shape)
687 688 689
        return x

    def _real_state(self, state, new_state, step_mask):
690
        new_state = paddle.tensor.math._multiply_with_axis(
691
            new_state, step_mask, axis=0
692 693 694
        ) - paddle.tensor.math._multiply_with_axis(
            state, (step_mask - 1), axis=0
        )
695 696 697
        return new_state

    def _gather(self, x, indices, batch_pos):
698
        topk_coordinates = paddle.stack([batch_pos, indices], axis=2)
699
        return paddle.gather_nd(x, topk_coordinates)
700 701

    def attention(self, query, enc_output, mask=None):
702
        query = paddle.unsqueeze(query, [1])
703
        memory = self.attn_fc(enc_output)
K
kangguangli 已提交
704
        attn = paddle.matmul(query, memory, transpose_y=True)
705 706

        if mask is not None:
707
            attn = paddle.transpose(attn, [1, 0, 2])
708
            attn = paddle.add(attn, mask * 1000000000)
709
            attn = paddle.transpose(attn, [1, 0, 2])
710
        weight = paddle.nn.functional.softmax(attn)
K
kangguangli 已提交
711
        weight_memory = paddle.matmul(weight, memory)
712 713 714 715 716 717 718

        return weight_memory

    def _change_size_for_array(self, func, array):
        print(" ^" * 10, "_change_size_for_array")
        print("array : ", array)
        for i, state in enumerate(array):
719
            paddle.tensor.array_write(func(state), i, array)
720 721 722

        return array

H
hjyp 已提交
723
    @to_static
724 725 726 727 728 729 730 731 732 733
    def forward(self, inputs):
        src, tar, label, src_sequence_length, tar_sequence_length = inputs
        if src.shape[0] < self.batch_size:
            self.batch_size = src.shape[0]

        src_emb = self.src_embeder(self._transpose_batch_time(src))

        # NOTE: modify model code about `enc_hidden` and `enc_cell` to transforme dygraph code successfully.
        # Because nested list can't be transformed now.
        enc_hidden_0 = to_variable(
734 735
            np.zeros((self.batch_size, self.hidden_size), dtype='float32')
        )
736 737
        enc_hidden_0.stop_gradient = True
        enc_cell_0 = to_variable(
738 739
            np.zeros((self.batch_size, self.hidden_size), dtype='float32')
        )
740
        enc_hidden_0.stop_gradient = True
741
        zero = paddle.zeros(shape=[1], dtype="int64")
742 743
        enc_hidden = paddle.tensor.create_array(dtype="float32")
        enc_cell = paddle.tensor.create_array(dtype="float32")
744 745
        for i in range(self.num_layers):
            index = zero + i
746
            enc_hidden = paddle.tensor.array_write(
747 748
                enc_hidden_0, index, array=enc_hidden
            )
749
            enc_cell = paddle.tensor.array_write(
750 751
                enc_cell_0, index, array=enc_cell
            )
752 753 754

        max_seq_len = src_emb.shape[0]

755
        enc_len_mask = paddle.static.nn.sequence_lod.sequence_mask(
756 757 758
            src_sequence_length, maxlen=max_seq_len, dtype="float32"
        )
        enc_padding_mask = enc_len_mask - 1.0
759
        enc_len_mask = paddle.transpose(enc_len_mask, [1, 0])
760 761 762 763 764 765 766 767 768 769 770

        enc_outputs = []
        # TODO: Because diff exits if call while_loop in static graph.
        # In while block, a Variable created in parent block participates in the calculation of gradient,
        # the gradient is wrong because each step scope always returns the same value generated by last step.
        for p in range(max_seq_len):
            k = 0 + p
            enc_step_input = src_emb[k]
            step_mask = enc_len_mask[k]
            new_enc_hidden, new_enc_cell = [], []
            for i in range(self.num_layers):
771 772 773
                enc_new_hidden, enc_new_cell = self.enc_units[i](
                    enc_step_input, enc_hidden[i], enc_cell[i]
                )
774
                if self.dropout is not None and self.dropout > 0.0:
C
ccrrong 已提交
775
                    enc_step_input = paddle.nn.functional.dropout(
776
                        enc_new_hidden,
C
ccrrong 已提交
777 778
                        p=self.dropout,
                        mode='upscale_in_train',
779
                    )
780 781 782 783
                else:
                    enc_step_input = enc_new_hidden

                new_enc_hidden.append(
784 785
                    self._real_state(enc_hidden[i], enc_new_hidden, step_mask)
                )
786
                new_enc_cell.append(
787 788
                    self._real_state(enc_cell[i], enc_new_cell, step_mask)
                )
789 790 791
            enc_outputs.append(enc_step_input)
            enc_hidden, enc_cell = new_enc_hidden, new_enc_cell

792
        enc_outputs = paddle.stack(enc_outputs)
793 794 795 796
        enc_outputs = self._transpose_batch_time(enc_outputs)

        # train
        input_feed = to_variable(
797 798
            np.zeros((self.batch_size, self.hidden_size), dtype='float32')
        )
799 800 801 802 803 804 805 806 807 808
        # NOTE: set stop_gradient here, otherwise grad var is null
        input_feed.stop_gradient = True
        dec_hidden, dec_cell = enc_hidden, enc_cell
        tar_emb = self.tar_embeder(self._transpose_batch_time(tar))
        max_seq_len = tar_emb.shape[0]
        dec_output = []

        for step_idx in range(max_seq_len):
            j = step_idx + 0
            step_input = tar_emb[j]
809
            step_input = paddle.concat([step_input, input_feed], 1)
810 811
            new_dec_hidden, new_dec_cell = [], []
            for i in range(self.num_layers):
812 813 814
                new_hidden, new_cell = self.dec_units[i](
                    step_input, dec_hidden[i], dec_cell[i]
                )
815 816
                new_dec_hidden.append(new_hidden)
                new_dec_cell.append(new_cell)
817
                if self.dropout is not None and self.dropout > 0.0:
C
ccrrong 已提交
818
                    step_input = paddle.nn.functional.dropout(
819
                        new_hidden,
C
ccrrong 已提交
820 821
                        p=self.dropout,
                        mode='upscale_in_train',
822
                    )
823 824 825
                else:
                    step_input = new_hidden
            dec_att = self.attention(step_input, enc_outputs, enc_padding_mask)
826
            dec_att = paddle.squeeze(dec_att, [1])
827
            concat_att_out = paddle.concat([dec_att, step_input], 1)
828 829 830 831 832
            out = self.concat_fc(concat_att_out)
            input_feed = out
            dec_output.append(out)
            dec_hidden, dec_cell = new_dec_hidden, new_dec_cell

833
        dec_output = paddle.stack(dec_output)
834
        dec_output = self.fc(self._transpose_batch_time(dec_output))
835
        loss = paddle.nn.functional.softmax_with_cross_entropy(
836 837
            logits=dec_output, label=label, soft_label=False
        )
2
201716010711 已提交
838 839
        loss = paddle.squeeze(loss, axes=[2])
        max_tar_seq_len = paddle.shape(tar)[1]
840
        tar_mask = paddle.static.nn.sequence_lod.sequence_mask(
841 842
            tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
        )
843
        loss = loss * tar_mask
844 845
        loss = paddle.mean(loss, axis=[0])
        loss = fluid.layers.reduce_sum(loss)
846 847

        return loss