From 0073f9bdb0b43a8d298346e28a2b403fe351bac3 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Mon, 23 Nov 2020 20:00:36 +0800 Subject: [PATCH] support ps-gpu (#28752) * ps gpu transpile * ps gpu * remove op * gps trainer * local ps * add macro * HeterBox * def cuda * tab * code style * style Co-authored-by: Thunderbrook --- paddle/fluid/framework/CMakeLists.txt | 47 +- paddle/fluid/framework/device_worker.h | 108 ++- .../fluid/framework/device_worker_factory.cc | 1 + paddle/fluid/framework/fleet/fleet_wrapper.cc | 5 +- paddle/fluid/framework/fleet/fleet_wrapper.h | 4 +- paddle/fluid/framework/fleet/heter_wrapper.h | 6 +- paddle/fluid/framework/heter_service.h | 9 + paddle/fluid/framework/heterbox_trainer.cc | 260 ++++++ paddle/fluid/framework/heterbox_worker.cc | 753 ++++++++++++++++++ paddle/fluid/framework/hetercpu_worker.cc | 12 +- paddle/fluid/framework/heterxpu_trainer.cc | 2 +- paddle/fluid/framework/pull_dense_worker.cc | 17 + paddle/fluid/framework/trainer.h | 50 ++ paddle/fluid/framework/trainer_desc.proto | 2 + paddle/fluid/framework/trainer_factory.cc | 1 + python/paddle/fluid/executor.py | 2 + .../fleet/parameter_server/pslib/__init__.py | 78 ++ .../pslib/optimizer_factory.py | 41 +- .../fluid/incubate/fleet/utils/fleet_util.py | 26 +- .../paddle/fluid/incubate/fleet/utils/hdfs.py | 1 - python/paddle/fluid/trainer_desc.py | 29 +- python/paddle/fluid/trainer_factory.py | 4 +- 22 files changed, 1415 insertions(+), 43 deletions(-) create mode 100644 paddle/fluid/framework/heterbox_trainer.cc create mode 100644 paddle/fluid/framework/heterbox_worker.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6b724b656dd..55e56bf2ecc 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -200,23 +200,41 @@ cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector) if(WITH_DISTRIBUTE) - cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc - dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc - heterxpu_trainer.cc - data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc - pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry - device_context scope framework_proto trainer_desc_proto glog fs shell - fleet_wrapper heter_wrapper box_wrapper lodtensor_printer - lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} - graph_to_program_pass variable_helper data_feed_proto timer monitor - heter_service_proto) - set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") - set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + if(WITH_PSLIB) + cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc + dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc + heterxpu_trainer.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc + pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry + device_context scope framework_proto trainer_desc_proto glog fs shell + fleet_wrapper heter_wrapper box_wrapper lodtensor_printer + lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} + graph_to_program_pass variable_helper data_feed_proto timer monitor + heter_service_proto pslib_brpc) + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + else() + cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc + dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc + heterxpu_trainer.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc + pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry + device_context scope framework_proto trainer_desc_proto glog fs shell + fleet_wrapper heter_wrapper box_wrapper lodtensor_printer + lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} + graph_to_program_pass variable_helper data_feed_proto timer monitor + heter_service_proto) + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + endif() elseif(WITH_PSLIB) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc - data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method @@ -229,7 +247,8 @@ else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc - data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 4951ada9bd5..a254248feaf 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -92,6 +92,7 @@ class PullDenseWorker { void Wait(std::vector<::std::future>* status_vec); void PullDense(bool force_update = false); void CreatePinVar(); + void MergeDenseParam(); int GetThreadIdByScope(const Scope* scope); void SetThreadIdByScope(const Scope* scope, int tid); static std::shared_ptr GetInstance() { @@ -164,7 +165,12 @@ class DeviceWorker { virtual void SetDataFeed(DataFeed* data_feed); virtual void SetWorkerNum(int num) {} virtual void CacheProgram(const ProgramDesc& main_program) {} + virtual void ProduceTasks() {} virtual void GetXpuOpIndex() {} +#ifdef PADDLE_WITH_CUDA + virtual void SetStream(const cudaStream_t stream) {} + virtual void SetEvent(const cudaEvent_t event) {} +#endif virtual void SetNeedDumpField(bool need_dump_field) { need_dump_field_ = need_dump_field; } @@ -187,6 +193,7 @@ class DeviceWorker { device_reader_->SetPlace(place); } virtual Scope* GetThreadScope() { return thread_scope_; } + DataFeed* device_reader_ = nullptr; protected: virtual void DumpParam(const Scope& scope, const int batch_id); @@ -195,7 +202,6 @@ class DeviceWorker { Scope* root_scope_ = nullptr; Scope* thread_scope_; paddle::platform::Place place_; - DataFeed* device_reader_ = nullptr; int64_t batch_num_; FetchConfig fetch_config_; bool use_cvm_; @@ -431,6 +437,106 @@ class HeterCpuWorker : public HogwildWorker { }; #endif +#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ + (defined PADDLE_WITH_PSLIB) +class HeterBoxWorker : public HogwildWorker { + public: + HeterBoxWorker() {} + virtual ~HeterBoxWorker() {} + virtual void Initialize(const TrainerDesc& desc); + virtual void TrainFiles(); + virtual void SetNeedDump(bool need_dump_field); + virtual void SetChannelWriter(ChannelObject* queue); + virtual void SetWorkerNum(int num) { worker_num_ = num; } + virtual void CacheProgram(const ProgramDesc& main_program) { + new (&program_) ProgramDesc(main_program); + } + virtual void ProduceTasks() override; + virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } + virtual void SetEvent(const cudaEvent_t event) { event_ = event; } + virtual void TrainFilesWithProfiler() {} + void ResetStat(); + + protected: + std::shared_ptr fleet_ptr_; + void FillSparseValue(std::shared_ptr task, size_t table_id); + void PushGradients(); + void CollectLabelInfo(std::shared_ptr task, size_t table_id); + void AdjustInsWeight(std::shared_ptr task); + void DumpParam(); + void CopySparseTable(); + void CopyDenseTable(); + void CopyDenseVars(); + + private: + int mpi_rank_; + std::mutex mutex_; + std::vector send_var_list_; + int worker_num_; + ProgramDesc program_; + HeterObjectPool object_pool_; + bool need_dump_param_; + std::vector dump_param_; + bool need_to_push_dense_; + bool need_dump_field_; + bool dump_slot_; + bool need_to_push_sparse_; + std::vector dump_fields_; + ChannelWriter writer_; + DownpourWorkerParameter param_; + float scale_datanorm_; + // just save the value in param_ for easy access + std::map label_var_name_; + std::map> sparse_key_names_; + std::map> sparse_value_names_; + std::map> sparse_grad_names_; + std::map> dense_value_names_; + std::map> dense_grad_names_; + platform::Place root_place_; + // actually pushed feasign of each table + std::map> sparse_push_keys_; + + // skipped ops + std::vector skip_ops_; + + std::vector<::std::future> push_sparse_status_; + std::vector<::std::future> push_dense_status_; + + // adjust ins weight + AdjustInsWeightConfig adjust_ins_weight_config_; + std::vector nid_show_; + // check nan and inf during training + std::vector check_nan_var_names_; + // copy table + CopyTableConfig copy_table_config_; + std::map table_dependency_; + std::vector> copy_sparse_tables_; + std::vector> copy_dense_tables_; + std::unordered_map> feasign_set_; + paddle::framework::Channel> pull_queue_; + paddle::framework::Channel> push_queue_; + cudaEvent_t event_; + cudaStream_t copy_stream_; + int batch_cnt_{0}; + std::atomic done_cnt_{0}; + + double total_time_; + double read_time_; + double pack_time_; + double pull_sparse_local_time_; + double op_all_time_; + double xpu_op_time_; + double xpu_wait_time_; + double cpu_op_time_; + double collect_label_time_; + double fill_sparse_time_; + double push_sparse_time_; + double gpu_2_cpu_time_; + double cpu_2_gpu_time_; + uint64_t total_inst_; +}; +#endif + #if defined(PADDLE_WITH_NCCL) class SectionWorker : public DeviceWorker { public: diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index 3b60cb65e34..ca5a035b4ab 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -66,6 +66,7 @@ REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt); #ifdef PADDLE_WITH_PSLIB REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker); +REGISTER_DEVICE_WORKER_CLASS(HeterBoxWorker); #endif #if defined(PADDLE_WITH_NCCL) REGISTER_DEVICE_WORKER_CLASS(SectionWorker); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 25086001598..84683b76e98 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -214,12 +214,11 @@ void FleetWrapper::HeterPullSparseVars( } void FleetWrapper::HeterPushSparseVars( - std::shared_ptr task, const uint64_t table_id, - const std::vector& sparse_key_names, + std::shared_ptr task, const Scope& scope, + const uint64_t table_id, const std::vector& sparse_key_names, const std::vector& sparse_grad_names, const int emb_dim, std::vector<::std::future>* push_sparse_status, const bool use_cvm, const bool dump_slot, const bool no_cvm) { - auto& scope = *(task->scope_); int batch_size = task->cur_batch_; int offset = 2; int slot_offset = 0; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index ae86835f38d..c2f89e336a4 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -95,8 +95,8 @@ class FleetWrapper { const std::vector& var_emb_names); void HeterPushSparseVars( - std::shared_ptr task, const uint64_t table_id, - const std::vector& sparse_key_names, + std::shared_ptr task, const Scope& scope, + const uint64_t table_id, const std::vector& sparse_key_names, const std::vector& sparse_grad_names, const int emb_dim, std::vector<::std::future>* push_sparse_status, const bool use_cvm, const bool dump_slot, const bool no_cvm); diff --git a/paddle/fluid/framework/fleet/heter_wrapper.h b/paddle/fluid/framework/fleet/heter_wrapper.h index 6ba4e00fc85..55ad218198e 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_wrapper.h @@ -88,12 +88,10 @@ class HeterWrapper { #ifdef PADDLE_WITH_CUDA void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var, - platform::Place place, - cudaStream_t stream = nullptr); -#else + platform::Place place, cudaStream_t stream); +#endif void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var, platform::Place place); -#endif // HeterWrapper singleton static std::shared_ptr GetInstance() { if (NULL == s_instance_) { diff --git a/paddle/fluid/framework/heter_service.h b/paddle/fluid/framework/heter_service.h index 8662e460aa3..a6687f9a650 100644 --- a/paddle/fluid/framework/heter_service.h +++ b/paddle/fluid/framework/heter_service.h @@ -29,6 +29,7 @@ limitations under the License. */ #include "brpc/channel.h" #include "brpc/controller.h" #include "brpc/server.h" +#include "paddle/fluid/platform/timer.h" namespace paddle { namespace framework { @@ -100,6 +101,9 @@ class HeterTask { collect_label_time = 0; fill_sparse_time = 0; push_sparse_time = 0; + gpu_2_cpu_time = 0; + cpu_2_gpu_time = 0; + timeline.Reset(); } void Show() { std::cout << "features size " << features_.size() << std::endl; @@ -110,6 +114,8 @@ class HeterTask { } void PackTask(Scope* scope, int taskid, DataFeed* reader, int cur_batch, const ProgramDesc& program); + void PackGpuTask(Scope* thread_scope, DataFeed* reader, + const ProgramDesc& program); Scope* scope_{nullptr}; int taskid_; @@ -132,6 +138,9 @@ class HeterTask { double collect_label_time{0}; double fill_sparse_time{0}; double push_sparse_time{0}; + double gpu_2_cpu_time{0}; + double cpu_2_gpu_time{0}; + platform::Timer timeline; }; template diff --git a/paddle/fluid/framework/heterbox_trainer.cc b/paddle/fluid/framework/heterbox_trainer.cc new file mode 100644 index 00000000000..3e55576b846 --- /dev/null +++ b/paddle/fluid/framework/heterbox_trainer.cc @@ -0,0 +1,260 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "io/fs.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/trainer.h" +#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ + (defined PADDLE_WITH_PSLIB) +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_device_guard.h" +#endif +namespace paddle { +namespace framework { + +void HeterBoxTrainer::Initialize(const TrainerDesc& trainer_desc, + Dataset* dataset) { + thread_num_ = trainer_desc.thread_num(); + param_ = trainer_desc.downpour_param(); + for (int i = 0; i < param_.dense_table_size(); ++i) { + uint64_t table_id = static_cast(param_.dense_table(i).table_id()); + auto table = param_.dense_table(i); + dense_grad_names_[table_id].resize(table.dense_grad_name_size()); + for (int j = 0; j < table.dense_grad_name_size(); ++j) { + dense_grad_names_[table_id][j] = table.dense_grad_name(j); + } + } + RegisterHeterCallback(); + scale_datanorm_ = trainer_desc.scale_datanorm(); + int place_num = trainer_desc.worker_places_size(); + const std::vector readers = + dataset->GetReaders(); + for (int i = 0; i < place_num; ++i) { + int num = trainer_desc.worker_places(i); +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace place = platform::CUDAPlace(num); + platform::CUDADeviceGuard guard(place.device); + cudaStream_t stream; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream)); + copy_streams_.push_back(stream); + places_.push_back(place); + cudaEvent_t event; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + events_.push_back(event); +#endif +#ifdef PADDLE_WITH_XPU + platform::XPUPlace place = platform::XPUPlace(num); + places_.push_back(place); +#endif + } + for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size(); + i++) { + need_merge_var_names_.push_back( + trainer_desc.downpour_param().stat_var_names(i)); + } + VLOG(3) << "going to initialize pull dense worker"; + pull_dense_worker_ = PullDenseWorker::GetInstance(); + pull_dense_worker_->Initialize(trainer_desc); + VLOG(3) << "initialize pull dense worker"; + SetDebug(trainer_desc.debug()); + fleet_ptr_ = FleetWrapper::GetInstance(); + trainer_desc_ = trainer_desc; + workers_.resize(place_num); + for (int i = 0; i < place_num; ++i) { + workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc.device_worker_name()); + workers_[i]->SetDeviceIndex(i); + workers_[i]->SetDataFeed(readers[i]); + workers_[i]->Initialize(trainer_desc); + workers_[i]->SetWorkerNum(place_num); + } +} + +void HeterBoxTrainer::DumpWork(int tid) {} + +void HeterBoxTrainer::RegisterHeterCallback() { + auto fleet_ptr = FleetWrapper::GetInstance(); + fleet_ptr->RegisterHeterCallback([this](int worker, int taskid) { + // workers_[worker]->Schedule(taskid); + }); +} + +void HeterBoxTrainer::InitTrainerEnv(const ProgramDesc& main_program, + const platform::Place& place) { + for (size_t i = 0; i < places_.size(); ++i) { + workers_[i]->SetPlace(places_[i]); + workers_[i]->SetStream(copy_streams_[i]); + workers_[i]->SetEvent(events_[i]); + workers_[i]->SetReaderPlace(platform::CPUPlace()); + workers_[i]->SetRootScope(root_scope_); + workers_[i]->CreateDeviceResource(main_program); // Program + workers_[i]->BindingDataFeedMemory(); +#ifdef PADDLE_WITH_PSLIB + workers_[i]->CacheProgram(main_program); +#endif + } + for (size_t num = 0; num < places_.size(); ++num) { + auto place = places_[num]; + Scope* scope = workers_[num]->GetThreadScope(); + auto stream = copy_streams_[num]; + auto event = events_[num]; + auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device; + platform::CUDADeviceGuard guard(dev_id); + auto& block = main_program.Block(0); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto name = var->Name(); + Variable* root_var = root_scope_->FindVar(name); + if (!root_var) { + continue; + } + LoDTensor* root_tensor = root_var->GetMutable(); + auto* ptr = scope->Var(name); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + LoDTensor* thread_tensor = ptr->GetMutable(); + +#define HeterMemcpyFunc(cpp_type, proto_type) \ + do { \ + if (root_tensor->type() == proto_type) { \ + HeterMemCpy(thread_tensor, root_tensor, place, stream); \ + } \ + } while (0) + _ForEachDataType_(HeterMemcpyFunc); + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); + cudaEventSynchronize(event); + } + place_ = place; +} + +template +void HeterBoxTrainer::HeterMemCpy(LoDTensor* thread_tensor, + LoDTensor* root_tensor, + const paddle::platform::Place& thread_place, + cudaStream_t stream) { + T* thread_ptr = + thread_tensor->mutable_data(root_tensor->dims(), thread_place); + T* root_ptr = root_tensor->data(); + if (platform::is_cpu_place(root_tensor->place())) { + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, thread_place), thread_ptr, + platform::CPUPlace(), root_ptr, + sizeof(T) * root_tensor->numel(), stream); + } else { + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, thread_place), thread_ptr, + BOOST_GET_CONST(platform::CUDAPlace, root_tensor->place()), + root_ptr, sizeof(T) * root_tensor->numel(), stream); + } +} + +void HeterBoxTrainer::InitOtherEnv(const ProgramDesc& main_program) { + pull_dense_worker_->SetRootScope(root_scope_); + pull_dense_worker_->CreatePinVar(); + for (size_t i = 0; i < places_.size(); ++i) { + pull_dense_worker_->AddThreadScope(workers_[i]->GetThreadScope()); + pull_dense_worker_->AddPlace(places_[i]); +#ifdef PADDLE_WITH_CUDA + pull_dense_worker_->AddStream(copy_streams_[i]); +#endif + } + VLOG(3) << "init other env done."; +} + +void HeterBoxTrainer::Run() { + int pull_thread_num = 3 * places_.size(); + for (size_t thidx = 0; thidx < places_.size(); ++thidx) { + workers_[thidx]->device_reader_->Start(); + std::dynamic_pointer_cast( + workers_[thidx]) + ->ResetStat(); + } + for (int i = 0; i < pull_thread_num; ++i) { + int worker_id = i % places_.size(); + pull_threads_.push_back( + std::thread(&DeviceWorker::ProduceTasks, workers_[worker_id].get())); + } + for (size_t thidx = 0; thidx < places_.size(); ++thidx) { + threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); + } +} + +template +void HeterBoxTrainer::MergeToRootScope(LoDTensor* root_tensor, + LoDTensor* tensor) { + LoDTensor tmp_root; + TensorCopy(*root_tensor, platform::CPUPlace(), &tmp_root); + T* tmp_root_data = tmp_root.data(); + LoDTensor tmp_tensor; + TensorCopy(*tensor, platform::CPUPlace(), &tmp_tensor); + T* data = tmp_tensor.data(); + for (int i = 0; i < tmp_tensor.numel(); i++) { + tmp_root_data[i] += data[i]; + } + TensorCopy(tmp_root, platform::CPUPlace(), root_tensor); +} + +Scope* HeterBoxTrainer::GetWorkerScope(int thread_id) { return nullptr; } + +void HeterBoxTrainer::Finalize() { + for (auto& th : pull_threads_) { + th.join(); + } + for (auto& th : threads_) { + th.join(); + } + for (size_t i = 0; i < need_merge_var_names_.size(); i++) { + Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]); + if (root_var == nullptr) { + continue; + } + LoDTensor* root_tensor = root_var->GetMutable(); + + for (size_t j = 0; j < places_.size(); j++) { + Scope* cur_thread_scope = workers_[j]->GetThreadScope(); + Variable* thread_var = + cur_thread_scope->FindVar(need_merge_var_names_[i]); + if (thread_var == nullptr) { + continue; + } + LoDTensor* thread_tensor = thread_var->GetMutable(); +#define MergeCallback(cpp_type, proto_type) \ + do { \ + if (root_tensor->type() == proto_type) { \ + if (thread_tensor->type() != proto_type) { \ + VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \ + << "] " << need_merge_var_names_[i] \ + << ", root tensor type=" << root_tensor->type() \ + << ", thread tensor type=" << thread_tensor->type(); \ + exit(-1); \ + } \ + MergeToRootScope(root_tensor, thread_tensor); \ + } \ + } while (0) + _ForEachDataType_(MergeCallback); + } + } + pull_dense_worker_->MergeDenseParam(); + root_scope_->DropKids(); +} +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/heterbox_worker.cc b/paddle/fluid/framework/heterbox_worker.cc new file mode 100644 index 00000000000..726b651fcf4 --- /dev/null +++ b/paddle/fluid/framework/heterbox_worker.cc @@ -0,0 +1,753 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/string/string_helper.h" + +#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ + (defined PADDLE_WITH_PSLIB) +#include "paddle/fluid/platform/cuda_device_guard.h" + +#if defined _WIN32 || defined __APPLE__ +#else +#define _LINUX +#endif + +namespace paddle { +namespace framework { + +void HeterBoxWorker::Initialize(const TrainerDesc& desc) { + param_ = desc.downpour_param(); + mpi_rank_ = desc.mpi_rank(); + trainer_desc_ = desc; + for (int i = 0; i < trainer_desc_.xpu_recv_list_size(); ++i) { + send_var_list_.push_back(trainer_desc_.xpu_recv_list(i)); + } + for (int i = 0; i < param_.sparse_table_size(); ++i) { + uint64_t table_id = + static_cast(param_.sparse_table(i).table_id()); + TableParameter table = param_.sparse_table(i); + sparse_key_names_[table_id].resize(table.sparse_key_name_size()); + for (int j = 0; j < table.sparse_key_name_size(); ++j) { + sparse_key_names_[table_id][j] = table.sparse_key_name(j); + } + sparse_value_names_[table_id].resize(table.sparse_value_name_size()); + for (int j = 0; j < table.sparse_value_name_size(); ++j) { + sparse_value_names_[table_id][j] = table.sparse_value_name(j); + } + sparse_grad_names_[table_id].resize(table.sparse_grad_name_size()); + for (int j = 0; j < table.sparse_grad_name_size(); ++j) { + sparse_grad_names_[table_id][j] = table.sparse_grad_name(j); + } + label_var_name_[table_id] = table.label_var_name(); + sparse_push_keys_[table_id] = std::vector(); + } + + for (int i = 0; i < param_.dense_table_size(); ++i) { + uint64_t table_id = static_cast(param_.dense_table(i).table_id()); + auto table = param_.dense_table(i); + dense_value_names_[table_id].resize(table.dense_value_name_size()); + for (int j = 0; j < table.dense_value_name_size(); ++j) { + dense_value_names_[table_id][j] = table.dense_value_name(j); + } + dense_grad_names_[table_id].resize(table.dense_grad_name_size()); + for (int j = 0; j < table.dense_grad_name_size(); ++j) { + dense_grad_names_[table_id][j] = table.dense_grad_name(j); + } + } + + skip_ops_.resize(param_.skip_ops_size()); + for (int i = 0; i < param_.skip_ops_size(); ++i) { + skip_ops_[i] = param_.skip_ops(i); + } + for (int i = 0; i < param_.stat_var_names_size(); ++i) { + stat_var_name_map_[param_.stat_var_names(i)] = 1; + } + + need_to_push_sparse_ = param_.push_sparse(); + need_to_push_dense_ = param_.push_dense(); + + fleet_ptr_ = FleetWrapper::GetInstance(); + fetch_config_ = desc.fetch_config(); + use_cvm_ = desc.use_cvm(); + // for sparse value accessor, embedding only + no_cvm_ = desc.no_cvm(); + scale_datanorm_ = desc.scale_datanorm(); + dump_slot_ = desc.dump_slot(); + dump_fields_.resize(desc.dump_fields_size()); + for (int i = 0; i < desc.dump_fields_size(); ++i) { + dump_fields_[i] = desc.dump_fields(i); + } + adjust_ins_weight_config_ = desc.adjust_ins_weight_config(); + need_dump_param_ = false; + dump_param_.resize(desc.dump_param_size()); + for (int i = 0; i < desc.dump_param_size(); ++i) { + dump_param_[i] = desc.dump_param(i); + } + if (desc.dump_param_size() != 0) { + need_dump_param_ = true; + } + for (int i = 0; i < desc.check_nan_var_names_size(); ++i) { + check_nan_var_names_.push_back(desc.check_nan_var_names(i)); + } + copy_table_config_ = desc.copy_table_config(); + for (int i = 0; i < copy_table_config_.src_sparse_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_sparse_tables(i); + uint64_t dest_table = copy_table_config_.dest_sparse_tables(i); + VLOG(3) << "copy_sparse_tables_ push back " << src_table << "->" + << dest_table; + copy_sparse_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (int i = 0; i < copy_table_config_.src_dense_tables_size(); ++i) { + uint64_t src_table = copy_table_config_.src_dense_tables(i); + uint64_t dest_table = copy_table_config_.dest_dense_tables(i); + VLOG(3) << "copy_dense_tables_ push back " << src_table << "->" + << dest_table; + copy_dense_tables_.push_back(std::make_pair(src_table, dest_table)); + } + for (auto& m : copy_table_config_.table_denpendency_map()) { + if (sparse_key_names_.find(m.key()) != sparse_key_names_.end()) { + // currently only support one dependency + for (auto& value : m.values()) { + table_dependency_[m.key()] = value; + } + } + } + pull_queue_ = paddle::framework::MakeChannel>(); + push_queue_ = paddle::framework::MakeChannel>(); +} + +void HeterBoxWorker::SetChannelWriter(ChannelObject* queue) { + writer_.Reset(queue); +} + +void HeterBoxWorker::SetNeedDump(bool need_dump_field) { + need_dump_field_ = need_dump_field; +} + +void HeterBoxWorker::DumpParam() {} + +void HeterBoxWorker::CollectLabelInfo(std::shared_ptr task, + size_t table_idx) { + if (no_cvm_) { + return; + } + uint64_t table_id = static_cast( + param_.program_config(0).pull_sparse_table_id(table_idx)); + + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == table_id) { + table = i; + break; + } + } + auto& feature = (task->features_)[table_id]; + auto& feature_label = (task->feature_labels_)[table_id]; + Scope* scope = task->scope_; + feature_label.resize(feature.size()); + Variable* var = scope->FindVar(label_var_name_[table_id]); + LoDTensor* tensor = var->GetMutable(); + int64_t* label_ptr = tensor->data(); + + size_t global_index = 0; + for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) { + VLOG(3) << "sparse_key_names_[" << i + << "]: " << sparse_key_names_[table_id][i]; + Variable* fea_var = scope->FindVar(sparse_key_names_[table_id][i]); + if (fea_var == nullptr) { + continue; + } + LoDTensor* tensor = fea_var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " + << sparse_key_names_[table_id][i] << " is null"; + + // skip slots which do not have embedding + Variable* emb_var = scope->FindVar(sparse_value_names_[table_id][i]); + if (emb_var == nullptr) { + continue; + } + int64_t* ids = tensor->data(); + size_t fea_idx = 0; + // tensor->lod()[0].size() == batch_size + 1 + for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) { + for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) { + // should be skipped feasign defined in protobuf + if (ids[fea_idx] == 0u) { + continue; + } + feature_label[global_index++] = + static_cast(label_ptr[lod_idx - 1]); + } + } + } + CHECK(global_index == feature.size()) + << "expect fea info size:" << feature.size() << " real:" << global_index; +} + +void HeterBoxWorker::FillSparseValue(std::shared_ptr task, + size_t table_idx) { + uint64_t table_id = static_cast( + param_.program_config(0).pull_sparse_table_id(table_idx)); + + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == table_id) { + table = i; + break; + } + } + + auto& fea_value = (task->feature_values_)[table_id]; + Scope* scope = task->scope_; + auto fea_idx = 0u; + + std::vector init_value(table.fea_dim()); + for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) { + std::string slot_name = sparse_key_names_[table_id][i]; + std::string emb_slot_name = sparse_value_names_[table_id][i]; + Variable* var = scope->FindVar(slot_name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << slot_name << " is null"; + int64_t* ids = tensor->data(); + int len = tensor->numel(); + Variable* var_emb = scope->FindVar(emb_slot_name); + if (var_emb == nullptr) { + continue; + } + LoDTensor* tensor_emb = var_emb->GetMutable(); + float* ptr = tensor_emb->mutable_data({len, table.emb_dim()}, + platform::CPUPlace()); + // memset(ptr, 0, sizeof(float) * len * table.emb_dim()); + auto& tensor_lod = tensor->lod()[0]; + LoD data_lod{tensor_lod}; + tensor_emb->set_lod(data_lod); + + bool is_nid = (adjust_ins_weight_config_.need_adjust() && + adjust_ins_weight_config_.nid_slot() == emb_slot_name); + if (is_nid) { + nid_show_.clear(); + } + int nid_ins_index = 0; + + for (int index = 0; index < len; ++index) { + if (use_cvm_ || no_cvm_) { + if (ids[index] == 0u) { + memcpy(ptr + table.emb_dim() * index, init_value.data(), + sizeof(float) * table.emb_dim()); + if (is_nid) { + nid_show_.push_back(-1); + ++nid_ins_index; + } + continue; + } + memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data(), + sizeof(float) * table.emb_dim()); + if (is_nid && + static_cast(index) == tensor->lod()[0][nid_ins_index]) { + nid_show_.push_back(fea_value[fea_idx][0]); + ++nid_ins_index; + } + fea_idx++; + } else { + if (ids[index] == 0u) { + memcpy(ptr + table.emb_dim() * index, init_value.data() + 2, + sizeof(float) * table.emb_dim()); + if (is_nid) { + nid_show_.push_back(-1); + ++nid_ins_index; + } + continue; + } + memcpy(ptr + table.emb_dim() * index, fea_value[fea_idx].data() + 2, + sizeof(float) * table.emb_dim()); + if (is_nid && + static_cast(index) == tensor->lod()[0][nid_ins_index]) { + nid_show_.push_back(fea_value[fea_idx][0]); + ++nid_ins_index; + } + fea_idx++; + } + } + } +} + +void HeterBoxWorker::AdjustInsWeight(std::shared_ptr task) { +#ifdef _LINUX + // check var and tensor not null + Scope* scope = task->scope_; + if (!adjust_ins_weight_config_.need_adjust()) { + VLOG(0) << "need_adjust=false, skip adjust ins weight"; + return; + } + Variable* nid_var = scope->FindVar(adjust_ins_weight_config_.nid_slot()); + if (nid_var == nullptr) { + VLOG(0) << "nid slot var " << adjust_ins_weight_config_.nid_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + LoDTensor* nid_tensor = nid_var->GetMutable(); + if (nid_tensor == nullptr) { + VLOG(0) << "tensor of nid slot var " << adjust_ins_weight_config_.nid_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + Variable* ins_weight_var = + scope->FindVar(adjust_ins_weight_config_.ins_weight_slot()); + if (ins_weight_var == nullptr) { + VLOG(0) << "ins weight var " << adjust_ins_weight_config_.ins_weight_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + LoDTensor* ins_weight_tensor = ins_weight_var->GetMutable(); + if (ins_weight_tensor == nullptr) { + VLOG(0) << "tensor of ins weight tensor " + << adjust_ins_weight_config_.ins_weight_slot() + << " is nullptr, skip adjust ins weight"; + return; + } + + float* ins_weights = ins_weight_tensor->data(); + size_t len = ins_weight_tensor->numel(); // len = batch size + // here we assume nid_show slot only has one feasign in each instance + CHECK(len == nid_show_.size()) << "ins_weight size should be equal to " + << "nid_show size, " << len << " vs " + << nid_show_.size(); + float nid_adjw_threshold = adjust_ins_weight_config_.nid_adjw_threshold(); + float nid_adjw_ratio = adjust_ins_weight_config_.nid_adjw_ratio(); + int64_t nid_adjw_num = 0; + double nid_adjw_weight = 0.0; + size_t ins_index = 0; + for (size_t i = 0; i < len; ++i) { + float nid_show = nid_show_[i]; + VLOG(3) << "nid_show " << nid_show; + if (nid_show < 0) { + VLOG(3) << "nid_show < 0, continue"; + continue; + } + float ins_weight = 1.0; + if (nid_show >= 0 && nid_show < nid_adjw_threshold) { + ins_weight = log(M_E + + (nid_adjw_threshold - nid_show) / nid_adjw_threshold * + nid_adjw_ratio); + // count nid adjw insnum and weight + ++nid_adjw_num; + nid_adjw_weight += ins_weight; + // choose large ins weight + VLOG(3) << "ins weight new " << ins_weight << ", ins weight origin " + << ins_weights[ins_index]; + if (ins_weight > ins_weights[ins_index]) { + VLOG(3) << "ins " << ins_index << " weight changes to " << ins_weight; + ins_weights[ins_index] = ins_weight; + } + ++ins_index; + } + } + VLOG(3) << "nid adjw info: total_adjw_num: " << nid_adjw_num + << ", avg_adjw_weight: " << nid_adjw_weight; +#endif +} + +void HeterBoxWorker::TrainFiles() { + VLOG(3) << "Begin to train files"; + platform::SetNumThreads(1); + need_to_push_dense_ = false; + while (1) { + VLOG(3) << "before heter task"; + std::shared_ptr task; + + if (!pull_queue_->Get(task)) { + VLOG(3) << "get task"; + break; + } + VLOG(3) << "get task done"; + Scope* scope = task->scope_->kids().front(); + VLOG(3) << "get kid done"; + // do computation here + task->timeline.Start(); + for (auto& op : ops_) { + if (op->HasAttr("op_device")) { + auto device = op->Attr("op_device"); + if (device != "gpu") { + continue; + } + } + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->Run(*(scope), place_); + } + } + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + task->timeline.Pause(); + task->xpu_op_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + push_queue_->Put(task); + } +} + +void HeterTask::PackGpuTask(Scope* thread_scope, DataFeed* reader, + const ProgramDesc& program) { + auto& block = program.Block(0); + if (!scope_) { + scope_ = &(thread_scope->NewScope()); + for (auto& var : block.AllVars()) { + if (!var->Persistable()) { + auto* ptr = scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } + } + } + reader->AssignFeedVar(*scope_); + cur_batch_ = reader->Next(); +} + +void HeterBoxWorker::ResetStat() { + total_time_ = 0; + read_time_ = 0; + pack_time_ = 0; + pull_sparse_local_time_ = 0; + op_all_time_ = 0; + xpu_op_time_ = 0; + xpu_wait_time_ = 0; + cpu_op_time_ = 0; + collect_label_time_ = 0; + fill_sparse_time_ = 0; + push_sparse_time_ = 0; + gpu_2_cpu_time_ = 0; + cpu_2_gpu_time_ = 0; + total_inst_ = 0; +} + +void HeterBoxWorker::ProduceTasks() { + need_to_push_dense_ = false; + while (1) { + std::shared_ptr task; + task = object_pool_.Get(); + task->Reset(); + { + std::lock_guard lock(mutex_); + task->timeline.Start(); + task->PackGpuTask(thread_scope_, device_reader_, program_); + task->timeline.Pause(); + task->pack_time = task->timeline.ElapsedSec(); + task->total_time += task->pack_time; + if (task->cur_batch_ <= 0) { + if (!pull_queue_->Closed() && batch_cnt_ == done_cnt_) { + pull_queue_->Close(); + } + break; + } + batch_cnt_ += 1; + } + for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).pull_sparse_table_id(i)); + TableParameter table; + for (auto j : param_.sparse_table()) { + if (j.table_id() == tid) { + table = j; + break; + } + } + task->timeline.Start(); + fleet_ptr_->HeterPullSparseVars(thread_id_, task, tid, + sparse_key_names_[tid], table.fea_dim(), + sparse_value_names_[tid]); + task->timeline.Pause(); + task->pull_sparse_local_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + + task->timeline.Start(); + CollectLabelInfo(task, i); + task->timeline.Pause(); + task->collect_label_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + + task->timeline.Start(); + FillSparseValue(task, i); + task->timeline.Pause(); + task->fill_sparse_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + + auto nid_iter = std::find(sparse_value_names_[tid].begin(), + sparse_value_names_[tid].end(), + adjust_ins_weight_config_.nid_slot()); + if (nid_iter != sparse_value_names_[tid].end()) { + AdjustInsWeight(task); + } + } + + task->timeline.Start(); + size_t op_index = 0; + for (; op_index < ops_.size(); ++op_index) { + auto& op = ops_[op_index]; + if (op->HasAttr("op_device")) { + auto device = op->Attr("op_device"); + if (device == "gpu") { + break; + } + } + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->Run(*(task->scope_), platform::CPUPlace()); + } + } + + task->timeline.Pause(); + task->cpu_op_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + + task->timeline.Start(); + // prepare for gpu + Scope* cpu_scope = task->scope_; + Scope* gpu_scope = nullptr; + if (cpu_scope->kids().empty()) { + gpu_scope = &cpu_scope->NewScope(); + } else { + gpu_scope = cpu_scope->kids().front(); + } + for (const std::string& name : send_var_list_) { + const LoDTensor& cpu_tensor = cpu_scope->FindVar(name)->Get(); + LoDTensor* gpu_tensor = gpu_scope->Var(name)->GetMutable(); + gpu_tensor->set_lod(cpu_tensor.lod()); + gpu_tensor->Resize(cpu_tensor.dims()); + gpu_tensor->set_layout(cpu_tensor.layout()); + void* gpu_ptr = gpu_tensor->mutable_data(place_, cpu_tensor.type()); + const void* cpu_ptr = cpu_tensor.data(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place_), gpu_ptr, + platform::CPUPlace(), cpu_ptr, + cpu_tensor.numel() * SizeOfType(cpu_tensor.type()), + copy_stream_); + } + task->timeline.Pause(); + task->cpu_2_gpu_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + pull_queue_->Put(task); + push_queue_->Get(task); + + int need_copy_grad = 1; + task->timeline.Start(); + for (; op_index < ops_.size(); ++op_index) { + auto& op = ops_[op_index]; + if (op->HasAttr("op_device")) { + auto device = op->Attr("op_device"); + if (device == "gpu") { + continue; + } + } + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + need_copy_grad = 0; + op->Run(*(task->scope_), platform::CPUPlace()); + } + } + task->timeline.Pause(); + task->cpu_op_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + + VLOG(3) << "fill sparse value for all sparse table done."; + for (std::string& var_name : check_nan_var_names_) { + Variable* var = (task->scope_)->FindVar(var_name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (tensor == nullptr) { + continue; + } + PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false, + platform::errors::InvalidArgument( + "Tensor %s contains Inf.", var_name)); + PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false, + platform::errors::InvalidArgument( + "Tensor %s contains NAN.", var_name)); + } + + if (need_to_push_sparse_) { + // push gradients here + for (int i = 0; i < param_.program_config(0).push_sparse_table_id_size(); + ++i) { + uint64_t tid = static_cast( + param_.program_config(0).push_sparse_table_id(i)); + TableParameter table; + for (auto i : param_.sparse_table()) { + if (i.table_id() == tid) { + table = i; + break; + } + } + Scope* src_scope = task->scope_; + Scope* dest_scope = nullptr; + task->timeline.Start(); + if (need_copy_grad) { + if (cpu_scope->kids().empty()) { + dest_scope = &src_scope->NewScope(); + } else { + dest_scope = src_scope->kids().front(); + } + auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; + platform::CUDADeviceGuard guard(dev_id); + + for (const std::string& name : sparse_grad_names_[tid]) { + const LoDTensor& src_tensor = + src_scope->FindVar(name)->Get(); + LoDTensor* dest_tensor = + dest_scope->Var(name)->GetMutable(); + dest_tensor->set_lod(src_tensor.lod()); + dest_tensor->Resize(src_tensor.dims()); + dest_tensor->set_layout(src_tensor.layout()); + void* dest_ptr = dest_tensor->mutable_data(platform::CPUPlace(), + src_tensor.type()); + const void* src_ptr = src_tensor.data(); + memory::Copy(platform::CPUPlace(), dest_ptr, + BOOST_GET_CONST(platform::CUDAPlace, place_), src_ptr, + src_tensor.numel() * SizeOfType(src_tensor.type()), + copy_stream_); + } + } else { + dest_scope = task->scope_; + } + task->timeline.Pause(); + task->gpu_2_cpu_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + + task->timeline.Start(); + fleet_ptr_->HeterPushSparseVars( + task, *(dest_scope), tid, sparse_key_names_[tid], + sparse_grad_names_[tid], table.emb_dim(), &push_sparse_status_, + use_cvm_, dump_slot_, no_cvm_); + task->timeline.Pause(); + task->push_sparse_time += task->timeline.ElapsedSec(); + task->total_time += task->timeline.ElapsedSec(); + } + } + + if (need_to_push_sparse_) { + VLOG(3) << "push sparse gradient done."; + int32_t tmp_push_sparse_wait_times = -1; + static uint32_t push_sparse_wait_times = + static_cast(tmp_push_sparse_wait_times); + if (push_sparse_status_.size() >= push_sparse_wait_times) { + for (auto& t : push_sparse_status_) { + t.wait(); + } + push_sparse_status_.resize(0); + } + + if (tmp_push_sparse_wait_times == -1) { + push_sparse_status_.resize(0); + } + } + { + std::lock_guard lock(mutex_); + total_time_ += task->total_time; + read_time_ += task->read_time; + pack_time_ += task->pack_time; + pull_sparse_local_time_ += task->pull_sparse_local_time; + op_all_time_ += task->op_all_time; + xpu_op_time_ += task->xpu_op_time; + xpu_wait_time_ += task->xpu_wait_time; + cpu_op_time_ += task->cpu_op_time; + collect_label_time_ += task->collect_label_time; + fill_sparse_time_ += task->fill_sparse_time; + push_sparse_time_ += task->push_sparse_time; + gpu_2_cpu_time_ += task->gpu_2_cpu_time; + cpu_2_gpu_time_ += task->cpu_2_gpu_time; + total_inst_ += task->cur_batch_; + } + done_cnt_.fetch_add(1, std::memory_order_relaxed); + if (thread_id_ == 0) { + // should be configured here + if (done_cnt_ > 0 && done_cnt_ % 100 == 0) { + fprintf(stderr, "cpu_2_gpu total time: %fs\n", + cpu_2_gpu_time_ / done_cnt_); + fprintf(stderr, "gpu_2_cpu run total time: %fs\n", + gpu_2_cpu_time_ / done_cnt_); + fprintf(stderr, "cpu op run total time: %fs\n", + cpu_op_time_ / done_cnt_); + fprintf(stderr, "xpu op run total time: %fs\n", + xpu_op_time_ / done_cnt_); + fprintf(stderr, "xpu wait total time: %fs\n", + xpu_wait_time_ / done_cnt_); + fprintf(stderr, "pack task time: %fs\n", pack_time_ / done_cnt_); + fprintf(stderr, "train total time: %fs\n", total_time_ / done_cnt_); + fprintf(stderr, "pull sparse local time: %fs\n", + pull_sparse_local_time_ / done_cnt_); + fprintf(stderr, "fill sparse time: %fs\n", + fill_sparse_time_ / done_cnt_); + fprintf(stderr, "push sparse time: %fs\n", + push_sparse_time_ / done_cnt_); + fprintf(stderr, "collect label time: %fs\n", + collect_label_time_ / done_cnt_); + fprintf(stderr, "mean read time: %fs\n", read_time_ / done_cnt_); + fprintf(stderr, "IO percent: %f\n", read_time_ / total_time_ * 100); + fprintf(stderr, "cpu_2_gpu run percent: %f\n", + cpu_2_gpu_time_ / total_time_ * 100); + fprintf(stderr, "gpu_2_cpu run percent: %f\n", + gpu_2_cpu_time_ / total_time_ * 100); + fprintf(stderr, "cpu op run percent: %f\n", + cpu_op_time_ / total_time_ * 100); + fprintf(stderr, "xpu op run percent: %f\n", + xpu_op_time_ / total_time_ * 100); + fprintf(stderr, "xpu wait percent: %f\n", + xpu_wait_time_ / total_time_ * 100); + fprintf(stderr, "pack task percent: %f\n", + pack_time_ / total_time_ * 100); + fprintf(stderr, "pull sparse local time percent: %f\n", + pull_sparse_local_time_ / total_time_ * 100); + fprintf(stderr, "collect label time percent: %f\n", + collect_label_time_ / total_time_ * 100); + fprintf(stderr, "fill sparse time percent: %f\n", + fill_sparse_time_ / total_time_ * 100); + fprintf(stderr, "push sparse time percent: %f\n", + push_sparse_time_ / total_time_ * 100); + fprintf(stderr, "%6.2f instances/s\n", total_inst_ / total_time_); + } + } + + VLOG(3) << "done taskid = " << task->taskid_; + task->scope_->DropKids(); + object_pool_.Push(task); + } +} + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/hetercpu_worker.cc b/paddle/fluid/framework/hetercpu_worker.cc index 83838f4df67..f50cc2769e9 100644 --- a/paddle/fluid/framework/hetercpu_worker.cc +++ b/paddle/fluid/framework/hetercpu_worker.cc @@ -811,9 +811,9 @@ void HeterCpuWorker::TrainFilesWithProfiler() { } timeline.Start(); fleet_ptr_->HeterPushSparseVars( - task, tid, sparse_key_names_[tid], sparse_grad_names_[tid], - table.emb_dim(), &push_sparse_status_, use_cvm_, dump_slot_, - no_cvm_); + task, *(task->scope_), tid, sparse_key_names_[tid], + sparse_grad_names_[tid], table.emb_dim(), &push_sparse_status_, + use_cvm_, dump_slot_, no_cvm_); timeline.Pause(); task->push_sparse_time += timeline.ElapsedSec(); task->total_time += timeline.ElapsedSec(); @@ -1074,9 +1074,9 @@ void HeterCpuWorker::TrainFiles() { } } fleet_ptr_->HeterPushSparseVars( - task, tid, sparse_key_names_[tid], sparse_grad_names_[tid], - table.emb_dim(), &push_sparse_status_, use_cvm_, dump_slot_, - no_cvm_); + task, *(task->scope_), tid, sparse_key_names_[tid], + sparse_grad_names_[tid], table.emb_dim(), &push_sparse_status_, + use_cvm_, dump_slot_, no_cvm_); } } diff --git a/paddle/fluid/framework/heterxpu_trainer.cc b/paddle/fluid/framework/heterxpu_trainer.cc index 6bbbaacdde3..5e1fabf2038 100644 --- a/paddle/fluid/framework/heterxpu_trainer.cc +++ b/paddle/fluid/framework/heterxpu_trainer.cc @@ -415,7 +415,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request, std::shared_ptr context = object_pool_.Get(); if (!context->scope_) { - int num = rand_r() % places_.size(); + int num = rand() % places_.size(); context->place_num_ = num; auto place = places_[num]; context->scope_ = &(place_scopes_[num]->NewScope()); diff --git a/paddle/fluid/framework/pull_dense_worker.cc b/paddle/fluid/framework/pull_dense_worker.cc index bfb5aa4a26a..093b0dfe8fa 100644 --- a/paddle/fluid/framework/pull_dense_worker.cc +++ b/paddle/fluid/framework/pull_dense_worker.cc @@ -225,5 +225,22 @@ void PullDenseWorker::SetThreadIdByScope(const Scope* scope, int tid) { scope_to_thread_id_[scope] = tid; } +void PullDenseWorker::MergeDenseParam() { + for (int x = 0; x < dwp_param_.program_config(0).pull_dense_table_id_size(); + ++x) { + uint64_t tid = static_cast( + dwp_param_.program_config(0).pull_dense_table_id(x)); + for (size_t j = 0; j < dense_value_names_[tid].size(); j++) { + auto& name = dense_value_names_[tid][j]; + + Variable* root_var = root_scope_->FindVar(name); + LoDTensor* root_tensor = root_var->GetMutable(); + Variable* var = thread_scopes_[0]->FindVar(name); + LoDTensor* tensor = var->GetMutable(); + TensorCopy((*tensor), root_tensor->place(), root_tensor); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index ecaec49aa46..88dbe9c748d 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -224,6 +224,56 @@ class HeterXpuTrainer : public TrainerBase { std::vector events_; #endif }; + +class HeterBoxTrainer : public TrainerBase { + public: + HeterBoxTrainer() {} + virtual ~HeterBoxTrainer() {} + virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); + virtual void InitTrainerEnv(const ProgramDesc& main_program, + const platform::Place& place); + virtual void InitOtherEnv(const ProgramDesc& main_program); + virtual void Run(); + virtual void Finalize(); + virtual void RegisterHeterCallback(); + virtual void DumpWork(int tid); + virtual Scope* GetWorkerScope(int thread_id); + virtual void CacheProgram(const ProgramDesc& main_program) { + new (&program_) ProgramDesc(main_program); + } + virtual std::string GetDumpPath(int tid) { return ""; } + virtual void InitDumpEnv() {} + template +#ifdef PADDLE_WITH_CUDA + void HeterMemCpy(LoDTensor* tensor, LoDTensor* root_tensor, + const paddle::platform::Place& thread_place, + cudaStream_t stream); +#endif + void CreateThreadParam(const ProgramDesc& program, int num); + template + void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); + + protected: + DownpourWorkerParameter param_; + std::map> dense_grad_names_; + std::vector need_merge_var_names_; + float scale_datanorm_; + paddle::platform::Place place_; + ProgramDesc program_; + std::shared_ptr fleet_ptr_; + std::shared_ptr pull_dense_worker_; + std::vector> workers_; + std::vector places_; + // ps-gpu + std::vector pull_threads_; + std::vector threads_; + int use_ps_gpu_; + int thread_num_; +#ifdef PADDLE_WITH_CUDA + std::vector copy_streams_; + std::vector events_; +#endif +}; #endif #if defined(PADDLE_WITH_NCCL) diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 87de436617e..4d2e6d9b3a2 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -59,6 +59,8 @@ message TrainerDesc { optional int32 xpu_start_idx = 30; optional int32 xpu_end_idx = 31; + optional bool use_ps_gpu = 32 [ default = false ]; + // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; optional DownpourWorkerParameter downpour_param = 103; diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index cc92c50cc42..087d1ea0af8 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -66,6 +66,7 @@ REGISTER_TRAINER_CLASS(DistMultiTrainer); #if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ (defined PADDLE_WITH_PSLIB) REGISTER_TRAINER_CLASS(HeterXpuTrainer); +REGISTER_TRAINER_CLASS(HeterBoxTrainer); #endif #if defined(PADDLE_WITH_NCCL) REGISTER_TRAINER_CLASS(PipelineTrainer); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 90851e6d864..b4dfb9a914c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1368,6 +1368,8 @@ class Executor(object): is_heter = 1 if program._fleet_opt.get("trainer", "") == "HeterXpuTrainer": is_heter = 1 + if program._fleet_opt.get("use_ps_gpu", ""): + is_heter = 1 if scope is None: scope = global_scope() if fetch_list is None: diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index f3563808d23..6bc0b60650f 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -987,12 +987,64 @@ class DownpourOptimizer(DistributedOptimizer): """ raise NotImplementedError() + def _remove_collective_ops(self, program, name): + """ + colective init op should call once, so remove other call. + """ + block = program.global_block() + for ids, op in list(enumerate(block.ops)): + if op.type == name: + block._remove_op(ids) + return + def apply_gradients(self, params_grads): """ Currently, apply_gradients function can not be called through DistributedOptimizer """ raise NotImplementedError() + def get_dist_env(self): + trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0')) + trainer_endpoints = '' + current_endpoint = '' + num_trainers = 0 + if os.getenv('PADDLE_TRAINER_ENDPOINTS') and os.getenv( + 'PADDLE_CURRENT_ENDPOINT'): + trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS') + current_endpoint = os.getenv('PADDLE_CURRENT_ENDPOINT') + num_trainers = len(trainer_endpoints.split(',')) + + return { + 'trainer_id': trainer_id, + 'num_trainers': num_trainers, + 'current_endpoint': current_endpoint, + 'trainer_endpoints': trainer_endpoints + } + + def _remove_collective_op_for_embedding(self, loss, table_name): + """ + find multi-sparse-table + """ + table_name = [name + "@GRAD" for name in table_name] + need_remove_op_index = [] + block = loss.block.program.global_block() + collective_ops = ["c_sync_calc_stream", "c_allreduce_sum"] + for ids, op in list(enumerate(block.ops)): + if op.type in collective_ops: + if op.input("X")[0] in table_name: + need_remove_op_index.append(ids) + if op.type == "lookup_table_grad": + need_remove_op_index.append(ids) + try: + if op.output("Out")[0] in table_name: + need_remove_op_index.append(ids) + except: + pass + + need_remove_op_index.sort(reverse=True) + for index in need_remove_op_index: + block._remove_op(index) + def minimize(self, losses, scopes=None, @@ -1043,5 +1095,31 @@ class DownpourOptimizer(DistributedOptimizer): fleet._main_programs = programs fleet._scopes = scopes + if opt_info["use_ps_gpu"]: + from paddle.fluid.transpiler.collective import SingleProcessMultiThread + # check start program + + env = self.get_dist_env() + if not isinstance(losses, list): + startup_programs = [startup_programs] + for i in range(0, len(startup_programs)): + t = SingleProcessMultiThread() + start_program = startup_programs[i] + main_program = programs[i] + t.transpile( + startup_program=start_program, + main_program=main_program, + rank=env["trainer_id"], + endpoints=env["trainer_endpoints"], + current_endpoint=env['current_endpoint'], + wait_port=False) + if i > 0: + self._remove_collective_ops(start_program, + "c_comm_init_all") + for i in range(0, len(losses)): + loss = losses[i] + embedding_table = self._distributed_optimizer._find_multi_distributed_lookup_table( + [loss]) + self._remove_collective_op_for_embedding(loss, embedding_table) return [optimize_ops, param_grads] diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 0189bc2bd74..61fbc7fdf66 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -15,14 +15,17 @@ __all__ = ["DistributedAdam", "FLEET_GLOBAL_DICT"] import paddle.fluid as fluid +from paddle.fluid import core from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs from google.protobuf import text_format from collections import OrderedDict +import copy from .node import DownpourWorker, DownpourServer from . import ps_pb2 as pslib +OpRole = core.op_proto_and_checker_maker.OpRole # this dict is for store info about pull/push sparse ops. FLEET_GLOBAL_DICT = { # global settings @@ -87,6 +90,8 @@ class DistributedAdam(DistributedOptimizerImplBase): self.supported_embedding_grad_types = [ "lookup_table_grad", "push_sparse", "push_sparse_v2" ] + op_maker = core.op_proto_and_checker_maker + self.op_role_key = op_maker.kOpRoleAttrName() def _find_distributed_lookup_table_inputs(self, program, table_names): """ @@ -145,6 +150,26 @@ class DistributedAdam(DistributedOptimizerImplBase): [local_vars[name] for name in op.input("Out@GRAD")]) return grads_dict + def _is_optimizer_op(self, op): + return self.op_role_key in op.attr_names and \ + int(op.all_attrs()[self.op_role_key]) & int(OpRole.Optimize) + + def _remove_optimize_op_for_embedding(self, loss, table_name): + """ + find multi-sparse-table + """ + table_name = [name + "@GRAD" for name in table_name] + need_remove_op_index = [] + block = loss.block.program.global_block() + for ids, op in list(enumerate(block.ops)): + if self._is_optimizer_op(op): + if op.input("Grad")[0] in table_name: + need_remove_op_index.append(ids) + + need_remove_op_index.sort(reverse=True) + for index in need_remove_op_index: + block._remove_op(index) + def _find_multi_distributed_lookup_table(self, losses): """ find multi-sparse-table @@ -314,7 +339,8 @@ class DistributedAdam(DistributedOptimizerImplBase): sparse_table_to_index = OrderedDict() sparse_table_index = 0 - for loss in losses: + for num in range(len(losses)): + loss = losses[num] prog_id = str(id(loss.block.program)) # param_grads of program params_grads = sorted( @@ -322,6 +348,18 @@ class DistributedAdam(DistributedOptimizerImplBase): no_grad_set), key=lambda x: x[0].name) + flag_use_ps_gpu = strategy.get("use_ps_gpu", False) + if flag_use_ps_gpu: + if not isinstance(startup_program, list): + startup_program = [startup_program] + optimizer = copy.deepcopy(self._optimizer) + optimize_ops = optimizer.apply_optimize( + loss, + startup_program=startup_program[num], + params_grads=params_grads) + embedding_table = self._find_multi_distributed_lookup_table( + [loss]) + self._remove_optimize_op_for_embedding(loss, embedding_table) # has condition_block op means multi-task flag_multi_task = self._has_conditional_block(loss) if flag_multi_task: @@ -725,6 +763,7 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_param"] = strategy.get("dump_param", []) opt_info["worker_places"] = strategy.get("worker_places", []) + opt_info["use_ps_gpu"] = strategy.get("use_ps_gpu", False) if server._server.downpour_server_param.downpour_table_param[ 0].accessor.accessor_class in [ "DownpourCtrAccessor", "DownpourCtrDoubleAccessor", diff --git a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py index 58313c46c3c..c126f06de9d 100644 --- a/python/paddle/fluid/incubate/fleet/utils/fleet_util.py +++ b/python/paddle/fluid/incubate/fleet/utils/fleet_util.py @@ -23,18 +23,18 @@ import os import sys import time import paddle.fluid as fluid +from paddle.fluid import core from paddle.fluid.log_helper import get_logger -from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_pslib -from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient from . import utils +OpRole = core.op_proto_and_checker_maker.OpRole __all__ = ["FleetUtil"] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -fleet = fleet_pslib +fleet = None class FleetUtil(object): @@ -52,9 +52,13 @@ class FleetUtil(object): def __init__(self, mode="pslib"): global fleet + op_maker = core.op_proto_and_checker_maker + self.op_role_key = op_maker.kOpRoleAttrName() if mode == "pslib": + from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet as fleet_pslib fleet = fleet_pslib elif mode == "transpiler": + from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler fleet = fleet_transpiler else: raise ValueError( @@ -1616,20 +1620,26 @@ class FleetUtil(object): program = utils.load_program(prog_path, is_text) utils.parse_program(program, output_dir) + def _is_optimizer_op(self, op): + return self.op_role_key in op.attr_names and \ + int(op.all_attrs()[self.op_role_key]) & int(OpRole.Optimize) + def split_program_by_device(self, program): ops_list = [] type_list = [] pre = None type_cpu = "cpu" for op in program.global_block().ops: + if self._is_optimizer_op(op): + break if op.has_attr("op_device"): - if pre is None or pre != op.attr("op_device"): + cur_attr = op.attr("op_device") if op.attr( + "op_device") != "" else type_cpu + if pre is None or pre != cur_attr: ops_list.append([]) - type_list.append( - op.attr("op_device") - if op.attr("op_device") != "" else type_cpu) + type_list.append(cur_attr) ops_list[-1].append(op) - pre = op.attr("op_device") + pre = cur_attr l = len(type_list) i = 0 type_heter = None diff --git a/python/paddle/fluid/incubate/fleet/utils/hdfs.py b/python/paddle/fluid/incubate/fleet/utils/hdfs.py index b136b3853ad..4d343ffaf14 100644 --- a/python/paddle/fluid/incubate/fleet/utils/hdfs.py +++ b/python/paddle/fluid/incubate/fleet/utils/hdfs.py @@ -79,7 +79,6 @@ class HDFSClient(FS): time_out=5 * 60 * 1000, #ms sleep_inter=1000): #ms # Raise exception if JAVA_HOME not exists. - java_home = os.environ["JAVA_HOME"] self.pre_commands = [] hadoop_bin = '%s/bin/hadoop' % hadoop_home diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 9f0089f68ab..ac7c8c0a687 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -17,7 +17,7 @@ import sys import os __all__ = [ 'TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer', - 'HeterXpuTrainer' + 'HeterXpuTrainer', 'HeterBoxWorker' ] @@ -166,6 +166,9 @@ class TrainerDesc(object): for place in worker_places: self.proto_desc.worker_places.append(place) + def _set_use_ps_gpu(self, use_ps_gpu=False): + self.proto_desc.use_ps_gpu = use_ps_gpu + def _set_thread_barrier(self, thread_barrier): self.proto_desc.thread_barrier = thread_barrier @@ -340,6 +343,30 @@ class HeterXpuTrainer(TrainerDesc): self._device_worker._gen_worker_desc(self.proto_desc) +class HeterBoxTrainer(TrainerDesc): + """ + Implement of HeterBoxTrainer. + It's for Distributed training. + """ + + def __init__(self): + super(HeterBoxTrainer, self).__init__() + pass + + def _set_program(self, program): + super(HeterBoxTrainer, self)._set_program(program) + self._program = program + + def _gen_trainer_desc(self): + super(HeterBoxTrainer, self)._gen_trainer_desc() + self.proto_desc.class_name = "HeterBoxTrainer" + if self._program == None: + raise RuntimeError("None Program") + self._device_worker._set_infer(self._infer) + self._device_worker._set_program(self._program) + self._device_worker._gen_worker_desc(self.proto_desc) + + class PipelineTrainer(TrainerDesc): """ Implement of PipelineTrainer. diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index f7573f6045d..5aff7811330 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -22,7 +22,7 @@ from paddle.fluid.log_helper import get_logger local_logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer +from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, HeterBoxTrainer from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT from .framework import Variable from multiprocessing import Process, Manager @@ -77,6 +77,8 @@ class TrainerFactory(object): trainer._set_dump_param(opt_info["dump_param"]) if opt_info.get("worker_places") is not None: trainer._set_worker_places(opt_info["worker_places"]) + if opt_info.get("use_ps_gpu") is not None: + trainer._set_use_ps_gpu(opt_info["use_ps_gpu"]) if opt_info.get("enable_random_dump") is not None: trainer._set_enable_random_dump(opt_info[ "enable_random_dump"]) -- GitLab