diff --git a/tools/eval.py b/tools/eval.py index db5ce4eec1d55ae4a4c45f2f4def4b7429b7d4ca..e95b42e3914611482a55b1b9e1f052de606d9bef 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -73,7 +73,7 @@ def main(args): valid_dataloader.set_sample_list_generator(valid_reader, place) compiled_valid_prog = program.compile(config, valid_prog) - program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, 0, + program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1, 'valid') diff --git a/tools/program.py b/tools/program.py index 00fe3aeb6859b9d0a3b5ef0dac54223e2d79cd72..5a5745953ec500df408ad6417e7f62ea49a44952 100644 --- a/tools/program.py +++ b/tools/program.py @@ -385,7 +385,15 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): metric_list[i].update(m[0], len(batch[0])) fetchs_str = ''.join([str(m.value)+' ' for m in metric_list]+ [batch_time.value]) - logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format( + if epoch != -1: + logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format( epoch, mode, idx, fetchs_str)) + else: + logger.info("{:s} step:{:<4d} {:s}s".format( + mode, idx, fetchs_str)) + end_str = ''.join([str(m.mean)+' ' for m in metric_list] + [batch_time.total]) - logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) + if epoch!= -1: + logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) + else: + logger.info("END {:s} {:s}s".format(mode, end_str))