infer.py 3.2 KB
Newer Older
F
frankwhzhang 已提交
1
import argparse
F
frankwhzhang 已提交
2 3 4 5 6 7 8 9 10 11 12 13
import sys
import time
import math
import unittest
import contextlib
import numpy as np
import six
import paddle.fluid as fluid
import paddle

import utils

Z
fix bug  
zhangwenhui03 已提交
14

F
frankwhzhang 已提交
15 16 17 18 19
def parse_args():
    parser = argparse.ArgumentParser("gru4rec benchmark.")
    parser.add_argument(
        '--test_dir', type=str, default='test_data', help='test file address')
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
20 21 22
        '--start_index', type=int, default='1', help='start index')
    parser.add_argument(
        '--last_index', type=int, default='10', help='end index')
F
frankwhzhang 已提交
23
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
24
        '--model_dir', type=str, default='model_recall20', help='model dir')
F
frankwhzhang 已提交
25
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
26
        '--use_cuda', type=int, default='0', help='whether use cuda')
F
frankwhzhang 已提交
27
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
28
        '--batch_size', type=int, default='5', help='batch_size')
F
frankwhzhang 已提交
29
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
30
        '--vocab_path', type=str, default='vocab.txt', help='vocab file')
F
frankwhzhang 已提交
31 32
    args = parser.parse_args()
    return args
F
frankwhzhang 已提交
33

Z
fix bug  
zhangwenhui03 已提交
34

F
frankwhzhang 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
def infer(test_reader, use_cuda, model_path):
    """ inference function """
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)

    with fluid.scope_guard(fluid.core.Scope()):
        infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
            model_path, exe)
        accum_num_recall = 0.0
        accum_num_sum = 0.0
        t0 = time.time()
        step_id = 0
        for data in test_reader():
            step_id += 1
            src_wordseq = utils.to_lodtensor([dat[0] for dat in data], place)
            label_data = [dat[1] for dat in data]
            dst_wordseq = utils.to_lodtensor(label_data, place)
            para = exe.run(
                infer_program,
                feed={"src_wordseq": src_wordseq,
                      "dst_wordseq": dst_wordseq},
                fetch_list=fetch_vars,
                return_numpy=False)

            acc_ = para[1]._get_float_element(0)
            data_length = len(
                np.concatenate(
                    label_data, axis=0).astype("int64"))
            accum_num_sum += (data_length)
            accum_num_recall += (data_length * acc_)
F
frankwhzhang 已提交
65
            if step_id % 1 == 0:
Z
zhangwenhui03 已提交
66 67
                print("step:%d  recall@20:%.4f" %
                      (step_id, accum_num_recall / accum_num_sum))
F
frankwhzhang 已提交
68 69 70 71 72 73
        t1 = time.time()
        print("model:%s recall@20:%.3f time_cost(s):%.2f" %
              (model_path, accum_num_recall / accum_num_sum, t1 - t0))


if __name__ == "__main__":
F
frankwhzhang 已提交
74 75 76 77 78 79
    args = parse_args()
    start_index = args.start_index
    last_index = args.last_index
    test_dir = args.test_dir
    model_dir = args.model_dir
    batch_size = args.batch_size
Z
fix bug  
zhangwenhui03 已提交
80
    vocab_path = args.vocab_path
F
frankwhzhang 已提交
81
    use_cuda = True if args.use_cuda else False
Z
fix bug  
zhangwenhui03 已提交
82
    print("start index: ", start_index, " last_index:", last_index)
F
frankwhzhang 已提交
83
    vocab_size, test_reader = utils.prepare_data(
Z
fix bug  
zhangwenhui03 已提交
84 85 86 87 88 89
        test_dir,
        vocab_path,
        batch_size=batch_size,
        buffer_size=1000,
        word_freq_threshold=0,
        is_train=False)
F
frankwhzhang 已提交
90

91
    for epoch in range(start_index, last_index + 1):
F
frankwhzhang 已提交
92
        epoch_path = model_dir + "/epoch_" + str(epoch)
F
frankwhzhang 已提交
93
        infer(test_reader=test_reader, use_cuda=use_cuda, model_path=epoch_path)