From 45765d6eb6445ab8ee98ea2306dcfeae4d1f295e Mon Sep 17 00:00:00 2001 From: Void Main Date: Tue, 2 Mar 2021 15:34:18 +0800 Subject: [PATCH] Refactor HCCLCommContext to be compatible with Paddle (#31359) Refactor HCCLCommContext to be compatible with Paddle (#31359) --- .../fluid/operators/collective/CMakeLists.txt | 2 +- .../operators/collective/c_allreduce_op.h | 2 +- .../collective/c_broadcast_op_npu.cc | 4 +- ...init_hccl_op.cc => c_comm_init_hcom_op.cc} | 47 ++++-- .../operators/collective/c_create_group_op.cc | 76 ---------- .../collective/c_hcom_op_npu_test.cc | 36 ++--- paddle/fluid/platform/collective_helper.h | 59 ++++++-- .../fluid/platform/collective_helper_npu.cc | 141 +++++++++++------- paddle/fluid/platform/device_context.h | 8 - 9 files changed, 185 insertions(+), 190 deletions(-) rename paddle/fluid/operators/collective/{c_comm_init_hccl_op.cc => c_comm_init_hcom_op.cc} (55%) delete mode 100644 paddle/fluid/operators/collective/c_create_group_op.cc diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 17f5502a039..b2405b60585 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -36,4 +36,4 @@ endif() set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE) set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency") -cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hccl_op c_create_group_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) +cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 6658e5b364a..27cdee5a982 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -135,7 +135,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel { std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); group = "hccl_world_group";// std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); - auto comm = paddle::platform::HCCLCommContext::Instance().Get(); + auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place); aclrtStream stream = nullptr; if (ctx.Attr("use_calc_stream")) { diff --git a/paddle/fluid/operators/collective/c_broadcast_op_npu.cc b/paddle/fluid/operators/collective/c_broadcast_op_npu.cc index 690e1433595..c2c049b3a91 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op_npu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op_npu.cc @@ -34,8 +34,9 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel { auto out = ctx.Output("Out"); + int ring_id = ctx.Attr("ring_id"); auto place = ctx.GetPlace(); - auto comm = paddle::platform::HCCLCommContext::Instance().Get(); + auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place); aclrtStream stream = nullptr; if (ctx.Attr("use_calc_stream")) { @@ -46,7 +47,6 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel { } int root = ctx.Attr("root"); - int ring_id = ctx.Attr("ring_id"); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string tag = ctx.Attr("tag"); diff --git a/paddle/fluid/operators/collective/c_comm_init_hccl_op.cc b/paddle/fluid/operators/collective/c_comm_init_hcom_op.cc similarity index 55% rename from paddle/fluid/operators/collective/c_comm_init_hccl_op.cc rename to paddle/fluid/operators/collective/c_comm_init_hcom_op.cc index e47849b8394..f720ffdd0fe 100644 --- a/paddle/fluid/operators/collective/c_comm_init_hccl_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_hcom_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/platform/hccl_helper.h" #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/npu_op_runner.h" @@ -40,16 +41,24 @@ class CCommInitOpNPU : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { - std::string rank_table_file = Attr("rank_table_file"); - uint32_t rank_id = Attr("rank_id"); - uint32_t device_id = Attr("device_id"); - - VLOG(3) << "begin init hccl, parameter is: " - << "rank_table_file " << rank_table_file - << " rank_id " << rank_id - << " device_id " << device_id; - - platform::HCCLCommContext::Instance().CreateHCCLComm(rank_table_file, rank_id, device_id); + int rid = Attr("ring_id"); + int nranks = Attr("nranks"); + int rank_id = Attr("rank"); + int device_id = BOOST_GET_CONST(platform::NPUPlace, place).device; + if (Attr("device_id") >= 0) { + device_id = Attr("device_id"); + } + std::vector rank_ids = Attr>("rank_ids"); + + VLOG(3) << "begin c_comm_init on npu, parameters are: " + << "ring id[" << rid + << "], nranks[" << nranks + << "], rank_id[" << rank_id + << "], device_id[" << device_id + << "]"; + + platform::HCCLCommContext::Instance().CreateHCCLComm( + rank_ids, rank_id, device_id, rid); } }; @@ -61,10 +70,17 @@ CCommInit operator on NPU Initialize collective communication context within this trainer )DOC"); - AddAttr("rank_table_file", - "(string) path to rank_table_file"); - AddAttr("rank_id", "(int) world rank id of the process"); - AddAttr("device_id", "(int) device id of the process/thread"); + AddAttr("nranks", "(int) The number of ranks of distributed trainers"); + AddAttr>("rank_ids", "The world rank ids of the group"); + AddAttr("rank", + "(int) The rank of the trainer in distributed training."); + AddAttr("device_id", + "(int) The deivce_id on which to initialize the communicator." + "Now, you only have to set this attr manually for pipeline " + "training. Otherwise, make it as default.") + .SetDefault(-1); + AddAttr("ring_id", "(int default 0) user specified ring id") + .SetDefault(0); } }; @@ -73,7 +89,6 @@ Initialize collective communication context within this trainer namespace ops = paddle::operators; -REGISTER_OPERATOR(c_comm_init_hccl, ops::CCommInitOpNPU, - ops::CCommInitOpNPUMaker); +REGISTER_OPERATOR(c_comm_init_hcom, ops::CCommInitOpNPU, ops::CCommInitOpNPUMaker); #endif diff --git a/paddle/fluid/operators/collective/c_create_group_op.cc b/paddle/fluid/operators/collective/c_create_group_op.cc deleted file mode 100644 index 07ca91e72ae..00000000000 --- a/paddle/fluid/operators/collective/c_create_group_op.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_ASCEND_CL -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/hccl_helper.h" - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/npu_op_runner.h" - -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { - -class CCreateGroupOpNPU : public framework::OperatorBase { - public: - CCreateGroupOpNPU(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - std::string group_name = Attr("group_name"); - int nranks = Attr("nranks"); - std::vector rank_ids = Attr>("rank_ids"); - paddle::platform::HCCLCommContext::Instance().CreateHCCLGroup( - group_name, (uint32_t)nranks, - std::vector(rank_ids.begin(), rank_ids.end())); - } -}; - -class CCreateGroupOpNPUMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddComment(R"DOC( -CCreateGroup operator on NPU - -Create collective communication group on NPU -)DOC"); - AddAttr("group_name", - "(string) name of the collective communication group"); - AddAttr("nranks", "(int) number of the group"); - AddAttr>("rank_ids", - "(list of int) The world rank id of the group members"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(c_create_group, ops::CCreateGroupOpNPU, - ops::CCreateGroupOpNPUMaker); - -#endif diff --git a/paddle/fluid/operators/collective/c_hcom_op_npu_test.cc b/paddle/fluid/operators/collective/c_hcom_op_npu_test.cc index 590cfe62148..643300158f4 100644 --- a/paddle/fluid/operators/collective/c_hcom_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_hcom_op_npu_test.cc @@ -43,37 +43,29 @@ namespace m = paddle::operators::math; USE_OP(c_broadcast); USE_OP(c_allreduce_sum); -USE_NO_KERNEL_OP(c_comm_init_hccl); -USE_NO_KERNEL_OP(c_create_group); +USE_NO_KERNEL_OP(c_comm_init_hcom); USE_OP_DEVICE_KERNEL(c_broadcast, NPU); USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU); void Prepare(f::Scope* scope, const p::DeviceContext& ctx){ - std::string rank_table_file = getenv("RANK_TABLE_FILE"); int rank_id = atoi(getenv("RANK_ID")); int device_id = atoi(getenv("DEVICE_ID")); - printf("rank_table_file: %s, rank_id = %d, device_id = %d\n", rank_table_file.c_str(), rank_id, device_id); + printf("rank_id = %d, device_id = %d\n", rank_id, device_id); - f::AttributeMap attrs; - attrs["rank_table_file"] = rank_table_file; - attrs["rank_id"] = rank_id; - attrs["device_id"] = device_id; + std::vector rank_ids{0, 1}; + f::AttributeMap comm_init_attrs; + comm_init_attrs["ring_id"] = 0; + comm_init_attrs["nranks"] = 2; + comm_init_attrs["rank"] = rank_id; + comm_init_attrs["device_id"] = device_id; + comm_init_attrs["rank_ids"] = rank_ids; auto comm_init_op = - f::OpRegistry::CreateOp("c_comm_init_hccl", {}, {}, attrs); + f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); auto place = ctx.GetPlace(); comm_init_op->Run(*scope, place); ctx.Wait(); - - f::AttributeMap create_attrs; - create_attrs["group_name"] = HCOM_GROUP_PREFIX + std::to_string(0); - create_attrs["nranks"] = 2; - std::vector rank_ids{0, 1}; - create_attrs["rank_ids"] = rank_ids; - auto create_group_op = f::OpRegistry::CreateOp("c_create_group", {}, {}, create_attrs); - create_group_op->Run(*scope, place); - ctx.Wait(); } void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { std::cout<< "BEGIN TEST:" << __FUNCTION__ < init; int rank_id = atoi(getenv("RANK_ID")); std::cout<< "rank_id:" << rank_id<& world_rank_ids, int rank, int dev_id, int ring_id = 0); + + // a latter comm with the same dev_id and the same ring_id + // will override the former + HCCLComm* AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id = 0); + + // retrieve a communicator by the ring id in multiprocessing mode + HCCLComm* Get(int ring_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), 0, + platform::errors::InvalidArgument( + "Communicator in ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1, + platform::errors::InvalidArgument( + "One device id should be specified to retrieve from " + "multiple communicators.")); + return comm_map_.at(ring_id).begin()->second.get(); + } - void CreateHCCLGroup(const std::string& group_name, uint32_t nranks, const std::vector& rank_ids); + // retrieve a communicator by the ring id and the device id + HCCLComm* Get(int ring_id, int dev_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), 0, + platform::errors::InvalidArgument( + "Communicator of ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_GT( + comm_map_.at(ring_id).count(dev_id), 0, + platform::errors::InvalidArgument( + "Communicator at device id %d has not been initialized in ring %d.", + dev_id, ring_id)); + return comm_map_.at(ring_id).at(dev_id).get(); + } // retrieve a communicator by the ring id and place - HCCLComm* Get() const { - return comm_.get(); + HCCLComm* Get(int ring_id, Place place) const { + return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device); } + private: + // Init global hcom + HCCLCommContext() { InitHcomWorldGroup(); } + std::once_flag once_flag_; std::mutex comm_map_mutex_; - std::unique_ptr comm_; + // ring id to dev-HCCLComm + std::map>> comm_map_; - HCCLComm* AssignHCCLComm(const std::string& config_file, uint32_t rank, uint32_t device_id); + void InitHcomWorldGroup(); + void ReleaseHCCLComms(); - HCCLCommContext() = default; DISABLE_COPY_AND_ASSIGN(HCCLCommContext); }; #endif diff --git a/paddle/fluid/platform/collective_helper_npu.cc b/paddle/fluid/platform/collective_helper_npu.cc index 3cf16475ee0..edfa351f19b 100644 --- a/paddle/fluid/platform/collective_helper_npu.cc +++ b/paddle/fluid/platform/collective_helper_npu.cc @@ -21,14 +21,18 @@ namespace platform { class HCCLCommImpl : public HCCLComm { public: - void set_rank_table_file(const std::string& rank_table_file) { rank_table_file_ = rank_table_file; } - std::string rank_table_file() const override { return rank_table_file_; } + void set_ring_id(int ring_id) { ring_id_ = ring_id; } + int ring_id() const override { return ring_id_; } - void set_rank(uint32_t rank) { rank_ = rank; } - uint32_t rank() const override { return rank_; } + void set_nranks(int nranks) { nranks_ = nranks; } + int nranks() const override { return nranks_; } - void set_device_id(uint32_t device_id) { device_id_ = device_id; } - uint32_t device_id() const override { return device_id_; } + void set_rank(int rank) { rank_ = rank; } + int rank() const override { return rank_; } + + int device_id() const override { + return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device; + } aclrtStream stream() const override { return dev_ctx_->stream(); } @@ -38,74 +42,103 @@ class HCCLCommImpl : public HCCLComm { NPUDeviceContext* dev_context() const override { return dev_ctx_.get(); } private: - std::string rank_table_file_; - uint32_t rank_; - uint32_t device_id_; + int ring_id_; + int nranks_; + int rank_; std::unique_ptr dev_ctx_; }; -HCCLComm* HCCLCommContext::CreateHCCLComm(const std::string& rank_table_file, - uint32_t rank, uint32_t device_id) { -/* - PADDLE_ENFORCE_NOT_NULL(rank_table_file, - platform::errors::InvalidArgument( - "The rank table file should not be null.")); - +HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector& world_rank_ids, int rank, int dev_id, int ring_id) { + PADDLE_ENFORCE_GT( + world_rank_ids.size(), 1, + platform::errors::InvalidArgument( + "Expected world_rank_ids.size() > 1. But received size is %d.", world_rank_ids.size())); PADDLE_ENFORCE_GE(rank, 0, + platform::errors::InvalidArgument( + "Expected rank >= 0. But received rank is %d.", rank)); + PADDLE_ENFORCE_LT( + rank, world_rank_ids.size(), platform::errors::InvalidArgument( - "Expected rank >= 0. But received rank is %d.", rank)); - - PADDLE_ENFORCE_GE(device_id, 0, + "Expected rank < nranks. But received rank is %d, nranks is %d.", + rank, world_rank_ids.size())); + PADDLE_ENFORCE_GE( + dev_id, 0, platform::errors::InvalidArgument( - "Expected dev_id >= 0. But received dev_id is %d.", device_id)); -*/ - auto* comm_wrapper = AssignHCCLComm(rank_table_file, rank, device_id); + "Expected dev_id >= 0. But received dev_id is %d.", dev_id)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "Expected ring_id >= 0. But received ring_id is %d.", ring_id)); + + auto* comm_wrapper = AssignHCCLComm(world_rank_ids.size(), rank, dev_id, ring_id); + + // HACK(sunpeng17): hcom API requires bind stream to a model + // but we don't need model in Paddle, so we feed stream pointer as model pointer + PADDLE_ENFORCE_NPU_SUCCESS( + platform::dynload::hcom_bind_model(comm_wrapper->stream(), + comm_wrapper->stream())); - platform::dynload::hcom_init(rank_table_file.c_str(), std::to_string(rank).c_str()); - platform::dynload::hcom_bind_model(comm_wrapper->stream(), comm_wrapper->stream()); + // Get world_rank_ids registered in gen_nccl_id op + std::string group_name = HCOM_GROUP_PREFIX + std::to_string(ring_id); + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_create_group( + group_name.c_str(), world_rank_ids.size(), (unsigned int*)world_rank_ids.data())); + + VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id + << " has been created on device " << dev_id << ", group name: " << group_name; + + std::call_once(once_flag_, []() { + std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); }); + }); - VLOG(1) << "hccl communicator of rank " << rank << " has been created"; return comm_wrapper; } -HCCLComm* HCCLCommContext::AssignHCCLComm(const std::string& rank_table_file, - uint32_t rank, uint32_t device_id) { - +HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id) { std::unique_ptr dev_ctx( - new NPUDeviceContext(NPUPlace(device_id))); - - VLOG(3) << "device_id" << device_id; - VLOG(3) << "dev_ctx->stream()" << dev_ctx->stream(); + new NPUDeviceContext(NPUPlace(dev_id))); HCCLCommImpl* c = new HCCLCommImpl; - c->set_rank_table_file(rank_table_file); + c->set_ring_id(ring_id); + c->set_nranks(nranks); c->set_rank(rank); - c->set_device_id(device_id); c->set_dev_ctx(std::move(dev_ctx)); - // comm_ = c - comm_.reset(c); - return c; + + comm_map_mutex_.lock(); + if (comm_map_.count(ring_id) == 0) { + comm_map_.emplace(ring_id, std::map>()); + } + auto& dev2comm = comm_map_[ring_id]; + + dev2comm.emplace(dev_id, std::unique_ptr(c)); + comm_map_mutex_.unlock(); + + return comm_map_[ring_id][dev_id].get(); } -void HCCLCommContext::CreateHCCLGroup(const std::string& group_name, uint32_t nranks, - const std::vector& rank_ids) { -/* - PADDLE_ENFORCE_NOT_NULL(group_name, - platform::errors::InvalidArgument( - "The group name should not be null.")); - PADDLE_ENFORCE_GT(nranks, 0, - platform::errors::InvalidArgument( - "Expected nranks > 0. But received nranks is %d.", nranks)); - PADDLE_ENFORCE_NOT_NULL(rank_ids, - platform::errors::InvalidArgument( - "The rank ids should not be null.")); -*/ - platform::dynload::hcom_create_group(group_name.c_str(), nranks, (unsigned int*)rank_ids.data()); - - VLOG(1) << "hccl group with name " << group_name << " has been created"; +void HCCLCommContext::InitHcomWorldGroup() { + const char *rank_table_file = getenv(ENV_RANK_TABLE_FILE); + PADDLE_ENFORCE_NOT_NULL( + rank_table_file, + platform::errors::InvalidArgument("The RANK_TABLE_FILE environment variable should not be null.")); + + const char *rank_id = getenv(ENV_RANK_ID); + PADDLE_ENFORCE_NOT_NULL( + rank_id, + platform::errors::InvalidArgument("The RANK_ID environment variable should not be null.")); + + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_init(rank_table_file, rank_id)); + VLOG(3) << "Successfully initialized hcom. rank_table_file: " + << rank_table_file << ", rank_id " << rank_id; +} + +void HCCLCommContext::ReleaseHCCLComms() { + for (auto& p : comm_map_) { + for (auto& q : p.second) { + q.second.reset(); + } + } } } // namespace platform } // namespace paddle - #endif diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index f5fa6816b50..187dd627e4a 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -178,14 +178,6 @@ class NPUDeviceContext : public DeviceContext { /*! \brief Return npu stream in the device context. */ aclrtStream stream() const; -#ifdef PADDLE_WITH_ASCEND_HCCL - /*! \brief Return bkcl context. */ - HCCLContext_t hccl_context() const { return hccl_context_; } - - /*! \brief Set bkcl context. */ - void set_hccl_context(HCCLContext_t context) { hccl_context_ = context; } -#endif - private: NPUPlace place_; aclrtContext context_; -- GitLab