提交 c6ac8a9c 编写于 作者: C chenxuyi

fix: inference classifier

上级 9f6ab586
......@@ -426,13 +426,14 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
else:
log.warning('save dir exsits: %s, will skip saving' % save_dir)
with open(save_path, 'w') as f:
for id, s, p in zip(qids, preds, probs):
f.write('{}\t{}\t{}\n'.format(id, s, p))
if len(qids) == 0:
for s, p in zip(preds, probs):
f.write('{}\t{}\n'.format(s, p))
else:
for id, s, p in zip(qids, preds, probs):
f.write('{}\t{}\t{}\n'.format(id, s, p))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册