test_rnn_decode_api.py 27.8 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
G
Guo Sheng 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

17
import random
G
Guo Sheng 已提交
18
import unittest
19
import numpy as np
G
Guo Sheng 已提交
20

21 22 23 24 25 26 27
import paddle
import paddle.nn as nn
from paddle import Model, set_device
from paddle.static import InputSpec as Input
from paddle.fluid.dygraph import Layer
from paddle.nn import BeamSearchDecoder, dynamic_decode

G
Guo Sheng 已提交
28 29 30 31 32 33 34
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core

from paddle.fluid.executor import Executor
from paddle.fluid import framework

35 36
paddle.enable_static()

G
Guo Sheng 已提交
37

38
class EncoderCell(layers.RNNCell):
G
Guo Sheng 已提交
39 40 41 42
    def __init__(self, num_layers, hidden_size, dropout_prob=0.):
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob
43 44 45
        self.lstm_cells = [
            layers.LSTMCell(hidden_size) for i in range(num_layers)
        ]
G
Guo Sheng 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

    def call(self, step_input, states):
        new_states = []
        for i in range(self.num_layers):
            out, new_state = self.lstm_cells[i](step_input, states[i])
            step_input = layers.dropout(
                out, self.dropout_prob) if self.dropout_prob > 0 else out
            new_states.append(new_state)
        return step_input, new_states

    @property
    def state_shape(self):
        return [cell.state_shape for cell in self.lstm_cells]


61
class DecoderCell(layers.RNNCell):
G
Guo Sheng 已提交
62 63 64 65
    def __init__(self, num_layers, hidden_size, dropout_prob=0.):
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob
66 67 68
        self.lstm_cells = [
            layers.LSTMCell(hidden_size) for i in range(num_layers)
        ]
G
Guo Sheng 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102

    def attention(self, hidden, encoder_output, encoder_padding_mask):
        query = layers.fc(hidden,
                          size=encoder_output.shape[-1],
                          bias_attr=False)
        attn_scores = layers.matmul(
            layers.unsqueeze(query, [1]), encoder_output, transpose_y=True)
        if encoder_padding_mask is not None:
            attn_scores = layers.elementwise_add(attn_scores,
                                                 encoder_padding_mask)
        attn_scores = layers.softmax(attn_scores)
        attn_out = layers.squeeze(
            layers.matmul(attn_scores, encoder_output), [1])
        attn_out = layers.concat([attn_out, hidden], 1)
        attn_out = layers.fc(attn_out, size=self.hidden_size, bias_attr=False)
        return attn_out

    def call(self,
             step_input,
             states,
             encoder_output,
             encoder_padding_mask=None):
        lstm_states, input_feed = states
        new_lstm_states = []
        step_input = layers.concat([step_input, input_feed], 1)
        for i in range(self.num_layers):
            out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
            step_input = layers.dropout(
                out, self.dropout_prob) if self.dropout_prob > 0 else out
            new_lstm_states.append(new_lstm_state)
        out = self.attention(step_input, encoder_output, encoder_padding_mask)
        return out, [new_lstm_states, out]


103 104 105
class Encoder(object):
    def __init__(self, num_layers, hidden_size, dropout_prob=0.):
        self.encoder_cell = EncoderCell(num_layers, hidden_size, dropout_prob)
G
Guo Sheng 已提交
106

107 108 109 110 111
    def __call__(self, src_emb, src_sequence_length):
        encoder_output, encoder_final_state = layers.rnn(
            cell=self.encoder_cell,
            inputs=src_emb,
            sequence_length=src_sequence_length,
G
Guo Sheng 已提交
112
            is_reverse=False)
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        return encoder_output, encoder_final_state


