From 229964e46efdf3cd187da6935749485c8fab6e5e Mon Sep 17 00:00:00 2001 From: xiexionghang Date: Thu, 12 Sep 2019 11:25:14 +0800 Subject: [PATCH] add force-dump in startup --- .../train/custom_trainer/feed/process/learner_process.cc | 9 +++++---- .../train/custom_trainer/feed/process/learner_process.h | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index e169863b..954c1011 100755 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -16,6 +16,7 @@ namespace feed { int LearnerProcess::initialize(std::shared_ptr context_ptr) { int ret = Process::initialize(context_ptr); auto& config = _context_ptr->trainer_config; + _startup_dump_inference_base = config["startup_dump_inference_base"].as(false); if (config["executor"]) { _executors.resize(config["executor"].size()); for (size_t i = 0; i < _executors.size(); ++i) { @@ -26,7 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr context_ptr) { return 0; } -int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { +int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) { auto fs = _context_ptr->file_system; auto* ps_client = _context_ptr->pslib->ps_client(); auto* environment = _context_ptr->environment.get(); @@ -34,7 +35,7 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { if (!environment->is_master_node(EnvironmentRole::WORKER)) { return 0; } - if (!epoch_accessor->need_save_model(epoch_id, way)) { + if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) { return 0; } paddle::platform::Timer timer; @@ -112,8 +113,8 @@ int LearnerProcess::run() { CHECK(load_model(epoch_id) == 0); environment->barrier(EnvironmentRole::WORKER); - //判断是否先dump出base - wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); + //判断是否先dump出base TODO + wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base); environment->barrier(EnvironmentRole::WORKER); while (true) { diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.h b/paddle/fluid/train/custom_trainer/feed/process/learner_process.h index 49b69562..2f0ccf4b 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.h +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.h @@ -20,10 +20,11 @@ public: protected: // 加载所有模型 virtual int load_model(uint64_t epoch_id); -// 同步保存所有模型 -virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way); +// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型 +virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false); private: + bool _startup_dump_inference_base; //启动立即dump base std::vector> _executors; }; -- GitLab