From 175954d894f0f7158c4f062764589096c8da11e9 Mon Sep 17 00:00:00 2001 From: hutuxian Date: Tue, 25 Feb 2020 00:10:36 +0800 Subject: [PATCH] PaddleBox Framework Part2 (#22466) * Add two types of Metric Calculator: MultiTaskCalculator & CmatchRankCalculator. * Add a config for DynamicAdjustChannelNum function to denote whether we will discard the remaining instances when they are not be distributed evenly. * Remove CPU code in Pull/PushSparse and we will add it back when testing it fully. * Fix some known issues: such as copying persistable vars after one epoch running. --- cmake/external/box_ps.cmake | 2 +- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/channel.h | 1 + paddle/fluid/framework/data_set.cc | 15 +- paddle/fluid/framework/data_set.h | 8 +- paddle/fluid/framework/fleet/CMakeLists.txt | 2 +- paddle/fluid/framework/fleet/box_wrapper.cc | 83 +---- paddle/fluid/framework/fleet/box_wrapper.h | 341 ++++++++++++++++-- paddle/fluid/framework/pipeline_trainer.cc | 7 +- paddle/fluid/framework/section_worker.cc | 31 ++ paddle/fluid/framework/trainer.h | 1 + paddle/fluid/pybind/box_helper_py.cc | 35 ++ paddle/fluid/pybind/box_helper_py.h | 3 + paddle/fluid/pybind/pybind.cc | 7 + python/paddle/fluid/dataset.py | 18 +- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_boxps.py | 14 +- python/paddle/fluid/transpiler/collective.py | 2 +- 18 files changed, 436 insertions(+), 139 deletions(-) diff --git a/cmake/external/box_ps.cmake b/cmake/external/box_ps.cmake index c6716d13f1..adfc6dba1f 100644 --- a/cmake/external/box_ps.cmake +++ b/cmake/external/box_ps.cmake @@ -19,7 +19,7 @@ IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL)) MESSAGE(STATUS "use pre defined download url") SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE) SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE) - SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps_stub.tar.gz" CACHE STRING "" FORCE) + SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE) ENDIF() MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}") SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps") diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index ede111edb6..0b408e5b09 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -193,7 +193,7 @@ if(WITH_DISTRIBUTE) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry - device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer + device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") @@ -204,7 +204,7 @@ else() data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog - lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method + lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index d186ef1274..64a645bf8b 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -43,6 +43,7 @@ class ChannelObject { capacity_ = (std::min)(MaxCapacity(), capacity); } + const std::deque& GetData() const { return data_; } void Clear() { std::unique_lock lock(mutex_); data_.clear(); diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 7926a9bfb9..7c5f9351d2 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -390,7 +390,8 @@ void DatasetImpl::GlobalShuffle(int thread_num) { } template -void DatasetImpl::DynamicAdjustChannelNum(int channel_num) { +void DatasetImpl::DynamicAdjustChannelNum(int channel_num, + bool discard_remaining_ins) { if (channel_num_ == channel_num) { VLOG(3) << "DatasetImpl::DynamicAdjustChannelNum channel_num_=" << channel_num_ << ", channel_num_=channel_num, no need to adjust"; @@ -439,13 +440,13 @@ void DatasetImpl::DynamicAdjustChannelNum(int channel_num) { total_data_channel->Write(std::move(local_vec)); } total_data_channel->Close(); - total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num + - 1); - // will discard the remaining instances, - // TODO(hutuxian): should add a config here to choose how to deal with - // remaining instances + if (static_cast(total_data_channel->Size()) >= channel_num) { + total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num + + (discard_remaining_ins ? 0 : 1)); + } if (static_cast(input_channel_->Size()) >= channel_num) { - input_channel_->SetBlockSize(input_channel_->Size() / channel_num); + input_channel_->SetBlockSize(input_channel_->Size() / channel_num + + (discard_remaining_ins ? 0 : 1)); } for (int i = 0; i < channel_num; ++i) { diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index d82035c03e..f244cd76f6 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -126,8 +126,9 @@ class Dataset { virtual void DestroyPreLoadReaders() = 0; // set preload thread num virtual void SetPreLoadThreadNum(int thread_num) = 0; - // separate train thread and dataset thread - virtual void DynamicAdjustChannelNum(int channel_num) = 0; + // seperate train thread and dataset thread + virtual void DynamicAdjustChannelNum(int channel_num, + bool discard_remaining_ins = false) = 0; virtual void DynamicAdjustReadersNum(int thread_num) = 0; // set fleet send sleep seconds virtual void SetFleetSendSleepSeconds(int seconds) = 0; @@ -195,7 +196,8 @@ class DatasetImpl : public Dataset { virtual void CreatePreLoadReaders(); virtual void DestroyPreLoadReaders(); virtual void SetPreLoadThreadNum(int thread_num); - virtual void DynamicAdjustChannelNum(int channel_num); + virtual void DynamicAdjustChannelNum(int channel_num, + bool discard_remaining_ins = false); virtual void DynamicAdjustReadersNum(int thread_num); virtual void SetFleetSendSleepSeconds(int seconds); diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 78e9cb10d5..6922f92c8f 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -8,7 +8,7 @@ if(WITH_NCCL) cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope) endif() if(WITH_BOX_PS) - cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor box_ps) + nv_library(box_wrapper SRCS box_wrapper.cc box_wrapper.cu DEPS framework_proto lod_tensor box_ps) else() cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor) endif(WITH_BOX_PS) diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index 40f17621da..5172964949 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -91,6 +91,7 @@ void BasicAucCalculator::calculate_bucket_error() { _bucket_error = error_count > 0 ? error_sum / error_count : 0.0; } +// Deprecated: should use BeginFeedPass & EndFeedPass void BoxWrapper::FeedPass(int date, const std::vector& feasgin_to_box) const { int ret = boxps_ptr_->FeedPass(date, feasgin_to_box); @@ -140,47 +141,8 @@ void BoxWrapper::PullSparse(const paddle::platform::Place& place, reinterpret_cast(buf->ptr()); if (platform::is_cpu_place(place)) { - // Note: Only GPU is supported in paddlebox now, and following code have not - // be tested fully yet - LoDTensor total_keys_tensor; - uint64_t* total_keys = reinterpret_cast( - total_keys_tensor.mutable_data({total_length, 1}, place)); - int64_t offset = 0; - VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; - for (size_t i = 0; i < keys.size(); ++i) { - memory::Copy(boost::get(place), total_keys + offset, - boost::get(place), keys[i], - slot_lengths[i] * sizeof(uint64_t)); - offset += slot_lengths[i]; - } - - VLOG(3) << "Begin call PullSparseCPU in BoxPS"; - pull_boxps_timer.Start(); - // TODO(hutuxian): should use boxps::FeatureValue in the future - int ret = boxps_ptr_->PullSparseCPU(total_keys, total_values_gpu, - static_cast(total_length)); - PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( - "PullSparseCPU failed in BoxPS.")); - pull_boxps_timer.Pause(); - - VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length - << "]"; - offset = 0; - for (size_t i = 0; i < values.size(); ++i) { - int64_t fea_num = slot_lengths[i]; - VLOG(3) << "Begin Copy slot[" << i << "] fea_num[" << fea_num << "]"; - for (auto j = 0; j < fea_num; ++j) { - // Copy the emb from BoxPS to paddle tensor. Since - // 'show','click','emb' - // are continuous in memory, so we copy here using the 'show' address - memory::Copy( - boost::get(place), values[i] + j * hidden_size, - boost::get(place), - reinterpret_cast(&((total_values_gpu + offset)->show)), - sizeof(float) * hidden_size); - ++offset; - } - } + PADDLE_THROW(platform::errors::Unimplemented( + "Warning:: CPUPlace is not supported in PaddleBox now.")); } else if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; @@ -253,43 +215,8 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, boxps::FeaturePushValueGpu* total_grad_values_gpu = reinterpret_cast(buf->ptr()); if (platform::is_cpu_place(place)) { - // Note: only GPU is supported in paddlebox now, and following code have not - // be tested fully yet - LoDTensor total_keys_tensor; - uint64_t* total_keys = reinterpret_cast( - total_keys_tensor.mutable_data({total_length, 1}, place)); - int64_t offset = 0; - VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; - for (size_t i = 0; i < keys.size(); ++i) { - memory::Copy(boost::get(place), total_keys + offset, - boost::get(place), keys[i], - slot_lengths[i] * sizeof(uint64_t)); - offset += slot_lengths[i]; - } - offset = 0; - VLOG(3) << "Begin copy grad tensor to BoxPS struct"; - for (size_t i = 0; i < grad_values.size(); ++i) { - int64_t fea_num = slot_lengths[i]; - for (auto j = 0; j < fea_num; ++j) { - // Copy the emb grad from paddle tensor to BoxPS. Since - // 'show','click','emb' are continuous in memory, here we copy - // using 'show' address - memory::Copy( - boost::get(place), - reinterpret_cast(&((total_grad_values_gpu + offset)->show)), - boost::get(place), - grad_values[i] + j * hidden_size, sizeof(float) * hidden_size); - ++offset; - } - } - - VLOG(3) << "Begin call PushSparseCPU in BoxPS"; - push_boxps_timer.Start(); - int ret = boxps_ptr_->PushSparseCPU(total_keys, total_grad_values_gpu, - static_cast(total_length)); - PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( - "PushSparseCPU failed in BoxPS.")); - push_boxps_timer.Pause(); + PADDLE_THROW(platform::errors::Unimplemented( + "Warning:: CPUPlace is not supported in PaddleBox now.")); } else if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) int device_id = boost::get(place).GetDeviceId(); diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index f83ecc36bb..2f49cbe0a6 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -27,12 +27,15 @@ limitations under the License. */ #include // NOLINT #include #include +#include #include #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/timer.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace framework { @@ -188,21 +191,23 @@ class BoxWrapper { } } - void SaveBase(const char* batch_model_path, const char* xbox_model_path, - boxps::SaveModelStat& stat) { // NOLINT + const std::string SaveBase(const char* batch_model_path, + const char* xbox_model_path) { VLOG(3) << "Begin SaveBase"; - if (nullptr != s_instance_) { - s_instance_->boxps_ptr_->SaveBase(batch_model_path, xbox_model_path, - stat); - } + std::string ret_str; + int ret = boxps_ptr_->SaveBase(batch_model_path, xbox_model_path, ret_str); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( + "SaveBase failed in BoxPS.")); + return ret_str; } - void SaveDelta(const char* xbox_model_path, - boxps::SaveModelStat& stat) { // NOLINT + const std::string SaveDelta(const char* xbox_model_path) { VLOG(3) << "Begin SaveDelta"; - if (nullptr != s_instance_) { - s_instance_->boxps_ptr_->SaveDelta(xbox_model_path, stat); - } + std::string ret_str; + int ret = boxps_ptr_->SaveDelta(xbox_model_path, ret_str); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( + "SaveDelta failed in BoxPS.")); + return ret_str; } static std::shared_ptr GetInstance() { @@ -223,7 +228,7 @@ class BoxWrapper { return slot_name_omited_in_feedpass_; } - struct MetricMsg { + class MetricMsg { public: MetricMsg() {} MetricMsg(const std::string& label_varname, const std::string& pred_varname, @@ -234,29 +239,212 @@ class BoxWrapper { calculator = new BasicAucCalculator(); calculator->init(bucket_size); } - const std::string& LabelVarname() const { return label_varname_; } - const std::string& PredVarname() const { return pred_varname_; } + virtual ~MetricMsg() {} + int IsJoin() const { return is_join_; } BasicAucCalculator* GetCalculator() { return calculator; } + virtual void add_data(const Scope* exe_scope) { + std::vector label_data; + get_data(exe_scope, label_varname_, &label_data); + std::vector pred_data; + get_data(exe_scope, pred_varname_, &pred_data); + auto cal = GetCalculator(); + auto batch_size = label_data.size(); + for (size_t i = 0; i < batch_size; ++i) { + cal->add_data(pred_data[i], label_data[i]); + } + } + template + static void get_data(const Scope* exe_scope, const std::string& varname, + std::vector* data) { + auto* var = exe_scope->FindVar(varname.c_str()); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound( + "Error: var %s is not found in scope.", varname.c_str())); + auto& gpu_tensor = var->Get(); + auto* gpu_data = gpu_tensor.data(); + auto len = gpu_tensor.numel(); + data->resize(len); + cudaMemcpy(data->data(), gpu_data, sizeof(T) * len, + cudaMemcpyDeviceToHost); + } + static inline std::pair parse_cmatch_rank(uint64_t x) { + // first 32 bit store cmatch and second 32 bit store rank + return std::make_pair(static_cast(x >> 32), + static_cast(x & 0xff)); + } - private: + protected: std::string label_varname_; std::string pred_varname_; int is_join_; BasicAucCalculator* calculator; }; + class MultiTaskMetricMsg : public MetricMsg { + public: + MultiTaskMetricMsg(const std::string& label_varname, + const std::string& pred_varname_list, int is_join, + const std::string& cmatch_rank_group, + const std::string& cmatch_rank_varname, + int bucket_size = 1000000) { + label_varname_ = label_varname; + cmatch_rank_varname_ = cmatch_rank_varname; + is_join_ = is_join; + calculator = new BasicAucCalculator(); + calculator->init(bucket_size); + for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) { + const std::vector& cur_cmatch_rank = + string::split_string(cmatch_rank, "_"); + PADDLE_ENFORCE_EQ( + cur_cmatch_rank.size(), 2, + platform::errors::PreconditionNotMet( + "illegal multitask auc spec: %s", cmatch_rank.c_str())); + cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()), + atoi(cur_cmatch_rank[1].c_str())); + } + for (const auto& pred_varname : string::split_string(pred_varname_list)) { + pred_v.emplace_back(pred_varname); + } + PADDLE_ENFORCE_EQ(cmatch_rank_v.size(), pred_v.size(), + platform::errors::PreconditionNotMet( + "cmatch_rank's size [%lu] should be equal to pred " + "list's size [%lu], but ther are not equal", + cmatch_rank_v.size(), pred_v.size())); + } + virtual ~MultiTaskMetricMsg() {} + void add_data(const Scope* exe_scope) override { + std::vector cmatch_rank_data; + get_data(exe_scope, cmatch_rank_varname_, &cmatch_rank_data); + std::vector label_data; + get_data(exe_scope, label_varname_, &label_data); + size_t batch_size = cmatch_rank_data.size(); + PADDLE_ENFORCE_EQ( + batch_size, label_data.size(), + platform::errors::PreconditionNotMet( + "illegal batch size: batch_size[%lu] and label_data[%lu]", + batch_size, label_data.size())); + + std::vector> pred_data_list(pred_v.size()); + for (size_t i = 0; i < pred_v.size(); ++i) { + get_data(exe_scope, pred_v[i], &pred_data_list[i]); + } + for (size_t i = 0; i < pred_data_list.size(); ++i) { + PADDLE_ENFORCE_EQ( + batch_size, pred_data_list[i].size(), + platform::errors::PreconditionNotMet( + "illegal batch size: batch_size[%lu] and pred_data[%lu]", + batch_size, pred_data_list[i].size())); + } + auto cal = GetCalculator(); + for (size_t i = 0; i < batch_size; ++i) { + auto cmatch_rank_it = + std::find(cmatch_rank_v.begin(), cmatch_rank_v.end(), + parse_cmatch_rank(cmatch_rank_data[i])); + if (cmatch_rank_it != cmatch_rank_v.end()) { + cal->add_data(pred_data_list[std::distance(cmatch_rank_v.begin(), + cmatch_rank_it)][i], + label_data[i]); + } + } + } + + protected: + std::vector> cmatch_rank_v; + std::vector pred_v; + std::string cmatch_rank_varname_; + }; + class CmatchRankMetricMsg : public MetricMsg { + public: + CmatchRankMetricMsg(const std::string& label_varname, + const std::string& pred_varname, int is_join, + const std::string& cmatch_rank_group, + const std::string& cmatch_rank_varname, + int bucket_size = 1000000) { + label_varname_ = label_varname; + pred_varname_ = pred_varname; + cmatch_rank_varname_ = cmatch_rank_varname; + is_join_ = is_join; + calculator = new BasicAucCalculator(); + calculator->init(bucket_size); + for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) { + const std::vector& cur_cmatch_rank = + string::split_string(cmatch_rank, "_"); + PADDLE_ENFORCE_EQ( + cur_cmatch_rank.size(), 2, + platform::errors::PreconditionNotMet( + "illegal cmatch_rank auc spec: %s", cmatch_rank.c_str())); + cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()), + atoi(cur_cmatch_rank[1].c_str())); + } + } + virtual ~CmatchRankMetricMsg() {} + void add_data(const Scope* exe_scope) override { + std::vector cmatch_rank_data; + get_data(exe_scope, cmatch_rank_varname_, &cmatch_rank_data); + std::vector label_data; + get_data(exe_scope, label_varname_, &label_data); + std::vector pred_data; + get_data(exe_scope, pred_varname_, &pred_data); + size_t batch_size = cmatch_rank_data.size(); + PADDLE_ENFORCE_EQ( + batch_size, label_data.size(), + platform::errors::PreconditionNotMet( + "illegal batch size: cmatch_rank[%lu] and label_data[%lu]", + batch_size, label_data.size())); + PADDLE_ENFORCE_EQ( + batch_size, pred_data.size(), + platform::errors::PreconditionNotMet( + "illegal batch size: cmatch_rank[%lu] and pred_data[%lu]", + batch_size, pred_data.size())); + auto cal = GetCalculator(); + for (size_t i = 0; i < batch_size; ++i) { + const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]); + for (size_t j = 0; j < cmatch_rank_v.size(); ++j) { + if (cmatch_rank_v[j] == cur_cmatch_rank) { + cal->add_data(pred_data[i], label_data[i]); + break; + } + } + } + } + + protected: + std::vector> cmatch_rank_v; + std::string cmatch_rank_varname_; + }; + const std::vector& GetMetricNameList() const { + return metric_name_list_; + } int PassFlag() const { return pass_flag_; } void FlipPassFlag() { pass_flag_ = 1 - pass_flag_; } - bool NeedMetric() const { return need_metric_; } - std::map& GetMetricList() { return metric_lists_; } + std::map& GetMetricList() { return metric_lists_; } - void InitMetric(const std::string& name, const std::string& label_varname, - const std::string& pred_varname, bool is_join, + void InitMetric(const std::string& method, const std::string& name, + const std::string& label_varname, + const std::string& pred_varname, + const std::string& cmatch_rank_varname, bool is_join, + const std::string& cmatch_rank_group, int bucket_size = 1000000) { - metric_lists_.emplace(name, MetricMsg(label_varname, pred_varname, - is_join ? 1 : 0, bucket_size)); - need_metric_ = true; + if (method == "AucCalculator") { + metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname, + is_join ? 1 : 0, bucket_size)); + } else if (method == "MultiTaskAucCalculator") { + metric_lists_.emplace( + name, new MultiTaskMetricMsg(label_varname, pred_varname, + is_join ? 1 : 0, cmatch_rank_group, + cmatch_rank_varname, bucket_size)); + } else if (method == "CmatchRankAucCalculator") { + metric_lists_.emplace( + name, new CmatchRankMetricMsg(label_varname, pred_varname, + is_join ? 1 : 0, cmatch_rank_group, + cmatch_rank_varname, bucket_size)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "PaddleBox only support AucCalculator, MultiTaskAucCalculator and " + "CmatchRankAucCalculator")); + } + metric_name_list_.emplace_back(name); } const std::vector GetMetricMsg(const std::string& name) { @@ -265,7 +453,7 @@ class BoxWrapper { platform::errors::InvalidArgument( "The metric name you provided is not registered.")); std::vector metric_return_values_(8, 0.0); - auto* auc_cal_ = iter->second.GetCalculator(); + auto* auc_cal_ = iter->second->GetCalculator(); auc_cal_->calculate_bucket_error(); auc_cal_->compute(); metric_return_values_[0] = auc_cal_->auc(); @@ -285,14 +473,15 @@ class BoxWrapper { static cudaStream_t stream_list_[8]; static std::shared_ptr boxps_ptr_; boxps::PSAgentBase* p_agent_ = nullptr; + // TODO(hutuxian): magic number, will add a config to specify const int feedpass_thread_num_ = 30; // magic number static std::shared_ptr s_instance_; std::unordered_set slot_name_omited_in_feedpass_; // Metric Related int pass_flag_ = 1; // join: 1, update: 0 - bool need_metric_ = false; - std::map metric_lists_; + std::map metric_lists_; + std::vector metric_name_list_; std::vector slot_vector_; std::vector keys_tensor; // Cache for pull_sparse }; @@ -303,13 +492,17 @@ class BoxHelper { explicit BoxHelper(paddle::framework::Dataset* dataset) : dataset_(dataset) {} virtual ~BoxHelper() {} + void SetDate(int year, int month, int day) { + year_ = year; + month_ = month; + day_ = day; + } void BeginPass() { #ifdef PADDLE_WITH_BOX_PS auto box_ptr = BoxWrapper::GetInstance(); box_ptr->BeginPass(); #endif } - void EndPass() { #ifdef PADDLE_WITH_BOX_PS auto box_ptr = BoxWrapper::GetInstance(); @@ -317,8 +510,18 @@ class BoxHelper { #endif } void LoadIntoMemory() { + platform::Timer timer; + VLOG(3) << "Begin LoadIntoMemory(), dataset[" << dataset_ << "]"; + timer.Start(); dataset_->LoadIntoMemory(); + timer.Pause(); + VLOG(0) << "download + parse cost: " << timer.ElapsedSec() << "s"; + + timer.Start(); FeedPass(); + timer.Pause(); + VLOG(0) << "FeedPass cost: " << timer.ElapsedSec() << " s"; + VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]"; } void PreLoadIntoMemory() { dataset_->PreLoadIntoMemory(); @@ -326,33 +529,91 @@ class BoxHelper { dataset_->WaitPreLoadDone(); FeedPass(); })); + VLOG(3) << "After PreLoadIntoMemory()"; } void WaitFeedPassDone() { feed_data_thread_->join(); } - private: - Dataset* dataset_; - std::shared_ptr feed_data_thread_; +#ifdef PADDLE_WITH_BOX_PS // notify boxps to feed this pass feasigns from SSD to memory + static void FeedPassThread(const std::deque& t, int begin_index, + int end_index, boxps::PSAgentBase* p_agent, + const std::unordered_set& index_map, + int thread_id) { + p_agent->AddKey(0ul, thread_id); + for (auto iter = t.begin() + begin_index; iter != t.begin() + end_index; + iter++) { + const auto& ins = *iter; + const auto& feasign_v = ins.uint64_feasigns_; + for (const auto feasign : feasign_v) { + if (index_map.find(feasign.slot()) != index_map.end()) { + continue; + } + p_agent->AddKey(feasign.sign().uint64_feasign_, thread_id); + } + } + } +#endif void FeedPass() { + VLOG(3) << "Begin FeedPass"; #ifdef PADDLE_WITH_BOX_PS + struct std::tm b; + b.tm_year = year_ - 1900; + b.tm_mon = month_ - 1; + b.tm_mday = day_; + b.tm_min = b.tm_hour = b.tm_sec = 0; + std::time_t x = std::mktime(&b); + auto box_ptr = BoxWrapper::GetInstance(); auto input_channel_ = dynamic_cast(dataset_)->GetInputChannel(); - std::vector pass_data; - std::vector feasign_to_box; - input_channel_->ReadAll(pass_data); - for (const auto& ins : pass_data) { - const auto& feasign_v = ins.uint64_feasigns_; - for (const auto feasign : feasign_v) { - feasign_to_box.push_back(feasign.sign().uint64_feasign_); + const std::deque& pass_data = input_channel_->GetData(); + + // get feasigns that FeedPass doesn't need + const std::unordered_set& slot_name_omited_in_feedpass_ = + box_ptr->GetOmitedSlot(); + std::unordered_set slot_id_omited_in_feedpass_; + const auto& all_readers = dataset_->GetReaders(); + PADDLE_ENFORCE_GT(all_readers.size(), 0, + platform::errors::PreconditionNotMet( + "Readers number must be greater than 0.")); + const auto& all_slots_name = all_readers[0]->GetAllSlotAlias(); + for (size_t i = 0; i < all_slots_name.size(); ++i) { + if (slot_name_omited_in_feedpass_.find(all_slots_name[i]) != + slot_name_omited_in_feedpass_.end()) { + slot_id_omited_in_feedpass_.insert(i); } } - input_channel_->Open(); - input_channel_->Write(pass_data); - input_channel_->Close(); - box_ptr->FeedPass(feasign_to_box); + const size_t tnum = box_ptr->GetFeedpassThreadNum(); + boxps::PSAgentBase* p_agent = box_ptr->GetAgent(); + VLOG(3) << "Begin call BeginFeedPass in BoxPS"; + box_ptr->BeginFeedPass(x / 86400, &p_agent); + + std::vector threads; + size_t len = pass_data.size(); + size_t len_per_thread = len / tnum; + auto remain = len % tnum; + size_t begin = 0; + for (size_t i = 0; i < tnum; i++) { + threads.push_back( + std::thread(FeedPassThread, std::ref(pass_data), begin, + begin + len_per_thread + (i < remain ? 1 : 0), p_agent, + std::ref(slot_id_omited_in_feedpass_), i)); + begin += len_per_thread + (i < remain ? 1 : 0); + } + for (size_t i = 0; i < tnum; ++i) { + threads[i].join(); + } + VLOG(3) << "Begin call EndFeedPass in BoxPS"; + box_ptr->EndFeedPass(p_agent); #endif } + + private: + Dataset* dataset_; + std::shared_ptr feed_data_thread_; + int year_; + int month_; + int day_; }; } // end namespace framework diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 88c1d83ff8..478d8c6143 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -168,6 +168,11 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, SectionWorker::cpu_id_.store(pipeline_config_.start_cpu_core_id()); scope_queues_.resize(section_num_); pipeline_scopes_.resize(pipeline_num_); + for (auto& var : main_program.Block(0).AllVars()) { + if (var->Persistable()) { + persistable_vars_.push_back(var->Name()); + } + } VLOG(3) << "Init ScopeQueues and create all scopes"; for (int i = 0; i < section_num_; ++i) { @@ -266,7 +271,7 @@ void PipelineTrainer::Finalize() { for (auto& th : section_threads_) { th.join(); } - for (const auto& var : *param_need_sync_) { + for (const auto& var : persistable_vars_) { auto* root_tensor = root_scope_->Var(var)->GetMutable(); // TODO(hutuxian): Add a final all-reduce? const auto& thread_tensor = diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index cd5204b490..01d07f9b2e 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/platform/cpu_helper.h" @@ -146,6 +147,9 @@ void SectionWorker::TrainFiles() { int64_t accum_num = 0; int batch_size = 0; Scope* scope = nullptr; + if (device_reader_ != nullptr) { + device_reader_->Start(); + } while (in_scope_queue_->Receive(&scope)) { if (device_reader_ != nullptr) { device_reader_->AssignFeedVar(*scope); @@ -202,6 +206,17 @@ void SectionWorker::TrainFiles() { // No effect when it is a CPUDeviceContext dev_ctx_->Wait(); +#ifdef PADDLE_WITH_BOX_PS + auto box_ptr = BoxWrapper::GetInstance(); + auto& metric_list = box_ptr->GetMetricList(); + for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) { + auto* metric_msg = iter->second; + if (metric_msg->IsJoin() != box_ptr->PassFlag()) { + continue; + } + metric_msg->add_data(exe_scope); + } +#endif if (section_id_ != section_num_ - 1 && platform::is_gpu_place(place_)) { // FIXME: Temporarily we assume two adjacent sections are in different // places, @@ -273,6 +288,9 @@ void SectionWorker::TrainFilesWithProfiler() { op_total_time[i] = 0.0; } platform::Timer timeline; + if (device_reader_ != nullptr) { + device_reader_->Start(); + } bool started = false; while (in_scope_queue_->Receive(&scope)) { @@ -330,9 +348,11 @@ void SectionWorker::TrainFilesWithProfiler() { SEC_LOG << "begin running ops"; cal_timer.Resume(); int op_id = 0; + dev_ctx_->Wait(); for (auto& op : ops_) { timeline.Start(); op->Run(*exe_scope, place_); + dev_ctx_->Wait(); timeline.Pause(); op_total_time[op_id++] += timeline.ElapsedUS(); } @@ -342,6 +362,17 @@ void SectionWorker::TrainFilesWithProfiler() { // No effect when it is a CPUDeviceContext dev_ctx_->Wait(); cal_timer.Pause(); +#ifdef PADDLE_WITH_BOX_PS + auto box_ptr = BoxWrapper::GetInstance(); + auto& metric_list = box_ptr->GetMetricList(); + for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) { + auto* metric_msg = iter->second; + if (metric_msg->IsJoin() != box_ptr->PassFlag()) { + continue; + } + metric_msg->add_data(exe_scope); + } +#endif if (section_id_ != section_num_ - 1 && platform::is_gpu_place(place_)) { // FIXME: Temporarily we assume two adjacent sections are in different diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index c769dbb350..5ad1762b4d 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -157,6 +157,7 @@ class PipelineTrainer : public TrainerBase { // The parameters that should be syncronized between different cards using // nccl all-reduce std::shared_ptr> param_need_sync_; + std::vector persistable_vars_; std::vector> sync_functors_; std::shared_ptr nccl_ctx_map_; diff --git a/paddle/fluid/pybind/box_helper_py.cc b/paddle/fluid/pybind/box_helper_py.cc index e90445175f..287de7e6be 100644 --- a/paddle/fluid/pybind/box_helper_py.cc +++ b/paddle/fluid/pybind/box_helper_py.cc @@ -29,6 +29,9 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/pybind/box_helper_py.h" +#ifdef PADDLE_WITH_BOX_PS +#include +#endif namespace py = pybind11; @@ -40,6 +43,8 @@ void BindBoxHelper(py::module* m) { .def(py::init([](paddle::framework::Dataset* dataset) { return std::make_shared(dataset); })) + .def("set_date", &framework::BoxHelper::SetDate, + py::call_guard()) .def("begin_pass", &framework::BoxHelper::BeginPass, py::call_guard()) .def("end_pass", &framework::BoxHelper::EndPass, @@ -51,5 +56,35 @@ void BindBoxHelper(py::module* m) { .def("load_into_memory", &framework::BoxHelper::LoadIntoMemory, py::call_guard()); } // end BoxHelper + +#ifdef PADDLE_WITH_BOX_PS +void BindBoxWrapper(py::module* m) { + py::class_>( + *m, "BoxWrapper") + .def(py::init([]() { + // return std::make_shared(dataset); + return framework::BoxWrapper::GetInstance(); + })) + .def("save_base", &framework::BoxWrapper::SaveBase, + py::call_guard()) + .def("feed_pass", &framework::BoxWrapper::FeedPass, + py::call_guard()) + .def("save_delta", &framework::BoxWrapper::SaveDelta, + py::call_guard()) + .def("initialize_gpu", &framework::BoxWrapper::InitializeGPU, + py::call_guard()) + .def("init_metric", &framework::BoxWrapper::InitMetric, + py::call_guard()) + .def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg, + py::call_guard()) + .def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList, + py::call_guard()) + .def("flip_pass_flag", &framework::BoxWrapper::FlipPassFlag, + py::call_guard()) + .def("finalize", &framework::BoxWrapper::Finalize, + py::call_guard()); +} // end BoxWrapper +#endif + } // end namespace pybind } // end namespace paddle diff --git a/paddle/fluid/pybind/box_helper_py.h b/paddle/fluid/pybind/box_helper_py.h index 33072dd5a3..7bc36516c6 100644 --- a/paddle/fluid/pybind/box_helper_py.h +++ b/paddle/fluid/pybind/box_helper_py.h @@ -23,6 +23,9 @@ namespace paddle { namespace pybind { void BindBoxHelper(py::module* m); +#ifdef PADDLE_WITH_BOX_PS +void BindBoxWrapper(py::module* m); +#endif } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c080d9219f..8120ac6a00 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/load_op_lib.h" @@ -1456,6 +1457,9 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_brpc", IsCompiledWithBrpc); m.def("is_compiled_with_dist", IsCompiledWithDIST); + m.def("run_cmd", [](const std::string &cmd) -> const std::string { + return paddle::framework::shell_get_command_output(cmd); + }); #ifdef PADDLE_WITH_CUDA m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { // Only GPUs with Compute Capability >= 53 support float16 @@ -2245,6 +2249,9 @@ All parameter, weight, gradient are variables in Paddle. BindFleetWrapper(&m); BindGlooWrapper(&m); BindBoxHelper(&m); +#ifdef PADDLE_WITH_BOX_PS + BindBoxWrapper(&m); +#endif #ifdef PADDLE_WITH_NCCL BindNCCLWrapper(&m); #endif diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 60dd4eb383..6861d86684 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -314,12 +314,12 @@ class InMemoryDataset(DatasetBase): def _dynamic_adjust_before_train(self, thread_num): if not self.is_user_set_queue_num: - self.dataset.dynamic_adjust_channel_num(thread_num) + self.dataset.dynamic_adjust_channel_num(thread_num, False) self.dataset.dynamic_adjust_readers_num(thread_num) def _dynamic_adjust_after_train(self): if not self.is_user_set_queue_num: - self.dataset.dynamic_adjust_channel_num(self.thread_num) + self.dataset.dynamic_adjust_channel_num(self.thread_num, False) self.dataset.dynamic_adjust_readers_num(self.thread_num) def set_queue_num(self, queue_num): @@ -793,6 +793,15 @@ class BoxPSDataset(InMemoryDataset): super(BoxPSDataset, self).__init__() self.boxps = core.BoxPS(self.dataset) + def set_date(self, date): + """ + Workaround for date + """ + year = int(date[:4]) + month = int(date[4:6]) + day = int(date[6:]) + self.boxps.set_date(year, month, day) + def begin_pass(self): """ Begin Pass @@ -865,3 +874,8 @@ class BoxPSDataset(InMemoryDataset): """ self._prepare_to_run() self.boxps.preload_into_memory() + + def _dynamic_adjust_before_train(self, thread_num): + if not self.is_user_set_queue_num: + self.dataset.dynamic_adjust_channel_num(thread_num, True) + self.dataset.dynamic_adjust_readers_num(thread_num) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d047df4411..f972627cde 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -57,6 +57,7 @@ endif() if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_pipeline) + LIST(REMOVE_ITEM TEST_OPS test_boxps) endif() list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185 diff --git a/python/paddle/fluid/tests/unittests/test_boxps.py b/python/paddle/fluid/tests/unittests/test_boxps.py index b03b83ce62..c914abbf23 100644 --- a/python/paddle/fluid/tests/unittests/test_boxps.py +++ b/python/paddle/fluid/tests/unittests/test_boxps.py @@ -90,7 +90,6 @@ class TestBoxPSPreload(unittest.TestCase): y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0) emb_x, emb_y = _pull_box_sparse([x, y], size=2) emb_xp = _pull_box_sparse(x, size=2) - layers.Print(emb_xp) concat = layers.concat([emb_x, emb_y], axis=1) fc = layers.fc(input=concat, name="fc", @@ -102,7 +101,6 @@ class TestBoxPSPreload(unittest.TestCase): place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda( ) else fluid.CUDAPlace(0) exe = fluid.Executor(place) - optimizer = fluid.optimizer.SGD(learning_rate=0.5) batch_size = 2 def binary_print(slot, fout): @@ -125,6 +123,7 @@ class TestBoxPSPreload(unittest.TestCase): def create_dataset(): dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + dataset.set_date("20190930") dataset.set_use_var([x, y]) dataset.set_batch_size(2) dataset.set_thread(1) @@ -134,6 +133,14 @@ class TestBoxPSPreload(unittest.TestCase): datasets = [] datasets.append(create_dataset()) datasets.append(create_dataset()) + optimizer = fluid.optimizer.SGD(learning_rate=0.5) + optimizer = fluid.optimizer.PipelineOptimizer( + optimizer, + cut_list=[], + place_list=[place], + concurrency_list=[1], + queue_size=1, + sync_steps=-1) optimizer.minimize(loss) exe.run(fluid.default_startup_program()) datasets[0].load_into_memory() @@ -149,7 +156,8 @@ class TestBoxPSPreload(unittest.TestCase): exe.train_from_dataset( program=fluid.default_main_program(), dataset=datasets[1], - print_period=1) + print_period=1, + debug=True) datasets[1].end_pass() for f in filelist: os.remove(f) diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index 1b445583d2..42623337de 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -379,7 +379,7 @@ class SingleProcessMultiThread(GradAllReduce): ''' def __init__(self): - GradAllReduce.__init__(self, -1) + GradAllReduce.__init__(self, 1) self.mode = "single_process_multi_thread" def _transpile_startup_program(self): -- GitLab