class Decoder(object):
    def __init__(self,
                 num_layers,
                 hidden_size,
                 dropout_prob,
                 decoding_strategy="infer_sample",
                 max_decoding_length=20):
        self.decoder_cell = DecoderCell(num_layers, hidden_size, dropout_prob)
        self.decoding_strategy = decoding_strategy
        self.max_decoding_length = None if (
            self.decoding_strategy == "train_greedy") else max_decoding_length

    def __call__(self, decoder_initial_states, encoder_output,
                 encoder_padding_mask, **kwargs):
        output_layer = kwargs.pop("output_layer", None)
        if self.decoding_strategy == "train_greedy":
            # for teach-forcing MLE pre-training
            helper = layers.TrainingHelper(**kwargs)
        elif self.decoding_strategy == "infer_sample":
            helper = layers.SampleEmbeddingHelper(**kwargs)
        elif self.decoding_strategy == "infer_greedy":
            helper = layers.GreedyEmbeddingHelper(**kwargs)

        if self.decoding_strategy == "beam_search":
            beam_size = kwargs.get("beam_size", 4)
            encoder_output = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
                encoder_output, beam_size)
            encoder_padding_mask = layers.BeamSearchDecoder.tile_beam_merge_with_batch(
                encoder_padding_mask, beam_size)
            decoder = layers.BeamSearchDecoder(
                cell=self.decoder_cell, output_fn=output_layer, **kwargs)
        else:
            decoder = layers.BasicDecoder(
                self.decoder_cell, helper, output_fn=output_layer)

        (decoder_output, decoder_final_state,
         dec_seq_lengths) = layers.dynamic_decode(
             decoder,
             inits=decoder_initial_states,
             max_step_num=self.max_decoding_length,
             encoder_output=encoder_output,
             encoder_padding_mask=encoder_padding_mask,
             impute_finished=False  # for test coverage
             if self.decoding_strategy == "beam_search" else True,
             is_test=True if self.decoding_strategy == "beam_search" else False,
             return_length=True)
        return decoder_output, decoder_final_state, dec_seq_lengths


class Seq2SeqModel(object):
    """Seq2Seq model: RNN encoder-decoder with attention"""

    def __init__(self,
                 num_layers,
                 hidden_size,
                 dropout_prob,
                 src_vocab_size,
                 trg_vocab_size,
                 start_token,
                 end_token,
                 decoding_strategy="infer_sample",
                 max_decoding_length=20,
                 beam_size=4):
        self.start_token, self.end_token = start_token, end_token
        self.max_decoding_length, self.beam_size = max_decoding_length, beam_size
        self.src_embeder = lambda x: fluid.embedding(
            input=x,
            size=[src_vocab_size, hidden_size],
            dtype="float32",
            param_attr=fluid.ParamAttr(name="source_embedding"))
        self.trg_embeder = lambda x: fluid.embedding(
            input=x,
            size=[trg_vocab_size, hidden_size],
            dtype="float32",
            param_attr=fluid.ParamAttr(name="target_embedding"))
        self.encoder = Encoder(num_layers, hidden_size, dropout_prob)
        self.decoder = Decoder(num_layers, hidden_size, dropout_prob,
                               decoding_strategy, max_decoding_length)
        self.output_layer = lambda x: layers.fc(
            x,
            size=trg_vocab_size,
            num_flatten_dims=len(x.shape) - 1,
            param_attr=fluid.ParamAttr(name="output_w"),
            bias_attr=False)
G
Guo Sheng 已提交
200

201 202 203 204
    def __call__(self, src, src_length, trg=None, trg_length=None):
        # encoder
        encoder_output, encoder_final_state = self.encoder(
            self.src_embeder(src), src_length)
G
Guo Sheng 已提交
205 206

        decoder_initial_states = [
207 208
            encoder_final_state, self.decoder.decoder_cell.get_initial_states(
                batch_ref=encoder_output, shape=[encoder_output.shape[-1]])
G
Guo Sheng 已提交
209
        ]
210 211 212 213
        src_mask = layers.sequence_mask(
            src_length, maxlen=layers.shape(src)[1], dtype="float32")
        encoder_padding_mask = (src_mask - 1.0) * 1e9
        encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1])
G
Guo Sheng 已提交
214

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
        # decoder
        decoder_kwargs = {
            "inputs": self.trg_embeder(trg),
            "sequence_length": trg_length,
        } if self.decoder.decoding_strategy == "train_greedy" else ({
            "embedding_fn": self.trg_embeder,
            "beam_size": self.beam_size,
            "start_token": self.start_token,
            "end_token": self.end_token
        } if self.decoder.decoding_strategy == "beam_search" else {
            "embedding_fn": self.trg_embeder,
            "start_tokens": layers.fill_constant_batch_size_like(
                input=encoder_output,
                shape=[-1],
                dtype=src.dtype,
                value=self.start_token),
            "end_token": self.end_token
        })
        decoder_kwargs["output_layer"] = self.output_layer

        (decoder_output, decoder_final_state,
         dec_seq_lengths) = self.decoder(decoder_initial_states, encoder_output,
                                         encoder_padding_mask, **decoder_kwargs)
        if self.decoder.decoding_strategy == "beam_search":  # for inference
            return decoder_output
        logits, samples, sample_length = (decoder_output.cell_outputs,
                                          decoder_output.sample_ids,
                                          dec_seq_lengths)
        probs = layers.softmax(logits)
        return probs, samples, sample_length


