diff --git a/run_classifier.py b/run_classifier.py index 0de67020cfc65ce65fa4ee5dca0469ce253a17a5..47286f5a5761e57eb5bf3cf6894620f375f7a134 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -426,13 +426,14 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars, save_dir = os.path.dirname(save_path) if not os.path.exists(save_dir): os.makedirs(save_dir) - else: - log.warning('save dir exsits: %s, will skip saving' % save_dir) - with open(save_path, 'w') as f: - for id, s, p in zip(qids, preds, probs): - f.write('{}\t{}\t{}\n'.format(id, s, p)) + if len(qids) == 0: + for s, p in zip(preds, probs): + f.write('{}\t{}\n'.format(s, p)) + else: + for id, s, p in zip(qids, preds, probs): + f.write('{}\t{}\t{}\n'.format(id, s, p)) if __name__ == '__main__':