generate.py 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
import gzip
import argparse
import distutils.util
import paddle.v2 as paddle

from network_conf import seqToseq_net


def parse_args():
    parser = argparse.ArgumentParser(
        description="PaddlePaddle Scheduled Sampling")
    parser.add_argument(
        '--model_path',
        type=str,
        required=True,
        help="The path for trained model to load.")
    parser.add_argument(
        '--beam_size',
        type=int,
        default=3,
        help='The width of beam expansion. (default: %(default)s)')
    parser.add_argument(
        "--use_gpu",
        type=distutils.util.strtobool,
        default=False,
        help="Use gpu or not. (default: %(default)s)")
    parser.add_argument(
        "--trainer_count",
        type=int,
        default=1,
        help="Trainer number. (default: %(default)s)")

    return parser.parse_args()


def generate(gen_data, dict_size, model_path, beam_size):
    beam_gen = seqToseq_net(dict_size, dict_size, beam_size, is_generating=True)

    with gzip.open(model_path, 'r') as f:
        parameters = paddle.parameters.Parameters.from_tar(f)

    # prob is the prediction probabilities, and id is the prediction word.
    beam_result = paddle.infer(
        output_layer=beam_gen,
        parameters=parameters,
        input=gen_data,
        field=['prob', 'id'])

    # get the dictionary
    src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)

    # the delimited element of generated sequences is -1,
    # the first element of each generated sequence is the sequence length
    seq_list = []
    seq = []
    for w in beam_result[1]:
        if w != -1:
            seq.append(w)
        else:
            seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]]))
            seq = []

    prob = beam_result[0]
    for i in xrange(gen_num):
        print "\n*******************************************************\n"
        print "src:", ' '.join([src_dict.get(w) for w in gen_data[i][0]]), "\n"
        for j in xrange(beam_size):
            print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]


if __name__ == '__main__':
    args = parse_args()

    dict_size = 30000

    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)

    # use the first 3 samples for generation
    gen_creator = paddle.dataset.wmt14.gen(dict_size)
    gen_data = []
    gen_num = 3
    for item in gen_creator():
        gen_data.append((item[0], ))
        if len(gen_data) == gen_num:
            break

    generate(
        gen_data,
        dict_size=dict_size,
        model_path=args.model_path,
        beam_size=args.beam_size)