class PolicyGradient(object):
    """policy gradient"""

    def __init__(self, lr=None):
        self.lr = lr

    def learn(self, act_prob, action, reward, length=None):
        """
        update policy model self.model with policy gradient algorithm
        """
        self.reward = fluid.layers.py_func(
            func=reward_func, x=[action, length], out=reward)
        neg_log_prob = layers.cross_entropy(act_prob, action)
        cost = neg_log_prob * reward
S
ShenLiang 已提交
261
        cost = (layers.reduce_sum(cost) / layers.reduce_sum(length)
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
                ) if length is not None else layers.reduce_mean(cost)
        optimizer = fluid.optimizer.Adam(self.lr)
        optimizer.minimize(cost)
        return cost


def reward_func(samples, sample_length):
    """toy reward"""

    def discount_reward(reward, sequence_length, discount=1.):
        return discount_reward_1d(reward, sequence_length, discount)

    def discount_reward_1d(reward, sequence_length, discount=1., dtype=None):
        if sequence_length is None:
            raise ValueError(
                'sequence_length must not be `None` for 1D reward.')
        reward = np.array(reward)
        sequence_length = np.array(sequence_length)
        batch_size = reward.shape[0]
        max_seq_length = np.max(sequence_length)
        dtype = dtype or reward.dtype
        if discount == 1.:
            dmat = np.ones([batch_size, max_seq_length], dtype=dtype)
G
Guo Sheng 已提交
285
        else:
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
            steps = np.tile(np.arange(max_seq_length), [batch_size, 1])
            mask = np.asarray(
                steps < (sequence_length - 1)[:, None], dtype=dtype)
            # Make each row = [discount, ..., discount, 1, ..., 1]
            dmat = mask * discount + (1 - mask)
            dmat = np.cumprod(dmat[:, ::-1], axis=1)[:, ::-1]
        disc_reward = dmat * reward[:, None]
        disc_reward = mask_sequences(disc_reward, sequence_length, dtype=dtype)
        return disc_reward

    def mask_sequences(sequence, sequence_length, dtype=None, time_major=False):
        sequence = np.array(sequence)
        sequence_length = np.array(sequence_length)
        rank = sequence.ndim
        if rank < 2:
            raise ValueError("`sequence` must be 2D or higher order.")
        batch_size = sequence.shape[0]
        max_time = sequence.shape[1]
        dtype = dtype or sequence.dtype
        if time_major:
            sequence = np.transpose(sequence, axes=[1, 0, 2])
        steps = np.tile(np.arange(max_time), [batch_size, 1])
        mask = np.asarray(steps < sequence_length[:, None], dtype=dtype)
        for _ in range(2, rank):
            mask = np.expand_dims(mask, -1)
        sequence = sequence * mask
        if time_major:
            sequence = np.transpose(sequence, axes=[1, 0, 2])
        return sequence

    samples = np.array(samples)
    sample_length = np.array(sample_length)
    # length reward
    reward = (5 - np.abs(sample_length - 5)).astype("float32")
    # repeat punishment to trapped into local minima getting all same words
    # beam search to get more than one sample may also can avoid this
    for i in range(reward.shape[0]):
        reward[i] += -10 if sample_length[i] > 1 and np.all(
            samples[i][:sample_length[i] - 1] == samples[i][0]) else 0
    return discount_reward(reward, sample_length, discount=1.).astype("float32")


class MLE(object):
    """teacher-forcing MLE training"""

    def __init__(self, lr=None):
        self.lr = lr

    def learn(self, probs, label, weight=None, length=None):
        loss = layers.cross_entropy(input=probs, label=label, soft_label=False)
        max_seq_len = layers.shape(probs)[1]
        mask = layers.sequence_mask(length, maxlen=max_seq_len, dtype="float32")
        loss = loss * mask
        loss = layers.reduce_mean(loss, dim=[0])
        loss = layers.reduce_sum(loss)
        optimizer = fluid.optimizer.Adam(self.lr)
        optimizer.minimize(loss)
        return loss


