提交 5fed50e8 编写于 作者: Z zhangwenhui03

fix infer.py style

上级 b4467d76
...@@ -34,7 +34,7 @@ def parse_args(): ...@@ -34,7 +34,7 @@ def parse_args():
return args return args
def infer(test_reader, vocab_tag, use_cuda, model_path): def infer(test_reader, vocab_tag, use_cuda, model_path, epoch):
""" inference function """ """ inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -48,6 +48,7 @@ def infer(test_reader, vocab_tag, use_cuda, model_path): ...@@ -48,6 +48,7 @@ def infer(test_reader, vocab_tag, use_cuda, model_path):
all_num = 0 all_num = 0
size = vocab_tag size = vocab_tag
value = [] value = []
print("epoch " + str(epoch) + " start")
for data in test_reader(): for data in test_reader():
step_id += 1 step_id += 1
lod_text_seq = utils.to_lodtensor([dat[0] for dat in data], place) lod_text_seq = utils.to_lodtensor([dat[0] for dat in data], place)
...@@ -65,8 +66,7 @@ def infer(test_reader, vocab_tag, use_cuda, model_path): ...@@ -65,8 +66,7 @@ def infer(test_reader, vocab_tag, use_cuda, model_path):
if value.index(max(value)) == int(true_pos): if value.index(max(value)) == int(true_pos):
true_num += 1 true_num += 1
value = [] value = []
if step_id % 1000 == 0: print("epoch:" + str(epoch) + "\tacc:" + str(1.0 * true_num / all_num))
print(step_id, 1.0 * true_num / all_num)
t1 = time.time() t1 = time.time()
...@@ -95,4 +95,5 @@ if __name__ == "__main__": ...@@ -95,4 +95,5 @@ if __name__ == "__main__":
test_reader=test_reader, test_reader=test_reader,
vocab_tag=vocab_tag, vocab_tag=vocab_tag,
use_cuda=False, use_cuda=False,
model_path=epoch_path) model_path=epoch_path,
epoch=epoch)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册