diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h index b49eb469abe8338728c7fd36eeb86357df281c43..45c3d051a1d0d0584f193aad35f722dc96cf9371 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h @@ -26,6 +26,10 @@ public: virtual int32_t create(::paddle::framework::Scope* scope) { return 0; } + // 裁剪,用于模型裁剪,base级调用 + virtual int32_t shrink() { + return 0; + } // 前向, 一般用于填充输入,在训练网络执行前调用 virtual int32_t forward(SampleInstance* samples, size_t num, diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc index 35be38bf315d3745be01473f2637ea57807ef499..68fb7f78cbdb4ccb6b378713c223f6ffbaa9d6f1 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc @@ -253,6 +253,13 @@ public: var_data[i] += pull_raw[i + 2]; } } + + // 裁剪,用于模型裁剪,base级调用 + virtual int32_t shrink() { + auto* ps_client = _trainer_context->pslib->ps_client(); + auto status = ps_client->shrink(_table_id); + return status.get(); + } virtual void post_process_input(float* var_data, SparseInputVariable& variable, SampleInstance* samples, size_t num) { diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index 0f7b204efc0f6509359dd6e6493dcf5164fba87a..367f340b20c434f5f103292b12f90f161e3cb555 100755 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -169,8 +169,23 @@ int LearnerProcess::run() { //Step3. Dump Model For Delta&&Checkpoint { wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); + environment->barrier(EnvironmentRole::WORKER); wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); environment->barrier(EnvironmentRole::WORKER); + if (epoch_accessor->is_last_epoch(epoch_id) && + environment->is_master_node(EnvironmentRole::WORKER)) { + paddle::platform::Timer timer; + timer.Start(); + VLOG(2) << "Start shrink table"; + for (auto& executor : _executors) { + const auto& table_accessors = executor->table_accessors(); + for (auto& itr : table_accessors) { + CHECK(itr.second[0]->shrink() == 0); + } + } + VLOG(2) << "End shrink table, cost" << timer.ElapsedSec(); + } + environment->barrier(EnvironmentRole::WORKER); epoch_accessor->epoch_done(epoch_id); environment->barrier(EnvironmentRole::WORKER);