""" For http://wiki.baidu.com/display/LegoNet/Text+Classification """ import paddle.fluid as fluid import paddle.v2 as paddle import numpy as np import sys import time import unittest import contextlib import utils def infer(test_reader, use_cuda, save_dirname=None): """ inference function """ if save_dirname is None: print(str(save_dirname) + " cannot be found") return place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) inference_scope = fluid.core.Scope() with fluid.scope_guard(inference_scope): [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) total_acc = 0.0 total_count = 0 for data in test_reader(): acc = exe.run(inference_program, feed = utils.data2tensor(data, place), fetch_list=fetch_targets, return_numpy=True) total_acc += acc[0] * len(data) total_count += len(data) print("test_acc: %f" % (total_acc / total_count)) if __name__ == "__main__": word_dict, train_reader, test_reader = utils.prepare_data( "imdb", self_dict = False, batch_size = 128, buf_size = 50000) model_path = sys.argv[1] test(test_reader, use_cuda=True, save_dirname=model_path)