diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc old mode 100644 new mode 100755 index 71fd83a13fedf81428479eca9d7c1e6a035b45f9..ca9d2091397a27dd61a5ec4c4c4b12e4f7f24784 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -150,6 +150,8 @@ namespace feed { return true; } return delta_id(epoch_id) % 24 == 0; + case ModelSaveWay::ModelSaveTrainCheckpointBase: + return is_last_epoch(epoch_id); } return false; } @@ -165,6 +167,7 @@ namespace feed { case ModelSaveWay::ModelSaveInferenceBase: return _trainer_context->file_system->path_join(_inference_model_path, string::format_string("%s/base", date.c_str())); + case ModelSaveWay::ModelSaveTrainCheckpointBase: case ModelSaveWay::ModelSaveTrainCheckpoint: return _trainer_context->file_system->path_join(_model_root_path, string::format_string("batch_model/%s", date_with_hour.c_str())); diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index 93948db9b2458c1e434a86532dc725a2560debef..e169863b2bb05d956ee8ba7dfe18dbd98f19c011 100755 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -175,7 +175,11 @@ int LearnerProcess::run() { { wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); environment->barrier(EnvironmentRole::WORKER); - wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); + if (epoch_accessor->is_last_epoch(epoch_id)) { + wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase); + } else { + wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); + } environment->barrier(EnvironmentRole::WORKER); if (epoch_accessor->is_last_epoch(epoch_id) && environment->is_master_node(EnvironmentRole::WORKER)) { diff --git a/paddle/fluid/train/custom_trainer/feed/trainer_context.h b/paddle/fluid/train/custom_trainer/feed/trainer_context.h index cb01dd4f9ab1fde3994b49fc444c5847ba1f5dd3..71b98be3e77871d75f6b0510537767c4f41365ea 100755 --- a/paddle/fluid/train/custom_trainer/feed/trainer_context.h +++ b/paddle/fluid/train/custom_trainer/feed/trainer_context.h @@ -26,7 +26,8 @@ const uint32_t SecondsPerDay = 24 * 3600; enum class ModelSaveWay { ModelSaveTrainCheckpoint = 0, ModelSaveInferenceDelta = 1, - ModelSaveInferenceBase = 2 + ModelSaveInferenceBase = 2, + ModelSaveTrainCheckpointBase = 3, }; enum class TrainerStatus {