class SeqPGAgent(object):
    def __init__(self,
                 model_cls,
                 alg_cls=PolicyGradient,
                 model_hparams={},
                 alg_hparams={},
                 executor=None,
                 main_program=None,
                 startup_program=None,
                 seed=None):
        self.main_program = fluid.Program(
        ) if main_program is None else main_program
        self.startup_program = fluid.Program(
        ) if startup_program is None else startup_program
        if seed is not None:
            self.main_program.random_seed = seed
            self.startup_program.random_seed = seed
        self.build_program(model_cls, alg_cls, model_hparams, alg_hparams)
        self.executor = executor

    def build_program(self, model_cls, alg_cls, model_hparams, alg_hparams):
        with fluid.program_guard(self.main_program, self.startup_program):
            source = fluid.data(name="src", shape=[None, None], dtype="int64")
            source_length = fluid.data(
                name="src_sequence_length", shape=[None], dtype="int64")
            # only for teacher-forcing MLE training
            target = fluid.data(name="trg", shape=[None, None], dtype="int64")
            target_length = fluid.data(
                name="trg_sequence_length", shape=[None], dtype="int64")
            label = fluid.data(
                name="label", shape=[None, None, 1], dtype="int64")
            self.model = model_cls(**model_hparams)
            self.alg = alg_cls(**alg_hparams)
            self.probs, self.samples, self.sample_length = self.model(
                source, source_length, target, target_length)
            self.samples.stop_gradient = True
382
            self.reward = fluid.data(
383
                name="reward",
384
                shape=[None, None],  # batch_size, seq_len
385
                dtype=self.probs.dtype)
386
            self.samples.stop_gradient = False
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
            self.cost = self.alg.learn(self.probs, self.samples, self.reward,
                                       self.sample_length)

        # to define the same parameters between different programs
        self.pred_program = self.main_program._prune_with_input(
            [source.name, source_length.name],
            [self.probs, self.samples, self.sample_length])

    def predict(self, feed_dict):
        samples, sample_length = self.executor.run(
            self.pred_program,
            feed=feed_dict,
            fetch_list=[self.samples, self.sample_length])
        return samples, sample_length

    def learn(self, feed_dict, fetch_list):
        results = self.executor.run(self.main_program,
                                    feed=feed_dict,
                                    fetch_list=fetch_list)
        return results


class TestDynamicDecode(unittest.TestCase):
    def setUp(self):
        np.random.seed(123)
        self.model_hparams = {
            "num_layers": 2,
            "hidden_size": 32,
            "dropout_prob": 0.1,
            "src_vocab_size": 100,
            "trg_vocab_size": 100,
            "start_token": 0,
            "end_token": 1,
            "decoding_strategy": "infer_greedy",
            "max_decoding_length": 10
        }

        self.iter_num = iter_num = 2
        self.batch_size = batch_size = 4
        src_seq_len = 10
        trg_seq_len = 12
        self.data = {
            "src": np.random.randint(
                2, self.model_hparams["src_vocab_size"],
                (iter_num * batch_size, src_seq_len)).astype("int64"),
            "src_sequence_length": np.random.randint(
                1, src_seq_len, (iter_num * batch_size, )).astype("int64"),
            "trg": np.random.randint(
                2, self.model_hparams["src_vocab_size"],
                (iter_num * batch_size, trg_seq_len)).astype("int64"),
            "trg_sequence_length": np.random.randint(
                1, trg_seq_len, (iter_num * batch_size, )).astype("int64"),
            "label": np.random.randint(
                2, self.model_hparams["src_vocab_size"],
                (iter_num * batch_size, trg_seq_len, 1)).astype("int64"),
        }

        place = core.CUDAPlace(0) if core.is_compiled_with_cuda(
        ) else core.CPUPlace()
        self.exe = Executor(place)

    def test_mle_train(self):
