From c04554d9e677e0c954bf337584dfbd152cd0830c Mon Sep 17 00:00:00 2001 From: xiexionghang Date: Mon, 2 Sep 2019 13:33:46 +0800 Subject: [PATCH] add shrink --- .../feed/accessor/input_data_accessor.h | 4 ++++ .../feed/accessor/sparse_input_accessor.cc | 7 +++++++ .../feed/process/learner_process.cc | 15 +++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h index b49eb469..45c3d051 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h @@ -26,6 +26,10 @@ public: virtual int32_t create(::paddle::framework::Scope* scope) { return 0; } + // 裁剪,用于模型裁剪,base级调用 + virtual int32_t shrink() { + return 0; + } // 前向, 一般用于填充输入,在训练网络执行前调用 virtual int32_t forward(SampleInstance* samples, size_t num, diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc index 35be38bf..68fb7f78 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc @@ -253,6 +253,13 @@ public: var_data[i] += pull_raw[i + 2]; } } + + // 裁剪,用于模型裁剪,base级调用 + virtual int32_t shrink() { + auto* ps_client = _trainer_context->pslib->ps_client(); + auto status = ps_client->shrink(_table_id); + return status.get(); + } virtual void post_process_input(float* var_data, SparseInputVariable& variable, SampleInstance* samples, size_t num) { diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index 0f7b204e..367f340b 100755 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -169,8 +169,23 @@ int LearnerProcess::run() { //Step3. Dump Model For Delta&&Checkpoint { wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); + environment->barrier(EnvironmentRole::WORKER); wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); environment->barrier(EnvironmentRole::WORKER); + if (epoch_accessor->is_last_epoch(epoch_id) && + environment->is_master_node(EnvironmentRole::WORKER)) { + paddle::platform::Timer timer; + timer.Start(); + VLOG(2) << "Start shrink table"; + for (auto& executor : _executors) { + const auto& table_accessors = executor->table_accessors(); + for (auto& itr : table_accessors) { + CHECK(itr.second[0]->shrink() == 0); + } + } + VLOG(2) << "End shrink table, cost" << timer.ElapsedSec(); + } + environment->barrier(EnvironmentRole::WORKER); epoch_accessor->epoch_done(epoch_id); environment->barrier(EnvironmentRole::WORKER); -- GitLab