提交 c04554d9 编写于 作者: X xiexionghang

add shrink

上级 73429ba9
......@@ -26,6 +26,10 @@ public:
virtual int32_t create(::paddle::framework::Scope* scope) {
return 0;
}
// 裁剪,用于模型裁剪,base级调用
virtual int32_t shrink() {
return 0;
}
// 前向, 一般用于填充输入,在训练网络执行前调用
virtual int32_t forward(SampleInstance* samples, size_t num,
......
......@@ -253,6 +253,13 @@ public:
var_data[i] += pull_raw[i + 2];
}
}
// 裁剪,用于模型裁剪,base级调用
virtual int32_t shrink() {
auto* ps_client = _trainer_context->pslib->ps_client();
auto status = ps_client->shrink(_table_id);
return status.get();
}
virtual void post_process_input(float* var_data,
SparseInputVariable& variable, SampleInstance* samples, size_t num) {
......
......@@ -169,8 +169,23 @@ int LearnerProcess::run() {
//Step3. Dump Model For Delta&&Checkpoint
{
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
environment->barrier(EnvironmentRole::WORKER);
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
environment->barrier(EnvironmentRole::WORKER);
if (epoch_accessor->is_last_epoch(epoch_id) &&
environment->is_master_node(EnvironmentRole::WORKER)) {
paddle::platform::Timer timer;
timer.Start();
VLOG(2) << "Start shrink table";
for (auto& executor : _executors) {
const auto& table_accessors = executor->table_accessors();
for (auto& itr : table_accessors) {
CHECK(itr.second[0]->shrink() == 0);
}
}
VLOG(2) << "End shrink table, cost" << timer.ElapsedSec();
}
environment->barrier(EnvironmentRole::WORKER);
epoch_accessor->epoch_done(epoch_id);
environment->barrier(EnvironmentRole::WORKER);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册