449
        paddle.enable_static()
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
        self.model_hparams["decoding_strategy"] = "train_greedy"
        agent = SeqPGAgent(
            model_cls=Seq2SeqModel,
            alg_cls=MLE,
            model_hparams=self.model_hparams,
            alg_hparams={"lr": 0.001},
            executor=self.exe,
            main_program=fluid.Program(),
            startup_program=fluid.Program(),
            seed=123)
        self.exe.run(agent.startup_program)
        for iter_idx in range(self.iter_num):
            reward, cost = agent.learn(
                {
                    "src": self.data["src"][iter_idx * self.batch_size:(
                        iter_idx + 1) * self.batch_size, :],
                    "src_sequence_length": self.data["src_sequence_length"][
                        iter_idx * self.batch_size:(iter_idx + 1
                                                    ) * self.batch_size],
                    "trg": self.data["trg"][iter_idx * self.batch_size:(
                        iter_idx + 1) * self.batch_size, :],
                    "trg_sequence_length": self.data["trg_sequence_length"]
                    [iter_idx * self.batch_size:(iter_idx + 1) *
                     self.batch_size],
                    "label": self.data["label"][iter_idx * self.batch_size:(
                        iter_idx + 1) * self.batch_size]
                },
                fetch_list=[agent.cost, agent.cost])
            print("iter_idx: %d, reward: %f, cost: %f" %
                  (iter_idx, reward.mean(), cost))

    def test_greedy_train(self):
482
        paddle.enable_static()
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
        self.model_hparams["decoding_strategy"] = "infer_greedy"
        agent = SeqPGAgent(
            model_cls=Seq2SeqModel,
            alg_cls=PolicyGradient,
            model_hparams=self.model_hparams,
            alg_hparams={"lr": 0.001},
            executor=self.exe,
            main_program=fluid.Program(),
            startup_program=fluid.Program(),
            seed=123)
        self.exe.run(agent.startup_program)
        for iter_idx in range(self.iter_num):
            reward, cost = agent.learn(
                {
                    "src": self.data["src"][iter_idx * self.batch_size:(
                        iter_idx + 1) * self.batch_size, :],
                    "src_sequence_length": self.data["src_sequence_length"]
                    [iter_idx * self.batch_size:(iter_idx + 1) *
                     self.batch_size]
                },
                fetch_list=[agent.reward, agent.cost])
            print("iter_idx: %d, reward: %f, cost: %f" %
                  (iter_idx, reward.mean(), cost))

    def test_sample_train(self):
508
        paddle.enable_static()
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
        self.model_hparams["decoding_strategy"] = "infer_sample"
        agent = SeqPGAgent(
            model_cls=Seq2SeqModel,
            alg_cls=PolicyGradient,
            model_hparams=self.model_hparams,
            alg_hparams={"lr": 0.001},
            executor=self.exe,
            main_program=fluid.Program(),
            startup_program=fluid.Program(),
            seed=123)
        self.exe.run(agent.startup_program)
        for iter_idx in range(self.iter_num):
            reward, cost = agent.learn(
                {
                    "src": self.data["src"][iter_idx * self.batch_size:(
                        iter_idx + 1) * self.batch_size, :],
                    "src_sequence_length": self.data["src_sequence_length"]
                    [iter_idx * self.batch_size:(iter_idx + 1) *
                     self.batch_size]
                },
                fetch_list=[agent.reward, agent.cost])
            print("iter_idx: %d, reward: %f, cost: %f" %
                  (iter_idx, reward.mean(), cost))

    def test_beam_search_infer(self):
534 535
        paddle.set_default_dtype("float32")
        paddle.enable_static()
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
        self.model_hparams["decoding_strategy"] = "beam_search"
        main_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(main_program, startup_program):
            source = fluid.data(name="src", shape=[None, None], dtype="int64")
            source_length = fluid.data(
                name="src_sequence_length", shape=[None], dtype="int64")
            model = Seq2SeqModel(**self.model_hparams)
            output = model(source, source_length)

        self.exe.run(startup_program)
        for iter_idx in range(self.iter_num):
            trans_ids = self.exe.run(
                program=main_program,
                feed={
                    "src": self.data["src"][iter_idx * self.batch_size:(
                        iter_idx + 1) * self.batch_size, :],
                    "src_sequence_length": self.data["src_sequence_length"]
                    [iter_idx * self.batch_size:(iter_idx + 1) *
                     self.batch_size]
                },
                fetch_list=[output])[0]
G
Guo Sheng 已提交
558 559


