diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 9064989ad467d3104adf50d64300b7903376e7ce..d3a02982a5eb85f5270310847df997e8bfa3be4c 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -420,7 +420,7 @@ class RunnerBase(object): dirname = envs.get_global_env(name + "save_checkpoint_path", None) if dirname is None or dirname == "": return - dirname = os.path.join(dirname, str(epoch_id)) + dirname = os.path.join(dirname, "epoch_" + str(epoch_id)) logging.info("\tsave epoch_id:%d model into: \"%s\"" % (epoch_id, dirname)) if is_fleet: @@ -436,9 +436,11 @@ class RunnerBase(object): dirname = envs.get_global_env(name + "save_step_path", None) if dirname is None or dirname == "": return - dirname = os.path.join(dirname, str(batch_id)) - logging.info("\tsave batch_id:%d model into: \"%s\"" % - (batch_id, dirname)) + dirname = os.path.join(dirname, + "epoch_" + str(context["current_epoch"]) + + "_batch_" + str(batch_id)) + logging.info("\tsave epoch_id:%d, batch_id:%d model into: \"%s\"" % + (context["current_epoch"], batch_id, dirname)) if is_fleet: if context["fleet"].worker_index() == 0: context["fleet"].save_persistables(context["exe"], dirname)