infer.py 1.4 KB
Newer Older
G
gmcather 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
"""
For http://wiki.baidu.com/display/LegoNet/Text+Classification
"""
import paddle.fluid as fluid
import paddle.v2 as paddle
import numpy as np
import sys
import time
import unittest
import contextlib
import utils


def infer(test_reader, use_cuda, 
        save_dirname=None):
    """
    inference function
    """
    if save_dirname is None:
        print(str(save_dirname) + " cannot be found")
        return

    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    
    inference_scope = fluid.core.Scope()
    with fluid.scope_guard(inference_scope):
        [inference_program, feed_target_names,
        fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)

        total_acc = 0.0
        total_count = 0
        for data in test_reader():
            acc = exe.run(inference_program,
                    feed = utils.data2tensor(data, place),
                    fetch_list=fetch_targets,
                    return_numpy=True)
            total_acc += acc[0] * len(data)
            total_count += len(data)
        print("test_acc: %f" % (total_acc / total_count))


if __name__ == "__main__":
    word_dict, train_reader, test_reader = utils.prepare_data(
            "imdb", self_dict = False,
            batch_size = 128, buf_size = 50000)

    model_path = sys.argv[1]
G
gmcather 已提交
49
    infer(test_reader, use_cuda=False,
G
gmcather 已提交
50 51
            save_dirname=model_path)