未验证 提交 45765d6e 编写于 作者: V Void Main 提交者: GitHub

Refactor HCCLCommContext to be compatible with Paddle (#31359)

Refactor HCCLCommContext to be compatible with Paddle (#31359)
上级 8497e2aa
......@@ -36,4 +36,4 @@ endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")
cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hccl_op c_create_group_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hcom_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
......@@ -135,7 +135,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
group = "hccl_world_group";// std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get();
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
......
......@@ -34,8 +34,9 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
auto out = ctx.Output<framework::LoDTensor>("Out");
int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = paddle::platform::HCCLCommContext::Instance().Get();
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
......@@ -46,7 +47,6 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
}
int root = ctx.Attr<int>("root");
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
......@@ -40,16 +41,24 @@ class CCommInitOpNPU : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::string rank_table_file = Attr<std::string>("rank_table_file");
uint32_t rank_id = Attr<int>("rank_id");
uint32_t device_id = Attr<int>("device_id");
VLOG(3) << "begin init hccl, parameter is: "
<< "rank_table_file " << rank_table_file
<< " rank_id " << rank_id
<< " device_id " << device_id;
platform::HCCLCommContext::Instance().CreateHCCLComm(rank_table_file, rank_id, device_id);
int rid = Attr<int>("ring_id");
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int device_id = BOOST_GET_CONST(platform::NPUPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
std::vector<int> rank_ids = Attr<std::vector<int>>("rank_ids");
VLOG(3) << "begin c_comm_init on npu, parameters are: "
<< "ring id[" << rid
<< "], nranks[" << nranks
<< "], rank_id[" << rank_id
<< "], device_id[" << device_id
<< "]";
platform::HCCLCommContext::Instance().CreateHCCLComm(
rank_ids, rank_id, device_id, rid);
}
};
......@@ -61,10 +70,17 @@ CCommInit operator on NPU
Initialize collective communication context within this trainer
)DOC");
AddAttr<std::string>("rank_table_file",
"(string) path to rank_table_file");
AddAttr<int>("rank_id", "(int) world rank id of the process");
AddAttr<int>("device_id", "(int) device id of the process/thread");
AddAttr<int>("nranks", "(int) The number of ranks of distributed trainers");
AddAttr<std::vector<int>>("rank_ids", "The world rank ids of the group");
AddAttr<int>("rank",
"(int) The rank of the trainer in distributed training.");
AddAttr<int>("device_id",
"(int) The deivce_id on which to initialize the communicator."
"Now, you only have to set this attr manually for pipeline "
"training. Otherwise, make it as default.")
.SetDefault(-1);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
......@@ -73,7 +89,6 @@ Initialize collective communication context within this trainer
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init_hccl, ops::CCommInitOpNPU,
ops::CCommInitOpNPUMaker);
REGISTER_OPERATOR(c_comm_init_hcom, ops::CCommInitOpNPU, ops::CCommInitOpNPUMaker);
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class CCreateGroupOpNPU : public framework::OperatorBase {
public:
CCreateGroupOpNPU(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::string group_name = Attr<std::string>("group_name");
int nranks = Attr<int>("nranks");
std::vector<int> rank_ids = Attr<std::vector<int>>("rank_ids");
paddle::platform::HCCLCommContext::Instance().CreateHCCLGroup(
group_name, (uint32_t)nranks,
std::vector<uint32_t>(rank_ids.begin(), rank_ids.end()));
}
};
class CCreateGroupOpNPUMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
CCreateGroup operator on NPU
Create collective communication group on NPU
)DOC");
AddAttr<std::string>("group_name",
"(string) name of the collective communication group");
AddAttr<int>("nranks", "(int) number of the group");
AddAttr<std::vector<int>>("rank_ids",
"(list of int) The world rank id of the group members");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_create_group, ops::CCreateGroupOpNPU,
ops::CCreateGroupOpNPUMaker);
#endif
......@@ -43,37 +43,29 @@ namespace m = paddle::operators::math;
USE_OP(c_broadcast);
USE_OP(c_allreduce_sum);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_NO_KERNEL_OP(c_create_group);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx){
std::string rank_table_file = getenv("RANK_TABLE_FILE");
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
printf("rank_table_file: %s, rank_id = %d, device_id = %d\n", rank_table_file.c_str(), rank_id, device_id);
printf("rank_id = %d, device_id = %d\n", rank_id, device_id);
f::AttributeMap attrs;
attrs["rank_table_file"] = rank_table_file;
attrs["rank_id"] = rank_id;
attrs["device_id"] = device_id;
std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hccl", {}, {}, attrs);
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
f::AttributeMap create_attrs;
create_attrs["group_name"] = HCOM_GROUP_PREFIX + std::to_string(0);
create_attrs["nranks"] = 2;
std::vector<int> rank_ids{0, 1};
create_attrs["rank_ids"] = rank_ids;
auto create_group_op = f::OpRegistry::CreateOp("c_create_group", {}, {}, create_attrs);
create_group_op->Run(*scope, place);
ctx.Wait();
}
void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
std::cout<< "BEGIN TEST:" << __FUNCTION__ <<std::endl;
......@@ -135,7 +127,7 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
std::vector<float> init;
int rank_id = atoi(getenv("RANK_ID"));
std::cout<< "rank_id:" << rank_id<<std::endl;
int num1 = 1;
int num2 = 4;
......@@ -184,9 +176,9 @@ TEST(c_broadcast, NPU) {
f::Scope scope;
char * npu_id=getenv("FLAGS_selected_npus");
p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
Prepare(&scope, ctx);
// TestHCCLBroadcastOp(&scope, ctx);
TestHCCLAllReduceOp(&scope, ctx);
TestHCCLBroadcastOp(&scope, ctx);
// TestHCCLAllReduceOp(&scope, ctx);
}
......@@ -147,11 +147,16 @@ class NCCLCommContext {
// singleton with a global user specified group id.
class NPUDeviceContext;
#define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE"
#define ENV_RANK_ID "RANK_ID"
#define ENV_DEV_ID "DEV_ID"
class HCCLComm {
public:
virtual std::string rank_table_file() const = 0;
virtual uint32_t rank() const = 0;
virtual uint32_t device_id() const = 0;
virtual int ring_id() const = 0;
virtual int nranks() const = 0;
virtual int rank() const = 0;
virtual int device_id() const = 0;
virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default;
......@@ -165,22 +170,56 @@ class HCCLCommContext {
return comm_ctx;
}
HCCLComm* CreateHCCLComm(const std::string& config_file, uint32_t rank, uint32_t device_id);
HCCLComm* CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id = 0);
// a latter comm with the same dev_id and the same ring_id
// will override the former
HCCLComm* AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id = 0);
// retrieve a communicator by the ring id in multiprocessing mode
HCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"Communicator in ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1,
platform::errors::InvalidArgument(
"One device id should be specified to retrieve from "
"multiple communicators."));
return comm_map_.at(ring_id).begin()->second.get();
}
void CreateHCCLGroup(const std::string& group_name, uint32_t nranks, const std::vector<uint32_t>& rank_ids);
// retrieve a communicator by the ring id and the device id
HCCLComm* Get(int ring_id, int dev_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"Communicator of ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_GT(
comm_map_.at(ring_id).count(dev_id), 0,
platform::errors::InvalidArgument(
"Communicator at device id %d has not been initialized in ring %d.",
dev_id, ring_id));
return comm_map_.at(ring_id).at(dev_id).get();
}
// retrieve a communicator by the ring id and place
HCCLComm* Get() const {
return comm_.get();
HCCLComm* Get(int ring_id, Place place) const {
return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device);
}
private:
// Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); }
std::once_flag once_flag_;
std::mutex comm_map_mutex_;
std::unique_ptr<HCCLComm> comm_;
// ring id to dev-HCCLComm
std::map<int, std::map<int, std::unique_ptr<HCCLComm>>> comm_map_;
HCCLComm* AssignHCCLComm(const std::string& config_file, uint32_t rank, uint32_t device_id);
void InitHcomWorldGroup();
void ReleaseHCCLComms();
HCCLCommContext() = default;
DISABLE_COPY_AND_ASSIGN(HCCLCommContext);
};
#endif
......
......@@ -21,14 +21,18 @@ namespace platform {
class HCCLCommImpl : public HCCLComm {
public:
void set_rank_table_file(const std::string& rank_table_file) { rank_table_file_ = rank_table_file; }
std::string rank_table_file() const override { return rank_table_file_; }
void set_ring_id(int ring_id) { ring_id_ = ring_id; }
int ring_id() const override { return ring_id_; }
void set_rank(uint32_t rank) { rank_ = rank; }
uint32_t rank() const override { return rank_; }
void set_nranks(int nranks) { nranks_ = nranks; }
int nranks() const override { return nranks_; }
void set_device_id(uint32_t device_id) { device_id_ = device_id; }
uint32_t device_id() const override { return device_id_; }
void set_rank(int rank) { rank_ = rank; }
int rank() const override { return rank_; }
int device_id() const override {
return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device;
}
aclrtStream stream() const override { return dev_ctx_->stream(); }
......@@ -38,74 +42,103 @@ class HCCLCommImpl : public HCCLComm {
NPUDeviceContext* dev_context() const override { return dev_ctx_.get(); }
private:
std::string rank_table_file_;
uint32_t rank_;
uint32_t device_id_;
int ring_id_;
int nranks_;
int rank_;
std::unique_ptr<NPUDeviceContext> dev_ctx_;
};
HCCLComm* HCCLCommContext::CreateHCCLComm(const std::string& rank_table_file,
uint32_t rank, uint32_t device_id) {
/*
PADDLE_ENFORCE_NOT_NULL(rank_table_file,
platform::errors::InvalidArgument(
"The rank table file should not be null."));
HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_GT(
world_rank_ids.size(), 1,
platform::errors::InvalidArgument(
"Expected world_rank_ids.size() > 1. But received size is %d.", world_rank_ids.size()));
PADDLE_ENFORCE_GE(rank, 0,
platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_LT(
rank, world_rank_ids.size(),
platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_GE(device_id, 0,
"Expected rank < nranks. But received rank is %d, nranks is %d.",
rank, world_rank_ids.size()));
PADDLE_ENFORCE_GE(
dev_id, 0,
platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", device_id));
*/
auto* comm_wrapper = AssignHCCLComm(rank_table_file, rank, device_id);
"Expected dev_id >= 0. But received dev_id is %d.", dev_id));
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"Expected ring_id >= 0. But received ring_id is %d.", ring_id));
auto* comm_wrapper = AssignHCCLComm(world_rank_ids.size(), rank, dev_id, ring_id);
// HACK(sunpeng17): hcom API requires bind stream to a model
// but we don't need model in Paddle, so we feed stream pointer as model pointer
PADDLE_ENFORCE_NPU_SUCCESS(
platform::dynload::hcom_bind_model(comm_wrapper->stream(),
comm_wrapper->stream()));
platform::dynload::hcom_init(rank_table_file.c_str(), std::to_string(rank).c_str());
platform::dynload::hcom_bind_model(comm_wrapper->stream(), comm_wrapper->stream());
// Get world_rank_ids registered in gen_nccl_id op
std::string group_name = HCOM_GROUP_PREFIX + std::to_string(ring_id);
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_create_group(
group_name.c_str(), world_rank_ids.size(), (unsigned int*)world_rank_ids.data()));
VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created on device " << dev_id << ", group name: " << group_name;
std::call_once(once_flag_, []() {
std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); });
});
VLOG(1) << "hccl communicator of rank " << rank << " has been created";
return comm_wrapper;
}
HCCLComm* HCCLCommContext::AssignHCCLComm(const std::string& rank_table_file,
uint32_t rank, uint32_t device_id) {
HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id) {
std::unique_ptr<NPUDeviceContext> dev_ctx(
new NPUDeviceContext(NPUPlace(device_id)));
VLOG(3) << "device_id" << device_id;
VLOG(3) << "dev_ctx->stream()" << dev_ctx->stream();
new NPUDeviceContext(NPUPlace(dev_id)));
HCCLCommImpl* c = new HCCLCommImpl;
c->set_rank_table_file(rank_table_file);
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_device_id(device_id);
c->set_dev_ctx(std::move(dev_ctx));
// comm_ = c
comm_.reset(c);
return c;
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<HCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<HCCLComm>(c));
comm_map_mutex_.unlock();
return comm_map_[ring_id][dev_id].get();
}
void HCCLCommContext::CreateHCCLGroup(const std::string& group_name, uint32_t nranks,
const std::vector<uint32_t>& rank_ids) {
/*
PADDLE_ENFORCE_NOT_NULL(group_name,
platform::errors::InvalidArgument(
"The group name should not be null."));
PADDLE_ENFORCE_GT(nranks, 0,
platform::errors::InvalidArgument(
"Expected nranks > 0. But received nranks is %d.", nranks));
PADDLE_ENFORCE_NOT_NULL(rank_ids,
platform::errors::InvalidArgument(
"The rank ids should not be null."));
*/
platform::dynload::hcom_create_group(group_name.c_str(), nranks, (unsigned int*)rank_ids.data());
VLOG(1) << "hccl group with name " << group_name << " has been created";
void HCCLCommContext::InitHcomWorldGroup() {
const char *rank_table_file = getenv(ENV_RANK_TABLE_FILE);
PADDLE_ENFORCE_NOT_NULL(
rank_table_file,
platform::errors::InvalidArgument("The RANK_TABLE_FILE environment variable should not be null."));
const char *rank_id = getenv(ENV_RANK_ID);
PADDLE_ENFORCE_NOT_NULL(
rank_id,
platform::errors::InvalidArgument("The RANK_ID environment variable should not be null."));
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_init(rank_table_file, rank_id));
VLOG(3) << "Successfully initialized hcom. rank_table_file: "
<< rank_table_file << ", rank_id " << rank_id;
}
void HCCLCommContext::ReleaseHCCLComms() {
for (auto& p : comm_map_) {
for (auto& q : p.second) {
q.second.reset();
}
}
}
} // namespace platform
} // namespace paddle
#endif
......@@ -178,14 +178,6 @@ class NPUDeviceContext : public DeviceContext {
/*! \brief Return npu stream in the device context. */
aclrtStream stream() const;
#ifdef PADDLE_WITH_ASCEND_HCCL
/*! \brief Return bkcl context. */
HCCLContext_t hccl_context() const { return hccl_context_; }
/*! \brief Set bkcl context. */
void set_hccl_context(HCCLContext_t context) { hccl_context_ = context; }
#endif
private:
NPUPlace place_;
aclrtContext context_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册