From fbbc339428cda9c4120dc9b8bd75a955aaa84c59 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Mon, 10 May 2021 20:07:21 +0800 Subject: [PATCH] [pslib] pslib with cmake (#32800) * pslib with cmake * heter util * vlog * heter server test * add dtor * cmake --- cmake/third_party.cmake | 8 + paddle/fluid/framework/CMakeLists.txt | 21 +- paddle/fluid/framework/device_worker.h | 2 +- paddle/fluid/framework/executor.h | 2 - paddle/fluid/framework/executor_cache.h | 2 + paddle/fluid/framework/fleet/CMakeLists.txt | 23 +- paddle/fluid/framework/fleet/fleet_wrapper.h | 2 +- .../framework/fleet/heter_ps/hashtable.h | 4 +- paddle/fluid/framework/fleet/heter_wrapper.h | 1 + paddle/fluid/framework/heter_service.h | 293 ---------------- paddle/fluid/framework/heter_util.h | 329 ++++++++++++++++++ paddle/fluid/framework/heterbox_worker.cc | 2 +- paddle/fluid/framework/heterxpu_trainer.cc | 1 + paddle/fluid/framework/multi_trainer.cc | 2 + paddle/fluid/framework/ps_gpu_worker.cc | 3 - paddle/fluid/framework/trainer.h | 9 +- .../controlflow/conditional_block_op_helper.h | 1 + .../operators/pscore/heter_server_test.cc | 2 + paddle/fluid/pybind/CMakeLists.txt | 8 + 19 files changed, 403 insertions(+), 312 deletions(-) create mode 100644 paddle/fluid/framework/heter_util.h diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index f90fa3509d..56edaff2a5 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -261,6 +261,14 @@ if(WITH_PSLIB) if(WITH_PSLIB_BRPC) include(external/pslib_brpc) # download, build, install pslib_brpc list(APPEND third_party_deps extern_pslib_brpc) + else() + include(external/snappy) + list(APPEND third_party_deps extern_snappy) + + include(external/leveldb) + list(APPEND third_party_deps extern_leveldb) + include(external/brpc) + list(APPEND third_party_deps extern_brpc) endif() endif(WITH_PSLIB) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index e55fca403a..4644e674ba 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -100,8 +100,16 @@ if (WITH_GPU) endif() cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits) +set(BRPC_DEPS "") +if(WITH_PSLIB OR WITH_PSCORE) + set(BRPC_DEPS brpc) + if(WITH_PSLIB_BRPC) + set(BRPC_DEPS pslib_brpc) + endif() +endif() + cc_library(scope SRCS scope.cc DEPS glog threadpool xxhash var_type_traits) -cc_library(device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor scope) +cc_library(device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor scope ${BRPC_DEPS}) cc_test(device_worker_test SRCS device_worker_test.cc DEPS device_worker) cc_library(scope_pool SRCS scope_pool.cc DEPS scope) @@ -243,9 +251,16 @@ if(WITH_DISTRIBUTE) fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto timer monitor - heter_service_proto pslib_brpc) + heter_service_proto ${BRPC_DEP}) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS + "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") + endif() set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(device_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(hetercpu_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(heterxpu_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) elseif(WITH_PSCORE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc @@ -280,7 +295,7 @@ elseif(WITH_PSLIB) pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method - graph_to_program_pass variable_helper timer monitor pslib_brpc ) + graph_to_program_pass variable_helper timer monitor ${BRPC_DEP}) else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index d6c422415f..8436901147 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -29,7 +29,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/executor_gc_helper.h" -#include "paddle/fluid/framework/heter_service.h" +#include "paddle/fluid/framework/heter_util.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/program_desc.h" diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 7593b60abf..9c9f29520d 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -20,14 +20,12 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index 782018d1cf..3beeacb101 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -22,8 +22,10 @@ #include #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 03dd2cff65..a9e4691dd0 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -1,5 +1,10 @@ if(WITH_PSLIB) - cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib) + if(WITH_PSLIB_BRPC) + set(BRPC_DEPS pslib_brpc) + else() + set(BRPC_DEPS brpc) + endif(WITH_PSLIB_BRPC) + cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope ${BRPC_DEPS} pslib) else() cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope) endif(WITH_PSLIB) @@ -7,11 +12,11 @@ endif(WITH_PSLIB) if(WITH_HETERPS) if(WITH_NCCL) nv_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc - DEPS heter_ps) + DEPS heter_ps ${BRPC_DEPS}) add_subdirectory(heter_ps) elseif(WITH_RCCL) hip_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc - DEPS heter_ps) + DEPS heter_ps ${BRPC_DEPS}) add_subdirectory(heter_ps) endif(WITH_NCCL) else() @@ -39,7 +44,17 @@ else() cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope) endif(WITH_GLOO) -cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_context heter_service_proto) +if(WITH_PSLIB) +set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") +if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS + "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") +endif() +set_source_files_properties(heter_wrapper.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +endif() + +cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto +device_context heter_service_proto ${BRPC_DEPS}) cc_test(test_fleet_cc SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 613b280363..09f7801b19 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -28,7 +28,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/heter_service.h" +#include "paddle/fluid/framework/heter_util.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 089130f6da..3782e14ad4 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -17,16 +17,16 @@ limitations under the License. */ #include #include #include -#ifdef PADDLE_WTIH_PSLIB +#ifdef PADDLE_WITH_PSLIB #include "common_value.h" // NOLINT #endif #ifdef PADDLE_WITH_PSCORE +#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #endif #include "thrust/pair.h" //#include "cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" #ifdef PADDLE_WITH_HETERPS -#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #include "paddle/fluid/platform/type_defs.h" namespace paddle { diff --git a/paddle/fluid/framework/fleet/heter_wrapper.h b/paddle/fluid/framework/fleet/heter_wrapper.h index 871d2e251b..4e529de077 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_wrapper.h @@ -25,6 +25,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_PSLIB #include "paddle/fluid/framework/heter_service.h" +#include "paddle/fluid/framework/heter_util.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" diff --git a/paddle/fluid/framework/heter_service.h b/paddle/fluid/framework/heter_service.h index 3f65eaf3aa..7e5bf138d9 100644 --- a/paddle/fluid/framework/heter_service.h +++ b/paddle/fluid/framework/heter_service.h @@ -72,299 +72,6 @@ class HeterXpuService : public HeterService { std::unordered_map handler_map_; }; -enum HeterTaskState { PULL_SPARSE, OP_RUN, XPU, OP_RUN_END, PUSH_GRAD, DONE }; - -class HeterTask { - public: - void Update() { - if (state_ == PULL_SPARSE) { - state_ = OP_RUN; - } else if (state_ == OP_RUN) { - state_ = XPU; - // state_ = PUSH_GRAD; - // state_ = PUSH_GRAD; - } else if (state_ == XPU) { - state_ = OP_RUN_END; - } else if (state_ == OP_RUN_END) { - state_ = PUSH_GRAD; - } else if (state_ == PUSH_GRAD) { - state_ = DONE; - } - } - void Reset() { - total_time = 0; - read_time = 0; - pack_time = 0; - pull_sparse_local_time = 0; - op_all_time = 0; - xpu_op_time = 0; - xpu_wait_time = 0; - cpu_op_time = 0; - collect_label_time = 0; - fill_sparse_time = 0; - push_sparse_time = 0; - gpu_2_cpu_time = 0; - cpu_2_gpu_time = 0; - timeline.Reset(); - } - void Show() { - std::cout << "features size " << features_.size() << std::endl; - for (size_t i = 0; i < features_.size(); ++i) { - std::cout << "features[" << i << "] size " << features_[i].size() - << std::endl; - } - } - void PackTask(Scope* scope, int taskid, DataFeed* reader, int cur_batch, - const ProgramDesc& program); - void PackGpuTask(Scope* thread_scope, DataFeed* reader, - const ProgramDesc& program); - - Scope* scope_{nullptr}; - int taskid_; - int cur_batch_; - HeterTaskState state_; - // cache - std::map> features_; - std::map> feature_labels_; - std::map>> feature_values_; - std::map>> feature_grads_; - std::map> sparse_push_keys_; - double total_time{0}; - double read_time{0}; - double pack_time{0}; - double pull_sparse_local_time{0}; - double op_all_time{0}; - double xpu_op_time{0}; - double xpu_wait_time{0}; - double cpu_op_time{0}; - double collect_label_time{0}; - double fill_sparse_time{0}; - double push_sparse_time{0}; - double gpu_2_cpu_time{0}; - double cpu_2_gpu_time{0}; - platform::Timer timeline; -}; -#endif -template -class HeterObjectPool { - public: - HeterObjectPool() {} - virtual ~HeterObjectPool(){}; - std::shared_ptr Get() { - std::lock_guard lock(mutex_); - if (pool_.empty()) { - num_ += 1; -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - VLOG(3) << "pool construct size: " << num_; -#endif - return std::make_shared(); - } else { - auto ret = pool_.back(); - pool_.pop_back(); - return ret; - } - } - void Push(std::shared_ptr data) { - std::lock_guard lock(mutex_); - pool_.push_back(std::move(data)); - } - int Size() { - std::lock_guard lock(mutex_); - return pool_.size(); - } - std::shared_ptr& GetElement(int i) { return pool_[i]; } - - private: - std::vector> pool_; - std::mutex mutex_; - int num_{0}; -}; - -#ifdef PADDLE_WITH_PSLIB -struct BthreadMutextGuard { - BthreadMutextGuard(bthread_mutex_t* rho) { - mutex_ = rho; - bthread_mutex_lock(mutex_); - } - ~BthreadMutextGuard() { bthread_mutex_unlock(mutex_); } - bthread_mutex_t* mutex_; -}; - -template -class BtObjectPool { - public: - BtObjectPool() { - bthread_mutex_init(&mutex_, NULL); - bthread_cond_init(&cond_, NULL); - } - - virtual ~BtObjectPool() { - bthread_cond_destroy(&cond_); - bthread_mutex_destroy(&mutex_); - }; - - std::shared_ptr Get() { - BthreadMutextGuard guard(&mutex_); - while (pool_.empty()) { - bthread_cond_wait(&cond_, &mutex_); - } - auto ret = pool_.back(); - pool_.pop_back(); - return ret; - } - - void Push(std::shared_ptr data) { - BthreadMutextGuard guard(&mutex_); - pool_.push_back(std::move(data)); - bthread_cond_signal(&cond_); - } - - int Size() { return pool_.size(); } - - std::shared_ptr& GetElement(int i) { return pool_[i]; } - - private: - std::vector> pool_; - bthread_mutex_t mutex_; - bthread_cond_t cond_; - int num_{0}; -}; - -template -struct HeterNode { - K key; - T value; - HeterNode* prev; - HeterNode* next; -}; - -template -class HeterList { - public: - HeterList() : head_(new HeterNode), tail_(new HeterNode) { - head_->prev = NULL; - head_->next = tail_; - tail_->prev = head_; - tail_->next = NULL; - size = 0; - cap_ = 1e9; - } - - ~HeterList() { - delete head_; - delete tail_; - } - - void SetCap(int num) { cap_ = num; } - - bool TryPut(K& key, T& value) { - std::unique_lock lock(mutex_); - cond_.wait(lock, [this] { return size < cap_; }); - if (task_map_.find(key) != task_map_.end()) { - task_map_.erase(key); - return false; - } else { - HeterNode* node = new HeterNode; - node->key = key; - node->value = value; - map_[node->key] = node; - attach(node); - return true; - } - } - - bool Put(K& key, T& value) { - std::unique_lock lock(mutex_); - cond_.wait(lock, [this] { return size < cap_; }); - HeterNode* node = new HeterNode; - node->key = key; - node->value = value; - map_[node->key] = node; - attach(node); - return true; - } - - T TryGet(const K& key) { - std::lock_guard lock(mutex_); - auto iter = map_.find(key); - if (iter != map_.end()) { - HeterNode* node = iter->second; - detach(node); - cond_.notify_one(); - T ret = std::move(node->value); - map_.erase(key); - delete node; - return ret; - } - task_map_.insert(key); - return nullptr; - } - - T Get(const K& key) { - std::lock_guard lock(mutex_); - auto iter = map_.find(key); - if (iter != map_.end()) { - HeterNode* node = iter->second; - detach(node); - cond_.notify_one(); - T ret = std::move(node->value); - map_.erase(key); - delete node; - return ret; - } - return nullptr; - } - - T Get() { - std::lock_guard lock(mutex_); - HeterNode* node = head_->next; - if (node == tail_) { - return nullptr; - } else { - detach(node); - cond_.notify_one(); - T ret = std::move(node->value); - map_.erase(node->key); - delete node; - return ret; - } - } - - bool Empty() { - std::lock_guard lock(mutex_); - return head_->next == tail_; - } - - int Size() { - std::lock_guard lock(mutex_); - return size; - } - - private: - void detach(HeterNode* node) { - node->prev->next = node->next; - node->next->prev = node->prev; - size--; - } - - void attach(HeterNode* node) { - node->prev = head_; - node->next = head_->next; - head_->next->prev = node; - head_->next = node; - size++; - } - - private: - HeterNode* head_; - HeterNode* tail_; - std::unordered_map*> map_; - std::unordered_set task_map_; - std::mutex mutex_; - std::condition_variable cond_; - int cap_; - int size; -}; #endif } // namespace framework diff --git a/paddle/fluid/framework/heter_util.h b/paddle/fluid/framework/heter_util.h new file mode 100644 index 0000000000..a08f08428d --- /dev/null +++ b/paddle/fluid/framework/heter_util.h @@ -0,0 +1,329 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_PSLIB +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include +#include "bthread/bthread.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace framework { +class DataFeed; +enum HeterTaskState { PULL_SPARSE, OP_RUN, XPU, OP_RUN_END, PUSH_GRAD, DONE }; + +class HeterTask { + public: + HeterTask() {} + virtual ~HeterTask(){}; + + void Update() { + if (state_ == PULL_SPARSE) { + state_ = OP_RUN; + } else if (state_ == OP_RUN) { + state_ = XPU; + // state_ = PUSH_GRAD; + // state_ = PUSH_GRAD; + } else if (state_ == XPU) { + state_ = OP_RUN_END; + } else if (state_ == OP_RUN_END) { + state_ = PUSH_GRAD; + } else if (state_ == PUSH_GRAD) { + state_ = DONE; + } + } + void Reset() { + total_time = 0; + read_time = 0; + pack_time = 0; + pull_sparse_local_time = 0; + op_all_time = 0; + xpu_op_time = 0; + xpu_wait_time = 0; + cpu_op_time = 0; + collect_label_time = 0; + fill_sparse_time = 0; + push_sparse_time = 0; + gpu_2_cpu_time = 0; + cpu_2_gpu_time = 0; + timeline.Reset(); + } + void Show() { + std::cout << "features size " << features_.size() << std::endl; + for (size_t i = 0; i < features_.size(); ++i) { + std::cout << "features[" << i << "] size " << features_[i].size() + << std::endl; + } + } + void PackTask(Scope* scope, int taskid, DataFeed* reader, int cur_batch, + const ProgramDesc& program); + void PackGpuTask(Scope* thread_scope, DataFeed* reader, + const ProgramDesc& program); + + Scope* scope_{nullptr}; + int taskid_; + int cur_batch_; + HeterTaskState state_; + // cache + std::map> features_; + std::map> feature_labels_; + std::map>> feature_values_; + std::map>> feature_grads_; + std::map> sparse_push_keys_; + double total_time{0}; + double read_time{0}; + double pack_time{0}; + double pull_sparse_local_time{0}; + double op_all_time{0}; + double xpu_op_time{0}; + double xpu_wait_time{0}; + double cpu_op_time{0}; + double collect_label_time{0}; + double fill_sparse_time{0}; + double push_sparse_time{0}; + double gpu_2_cpu_time{0}; + double cpu_2_gpu_time{0}; + platform::Timer timeline; +}; +#endif +template +class HeterObjectPool { + public: + HeterObjectPool() {} + virtual ~HeterObjectPool(){}; + std::shared_ptr Get() { + std::lock_guard lock(mutex_); + if (pool_.empty()) { + num_ += 1; + return std::make_shared(); + } else { + auto ret = pool_.back(); + pool_.pop_back(); + return ret; + } + } + void Push(std::shared_ptr data) { + std::lock_guard lock(mutex_); + pool_.push_back(std::move(data)); + } + int Size() { + std::lock_guard lock(mutex_); + return pool_.size(); + } + std::shared_ptr& GetElement(int i) { return pool_[i]; } + + private: + std::vector> pool_; + std::mutex mutex_; + int num_{0}; +}; + +#ifdef PADDLE_WITH_PSLIB +struct BthreadMutextGuard { + BthreadMutextGuard(bthread_mutex_t* rho) { + mutex_ = rho; + bthread_mutex_lock(mutex_); + } + ~BthreadMutextGuard() { bthread_mutex_unlock(mutex_); } + bthread_mutex_t* mutex_; +}; + +template +class BtObjectPool { + public: + BtObjectPool() { + bthread_mutex_init(&mutex_, NULL); + bthread_cond_init(&cond_, NULL); + } + + virtual ~BtObjectPool() { + bthread_cond_destroy(&cond_); + bthread_mutex_destroy(&mutex_); + }; + + std::shared_ptr Get() { + BthreadMutextGuard guard(&mutex_); + while (pool_.empty()) { + bthread_cond_wait(&cond_, &mutex_); + } + auto ret = pool_.back(); + pool_.pop_back(); + return ret; + } + + void Push(std::shared_ptr data) { + BthreadMutextGuard guard(&mutex_); + pool_.push_back(std::move(data)); + bthread_cond_signal(&cond_); + } + + int Size() { return pool_.size(); } + + std::shared_ptr& GetElement(int i) { return pool_[i]; } + + private: + std::vector> pool_; + bthread_mutex_t mutex_; + bthread_cond_t cond_; + int num_{0}; +}; + +template +struct HeterNode { + K key; + T value; + HeterNode* prev; + HeterNode* next; +}; + +template +class HeterList { + public: + HeterList() : head_(new HeterNode), tail_(new HeterNode) { + head_->prev = NULL; + head_->next = tail_; + tail_->prev = head_; + tail_->next = NULL; + size = 0; + cap_ = 1e9; + } + + ~HeterList() { + delete head_; + delete tail_; + } + + void SetCap(int num) { cap_ = num; } + + bool TryPut(K& key, T& value) { + std::unique_lock lock(mutex_); + cond_.wait(lock, [this] { return size < cap_; }); + if (task_map_.find(key) != task_map_.end()) { + task_map_.erase(key); + return false; + } else { + HeterNode* node = new HeterNode; + node->key = key; + node->value = value; + map_[node->key] = node; + attach(node); + return true; + } + } + + bool Put(K& key, T& value) { + std::unique_lock lock(mutex_); + cond_.wait(lock, [this] { return size < cap_; }); + HeterNode* node = new HeterNode; + node->key = key; + node->value = value; + map_[node->key] = node; + attach(node); + return true; + } + + T TryGet(const K& key) { + std::lock_guard lock(mutex_); + auto iter = map_.find(key); + if (iter != map_.end()) { + HeterNode* node = iter->second; + detach(node); + cond_.notify_one(); + T ret = std::move(node->value); + map_.erase(key); + delete node; + return ret; + } + task_map_.insert(key); + return nullptr; + } + + T Get(const K& key) { + std::lock_guard lock(mutex_); + auto iter = map_.find(key); + if (iter != map_.end()) { + HeterNode* node = iter->second; + detach(node); + cond_.notify_one(); + T ret = std::move(node->value); + map_.erase(key); + delete node; + return ret; + } + return nullptr; + } + + T Get() { + std::lock_guard lock(mutex_); + HeterNode* node = head_->next; + if (node == tail_) { + return nullptr; + } else { + detach(node); + cond_.notify_one(); + T ret = std::move(node->value); + map_.erase(node->key); + delete node; + return ret; + } + } + + bool Empty() { + std::lock_guard lock(mutex_); + return head_->next == tail_; + } + + int Size() { + std::lock_guard lock(mutex_); + return size; + } + + private: + void detach(HeterNode* node) { + node->prev->next = node->next; + node->next->prev = node->prev; + size--; + } + + void attach(HeterNode* node) { + node->prev = head_; + node->next = head_->next; + head_->next->prev = node; + head_->next = node; + size++; + } + + private: + HeterNode* head_; + HeterNode* tail_; + std::unordered_map*> map_; + std::unordered_set task_map_; + std::mutex mutex_; + std::condition_variable cond_; + int cap_; + int size; +}; +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/heterbox_worker.cc b/paddle/fluid/framework/heterbox_worker.cc index 726b651fcf..b7df88218c 100644 --- a/paddle/fluid/framework/heterbox_worker.cc +++ b/paddle/fluid/framework/heterbox_worker.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" -#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/framework/heter_util.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/string/string_helper.h" diff --git a/paddle/fluid/framework/heterxpu_trainer.cc b/paddle/fluid/framework/heterxpu_trainer.cc index 5e1fabf203..8049a1c942 100644 --- a/paddle/fluid/framework/heterxpu_trainer.cc +++ b/paddle/fluid/framework/heterxpu_trainer.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/framework/trainer.h" #if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ (defined PADDLE_WITH_PSLIB) diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 198bb65863..7afa76c3fb 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -176,6 +176,7 @@ void MultiTrainer::Run() { #ifdef PADDLE_WITH_HETERPS void MultiTrainer::MergeDenseParam() { +#ifdef PADDLE_WTIH_PSCORE auto communicator = paddle::distributed::Communicator::GetInstance(); auto& recv_ctx = communicator->GetRecvCtxMap(); Scope* thread_scope = workers_[0]->GetThreadScope(); @@ -189,6 +190,7 @@ void MultiTrainer::MergeDenseParam() { TensorCopy((*tensor), root_tensor->place(), root_tensor); } } +#endif } #endif diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index d178c4e556..66d8a40dda 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker_factory.h" -#include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/string/string_helper.h" @@ -129,8 +128,6 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) { } } } - // pull_queue_ = paddle::framework::MakeChannel>(); - // push_queue_ = paddle::framework::MakeChannel>(); } void PSGPUWorker::SetChannelWriter(ChannelObject* queue) { diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 3ac36bd2e4..636760029f 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -26,8 +26,9 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/device_worker.h" -#include "paddle/fluid/framework/fleet/heter_wrapper.h" -#include "paddle/fluid/framework/heter_service.h" +#include "paddle/fluid/framework/fleet/heter_context.h" +//#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/framework/heter_util.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/reader.h" @@ -46,6 +47,10 @@ class PullDenseWorker; class Scope; class VarDesc; class DeviceWorker; +class HeterWrapper; +class HeterRequest; +class HeterResponse; + template class ChannelObject; diff --git a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h index 22eb2ece4b..7ce63aa9cb 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h +++ b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/controlflow/conditional_block_op.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/operators/pscore/heter_server_test.cc b/paddle/fluid/operators/pscore/heter_server_test.cc index 1d072936f4..df2eb70b14 100644 --- a/paddle/fluid/operators/pscore/heter_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_server_test.cc @@ -20,6 +20,8 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/service/heter_client.h" #include "paddle/fluid/distributed/service/heter_server.h" +#include "paddle/fluid/framework/op_registry.h" + namespace framework = paddle::framework; namespace platform = paddle::platform; namespace distributed = paddle::distributed; diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index b30214e1d8..49da540807 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -73,6 +73,14 @@ if (WITH_CRYPTO) set(PYBIND_SRCS ${PYBIND_SRCS} crypto.cc) endif (WITH_CRYPTO) +if (WITH_PSLIB) + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result") + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS + "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") + endif() + set_source_files_properties(heter_wrapper_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +endif(WITH_PSLIB) if (WITH_PSCORE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result") set_source_files_properties(fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -- GitLab