generate.py 2.7 KB
Newer Older
C
caoying03 已提交
1
import os
2 3 4 5 6 7 8
import logging
import numpy as np

from network_conf import seq2seq_net

logger = logging.getLogger("paddle")
logger.setLevel(logging.WARNING)
C
caoying03 已提交
9 10 11 12 13


def infer_a_batch(inferer, test_batch, beam_size, src_dict, trg_dict):
    beam_result = inferer.infer(input=test_batch, field=["prob", "id"])

14 15 16 17
    gen_sen_idx = np.where(beam_result[1] == -1)[0]
    assert len(gen_sen_idx) == len(test_batch) * beam_size

    start_pos, end_pos = 1, 0
C
caoying03 已提交
18
    for i, sample in enumerate(test_batch):
19 20 21
        print(" ".join([
            src_dict[w] for w in sample[0][1:-1]
        ]))  # skip the start and ending mark when print the source sentence
C
caoying03 已提交
22
        for j in xrange(beam_size):
23 24 25 26
            end_pos = gen_sen_idx[i * beam_size + j]
            print("%.4f\t%s" % (beam_result[0][i][j], " ".join(
                trg_dict[w] for w in beam_result[1][start_pos:end_pos])))
            start_pos = end_pos + 2
C
caoying03 已提交
27 28 29
        print("\n")


30 31
def generate(source_dict_dim, target_dict_dim, model_path, beam_size,
             batch_size):
C
caoying03 已提交
32
    """
C
caoying03 已提交
33
    Sequence generation for NMT.
C
caoying03 已提交
34 35 36 37 38 39 40

    :param source_dict_dim: size of source dictionary
    :type source_dict_dim: int
    :param target_dict_dim: size of target dictionary
    :type target_dict_dim: int
    :param model_path: path for inital model
    :type model_path: string
41 42 43 44
    :param beam_size: the expanson width in each generation setp
    :param beam_size: int
    :param batch_size: the number of training examples in one forward pass
    :param batch_size: int
C
caoying03 已提交
45 46 47 48 49 50 51 52
    """

    assert os.path.exists(model_path), "trained model does not exist."

    # step 1: prepare dictionary
    src_dict, trg_dict = paddle.dataset.wmt14.get_dict(source_dict_dim)

    # step 2: load the trained model
53
    paddle.init(use_gpu=False, trainer_count=1)
C
caoying03 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    with gzip.open(model_path) as f:
        parameters = paddle.parameters.Parameters.from_tar(f)
    beam_gen = seq2seq_net(
        source_dict_dim,
        target_dict_dim,
        beam_size=beam_size,
        max_length=100,
        is_generating=True)
    inferer = paddle.inference.Inference(
        output_layer=beam_gen, parameters=parameters)

    # step 3: iterating over the testing dataset
    test_batch = []
    for idx, item in enumerate(paddle.dataset.wmt14.gen(source_dict_dim)()):
        test_batch.append([item[0]])
        if len(test_batch) == batch_size:
            infer_a_batch(inferer, test_batch, beam_size, src_dict, trg_dict)
            test_batch = []

    if len(test_batch):
        infer_a_batch(inferer, test_batch, beam_size, src_dict, trg_dict)
        test_batch = []


if __name__ == "__main__":
    generate(
C
caoying03 已提交
80 81
        source_dict_dim=30000,
        target_dict_dim=30000,
82
        batch_size=20,
C
caoying03 已提交
83 84
        beam_size=3,
        model_path="models/nmt_without_att_params_batch_00100.tar.gz")