diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 538002ca1c835c7e302018d37b051c4017687039..57bdfe0494093a5ab6183259ef841d1e0b8ffc07 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -156,6 +156,11 @@ copy(inference_lib_dist SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES} DSTS ${dst_dir} ${dst_dir}/lib) +set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/threadpool") +copy(inference_lib_dist + SRCS ${THREADPOOL_INCLUDE_DIR}/ThreadPool.h + DSTS ${dst_dir}) + copy(inference_lib_dist SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt DSTS ${FLUID_INFERENCE_INSTALL_DIR}) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index bbe2e34650a8d85072a639645148e5c572108479..606e9caa0731b3c3a752f602572aa6ee2d9ca693 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -189,7 +189,7 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o if(WITH_DISTRIBUTE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc - data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc + data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} @@ -199,7 +199,7 @@ set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_CO else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc - data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc + data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 69f13d9f7d75377c1892154d40398b9e0d3ef3dc..7926a9bfb9d5e6acdf3a90338a1763ff984c7b53 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -123,6 +123,12 @@ void DatasetImpl::SetMergeByInsId(int merge_size) { merge_size_ = merge_size; } +template +void DatasetImpl::SetGenerateUniqueFeasign(bool gen_uni_feasigns) { + gen_uni_feasigns_ = gen_uni_feasigns; + VLOG(3) << "Set generate unique feasigns: " << gen_uni_feasigns; +} + template void DatasetImpl::SetFeaEval(bool fea_eval, int record_candidate_size) { slots_shuffle_fea_eval_ = fea_eval; @@ -640,6 +646,85 @@ int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, // explicit instantiation template class DatasetImpl; +void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim, + int read_thread_num, + int consume_thread_num, + int shard_num) { + VLOG(3) << "MultiSlotDataset::GenerateUniqueFeasign begin"; + if (!gen_uni_feasigns_) { + VLOG(3) << "generate_unique_feasign_=false, will not GenerateUniqueFeasign"; + return; + } + + CHECK(multi_output_channel_.size() != 0); // NOLINT + auto fleet_ptr_ = FleetWrapper::GetInstance(); + std::vector>>& + local_map_tables = fleet_ptr_->GetLocalTable(); + local_map_tables.resize(shard_num); + // read thread + int channel_num = multi_output_channel_.size(); + if (read_thread_num < channel_num) { + read_thread_num = channel_num; + } + std::vector threads(read_thread_num); + consume_task_pool_.resize(consume_thread_num); + for (size_t i = 0; i < consume_task_pool_.size(); i++) { + consume_task_pool_[i].reset(new ::ThreadPool(1)); + } + auto consume_func = [&local_map_tables](int shard_id, int feadim, + std::vector& keys) { + for (auto k : keys) { + if (local_map_tables[shard_id].find(k) == + local_map_tables[shard_id].end()) { + local_map_tables[shard_id][k] = std::vector(feadim, 0); + } + } + }; + auto gen_func = [this, &shard_num, &feadim, &local_map_tables, + &consume_func](int i) { + std::vector vec_data; + std::vector> task_keys(shard_num); + std::vector> task_futures; + this->multi_output_channel_[i]->Close(); + this->multi_output_channel_[i]->ReadAll(vec_data); + for (size_t j = 0; j < vec_data.size(); j++) { + for (auto& feature : vec_data[j].uint64_feasigns_) { + int shard = feature.sign().uint64_feasign_ % shard_num; + task_keys[shard].push_back(feature.sign().uint64_feasign_); + } + } + + for (int shard_id = 0; shard_id < shard_num; shard_id++) { + task_futures.emplace_back(consume_task_pool_[shard_id]->enqueue( + consume_func, shard_id, feadim, task_keys[shard_id])); + } + + multi_output_channel_[i]->Open(); + multi_output_channel_[i]->Write(std::move(vec_data)); + vec_data.clear(); + vec_data.shrink_to_fit(); + for (auto& tk : task_keys) { + tk.clear(); + std::vector().swap(tk); + } + task_keys.clear(); + std::vector>().swap(task_keys); + for (auto& tf : task_futures) { + tf.wait(); + } + }; + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(gen_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + for (size_t i = 0; i < consume_task_pool_.size(); i++) { + consume_task_pool_[i].reset(); + } + consume_task_pool_.clear(); + fleet_ptr_->PullSparseToLocal(table_id, feadim); +} void MultiSlotDataset::MergeByInsId() { VLOG(3) << "MultiSlotDataset::MergeByInsId begin"; if (!merge_by_insid_) { diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index aa82b66305eb000752a3a1f0ab88caf957087bd0..94424a5ffaf23067b86e66fe232dcdec6a189712 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -14,12 +14,14 @@ #pragma once +#include #include #include #include // NOLINT #include #include #include // NOLINT +#include #include #include @@ -63,6 +65,7 @@ class Dataset { virtual void SetParseContent(bool parse_content) = 0; // set merge by ins id virtual void SetMergeByInsId(int merge_size) = 0; + virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0; // set fea eval mode virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0; // get file list @@ -112,6 +115,11 @@ class Dataset { virtual int64_t GetShuffleDataSize() = 0; // merge by ins id virtual void MergeByInsId() = 0; + virtual void GenerateLocalTablesUnlock(int table_id, int feadim, + int read_thread_num, + int consume_thread_num, + int shard_num) = 0; + virtual void ClearLocalTables() = 0; // create preload readers virtual void CreatePreLoadReaders() = 0; // destroy preload readers after prelaod done @@ -148,7 +156,7 @@ class DatasetImpl : public Dataset { virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseContent(bool parse_content); virtual void SetMergeByInsId(int merge_size); - + virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns); virtual void SetFeaEval(bool fea_eval, int record_candidate_size); virtual const std::vector& GetFileList() { return filelist_; } virtual int GetThreadNum() { return thread_num_; } @@ -179,6 +187,11 @@ class DatasetImpl : public Dataset { virtual int64_t GetMemoryDataSize(); virtual int64_t GetShuffleDataSize(); virtual void MergeByInsId() {} + virtual void GenerateLocalTablesUnlock(int table_id, int feadim, + int read_thread_num, + int consume_thread_num, + int shard_num) {} + virtual void ClearLocalTables() {} virtual void CreatePreLoadReaders(); virtual void DestroyPreLoadReaders(); virtual void SetPreLoadThreadNum(int thread_num); @@ -195,6 +208,7 @@ class DatasetImpl : public Dataset { int channel_num_; std::vector> multi_output_channel_; std::vector> multi_consume_channel_; + std::vector> local_tables_; // when read ins, we put ins from one channel to the other, // and when finish reading, we set cur_channel = 1 - cur_channel, // so if cur_channel=0, all data are in output_channel, else consume_channel @@ -202,6 +216,7 @@ class DatasetImpl : public Dataset { std::vector slots_shuffle_original_data_; RecordCandidateList slots_shuffle_rclist_; int thread_num_; + int pull_sparse_to_local_thread_num_; paddle::framework::DataFeedDesc data_feed_desc_; int trainer_num_; std::vector filelist_; @@ -217,9 +232,11 @@ class DatasetImpl : public Dataset { bool parse_content_; size_t merge_size_; bool slots_shuffle_fea_eval_ = false; + bool gen_uni_feasigns_ = false; int preload_thread_num_; std::mutex global_index_mutex_; int64_t global_index_ = 0; + std::vector> consume_task_pool_; }; // use std::vector or Record as data type @@ -227,6 +244,16 @@ class MultiSlotDataset : public DatasetImpl { public: MultiSlotDataset() {} virtual void MergeByInsId(); + virtual void GenerateLocalTablesUnlock(int table_id, int feadim, + int read_thread_num, + int consume_thread_num, int shard_num); + virtual void ClearLocalTables() { + for (auto& t : local_tables_) { + t.clear(); + std::unordered_set().swap(t); + } + std::vector>().swap(local_tables_); + } virtual void SlotsShuffle(const std::set& slots_to_replace); virtual void GetRandomData(const std::set& slots_to_replace, std::vector* result); diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 3dbdda14cb578bfead39f7e7b45e9eadf872021c..946ecbf0f9f1621ff2cd83644c3d6de23bb26897 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -207,54 +207,80 @@ class DownpourWorker : public HogwildWorker { void CopySparseTable(); void CopyDenseTable(); void CopyDenseVars(); - - private: - bool need_dump_param_; - std::vector dump_param_; - bool need_to_push_dense_; - bool need_dump_field_; - bool dump_slot_; - bool need_to_push_sparse_; - std::vector dump_fields_; - ChannelWriter writer_; + std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end); + std::pair GetTensorBound(LoDTensor* tensor, int index); + bool CheckValidOutput(LoDTensor* tensor, size_t batch_size); DownpourWorkerParameter param_; - float scale_datanorm_; - // just save the value in param_ for easy access - std::map label_var_name_; - std::map> sparse_key_names_; - std::map> sparse_value_names_; - std::map> sparse_grad_names_; - std::map> dense_value_names_; - std::map> dense_grad_names_; + // copy table + CopyTableConfig copy_table_config_; + std::vector> copy_sparse_tables_; + std::unordered_map> feasign_set_; // actually pushed feasign of each table std::map> sparse_push_keys_; - + std::map> sparse_key_names_; // feasign std::map> features_; - // feasign stats - std::map> feature_labels_; // feasign embedding std::map>> feature_values_; + std::map> sparse_value_names_; + // adjust ins weight + AdjustInsWeightConfig adjust_ins_weight_config_; + // check nan and inf during training + std::vector check_nan_var_names_; + bool need_to_push_sparse_; + // feasign stats + std::map> feature_labels_; + std::map> sparse_grad_names_; // feasign embedding gradient std::map>> feature_grads_; + std::vector<::std::future> push_sparse_status_; + bool dump_slot_; + bool need_to_push_dense_; + bool need_dump_field_; + bool need_dump_param_; + std::map> dense_grad_names_; + float scale_datanorm_; + std::vector<::std::future> push_dense_status_; + std::vector dump_fields_; + ChannelWriter writer_; // skipped ops std::vector skip_ops_; + std::vector dump_param_; + // just save the value in param_ for easy access + std::map label_var_name_; + std::map> dense_value_names_; + std::map table_dependency_; + std::vector> copy_dense_tables_; + + private: + // std::vector dump_param_; + // just save the value in param_ for easy access + // std::map label_var_name_; + // std::map> dense_value_names_; std::shared_ptr _pull_dense_worker; - std::vector<::std::future> push_sparse_status_; - std::vector<::std::future> push_dense_status_; - // adjust ins weight - AdjustInsWeightConfig adjust_ins_weight_config_; std::vector nid_show_; - // check nan and inf during training - std::vector check_nan_var_names_; - // copy table - CopyTableConfig copy_table_config_; - std::map table_dependency_; - std::vector> copy_sparse_tables_; - std::vector> copy_dense_tables_; - std::unordered_map> feasign_set_; + // std::map table_dependency_; + // std::vector> copy_dense_tables_; +}; + +class DownpourWorkerOpt : public DownpourWorker { + public: + DownpourWorkerOpt() {} + virtual ~DownpourWorkerOpt() {} + virtual void CreateDeviceResource(const ProgramDesc& main_prog); + virtual void Initialize(const TrainerDesc& desc); + virtual void TrainFiles(); + + protected: + void CreateThreadOperatorsWithRerank(const ProgramDesc& program); + std::vector> loss_ops_; + std::vector> loss_op_names_; + std::vector loss_names_; + std::string async_wait_name_; + int async_index_ = -1; + uint64_t async_tid_ = 0; }; #if defined(PADDLE_WITH_NCCL) diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index e163b601d9eaeeb5d17f2eb8387bcd9bd5dd9417..80e4000c9dc686bc413b38fcf8298dc8b5399335 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -61,6 +61,7 @@ std::shared_ptr DeviceWorkerFactory::CreateDeviceWorker( REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); +REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt); #if defined(PADDLE_WITH_NCCL) REGISTER_DEVICE_WORKER_CLASS(SectionWorker); #endif diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index ebe62c7d05a7c21ba316c60b6ebfc1cdfa998a47..763441d764dd2c44a8ed46e89cbd0c09b655bd99 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -157,7 +157,8 @@ std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start, return os.str(); } -std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) { +std::string DownpourWorker::PrintLodTensor(LoDTensor* tensor, int64_t start, + int64_t end) { std::string out_val; if (tensor->type() == proto::VarType::FP32) { out_val = PrintLodTensorType(tensor, start, end); @@ -171,7 +172,8 @@ std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) { return out_val; } -std::pair GetTensorBound(LoDTensor* tensor, int index) { +std::pair DownpourWorker::GetTensorBound(LoDTensor* tensor, + int index) { auto& dims = tensor->dims(); if (tensor->lod().size() != 0) { auto& lod = tensor->lod()[0]; @@ -181,7 +183,7 @@ std::pair GetTensorBound(LoDTensor* tensor, int index) { } } -bool CheckValidOutput(LoDTensor* tensor, size_t batch_size) { +bool DownpourWorker::CheckValidOutput(LoDTensor* tensor, size_t batch_size) { auto& dims = tensor->dims(); if (dims.size() != 2) return false; if (tensor->lod().size() != 0) { diff --git a/paddle/fluid/framework/downpour_worker_opt.cc b/paddle/fluid/framework/downpour_worker_opt.cc new file mode 100644 index 0000000000000000000000000000000000000000..0db2c7510c7d33b97d185036a89762c2e7b7c3ca --- /dev/null +++ b/paddle/fluid/framework/downpour_worker_opt.cc @@ -0,0 +1,586 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/platform/lodtensor_printer.h" + +namespace paddle { +namespace framework { + +bool HasDependentOutput(const OpDesc& op_desc, + const std::unordered_set& dependent_vars) { + for (auto& var : op_desc.Outputs()) { + for (auto& argu : var.second) { + if (dependent_vars.count(argu) != 0) { + return true; + } + } + } + return false; +} + +bool HasDependentInput(const OpDesc& op_desc, + const std::unordered_set& dependent_vars) { + for (auto& var : op_desc.Inputs()) { + for (auto& argu : var.second) { + if (dependent_vars.count(argu) != 0) { + return true; + } + } + } + return false; +} + +bool OnlyHasDependentInput( + const OpDesc& op_desc, + const std::unordered_set& dependent_vars) { + for (auto& var : op_desc.Inputs()) { + for (auto& argu : var.second) { + if (dependent_vars.count(argu) == 0) { + return false; + } + } + } + return true; +} + +bool NotHasDependentOutput( + const OpDesc& op_desc, + const std::unordered_set& dependent_vars) { + for (auto& var : op_desc.Outputs()) { + for (auto& argu : var.second) { + if (dependent_vars.count(argu) != 0) { + return false; + } + } + } + return true; +} + +bool HasOutput(const OpDesc& op_desc, const std::string& name) { + for (auto& var : op_desc.Outputs()) { + for (auto& argu : var.second) { + if (argu == name) { + return true; + } + } + } + return false; +} +void AppendInputVar(const OpDesc& op_desc, + std::unordered_set* vars_set) { + for (auto& var : op_desc.Inputs()) { + for (auto& arg : var.second) { + vars_set->emplace(arg); + } + } +} + +void AppendOutputVar(const OpDesc& op_desc, + std::unordered_set* vars_set) { + for (auto& var : op_desc.Outputs()) { + for (auto& arg : var.second) { + vars_set->emplace(arg); + } + } +} + +void DownpourWorkerOpt::Initialize(const TrainerDesc& desc) { + param_ = desc.downpour_param(); + for (int i = 0; i < param_.sparse_table_size(); ++i) { + uint64_t table_id = + static_cast(param_.sparse_table(i).table_id()); + TableParameter table = param_.sparse_table(i); + sparse_key_names_[table_id].resize(table.sparse_key_name_size()); + for (int j = 0; j < table.sparse_key_name_size(); ++j) { + sparse_key_names_[table_id][j] = table.sparse_key_name(j); + } + sparse_value_names_[table_id].resize(table.sparse_value_name_size()); + for (int j = 0; j < table.sparse_value_name_size(); ++j) { + sparse_value_names_[table_id][j] = table.sparse_value_name(j); + } + sparse_grad_names_[table_id].resize(table.sparse_grad_name_size()); + for (int j = 0; j < table.sparse_grad_name_size(); ++j) { + sparse_grad_names_[table_id][j] = table.sparse_grad_name(j); + } + label_var_name_[table_id] = table.label_var_name(); + sparse_push_keys_[table_id] = std::vector(); + } + + for (int i = 0; i < param_.dense_table_size(); ++i) { + uint64_t table_id = static_cast(param_.dense_table(i).table_id()); + auto table = param_.dense_table(i); + dense_value_names_[table_id].resize(table.dense_value_name_size()); + for (int j = 0; j < table.dense_value_name_size(); ++j) { + dense_value_names_[table_id][j] = table.dense_value_name(j); + } + dense_grad_names_[table_id].resize(table.dense_grad_name_size()); + for (int j = 0; j < table.dense_grad_name_size(); ++j) { + dense_grad_names_[table_id][j] = table.dense_grad_name(j); + } + } + + skip_ops_.resize(param_.skip_ops_size()); + for (int i = 0; i < param_.skip_ops_size(); ++i) { + skip_ops_[i] = param_.skip_ops(i); + } + + for (int i = 0; i < param_.stat_var_names_size(); ++i) { + stat_var_name_map_[param_.stat_var_names(i)] = 1; + } + + need_to_push_sparse_ = param_.push_sparse(); + need_to_push_dense_ = param_.push_dense(); + + fleet_ptr_ = FleetWrapper::GetInstance(); + fetch_config_ = desc.fetch_config(); + use_cvm_ = desc.use_cvm(); + // for sparse value accessor, embedding only + no_cvm_ = desc.no_cvm(); + scale_datanorm_ = desc.scale_datanorm(); + dump_slot_ = desc.dump_slot(); + dump_fields_.resize(desc.dump_fields_size()); + for (int i = 0; i < desc.dump_fields_size(); ++i) { + dump_fields_[i] = desc.dump_fields(i); + } + adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); + need_dump_param_ = false; + dump_param_.resize(desc.dump_param_size()); + for (int i = 0; i < desc.dump_param_size(); ++i) { + dump_param_[i] = desc.dump_param(i); + } + if (desc.dump_param_size() != 0) { + need_dump_param_ = true; + } + for (int i = 0; i < desc.loss_names_size(); ++i) { + loss_names_.push_back(desc.loss_names(i)); + } + for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { + check_nan_var_names_.push_back(desc.check_nan_var_names(i)); + } + copy_table_config_ = desc.copy_table_config(); + for (int i = 0; i < copy_table_config_.src_sparse_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_sparse_tables(i); + uint64_t dest_table = copy_table_config_.dest_sparse_tables(i); + VLOG(3) << "copy_sparse_tables_ push back " << src_table << "->" + << dest_table; + copy_sparse_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (int i = 0; i < copy_table_config_.src_dense_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_dense_tables(i); + uint64_t dest_table = copy_table_config_.dest_dense_tables(i); + VLOG(3) << "copy_dense_tables_ push back " << src_table << "->" + << dest_table; + copy_dense_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (auto& m : copy_table_config_.table_denpendency_map()) { + if (sparse_key_names_.find(m.key()) != sparse_key_names_.end()) { + // currently only support one dependency + for (auto& value : m.values()) { + table_dependency_[m.key()] = value; + } + } + } +} + +void DownpourWorkerOpt::CreateDeviceResource(const ProgramDesc& main_prog) { + CreateThreadScope(main_prog); + CreateThreadOperatorsWithRerank(main_prog); +} + +void DownpourWorkerOpt::CreateThreadOperatorsWithRerank( + const ProgramDesc& program) { + auto& block = program.Block(0); + std::vector ops = block.AllOps(); + // check if Independent between losses if not skip for now + int loss_num = loss_names_.size(); + std::unordered_map> + loss_input_map; + std::unordered_map> + loss_output_map; + std::unordered_map> + loss_input_grad_map; + std::unordered_map> + loss_output_grad_map; + std::unordered_map> + metric_input_map; + std::unordered_map> + metric_output_map; + std::vector loss_grad_names; + for (int i = 0; i < loss_num; i++) { + loss_grad_names.push_back(loss_names_[i] + "@GRAD"); + } + // mark forward ops by loss + for (int i = 0; i < loss_num; i++) { + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + if (i > 0) { + for (int j = 0; j < i; j++) { + if (HasDependentInput(*op_desc, loss_input_map[loss_names_[j]])) { + VLOG(3) << "losses must be independence currently"; + return; + } + } + } + if (HasOutput(*op_desc, loss_names_[i]) || + HasOutput(*op_desc, loss_grad_names[i]) || + HasDependentOutput(*op_desc, loss_input_map[loss_names_[i]])) { + AppendInputVar(*op_desc, &loss_input_map[loss_names_[i]]); + AppendOutputVar(*op_desc, &loss_output_map[loss_names_[i]]); + } + } + } + + for (int i = 0; i < loss_num; i++) { + for (auto op_iter = ops.begin(); op_iter != ops.end(); ++op_iter) { + auto& op_desc = *op_iter; + if (HasOutput(*op_desc, loss_grad_names[i]) || + HasDependentInput(*op_desc, loss_output_grad_map[loss_names_[i]])) { + AppendInputVar(*op_desc, &loss_input_grad_map[loss_names_[i]]); + AppendOutputVar(*op_desc, &loss_output_grad_map[loss_names_[i]]); + } + } + } + + for (int i = 0; i < loss_num; i++) { + for (auto op_iter = ops.begin(); op_iter != ops.end(); ++op_iter) { + auto& op_desc = *op_iter; + if ((HasDependentInput(*op_desc, loss_output_map[loss_names_[i]]) && + OnlyHasDependentInput(*op_desc, loss_output_map[loss_names_[i]]) && + NotHasDependentOutput(*op_desc, loss_input_map[loss_names_[i]])) || + HasDependentInput(*op_desc, metric_output_map[loss_names_[i]])) { + AppendInputVar(*op_desc, &metric_input_map[loss_names_[i]]); + AppendOutputVar(*op_desc, &metric_output_map[loss_names_[i]]); + } + } + } + + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = + static_cast(param_.program_config(0).pull_sparse_table_id(i)); + TableParameter table; + for (auto j : param_.sparse_table()) { + if (j.table_id() == tid) { + table = j; + break; + } + } + if (table.is_async()) { + async_tid_ = tid; + async_index_ = i; + async_wait_name_ = table.async_wait_op_name(); + } + } + loss_op_names_.resize(loss_num); + loss_ops_.resize(loss_num); + std::string async_wait_flag = "async_wait_flag"; + for (int i = 0; i < loss_num; i++) { + for (auto op_iter = ops.begin(); op_iter != ops.end(); ++op_iter) { + auto& op_desc = *op_iter; + if ((op_desc->Type() == "fill_constant" && + HasDependentOutput(*op_desc, + loss_output_grad_map[loss_names_[i]])) || + (HasDependentInput(*op_desc, loss_input_map[loss_names_[i]]) && + HasDependentOutput(*op_desc, loss_output_map[loss_names_[i]])) || + (HasDependentInput(*op_desc, loss_input_grad_map[loss_names_[i]]) && + HasDependentOutput(*op_desc, + loss_output_grad_map[loss_names_[i]])) || + (HasDependentInput(*op_desc, metric_input_map[loss_names_[i]]) && + HasDependentOutput(*op_desc, metric_output_map[loss_names_[i]]))) { + std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); + if (HasOutput(*op_desc, async_wait_name_)) { + loss_op_names_[i].push_back(async_wait_flag); + } else { + loss_op_names_[i].push_back(op_desc->Type()); + } + OperatorBase* local_op_ptr = local_op.release(); + loss_ops_[i].push_back(local_op_ptr); + } + } + } +} + +void DownpourWorkerOpt::TrainFiles() { + VLOG(3) << "Begin to train files"; + platform::SetNumThreads(1); + device_reader_->Start(); + int batch_cnt = 0; + int cur_batch; + std::future pull_async_status; + std::string async_wait_name = ""; + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = + static_cast(param_.program_config(0).pull_sparse_table_id(i)); + TableParameter table; + for (auto j : param_.sparse_table()) { + if (j.table_id() == tid) { + table = j; + break; + } + } + } + // pre-defined for the first op run with async-pulled embedding + while ((cur_batch = device_reader_->Next()) > 0) { + if (copy_table_config_.need_copy()) { + if (copy_table_config_.sparse_copy_by_feasign()) { + for (size_t i = 0; i < copy_sparse_tables_.size(); ++i) { + uint64_t tid = copy_sparse_tables_[i].first; + feasign_set_[tid].insert(sparse_push_keys_[tid].begin(), + sparse_push_keys_[tid].end()); + } + } + if (batch_cnt % copy_table_config_.batch_num() == 0) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } + } + // pull sparse here + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).pull_sparse_table_id(i)); + TableParameter table; + for (auto j : param_.sparse_table()) { + if (j.table_id() == tid) { + table = j; + break; + } + } + if (table.is_local()) { + fleet_ptr_->PullSparseVarsFromLocal( + *thread_scope_, tid, sparse_key_names_[tid], &features_[tid], + &feature_values_[tid], table.fea_dim()); + CollectLabelInfo(i); + continue; + } else if (table.is_async()) { + pull_async_status = fleet_ptr_->PullSparseVarsAsync( + *thread_scope_, tid, sparse_key_names_[tid], &features_[tid], + &feature_values_[tid], table.fea_dim()); + continue; + } else { + fleet_ptr_->PullSparseVarsSync( + *thread_scope_, tid, sparse_key_names_[tid], &features_[tid], + &feature_values_[tid], table.fea_dim(), sparse_value_names_[tid]); + } + CollectLabelInfo(i); + FillSparseValue(i); + auto nid_iter = std::find(sparse_value_names_[tid].begin(), + sparse_value_names_[tid].end(), + adjust_ins_weight_config_.nid_slot()); + if (nid_iter != sparse_value_names_[tid].end()) { + AdjustInsWeight(); + } + } + VLOG(3) << "fill sparse value for all sparse table done."; + + // do computation here + for (size_t loss_idx = 0; loss_idx < loss_ops_.size(); loss_idx++) { + int op_idx = 0; + for (auto& op : loss_ops_[loss_idx]) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + if (loss_op_names_[loss_idx][op_idx] == async_wait_name_) { + pull_async_status.wait(); + auto status = pull_async_status.get(); + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status + << "]"; + sleep(1); + exit(-1); + } else { + // CollectLabelInfo(async_index); + FillSparseValue(async_index_); + auto nid_iter = std::find(sparse_value_names_[async_tid_].begin(), + sparse_value_names_[async_tid_].end(), + adjust_ins_weight_config_.nid_slot()); + if (nid_iter != sparse_value_names_[async_tid_].end()) { + AdjustInsWeight(); + } + } + } + op->Run(*thread_scope_, place_); + } + } + op_idx++; + } + // check inf and nan + for (std::string& var_name : check_nan_var_names_) { + Variable* var = thread_scope_->FindVar(var_name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (tensor == nullptr) { + continue; + } + PADDLE_ENFORCE_EQ( + framework::TensorContainsInf(*tensor), false, + platform::errors::InvalidArgument("The target tensor %s contains Inf " + "should check some layers output.", + var_name)); + PADDLE_ENFORCE_EQ( + framework::TensorContainsNAN(*tensor), false, + platform::errors::InvalidArgument("The target tensor %s contains Nan " + "should check some layers output.", + var_name)); + } + + if (need_to_push_sparse_) { + // push gradients here + for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_sparse_table_id(i)); + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == tid) { + table = i; + break; + } + } + fleet_ptr_->PushSparseVarsWithLabelAsync( + *thread_scope_, tid, features_[tid], feature_labels_[tid], + sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), + &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, + dump_slot_, &sparse_push_keys_[tid], no_cvm_); + } + } + + if (need_to_push_dense_) { + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + fleet_ptr_->PushDenseVarsAsync( + *thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_, + scale_datanorm_, cur_batch); + } + VLOG(3) << "push dense gradient done."; + + // the following code should be more precise and clean + // TODO(guru4elephant) + int32_t tmp_push_dense_wait_times = -1; + static uint32_t push_dense_wait_times = + static_cast(tmp_push_dense_wait_times); + + if (push_dense_status_.size() >= push_dense_wait_times) { + for (auto& t : push_dense_status_) { + t.wait(); + } + push_dense_status_.resize(0); + } + + if (tmp_push_dense_wait_times == -1) { + push_dense_status_.resize(0); + } + } + + if (need_to_push_sparse_) { + VLOG(3) << "push sparse gradient done."; + int32_t tmp_push_sparse_wait_times = -1; + static uint32_t push_sparse_wait_times = + static_cast(tmp_push_sparse_wait_times); + if (push_sparse_status_.size() >= push_sparse_wait_times) { + for (auto& t : push_sparse_status_) { + t.wait(); + } + push_sparse_status_.resize(0); + } + + if (tmp_push_sparse_wait_times == -1) { + push_sparse_status_.resize(0); + } + } + + if (need_to_push_dense_) { + for (int i = 0; i < param_.program_config(0).push_dense_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_dense_table_id(i)); + pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid); + } + } + if (need_dump_field_) { + size_t batch_size = device_reader_->GetCurBatchSize(); + std::vector ars(batch_size); + for (auto& ar : ars) { + ar.clear(); + } + auto& ins_id_vec = device_reader_->GetInsIdVec(); + auto& ins_content_vec = device_reader_->GetInsContentVec(); + for (size_t i = 0; i < ins_id_vec.size(); i++) { + ars[i] += ins_id_vec[i]; + ars[i] = ars[i] + "\t" + ins_content_vec[i]; + } + for (auto& field : dump_fields_) { + Variable* var = thread_scope_->FindVar(field); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (!CheckValidOutput(tensor, batch_size)) { + continue; + } + for (size_t i = 0; i < batch_size; ++i) { + auto output_dim = tensor->dims()[1]; + std::string output_dimstr = + boost::lexical_cast(output_dim); + ars[i] = ars[i] + "\t" + field + ":" + output_dimstr; + auto bound = GetTensorBound(tensor, i); + ars[i] += PrintLodTensor(tensor, bound.first, bound.second); + } + } + // #pragma omp parallel for + for (size_t i = 0; i < ars.size(); i++) { + if (ars[i].length() == 0) { + continue; + } + writer_ << ars[i]; + } + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(); + } + } + + PrintFetchVars(); + thread_scope_->DropKids(); + ++batch_cnt; + } + if (need_dump_field_) { + writer_.Flush(); + } + if (copy_table_config_.need_copy()) { + CopySparseTable(); + CopyDenseTable(); + CopyDenseVars(); + } +} + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 29d1bf5b2e09ad669752ba3da4177ebe20f55e32..fc52e1a4c930bd41571de4416a6b413923f0e94e 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -29,9 +29,12 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include #include +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/timer.h" namespace paddle { namespace framework { @@ -151,6 +154,151 @@ void FleetWrapper::CreateClient2ClientConnection() { #endif } +void FleetWrapper::PullSparseToLocal(const uint64_t table_id, + int fea_value_dim) { +#ifdef PADDLE_WITH_PSLIB + size_t fea_keys_size = local_tables_.size(); + if (fea_keys_size == 0) { + return; + } + local_table_shard_num_ = fea_keys_size; + platform::Timer timeline; + std::vector threads(fea_keys_size); + auto ptl_func = [this, &table_id](int i) { + size_t key_size = this->local_tables_[i].size(); + std::vector keys; + keys.reserve(key_size); + std::vector pull_result_ptr; + pull_result_ptr.reserve(key_size); + + for (auto& kv : this->local_tables_[i]) { + keys.emplace_back(kv.first); + pull_result_ptr.emplace_back(kv.second.data()); + } + auto tt = pslib_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, keys.data(), key_size); + tt.wait(); + auto status = tt.get(); + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } else { + VLOG(3) << "FleetWrapper Pull sparse to local done with table size: " + << pull_result_ptr.size(); + } + }; + timeline.Start(); + for (size_t i = 0; i < threads.size(); i++) { + threads[i] = std::thread(ptl_func, i); + } + for (std::thread& t : threads) { + t.join(); + } + local_pull_pool_.reset(new ::ThreadPool(pull_local_thread_num_)); + timeline.Pause(); +#endif +} + +void FleetWrapper::PullSparseVarsFromLocal( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim) { +#ifdef PADDLE_WITH_PSLIB + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (auto name : var_names) { + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + size_t key_length = fea_keys->size(); + int local_step = key_length / pull_local_thread_num_; + std::vector> task_futures; + task_futures.reserve(key_length / local_step + 1); + for (size_t i = 0; i < key_length; i += local_step) { + size_t end = i + local_step < key_length ? i + local_step : key_length; + auto pull_local_task = [this, i, end, &fea_values, &fea_keys, + &fea_value_dim] { + for (size_t j = i; j < end; j++) { + std::memcpy((*fea_values)[j].data(), + local_tables_[(*fea_keys)[j] % local_table_shard_num_] + [(*fea_keys)[j]] + .data(), + fea_value_dim * sizeof(float)); + } + }; + task_futures.emplace_back( + local_pull_pool_->enqueue(std::move(pull_local_task))); + } + for (auto& tf : task_futures) { + tf.wait(); + } +#endif +} + +void FleetWrapper::ClearLocalTable() { +#ifdef PADDLE_WITH_PSLIB + for (auto& t : local_tables_) { + t.clear(); + } +#endif +} + +std::future FleetWrapper::PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim) { +#ifdef PADDLE_WITH_PSLIB + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (auto name : var_names) { + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : *fea_values) { + pull_result_ptr.push_back(t.data()); + } + return pslib_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size()); +#endif + return std::future(); +} + void FleetWrapper::PullSparseVarsSync( const Scope& scope, const uint64_t table_id, const std::vector& var_names, std::vector* fea_keys, diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 40c500c64722177d9823bb02147c26a820faebc6..5d831f31c7f6a6f7887e2d1f425a34416a6206ce 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -19,12 +19,15 @@ limitations under the License. */ #include #include #endif +#include #include #include #include #include #include +#include #include + #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable_helper.h" @@ -65,12 +68,16 @@ class FleetWrapper { client2client_connect_timeout_ms_ = 10000; // pslib request max retry client2client_max_retry_ = 3; + pull_local_thread_num_ = 25; } // set client to client communication config void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, int max_retry); + void SetPullLocalThreadNum(int thread_num) { + pull_local_thread_num_ = thread_num; + } // Pull sparse variables from server in sync mode // Param: scope, table_id, var_names, fea_keys, fea_dim // Param: fea_values @@ -80,7 +87,11 @@ class FleetWrapper { std::vector>* fea_values, int fea_dim, const std::vector& var_emb_names); - + std::future PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, int fea_dim); // pull dense variables from server in sync mod void PullDenseVarsSync(const Scope& scope, const uint64_t table_id, const std::vector& var_names); @@ -111,6 +122,18 @@ class FleetWrapper { const std::vector& var_names); // Push sparse variables with labels to server in async mode + std::vector>> local_tables_; + void PullSparseToLocal(const uint64_t table_id, int fea_value_dim); + void PullSparseVarsFromLocal(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, + int fea_value_dim); + void ClearLocalTable(); + std::vector>>& + GetLocalTable() { + return local_tables_; + } // This is specially designed for click/show stats in server // Param: scope, table_id, fea_keys, fea_labels, sparse_key_names, // sparse_grad_names, batch_size, use_cvm, dump_slot @@ -237,6 +260,10 @@ class FleetWrapper { int client2client_request_timeout_ms_; int client2client_connect_timeout_ms_; int client2client_max_retry_; + std::unique_ptr<::ThreadPool> local_pull_pool_{nullptr}; + int pull_local_thread_num_; + std::unique_ptr<::ThreadPool> pull_to_local_pool_{nullptr}; + int local_table_shard_num_; DISABLE_COPY_AND_ASSIGN(FleetWrapper); }; diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index ec5bef64d7a7e6071c6041a88c5829c22aede024..43b2dd63e4035e8585549a0bb094dd9b14e5f52b 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -48,6 +48,7 @@ message TrainerDesc { optional AdjustInsWeightConfig adjust_ins_weight_config = 20; optional bool no_cvm = 21 [ default = false ]; optional bool thread_barrier = 22; + repeated string loss_names = 23; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; @@ -164,4 +165,9 @@ message TableParameter { optional int32 emb_dim = 10; optional int32 fea_dim = 11; optional string label_var_name = 12; + // if table will pull sparse to local first + optional bool is_local = 13 [ default = false ]; + // if table will pull sparse asynchronously in worker + optional bool is_async = 14 [ default = false ]; + optional string async_wait_op_name = 15; } diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index dd513d4b85ccb0d18d2641239365523c5b0b7ea4..6435aea8a8811bc446b658841108d1b4ea0f00c4 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -247,6 +247,12 @@ void BindDataset(py::module *m) { py::call_guard()) .def("merge_by_lineid", &framework::Dataset::MergeByInsId, py::call_guard()) + .def("set_generate_unique_feasigns", + &framework::Dataset::SetGenerateUniqueFeasign, + py::call_guard()) + .def("generate_local_tables_unlock", + &framework::Dataset::GenerateLocalTablesUnlock, + py::call_guard()) .def("slots_shuffle", &framework::Dataset::SlotsShuffle, py::call_guard()) .def("set_fea_eval", &framework::Dataset::SetFeaEval, diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 03e574bdd270819413e81a1019872ec8f1fdcd08..fac6de452aed018f1397c536c40d7a55b5f188b0 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -75,6 +75,8 @@ void BindFleetWrapper(py::module* m) { .def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable) .def("set_client2client_config", &framework::FleetWrapper::SetClient2ClientConfig) + .def("set_pull_local_thread_num", + &framework::FleetWrapper::SetPullLocalThreadNum) .def("copy_table", &framework::FleetWrapper::CopyTable) .def("copy_table_by_feasign", &framework::FleetWrapper::CopyTableByFeasign); diff --git a/paddle/fluid/train/demo/CMakeLists.txt b/paddle/fluid/train/demo/CMakeLists.txt index a15ddc9273fd12e3a06a88253c477e3010800539..57fda493a81102096ab6c152379c9557e50d6828 100644 --- a/paddle/fluid/train/demo/CMakeLists.txt +++ b/paddle/fluid/train/demo/CMakeLists.txt @@ -20,6 +20,7 @@ include_directories("${PADDLE_LIB}/third_party/install/zlib/include") include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") +include_directories("${PADDLE_LIB}/third_party/threadpool") include_directories("${PADDLE_LIB}/third_party/dlpack") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") diff --git a/paddle/fluid/train/imdb_demo/CMakeLists.txt b/paddle/fluid/train/imdb_demo/CMakeLists.txt index 7cb4f0a3ec123bc249462501fea591daf0c02521..29d54d0d2fbf61cae9f1cf8505124a969c029005 100644 --- a/paddle/fluid/train/imdb_demo/CMakeLists.txt +++ b/paddle/fluid/train/imdb_demo/CMakeLists.txt @@ -20,6 +20,7 @@ include_directories("${PADDLE_LIB}/third_party/install/zlib/include") include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") +include_directories("${PADDLE_LIB}/third_party/threadpool") include_directories("${PADDLE_LIB}/third_party/dlpack") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 92c8b6f06b880fb700fa5cfd1a594904ba3095b1..b10ebcaa47e62f5f96467e105f189376a9ffb2a7 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -428,6 +428,16 @@ class InMemoryDataset(DatasetBase): self.merge_by_lineid = True self.parse_ins_id = True + def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num): + self.dataset.set_generate_unique_feasigns(generate_uni_feasigns) + self.gen_uni_feasigns = generate_uni_feasigns + self.local_shard_num = shard_num + + def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num, + consume_thread_num, shard_num): + self.dataset.generate_local_tables_unlock( + table_id, fea_dim, read_thread_num, consume_thread_num, shard_num) + def load_into_memory(self): """ Load data into memory diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 1a23d7281e32ec6b68790da35f695e677da6e5d1..f6ffd4fa7c5071af98e77a2469f17359691a9425 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -13,7 +13,9 @@ # limitations under the License. """Defination of device workers.""" -__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section'] +__all__ = [ + 'DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section', 'DownpourSGDOPT' +] class DeviceWorker(object): @@ -190,6 +192,112 @@ class DownpourSGD(DeviceWorker): downpour.push_sparse = False +class DownpourSGDOPT(DeviceWorker): + """ + DownpourSGDOPT is a kind of distributed SGD algorithm. + """ + + def __init__(self): + """ + Init. + initialize downpourSGDOPT device worker + """ + super(DownpourSGDOPT, self).__init__() + + def _gen_worker_desc(self, trainer_desc): + """ + Generator worker desc, which device worker is DownpourWorker. + + Args: + trainer_desc(TrainerDesc): a TrainerDesc object + """ + dense_table_set = set() + program_id = str(id(self._program)) + if self._program == None: + print("program of current device worker is not configured") + exit(-1) + opt_info = self._program._fleet_opt + program_configs = opt_info["program_configs"] + downpour = trainer_desc.downpour_param + + for pid in program_configs: + if pid == program_id: + pc = downpour.program_config.add() + pc.program_id = program_id + for i in program_configs[program_id]["push_sparse"]: + pc.push_sparse_table_id.extend([i]) + for i in program_configs[program_id]["push_dense"]: + pc.push_dense_table_id.extend([i]) + dense_table_set.add(i) + for i in program_configs[program_id]["pull_sparse"]: + pc.pull_sparse_table_id.extend([i]) + for i in program_configs[program_id]["pull_dense"]: + pc.pull_dense_table_id.extend([i]) + dense_table_set.add(i) + break + + trainer_desc.device_worker_name = "DownpourWorkerOpt" + pull_thread = trainer_desc.pull_dense_param + pull_thread.device_num = trainer_desc.thread_num + if opt_info.get("program_id_to_worker") is None: + raise ValueError("opt_info must have program_id_to_worker") + prog_id_to_worker = opt_info["program_id_to_worker"] + if prog_id_to_worker.get(program_id) is None: + raise ValueError("%s not found in program_id_to_worker" % + program_id) + worker = opt_info["program_id_to_worker"][program_id] + for i in worker.get_desc().dense_table: + if i.table_id in dense_table_set: + dense_table = pull_thread.dense_table.add() + dense_table.dense_value_name.extend(i.dense_variable_name) + dense_table.table_id = \ + i.table_id + sparse_len = len(worker.get_desc().sparse_table) + for i in range(sparse_len): + sparse_table = downpour.sparse_table.add() + sparse_table.table_id = worker.get_desc().sparse_table[i].table_id + sparse_table.sparse_key_name.extend(worker.get_desc().sparse_table[ + i].slot_key) + sparse_table.sparse_value_name.extend(worker.get_desc() + .sparse_table[i].slot_value) + sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[ + i].slot_gradient) + if opt_info["use_cvm"] or "no_cvm" in opt_info and opt_info[ + "no_cvm"] == True: + sparse_table.emb_dim = \ + self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ + i].accessor.fea_dim + sparse_table.fea_dim = sparse_table.emb_dim + else: + sparse_table.emb_dim = \ + self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ + i].accessor.fea_dim - 2 + sparse_table.fea_dim = sparse_table.emb_dim + 2 + # TODO(guru4elephant): hard code here, need to improve + sparse_table.label_var_name = "click" + if "local_tables" in opt_info and sparse_table.table_id in opt_info[ + "local_tables"]: + sparse_table.is_local = True + if "async_tables" in opt_info and sparse_table.table_id in opt_info[ + "async_tables"]: + sparse_table.is_async = True + if opt_info["stat_var_names"]: + for i in opt_info["stat_var_names"]: + downpour.stat_var_names.extend([i]) + + for i in worker.get_desc().dense_table: + if i.table_id in dense_table_set: + dense_table = downpour.dense_table.add() + dense_table.table_id = i.table_id + dense_table.dense_value_name.extend(i.dense_variable_name) + dense_table.dense_grad_name.extend( + i.dense_gradient_variable_name) + downpour.skip_ops.extend(worker.get_desc().skip_op) + if self._infer: + downpour.push_dense = False + downpour.push_sparse = False + + class Section(DeviceWorker): """SectionWorker.""" diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 362bb2c586f34d3e527ead38ee95a107cab4c685..ec5f6de81c9a7994dc0396c42380b0d6336002a2 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -51,6 +51,9 @@ class PSLib(Fleet): self._client2client_connect_timeout_ms = connect_timeout_ms self._client2client_max_retry = max_retry + def set_pull_local_thread_num(self, thread_num): + self._fleet_ptr.set_pull_local_thread_num(thread_num) + def init_worker(self): """ init_worker(): will be called by user. When a user knows current process is_server(), he/she diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index c626bc0dc98a2a231b30f90286888ffc4045a869..c5e105cc8d617e55985ce7305ca028ddf4535312 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -182,41 +182,49 @@ class DistributedAdam(DistributedOptimizerImplBase): prog_id_to_param_grads = OrderedDict() # sparse_grads of each program prog_id_to_sparse_grads = OrderedDict() + # unique program set + program_id_set = set() sparse_table_to_index = OrderedDict() sparse_table_index = 0 for loss in losses: - sparse_table = self._find_multi_distributed_lookup_table([loss]) prog_id = str(id(loss.block.program)) - prog_id_to_sparse_table[prog_id] = sparse_table - - # get sparse_table_to_index - for tn in sparse_table: - if sparse_table_to_index.get(tn) is None: - sparse_table_to_index[tn] = sparse_table_index - sparse_table_index += 1 - - # get inputs_dict - inputs_dict = self._find_distributed_lookup_table_inputs( - loss.block.program, sparse_table) - prog_id_to_inputs_dict[prog_id] = inputs_dict - # get outputs_dict - outputs_dict = self._find_distributed_lookup_table_outputs( - loss.block.program, sparse_table) - prog_id_to_outputs_dict[prog_id] = outputs_dict - - prog_id_to_worker[prog_id] = DownpourWorker(self._window) + if prog_id not in program_id_set: + program_id_set.add(prog_id) + sparse_table = self._find_multi_distributed_lookup_table([loss]) + prog_id_to_sparse_table[prog_id] = sparse_table + + # get sparse_table_to_index + for tn in sparse_table: + if sparse_table_to_index.get(tn) is None: + sparse_table_to_index[tn] = sparse_table_index + sparse_table_index += 1 + + # get inputs_dict + inputs_dict = self._find_distributed_lookup_table_inputs( + loss.block.program, sparse_table) + prog_id_to_inputs_dict[prog_id] = inputs_dict + # get outputs_dict + outputs_dict = self._find_distributed_lookup_table_outputs( + loss.block.program, sparse_table) + prog_id_to_outputs_dict[prog_id] = outputs_dict + + prog_id_to_worker[prog_id] = DownpourWorker(self._window) + + grads_dict = self._find_distributed_lookup_table_grads( + loss.block.program, sparse_table) + prog_id_to_sparse_grads[prog_id] = grads_dict # param_grads of program params_grads = sorted( fluid.backward.append_backward(loss, parameter_list, no_grad_set), key=lambda x: x[0].name) - prog_id_to_param_grads[prog_id] = params_grads + if prog_id not in prog_id_to_param_grads: + prog_id_to_param_grads[prog_id] = [] + prog_id_to_param_grads[prog_id].append(params_grads) - grads_dict = self._find_distributed_lookup_table_grads( - loss.block.program, sparse_table) - prog_id_to_sparse_grads[prog_id] = grads_dict + #if strategy.get("parallel_compute") # if user specify a fleet_desc.prototxt file, then load the file # instead of creating default fleet_desc.prototxt. @@ -251,90 +259,109 @@ class DistributedAdam(DistributedOptimizerImplBase): server.add_sparse_table(sparse_table_index, None) # each DownpourTrainerParameter add its own sparse tables + program_id_set.clear() for loss in losses: prog_id = str(id(loss.block.program)) - worker = prog_id_to_worker[prog_id] - inputs_dict = prog_id_to_inputs_dict[prog_id] - outputs_dict = prog_id_to_outputs_dict[prog_id] - for tn in prog_id_to_sparse_table[prog_id]: - sparse_table_index = sparse_table_to_index[tn] - grads_dict = prog_id_to_sparse_grads[prog_id] - worker.add_sparse_table(sparse_table_index, inputs_dict[tn], - outputs_dict[tn], grads_dict[tn]) + if prog_id not in program_id_set: + program_id_set.add(prog_id) + worker = prog_id_to_worker[prog_id] + inputs_dict = prog_id_to_inputs_dict[prog_id] + outputs_dict = prog_id_to_outputs_dict[prog_id] + for tn in prog_id_to_sparse_table[prog_id]: + sparse_table_index = sparse_table_to_index[tn] + grads_dict = prog_id_to_sparse_grads[prog_id] + worker.add_sparse_table(sparse_table_index, inputs_dict[tn], + outputs_dict[tn], grads_dict[tn]) dense_start_table_id = len(sparse_table_to_index) dense_table_index = len(sparse_table_to_index) program_configs = {} # ServerParameter add all dense tables # each DownpourTrainerParameter add its own dense tables + program_id_set.clear() for loss_index in range(len(losses)): program_id = str(id(losses[loss_index].block.program)) - worker = prog_id_to_worker[program_id] - sparse_table_names = prog_id_to_sparse_table[program_id] - sparse_table_index = \ - [sparse_table_to_index[i] for i in sparse_table_names] - - program_configs[program_id] = { - "pull_sparse": [t_index for t_index in sparse_table_index], - "push_sparse": [t_index for t_index in sparse_table_index] - } - - params_grads = prog_id_to_param_grads[program_id] - params = [] - grads = [] - data_norm_params = [] - data_norm_grads = [] - for i in params_grads: - is_data_norm_data = False - for data_norm_name in self.data_norm_name: - if i[0].name.endswith(data_norm_name): - is_data_norm_data = True - data_norm_params.append(i[0]) - if not is_data_norm_data: - params.append(i[0]) - - for i in params_grads: - is_data_norm_data = False - for data_norm_grad in self.data_norm_name: - if i[0].name.endswith(data_norm_grad): - is_data_norm_data = True - data_norm_grads.append(i[1]) - if not is_data_norm_data: - grads.append(i[1]) - - if strategy.get('dense_table') is not None: - server.add_dense_table(dense_table_index, params, grads, - strategy['dense_table'], - sparse_table_names) - else: - server.add_dense_table(dense_table_index, params, grads, None, - sparse_table_names) - worker.add_dense_table(dense_table_index, self._learning_rate, - params, grads, dense_start_table_id, - sparse_table_names) - program_configs[program_id]["pull_dense"] = [dense_table_index] - program_configs[program_id]["push_dense"] = [dense_table_index] - if len(data_norm_params) != 0 and len(data_norm_grads) != 0: - dense_table_index += 1 - if strategy.get('datanorm_table') is not None: - server.add_data_norm_table( - dense_table_index, self._learning_rate, - data_norm_params, data_norm_grads, - strategy['datanorm_table'], sparse_table_names) - else: - server.add_data_norm_table( - dense_table_index, self._learning_rate, - data_norm_params, data_norm_grads, None, - sparse_table_names) - - worker.add_dense_table(dense_table_index, self._learning_rate, - data_norm_params, data_norm_grads, - dense_start_table_id, sparse_table_names) - program_configs[program_id]["pull_dense"].extend( - [dense_table_index]) - program_configs[program_id]["push_dense"].extend( - [dense_table_index]) - dense_table_index += 1 + if program_id not in program_id_set: + program_id_set.add(program_id) + worker = prog_id_to_worker[program_id] + sparse_table_names = prog_id_to_sparse_table[program_id] + sparse_table_index = \ + [sparse_table_to_index[i] for i in sparse_table_names] + + program_configs[program_id] = { + "pull_sparse": [t_index for t_index in sparse_table_index], + "push_sparse": [t_index for t_index in sparse_table_index] + } + + params_grads = prog_id_to_param_grads[program_id] + for pg in params_grads: + params = [] + grads = [] + data_norm_params = [] + data_norm_grads = [] + for i in pg: + is_data_norm_data = False + for data_norm_name in self.data_norm_name: + if i[0].name.endswith(data_norm_name): + is_data_norm_data = True + data_norm_params.append(i[0]) + if not is_data_norm_data: + params.append(i[0]) + + for i in pg: + is_data_norm_data = False + for data_norm_grad in self.data_norm_name: + if i[0].name.endswith(data_norm_grad): + is_data_norm_data = True + data_norm_grads.append(i[1]) + if not is_data_norm_data: + grads.append(i[1]) + + if strategy.get('dense_table') is not None: + server.add_dense_table(dense_table_index, params, grads, + strategy['dense_table'], + sparse_table_names) + else: + server.add_dense_table(dense_table_index, params, grads, + None, sparse_table_names) + worker.add_dense_table( + dense_table_index, self._learning_rate, params, grads, + dense_start_table_id, sparse_table_names) + if "pull_dense" in program_configs[ + program_id] and "push_dense" in program_configs[ + program_id] and len(program_configs[program_id][ + "pull_dense"]) > 0: + program_configs[program_id]["pull_dense"].extend( + [dense_table_index]) + program_configs[program_id]["push_dense"].extend( + [dense_table_index]) + else: + program_configs[program_id][ + "pull_dense"] = [dense_table_index] + program_configs[program_id][ + "push_dense"] = [dense_table_index] + if len(data_norm_params) != 0 and len(data_norm_grads) != 0: + dense_table_index += 1 + if strategy.get('datanorm_table') is not None: + server.add_data_norm_table( + dense_table_index, self._learning_rate, + data_norm_params, data_norm_grads, + strategy['datanorm_table'], sparse_table_names) + else: + server.add_data_norm_table( + dense_table_index, self._learning_rate, + data_norm_params, data_norm_grads, None, + sparse_table_names) + + worker.add_dense_table( + dense_table_index, self._learning_rate, + data_norm_params, data_norm_grads, + dense_start_table_id, sparse_table_names) + program_configs[program_id]["pull_dense"].extend( + [dense_table_index]) + program_configs[program_id]["push_dense"].extend( + [dense_table_index]) + dense_table_index += 1 # Todo(guru4elephant): figure out how to support more sparse parameters # currently only support lookup_table @@ -370,13 +397,16 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["program_id_to_worker"] = prog_id_to_worker opt_info["program_configs"] = program_configs opt_info["trainer"] = "DistMultiTrainer" - opt_info["device_worker"] = "DownpourSGD" + opt_info["device_worker"] = strategy.get("device_worker", "DownpourSGD") opt_info["optimizer"] = "DownpourSGD" opt_info["fleet_desc"] = ps_param opt_info["worker_skipped_ops"] = worker_skipped_ops opt_info["use_cvm"] = strategy.get("use_cvm", False) opt_info["no_cvm"] = strategy.get("no_cvm", False) opt_info["stat_var_names"] = strategy.get("stat_var_names", []) + opt_info["local_tables"] = strategy.get("local_tables", []) + opt_info["async_tables"] = strategy.get("async_tables", []) + opt_info["async_tables"] = strategy.get("async_tables", []) opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1) opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names", []) @@ -391,6 +421,7 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["dump_slot"] = True opt_info["adjust_ins_weight"] = strategy.get("adjust_ins_weight", {}) opt_info["copy_table"] = strategy.get("copy_table", {}) + opt_info["loss_names"] = strategy.get("loss_names", []) for loss in losses: loss.block.program._fleet_opt = opt_info diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index ff115963216baa70cb1bede15f54bbc8c1415f68..6f13fd7220b8bd464a5797ba898e520ac329547e 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -177,7 +177,8 @@ class TestDataset(unittest.TestCase): dataset.set_fea_eval(10000, True) dataset.slots_shuffle(["slot1"]) dataset.local_shuffle() - + dataset.set_generate_unique_feasigns(True, 15) + dataset.generate_local_tables_unlock(0, 11, 1, 25, 15) exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) if self.use_data_loader: diff --git a/python/paddle/fluid/tests/unittests/test_downpoursgd.py b/python/paddle/fluid/tests/unittests/test_downpoursgd.py index 3564609e08e8a49b588a4c4d407843c58e7102ff..0d78a23e111cfdb0caf58afff8f942dd60b68ada 100644 --- a/python/paddle/fluid/tests/unittests/test_downpoursgd.py +++ b/python/paddle/fluid/tests/unittests/test_downpoursgd.py @@ -25,7 +25,7 @@ import unittest import sys from op_test import OpTest from paddle.fluid.trainer_desc import DistMultiTrainer -from paddle.fluid.device_worker import DownpourSGD +from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker from google.protobuf import text_format import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib @@ -157,6 +157,66 @@ class TestListenAndServOp(unittest.TestCase): cmd = "rm fleet_desc.prototxt*" os.system(cmd) + def test_downpour_opt_work(self): + """test devicve worker.""" + if sys.platform == 'win32' or sys.platform == 'sys.platform': + pass + else: + print(sys.platform) + cmd = "wget --no-check-certificate https://pslib.bj.bcebos.com/fleet_desc.prototxt" + os.system(cmd) + x = fluid.layers.data(name='x', shape=[1], dtype='int64') + x_emb = fluid.layers.embedding( + input=x, size=[1, 2], is_distributed=True) + y_predict = fluid.layers.fc(input=x_emb, size=1, act=None) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + ps_param = pslib.PSParameter() + with open("fleet_desc.prototxt") as f: + text_format.Merge(f.read(), ps_param) + fleet_desc = ps_param + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + + opt_info = {} + main_program = fluid.default_main_program() + program_id = str(id(avg_cost.block.program)) + program_configs = {} + program_configs[program_id] = { + "pull_sparse": [0], + "push_sparse": [0] + } + program_configs[program_id]["pull_dense"] = [1] + program_configs[program_id]["push_dense"] = [1] + + worker_skipped_ops = ["lookup_table", "lookup_table_grad"] + opt_info["program_configs"] = program_configs + opt_info["trainer"] = "DistMultiTrainer" + opt_info["device_worker"] = "DownpourSGDOPT" + opt_info["optimizer"] = "DownpourSGD" + opt_info["fleet_desc"] = ps_param + opt_info["worker_skipped_ops"] = worker_skipped_ops + opt_info["use_cvm"] = False + opt_info["scale_datanorm"] = -1 + opt_info["dump_slot"] = False + opt_info["stat_var_names"] = [] + worker = DownpourWorker(None) + worker.get_desc().CopyFrom(ps_param.trainer_param[0]) + opt_info["program_id_to_worker"] = {program_id: worker} + + main_program._fleet_opt = opt_info + trainer = DistMultiTrainer() + trainer._set_program(main_program) + device_worker = DownpourSGDOPT() + device_worker._set_fleet_desc(fleet_desc) + trainer._set_device_worker(device_worker) + trainer._set_fleet_desc(fleet_desc) + trainer._gen_trainer_desc() + cmd = "rm fleet_desc.prototxt*" + os.system(cmd) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index d28d51dc73f912de525b15bf5218376af404b0c7..f61452f3425ac89acdb01007b57795d52cb4fdee 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -115,6 +115,10 @@ class TrainerDesc(object): for var in check_nan_var_names: self.proto_desc.check_nan_var_names.append(var) + def _set_loss_names(self, loss_names): + for loss in loss_names: + self.proto_desc.loss_names.append(loss) + def _set_adjust_ins_weight(self, config_dict): self.proto_desc.adjust_ins_weight_config.need_adjust = \ config_dict.get("need_adjust", False) diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index b4ffcfa60535e39fe2abdace1c96864f1397f8f6..f426db3df91888ed3ca09ea3c3dbf7717119ee87 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -23,7 +23,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) local_logger = logging.getLogger(__name__) from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer -from .device_worker import Hogwild, DownpourSGD, Section +from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT from .framework import Variable from multiprocessing import Process, Manager @@ -86,6 +86,8 @@ class TrainerFactory(object): "check_nan_var_names"]) if opt_info.get("dump_param") is not None: trainer._set_dump_param(opt_info["dump_param"]) + if opt_info.get("loss_names") is not None: + trainer._set_loss_names(opt_info["loss_names"]) trainer._set_device_worker(device_worker) return trainer