infer.py 1.9 KB
Newer Older
Y
Yi Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
import sys
import time
import math
import unittest
import contextlib
import numpy as np

import paddle.fluid as fluid
import paddle.v2 as paddle

import utils


def infer(test_reader, use_cuda, model_path):
    """ inference function """
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)

    with fluid.scope_guard(fluid.core.Scope()):
        infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
            model_path, exe)

        accum_cost = 0.0
        accum_words = 0
        t0 = time.time()
        for data in test_reader():
            src_wordseq = utils.to_lodtensor(map(lambda x: x[0], data), place)
            dst_wordseq = utils.to_lodtensor(map(lambda x: x[1], data), place)
            avg_cost = exe.run(
                infer_program,
                feed={"src_wordseq": src_wordseq,
                      "dst_wordseq": dst_wordseq},
                fetch_list=fetch_vars)

            nwords = src_wordseq.lod()[0][-1]

            cost = np.array(avg_cost) * nwords
            accum_cost += cost
            accum_words += nwords

        ppl = math.exp(accum_cost / accum_words)
        t1 = time.time()
        print("model:%s ppl:%.3f time_cost(s):%.2f" %
              (model_path, ppl, t1 - t0))


if __name__ == "__main__":
    if len(sys.argv) != 4:
        print("Usage: %s model_dir start_epoch last_epoch(inclusive)")
        exit(0)

    model_dir = sys.argv[1]
    try:
        start_index = int(sys.argv[2])
        last_index = int(sys.argv[3])
    except:
        print("Usage: %s model_dir start_epoch last_epoch(inclusive)")
        exit(-1)

    vocab, train_reader, test_reader = utils.prepare_data(
        batch_size=20, buffer_size=1000, word_freq_threshold=0)

    for epoch in xrange(start_index, last_index + 1):
        epoch_path = model_dir + "/epoch_" + str(epoch)
        infer(test_reader=test_reader, use_cuda=True, model_path=epoch_path)