From ae0b57ba6de1db98891a2f9d38456bf9208c020e Mon Sep 17 00:00:00 2001 From: linan17 Date: Wed, 11 Sep 2019 20:19:57 +0800 Subject: [PATCH] update for save way=3 Change-Id: I3b0c3753bf1dea211bd41931db19cb2550f03f6e --- .../train/custom_trainer/feed/accessor/epoch_accessor.cc | 3 +++ .../train/custom_trainer/feed/process/learner_process.cc | 6 +++++- paddle/fluid/train/custom_trainer/feed/trainer_context.h | 3 ++- 3 files changed, 10 insertions(+), 2 deletions(-) mode change 100644 => 100755 paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc old mode 100644 new mode 100755 index 71fd83a1..ca9d2091 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -150,6 +150,8 @@ namespace feed { return true; } return delta_id(epoch_id) % 24 == 0; + case ModelSaveWay::ModelSaveTrainCheckpointBase: + return is_last_epoch(epoch_id); } return false; } @@ -165,6 +167,7 @@ namespace feed { case ModelSaveWay::ModelSaveInferenceBase: return _trainer_context->file_system->path_join(_inference_model_path, string::format_string("%s/base", date.c_str())); + case ModelSaveWay::ModelSaveTrainCheckpointBase: case ModelSaveWay::ModelSaveTrainCheckpoint: return _trainer_context->file_system->path_join(_model_root_path, string::format_string("batch_model/%s", date_with_hour.c_str())); diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index 93948db9..e169863b 100755 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -175,7 +175,11 @@ int LearnerProcess::run() { { wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); environment->barrier(EnvironmentRole::WORKER); - wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); + if (epoch_accessor->is_last_epoch(epoch_id)) { + wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase); + } else { + wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); + } environment->barrier(EnvironmentRole::WORKER); if (epoch_accessor->is_last_epoch(epoch_id) && environment->is_master_node(EnvironmentRole::WORKER)) { diff --git a/paddle/fluid/train/custom_trainer/feed/trainer_context.h b/paddle/fluid/train/custom_trainer/feed/trainer_context.h index cb01dd4f..71b98be3 100755 --- a/paddle/fluid/train/custom_trainer/feed/trainer_context.h +++ b/paddle/fluid/train/custom_trainer/feed/trainer_context.h @@ -26,7 +26,8 @@ const uint32_t SecondsPerDay = 24 * 3600; enum class ModelSaveWay { ModelSaveTrainCheckpoint = 0, ModelSaveInferenceDelta = 1, - ModelSaveInferenceBase = 2 + ModelSaveInferenceBase = 2, + ModelSaveTrainCheckpointBase = 3, }; enum class TrainerStatus { -- GitLab