From c6ac8a9ca1a1ddcc6ba12e6913e631e2da79b10c Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Fri, 6 Sep 2019 23:51:02 +0800 Subject: [PATCH] fix: inference classifier --- run_classifier.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/run_classifier.py b/run_classifier.py index 0de6702..47286f5 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -426,13 +426,14 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, save_dir = os.path.dirname(save_path) if not os.path.exists(save_dir): os.makedirs(save_dir) - else: - log.warning('save dir exsits: %s, will skip saving' % save_dir) - with open(save_path, 'w') as f: - for id, s, p in zip(qids, preds, probs): - f.write('{}\t{}\t{}\n'.format(id, s, p)) + if len(qids) == 0: + for s, p in zip(preds, probs): + f.write('{}\t{}\n'.format(s, p)) + else: + for id, s, p in zip(qids, preds, probs): + f.write('{}\t{}\t{}\n'.format(id, s, p)) if __name__ == '__main__': -- GitLab