560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
class ModuleApiTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._np_rand_state = np.random.get_state()
        cls._py_rand_state = random.getstate()
        cls._random_seed = 123
        np.random.seed(cls._random_seed)
        random.seed(cls._random_seed)

        cls.model_cls = type(cls.__name__ + "Model", (Layer, ), {
            "__init__": cls.model_init_wrapper(cls.model_init),
            "forward": cls.model_forward
        })

    @classmethod
    def tearDownClass(cls):
        np.random.set_state(cls._np_rand_state)
        random.setstate(cls._py_rand_state)

    @staticmethod
    def model_init_wrapper(func):
        def __impl__(self, *args, **kwargs):
            Layer.__init__(self)
            func(self, *args, **kwargs)

        return __impl__

    @staticmethod
    def model_init(model, *args, **kwargs):
        raise NotImplementedError(
            "model_init acts as `Model.__init__`, thus must implement it")

    @staticmethod
    def model_forward(model, *args, **kwargs):
        return model.module(*args, **kwargs)

    def make_inputs(self):
        # TODO(guosheng): add default from `self.inputs`
        raise NotImplementedError(
            "model_inputs makes inputs for model, thus must implement it")

    def setUp(self):
        """
        For the model which wraps the module to be tested:
            Set input data by `self.inputs` list
            Set init argument values by `self.attrs` list/dict
            Set model parameter values by `self.param_states` dict
            Set expected output data by `self.outputs` list
        We can create a model instance and run once with these.
        """
        self.inputs = []
        self.attrs = {}
        self.param_states = {}
        self.outputs = []

    def _calc_output(self, place, mode="test", dygraph=True):
        if dygraph:
            fluid.enable_dygraph(place)
        else:
            fluid.disable_dygraph()
C
cnn 已提交
620
        gen = paddle.seed(self._random_seed)
621 622 623 624 625 626 627 628 629 630
        gen._is_init_py = False
        paddle.framework.random._manual_program_seed(self._random_seed)
        scope = fluid.core.Scope()
        with fluid.scope_guard(scope):
            layer = self.model_cls(**self.attrs) if isinstance(
                self.attrs, dict) else self.model_cls(*self.attrs)
            model = Model(layer, inputs=self.make_inputs())
            model.prepare()
            if self.param_states:
                model.load(self.param_states, optim_state=None)
631
            return model.predict_batch(self.inputs)
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708

    def check_output_with_place(self, place, mode="test"):
        dygraph_output = self._calc_output(place, mode, dygraph=True)
        stgraph_output = self._calc_output(place, mode, dygraph=False)
        expect_output = getattr(self, "outputs", None)
        for actual_t, expect_t in zip(dygraph_output, stgraph_output):
            self.assertTrue(np.allclose(actual_t, expect_t, rtol=1e-5, atol=0))
        if expect_output:
            for actual_t, expect_t in zip(dygraph_output, expect_output):
                self.assertTrue(
                    np.allclose(
                        actual_t, expect_t, rtol=1e-5, atol=0))

    def check_output(self):
        devices = ["CPU", "GPU"] if fluid.is_compiled_with_cuda() else ["CPU"]
        for device in devices:
            place = set_device(device)
            self.check_output_with_place(place)


class TestBeamSearch(ModuleApiTest):
    def setUp(self):
        paddle.set_default_dtype("float64")
        shape = (8, 32)
        self.inputs = [
            np.random.random(shape).astype("float64"),
            np.random.random(shape).astype("float64")
        ]
        self.outputs = None
        self.attrs = {
            "vocab_size": 100,
            "embed_dim": 32,
            "hidden_size": 32,
        }
        self.param_states = {}

    @staticmethod
    def model_init(self,
                   vocab_size,
                   embed_dim,
                   hidden_size,
                   bos_id=0,
                   eos_id=1,
                   beam_size=2,
                   max_step_num=2):
        embedder = paddle.fluid.dygraph.Embedding(
            size=[vocab_size, embed_dim], dtype="float64")
        output_layer = nn.Linear(hidden_size, vocab_size)
        cell = nn.LSTMCell(embed_dim, hidden_size)
        self.max_step_num = max_step_num
        self.beam_search_decoder = BeamSearchDecoder(
            cell,
            start_token=bos_id,
            end_token=eos_id,
            beam_size=beam_size,
            embedding_fn=embedder,
            output_fn=output_layer)

    @staticmethod
    def model_forward(model, init_hidden, init_cell):
        return dynamic_decode(
            model.beam_search_decoder, [init_hidden, init_cell],
            max_step_num=model.max_step_num,
            impute_finished=True,
            is_test=True)[0]

    def make_inputs(self):
        inputs = [
            Input([None, self.inputs[0].shape[-1]], "float64", "init_hidden"),
            Input([None, self.inputs[1].shape[-1]], "float64", "init_cell"),
        ]
        return inputs

    def test_check_output(self):
        self.check_output()


G
Guo Sheng 已提交
709 710
if __name__ == '__main__':
    unittest.main()