From c4f279fe8de5ea530242c29177e8fcf64adb3199 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Wed, 24 Feb 2021 17:30:26 +0800 Subject: [PATCH] support multi node in heterps (#31102) * push multi node * multi node * MultiThread * remove log * solve bug in 30829 --- paddle/fluid/framework/fleet/fleet_wrapper.cc | 1 + .../framework/fleet/heter_ps/heter_comm.h | 68 +++++++ .../framework/fleet/heter_ps/heter_comm_inl.h | 184 ++++++++++++++++++ .../framework/fleet/heter_ps/heter_ps.cu | 9 +- .../fluid/framework/fleet/heter_ps/heter_ps.h | 3 + .../framework/fleet/heter_ps/heter_ps_base.h | 3 + paddle/fluid/framework/fleet/heter_wrapper.cc | 1 + .../fluid/framework/fleet/ps_gpu_wrapper.cc | 1 + paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 49 ++++- paddle/fluid/framework/heterbox_trainer.cc | 8 + paddle/fluid/framework/heterbox_worker.cc | 7 + paddle/fluid/framework/hetercpu_worker.cc | 7 + paddle/fluid/framework/heterxpu_trainer.cc | 10 + .../fluid/incubate/fleet/base/role_maker.py | 13 ++ python/paddle/fluid/transpiler/collective.py | 24 +++ 15 files changed, 386 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 7ad20aa6e1..0c0792a95c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "glog/logging.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 5d29999853..77591c6df2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/place.h" #include "thrust/pair.h" @@ -68,7 +69,30 @@ class HeterComm { void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd); + template + void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads, + size_t len, Sgd& sgd); + + template + void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len, + Sgd& sgd); + + int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads, + int len); + + int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads, + int len); + int log2i(int x); + + void set_nccl_comm_and_size(const std::vector& inner_comms, + const std::vector& inter_comms, + int comm_size) { + nccl_inner_comms_ = inner_comms; + nccl_inter_comms_ = inter_comms; + node_size_ = comm_size; + } + bool need_transfer(int send_id, int receive_id) { return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id); } @@ -94,6 +118,44 @@ class HeterComm { std::vector nodes_; }; + struct LocalStorage { + LocalStorage() {} + void init(int size, int dev_id) { + place_ = platform::CUDAPlace(dev_id); + alloc(size, true); + } + + void alloc(int size, bool force = false) { + if (force || size > all_keys_mem->size()) { + all_keys_mem.reset(); + all_grads_mem.reset(); + all_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType)); + all_grads_mem = memory::AllocShared(place_, size * sizeof(GradType)); + all_keys = reinterpret_cast(all_keys_mem->ptr()); + all_grads = reinterpret_cast(all_grads_mem->ptr()); + } + if (force || size > local_keys_mem->size()) { + local_keys_mem.reset(); + local_grads_mem.reset(); + local_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType)); + local_grads_mem = memory::AllocShared(place_, size * sizeof(GradType)); + local_keys = reinterpret_cast(local_keys_mem->ptr()); + local_grads = reinterpret_cast(local_grads_mem->ptr()); + } + } + + platform::CUDAPlace place_; + std::shared_ptr all_keys_mem; + std::shared_ptr all_grads_mem; + KeyType* all_keys; + GradType* all_grads; + + std::shared_ptr local_keys_mem; + std::shared_ptr local_grads_mem; + KeyType* local_keys; + GradType* local_grads; + }; + void init_path(); void create_storage( int start_index, int end_index, int keylen, int vallen, @@ -111,6 +173,12 @@ class HeterComm { CustomGradMerger merger_; int topo_aware_{1}; std::vector> path_; + std::vector storage_; + int feanum_{1800 * 2048}; + int multi_node_{1}; + std::vector nccl_inner_comms_; + std::vector nccl_inter_comms_; + int node_size_; }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index e42a3a324f..4e4563daa1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -95,10 +95,14 @@ template HeterComm::HeterComm( size_t capacity, std::shared_ptr resource) { resource_ = resource; + storage_.resize(resource_->total_gpu()); for (int i = 0; i < resource_->total_gpu(); ++i) { platform::CUDADeviceGuard guard(resource_->dev_id(i)); auto table = new Table(capacity / load_factor_); tables_.push_back(table); + if (multi_node_) { + storage_[i].init(feanum_, resource_->dev_id(i)); + } } init_path(); } @@ -595,6 +599,186 @@ void HeterComm::push_sparse(int gpu_num, } } +template +template +void HeterComm::update_one_table( + int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) { + if (len == 0) { + return; + } + + int dev_id = resource_->dev_id(gpu_num); + platform::CUDADeviceGuard guard(dev_id); + tables_[gpu_num]->update(d_keys, d_grads, len, sgd, + resource_->remote_stream(gpu_num)); + cudaStreamSynchronize(resource_->remote_stream(gpu_num)); +} + +template +template +void HeterComm::push_sparse_multi_node( + int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) { + if (len == 0) { + return; + } + + int uniq_len = len; + merge_grad(gpu_num, d_keys, d_grads, len, uniq_len); + + uniq_len = gather_one_node_grad(gpu_num, d_keys, d_grads, uniq_len); + + uniq_len = gather_multi_node_grad(gpu_num, storage_[gpu_num].local_keys, + storage_[gpu_num].local_grads, uniq_len); + + update_one_table(gpu_num, storage_[gpu_num].local_keys, + storage_[gpu_num].local_grads, uniq_len, sgd); +} + +template +int HeterComm::gather_one_node_grad( + int gpu_num, KeyType* d_keys, GradType* d_grads, int len) { + int total_gpu = resource_->total_gpu(); + int dev_id = resource_->dev_id(gpu_num); + auto& storage = storage_[gpu_num]; + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->local_stream(gpu_num, 0); + int max_size = 0; + + ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num]; + // alloc for size + int h_node_len[total_gpu]; + auto d_node_len_mem = memory::AllocShared(place, total_gpu * sizeof(int)); + int* d_node_len = reinterpret_cast(d_node_len_mem->ptr()); + h_node_len[gpu_num] = len; + + cudaMemcpy(d_node_len + gpu_num, h_node_len + gpu_num, sizeof(int), + cudaMemcpyHostToDevice); + + // allgather grad len + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + (const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt, + nccl_inner_comm, stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu, + cudaMemcpyDeviceToHost); + + for (int i = 0; i < total_gpu; ++i) { + if (h_node_len[i] > max_size) { + max_size = h_node_len[i]; + } + } + storage.alloc(max_size * total_gpu); + + // allgather keys and grads + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + d_keys, storage.all_keys, max_size, ncclUint64, nccl_inner_comm, stream)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8, + nccl_inner_comm, stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + + int h_left[total_gpu]; + int h_right[total_gpu]; + auto d_left = memory::AllocShared(place, total_gpu * sizeof(int)); + auto d_right = memory::AllocShared(place, total_gpu * sizeof(int)); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + + int merge_num = 0; + for (int i = 0; i < total_gpu; ++i) { + int index = i * max_size; + auto d_idx = memory::AllocShared(place, h_node_len[i] * sizeof(int)); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + + cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int)); + cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int)); + + split_input_to_shard(storage.all_keys + index, d_idx_ptr, h_node_len[i], + d_left_ptr, d_right_ptr, gpu_num); + cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), + cudaMemcpyDeviceToHost); + + int grid_size = (h_node_len[i] - 1) / block_size_ + 1; + fill_shard_grads<<>>( + storage.local_keys + merge_num, storage.all_keys + index, + storage.local_grads + merge_num, storage.all_grads + index, + d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1); + merge_num = merge_num + h_right[gpu_num] - h_left[gpu_num] + 1; + } + + int ret = merge_num; + merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret); + return ret; +} + +template +int HeterComm::gather_multi_node_grad( + int gpu_num, KeyType* d_keys, GradType* d_grads, int len) { + int dev_id = resource_->dev_id(gpu_num); + auto& storage = storage_[gpu_num]; + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->local_stream(gpu_num, 0); + int max_size = 0; + ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num]; + // alloc for size + int h_node_len[node_size_]; + auto d_node_len_mem = memory::AllocShared(place, node_size_ * sizeof(int)); + int* d_node_len = reinterpret_cast(d_node_len_mem->ptr()); + h_node_len[0] = len; + + cudaMemcpy(d_node_len, h_node_len, sizeof(int), cudaMemcpyHostToDevice); + + // allgather grad len + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + d_node_len, d_node_len, 1, ncclInt, nccl_inter_comm, stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + cudaMemcpy(h_node_len, d_node_len, sizeof(int) * node_size_, + cudaMemcpyDeviceToHost); + + for (int i = 0; i < node_size_; ++i) { + if (h_node_len[i] > max_size) { + max_size = h_node_len[i]; + } + } + storage.alloc(max_size * node_size_); + + // allgather keys and grads + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + d_keys, storage.all_keys, max_size, ncclUint64, nccl_inter_comm, stream)); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8, + nccl_inter_comm, stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + + int merge_num = 0; + for (int i = 0; i < node_size_; ++i) { + int index = i * max_size; + cudaMemcpyAsync(storage.local_keys + merge_num, storage.all_keys + index, + h_node_len[i], cudaMemcpyDefault, stream); + cudaMemcpyAsync(storage.local_grads + merge_num, storage.all_grads + index, + h_node_len[i], cudaMemcpyDefault, stream); + merge_num += h_node_len[i]; + } + + int ret = merge_num; + merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret); + return ret; +} + template void HeterComm::end_pass() { int total_gpu = resource_->total_gpu(); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index a9db1a5629..f2e129ded9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -54,7 +54,14 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } void HeterPs::push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads, size_t len) { - comm_->push_sparse(num, d_keys, d_grads, len, opt_); + // comm_->push_sparse(num, d_keys, d_grads, len, opt_); + comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_); +} + +void HeterPs::set_nccl_comm_and_size(const std::vector& inner_comms, + const std::vector& inter_comms, + int comm_size) { + comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size); } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 74d24fe43e..142f4a93b9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -35,6 +35,9 @@ class HeterPs : public HeterPsBase { size_t len) override; virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) override; + virtual void set_nccl_comm_and_size( + const std::vector& inner_comms, + const std::vector& inter_comms, int comm_size) override; virtual void end_pass() override; virtual int get_index_by_devid(int devid) override; virtual void show_one_table(int gpu_num) override; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index 29c2f68fc4..7980220eab 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -35,6 +35,9 @@ class HeterPsBase { virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals, size_t len, size_t chunk_size, int stream_num) = 0; virtual int get_index_by_devid(int devid) = 0; + virtual void set_nccl_comm_and_size( + const std::vector& inner_comms, + const std::vector& inter_comms, int comm_size) = 0; virtual void end_pass() = 0; virtual void show_one_table(int gpu_num) = 0; virtual void push_sparse(int num, FeatureKey* d_keys, diff --git a/paddle/fluid/framework/fleet/heter_wrapper.cc b/paddle/fluid/framework/fleet/heter_wrapper.cc index a0667e9adb..a67f9a5e2c 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.cc +++ b/paddle/fluid/framework/fleet/heter_wrapper.cc @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_wrapper.h" #ifdef PADDLE_WITH_PSLIB +#include "paddle/fluid/framework/device_worker.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 728188e702..516f09a9ef 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -233,6 +233,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { } std::vector threads(device_num); HeterPs_ = HeterPsBase::get_instance(size_max, resource_); + HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); auto build_func = [this, &gpu_task, &feature_keys_count](int i) { std::cout << "building table: " << i << std::endl; this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(), diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 8a536fe0b8..fd3323d9d4 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -27,6 +27,10 @@ limitations under the License. */ #include #include +#ifdef PADDLE_WITH_GLOO +#include +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h" @@ -34,6 +38,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/place.h" @@ -80,11 +85,48 @@ class PSGPUWrapper { void BuildTask(std::shared_ptr gpu_task, uint64_t table_id, int feature_dim); void InitializeGPU(const std::vector& dev_ids) { - if (s_instance_ != NULL) { + if (s_instance_ != NULL && is_initialized_ == false) { VLOG(3) << "PSGPUWrapper Begin InitializeGPU"; + is_initialized_ = true; resource_ = std::make_shared(dev_ids); resource_->enable_p2p(); keys_tensor.resize(resource_->total_gpu()); + if (multi_node_) { + int dev_size = dev_ids.size(); + // init inner comm + inner_comms_.resize(dev_size); + inter_ncclids_.resize(dev_size); + platform::dynload::ncclCommInitAll(&(inner_comms_[0]), dev_size, + &dev_ids[0]); +// init inter comm +#ifdef PADDLE_WITH_GLOO + inter_comms_.resize(dev_size); + auto gloo = paddle::framework::GlooWrapper::GetInstance(); + if (gloo->Rank() == 0) { + for (int i = 0; i < dev_size; ++i) { + platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]); + } + } + + PADDLE_ENFORCE_EQ( + gloo->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "You must initialize the gloo environment first to use it.")); + gloo::BroadcastOptions opts(gloo->GetContext()); + opts.setOutput(&inter_ncclids_[0], dev_size); + opts.setRoot(0); + gloo::broadcast(opts); + + for (int i = 0; i < dev_size; ++i) { + platform::dynload::ncclCommInitRank(&inter_comms_[i], gloo->Size(), + inter_ncclids_[i], gloo->Rank()); + } + node_size_ = gloo->Size(); +#else + PADDLE_THROW( + platform::errors::Unavailable("heter ps need compile with GLOO")); +#endif + } heter_devices_ = dev_ids; } } @@ -177,6 +219,11 @@ class PSGPUWrapper { std::shared_ptr resource_; int32_t sleep_seconds_before_fail_exit_; std::vector slot_vector_; + int multi_node_{1}; + int node_size_; + std::vector inner_comms_; + std::vector inter_comms_; + std::vector inter_ncclids_; std::vector heter_devices_; std::unordered_set gpu_ps_config_keys_; HeterObjectPool gpu_task_pool_; diff --git a/paddle/fluid/framework/heterbox_trainer.cc b/paddle/fluid/framework/heterbox_trainer.cc index bdbcf9d1da..3e55576b84 100644 --- a/paddle/fluid/framework/heterbox_trainer.cc +++ b/paddle/fluid/framework/heterbox_trainer.cc @@ -12,6 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include +#include +#include "io/fs.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/trainer.h" #if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ (defined PADDLE_WITH_PSLIB) diff --git a/paddle/fluid/framework/heterbox_worker.cc b/paddle/fluid/framework/heterbox_worker.cc index 1d9b510ae9..726b651fcf 100644 --- a/paddle/fluid/framework/heterbox_worker.cc +++ b/paddle/fluid/framework/heterbox_worker.cc @@ -12,6 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/string/string_helper.h" + #if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ (defined PADDLE_WITH_PSLIB) #include "paddle/fluid/platform/cuda_device_guard.h" diff --git a/paddle/fluid/framework/hetercpu_worker.cc b/paddle/fluid/framework/hetercpu_worker.cc index 2142c64de8..f50cc2769e 100644 --- a/paddle/fluid/framework/hetercpu_worker.cc +++ b/paddle/fluid/framework/hetercpu_worker.cc @@ -12,6 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/string/string_helper.h" + #ifdef PADDLE_WITH_PSLIB #if defined _WIN32 || defined __APPLE__ diff --git a/paddle/fluid/framework/heterxpu_trainer.cc b/paddle/fluid/framework/heterxpu_trainer.cc index e6f3572fc0..5e1fabf203 100644 --- a/paddle/fluid/framework/heterxpu_trainer.cc +++ b/paddle/fluid/framework/heterxpu_trainer.cc @@ -12,6 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include +#include +#include +#include "io/fs.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include "paddle/fluid/framework/trainer.h" #if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \ (defined PADDLE_WITH_PSLIB) #ifdef PADDLE_WITH_CUDA diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index 6db2e65bcf..e3c417d4a6 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -599,6 +599,7 @@ class GeneralRoleMaker(RoleMakerBase): self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600) self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999) ip_port = kwargs.get("http_ip_port", "") + self._use_ps_gpu = kwargs.get("use_ps_gpu", False) self._http_ip_port = [] self._http_server = None # if ip_port is not empty, it will use http instead of hdfs @@ -666,6 +667,18 @@ class GeneralRoleMaker(RoleMakerBase): self._hdfs_name, self._hdfs_ugi) gloo.init() self._node_type_comm = gloo + if self._use_ps_gpu: + Gloo_strategy = fluid.core.GlooParallelStrategy() + Gloo_strategy.rank = current_id + Gloo_strategy.rank_num = len(worker_endpoints) + Gloo_strategy.ip_address = self._http_ip_port[0] + Gloo_strategy.ip_port = int(self._http_ip_port[1]) + Default_init_timeout_seconds = 3600 + Default_run_timeout_seconds = 9999999 + Gloo_strategy.init_seconds = Default_init_timeout_seconds + Gloo_strategy.run_seconds = Default_run_timeout_seconds + Gloo = fluid.core.GlooParallelContext(Gloo_strategy) + Gloo.init() else: self._all_comm = MockBarrier() elif training_role == "PSERVER": diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index ae4befa004..752ec0672c 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -386,3 +386,27 @@ class SingleProcessMultiThread(GradAllReduce): def _transpile_startup_program(self): block = self.startup_program.global_block() block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) + + +class MultiThread(GradAllReduce): + ''' + ''' + + def __init__(self, nrings=1): + GradAllReduce.__init__(self, nrings) + self.mode = "box" + + def _transpile_startup_program(self): + if len(self.endpoints) > 1: + print("begin to _transpile_startup_program for multi-node") + print("current_endpoint: ", self.current_endpoint) + print("total endpoints: ", self.endpoints) + print("rank: %d, ring_id: %d" % (self.rank, self.nrings)) + for ring_id in range(self.nrings): + self._init_communicator( + self.startup_program, self.current_endpoint, self.endpoints, + self.rank, ring_id, self.wait_port, True) + else: + print("begin to _transpile_startup_program for single-node") + block = self.startup_program.global_block() + block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) -- GitLab