diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc index b9a40b0a032377b660a859398e51a12a7c808749..1957b695c17a73c876fc36c0586a8b7559b75e56 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc @@ -70,6 +70,52 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) { int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, paddle::framework::Scope* scope) { + collect_persistables(scope); + + if (_need_async_pull) { + ++_pull_request_num; + } + return 0; +} + +int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, + paddle::framework::Scope* scope) { + if (!_need_gradient) { + return 0; + } + size_t data_buffer_idx = 0; + std::vector regions; + for (auto& variable : _x_variables) { + auto* tensor = scope->Var(variable.gradient_name)-> + GetMutable(); + auto* grad_data = tensor->mutable_data(_trainer_context->cpu_place); + regions.emplace_back(grad_data, variable.dim); + } + auto* ps_client = _trainer_context->pslib->ps_client(); + auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id); + //push_status.get(); + if (!FLAGS_feed_trainer_debug_dense_name.empty()) { + std::stringstream ssm; + for (auto& variable : _x_variables) { + ssm.str(""); + if (variable.name != FLAGS_feed_trainer_debug_dense_name) { + continue; + } + auto& tensor = scope->Var(variable.gradient_name)-> + Get(); + const auto* var_data = tensor.data(); + for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) { + if (data_idx > 0) + ssm << ","; + ssm << var_data[data_idx]; + } + VLOG(2) << "[DEBUG]push_dense: " << ssm.str(); + } + } + return 0; +} + +int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) { // 首次同步pull,之后异步pull if (_data_buffer == nullptr) { _pull_mutex.lock(); @@ -95,7 +141,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, paddle::framework::DDim ddim(shape_ptr, variable.shape.size()); auto* tensor = ScopeHelper::resize_lod_tensor(scope, variable.name, ddim); auto* grad_tensor = ScopeHelper::resize_lod_tensor(scope, variable.gradient_name, ddim); - VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name; + VLOG(5) << "fill scope variable:" << variable.name << ", " << variable.gradient_name + << ", data_buffer: " << _data_buffer + data_buffer_idx + << ", dim: " << variable.dim * sizeof(float); auto* var_data = tensor->mutable_data(_trainer_context->cpu_place); memcpy(var_data, _data_buffer + data_buffer_idx, variable.dim * sizeof(float)); data_buffer_idx += variable.dim; @@ -120,45 +168,12 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, VLOG(2) << "[DEBUG]pull_dense: " << ssm.str(); } } - if (_need_async_pull) { - ++_pull_request_num; - } return 0; } -int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, - paddle::framework::Scope* scope) { - if (!_need_gradient) { - return 0; - } - size_t data_buffer_idx = 0; - std::vector regions; +int32_t DenseInputAccessor::collect_persistables_name(std::vector& persistables) { for (auto& variable : _x_variables) { - auto* tensor = scope->Var(variable.gradient_name)-> - GetMutable(); - auto* grad_data = tensor->mutable_data(_trainer_context->cpu_place); - regions.emplace_back(grad_data, variable.dim); - } - auto* ps_client = _trainer_context->pslib->ps_client(); - auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id); - //push_status.get(); - if (!FLAGS_feed_trainer_debug_dense_name.empty()) { - std::stringstream ssm; - for (auto& variable : _x_variables) { - ssm.str(""); - if (variable.name != FLAGS_feed_trainer_debug_dense_name) { - continue; - } - auto& tensor = scope->Var(variable.gradient_name)-> - Get(); - const auto* var_data = tensor.data(); - for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) { - if (data_idx > 0) - ssm << ","; - ssm << var_data[data_idx]; - } - VLOG(2) << "[DEBUG]push_dense: " << ssm.str(); - } + persistables.push_back(variable.name); } return 0; } diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h index 45c3d051a1d0d0584f193aad35f722dc96cf9371..52d07b6b0598f2a1f4f87f44bba2d88b874b7582 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h @@ -38,6 +38,12 @@ public: // 后向,一般用于更新梯度,在训练网络执行后调用 virtual int32_t backward(SampleInstance* samples, size_t num, ::paddle::framework::Scope* scope) = 0; + + // 收集持久化变量的名称, 并将值拷贝到Scope + virtual int32_t collect_persistables_name(std::vector& persistables) {return 0;} + + // 填充持久化变量的值,用于保存 + virtual int32_t collect_persistables(paddle::framework::Scope* scope) {return 0;} protected: size_t _table_id = 0; bool _need_gradient = false; @@ -144,6 +150,11 @@ public: virtual int32_t backward(SampleInstance* samples, size_t num, paddle::framework::Scope* scope); + + + virtual int32_t collect_persistables_name(std::vector& persistables); + + virtual int32_t collect_persistables(paddle::framework::Scope* scope); protected: virtual int32_t pull_dense(size_t table_id); diff --git a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc index af3d4f7cbf18bde80dadf10fcd11ba5135110b97..e7ef07b0fd4ad5acc5f7990a49dc4826a797d6f8 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc @@ -52,6 +52,7 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, CHECK(_trainer_context->file_system->exists(model_config_path)) << "miss model config file:" << model_config_path; _model_config = YAML::LoadFile(model_config_path); + _persistables.clear(); for (const auto& accessor_config : _model_config["input_accessor"]) { auto accessor_class = accessor_config["class"].as(); auto* accessor_ptr = CREATE_INSTANCE(DataInputAccessor, accessor_class); @@ -66,7 +67,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, _table_to_accessors[table_id] = {accessor_ptr}; } } - } + CHECK(accessor_ptr->collect_persistables_name(_persistables) == 0) + << "collect_persistables Failed, class:" << accessor_class; + } + std::sort(_persistables.begin(), _persistables.end()); // 持久化变量名一定要排序 // Monitor组件 for (const auto& monitor_config : _model_config["monitor"]) { @@ -79,6 +83,27 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config, return ret; } +int32_t MultiThreadExecutor::save_persistables(const std::string& filename) { + // auto fs = _trainer_context->file_system; + // fs->mkdir(fs->path_split(filename).first); + auto scope_obj = _scope_obj_pool->get(); + for (size_t i = 0; i < _input_accessors.size(); ++i) { + _input_accessors[i]->collect_persistables(scope_obj.get()); + } + framework::ProgramDesc prog; + auto* block = prog.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("save_combine"); + op->SetInput("X", _persistables); + op->SetAttr("file_path", filename); + op->CheckAttrs(); + + platform::CPUPlace place; + framework::Executor exe(place); + exe.Run(prog, scope_obj.get(), 0, true, true); + return 0; +} + paddle::framework::Channel MultiThreadExecutor::run( paddle::framework::Channel input, const DataParser* parser) { diff --git a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h index e1fbc42d699ed852d7bee9f9c6d64221e55b824b..becfabea5c32db1ea766fbb09ad9b83f34f0d7cd 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h +++ b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h @@ -46,6 +46,8 @@ public: //执行训练 virtual paddle::framework::Channel run( paddle::framework::Channel input, const DataParser* parser); + + virtual int32_t save_persistables(const std::string& filename); virtual bool is_dump_all_model() { return _need_dump_all_model; @@ -79,6 +81,7 @@ protected: std::vector> _input_accessors; std::map> _table_to_accessors; std::shared_ptr> _scope_obj_pool; + std::vector _persistables; }; } // namespace feed diff --git a/paddle/fluid/train/custom_trainer/feed/io/file_system.h b/paddle/fluid/train/custom_trainer/feed/io/file_system.h index 0ef5a37b0c0a3e04d2f20d2b036ff2541b4f0f48..4531cbf1125d0a3af54a503fab82e1b9db970f2d 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/file_system.h +++ b/paddle/fluid/train/custom_trainer/feed/io/file_system.h @@ -25,6 +25,10 @@ public: virtual bool exists(const std::string& path) = 0; virtual void mkdir(const std::string& path) = 0; virtual std::string path_join(const std::string& dir, const std::string& path); + template + std::string path_join(const std::string& dir, const std::string& path, const STRS&... paths) { + return path_join(path_join(dir, path), paths...); + } virtual std::pair path_split(const std::string& path); protected: }; diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index 367f340b20c434f5f103292b12f90f161e3cb555..40cd0651512ba0a007416176ef336474e6feb129 100755 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -27,6 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr context_ptr) { } int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { + auto fs = _context_ptr->file_system; auto* ps_client = _context_ptr->pslib->ps_client(); auto* environment = _context_ptr->environment.get(); auto* epoch_accessor = _context_ptr->epoch_accessor.get(); @@ -39,18 +40,21 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { paddle::platform::Timer timer; timer.Start(); std::set table_set; + auto model_dir = epoch_accessor->model_save_path(epoch_id, way); for (auto& executor : _executors) { const auto& table_accessors = executor->table_accessors(); for (auto& itr : table_accessors) { table_set.insert(itr.first); } + auto save_path = fs->path_join(model_dir, executor->train_exe_name() + "_param"); + VLOG(2) << "Start save model, save_path:" << save_path; + executor->save_persistables(save_path); } int ret_size = 0; auto table_num = table_set.size(); std::future rets[table_num]; for (auto table_id : table_set) { VLOG(2) << "Start save model, table_id:" << table_id; - auto model_dir = epoch_accessor->model_save_path(epoch_id, way); rets[ret_size++] = ps_client->save(table_id, model_dir, std::to_string((int)way)); } int all_ret = 0;