generate.py 3.5 KB
Newer Older
1 2 3 4 5
import os
import sys
import gzip
import logging
import numpy as np
C
caoying03 已提交
6
import click
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30

import reader
import paddle.v2 as paddle
from paddle.v2.layer import parse_network
from network_conf import encoder_decoder_network

logger = logging.getLogger("paddle")
logger.setLevel(logging.WARNING)


def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout):
    beam_result = inferer.infer(input=test_batch, field=["prob", "id"])
    gen_sen_idx = np.where(beam_result[1] == -1)[0]
    assert len(gen_sen_idx) == len(test_batch) * beam_size, ("%d vs. %d" % (
        len(gen_sen_idx), len(test_batch) * beam_size))

    start_pos, end_pos = 1, 0
    for i, sample in enumerate(test_batch):
        fout.write("%s\n" % (
            " ".join([id_to_text[w] for w in sample[0][1:-1]])
        ))  # skip the start and ending mark when print the source sentence
        for j in xrange(beam_size):
            end_pos = gen_sen_idx[i * beam_size + j]
            fout.write("%s\n" % ("%.4f\t%s" % (beam_result[0][i][j], " ".join(
31
                id_to_text[w] for w in beam_result[1][start_pos:end_pos - 1]))))
32 33 34 35 36
            start_pos = end_pos + 2
        fout.write("\n")
        fout.flush


C
caoying03 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
@click.command("generate")
@click.option(
    "--model_path",
    default="",
    help="The path of the trained model for generation.")
@click.option(
    "--word_dict_path", required=True, help="The path of word dictionary.")
@click.option(
    "--test_data_path",
    required=True,
    help="The path of input data for generation.")
@click.option(
    "--batch_size",
    default=1,
    help="The number of testing examples in one forward pass in generation.")
@click.option(
    "--beam_size", default=5, help="The beam expansion in beam search.")
@click.option(
    "--save_file",
    required=True,
    help="The file path to save the generated results.")
@click.option(
    "--use_gpu", default=False, help="Whether to use GPU in generation.")
60 61
def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
             save_file, use_gpu):
C
caoying03 已提交
62 63
    assert os.path.exists(model_path), "The given model does not exist."
    assert os.path.exists(test_data_path), "The given test data does not exist."
64 65 66 67 68

    with gzip.open(model_path, "r") as f:
        parameters = paddle.parameters.Parameters.from_tar(f)

    id_to_text = {}
C
caoying03 已提交
69 70
    assert os.path.exists(
        word_dict_path), "The given word dictionary path does not exist."
71 72 73 74
    with open(word_dict_path, "r") as f:
        for i, line in enumerate(f):
            id_to_text[i] = line.strip().split("\t")[0]

C
caoying03 已提交
75
    paddle.init(use_gpu=use_gpu, trainer_count=1)
76 77 78 79 80 81 82
    beam_gen = encoder_decoder_network(
        word_count=len(id_to_text),
        emb_dim=512,
        encoder_depth=3,
        encoder_hidden_dim=512,
        decoder_depth=3,
        decoder_hidden_dim=512,
83 84
        bos_id=0,
        eos_id=1,
85
        max_length=9,
86
        beam_size=beam_size,
87
        is_generating=True)
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106

    inferer = paddle.inference.Inference(
        output_layer=beam_gen, parameters=parameters)

    test_batch = []
    with open(save_file, "w") as fout:
        for idx, item in enumerate(
                reader.gen_reader(test_data_path, word_dict_path)()):
            test_batch.append([item])
            if len(test_batch) == batch_size:
                infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout)
                test_batch = []

        if len(test_batch):
            infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout)
            test_batch = []


if __name__ == "__main__":
C
caoying03 已提交
107
    generate()