提交 540c5dc0 编写于 作者: R rensilin

save_model_params_local

Change-Id: I65ba0979c822db14c45a9c9fd6b00bc54e630cf3
上级 76e8be34
...@@ -70,6 +70,52 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) { ...@@ -70,6 +70,52 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) {
int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) { paddle::framework::Scope* scope) {
collect_persistables(scope);
if (_need_async_pull) {
++_pull_request_num;
}
return 0;
}
int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
if (!_need_gradient) {
return 0;
}
size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) {
auto* tensor = scope->Var(variable.gradient_name)->
GetMutable<paddle::framework::LoDTensor>();
auto* grad_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
regions.emplace_back(grad_data, variable.dim);
}
auto* ps_client = _trainer_context->pslib->ps_client();
auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id);
//push_status.get();
if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
std::stringstream ssm;
for (auto& variable : _x_variables) {
ssm.str("");
if (variable.name != FLAGS_feed_trainer_debug_dense_name) {
continue;
}
auto& tensor = scope->Var(variable.gradient_name)->
Get<paddle::framework::LoDTensor>();
const auto* var_data = tensor.data<float>();
for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) {
if (data_idx > 0)
ssm << ",";
ssm << var_data[data_idx];
}
VLOG(2) << "[DEBUG]push_dense: " << ssm.str();
}
}
return 0;
}
int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) {
// 首次同步pull,之后异步pull // 首次同步pull,之后异步pull
if (_data_buffer == nullptr) { if (_data_buffer == nullptr) {
_pull_mutex.lock(); _pull_mutex.lock();
...@@ -95,7 +141,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, ...@@ -95,7 +141,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
paddle::framework::DDim ddim(shape_ptr, variable.shape.size()); paddle::framework::DDim ddim(shape_ptr, variable.shape.size());
auto* tensor = ScopeHelper::resize_lod_tensor(scope, variable.name, ddim); auto* tensor = ScopeHelper::resize_lod_tensor(scope, variable.name, ddim);
auto* grad_tensor = ScopeHelper::resize_lod_tensor(scope, variable.gradient_name, ddim); auto* grad_tensor = ScopeHelper::resize_lod_tensor(scope, variable.gradient_name, ddim);
VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name; VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name
<< ", data_buffer: " << _data_buffer + data_buffer_idx
<< ", dim: " << variable.dim * sizeof(float);
auto* var_data = tensor->mutable_data<float>(_trainer_context->cpu_place); auto* var_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
memcpy(var_data, _data_buffer + data_buffer_idx, variable.dim * sizeof(float)); memcpy(var_data, _data_buffer + data_buffer_idx, variable.dim * sizeof(float));
data_buffer_idx += variable.dim; data_buffer_idx += variable.dim;
...@@ -120,45 +168,12 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, ...@@ -120,45 +168,12 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
VLOG(2) << "[DEBUG]pull_dense: " << ssm.str(); VLOG(2) << "[DEBUG]pull_dense: " << ssm.str();
} }
} }
if (_need_async_pull) {
++_pull_request_num;
}
return 0; return 0;
} }
int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>& persistables) {
paddle::framework::Scope* scope) {
if (!_need_gradient) {
return 0;
}
size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) { for (auto& variable : _x_variables) {
auto* tensor = scope->Var(variable.gradient_name)-> persistables.push_back(variable.name);
GetMutable<paddle::framework::LoDTensor>();
auto* grad_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
regions.emplace_back(grad_data, variable.dim);
}
auto* ps_client = _trainer_context->pslib->ps_client();
auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id);
//push_status.get();
if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
std::stringstream ssm;
for (auto& variable : _x_variables) {
ssm.str("");
if (variable.name != FLAGS_feed_trainer_debug_dense_name) {
continue;
}
auto& tensor = scope->Var(variable.gradient_name)->
Get<paddle::framework::LoDTensor>();
const auto* var_data = tensor.data<float>();
for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) {
if (data_idx > 0)
ssm << ",";
ssm << var_data[data_idx];
}
VLOG(2) << "[DEBUG]push_dense: " << ssm.str();
}
} }
return 0; return 0;
} }
......
...@@ -38,6 +38,12 @@ public: ...@@ -38,6 +38,12 @@ public:
// 后向,一般用于更新梯度,在训练网络执行后调用 // 后向,一般用于更新梯度,在训练网络执行后调用
virtual int32_t backward(SampleInstance* samples, size_t num, virtual int32_t backward(SampleInstance* samples, size_t num,
::paddle::framework::Scope* scope) = 0; ::paddle::framework::Scope* scope) = 0;
// 收集持久化变量的名称, 并将值拷贝到Scope
virtual int32_t collect_persistables_name(std::vector<std::string>& persistables) {return 0;}
// 填充持久化变量的值,用于保存
virtual int32_t collect_persistables(paddle::framework::Scope* scope) {return 0;}
protected: protected:
size_t _table_id = 0; size_t _table_id = 0;
bool _need_gradient = false; bool _need_gradient = false;
...@@ -144,6 +150,11 @@ public: ...@@ -144,6 +150,11 @@ public:
virtual int32_t backward(SampleInstance* samples, size_t num, virtual int32_t backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope); paddle::framework::Scope* scope);
virtual int32_t collect_persistables_name(std::vector<std::string>& persistables);
virtual int32_t collect_persistables(paddle::framework::Scope* scope);
protected: protected:
virtual int32_t pull_dense(size_t table_id); virtual int32_t pull_dense(size_t table_id);
......
...@@ -52,6 +52,7 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, ...@@ -52,6 +52,7 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
CHECK(_trainer_context->file_system->exists(model_config_path)) CHECK(_trainer_context->file_system->exists(model_config_path))
<< "miss model config file:" << model_config_path; << "miss model config file:" << model_config_path;
_model_config = YAML::LoadFile(model_config_path); _model_config = YAML::LoadFile(model_config_path);
_persistables.clear();
for (const auto& accessor_config : _model_config["input_accessor"]) { for (const auto& accessor_config : _model_config["input_accessor"]) {
auto accessor_class = accessor_config["class"].as<std::string>(); auto accessor_class = accessor_config["class"].as<std::string>();
auto* accessor_ptr = CREATE_INSTANCE(DataInputAccessor, accessor_class); auto* accessor_ptr = CREATE_INSTANCE(DataInputAccessor, accessor_class);
...@@ -66,7 +67,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, ...@@ -66,7 +67,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
_table_to_accessors[table_id] = {accessor_ptr}; _table_to_accessors[table_id] = {accessor_ptr};
} }
} }
CHECK(accessor_ptr->collect_persistables_name(_persistables) == 0)
<< "collect_persistables Failed, class:" << accessor_class;
} }
std::sort(_persistables.begin(), _persistables.end()); // 持久化变量名一定要排序
// Monitor组件 // Monitor组件
for (const auto& monitor_config : _model_config["monitor"]) { for (const auto& monitor_config : _model_config["monitor"]) {
...@@ -79,6 +83,27 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, ...@@ -79,6 +83,27 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
return ret; return ret;
} }
int32_t MultiThreadExecutor::save_persistables(const std::string& filename) {
// auto fs = _trainer_context->file_system;
// fs->mkdir(fs->path_split(filename).first);
auto scope_obj = _scope_obj_pool->get();
for (size_t i = 0; i < _input_accessors.size(); ++i) {
_input_accessors[i]->collect_persistables(scope_obj.get());
}
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", _persistables);
op->SetAttr("file_path", filename);
op->CheckAttrs();
platform::CPUPlace place;
framework::Executor exe(place);
exe.Run(prog, scope_obj.get(), 0, true, true);
return 0;
}
paddle::framework::Channel<DataItem> MultiThreadExecutor::run( paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
paddle::framework::Channel<DataItem> input, const DataParser* parser) { paddle::framework::Channel<DataItem> input, const DataParser* parser) {
......
...@@ -47,6 +47,8 @@ public: ...@@ -47,6 +47,8 @@ public:
virtual paddle::framework::Channel<DataItem> run( virtual paddle::framework::Channel<DataItem> run(
paddle::framework::Channel<DataItem> input, const DataParser* parser); paddle::framework::Channel<DataItem> input, const DataParser* parser);
virtual int32_t save_persistables(const std::string& filename);
virtual bool is_dump_all_model() { virtual bool is_dump_all_model() {
return _need_dump_all_model; return _need_dump_all_model;
} }
...@@ -79,6 +81,7 @@ protected: ...@@ -79,6 +81,7 @@ protected:
std::vector<std::shared_ptr<DataInputAccessor>> _input_accessors; std::vector<std::shared_ptr<DataInputAccessor>> _input_accessors;
std::map<uint32_t, std::vector<DataInputAccessor*>> _table_to_accessors; std::map<uint32_t, std::vector<DataInputAccessor*>> _table_to_accessors;
std::shared_ptr<paddle::ps::ObjectPool<::paddle::framework::Scope>> _scope_obj_pool; std::shared_ptr<paddle::ps::ObjectPool<::paddle::framework::Scope>> _scope_obj_pool;
std::vector<std::string> _persistables;
}; };
} // namespace feed } // namespace feed
......
...@@ -25,6 +25,10 @@ public: ...@@ -25,6 +25,10 @@ public:
virtual bool exists(const std::string& path) = 0; virtual bool exists(const std::string& path) = 0;
virtual void mkdir(const std::string& path) = 0; virtual void mkdir(const std::string& path) = 0;
virtual std::string path_join(const std::string& dir, const std::string& path); virtual std::string path_join(const std::string& dir, const std::string& path);
template<class... STRS>
std::string path_join(const std::string& dir, const std::string& path, const STRS&... paths) {
return path_join(path_join(dir, path), paths...);
}
virtual std::pair<std::string, std::string> path_split(const std::string& path); virtual std::pair<std::string, std::string> path_split(const std::string& path);
protected: protected:
}; };
......
...@@ -27,6 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { ...@@ -27,6 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
} }
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client(); auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get(); auto* environment = _context_ptr->environment.get();
auto* epoch_accessor = _context_ptr->epoch_accessor.get(); auto* epoch_accessor = _context_ptr->epoch_accessor.get();
...@@ -39,18 +40,21 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { ...@@ -39,18 +40,21 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
paddle::platform::Timer timer; paddle::platform::Timer timer;
timer.Start(); timer.Start();
std::set<uint32_t> table_set; std::set<uint32_t> table_set;
auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
for (auto& executor : _executors) { for (auto& executor : _executors) {
const auto& table_accessors = executor->table_accessors(); const auto& table_accessors = executor->table_accessors();
for (auto& itr : table_accessors) { for (auto& itr : table_accessors) {
table_set.insert(itr.first); table_set.insert(itr.first);
} }
auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param");
VLOG(2) << "Start save model, save_path:" << save_path;
executor->save_persistables(save_path);
} }
int ret_size = 0; int ret_size = 0;
auto table_num = table_set.size(); auto table_num = table_set.size();
std::future<int> rets[table_num]; std::future<int> rets[table_num];
for (auto table_id : table_set) { for (auto table_id : table_set) {
VLOG(2) << "Start save model, table_id:" << table_id; VLOG(2) << "Start save model, table_id:" << table_id;
auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way)); rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way));
} }
int all_ret = 0; int all_ret = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册