提交 229964e4 编写于 作者: X xiexionghang

add force-dump in startup

上级 19b00fb2
...@@ -16,6 +16,7 @@ namespace feed { ...@@ -16,6 +16,7 @@ namespace feed {
int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
int ret = Process::initialize(context_ptr); int ret = Process::initialize(context_ptr);
auto& config = _context_ptr->trainer_config; auto& config = _context_ptr->trainer_config;
_startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
if (config["executor"]) { if (config["executor"]) {
_executors.resize(config["executor"].size()); _executors.resize(config["executor"].size());
for (size_t i = 0; i < _executors.size(); ++i) { for (size_t i = 0; i < _executors.size(); ++i) {
...@@ -26,7 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { ...@@ -26,7 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return 0; return 0;
} }
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
auto fs = _context_ptr->file_system; auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client(); auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get(); auto* environment = _context_ptr->environment.get();
...@@ -34,7 +35,7 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { ...@@ -34,7 +35,7 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
if (!environment->is_master_node(EnvironmentRole::WORKER)) { if (!environment->is_master_node(EnvironmentRole::WORKER)) {
return 0; return 0;
} }
if (!epoch_accessor->need_save_model(epoch_id, way)) { if (!is_force_dump && !epoch_accessor->need_save_model(epoch_id, way)) {
return 0; return 0;
} }
paddle::platform::Timer timer; paddle::platform::Timer timer;
...@@ -112,8 +113,8 @@ int LearnerProcess::run() { ...@@ -112,8 +113,8 @@ int LearnerProcess::run() {
CHECK(load_model(epoch_id) == 0); CHECK(load_model(epoch_id) == 0);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
//判断是否先dump出base //判断是否先dump出base TODO
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase, _startup_dump_inference_base);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
while (true) { while (true) {
......
...@@ -20,10 +20,11 @@ public: ...@@ -20,10 +20,11 @@ public:
protected: protected:
// 加载所有模型 // 加载所有模型
virtual int load_model(uint64_t epoch_id); virtual int load_model(uint64_t epoch_id);
// 同步保存所有模型 // 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way); virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false);
private: private:
bool _startup_dump_inference_base; //启动立即dump base
std::vector<std::shared_ptr<MultiThreadExecutor>> _executors; std::vector<std::shared_ptr<MultiThreadExecutor>> _executors;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册