infer.py 1.4 KB
Newer Older
G
gmcather 已提交
1 2 3 4
import sys
import time
import unittest
import contextlib
G
gmcather 已提交
5 6
import numpy as np

Y
Yibing Liu 已提交
7
import paddle
G
gmcather 已提交
8 9
import paddle.fluid as fluid

G
gmcather 已提交
10 11 12
import utils


G
gmcather 已提交
13
def infer(test_reader, use_cuda, model_path=None):
G
gmcather 已提交
14 15 16
    """
    inference function
    """
G
gmcather 已提交
17 18
    if model_path is None:
        print(str(model_path) + " cannot be found")
G
gmcather 已提交
19 20 21 22
        return

    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
G
gmcather 已提交
23

G
gmcather 已提交
24 25 26
    inference_scope = fluid.core.Scope()
    with fluid.scope_guard(inference_scope):
        [inference_program, feed_target_names,
G
gmcather 已提交
27
         fetch_targets] = fluid.io.load_inference_model(model_path, exe)
G
gmcather 已提交
28 29 30 31 32

        total_acc = 0.0
        total_count = 0
        for data in test_reader():
            acc = exe.run(inference_program,
G
gmcather 已提交
33 34 35
                          feed=utils.data2tensor(data, place),
                          fetch_list=fetch_targets,
                          return_numpy=True)
G
gmcather 已提交
36 37
            total_acc += acc[0] * len(data)
            total_count += len(data)
G
gmcather 已提交
38

G
gmcather 已提交
39 40
        avg_acc = total_acc / total_count
        print("model_path: %s, avg_acc: %f" % (model_path, avg_acc))
G
gmcather 已提交
41 42 43 44


if __name__ == "__main__":
    word_dict, train_reader, test_reader = utils.prepare_data(
G
gmcather 已提交
45
        "imdb", self_dict=False, batch_size=128, buf_size=50000)
G
gmcather 已提交
46 47

    model_path = sys.argv[1]
G
gmcather 已提交
48 49
    for i in range(30):
        epoch_path = model_path + "/" + "epoch" + str(i)
G
gmcather 已提交
50
        infer(test_reader, use_cuda=False, model_path=epoch_path)