From b8ee0d6ab8a557e981ba36ff4eac405087463e8a Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Sun, 27 Sep 2020 16:07:15 +0800 Subject: [PATCH] format save_step output information --- core/trainers/framework/runner.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 9064989a..d3a02982 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -420,7 +420,7 @@ class RunnerBase(object): dirname = envs.get_global_env(name + "save_checkpoint_path", None) if dirname is None or dirname == "": return - dirname = os.path.join(dirname, str(epoch_id)) + dirname = os.path.join(dirname, "epoch_" + str(epoch_id)) logging.info("\tsave epoch_id:%d model into: \"%s\"" % (epoch_id, dirname)) if is_fleet: @@ -436,9 +436,11 @@ class RunnerBase(object): dirname = envs.get_global_env(name + "save_step_path", None) if dirname is None or dirname == "": return - dirname = os.path.join(dirname, str(batch_id)) - logging.info("\tsave batch_id:%d model into: \"%s\"" % - (batch_id, dirname)) + dirname = os.path.join(dirname, + "epoch_" + str(context["current_epoch"]) + + "_batch_" + str(batch_id)) + logging.info("\tsave epoch_id:%d, batch_id:%d model into: \"%s\"" % + (context["current_epoch"], batch_id, dirname)) if is_fleet: if context["fleet"].worker_index() == 0: context["fleet"].save_persistables(context["exe"], dirname) -- GitLab