generate.py 2.4 KB
Newer Older
C
caoying03 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
#!/usr/bin/env python
import os
from network_conf import *


def infer_a_batch(inferer, test_batch, beam_size, src_dict, trg_dict):
    beam_result = inferer.infer(input=test_batch, field=["prob", "id"])

    # the delimited element of generated sequences is -1,
    # the first element of each generated sequence is the sequence length
    seq_list, seq = [], []
    for w in beam_result[1]:
        if w != -1:
            seq.append(w)
        else:
            seq_list.append(" ".join([trg_dict.get(w) for w in seq[1:]]))
            seq = []

    prob = beam_result[0]
    for i, sample in enumerate(test_batch):
        print("src:", " ".join([src_dict.get(w) for w in sample[0]]), "\n")
        for j in xrange(beam_size):
            print("prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j])
        print("\n")


def generate(source_dict_dim, target_dict_dim, model_path, batch_size):
    """
    Generating function for NMT

    :param source_dict_dim: size of source dictionary
    :type source_dict_dim: int
    :param target_dict_dim: size of target dictionary
    :type target_dict_dim: int
    :param model_path: path for inital model
    :type model_path: string
    """

    assert os.path.exists(model_path), "trained model does not exist."

    # step 1: prepare dictionary
    src_dict, trg_dict = paddle.dataset.wmt14.get_dict(source_dict_dim)
    beam_size = 5

    # step 2: load the trained model
    paddle.init(use_gpu=True, trainer_count=1)
    with gzip.open(model_path) as f:
        parameters = paddle.parameters.Parameters.from_tar(f)
    beam_gen = seq2seq_net(
        source_dict_dim,
        target_dict_dim,
        beam_size=beam_size,
        max_length=100,
        is_generating=True)
    inferer = paddle.inference.Inference(
        output_layer=beam_gen, parameters=parameters)

    # step 3: iterating over the testing dataset
    test_batch = []
    for idx, item in enumerate(paddle.dataset.wmt14.gen(source_dict_dim)()):
        test_batch.append([item[0]])
        if len(test_batch) == batch_size:
            infer_a_batch(inferer, test_batch, beam_size, src_dict, trg_dict)
            test_batch = []

    if len(test_batch):
        infer_a_batch(inferer, test_batch, beam_size, src_dict, trg_dict)
        test_batch = []


if __name__ == "__main__":
    generate(
        source_dict_dim=3000,
        target_dict_dim=3000,
        batch_size=5,
        model_path="models/nmt_without_att_params_batch_00001.tar.gz")