infer.py 3.2 KB
Newer Older
F
frankwhzhang 已提交
1
import argparse
F
frankwhzhang 已提交
2 3 4 5 6 7 8 9 10 11 12 13
import sys
import time
import math
import unittest
import contextlib
import numpy as np
import six
import paddle.fluid as fluid
import paddle

import utils

Z
fix bug  
zhangwenhui03 已提交
14

F
frankwhzhang 已提交
15 16 17 18 19
def parse_args():
    parser = argparse.ArgumentParser("gru4rec benchmark.")
    parser.add_argument(
        '--test_dir', type=str, default='test_data', help='test file address')
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
20 21 22
        '--start_index', type=int, default='1', help='start index')
    parser.add_argument(
        '--last_index', type=int, default='10', help='end index')
F
frankwhzhang 已提交
23
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
24
        '--model_dir', type=str, default='model_recall20', help='model dir')
F
frankwhzhang 已提交
25
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
26
        '--use_cuda', type=int, default='0', help='whether use cuda')
F
frankwhzhang 已提交
27
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
28
        '--batch_size', type=int, default='5', help='batch_size')
F
frankwhzhang 已提交
29
    parser.add_argument(
Z
fix bug  
zhangwenhui03 已提交
30
        '--vocab_path', type=str, default='vocab.txt', help='vocab file')
F
frankwhzhang 已提交
31 32
    args = parser.parse_args()
    return args
F
frankwhzhang 已提交
33

Z
fix bug  
zhangwenhui03 已提交
34

F
frankwhzhang 已提交
35 36 37 38 39
def infer(test_reader, use_cuda, model_path):
    """ inference function """
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)

Y
Yibing Liu 已提交
40
    with fluid.scope_guard(fluid.Scope()):
F
frankwhzhang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
            model_path, exe)
        accum_num_recall = 0.0
        accum_num_sum = 0.0
        t0 = time.time()
        step_id = 0
        for data in test_reader():
            step_id += 1
            src_wordseq = utils.to_lodtensor([dat[0] for dat in data], place)
            label_data = [dat[1] for dat in data]
            dst_wordseq = utils.to_lodtensor(label_data, place)
            para = exe.run(
                infer_program,
                feed={"src_wordseq": src_wordseq,
                      "dst_wordseq": dst_wordseq},
                fetch_list=fetch_vars,
                return_numpy=False)

            acc_ = para[1]._get_float_element(0)
            data_length = len(
                np.concatenate(
                    label_data, axis=0).astype("int64"))
            accum_num_sum += (data_length)
            accum_num_recall += (data_length * acc_)
F
frankwhzhang 已提交
65
            if step_id % 1 == 0:
Z
zhangwenhui03 已提交
66 67
                print("step:%d  recall@20:%.4f" %
                      (step_id, accum_num_recall / accum_num_sum))
F
frankwhzhang 已提交
68 69 70 71 72 73
        t1 = time.time()
        print("model:%s recall@20:%.3f time_cost(s):%.2f" %
              (model_path, accum_num_recall / accum_num_sum, t1 - t0))


if __name__ == "__main__":
74
    utils.check_version()
F
frankwhzhang 已提交
75 76 77 78 79 80
    args = parse_args()
    start_index = args.start_index
    last_index = args.last_index
    test_dir = args.test_dir
    model_dir = args.model_dir
    batch_size = args.batch_size
Z
fix bug  
zhangwenhui03 已提交
81
    vocab_path = args.vocab_path
F
frankwhzhang 已提交
82
    use_cuda = True if args.use_cuda else False
Z
fix bug  
zhangwenhui03 已提交
83
    print("start index: ", start_index, " last_index:", last_index)
F
frankwhzhang 已提交
84
    vocab_size, test_reader = utils.prepare_data(
Z
fix bug  
zhangwenhui03 已提交
85 86 87 88 89 90
        test_dir,
        vocab_path,
        batch_size=batch_size,
        buffer_size=1000,
        word_freq_threshold=0,
        is_train=False)
F
frankwhzhang 已提交
91

92
    for epoch in range(start_index, last_index + 1):
F
frankwhzhang 已提交
93
        epoch_path = model_dir + "/epoch_" + str(epoch)
F
frankwhzhang 已提交
94
        infer(test_reader=test_reader, use_cuda=use_cuda, model_path=epoch_path)