未验证 提交 45765d6e 编写于 作者: V Void Main 提交者: GitHub

Refactor HCCLCommContext to be compatible with Paddle (#31359)

Refactor HCCLCommContext to be compatible with Paddle (#31359)
上级 8497e2aa
...@@ -36,4 +36,4 @@ endif() ...@@ -36,4 +36,4 @@ endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE) set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency") set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")
cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hccl_op c_create_group_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
...@@ -135,7 +135,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -135,7 +135,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
group = "hccl_world_group";// std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); group = "hccl_world_group";// std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(); auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
......
...@@ -34,8 +34,9 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> { ...@@ -34,8 +34,9 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = paddle::platform::HCCLCommContext::Instance().Get(); auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
...@@ -46,7 +47,6 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> { ...@@ -46,7 +47,6 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
} }
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag"); std::string tag = ctx.Attr<std::string>("tag");
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/platform/hccl_helper.h" #include "paddle/fluid/platform/hccl_helper.h"
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
...@@ -40,16 +41,24 @@ class CCommInitOpNPU : public framework::OperatorBase { ...@@ -40,16 +41,24 @@ class CCommInitOpNPU : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
std::string rank_table_file = Attr<std::string>("rank_table_file"); int rid = Attr<int>("ring_id");
uint32_t rank_id = Attr<int>("rank_id"); int nranks = Attr<int>("nranks");
uint32_t device_id = Attr<int>("device_id"); int rank_id = Attr<int>("rank");
int device_id = BOOST_GET_CONST(platform::NPUPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
std::vector<int> rank_ids = Attr<std::vector<int>>("rank_ids");
VLOG(3) << "begin init hccl, parameter is: " VLOG(3) << "begin c_comm_init on npu, parameters are: "
<< "rank_table_file " << rank_table_file << "ring id[" << rid
<< " rank_id " << rank_id << "], nranks[" << nranks
<< " device_id " << device_id; << "], rank_id[" << rank_id
<< "], device_id[" << device_id
<< "]";
platform::HCCLCommContext::Instance().CreateHCCLComm(rank_table_file, rank_id, device_id); platform::HCCLCommContext::Instance().CreateHCCLComm(
rank_ids, rank_id, device_id, rid);
} }
}; };
...@@ -61,10 +70,17 @@ CCommInit operator on NPU ...@@ -61,10 +70,17 @@ CCommInit operator on NPU
Initialize collective communication context within this trainer Initialize collective communication context within this trainer
)DOC"); )DOC");
AddAttr<std::string>("rank_table_file", AddAttr<int>("nranks", "(int) The number of ranks of distributed trainers");
"(string) path to rank_table_file"); AddAttr<std::vector<int>>("rank_ids", "The world rank ids of the group");
AddAttr<int>("rank_id", "(int) world rank id of the process"); AddAttr<int>("rank",
AddAttr<int>("device_id", "(int) device id of the process/thread"); "(int) The rank of the trainer in distributed training.");
AddAttr<int>("device_id",
"(int) The deivce_id on which to initialize the communicator."
"Now, you only have to set this attr manually for pipeline "
"training. Otherwise, make it as default.")
.SetDefault(-1);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
} }
}; };
...@@ -73,7 +89,6 @@ Initialize collective communication context within this trainer ...@@ -73,7 +89,6 @@ Initialize collective communication context within this trainer
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init_hccl, ops::CCommInitOpNPU, REGISTER_OPERATOR(c_comm_init_hcom, ops::CCommInitOpNPU, ops::CCommInitOpNPUMaker);
ops::CCommInitOpNPUMaker);
#endif #endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class CCreateGroupOpNPU : public framework::OperatorBase {
public:
CCreateGroupOpNPU(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::string group_name = Attr<std::string>("group_name");
int nranks = Attr<int>("nranks");
std::vector<int> rank_ids = Attr<std::vector<int>>("rank_ids");
paddle::platform::HCCLCommContext::Instance().CreateHCCLGroup(
group_name, (uint32_t)nranks,
std::vector<uint32_t>(rank_ids.begin(), rank_ids.end()));
}
};
class CCreateGroupOpNPUMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
CCreateGroup operator on NPU
Create collective communication group on NPU
)DOC");
AddAttr<std::string>("group_name",
"(string) name of the collective communication group");
AddAttr<int>("nranks", "(int) number of the group");
AddAttr<std::vector<int>>("rank_ids",
"(list of int) The world rank id of the group members");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_create_group, ops::CCreateGroupOpNPU,
ops::CCreateGroupOpNPUMaker);
#endif
...@@ -43,37 +43,29 @@ namespace m = paddle::operators::math; ...@@ -43,37 +43,29 @@ namespace m = paddle::operators::math;
USE_OP(c_broadcast); USE_OP(c_broadcast);
USE_OP(c_allreduce_sum); USE_OP(c_allreduce_sum);
USE_NO_KERNEL_OP(c_comm_init_hccl); USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_create_group);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU); USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU); USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx){ void Prepare(f::Scope* scope, const p::DeviceContext& ctx){
std::string rank_table_file = getenv("RANK_TABLE_FILE");
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
printf("rank_table_file: %s, rank_id = %d, device_id = %d\n", rank_table_file.c_str(), rank_id, device_id); printf("rank_id = %d, device_id = %d\n", rank_id, device_id);
f::AttributeMap attrs; std::vector<int> rank_ids{0, 1};
attrs["rank_table_file"] = rank_table_file; f::AttributeMap comm_init_attrs;
attrs["rank_id"] = rank_id; comm_init_attrs["ring_id"] = 0;
attrs["device_id"] = device_id; comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hccl", {}, {}, attrs); f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
f::AttributeMap create_attrs;
create_attrs["group_name"] = HCOM_GROUP_PREFIX + std::to_string(0);
create_attrs["nranks"] = 2;
std::vector<int> rank_ids{0, 1};
create_attrs["rank_ids"] = rank_ids;
auto create_group_op = f::OpRegistry::CreateOp("c_create_group", {}, {}, create_attrs);
create_group_op->Run(*scope, place);
ctx.Wait();
} }
void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
std::cout<< "BEGIN TEST:" << __FUNCTION__ <<std::endl; std::cout<< "BEGIN TEST:" << __FUNCTION__ <<std::endl;
...@@ -187,6 +179,6 @@ TEST(c_broadcast, NPU) { ...@@ -187,6 +179,6 @@ TEST(c_broadcast, NPU) {
p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
Prepare(&scope, ctx); Prepare(&scope, ctx);
// TestHCCLBroadcastOp(&scope, ctx); TestHCCLBroadcastOp(&scope, ctx);
TestHCCLAllReduceOp(&scope, ctx); // TestHCCLAllReduceOp(&scope, ctx);
} }
...@@ -147,11 +147,16 @@ class NCCLCommContext { ...@@ -147,11 +147,16 @@ class NCCLCommContext {
// singleton with a global user specified group id. // singleton with a global user specified group id.
class NPUDeviceContext; class NPUDeviceContext;
#define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE"
#define ENV_RANK_ID "RANK_ID"
#define ENV_DEV_ID "DEV_ID"
class HCCLComm { class HCCLComm {
public: public:
virtual std::string rank_table_file() const = 0; virtual int ring_id() const = 0;
virtual uint32_t rank() const = 0; virtual int nranks() const = 0;
virtual uint32_t device_id() const = 0; virtual int rank() const = 0;
virtual int device_id() const = 0;
virtual aclrtStream stream() const = 0; virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0; virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default; virtual ~HCCLComm() = default;
...@@ -165,22 +170,56 @@ class HCCLCommContext { ...@@ -165,22 +170,56 @@ class HCCLCommContext {
return comm_ctx; return comm_ctx;
} }
HCCLComm* CreateHCCLComm(const std::string& config_file, uint32_t rank, uint32_t device_id); HCCLComm* CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id = 0);
// a latter comm with the same dev_id and the same ring_id
// will override the former
HCCLComm* AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id = 0);
// retrieve a communicator by the ring id in multiprocessing mode
HCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"Communicator in ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1,
platform::errors::InvalidArgument(
"One device id should be specified to retrieve from "
"multiple communicators."));
return comm_map_.at(ring_id).begin()->second.get();
}
void CreateHCCLGroup(const std::string& group_name, uint32_t nranks, const std::vector<uint32_t>& rank_ids); // retrieve a communicator by the ring id and the device id
HCCLComm* Get(int ring_id, int dev_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"Communicator of ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_GT(
comm_map_.at(ring_id).count(dev_id), 0,
platform::errors::InvalidArgument(
"Communicator at device id %d has not been initialized in ring %d.",
dev_id, ring_id));
return comm_map_.at(ring_id).at(dev_id).get();
}
// retrieve a communicator by the ring id and place // retrieve a communicator by the ring id and place
HCCLComm* Get() const { HCCLComm* Get(int ring_id, Place place) const {
return comm_.get(); return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device);
} }
private: private:
// Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); }
std::once_flag once_flag_; std::once_flag once_flag_;
std::mutex comm_map_mutex_; std::mutex comm_map_mutex_;
std::unique_ptr<HCCLComm> comm_; // ring id to dev-HCCLComm
std::map<int, std::map<int, std::unique_ptr<HCCLComm>>> comm_map_;
HCCLComm* AssignHCCLComm(const std::string& config_file, uint32_t rank, uint32_t device_id); void InitHcomWorldGroup();
void ReleaseHCCLComms();
HCCLCommContext() = default;
DISABLE_COPY_AND_ASSIGN(HCCLCommContext); DISABLE_COPY_AND_ASSIGN(HCCLCommContext);
}; };
#endif #endif
......
...@@ -21,14 +21,18 @@ namespace platform { ...@@ -21,14 +21,18 @@ namespace platform {
class HCCLCommImpl : public HCCLComm { class HCCLCommImpl : public HCCLComm {
public: public:
void set_rank_table_file(const std::string& rank_table_file) { rank_table_file_ = rank_table_file; } void set_ring_id(int ring_id) { ring_id_ = ring_id; }
std::string rank_table_file() const override { return rank_table_file_; } int ring_id() const override { return ring_id_; }
void set_rank(uint32_t rank) { rank_ = rank; } void set_nranks(int nranks) { nranks_ = nranks; }
uint32_t rank() const override { return rank_; } int nranks() const override { return nranks_; }
void set_device_id(uint32_t device_id) { device_id_ = device_id; } void set_rank(int rank) { rank_ = rank; }
uint32_t device_id() const override { return device_id_; } int rank() const override { return rank_; }
int device_id() const override {
return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device;
}
aclrtStream stream() const override { return dev_ctx_->stream(); } aclrtStream stream() const override { return dev_ctx_->stream(); }
...@@ -38,74 +42,103 @@ class HCCLCommImpl : public HCCLComm { ...@@ -38,74 +42,103 @@ class HCCLCommImpl : public HCCLComm {
NPUDeviceContext* dev_context() const override { return dev_ctx_.get(); } NPUDeviceContext* dev_context() const override { return dev_ctx_.get(); }
private: private:
std::string rank_table_file_; int ring_id_;
uint32_t rank_; int nranks_;
uint32_t device_id_; int rank_;
std::unique_ptr<NPUDeviceContext> dev_ctx_; std::unique_ptr<NPUDeviceContext> dev_ctx_;
}; };
HCCLComm* HCCLCommContext::CreateHCCLComm(const std::string& rank_table_file, HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id) {
uint32_t rank, uint32_t device_id) { PADDLE_ENFORCE_GT(
/* world_rank_ids.size(), 1,
PADDLE_ENFORCE_NOT_NULL(rank_table_file,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The rank table file should not be null.")); "Expected world_rank_ids.size() > 1. But received size is %d.", world_rank_ids.size()));
PADDLE_ENFORCE_GE(rank, 0, PADDLE_ENFORCE_GE(rank, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank)); "Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_LT(
PADDLE_ENFORCE_GE(device_id, 0, rank, world_rank_ids.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", device_id)); "Expected rank < nranks. But received rank is %d, nranks is %d.",
*/ rank, world_rank_ids.size()));
auto* comm_wrapper = AssignHCCLComm(rank_table_file, rank, device_id); PADDLE_ENFORCE_GE(
dev_id, 0,
platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", dev_id));
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"Expected ring_id >= 0. But received ring_id is %d.", ring_id));
auto* comm_wrapper = AssignHCCLComm(world_rank_ids.size(), rank, dev_id, ring_id);
platform::dynload::hcom_init(rank_table_file.c_str(), std::to_string(rank).c_str()); // HACK(sunpeng17): hcom API requires bind stream to a model
platform::dynload::hcom_bind_model(comm_wrapper->stream(), comm_wrapper->stream()); // but we don't need model in Paddle, so we feed stream pointer as model pointer
PADDLE_ENFORCE_NPU_SUCCESS(
platform::dynload::hcom_bind_model(comm_wrapper->stream(),
comm_wrapper->stream()));
// Get world_rank_ids registered in gen_nccl_id op
std::string group_name = HCOM_GROUP_PREFIX + std::to_string(ring_id);
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_create_group(
group_name.c_str(), world_rank_ids.size(), (unsigned int*)world_rank_ids.data()));
VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created on device " << dev_id << ", group name: " << group_name;
std::call_once(once_flag_, []() {
std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); });
});
VLOG(1) << "hccl communicator of rank " << rank << " has been created";
return comm_wrapper; return comm_wrapper;
} }
HCCLComm* HCCLCommContext::AssignHCCLComm(const std::string& rank_table_file, HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id) {
uint32_t rank, uint32_t device_id) {
std::unique_ptr<NPUDeviceContext> dev_ctx( std::unique_ptr<NPUDeviceContext> dev_ctx(
new NPUDeviceContext(NPUPlace(device_id))); new NPUDeviceContext(NPUPlace(dev_id)));
VLOG(3) << "device_id" << device_id;
VLOG(3) << "dev_ctx->stream()" << dev_ctx->stream();
HCCLCommImpl* c = new HCCLCommImpl; HCCLCommImpl* c = new HCCLCommImpl;
c->set_rank_table_file(rank_table_file); c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank); c->set_rank(rank);
c->set_device_id(device_id);
c->set_dev_ctx(std::move(dev_ctx)); c->set_dev_ctx(std::move(dev_ctx));
// comm_ = c
comm_.reset(c); comm_map_mutex_.lock();
return c; if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<HCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<HCCLComm>(c));
comm_map_mutex_.unlock();
return comm_map_[ring_id][dev_id].get();
} }
void HCCLCommContext::CreateHCCLGroup(const std::string& group_name, uint32_t nranks, void HCCLCommContext::InitHcomWorldGroup() {
const std::vector<uint32_t>& rank_ids) { const char *rank_table_file = getenv(ENV_RANK_TABLE_FILE);
/* PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL(group_name, rank_table_file,
platform::errors::InvalidArgument( platform::errors::InvalidArgument("The RANK_TABLE_FILE environment variable should not be null."));
"The group name should not be null."));
PADDLE_ENFORCE_GT(nranks, 0,
platform::errors::InvalidArgument(
"Expected nranks > 0. But received nranks is %d.", nranks));
PADDLE_ENFORCE_NOT_NULL(rank_ids,
platform::errors::InvalidArgument(
"The rank ids should not be null."));
*/
platform::dynload::hcom_create_group(group_name.c_str(), nranks, (unsigned int*)rank_ids.data());
VLOG(1) << "hccl group with name " << group_name << " has been created"; const char *rank_id = getenv(ENV_RANK_ID);
PADDLE_ENFORCE_NOT_NULL(
rank_id,
platform::errors::InvalidArgument("The RANK_ID environment variable should not be null."));
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_init(rank_table_file, rank_id));
VLOG(3) << "Successfully initialized hcom. rank_table_file: "
<< rank_table_file << ", rank_id " << rank_id;
}
void HCCLCommContext::ReleaseHCCLComms() {
for (auto& p : comm_map_) {
for (auto& q : p.second) {
q.second.reset();
}
}
} }
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#endif #endif
...@@ -178,14 +178,6 @@ class NPUDeviceContext : public DeviceContext { ...@@ -178,14 +178,6 @@ class NPUDeviceContext : public DeviceContext {
/*! \brief Return npu stream in the device context. */ /*! \brief Return npu stream in the device context. */
aclrtStream stream() const; aclrtStream stream() const;
#ifdef PADDLE_WITH_ASCEND_HCCL
/*! \brief Return bkcl context. */
HCCLContext_t hccl_context() const { return hccl_context_; }
/*! \brief Set bkcl context. */
void set_hccl_context(HCCLContext_t context) { hccl_context_ = context; }
#endif
private: private:
NPUPlace place_; NPUPlace place_;
aclrtContext context_; aclrtContext context_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册