diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index ae5c9504ecb6ee25a2b5fea19dab34588ec8fe82..739e05e1d79712a6551c4b97f5e034b0f93ee1b8 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -305,6 +305,7 @@ message DistributedStrategy { optional bool semi_auto = 35 [ default = false ]; optional bool adam_d2sum = 36 [ default = true ]; optional bool auto_search = 37 [ default = false ]; + optional bool heter_ccl_mode = 38 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 9121610d29eaa0f64ef3041caa6a9a2c89a1b038..594b0d48a8aad8f5fcd9f143249ce48b5e939ce1 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -30,6 +30,9 @@ if(NOT WIN32) cc_library(hccl_context SRCS hccl_context.cc DEPS collective_helper device_context tensor var_type_traits) cc_library(reducer SRCS reducer.cc DEPS layer) endif() + if(WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL) + cc_library(heter_ccl_context SRCS heter_ccl_context.cc DEPS collective_helper device_context tensor var_type_traits) + endif() cc_library(data_loader SRCS data_loader.cc DEPS enforce) endif(NOT WIN32) if(WITH_GLOO) diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc index 8c6b840f60a591f82005f2eebc6c3e52f482591f..6569929d6f5d74447c626a56f1bcd2ff20e99b23 100644 --- a/paddle/fluid/imperative/bkcl_context.cc +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -150,6 +150,23 @@ void BKCLParallelContext::AllReduceByStream(const framework::Variable &src, } } +void BKCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { + VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id; + framework::Tensor *src_tensor = src->GetMutable(); + const auto &place = src_tensor->place(); + platform::BKCLComm *comm = + platform::BKCLCommContext::Instance().Get(ring_id, place); + XPUStream stream = comm->stream(); + + void *src_ptr = src_tensor->data(); + auto data_type = platform::ToBKCLDataType(src_tensor->type()); + + PADDLE_ENFORCE_EQ(bkcl_broadcast(comm->comm(), src_ptr, src_ptr, + src_tensor->numel(), data_type, 0, stream), + BKCL_SUCCESS, + platform::errors::Unavailable("bkcl_broadcast failed")); +} + paddle::platform::DeviceContext *BKCLParallelContext::GetDeviceContext( int ring_id) { return static_cast( diff --git a/paddle/fluid/imperative/bkcl_context.h b/paddle/fluid/imperative/bkcl_context.h index 652b7689666c6c66c4efe6edda0c23acfc0cab27..a5a10b19389c0d6decdf0cf223dba49ae13220ec 100644 --- a/paddle/fluid/imperative/bkcl_context.h +++ b/paddle/fluid/imperative/bkcl_context.h @@ -42,6 +42,8 @@ class BKCLParallelContext : public ParallelContext { framework::Variable* dst, int ring_id, bool use_calc_stream) override; + void Broadcast(framework::Variable* src, int ring_id) override; + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; void WaitCompute(int ring_id) override; diff --git a/paddle/fluid/imperative/gloo_context.cc b/paddle/fluid/imperative/gloo_context.cc index ef1bf0d158787e517168fdd5f2dc60180ae05f12..1eaf0c6538043ff274b8a30f8618373deea771b0 100644 --- a/paddle/fluid/imperative/gloo_context.cc +++ b/paddle/fluid/imperative/gloo_context.cc @@ -37,7 +37,7 @@ void GLOOParallelContext::Init() { gloo_wrapper->SetSize(strategy_.nranks_); gloo_wrapper->SetRank(strategy_.local_rank_); gloo_wrapper->SetPrefix(""); - gloo_wrapper->SetIface("lo"); + gloo_wrapper->SetIface(""); auto addr = paddle::string::Split(strategy_.trainer_endpoints_[0], ':'); VLOG(4) << "Server is" << strategy_.trainer_endpoints_[0]; std::string host = addr[0]; @@ -176,6 +176,11 @@ void GLOOParallelContext::AllReduce(const framework::SelectedRows &src, } } +void GLOOParallelContext::Broadcast(framework::Variable *src, int ring_id) { + PADDLE_THROW(platform::errors::Unimplemented( + "Unimplemented inter-broadcast for CPU now.")); +} + paddle::platform::DeviceContext *GLOOParallelContext::GetDeviceContext( int ring_id) { // return the CPUDeviceContext diff --git a/paddle/fluid/imperative/gloo_context.h b/paddle/fluid/imperative/gloo_context.h index 305a75a881153fc1fa79b4f3e04adc086a21576e..e7c9ba4cfddb656c7a61df84c5bb59ac5be84eb7 100644 --- a/paddle/fluid/imperative/gloo_context.h +++ b/paddle/fluid/imperative/gloo_context.h @@ -47,6 +47,8 @@ class GLOOParallelContext : public ParallelContext { framework::Variable* dst, int ring_id, bool use_calc_stream) override; + void Broadcast(framework::Variable* src, int ring_id) override; + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; void WaitCompute(int ring_id) override; diff --git a/paddle/fluid/imperative/hccl_context.cc b/paddle/fluid/imperative/hccl_context.cc index 4f1135fa9ddd48c7a803ebdf434fb34f92cb23ca..55c52ae6c11de8659b21d797e8e4172ca77e5e8f 100644 --- a/paddle/fluid/imperative/hccl_context.cc +++ b/paddle/fluid/imperative/hccl_context.cc @@ -158,6 +158,29 @@ void HCCLParallelContext::AllReduceByStream(const framework::Variable &src, } } +void HCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { + VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id; + if (src->IsType()) { + framework::Tensor *src_tensor = src->GetMutable(); + const auto &place = src_tensor->place(); + platform::HCCLComm *comm = + platform::HCCLCommContext::Instance().Get(ring_id, place); + aclrtStream stream = comm->stream(); + + void *src_ptr = + reinterpret_cast(const_cast(src_tensor->data())); + auto hccl_dtype = platform::ToHCCLDataType(src_tensor->type()); + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast( + src_ptr, src_tensor->numel(), hccl_dtype, 0, comm->comm(), + reinterpret_cast(stream))); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported variable type %s for imperative allreduce, only " + "LoDTensor is supported.", + platform::demangle(framework::ToTypeName(src->Type())))); + } +} + paddle::platform::DeviceContext *HCCLParallelContext::GetDeviceContext( int ring_id) { return static_cast( diff --git a/paddle/fluid/imperative/hccl_context.h b/paddle/fluid/imperative/hccl_context.h index b7f22f3a0b0f160da497b5e01d81bf89d0315e6a..e5f58dea9fb06151bdc536f9025a6bf0e3b12092 100644 --- a/paddle/fluid/imperative/hccl_context.h +++ b/paddle/fluid/imperative/hccl_context.h @@ -50,6 +50,8 @@ class HCCLParallelContext : public ParallelContext { framework::Variable* dst, int ring_id, bool use_calc_stream) override; + void Broadcast(framework::Variable* src, int ring_id) override; + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; void WaitCompute(int ring_id) override; diff --git a/paddle/fluid/imperative/heter_ccl_context.cc b/paddle/fluid/imperative/heter_ccl_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..a62c1da7815979fb1be7cf2e36e2f14f86e20faf --- /dev/null +++ b/paddle/fluid/imperative/heter_ccl_context.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/heter_ccl_context.h" + +// NCCL first +#ifdef PADDLE_WITH_NCCL +#include "paddle/fluid/imperative/all_reduce.h" +#endif + +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { + +HeterParallelContext::HeterParallelContext(const ParallelStrategy &strategy, + const int &device_id) +#ifdef PADDLE_WITH_NCCL + : ParallelContext(strategy, platform::CUDAPlace(device_id)) +#elif PADDLE_WITH_XPU_BKCL + : ParallelContext(strategy, platform::XPUPlace(device_id)) +#elif PADDLE_WITH_ASCEND_CL + : ParallelContext(strategy, platform::NPUPlace(device_id)) +#else + : ParallelContext(strategy, platform::CPUPlace()) +#endif +{ + // construct node_strategy_ from global strategy by selecting the + // endpoints with same ip address. + std::string node_ip = strategy_.current_endpoint_.substr( + 0, strategy_.current_endpoint_.find(':')); + int node_nranks = 0; + int inter_rank = -1; + + std::vector all_eps = strategy_.trainer_endpoints_; + std::vector inter_endpoints; + std::set nodes_ips; + for (auto ep : all_eps) { + std::string ip = ep.substr(0, ep.find(':')); + // record ip of different nodes + if (nodes_ips.find(ip) == nodes_ips.end()) { + if (ep == strategy_.current_endpoint_) { + inter_rank = nodes_ips.size(); + } + inter_endpoints.push_back(ep); + nodes_ips.emplace(ip); + } + + if (ip == node_ip) { + if (ep == strategy_.current_endpoint_) { + node_strategy_.local_rank_ = node_nranks; + } + node_nranks++; + node_strategy_.trainer_endpoints_.push_back(ep); + } + } + + VLOG(0) << "init node size " << node_nranks << " rank " + << node_strategy_.local_rank_; + + PADDLE_ENFORCE_NE(node_nranks, 0, + platform::errors::InvalidArgument( + "The number of local nranks should not be zero.")); + node_strategy_.nranks_ = node_nranks; + node_strategy_.current_endpoint_ = strategy_.current_endpoint_; + + if (inter_rank >= 0 && inter_endpoints.size() > 1) { + inter_strategy_.nranks_ = inter_endpoints.size(); + inter_strategy_.local_rank_ = inter_rank; + inter_strategy_.current_endpoint_ = strategy_.current_endpoint_; + inter_strategy_.trainer_endpoints_ = inter_endpoints; + inter_parallel_ctx_ = std::make_shared( + inter_strategy_, platform::CPUPlace()); + } + + VLOG(0) << "init inter size " << inter_endpoints.size() << " rank " + << inter_rank; + +#ifdef PADDLE_WITH_NCCL + node_place_ = platform::CUDAPlace(device_id); + node_parallel_ctx_ = + std::make_shared(node_strategy_, node_place_); +#endif +#ifdef PADDLE_WITH_XPU_BKCL + node_place_ = platform::XPUPlace(device_id); + node_parallel_ctx_ = + std::make_shared(node_strategy_, node_place_); +#endif +#ifdef PADDLE_WITH_ASCEND_CL + node_place_ = platform::NPUPlace(device_id); + node_parallel_ctx_ = + std::make_shared(node_strategy_, node_place_); +#endif +} + +void HeterParallelContext::Init() { + PADDLE_ENFORCE_NE( + node_parallel_ctx_, nullptr, + platform::errors::Unavailable( + "The heter parallel context has not been initialized.")); + + if (inter_parallel_ctx_ != nullptr) { + inter_parallel_ctx_->Init(); + } + + node_parallel_ctx_->Init(); + + VLOG(3) << "/// DEBUG /// heter parallel env init done..." << std::endl; +} + +void HeterParallelContext::InitWithRingID(int ring_id) { + PADDLE_THROW(platform::errors::Unimplemented( + "Unimplemented InitWithRingID from heter ctx.")); +} + +void HeterParallelContext::AllReduceByStream(const framework::Variable &src, + framework::Variable *dst, + int ring_id, + bool use_calc_stream) { + // step 1: call reduce within node + VLOG(3) << "/// DEBUG /// step 1: reduce in node... "; + node_parallel_ctx_->AllReduceByStream(src, dst, ring_id, false); + node_parallel_ctx_->WaitComm(ring_id); + + // step 2: call allreduce between nodes with gloo + if (inter_parallel_ctx_ != nullptr) { + // copy src to cpu + // dst is now the src + auto src_tensor = dst->Get(); + framework::Variable src_cpu; + auto src_cpu_tensor = src_cpu.GetMutable(); + framework::TensorCopySync(src_tensor, platform::CPUPlace(), src_cpu_tensor); + + // allreduce src/cpu to dst/cpu + framework::Variable dst_cpu; + inter_parallel_ctx_->AllReduceByStream(src_cpu, &dst_cpu, ring_id, false); + inter_parallel_ctx_->WaitComm(ring_id); + + // copy dst/cpu to dst + auto dst_cpu_tensor = dst_cpu.Get(); + auto dst_tensor = dst->GetMutable(); + framework::TensorCopySync(dst_cpu_tensor, dst_tensor->place(), dst_tensor); + + inter_parallel_ctx_->WaitComm(ring_id); + } + + // step 3: call broadcast within node + VLOG(3) << "/// DEBUG /// step 3: broadcast within node... "; + node_parallel_ctx_->WaitComm(ring_id); + node_parallel_ctx_->Broadcast(dst, ring_id); + node_parallel_ctx_->WaitComm(ring_id); +} + +void HeterParallelContext::Broadcast(framework::Variable *src, int ring_id) { + PADDLE_THROW(platform::errors::Unimplemented("Unimplemented function.")); +} + +paddle::platform::DeviceContext *HeterParallelContext::GetDeviceContext( + int ring_id) { + // directly call the implementation of target parallel ctx. + return node_parallel_ctx_->GetDeviceContext(ring_id); +} + +void HeterParallelContext::WaitCompute(int ring_id) { + // directly call the implementation of target parallel ctx. + node_parallel_ctx_->WaitCompute(ring_id); +} + +void HeterParallelContext::WaitComm(int ring_id) { + // directly call the implementation of target parallel ctx. + node_parallel_ctx_->WaitComm(ring_id); +} + +void HeterParallelContext::SynchronizeCompute() { + // directly call the implementation of target parallel ctx. + node_parallel_ctx_->SynchronizeCompute(); +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/heter_ccl_context.h b/paddle/fluid/imperative/heter_ccl_context.h new file mode 100644 index 0000000000000000000000000000000000000000..8ea5e85603ab5fae446e90eeb44005e7e4c71fc3 --- /dev/null +++ b/paddle/fluid/imperative/heter_ccl_context.h @@ -0,0 +1,78 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#ifdef PADDLE_WITH_NCCL +#include "paddle/fluid/imperative/nccl_context.h" +#endif + +#ifdef PADDLE_WITH_XPU_BKCL +#include "paddle/fluid/imperative/bkcl_context.h" +#endif + +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/imperative/hccl_context.h" +#endif + +#include "paddle/fluid/imperative/gloo_context.h" +#include "paddle/fluid/imperative/parallel_context.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { + +class HeterParallelContext : public ParallelContext { + public: + explicit HeterParallelContext(const ParallelStrategy& strategy, + const int& device_id); + + ~HeterParallelContext() override = default; + + void Init() override; + + void InitWithRingID(int ring_id) override; + + void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, int ring_id, + bool use_calc_stream) override; + + void Broadcast(framework::Variable* src, int ring_id) override; + + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; + + void WaitCompute(int ring_id) override; + + void WaitComm(int ring_id) override; + + void SynchronizeCompute() override; + + private: + ParallelStrategy inter_strategy_; + ParallelStrategy node_strategy_; + platform::Place node_place_; + std::shared_ptr node_parallel_ctx_{nullptr}; + std::shared_ptr inter_parallel_ctx_{nullptr}; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 0eb06983f409b1924f87141aa1e19bda35ca7558..f822894b42b0b58b83b1295ed589a2edfec77b71 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -20,7 +20,15 @@ #include "paddle/fluid/platform/gen_comm_id_helper.h" #endif +#ifdef PADDLE_WITH_NCCL +#include +#include "paddle/fluid/platform/dynload/nccl.h" +#endif + +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -127,6 +135,20 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, AllReduce(src, dst, strategy_, ring_id, use_calc_stream); } +void NCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { + VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id; + framework::Tensor *src_tensor = src->GetMutable(); + const auto &place = src_tensor->place(); + platform::NCCLComm *comm = + platform::NCCLCommContext::Instance().Get(ring_id, place); + gpuStream_t stream = comm->stream(); + + void *src_ptr = src_tensor->data(); + auto nccl_dtype = platform::ToNCCLDataType(src_tensor->type()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( + src_ptr, src_tensor->numel(), nccl_dtype, 0, comm->comm(), stream)); +} + paddle::platform::DeviceContext *NCCLParallelContext::GetDeviceContext( int ring_id) { return static_cast( diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index 1938fa08312f61c83690986d3afb98c125855123..bb5b8ea32df4f46f7c093cb930b0ba6dd6f1846d 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -60,6 +60,8 @@ class NCCLParallelContext : public ParallelContext { framework::Variable* dst, int ring_id, bool use_calc_stream) override; + void Broadcast(framework::Variable* src, int ring_id) override; + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; void WaitCompute(int ring_id) override; diff --git a/paddle/fluid/imperative/parallel_context.h b/paddle/fluid/imperative/parallel_context.h index f537a316014d60ed18250d72de3ec2b7dd95cf05..8bdfccc144243687da3bcda8c53af9470a5021d2 100644 --- a/paddle/fluid/imperative/parallel_context.h +++ b/paddle/fluid/imperative/parallel_context.h @@ -56,6 +56,8 @@ class ParallelContext { framework::Variable* dst, int ring_id, bool use_calc_stream) = 0; + virtual void Broadcast(framework::Variable* src, int ring_id) = 0; + virtual paddle::platform::DeviceContext* GetDeviceContext(int ring_id) = 0; // comm_stream[ring_id] wait compute_stream. diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 2f023f644fd060d6f72002f58fe3b17f9fc469a1..068de4f0435bbec8fb83aa9ee8b0cefdd71be06b 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -27,8 +27,9 @@ namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \ + defined(PADDLE_WITH_ASCEND_CL) // div the nranks void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { framework::Tensor *tensor = @@ -41,6 +42,9 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DivNRanks(tensor, nranks, context); #endif + } else if (platform::is_npu_place(tensor->place())) { + // TODO(kuizhiqing) + VLOG(4) << "divnrank for npu not support yet"; } else if (platform::is_cpu_place(tensor->place())) { VLOG(4) << "before div 2" << *tensor; VLOG(4) << "NDiv for cpu devices : rank = " << nranks; @@ -207,6 +211,70 @@ void SplitTensorsWithType( } #endif +// NOTE(liubo48): Only implement operators::math::SplitFunctor for npu now. +// If later the operators::StridedMemcpyWithAxis0 is supported, +// then this specific SplitTensorsForAllReduce can be removed. +#ifdef PADDLE_WITH_ASCEND_CL +template <> +void SplitTensorsForAllReduce( + const platform::NPUDeviceContext &context, + framework::Variable *p_dense_contents, + std::vector *p_dense_tensors) { + auto *in = p_dense_contents->GetMutable(); + std::vector outs; + std::vector shape_refer; + + outs.reserve(p_dense_tensors->size()); + shape_refer.reserve(p_dense_tensors->size()); + + for (auto &tensor : *p_dense_tensors) { + outs.emplace_back(&tensor); + shape_refer.emplace_back(&tensor); + } + operators::math::SplitFunctor + split_functor_; + split_functor_(context, *in, shape_refer, 0, &outs); +} + +template <> +void ConcatTensorsWithType( + const platform::NPUDeviceContext &context, + const std::vector &dense_tensors_, + framework::Variable *p_dense_contents, + framework::proto::VarType::Type type) { + switch (type) { + case framework::proto::VarType::FP32: + ConcatTensorsForAllReduce( + context, dense_tensors_, p_dense_contents); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it concats tensors for " + "allreduce.", + framework::DataTypeToString(type))); + } +} + +template <> +void SplitTensorsWithType( + const platform::NPUDeviceContext &context, + framework::Variable *p_dense_contents, + std::vector *p_dense_tensors, + framework::proto::VarType::Type type) { + switch (type) { + case framework::proto::VarType::FP32: + SplitTensorsForAllReduce( + context, p_dense_contents, p_dense_tensors); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it splits tensors for " + "allreduce.", + framework::DataTypeToString(type))); + } +} +#endif + void Group::ConcatTensors(const platform::DeviceContext &context) { auto place = context.GetPlace(); if (platform::is_gpu_place(place)) { @@ -831,7 +899,7 @@ void Reducer::MarkGroupReady(size_t group_index) { } }); #elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \ - defined(PADDLE_WITH_GLOO) + defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) FusedAllReduceSchedule(run_order, group, next_group_); #else PADDLE_THROW(platform::errors::PreconditionNotMet( @@ -1014,7 +1082,7 @@ void Reducer::FinalizeBackward() { if (find_unused_vars_each_step_) { // TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_GLOO) + defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) ProcessUnusedDenseVars(); #endif // Initialize local used vars diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index b5a7dd149f09fefeac7c8a80e6d541534573a3bf..3c03babc52cbe1b826c520dd7b989e3756fc2a2e 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -48,8 +48,9 @@ class VariableWrapper; namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \ + defined(PADDLE_WITH_ASCEND_CL) template struct DivNRanksFunctor { diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index adb560df77c78f208cbd7ec66400bbcd9b23f1c8..01a24872fbd7c23e97d3aba1c3e646564a02b6ed 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -3,6 +3,8 @@ if(WIN32) else() if (WITH_NCCL OR WITH_RCCL) cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) + cc_test(heter_ccl_context_test SRCS heter_ccl_context_test.cc DEPS heter_ccl_context nccl_context imperative_gloo_context gloo_context gloo_wrapper gloo fs shell) + #set_tests_properties(heter_ccl_context_test PROPERTIES LABELS "RUN_TYPE=DIST") endif() if (WITH_XPU_BKCL) cc_test(bkcl_context_test SRCS bkcl_context_test.cc DEPS bkcl_context) diff --git a/paddle/fluid/imperative/tests/heter_ccl_context_test.cc b/paddle/fluid/imperative/tests/heter_ccl_context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c40a5fc52ceb86bd516a21f6ca4443bde42c08bc --- /dev/null +++ b/paddle/fluid/imperative/tests/heter_ccl_context_test.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include // NOLINT + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/imperative/heter_ccl_context.h" + +#include "gtest/gtest.h" + +namespace imperative = paddle::imperative; +namespace platform = paddle::platform; +namespace framework = paddle::framework; + +imperative::ParallelStrategy GetStrategy(int local_rank) { + std::vector eps = {"127.0.0.1:37580", "127.0.0.1:37581"}; + imperative::ParallelStrategy strategy; + strategy.trainer_endpoints_ = eps; + strategy.current_endpoint_ = eps[local_rank]; + strategy.nranks_ = eps.size(); + strategy.local_rank_ = local_rank; + return strategy; +} + +#ifdef PADDLE_WITH_NCCL +void AllReduceByStream(int local_rank, int device_id) { + int data_size = 32; + const auto& place = platform::CUDAPlace(device_id); + platform::CUDADeviceContext ctx(place); + + // heter_parallel_ctx + imperative::HeterParallelContext hpc(GetStrategy(local_rank), device_id); + + // init + hpc.Init(); + + // input and output data + framework::Variable* src_dev_var(new framework::Variable()); + auto* src_dev_tensor = src_dev_var->GetMutable(); + src_dev_tensor->mutable_data(framework::make_ddim({data_size}), place); + + std::vector src_vec; + for (int i = 0; i < data_size; i++) { + src_vec.push_back(1.0 + local_rank); + } + framework::TensorFromVector(src_vec, ctx, src_dev_tensor); + ctx.Wait(); + + framework::Variable* dst_dev_var(new framework::Variable()); + auto* dst_dev_tensor = dst_dev_var->GetMutable(); + dst_dev_tensor->mutable_data(framework::make_ddim({data_size}), place); + + // call allreduce + hpc.AllReduceByStream(*src_dev_var, dst_dev_var, 0, false); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // check result + std::vector dst_vec; + framework::TensorToVector(*dst_dev_tensor, ctx, &dst_vec); + ctx.Wait(); + + EXPECT_EQ(dst_vec.size(), src_vec.size()); + for (int i = 0; i < data_size; i++) { + EXPECT_EQ(dst_vec[i], 3.0); + } +} + +TEST(AllReduceByStream, Run) { + if (platform::GetCUDADeviceCount() >= 2) { + std::thread t0(AllReduceByStream, 0, 0); + std::thread t1(AllReduceByStream, 1, 1); + t0.join(); + t1.join(); + } +} +#endif diff --git a/paddle/fluid/imperative/tests/nccl_context_test.cc b/paddle/fluid/imperative/tests/nccl_context_test.cc index 2d8a08217b0b83cfc22c250551e9aa81e01e86c0..b56444104f2779d7f56e2c945ac79063f6aac275 100644 --- a/paddle/fluid/imperative/tests/nccl_context_test.cc +++ b/paddle/fluid/imperative/tests/nccl_context_test.cc @@ -14,6 +14,8 @@ #include // NOLINT +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" @@ -21,6 +23,7 @@ namespace imperative = paddle::imperative; namespace platform = paddle::platform; +namespace framework = paddle::framework; int nrings = 2; imperative::ParallelStrategy GetStrategy(int local_rank) { @@ -68,4 +71,51 @@ TEST(BcastNCCLId, Run) { NCCL_UNIQUE_ID_BYTES)); } } + +void Broadcast(int local_rank, int device_id) { + int data_size = 4; + float test_data = 7; + const auto& place = platform::CUDAPlace(device_id); + platform::CUDADeviceContext ctx(place); + + imperative::NCCLParallelContext npc(GetStrategy(local_rank), place); + + // init + npc.Init(); + + framework::Variable* src_dev_var(new framework::Variable()); + auto* src_dev_tensor = src_dev_var->GetMutable(); + src_dev_tensor->mutable_data(framework::make_ddim({data_size}), place); + + // fill data for rank 0 only + std::vector src_vec; + if (local_rank == 0) { + for (int i = 0; i < data_size; i++) { + src_vec.push_back(test_data); + } + framework::TensorFromVector(src_vec, ctx, src_dev_tensor); + } + ctx.Wait(); + + npc.Broadcast(src_dev_var, 0); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // check result + std::vector dst_vec; + framework::TensorToVector(*src_dev_tensor, ctx, &dst_vec); + ctx.Wait(); + + for (int i = 0; i < data_size; i++) { + EXPECT_EQ(dst_vec[i], test_data); + } +} + +TEST(Broadcast, Run) { + if (platform::GetCUDADeviceCount() >= 2) { + std::thread t0(Broadcast, 0, 0); + std::thread t1(Broadcast, 1, 1); + t0.join(); + t1.join(); + } +} #endif diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 41708ef8611e42571ec4a9a932042185b9692425..521ca032a50ddb6703608643f8102fab073fd7a2 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -25,6 +25,13 @@ endif() if (WITH_XPU_BKCL) set(PYBIND_DEPS ${PYBIND_DEPS} reducer) set(PYBIND_DEPS ${PYBIND_DEPS} bkcl_context) + set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context) +endif() + +if (WITH_ASCEND_CL) + set(PYBIND_DEPS ${PYBIND_DEPS} reducer) + set(PYBIND_DEPS ${PYBIND_DEPS} hccl_context) + set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context) endif() if(NOT WIN32) @@ -32,9 +39,7 @@ if(NOT WIN32) set(PYBIND_DEPS ${PYBIND_DEPS} mmap_allocator) if (WITH_NCCL OR WITH_RCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context) - endif() - if (WITH_ASCEND_CL) - set(PYBIND_DEPS ${PYBIND_DEPS} hccl_context) + set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context) endif() endif(NOT WIN32) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 29a1f0eafcb219bee0a1fd54dc0d19a417d3ca82..2c850f0ca84d5f4a79f023646e9370e7a382a160 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/data_loader.h" #include "paddle/fluid/imperative/gloo_context.h" #include "paddle/fluid/imperative/hccl_context.h" +#include "paddle/fluid/imperative/heter_ccl_context.h" #include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/nccl_context.h" @@ -2332,6 +2333,15 @@ void BindImperative(py::module *m_ptr) { py::arg("ring_id")); #endif +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) + py::class_>( + m, "HeterParallelContext") + .def(py::init()) + .def("init", [](imperative::HeterParallelContext &self) { self.Init(); }); +#endif + m.def("pylayer_apply", [](const platform::CPUPlace &place, const py::object &cls, const py::args args, const py::kwargs kwargs) { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index cc0a5de233c382a20266866146f4c85050e921c5..e58b6c312fa1fe28b1565cc4421177c2c139cf21 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1758,6 +1758,37 @@ class DistributedStrategy(object): else: print("WARNING: auto-search should have value of bool type") + @property + def heter_ccl_mode(self): + """ + Indicating whether we are using heter_ccl_mode for model training. + This feature is currently an experimental feature. Currently, + heter_ccl_mode can be used only for dataparallel with dygraph mode. + Default Value: False + + Examples: + + .. code-block:: python + + import paddle + import paddle.distributed.fleet as fleet + + strategy = fleet.DistributedStrategy() + strategy.heter_ccl_mode = True + + # for initialize parallel env, only need to call + paddle.distributed.init_parallel_env() + # then the heterogenous context will be created. + """ + return self.strategy.heter_ccl_mode + + @heter_ccl_mode.setter + def heter_ccl_mode(self, flag): + if isinstance(flag, bool): + self.strategy.heter_ccl_mode = flag + else: + print("WARNING: heter_ccl_mode should have value of bool type") + @property def cudnn_exhaustive_search(self): """ diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index a1e5ef2ba799fce61dada7b280ecbcdbcd4a13ca..0d54a0ea5d3b1620522e097615c3adfd5f94d121 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -33,7 +33,7 @@ from . import topology as tp from .topology import ParallelMode from ..meta_parallel import TensorParallel, model_parallel_random_seed from ..meta_parallel import PipelineParallel, ShardingParallel -from ..meta_optimizers import HybridParallelOptimizer +from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer from paddle import _C_ops from paddle.fluid import core from paddle.fluid.dygraph import to_variable @@ -277,13 +277,15 @@ class Fleet(object): self._user_defined_strategy.nccl_comm_num) paddle.distributed.init_parallel_env() - # init hybrid parallel environment in dygraph - if tp._HYBRID_PARALLEL_GROUP is None: - self._init_hybrid_parallel_env() - else: - warnings.warn( - "The dygraph hybrid parallel environment has been initialized." - ) + # hybrid parallel not support for npu/xpu + if self._user_defined_strategy.heter_ccl_mode == False: + # init hybrid parallel environment in dygraph + if tp._HYBRID_PARALLEL_GROUP is None: + self._init_hybrid_parallel_env() + else: + warnings.warn( + "The dygraph hybrid parallel environment has been initialized." + ) elif self._is_collective: use_sharding = self._user_defined_strategy.sharding @@ -872,8 +874,12 @@ class Fleet(object): if paddle.fluid.framework.in_dygraph_mode(): if self.worker_num() > 1: - return HybridParallelOptimizer(optimizer, self._hcg, - self._user_defined_strategy) + if self._user_defined_strategy.heter_ccl_mode == False: + return HybridParallelOptimizer(optimizer, self._hcg, + self._user_defined_strategy) + else: + return HeterParallelOptimizer(optimizer, + self._user_defined_strategy) else: return optimizer return self @@ -938,6 +944,17 @@ class Fleet(object): if self.worker_num() <= 1: return model + if self._user_defined_strategy.heter_ccl_mode == True: + distributed_model = paddle.DataParallel( + model, + comm_buffer_size=self._user_defined_strategy. + fuse_grad_size_in_MB, + last_comm_buffer_size=self._user_defined_strategy. + last_comm_group_size_MB, + find_unused_parameters=self._user_defined_strategy. + find_unused_parameters) + return distributed_model + if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: distributed_model = ShardingParallel( model, self._hcg, strategy=self._user_defined_strategy) @@ -1569,13 +1586,13 @@ class Fleet(object): ] param_grads_fp16 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and (param._grad_ivar( - ).dtype == core.VarDesc.VarType.FP16) + if (param._grad_ivar() is not None) and + (param._grad_ivar().dtype == core.VarDesc.VarType.FP16) ] param_grads_fp32 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and (param._grad_ivar( - ).dtype == core.VarDesc.VarType.FP32) + if (param._grad_ivar() is not None) and + (param._grad_ivar().dtype == core.VarDesc.VarType.FP32) ] temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) diff --git a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py index 52eeebd0c126c241e2f0d961d6bc9138607c5181..322989099c856d6bfe82cbe4ae237f84ebe5e421 100755 --- a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py +++ b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py @@ -19,9 +19,10 @@ __all__ = [] meta_optimizer_names = list( filter(lambda name: name.endswith("Optimizer"), dir())) -# Because HybridParallelOptimizer is dygraph optimizer, it +# Because HybridParallelOptimizer is dygraph optimizer, it # should be removed meta_optimizer_names.remove("HybridParallelOptimizer") +meta_optimizer_names.remove("HeterParallelOptimizer") class MetaOptimizerFactory(object): diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index 0aae3331793ca7ebc4ecd16805b828fa542775f3..708ba2816077e1f84d1f0df990e21c0fb0137adf 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -108,9 +108,9 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra base_group.add_argument( "--backend", type=str, - default="auto", - help="Specifize the backend, can be gloo|nccl|bkcl|auto. Default value is auto which perfers nccl or bkcl." - ) + default=os.environ.get('PADDLE_DISTRI_BACKEND', 'auto'), + help="Specifize the backend, can be gloo|nccl|bkcl|auto|hccl|heter. " + "Default value is auto which perfers nccl or bkcl.") base_group.add_argument( "--nproc_per_node", type=int, @@ -146,6 +146,16 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ) base_group.add_argument("--selected_xpus", dest="xpus") + if fluid.core.is_compiled_with_npu(): + base_group.add_argument( + "--npus", + type=str, + default=None, + help="It's for xpu training. For example: " + "--npus=\"0,1,2,3\" will launch four training processes each bound to one npu." + ) + base_group.add_argument("--selected_npus", dest="npus") + base_group.add_argument( "training_script", type=str, @@ -301,25 +311,23 @@ def get_cluster_info(args): # lazy launch for auto-parallel if args.enable_auto_mapping == True: cluster, pod = get_mapped_cluster_from_args(args, device_mode) - else: + elif cloud_utils.use_paddlecloud() and trainers_num != 1: + cluster, pod = cloud_utils.get_cloud_cluster( + args.ips, device_mode, devices_per_proc, start_port) + logger.debug("get cluster from cloud:{}".format(cluster)) + elif device_mode == DeviceMode.ASCEND_NPU: # for ascend - if device_mode == DeviceMode.ASCEND_NPU: - cluster, pod = ascend_utils.get_cloud_cluster( - rank_table_file=os.getenv("RANK_TABLE_FILE", None), - device_mode=device_mode, - start_port=start_port) - elif cloud_utils.use_paddlecloud() and trainers_num != 1: - cluster, pod = cloud_utils.get_cloud_cluster( - args.ips, device_mode, devices_per_proc, start_port) - logger.debug("get cluster from cloud:{}".format(cluster)) - else: - # trainers_num = 1 or not use paddlecloud ips="a,b" - cluster, pod = get_cluster_from_args(args, device_mode, - devices_per_proc) - logger.debug("get cluster from args:{}".format(cluster)) + cluster, pod = ascend_utils.get_cloud_cluster( + rank_table_file=os.getenv("RANK_TABLE_FILE", None), + device_mode=device_mode, + start_port=start_port) + else: + # trainers_num = 1 or not use paddlecloud ips="a,b" + cluster, pod = get_cluster_from_args(args, device_mode, + devices_per_proc) + logger.debug("get cluster from args:{}".format(cluster)) return cluster, pod - def get_global_envs(args, tmp_dir): global_envs = copy.copy(os.environ.copy()) # add gloo env @@ -456,15 +464,15 @@ def which_distributed_mode(args): ) and not fluid.core.is_compiled_with_xpu(): if args.servers: logger.warning( - "Not found distinct arguments and not compiled with cuda or xpu. \ -But found args.servers not empty, default use ps mode") + "Not found distinct arguments and not compiled with cuda or xpu or npu. " + "But found args.servers not empty, default use ps mode") return DistributeMode.PS else: return DistributeMode.COLLECTIVE else: logger.warning( - "Not found distinct arguments and compiled with cuda or xpu. Default use collective mode" - ) + "Not found distinct arguments and compiled with cuda or xpu or npu. " + "Default use collective mode") return DistributeMode.COLLECTIVE @@ -651,7 +659,7 @@ def launch(): check_backend(args.backend) distribute_mode = DistributeMode.COLLECTIVE - assert args.backend in ['gloo', 'nccl', 'bkcl', 'unknown'] + #assert args.backend in ['gloo', 'nccl', 'bkcl', 'heter', 'unknown'] if args.backend == 'gloo': logger.warning("launch start with CPUONLY mode") diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index d87bdb47932ef16f0c6d47d66a7900c275631014..569f64c18bf52f31215f0d26aa56d997599ffb05 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -690,9 +690,51 @@ def get_xpus(xpus): return res_xpus +def get_npus(npus): + if npus is None: + npus_num = fluid.core.get_npu_device_count() + res_npus = [str(x) for x in range(0, npus_num)] + else: + npu_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") + if npu_visible_devices is None or npu_visible_devices == "": + res_npus = [x.strip() for x in npus.split(',')] + else: + # change npus into relative values + # e.g. ASCEND_VISIBLE_DEVICES=4,5,6,7; args.npus=4,5,6,7; + # therefore npus=0,1,2,3 + npu_visible_devices_list = npu_visible_devices.split(',') + for x in npus.split(','): + assert x in npu_visible_devices_list, "Can't find "\ + "your npus %s in ASCEND_VISIBLE_DEVICES[%s]."\ + % (x, npu_visible_devices) + res_npus = [ + npu_visible_devices_list.index(x.strip()) + for x in npus.split(',') + ] + logger.info("Change selected_npus into reletive values. --ips:{} " + "will change into relative_ips:{} according to your " + "ASCEND_VISIBLE_DEVICES:{}".format( + npus, res_npus, npu_visible_devices_list)) + + return res_npus + + def get_device_mode(backend): - if fluid.core.is_compiled_with_npu() and \ + if backend == 'heter': + if fluid.core.is_compiled_with_cuda() and \ + fluid.core.get_cuda_device_count() > 0: + print("launch train in heter mode with GPU device.") + return DeviceMode.GPU + if fluid.core.is_compiled_with_xpu() and \ + fluid.core.get_xpu_device_count() > 0: + print("launch train in heter mode with XPU device.") + return DeviceMode.XPU + if fluid.core.is_compiled_with_npu() and \ fluid.core.get_npu_device_count() > 0: + print("launch train in heter mode with NPU device.") + return DeviceMode.ASCEND_NPU + + if backend == 'hccl' and fluid.core.get_npu_device_count() > 0: print("launch train in ascend npu mode!") return DeviceMode.ASCEND_NPU @@ -731,7 +773,17 @@ def get_device_proc_info(args): else: devices_per_proc = gpus elif device_mode == DeviceMode.ASCEND_NPU: - devices_per_proc = None + npus = get_npus(args.npus) + if args.nproc_per_node is not None: + assert (len(npus) % int(args.nproc_per_node)) ==0, \ + "npus' number:{} mod args.nproc_per_node:{} must == 0".format(len(npus), args.nproc_per_node) + + n = int(len(npus) / int(args.nproc_per_node)) + devices_per_proc = [ + npus[i:i + n] for i in six.moves.range(0, len(npus), n) + ] + else: + devices_per_proc = npus elif device_mode == DeviceMode.XPU: xpus = get_xpus(args.xpus) if args.nproc_per_node is not None: @@ -902,11 +954,8 @@ def get_mapped_cluster_from_args(args, device_mode): node_rank = node_ips.index(ip) if os.environ.get('FLAGS_START_PORT') is not None: start_port = int(os.environ.get('FLAGS_START_PORT')) - free_ports = [ - x - for x in range(start_port, start_port + len(node_ranks_mapping[ - node_rank])) - ] + end_port = start_port + len(node_ranks_mapping[node_rank]) + free_ports = [x for x in range(start_port, end_port)] else: free_ports = find_free_ports(len(node_ranks_mapping[node_rank])) trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) @@ -1527,11 +1576,11 @@ class ParameterServerLauncher(object): def check_backend(backend): - if backend not in ['nccl', 'gloo', 'bkcl', 'auto']: - raise ValueError( - "paddle.distributed initialize error, " - "backend argument can only be one of 'nccl', 'gloo', 'bkcl', 'auto', but got %s" - % backend) + if backend not in ['nccl', 'gloo', 'bkcl', 'auto', 'hccl', 'heter']: + raise ValueError("paddle.distributed initialize error, " + "backend argument can only be one of " + "'nccl', 'gloo', 'bkcl', 'auto', 'hccl', 'heter' " + "but got %s" % backend) if backend == 'nccl' and not fluid.core.is_compiled_with_cuda(): raise ValueError( @@ -1545,6 +1594,12 @@ def check_backend(backend): "your paddle is not compiled with xpu but you assign 'bkcl' as backend." ) + if backend == 'hccl' and not fluid.core.is_compiled_with_npu(): + raise ValueError( + "paddle.distributed initialize error, " + "your paddle is not compiled with npu but you assign 'hccl' as backend." + ) + def block_windows_and_macos(backend): if backend != 'gloo': return @@ -1565,4 +1620,7 @@ def get_backend_by_compile_flag(): if fluid.core.is_compiled_with_xpu(): return 'bkcl' + if fluid.core.is_compiled_with_npu(): + return 'hccl' + return 'gloo' diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 739de0de57725fcb5d830e8880ea09458dc01f8d..13496ad8ee5d96da3fe67b79e0178b4f084a49ed 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -28,6 +28,7 @@ from .lamb_optimizer import LambOptimizer from .fp16_allreduce_optimizer import FP16AllReduceOptimizer from .sharding_optimizer import ShardingOptimizer from .dygraph_optimizer import HybridParallelOptimizer +from .dygraph_optimizer import HeterParallelOptimizer from .dygraph_optimizer import HybridParallelGradScaler from .tensor_parallel_optimizer import TensorParallelOptimizer from .raw_program_optimizer import RawProgramOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py index 28260d7aa186353dc27badc446c5650cd62b8b5a..3beb8635ba41a765a1dd567ba7950d73d7829cfa 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py @@ -13,5 +13,6 @@ from .hybrid_parallel_optimizer import HybridParallelOptimizer from .hybrid_parallel_gradscaler import HybridParallelGradScaler from .dygraph_sharding_optimizer import DygraphShardingOptimizer +from .heter_parallel_optimizer import HeterParallelOptimizer __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/heter_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/heter_parallel_optimizer.py new file mode 100755 index 0000000000000000000000000000000000000000..9218024be17203e0082d840d818170c94cad22e0 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/heter_parallel_optimizer.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.dygraph import base as imperative_base +from paddle.fluid import framework + +__all__ = [] + + +def _obtain_optimizer_parameters_list(optimizer): + if getattr(optimizer, '_param_groups', None) and isinstance( + optimizer._param_groups[0], dict): + parameters_list = [] + for group in optimizer._param_groups: + for param in group['params']: + parameters_list.append(param) + else: + parameters_list = [param for param in optimizer._parameter_list] + + return parameters_list + + +class HeterParallelOptimizer: + # adapter wrapper for optimizer + def __init__(self, optimizer, strategy): + self._inner_opt = optimizer + self._strategy = strategy + + # NOTE(liubo48): In pure DataParallel mode, + # the gradient synchronization is achieved through reducer. + + @imperative_base.no_grad + @framework.dygraph_only + def step(self): + parameters_list = _obtain_optimizer_parameters_list(self._inner_opt) + self._inner_opt.step() + + @imperative_base.no_grad + def minimize(self, + loss, + startup_program=None, + parameters=None, + no_grad_set=None): + + # minimize does not support parameters in the form of param_group, + # so no need use _obtain_optimizer_parameters_list + parameter_list = parameters if parameters \ + else self._inner_opt._parameter_list + + return self._inner_opt.minimize(loss, startup_program, parameter_list, + no_grad_set) + + def __getattr__(self, item): + return getattr(self._inner_opt, item) diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 7ea479f0fbb14dd1a6e7d16f44e6d5ce6066edf6..177e19194a52277d99c7f3b7904bb00a3962f4c6 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -58,7 +58,7 @@ def _start_kv_server(port, http_server_d, size): def _is_cpuonly(backend): check_backend(backend) - if backend in ['auto', 'nccl', 'bkcl', 'hccl'] and ( + if backend in ['auto', 'nccl', 'bkcl', 'hccl', 'heter'] and ( core.is_compiled_with_cuda() or core.is_compiled_with_xpu() or core.is_compiled_with_npu()): @@ -68,6 +68,14 @@ def _is_cpuonly(backend): return True +def _check_var_exists(var_name): + var = os.environ.get(var_name, None) + if var is None: + raise ValueError("paddle.distributed initialize error, " + "environment variable %s is needed, but not set." % + var_name) + + def init_parallel_env(): """ Initialize parallel training environment in dynamic graph mode. @@ -148,27 +156,22 @@ def init_parallel_env(): raise NotImplementedError( "If you want to use CPU-only version, please use 'gloo' as backend") - # 2. check env - def _check_var_exists(var_name): - var = os.environ.get(var_name, None) - if var is None: - raise ValueError("paddle.distributed initialize error, " - "environment variable %s is needed, but not set." % - var_name) - if not is_cpu_only and core.is_compiled_with_cuda(): _check_var_exists("FLAGS_selected_gpus") elif not is_cpu_only and core.is_compiled_with_xpu(): _check_var_exists('FLAGS_selected_xpus') + elif not is_cpu_only and core.is_compiled_with_npu(): + _check_var_exists('FLAGS_selected_npus') _check_var_exists("PADDLE_TRAINER_ID") _check_var_exists("PADDLE_CURRENT_ENDPOINT") _check_var_exists("PADDLE_TRAINERS_NUM") _check_var_exists("PADDLE_TRAINER_ENDPOINTS") + node_num = set([i.split(":")[0] for i in parallel_env.trainer_endpoints]) # 3: init gloo context (step 1: httpsever start) init_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0")) - if is_cpu_only or init_gloo: + if is_cpu_only or init_gloo or backend == "heter": ep_rank_0 = parallel_env.trainer_endpoints[0].split(":") manager = Manager() # glboal dict to store status @@ -177,6 +180,8 @@ def init_parallel_env(): if parallel_env.rank == 0: # The scope for worker used by http server is '_worker' size = {'_worker': parallel_env.world_size} + if backend == "heter": + size = {'_worker': len(node_num)} http_server = Process( target=_start_kv_server, args=(int(ep_rank_0[1]), http_server_d, size)) @@ -210,10 +215,13 @@ def init_parallel_env(): place = core.NPUPlace(parallel_env.device_id) _set_expected_place(place) - # init nccl or bkcl context + # init nccl or hccl or bkcl or heter context if is_cpu_only: parallel_helper._set_parallel_ctx( core.GLOOParallelContext(strategy, place)) + elif (backend == "heter"): + parallel_helper._set_parallel_ctx( + core.HeterParallelContext(strategy, parallel_env.device_id)) elif core.is_compiled_with_cuda(): parallel_helper._set_parallel_ctx( core.NCCLParallelContext(strategy, place)) @@ -224,17 +232,19 @@ def init_parallel_env(): parallel_helper._set_parallel_ctx( core.HCCLParallelContext(strategy, place)) - other_endpoints = strategy.trainer_endpoints[:] - other_endpoints.remove(strategy.current_endpoint) - if not is_cpu_only and strategy.local_rank == 0: - wait_server_ready(other_endpoints) + if backend != "heter": + other_endpoints = strategy.trainer_endpoints[:] + other_endpoints.remove(strategy.current_endpoint) + if not is_cpu_only and strategy.local_rank == 0: + wait_server_ready(other_endpoints) parallel_helper._init_parallel_ctx() + # 5: init gloo context (step 2: gloo init) # dividing init_gloo into two part beacause nccl and gloo # are separately looking for free ports which sometimes # leads to port-conflict. - if is_cpu_only and parallel_env.rank == 0: + if (is_cpu_only or backend == "heter") and parallel_env.rank == 0: # compare to init_gloo, we don't need to # init gloo, because we do this in _init_parallel_ctx; http_server_d["running"] = False diff --git a/python/paddle/fluid/dygraph/parallel_helper.py b/python/paddle/fluid/dygraph/parallel_helper.py index 40d5d18c9a40fac6f539573db0c1f25c84040b68..5fe4d4162e6e3287e8195cc05f416ec96a7961f0 100644 --- a/python/paddle/fluid/dygraph/parallel_helper.py +++ b/python/paddle/fluid/dygraph/parallel_helper.py @@ -28,11 +28,11 @@ def _is_parallel_ctx_initialized(): return __parallel_ctx__clz__ is not None -def _set_parallel_ctx(nccl_parallel_context): +def _set_parallel_ctx(ccl_parallel_context): global __parallel_ctx__clz__ assert __parallel_ctx__clz__ is None, \ "ParallelContext can only be initialized once." - __parallel_ctx__clz__ = nccl_parallel_context + __parallel_ctx__clz__ = ccl_parallel_context def _init_parallel_ctx():