diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 58df49d324bb23c3f8425ca556de4a9ed2b0a863..737df3dd6b5b2d60d563a5c56b20a7f4bee5f1d0 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -20,7 +20,10 @@ limitations under the License. */ #include #include // NOLINT #include -#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT #include #include "paddle/fluid/framework/data_feed.h" @@ -195,6 +198,9 @@ class DownpourWorker : public HogwildWorker { void CollectLabelInfo(size_t table_id); void AdjustInsWeight(); void DumpParam(); + void CopySparseTable(); + void CopyDenseTable(); + void CopyDenseVars(); private: bool need_dump_param_; @@ -237,6 +243,12 @@ class DownpourWorker : public HogwildWorker { std::vector nid_show_; // check nan and inf during training std::vector check_nan_var_names_; + // copy table + CopyTableConfig copy_table_config_; + std::map table_dependency_; + std::vector> copy_sparse_tables_; + std::vector> copy_dense_tables_; + std::unordered_map> feasign_set_; }; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index 248855b795f340ce21939335eec6ffe41645b763..cedf22bd9f05382c2e4946ee86910ad0205d03a4 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -93,6 +93,29 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { check_nan_var_names_.push_back(desc.check_nan_var_names(i)); } + copy_table_config_ = desc.copy_table_config(); + for (int i = 0; i < copy_table_config_.src_sparse_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_sparse_tables(i); + uint64_t dest_table = copy_table_config_.dest_sparse_tables(i); + VLOG(3) << "copy_sparse_tables_ push back " << src_table << "->" + << dest_table; + copy_sparse_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (int i = 0; i < copy_table_config_.src_dense_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_dense_tables(i); + uint64_t dest_table = copy_table_config_.dest_dense_tables(i); + VLOG(3) << "copy_dense_tables_ push back " << src_table << "->" + << dest_table; + copy_dense_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (auto& m : copy_table_config_.table_denpendency_map()) { + if (sparse_key_names_.find(m.key()) != sparse_key_names_.end()) { + // currently only support one dependency + for (auto& value : m.values()) { + table_dependency_[m.key()] = value; + } + } + } } void DownpourWorker::SetChannelWriter(ChannelObject* queue) { @@ -404,6 +427,102 @@ void DownpourWorker::AdjustInsWeight() { #endif } +void DownpourWorker::CopySparseTable() { + for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) { + int64_t src_table = copy_sparse_tables_[i].first; + int64_t dest_table = copy_sparse_tables_[i].second; + int32_t feanum = 0; + if (src_table == dest_table) { + continue; + } else if (!copy_table_config_.sparse_copy_by_feasign()) { + if (feasign_set_.find(src_table) == feasign_set_.end()) { + continue; + } else if (feasign_set_[src_table].size() == 0) { + continue; + } + feanum = fleet_ptr_->CopyTable(src_table, dest_table); + } else { + std::vector fea_vec(feasign_set_[src_table].begin(), + feasign_set_[src_table].end()); + feanum = fleet_ptr_->CopyTableByFeasign(src_table, dest_table, fea_vec); + fea_vec.clear(); + std::vector().swap(fea_vec); + } + VLOG(3) << "copy feasign from table " << src_table << " to table " + << dest_table << ", feasign num=" << feanum; + feasign_set_[src_table].clear(); + std::unordered_set().swap(feasign_set_[src_table]); + } + feasign_set_.clear(); +} + +void DownpourWorker::CopyDenseTable() { + if (thread_id_ != 0) { + return; + } + thread_local std::vector> pull_dense_status; + for (size_t i = 0; i < copy_dense_tables_.size(); ++i) { + uint64_t src_table = copy_dense_tables_[i].first; + uint64_t dest_table = copy_dense_tables_[i].second; + if (src_table == dest_table) { + continue; + } + int32_t dim = fleet_ptr_->CopyTable(src_table, dest_table); + VLOG(3) << "copy param from table " << src_table << " to table " + << dest_table << ", dim=" << dim; + if (copy_table_config_.dense_pull_after_copy()) { + VLOG(3) << "dense pull after copy, table=" << dest_table; + pull_dense_status.resize(0); + fleet_ptr_->PullDenseVarsAsync(*root_scope_, dest_table, + dense_value_names_[dest_table], + &pull_dense_status); + for (auto& t : pull_dense_status) { + t.wait(); + auto status = t.get(); + if (status != 0) { + LOG(WARNING) << "pull dense after copy table failed," + << " table=" << dest_table; + } + } + } + } +} + +void DownpourWorker::CopyDenseVars() { + if (thread_id_ != 0) { + return; + } + for (int i = 0; i < copy_table_config_.src_var_list_size(); ++i) { + auto& src_var_name = copy_table_config_.src_var_list(i); + auto& dest_var_name = copy_table_config_.dest_var_list(i); + if (src_var_name == dest_var_name) { + continue; + } + VLOG(3) << "copy dense var from " << src_var_name << " to " + << dest_var_name; + Variable* src_var = thread_scope_->FindVar(src_var_name); + CHECK(src_var != nullptr) << src_var_name << " not found"; // NOLINT + LoDTensor* src_tensor = src_var->GetMutable(); + CHECK(src_tensor != nullptr) << src_var_name + << " tensor is null"; // NOLINT + float* src_data = src_tensor->data(); + + Variable* dest_var = thread_scope_->FindVar(dest_var_name); + CHECK(dest_var != nullptr) << dest_var_name << " not found"; // NOLINT + LoDTensor* dest_tensor = dest_var->GetMutable(); + CHECK(dest_tensor != nullptr) << dest_var_name + << " tensor is null"; // NOLINT + float* dest_data = dest_tensor->data(); + + CHECK(src_tensor->numel() == dest_tensor->numel()) + << "tensor numel not equal," << src_tensor->numel() << " vs " + << dest_tensor->numel(); + for (int i = 0; i < src_tensor->numel(); i++) { + dest_data[i] = src_data[i]; + } + } +} + void DownpourWorker::TrainFilesWithProfiler() { VLOG(3) << "Begin to train files with profiler"; platform::SetNumThreads(1); @@ -437,6 +556,7 @@ void DownpourWorker::TrainFilesWithProfiler() { double fill_sparse_time = 0.0; double push_sparse_time = 0.0; double push_dense_time = 0.0; + double copy_table_time = 0.0; int cur_batch; int batch_cnt = 0; uint64_t total_inst = 0; @@ -445,6 +565,27 @@ void DownpourWorker::TrainFilesWithProfiler() { timeline.Pause(); read_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec(); + + timeline.Start(); + if (copy_table_config_.need_copy()) { + VLOG(3) << "copy_sparse_tables_.size " << copy_sparse_tables_.size(); + if (copy_table_config_.sparse_copy_by_feasign()) { + for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) { + uint64_t tid = copy_sparse_tables_[i].first; + feasign_set_[tid].insert(sparse_push_keys_[tid].begin(), + sparse_push_keys_[tid].end()); + } + } + if (batch_cnt % copy_table_config_.batch_num() == 0) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } + } + timeline.Pause(); + copy_table_time += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + VLOG(3) << "program config size: " << param_.program_config_size(); for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); ++i) { @@ -641,6 +782,7 @@ void DownpourWorker::TrainFilesWithProfiler() { collect_label_time / batch_cnt); fprintf(stderr, "adjust ins weight time: %fs\n", adjust_ins_weight_time / batch_cnt); + fprintf(stderr, "copy table time: %fs\n", copy_table_time / batch_cnt); fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt); fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100); fprintf(stderr, "op run percent: %f\n", op_sum_time / total_time * 100); @@ -648,6 +790,8 @@ void DownpourWorker::TrainFilesWithProfiler() { pull_sparse_time / total_time * 100); fprintf(stderr, "adjust ins weight time percent: %f\n", adjust_ins_weight_time / total_time * 100); + fprintf(stderr, "copy table time percent: %f\n", + copy_table_time / total_time * 100); fprintf(stderr, "collect label time percent: %f\n", collect_label_time / total_time * 100); fprintf(stderr, "fill sparse time percent: %f\n", @@ -661,6 +805,11 @@ void DownpourWorker::TrainFilesWithProfiler() { } timeline.Start(); } + if (copy_table_config_.need_copy()) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } } void DownpourWorker::TrainFiles() { @@ -670,6 +819,20 @@ void DownpourWorker::TrainFiles() { int batch_cnt = 0; int cur_batch; while ((cur_batch = device_reader_->Next()) > 0) { + if (copy_table_config_.need_copy()) { + if (copy_table_config_.sparse_copy_by_feasign()) { + for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) { + uint64_t tid = copy_sparse_tables_[i].first; + feasign_set_[tid].insert(sparse_push_keys_[tid].begin(), + sparse_push_keys_[tid].end()); + } + } + if (batch_cnt % copy_table_config_.batch_num() == 0) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } + } // pull sparse here for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); ++i) { @@ -850,6 +1013,11 @@ void DownpourWorker::TrainFiles() { if (need_dump_field_) { writer_.Flush(); } + if (copy_table_config_.need_copy()) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 93672e88963eaff089257181fdbe69c52561acf2..7aa1a6fc93825f45825341669c74acdc6de6841e 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -40,28 +40,6 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; std::shared_ptr FleetWrapper::s_instance_ = NULL; bool FleetWrapper::is_initialized_ = false; -#ifdef PADDLE_WITH_PSLIB -template -paddle::ps::Archive& operator<<(paddle::ps::Archive& ar, - const MultiSlotType& ins) { - ar << ins.GetType(); - ar << ins.GetOffset(); - ar << ins.GetFloatData(); - ar << ins.GetUint64Data(); - return ar; -} - -template -paddle::ps::Archive& operator>>(paddle::ps::Archive& ar, - MultiSlotType& ins) { - ar >> ins.MutableType(); - ar >> ins.MutableOffset(); - ar >> ins.MutableFloatData(); - ar >> ins.MutableUint64Data(); - return ar; -} -#endif - #ifdef PADDLE_WITH_PSLIB std::shared_ptr FleetWrapper::pslib_ptr_ = NULL; #endif @@ -729,40 +707,6 @@ std::future FleetWrapper::SendClientToClientMsg( return std::future(); } -template -void FleetWrapper::Serialize(const std::vector& t, std::string* str) { -#ifdef PADDLE_WITH_PSLIB - paddle::ps::BinaryArchive ar; - for (size_t i = 0; i < t.size(); ++i) { - ar << *(t[i]); - } - *str = std::string(ar.buffer(), ar.length()); -#else - VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib"; -#endif -} - -template -void FleetWrapper::Deserialize(std::vector* t, const std::string& str) { -#ifdef PADDLE_WITH_PSLIB - if (str.length() == 0) { - return; - } - paddle::ps::BinaryArchive ar; - ar.set_read_buffer(const_cast(str.c_str()), str.length(), nullptr); - if (ar.cursor() == ar.finish()) { - return; - } - while (ar.cursor() < ar.finish()) { - t->push_back(ar.get()); - } - CHECK(ar.cursor() == ar.finish()); - VLOG(3) << "Deserialize size " << t->size(); -#else - VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib"; -#endif -} - std::default_random_engine& FleetWrapper::LocalRandomEngine() { struct engine_wrapper_t { std::default_random_engine engine; @@ -781,10 +725,43 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() { return r.engine; } -template void FleetWrapper::Serialize>( - const std::vector*>&, std::string*); -template void FleetWrapper::Deserialize>( - std::vector>*, const std::string&); +int32_t FleetWrapper::CopyTable(const uint64_t src_table_id, + const uint64_t dest_table_id) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->copy_table(src_table_id, dest_table_id); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "copy table failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + return feasign_cnt; +#else + VLOG(0) << "FleetWrapper::CopyTable does nothing when no pslib"; + return 0; +#endif +} + +int32_t FleetWrapper::CopyTableByFeasign( + const uint64_t src_table_id, const uint64_t dest_table_id, + const std::vector& feasign_list) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->copy_table_by_feasign( + src_table_id, dest_table_id, feasign_list.data(), feasign_list.size()); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "copy table by feasign failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + return feasign_cnt; +#else + VLOG(0) << "FleetWrapper::CopyTableByFeasign does nothing when no pslib"; + return 0; +#endif +} } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 0307b28aba7ccfd916be7687032a6e191c61c93e..8ad860cbd2243d1da54e12367df6e02712079291 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -67,11 +67,12 @@ class FleetWrapper { client2client_max_retry_ = 3; } + // set client to client communication config void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry); - // Pull sparse variables from server in Sync mode - // Param: scope, table_id, var_names, fea_keys + // Pull sparse variables from server in sync mode + // Param: scope, table_id, var_names, fea_keys, fea_dim // Param: fea_values void PullSparseVarsSync(const Scope& scope, const uint64_t table_id, const std::vector& var_names, @@ -80,19 +81,24 @@ class FleetWrapper { int fea_dim, const std::vector& var_emb_names); + // pull dense variables from server in sync mod void PullDenseVarsSync(const Scope& scope, const uint64_t table_id, const std::vector& var_names); + // pull dense variables from server in async mod + // Param: scope, table_id, var_names + // Param: pull_dense_status void PullDenseVarsAsync( const Scope& scope, const uint64_t table_id, const std::vector& var_names, std::vector<::std::future>* pull_dense_status); + // push dense parameters(not gradients) to server in sync mode void PushDenseParamSync(const Scope& scope, const uint64_t table_id, const std::vector& var_names); // Push dense variables to server in async mode - // Param: scope, table_id, var_names, + // Param: scope, table_id, var_names, scale_datanorm, batch_size // Param: push_sparse_status void PushDenseVarsAsync( const Scope& scope, const uint64_t table_id, @@ -100,13 +106,14 @@ class FleetWrapper { std::vector<::std::future>* push_sparse_status, float scale_datanorm, int batch_size); + // push dense variables to server in sync mode void PushDenseVarsSync(Scope* scope, const uint64_t table_id, const std::vector& var_names); - // Push sparse variables with labels to server in Async mode + // Push sparse variables with labels to server in async mode // This is specially designed for click/show stats in server - // Param: scope, table_id, var_grad_names, - // fea_keys, fea_labels, sparse_grad_names + // Param: scope, table_id, fea_keys, fea_labels, sparse_key_names, + // sparse_grad_names, batch_size, use_cvm, dump_slot // Param: push_values, push_sparse_status void PushSparseVarsWithLabelAsync( const Scope& scope, const uint64_t table_id, @@ -132,12 +139,17 @@ class FleetWrapper { std::vector<::std::future>* push_sparse_status); */ + // init server void InitServer(const std::string& dist_desc, int index); + // init trainer void InitWorker(const std::string& dist_desc, const std::vector& host_sign_list, int node_num, int index); + // stop server void StopServer(); + // run server uint64_t RunServer(); + // gather server ip void GatherServers(const std::vector& host_sign_list, int node_num); // gather client ip void GatherClients(const std::vector& host_sign_list); @@ -145,7 +157,6 @@ class FleetWrapper { std::vector GetClientsInfo(); // create client to client connection void CreateClient2ClientConnection(); - // flush all push requests void ClientFlush(); // load from paddle model @@ -164,37 +175,42 @@ class FleetWrapper { // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); - + // get save cache threshold double GetCacheThreshold(); + // shuffle cache model between servers void CacheShuffle(int table_id, const std::string& path, const int mode, const double cache_threshold); + // save cache model + // cache model can speed up online predict int32_t SaveCache(int table_id, const std::string& path, const int mode); - + // copy feasign key/value from src_table_id to dest_table_id + int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id); + // copy feasign key/value from src_table_id to dest_table_id + int32_t CopyTableByFeasign(const uint64_t src_table_id, + const uint64_t dest_table_id, + const std::vector& feasign_list); + // clear all models, release their memory void ClearModel(); - + // shrink sparse table void ShrinkSparseTable(int table_id); + // shrink dense table void ShrinkDenseTable(int table_id, Scope* scope, std::vector var_list, float decay, int emb_dim); - // register client to client communication typedef std::function MsgHandlerFunc; + // register client to client communication int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); // send client to client message std::future SendClientToClientMsg(int msg_type, int to_client_id, const std::string& msg); - - template - void Serialize(const std::vector& t, std::string* str); - template - void Deserialize(std::vector* t, const std::string& str); + // FleetWrapper singleton static std::shared_ptr GetInstance() { if (NULL == s_instance_) { s_instance_.reset(new paddle::framework::FleetWrapper()); } return s_instance_; } - // this performs better than rand_r, especially large data std::default_random_engine& LocalRandomEngine(); diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 59f2cd9d327f6b6af8504a0e7e4af8ce135c4233..5212c09b65c15e33e3e022fdbe8584686767cdba 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -40,10 +40,12 @@ message TrainerDesc { repeated string dump_fields = 13; optional string dump_converter = 14; repeated string dump_param = 15; - optional int32 mpi_size = 16 [ default = -1 ]; optional int32 dump_file_num = 17 [ default = 16 ]; repeated string check_nan_var_names = 18; + optional CopyTableConfig copy_table_config = 19; + // adjust ins weight + optional AdjustInsWeightConfig adjust_ins_weight_config = 20; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; @@ -52,8 +54,6 @@ message TrainerDesc { optional SectionWorkerParameter section_param = 104; // datafeed desc optional DataFeedDesc data_desc = 201; - // adjust ins weight - optional AdjustInsWeightConfig adjust_ins_weight_config = 301; } message HogwildWorkerParameter { repeated string skip_ops = 1; } @@ -108,6 +108,29 @@ message AdjustInsWeightConfig { optional string ins_weight_slot = 5 [ default = "" ]; } +message TableDependencyMap { + required int32 key = 1; + repeated int32 values = 2; +} + +message CopyTableConfig { + optional bool need_copy = 1 [ default = false ]; + optional int32 batch_num = 2 [ default = 100 ]; + repeated int32 src_sparse_tables = 3; + repeated int32 dest_sparse_tables = 4; + repeated int32 src_dense_tables = 5; + repeated int32 dest_dense_tables = 6; + repeated string src_var_list = 7; + repeated string dest_var_list = 8; + // when dest dense table has no grad, should pull explicitly + optional bool dense_pull_after_copy = 9 [ default = false ]; + // copy feasigns or copy the whole table + optional bool sparse_copy_by_feasign = 10 [ default = true ]; + // table dependency for pull/push + optional bool enable_dependency = 11 [ default = false ]; + repeated TableDependencyMap table_denpendency_map = 12; +} + message ProgramConfig { required string program_id = 1; repeated int32 push_sparse_table_id = 2; diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index e7c7750c27de617ba8f339302ffa0fde95a794af..31268f5e1826a6be63a23cbe29e8a960b1ac5705 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -67,7 +67,10 @@ void BindFleetWrapper(py::module* m) { &framework::FleetWrapper::LoadFromPaddleModel) .def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable) .def("set_client2client_config", - &framework::FleetWrapper::SetClient2ClientConfig); + &framework::FleetWrapper::SetClient2ClientConfig) + .def("copy_table", &framework::FleetWrapper::CopyTable) + .def("copy_table_by_feasign", + &framework::FleetWrapper::CopyTableByFeasign); } // end FleetWrapper } // end namespace pybind } // end namespace paddle diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py index facecd8167919a0ca8f28747fb71cf4ac0067519..902f515d0bb466a32c7ab7b2531c3c7d8b4dd7e8 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py @@ -343,7 +343,7 @@ class DownpourWorker(Worker): target_table = None for table in self._worker.sparse_table: if table.table_id == table_id: - keys = self._worker.sparse_table[table_id].slot_key + keys = table.slot_key key_names = [var.name for var in sorted_slot_key_vars] for key_name in key_names: if key_name not in keys: diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 59b220c10808b994aaff9605db42d27aa2193788..5affa9b59c2bdb559312caa771502b08ef308c74 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -372,6 +372,7 @@ class DistributedAdam(DistributedOptimizerImplBase): 0].accessor.accessor_class == "DownpourCtrAccessor": opt_info["dump_slot"] = True opt_info["adjust_ins_weight"] = strategy.get("adjust_ins_weight", {}) + opt_info["copy_table"] = strategy.get("copy_table", {}) for loss in losses: loss.block.program._fleet_opt = opt_info diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py index 6813b76789b899b886a88fb3c6a9b3551976ee89..0f106a75253ee254b315c011e77c72bdb02d789a 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py @@ -820,8 +820,9 @@ class FleetUtil(object): """ fleet._role_maker._barrier_worker() if fleet._role_maker.is_first_worker(): - tables = fleet._dist_desc.trainer_param.dense_table prog_id = str(id(program)) + tables = fleet._opt_info["program_id_to_worker"][prog_id].\ + get_desc().dense_table prog_conf = fleet._opt_info['program_configs'][prog_id] prog_tables = {} for key in prog_conf: diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index a4cf033062e5ca73d217b690595b9bd2b31c86b4..75a1472d09bca66ce427104ae9f7402d6f1b7480 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Defination of trainers.""" import sys from os import path @@ -120,6 +121,78 @@ class TrainerDesc(object): self.proto_desc.adjust_ins_weight_config.ins_weight_slot = \ config_dict.get("ins_weight_slot", "") + def _set_copy_table_config(self, config_dict): + config = self.proto_desc.copy_table_config + config.need_copy = config_dict.get("need_copy", False) + config.batch_num = config_dict.get("batch_num", 100) + + src_sparse_tables = config_dict.get("src_sparse_tables", []) + if not isinstance(src_sparse_tables, list): + src_sparse_tables = [src_sparse_tables] + dest_sparse_tables = config_dict.get("dest_sparse_tables", []) + if not isinstance(dest_sparse_tables, list): + dest_sparse_tables = [dest_sparse_tables] + if len(src_sparse_tables) != len(dest_sparse_tables): + raise ValueError( + "len(src_sparse_tables) != len(dest_sparse_tables)," \ + " %s vs %s" % (len(src_sparse_tables), \ + len(dest_sparse_tables))) + for i in src_sparse_tables: + config.src_sparse_tables.append(i) + for i in dest_sparse_tables: + config.dest_sparse_tables.append(i) + + src_dense_tables = config_dict.get("src_dense_tables", []) + if not isinstance(src_dense_tables, list): + src_dense_tables = [src_dense_tables] + dest_dense_tables = config_dict.get("dest_dense_tables", []) + if not isinstance(dest_dense_tables, list): + dest_dense_tables = [dest_dense_tables] + if len(src_dense_tables) != len(dest_dense_tables): + raise ValueError( + "len(src_dense_tables) != len(dest_dense_tables)," \ + " %s vs %s" % (len(src_dense_tables), \ + len(dest_dense_tables))) + for i in src_dense_tables: + config.src_dense_tables.append(i) + for i in dest_dense_tables: + config.dest_dense_tables.append(i) + + # user can also specify dense variables to copy, + # instead of copy dense table + src_var_list = config_dict.get("src_var_list", []) + if not isinstance(src_var_list, list): + src_var_list = [src_var_list] + dest_var_list = config_dict.get("dest_var_list", []) + if not isinstance(dest_var_list, list): + dest_var_list = [dest_var_list] + if len(src_var_list) != len(dest_var_list): + raise ValueError( + "len(src_var_list) != len(dest_var_list), %s vs" \ + " %s" % (len(src_var_list), len(dest_var_list))) + for i in src_var_list: + config.src_var_list.append(i) + for i in dest_var_list: + config.dest_var_list.append(i) + + dependency_map = config_dict.get("dependency_map", {}) + for key in dependency_map: + m = config.table_denpendency_map.add() + m.key = key + values = dependency_map[key] + if not isinstance(values, list): + values = [values] + if len(values) != 1: + raise ValueError("dependency len %s != 1" % len(values)) + for value in values: + m.values.append(value) + config.dense_pull_after_copy = \ + config_dict.get("dense_pull_after_copy", True) + config.enable_dependency = \ + config_dict.get("enable_dependency", False) + config.sparse_copy_by_feasign = \ + config_dict.get("sparse_copy_by_feasign", True) + def _desc(self): from google.protobuf import text_format return self.proto_desc.SerializeToString() @@ -151,6 +224,11 @@ class MultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc): + """ + Implement of DistMultiTrainer. + It's for Distributed training. + """ + def __init__(self): super(DistMultiTrainer, self).__init__() pass @@ -170,6 +248,11 @@ class DistMultiTrainer(TrainerDesc): class PipelineTrainer(TrainerDesc): + """ + Implement of PipelineTrainer. + It's for Pipeline. + """ + def __init__(self): super(PipelineTrainer, self).__init__() pass diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 1469af3d182bd3182e8c9e811bfa11a34f29b7c1..70154e383a9ed0ecff7ada928992ec1f36a65ff1 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Defination of TrainerFactory.""" import threading import time @@ -24,6 +25,12 @@ __all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"] class TrainerFactory(object): + """ + Create trainer and device worker. + If opt_info is not None, it will get configs from opt_info, + otherwise create MultiTrainer and Hogwild. + """ + def __init__(self): pass @@ -43,24 +50,44 @@ class TrainerFactory(object): if "fleet_desc" in opt_info: device_worker._set_fleet_desc(opt_info["fleet_desc"]) trainer._set_fleet_desc(opt_info["fleet_desc"]) - trainer._set_use_cvm(opt_info["use_cvm"]) - trainer._set_scale_datanorm(opt_info["scale_datanorm"]) - trainer._set_dump_slot(opt_info["dump_slot"]) - trainer._set_mpi_rank(opt_info["mpi_rank"]) - trainer._set_mpi_size(opt_info["mpi_size"]) - trainer._set_dump_fields(opt_info["dump_fields"]) - trainer._set_dump_fields_path(opt_info["dump_fields_path"]) - trainer._set_dump_file_num(opt_info["dump_file_num"]) - trainer._set_dump_converter(opt_info["dump_converter"]) - trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"]) - trainer._set_dump_param(opt_info["dump_param"]) - trainer._set_check_nan_var_names(opt_info[ - "check_nan_var_names"]) + if opt_info.get("use_cvm") is not None: + trainer._set_use_cvm(opt_info["use_cvm"]) + if opt_info.get("scale_datanorm") is not None: + trainer._set_scale_datanorm(opt_info["scale_datanorm"]) + if opt_info.get("dump_slot") is not None: + trainer._set_dump_slot(opt_info["dump_slot"]) + if opt_info.get("mpi_rank") is not None: + trainer._set_mpi_rank(opt_info["mpi_rank"]) + if opt_info.get("mpi_size") is not None: + trainer._set_mpi_size(opt_info["mpi_size"]) + if opt_info.get("dump_fields") is not None: + trainer._set_dump_fields(opt_info["dump_fields"]) + if opt_info.get("dump_fields_path") is not None: + trainer._set_dump_fields_path(opt_info["dump_fields_path"]) + if opt_info.get("dump_file_num") is not None: + trainer._set_dump_file_num(opt_info["dump_file_num"]) + if opt_info.get("dump_converter") is not None: + trainer._set_dump_converter(opt_info["dump_converter"]) + if opt_info.get("adjust_ins_weight") is not None: + trainer._set_adjust_ins_weight(opt_info[ + "adjust_ins_weight"]) + if opt_info.get("copy_table") is not None: + trainer._set_copy_table_config(opt_info["copy_table"]) + if opt_info.get("check_nan_var_names") is not None: + trainer._set_check_nan_var_names(opt_info[ + "check_nan_var_names"]) + if opt_info.get("dump_param") is not None: + trainer._set_dump_param(opt_info["dump_param"]) trainer._set_device_worker(device_worker) return trainer class FetchHandlerMonitor(object): + """ + Defination of FetchHandlerMonitor class, + it's for fetch handler. + """ + def __init__(self, scope, handler): self.fetch_instance = handler self.fetch_thread = threading.Thread( @@ -69,11 +96,21 @@ class FetchHandlerMonitor(object): self.running = False def start(self): + """ + start monitor, + it will start a monitor thread. + """ self.running = True self.fetch_thread.setDaemon(True) self.fetch_thread.start() def handler_decorator(self, fetch_scope, fetch_handler): + """ + decorator of handler, + Args: + fetch_scope(Scope): fetch scope + fetch_handler(Handler): fetch handler + """ fetch_target_names = self.fetch_instance.fetch_target_names period_secs = self.fetch_instance.period_secs