未验证 提交 750e7b77 编写于 作者: Z zhang wenhui 提交者: GitHub

Merge pull request #1886 from frankwhzhang/fix_bug

fix infer.py style
......@@ -34,7 +34,7 @@ def parse_args():
return args
def infer(test_reader, vocab_tag, use_cuda, model_path):
def infer(test_reader, vocab_tag, use_cuda, model_path, epoch):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -48,6 +48,7 @@ def infer(test_reader, vocab_tag, use_cuda, model_path):
all_num = 0
size = vocab_tag
value = []
print("epoch " + str(epoch) + " start")
for data in test_reader():
step_id += 1
lod_text_seq = utils.to_lodtensor([dat[0] for dat in data], place)
......@@ -65,8 +66,7 @@ def infer(test_reader, vocab_tag, use_cuda, model_path):
if value.index(max(value)) == int(true_pos):
true_num += 1
value = []
if step_id % 1000 == 0:
print(step_id, 1.0 * true_num / all_num)
print("epoch:" + str(epoch) + "\tacc:" + str(1.0 * true_num / all_num))
t1 = time.time()
......@@ -95,4 +95,5 @@ if __name__ == "__main__":
test_reader=test_reader,
vocab_tag=vocab_tag,
use_cuda=False,
model_path=epoch_path)
model_path=epoch_path,
epoch=epoch)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册