未验证 提交 ce26f882 编写于 作者: L lw921014 提交者: GitHub

update Ascendrc hccl to 20.3 (#32126)

update Ascendrc hccl to 20.3 (#32126)
上级 75dd8423
......@@ -400,6 +400,7 @@ OperatorBase::OperatorBase(const std::string& type,
// framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
// Inputs, outputs and attrs will be set to empty map
// to improve the execution efficiency of dygraph.
if (inputs_.size() > 0 || outputs_.size() > 0) {
GenerateTemporaryNames();
CheckAllInputOutputSet();
......
......@@ -31,6 +31,11 @@
#endif
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include <hccl/hccl.h>
#include <hccl/hccl_types.h>
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif
......@@ -45,6 +50,10 @@ class Communicator;
class NCCLCommunicator;
#endif
#endif
#ifdef PADDLE_WITH_ASCEND_CL
class Communicator;
class HCCLCommunicator;
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
class BKCLCommunicator;
......@@ -157,6 +166,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#endif
operators::CudnnRNNCache,
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
HcclRootInfo,
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId, platform::BKCLCommunicator,
#endif
......
......@@ -11,7 +11,7 @@ foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS})
endforeach()
register_operators(EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
register_operators(EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op c_gen_hccl_id_op gen_hccl_id_op DEPS ${COLLECTIVE_DEPS})
if(WITH_NCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
......@@ -24,39 +24,43 @@ if(WITH_GLOO)
endif()
if(WITH_XPU_BKCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper nccl_common)
op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
endif()
if(WITH_ASCEND_CL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
cc_library(gen_hccl_id_op_helper SRCS gen_hccl_id_op_helper.cc DEPS dynload_warpctc dynamic_loader scope)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper gen_hccl_id_op_helper)
op_library(c_gen_hccl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_hccl_id_op DEPS ${COLLECTIVE_DEPS})
endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")
if(WITH_ASCEND_CL)
set(COMMON_TEST_DEPS_FOR_HCOM c_comm_init_hcom_op op_registry ascend_hccl flags
set(COMMON_TEST_DEPS_FOR_HCOM c_comm_init_hccl_op c_gen_hccl_id_op gen_hccl_id_op_helper
gen_hccl_id_op op_registry ascend_hccl flags
dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_broadcast_op_npu_test SRCS c_broadcast_op_npu_test.cc
DEPS c_broadcast_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_sum_op_npu_test SRCS c_allreduce_sum_op_npu_test.cc
DEPS c_allreduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc
DEPS c_allreduce_max_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_reduce_sum_op_npu_test SRCS c_reduce_sum_op_npu_test.cc
DEPS c_reduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_reducescatter_op_npu_test SRCS c_reducescatter_op_npu_test.cc
DEPS c_reducescatter_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc
DEPS c_allgather_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_reduce_sum_op_npu_test SRCS c_reduce_sum_op_npu_test.cc
DEPS c_reduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc
DEPS c_allreduce_max_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(send_v2_op_npu_test SRCS send_v2_op_npu_test.cc
DEPS send_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(recv_v2_op_npu_test SRCS recv_v2_op_npu_test.cc
DEPS recv_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_sync_comm_stream_op_npu_test SRCS c_sync_comm_stream_op_npu_test.cc
DEPS op_registry c_broadcast_op c_comm_init_hcom_op c_sync_comm_stream_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
DEPS op_registry c_broadcast_op c_comm_init_hccl_op c_sync_comm_stream_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_sync_calc_stream_op_npu_test SRCS c_sync_calc_stream_op_npu_test.cc
DEPS op_registry elementwise_add_op c_sync_calc_stream_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
endif()
......@@ -31,20 +31,19 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_ASCEND_CL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
hcclDataType_t dtype = platform::ToHCCLDataType(in->type());
HcclDataType dtype = platform::ToHCCLDataType(in->type());
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel();
uint64_t send_numel = in->numel();
void *send_buff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
void *recv_buff = reinterpret_cast<void*>(out->data<T>());
......@@ -59,12 +58,11 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
VLOG(3) << "begin hccl allgather, parameter is: "
<< ", group is " << group
<< ", ring_id is " << ring_id
<< ", nranks is " << nranks
<< ", tag is " << tag;
<< ", nranks is " << nranks;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_gather(
tag.c_str(), send_buff, recv_buff, (u64)send_numel, dtype,
group.c_str(), (void*)stream));
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllGather(
send_buff, recv_buff, send_numel, dtype,
comm->comm(), (void*)stream));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -45,7 +46,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(c_allgather);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allgather, NPU);
DECLARE_string(selected_npus);
......@@ -56,26 +58,68 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(",");
}
VLOG(2) << preStr << ":" << std::endl << debugstring;
VLOG(2) << preStr << ":" << std::endl <<debugstring;
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE"));
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
......@@ -83,7 +127,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto x = scope->Var("X");
auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init;
......@@ -102,7 +146,7 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait();
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate
......@@ -110,12 +154,12 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
// run
f::AttributeMap attrs;
attrs["tag"] = std::string("tagx");
attrs["ring_id"] = 0;
attrs["nranks"] = 2;
attrs["tag"]=std::string("tagx");
attrs["ring_id"]=0;
attrs["nranks"]=2;
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"Data"}}},
{{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) {
op->Run(*scope, place);
......@@ -139,11 +183,12 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
TEST(c_allgather, NPU) {
f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get(
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx);
TestHCCLAllGatherOp(&scope, *ctx);
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
TestHCCLAllGatherOp(&scope, ctx);
}
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -45,7 +46,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(c_allreduce_max);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allreduce_max, NPU);
DECLARE_string(selected_npus);
......@@ -59,23 +61,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(2) << preStr << ":" << std::endl << debugstring;
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE"));
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
......@@ -83,7 +127,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto x = scope->Var("X");
auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init;
......@@ -102,7 +146,7 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait();
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate
......@@ -113,8 +157,8 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
attrs["tag"] = std::string("tagx");
attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"Data"}}},
{{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) {
op->Run(*scope, place);
......@@ -135,11 +179,12 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
TEST(c_allreduce_max, NPU) {
f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get(
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx);
TestHCCLAllReduceOp(&scope, *ctx);
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
TestHCCLAllReduceOp(&scope, ctx);
}
......@@ -117,34 +117,18 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
// we need to pre-allocate 512 Bytes before the data
// and 512 Bytes after the data, so the hccl allreduce
// can work. This is a must acooding to huawei peer.
#define PRE_MALLOC_SIZE_BYTES 512
auto in = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
auto place = ctx.GetPlace();
hcclDataType_t dtype = platform::ToHCCLDataType(in->type());
HcclDataType dtype = platform::ToHCCLDataType(in->type());
int64_t numel = in->numel();
int64_t pre_tmp_size = PRE_MALLOC_SIZE_BYTES / sizeof(T);
int64_t tmp_numel = numel + pre_tmp_size * 2;
paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel});
auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
void* recvbuff = reinterpret_cast<void*>(out->data<T>());
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
......@@ -154,33 +138,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
stream = comm->stream();
}
// we need to memset this memory firstly to avoid core by hccl
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff,
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T),
stream);
hcclRedOp_t hccl_red_type = HCCL_REP_OP_SUM;
HcclReduceOp hccl_red_type = HCCL_REDUCE_SUM;
switch (red_type) {
case kRedSum:
hccl_red_type = HCCL_REP_OP_SUM;
hccl_red_type = HCCL_REDUCE_SUM;
break;
case kRedMax:
hccl_red_type = HCCL_REP_OP_MAX;
hccl_red_type = HCCL_REDUCE_MAX;
break;
case kRedMin:
hccl_red_type = HCCL_REP_OP_MIN;
hccl_red_type = HCCL_REDUCE_MIN;
break;
case kRedProd:
hccl_red_type = HCCL_REP_OP_PROD;
hccl_red_type = HCCL_REDUCE_PROD;
break;
default:
......@@ -192,16 +165,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
<< "input num: " << numel
<< "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type
<< ", group is: " << group
<< ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_reduce(
tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream));
<< ", group is: " << group;
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()),
npu_place, recvbuff,
numel * sizeof(T),
stream);
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce(
sendbuff, recvbuff, numel, dtype, hccl_red_type, comm->comm(), (void*)stream));
out->Resize(in->dims());
#else
......
......@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -42,7 +43,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(c_allreduce_sum);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU);
DECLARE_string(selected_npus);
......@@ -56,23 +58,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(3) << preStr << ":" << std::endl << debugstring;
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE"));
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
......@@ -81,7 +125,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
int iter) {
// init
auto x = scope->Var("X");
auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>();
int rank_id = atoi(getenv("RANK_ID"));
......@@ -100,7 +144,7 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
tensor_x->Resize({num1, num2});
ctx.Wait();
auto out = scope->Var("Out");
auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate
......@@ -111,8 +155,10 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
attrs["tag"] = std::string("tagx_" + std::to_string(iter));
attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp("c_allreduce_sum", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
auto op = f::OpRegistry::CreateOp("c_allreduce_sum",
{{"X", {"Data"}}},
{{"Out", {"OutData"}}},
attrs);
for (int i = 0; i < 10; i++) {
op->Run(*scope, place);
......@@ -133,14 +179,17 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
TEST(c_allreduce_sum, NPU) {
f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get(
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
// auto* ctx = p::DeviceContextPool::Instance().Get(
// p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx);
for (int i = 0; i < 1; i++) {
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
for(int i = 0; i < 1; i ++){
VLOG(2) << "iter num: " << i;
TestHCCLAllReduceOp(&scope, *ctx, i);
TestHCCLAllReduceOp(&scope, ctx, i);
}
}
......@@ -30,7 +30,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
auto x = ctx.Input<framework::LoDTensor>("X");
void *ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
int numel = x->numel();
hcclDataType_t dtype = platform::ToHCCLDataType(x->type());
HcclDataType dtype = platform::ToHCCLDataType(x->type());
auto out = ctx.Output<framework::LoDTensor>("Out");
......@@ -48,14 +48,12 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root
<< ", group is " << group
<< ", tag is " << tag;
<< ", group is " << group << ", comm: " << comm->comm() << ", stream: " << stream;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_broadcast(tag.c_str(), ptr, numel,
dtype, (uint32_t)root, group.c_str(), (void*)stream));
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(ptr, numel,
dtype, (uint32_t)root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims());
......
......@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -42,7 +43,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(c_broadcast);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
DECLARE_string(selected_npus);
......@@ -53,26 +55,68 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(",");
}
VLOG(2) << preStr << ":" << std::endl << debugstring;
VLOG(2) << preStr << ":" << std::endl <<debugstring;
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
......@@ -80,7 +124,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto x = scope->Var("X");
auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>();
int num = 2;
std::vector<float> init;
......@@ -96,7 +140,7 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait();
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num, num});
tensor_out->mutable_data<float>(place); // allocate
......@@ -108,8 +152,8 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
attrs["root"] = 0;
attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"Data"}}},
{{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) {
op->Run(*scope, place);
......@@ -129,11 +173,11 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
TEST(c_broadcast, NPU) {
f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get(
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx);
TestHCCLBroadcastOp(&scope, *ctx);
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
TestHCCLBroadcastOp(&scope, ctx);
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,66 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle {
namespace operators {
class CCommInitOpNPU : public framework::OperatorBase {
class CCommInitOpAscend : public framework::OperatorBase {
public:
CCommInitOpNPU(const std::string& type,
const framework::VariableNameMap& inputs,
CCommInitOpAscend(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
int rid = Attr<int>("ring_id");
int nranks = Attr<int>("nranks");
PADDLE_ENFORCE_EQ(is_npu_place(place), true,
platform::errors::PreconditionNotMet(
"CCommInitOpAscend can run on npu place only."));
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
#if defined(PADDLE_WITH_ASCEND_CL)
HcclRootInfo* hccl_id = var->GetMutable<HcclRootInfo>();
int rank_ids = Attr<int>("rank_ids");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::NPUPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
std::vector<int> rank_ids = Attr<std::vector<int>>("rank_ids");
VLOG(3) << "begin c_comm_init on npu, parameters are: "
<< "ring id[" << rid
<< "], nranks[" << nranks
<< "], rank_id[" << rank_id
<< "], device_id[" << device_id
<< "]";
platform::HCCLCommContext::Instance().CreateHCCLComm(
rank_ids, rank_id, device_id, rid);
hccl_id, rank_ids, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU."));
#endif
}
};
class CCommInitOpNPUMaker : public framework::OpProtoAndCheckerMaker {
class CCommInitOpAscendMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CCommInit operator on NPU
CCommInit operator
Initialize collective communication context within this trainer
Initialize collective communicatoin context within this trainer
)DOC");
AddAttr<int>("nranks", "(int) The number of ranks of distributed trainers");
AddAttr<std::vector<int>>("rank_ids", "The world rank ids of the group");
AddAttr<int>("rank_ids", "(int) The number of ranks of distributed trainers");
AddAttr<int>("rank",
"(int) The rank of the trainer in distributed training.");
AddAttr<int>("device_id",
......@@ -89,6 +90,4 @@ Initialize collective communication context within this trainer
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init_hcom, ops::CCommInitOpNPU, ops::CCommInitOpNPUMaker);
#endif
REGISTER_OPERATOR(c_comm_init_hccl, ops::CCommInitOpAscend, ops::CCommInitOpAscendMaker);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#endif
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_ASCEND_CL
class CGenHCCLIdOp : public framework::OperatorBase {
public:
CGenHCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};
if (rank == 0) {
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
SendBroadCastHCCLID(endpoint_list, 1, func, local_scope);
} else {
std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastHCCLID(endpoint, 1, func, local_scope);
}
scope.DeleteScope(&local_scope);
}
};
#else
class CGenHCCLIdOp : public framework::OperatorBase {
public:
CGenHCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
}
};
#endif
class CGenHCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
VLOG(3) << "ele";
AddOutput("Out", "Raw variable contains a HCCL UniqueId instaces.");
AddComment(R"DOC(
CGenHCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::string>("endpoint",
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint");
AddAttr<std::vector<std::string>>(
"other_endpoints",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of other trainer endpoints")
.SetDefault({});
AddAttr<int>("rank",
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_gen_hccl_id, ops::CGenHCCLIdOp, ops::CGenHCCLIdOpMaker);
......@@ -63,6 +63,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
};
#else
class CGenNCCLIdOp : public framework::OperatorBase {
public:
CGenNCCLIdOp(const std::string& type,
......
......@@ -121,31 +121,15 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
// we need to pre-allocate 512 Bytes before the data
// and 512 Bytes after the data, so the hccl allreduce
// can work. This is a must acooding to huawei peer.
#define PRE_MALLOC_SIZE_BYTES 512
auto in = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
auto place = ctx.GetPlace();
hcclDataType_t dtype = platform::ToHCCLDataType(in->type());
HcclDataType dtype = platform::ToHCCLDataType(in->type());
int64_t numel = in->numel();
int64_t pre_tmp_size = PRE_MALLOC_SIZE_BYTES / sizeof(T);
int64_t tmp_numel = numel + pre_tmp_size * 2;
paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel});
auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
void* recvbuff = reinterpret_cast<void*>(out->data<T>());
std::string tag = ctx.Attr<std::string>("tag");
int ring_id = ctx.Attr<int>("ring_id");
int root_id = ctx.Attr<int>("root_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
......@@ -161,33 +145,22 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> {
int rank_id = comm->rank();
// we need to memset this memory firstly to avoid core by hccl
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff,
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T),
stream);
hcclRedOp_t hccl_red_type = HCCL_REP_OP_SUM;
HcclReduceOp hccl_red_type = HCCL_REDUCE_SUM;
switch (red_type) {
case kRedSum:
hccl_red_type = HCCL_REP_OP_SUM;
hccl_red_type = HCCL_REDUCE_SUM;
break;
case kRedMax:
hccl_red_type = HCCL_REP_OP_MAX;
hccl_red_type = HCCL_REDUCE_MAX;
break;
case kRedMin:
hccl_red_type = HCCL_REP_OP_MIN;
hccl_red_type = HCCL_REDUCE_MIN;
break;
case kRedProd:
hccl_red_type = HCCL_REP_OP_PROD;
hccl_red_type = HCCL_REDUCE_PROD;
break;
default:
......@@ -200,18 +173,14 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> {
<< "root_id: " << root_id
<< "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type
<< ", group is: " << group
<< ", tag is " << tag;
<< ", group is: " << group;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_reduce(
tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream));
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce(
sendbuff, recvbuff, numel, dtype, hccl_red_type, comm->comm(), (void*)stream));
if(rank_id == root_id){
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()),
npu_place, recvbuff,
numel * sizeof(T),
stream);
}else{
if(rank_id != root_id){
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()),
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T),
......
......@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/c_reduce_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -42,7 +43,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(c_reduce_sum);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_reduce_sum, NPU);
DECLARE_string(selected_npus);
......@@ -56,23 +58,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(3) << preStr << ":" << std::endl << debugstring;
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE"));
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
......@@ -80,7 +124,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
// init
auto x = scope->Var("X");
auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>();
int rank_id = atoi(getenv("RANK_ID"));
......@@ -99,7 +143,7 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
tensor_x->Resize({num1, num2});
ctx.Wait();
auto out = scope->Var("Out");
auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate
......@@ -112,8 +156,10 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
int root_id = 0;
attrs["root_id"] = root_id;
auto op = f::OpRegistry::CreateOp("c_reduce_sum", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
auto op = f::OpRegistry::CreateOp("c_reduce_sum",
{{"X", {"Data"}}},
{{"Out", {"OutData"}}},
attrs);
op->Run(*scope, place);
ctx.Wait();
......@@ -136,14 +182,15 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
TEST(c_reduce_sum, NPU) {
f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get(
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx);
for (int i = 0; i < 2; i++) {
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
for(int i = 0; i < 2; i ++){
VLOG(2) << "iter num: " << i;
TestHCCLReduceOp(&scope, *ctx, i);
TestHCCLReduceOp(&scope, ctx, i);
}
}
......@@ -35,7 +35,6 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
auto out_dims = in->dims();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0,
......@@ -47,11 +46,11 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
int64_t recv_numel = in->numel() / nranks;
uint64_t recv_numel = in->numel() / nranks;
void* inputPtr = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
void* outputPtr = reinterpret_cast<void*>(out->data<T>());
hcclDataType_t dtype = platform::ToHCCLDataType(in->type());
HcclDataType dtype = platform::ToHCCLDataType(in->type());
aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
......@@ -63,12 +62,11 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
VLOG(3) << "begin hccl reduce scatter, parameter is: "
<< "recv_numel: " << recv_numel
<< "dtype: " << dtype
<< "hccl_red_type: " << HCCL_REP_OP_SUM
<< ", group is: " << group
<< ", tag is " << tag;
<< "hccl_red_type: " << HCCL_REDUCE_SUM
<< ", group is: " << group;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_reduce_scatter(
tag.c_str(), inputPtr, outputPtr, (u64)recv_numel, dtype, HCCL_REP_OP_SUM, group.c_str(), (void*)stream));
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclReduceScatter(
inputPtr, outputPtr, recv_numel, dtype, HCCL_REDUCE_SUM, comm->comm(), (void*)stream));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU."));
......
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -45,7 +46,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(c_reducescatter);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_reducescatter, NPU);
DECLARE_string(selected_npus);
......@@ -59,7 +61,8 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(2) << preStr << ":" << std::endl << debugstring;
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
......@@ -68,22 +71,63 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
}
void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto x = scope->Var("X");
auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init;
......@@ -101,7 +145,7 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait();
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate
......@@ -114,14 +158,14 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
attrs["ring_id"] = 0;
attrs["nranks"] = 2;
auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"Data"}}},
{{"Out", {"OutData"}}}, attrs);
int iter_num = 10;
for (int i = 0; i < iter_num; i++) {
op->Run(*scope, place);
}
ctx.Wait();
}
std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
......@@ -130,17 +174,18 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
PrintDebugInfo("output data", out_vec);
EXPECT_EQ(out_vec.size(), init.size() / 2);
for (uint32_t i = 0; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], iter_num + 1);
EXPECT_EQ(out_vec[i], 2.0);
}
}
TEST(c_reducescatter, NPU) {
f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get(
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx);
TestHCCLReduceScatterOp(&scope, *ctx);
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
TestHCCLReduceScatterOp(&scope, ctx);
}
......@@ -41,7 +41,7 @@ namespace m = paddle::operators::math;
USE_OP(c_broadcast);
USE_NO_KERNEL_OP(c_sync_comm_stream);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_ASCEND_CL
class GenHCCLIdOp : public framework::OperatorBase {
public:
GenHCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
std::vector<std::string> trainers =
Attr<std::vector<std::string>>("trainers");
int trainer_id = Attr<int>("trainer_id");
std::string endpoint = trainers[trainer_id];
PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument(
"trainer_id %d is less than 0. Its "
"valid range is [0, trainer_size)"));
PADDLE_ENFORCE_LT(
trainer_id, static_cast<int>(trainers.size()),
platform::errors::OutOfRange("trainer_id %d is out of range. Its valid "
"range is [0, trainer_size)",
trainer_id));
int hccl_comm_num = Attr<int>("hccl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
int inter_nranks = Attr<int>("hierarchical_allreduce_inter_nranks");
int inter_trainer_id = -1;
int exter_trainer_id = -1;
if (use_hierarchical_allreduce) {
PADDLE_ENFORCE_GT(
trainers.size(), 1,
platform::errors::PreconditionNotMet(
"The number of collective trainers %llu <= 1", trainers.size()));
PADDLE_ENFORCE_GT(
inter_nranks, 1,
platform::errors::PreconditionNotMet(
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
inter_nranks));
PADDLE_ENFORCE_EQ(
trainers.size() % inter_nranks, 0,
platform::errors::PreconditionNotMet(
"The number of trainers %llu mod inter_nranks %d is not equal 0",
trainers.size(), inter_nranks));
inter_trainer_id = trainer_id % inter_nranks;
if (trainer_id % inter_nranks == 0) {
exter_trainer_id = trainer_id / inter_nranks;
}
}
std::ostringstream ss;
for (size_t i = 0; i < trainers.size(); i++) {
ss << trainers[i] << ",";
}
VLOG(1) << "trainer_id:" << trainer_id
<< ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
<< ", hccl_comm_num:" << hccl_comm_num
<< ", inter_nranks:" << inter_nranks
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< ", trainers:" << ss.str();
int server_fd = -1;
/// 1. init flat
std::function<std::string(size_t)> func = platform::GetFlatHCCLVarName;
if (trainer_id == 0) {
// server endpoints
std::vector<std::string> flat_endpoints;
flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1,
trainers.end());
SendBroadCastHCCLID(flat_endpoints, hccl_comm_num, func, scope);
} else {
server_fd = CreateListenSocket(endpoint);
RecvBroadCastHCCLID(server_fd, endpoint, hccl_comm_num, func, scope);
}
/// 2. hierarchical inter ncclid
func = platform::GetHierarchicalInterHCCLVarName;
if (inter_trainer_id == 0) {
std::ostringstream ss;
ss << endpoint;
std::vector<std::string> inter_endpoints;
for (int i = trainer_id + 1; i < trainer_id + inter_nranks &&
i < static_cast<int>(trainers.size());
i++) {
ss << ",";
inter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
SendBroadCastHCCLID(inter_endpoints, hccl_comm_num, func, scope);
} else if (inter_trainer_id > 0) {
VLOG(1) << "Hierarchical inter ring";
RecvBroadCastHCCLID(server_fd, endpoint, hccl_comm_num, func, scope);
}
/// 3. hierarchical exter ncclid
func = platform::GetHierarchicalExterHCCLVarName;
if (exter_trainer_id == 0) {
std::ostringstream ss;
std::vector<std::string> exter_endpoints;
ss << endpoint;
for (size_t i = inter_nranks; i < trainers.size(); i += inter_nranks) {
ss << ",";
exter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str();
SendBroadCastHCCLID(exter_endpoints, hccl_comm_num, func, scope);
} else if (exter_trainer_id > 0) {
VLOG(1) << "Hierarchical exter ring";
RecvBroadCastHCCLID(server_fd, endpoint, hccl_comm_num, func, scope);
}
// close socket server
if (trainer_id != 0) {
CloseSocket(server_fd);
}
}
};
#else
class GenHCCLIdOp : public framework::OperatorBase {
public:
GenHCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
}
};
#endif
class GenHCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("HCCLID", "Raw variable contains a HCCL UniqueId instaces.");
AddComment(R"DOC(
GenHCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::vector<std::string>>(
"trainers",
"['trainer0_ip:port', 'trainer1_ip:port', ...] "
"list of all trainer endpoints")
.SetDefault({});
AddAttr<int>("trainer_id",
"(int) "
"The index of the trainer in distributed training.");
AddAttr<int>("hccl_comm_num",
"(int default 1) "
"The number of nccl communicator num.")
.SetDefault(1);
AddAttr<bool>("use_hierarchical_allreduce",
"(bool default false) "
"Wheter to use hierarchical allreduce.")
.SetDefault(false);
AddAttr<int>("hierarchical_allreduce_inter_nranks",
"(int default 1) "
"Wheter to use hierarchical allreduce.")
.SetDefault(-1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gen_hccl_id, ops::GenHCCLIdOp, ops::GenHCCLIdOpMaker);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <algorithm>
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/split.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif
namespace paddle {
namespace operators {
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
#define HCCL_UNIQUE_ID_BYTES 1024
// Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \
do { \
int retval; \
CHECK_SYS_CALL_VAL(call, name, retval); \
} while (false)
#define CHECK_SYS_CALL_VAL(call, name, retval) \
do { \
RETRY_SYS_CALL_VAL(call, name, retval); \
if (retval == -1) { \
PADDLE_THROW(platform::errors::Unavailable("Call to %s failed: %s", \
name, strerror(errno))); \
} \
} while (false)
#define RETRY_SYS_CALL_VAL(call, name, retval) \
do { \
retval = (call); \
if (retval == -1 && \
(errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
LOG(WARNING) << "Call " << name << " returned " << strerror(errno) \
<< " retry"; \
} else { \
break; \
} \
} while (true)
static int SocketSend(int fd, const char* buffer, int size) {
int offset = 0;
int bytes = 0;
while (offset < size) {
bytes = send(fd, buffer + offset, size - offset, 0);
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
// send failed
return -1;
} else {
bytes = 0;
}
}
offset += bytes;
}
return offset;
}
static int SocketRecv(int fd, char* buffer, int size) {
int offset = 0;
int bytes = 0;
while (offset < size) {
bytes = recv(fd, buffer + offset, size - offset, 0);
if (bytes == 0) {
// closed by client, maybe probing alive client
return 0;
}
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
return -1;
} else {
bytes = 0;
}
}
offset += bytes;
}
return offset;
}
static void BindOrConnectFailed(int timeout, int* try_times, int* total_time,
const char* op, const std::string& ep) {
PADDLE_ENFORCE_LT(
*total_time, timeout,
platform::errors::Unavailable("%s addr=%s timeout, failed reason: %s", op,
ep.c_str(), strerror(errno)));
++(*try_times);
int retry_time = std::min(*try_times * 500, 3000); // max 3 seconds
*total_time += retry_time;
LOG(WARNING) << op << " addr=" << ep << " failed " << *try_times
<< " times with reason: " << strerror(errno) << " retry after "
<< retry_time / 1000.0 << " seconds";
std::this_thread::sleep_for(std::chrono::milliseconds(retry_time));
}
int CreateListenSocket(const std::string& ep) {
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
// creating socket fd
int server_fd = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", server_fd);
// NOTE. Solutions to `Address already in use`.
// 1. Reuse addr&port. Otherwise, once the server closes the socket
// before client, the server will enter TIME-WAIT status. If we bind port
// again, the error `Address already in use` will appear.
// 2. Or we can close the client first to ensure that the server does
// not enter the TIME-WAIT state. But this is obviously not as convenient
// as the reuse method.
int opt = 1;
#if defined(SO_REUSEPORT)
// since Linux kernel 3.9
CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT,
&opt, sizeof(opt)),
"setsockopt");
#else
CHECK_SYS_CALL(
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)),
"setsockopt");
#endif
struct sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(port);
// TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000;
int try_times = 0;
int total_time = 0;
while (true) {
int ret_val = -1;
RETRY_SYS_CALL_VAL(
bind(server_fd, (struct sockaddr*)&address, sizeof(address)), "bind",
ret_val);
if (ret_val == -1) {
BindOrConnectFailed(timeout, &try_times, &total_time, "bind", ep);
continue;
}
break;
}
CHECK_SYS_CALL(listen(server_fd, 3), "listen");
LOG(INFO) << "Server listening on: " << ep << " successful.";
return server_fd;
}
void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); }
static int SocketAccept(int server_fd, const char* head) {
struct sockaddr_in client_addr;
socklen_t addr_length = sizeof(client_addr);
char buffer[1024] = {0};
int conn = -1;
while (true) {
CHECK_SYS_CALL_VAL(
accept(server_fd, reinterpret_cast<struct sockaddr*>(&client_addr),
&addr_length),
"accept", conn);
int ret_val = SocketRecv(conn, buffer, strlen(head));
if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) {
break; // accept client
} else {
VLOG(3) << "socket read failed with ret_val=" << ret_val;
CloseSocket(conn);
}
}
return conn;
}
static int ConnectAddr(const std::string& ep, const char* head) {
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
int sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
char* ip = NULL;
struct hostent* hp = NULL;
hp = gethostbyname(host.c_str());
PADDLE_ENFORCE_NOT_NULL(hp, platform::errors::InvalidArgument(
"Fail to get host by name %s.", host));
int i = 0;
while (hp->h_addr_list[i] != NULL) {
ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]);
VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip;
break;
}
PADDLE_ENFORCE_GT(inet_pton(AF_INET, ip, &server_addr.sin_addr), 0,
platform::errors::Unavailable("Open address %s failed: %s",
ep, strerror(errno)));
// TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000;
int try_times = 0;
int total_time = 0;
while (true) {
int ret_val = -1;
RETRY_SYS_CALL_VAL(
connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)),
"connect", ret_val);
if (ret_val == -1) {
BindOrConnectFailed(timeout, &try_times, &total_time, "connect", ep);
continue;
}
CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send");
break;
}
return sock;
}
static void RecvHCCLID(int conn, HcclRootInfo* hccl_id) {
char buffer[1024] = {0};
static_assert(HCCL_UNIQUE_ID_BYTES <= 1024,
"hccl id bytes must <= buffer size");
CHECK_SYS_CALL(SocketRecv(conn, buffer, HCCL_UNIQUE_ID_BYTES), "recv hccl id");
memcpy(hccl_id, buffer, HCCL_UNIQUE_ID_BYTES);
}
static void SendHCCLID(int conn, HcclRootInfo* hccl_id) {
char buffer[1024] = {0};
memcpy(buffer, hccl_id, HCCL_UNIQUE_ID_BYTES);
CHECK_SYS_CALL(SocketSend(conn, buffer, HCCL_UNIQUE_ID_BYTES),
"send hccl id");
}
void SendBroadCastHCCLID(std::vector<std::string> servers, int hccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
// connect with server
std::vector<int> connects;
for (auto server : servers) {
VLOG(3) << "connecting endpoint: " << server;
int conn = ConnectAddr(server, COMM_HEAD);
connects.push_back(conn);
}
VLOG(3) << "connecting completed...";
for (int i = 0; i < hccl_comm_num; ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto hccl_id = var->GetMutable<HcclRootInfo>();
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclGetRootInfo(hccl_id));
int j = 0;
for (auto conn : connects) {
VLOG(3) << "sending hccl_id_var: " << var_name << " to " << servers[j]
<< " hccl_comm_no: " << i;
SendHCCLID(conn, hccl_id);
++j;
}
VLOG(3) << "sending completed...";
}
// close client
for (auto conn : connects) {
CloseSocket(conn);
}
}
void RecvBroadCastHCCLID(std::string endpoint, int hccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
int server = CreateListenSocket(endpoint);
RecvBroadCastHCCLID(server, endpoint, hccl_comm_num, func, scope);
CloseSocket(server);
}
void RecvBroadCastHCCLID(int server_fd, std::string endpoint, int hccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
int client = SocketAccept(server_fd, COMM_HEAD);
for (int i = 0; i < hccl_comm_num; ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto hccl_id = var->GetMutable<HcclRootInfo>();
VLOG(3) << "trainer: " << endpoint << " receiving hccl_id_var: " << var_name
<< " from trainer 0, hccl_comm_no: " << i;
RecvHCCLID(client, hccl_id);
}
VLOG(3) << "receiving completed...";
CloseSocket(client);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <string>
#include <vector>
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
int CreateListenSocket(const std::string& ep);
void CloseSocket(int fd);
void SendBroadCastHCCLID(std::vector<std::string> servers, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
// server listen on endpoint, then recv nccl id
void RecvBroadCastHCCLID(std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
// recv nccl id from socket
void RecvBroadCastHCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
} // namespace operators
} // namespace paddle
......@@ -27,32 +27,39 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = out->numel();
hcclDataType_t dtype = platform::ToHCCLDataType(out->type());
auto x = ctx.Output<framework::LoDTensor>("Out");
void *ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
int numel = x->numel();
HcclDataType dtype = platform::ToHCCLDataType(x->type());
int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
if (ctx.Attr<bool>("use_calc_stream")) {
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int srcRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
VLOG(3) << "recv_v2_npu attr get";
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_receive(
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(out->data<T>())), (u64)numel, dtype, srcRank,
srTag, group.c_str(), stream));
VLOG(3) << "Source Rank: " << srcRank << " Invoke hcom receive. receiving ";
out->Resize(out->dims());
out->set_lod(out->lod());
int nranks = comm->nranks();
int peer = ctx.Attr<int>("peer");
PADDLE_ENFORCE_EQ(nranks, 2,
platform::errors::InvalidArgument(
"The nranks must be 2, but (%d)",
nranks));
int root = peer;
VLOG(3) << "begin hccl recv, parameter is: "<< "root " << root
<< ", comm: " << comm->comm() << ", stream: " << stream;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(ptr, numel,
dtype, (uint32_t)root, comm->comm(), stream));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU."));
......
......@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/recv_v2_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -42,45 +44,86 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(recv_v2);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(recv_v2, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
std::string rank_table_file = getenv("RANK_TABLE_FILE");
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
int src_rank = atoi(getenv("SRC_RANK"));
int dest_rank = atoi(getenv("DEST_RANK"));
VLOG(3) << "rank_id " << rank_id << "src_rank" << src_rank << "dest_rank"
<< dest_rank;
std::vector<int> rank_ids = {0, 1};
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
VLOG(3) << "CreateOp c_comm_init_hcom";
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
}
void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){
std::cout << "BEGIN TEST:" << __FUNCTION__ << std::endl;
int num = atoi(getenv("DATA_SIZE"));
EXPECT_GT(num, 0);
EXPECT_LT(num, 1 << 15);
int rank_id = atoi(getenv("RANK_ID"));
VLOG(3) << "rank_id:" << rank_id << std::endl;
VLOG(3) << "rank_id:" << rank_id<<std::endl;
ctx.Wait();
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto out = scope->Var("Data");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num, num});
tensor_out->mutable_data<float>(place); // allocate
......@@ -88,37 +131,39 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait();
f::AttributeMap attrs;
attrs["tag"] = std::string("srtest");
attrs["peer"] = atoi(getenv("SRC_RANK"));
attrs["ring_id"] = 0;
attrs["srTag"] = 0;
attrs["tag"]=std::string("srtest");
attrs["peer"]=atoi(getenv("SRC_RANK"));
attrs["ring_id"]=0;
attrs["srTag"]=0;
std::vector<int> out_shape;
out_shape.push_back(num);
out_shape.push_back(num);
attrs["out_shape"] = out_shape;
attrs["out_shape"]=out_shape;
auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs);
auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Data"}}}, attrs);
VLOG(3) << "CreateOp recv_v2";
for (int i = 0; i < 10; i++) {
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3) << "Run op recv_v2";
std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait();
std::vector<float> init(num * num, 1.0 * atoi(getenv("DEST_RANK")));
std::vector<float> init(num*num, 1.0 * atoi(getenv("DEST_RANK")));
EXPECT_EQ(out_vec == init, true);
}
TEST(recv_v2, NPU) {
TEST(recv_v2, NPU){
f::Scope scope;
char* npu_id = getenv("FLAGS_selected_npus");
HcclRootInfo hccl_id;
char * npu_id=getenv("FLAGS_selected_npus");
VLOG(3) << "Select npu:" << npu_id;
auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(atoi(npu_id)));
VLOG(3) << "Place over";
Prepare(&scope, *ctx);
VLOG(3) << "Prepare over";
TestHcomRecvOp(&scope, *ctx);
VLOG(3) << "Test over";
p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
TestHcomRecvOp(&scope, ctx);
}
......@@ -28,31 +28,37 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto x = ctx.Input<framework::LoDTensor>("X");
void *ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
int numel = x->numel();
hcclDataType_t dtype = platform::ToHCCLDataType(x->type());
HcclDataType dtype = platform::ToHCCLDataType(x->type());
auto place = ctx.GetPlace();
int ring_id = ctx.Attr<int>("ring_id");
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
auto place = ctx.GetPlace();
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
if (ctx.Attr<bool>("use_calc_stream")) {
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int destRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_send(
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), (u64)numel, dtype, destRank,
srTag, group.c_str(), stream));
int nranks = comm->nranks();
int rank = comm->rank();
PADDLE_ENFORCE_EQ(nranks, 2,
platform::errors::InvalidArgument(
"The nranks must be 2, but (%d)",
nranks));
int root = rank;
VLOG(3) << "begin hccl send, parameter is: "<< "root " << root
<< ", comm: " << comm->comm() << ", stream: " << stream;
VLOG(3) << "Dest rank:" << destRank << " Invoke hcom send. Sent "
<< x->numel();
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(ptr, numel,
dtype, (uint32_t)root, comm->comm(), stream));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/send_v2_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
......@@ -41,43 +42,85 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(send_v2);
USE_NO_KERNEL_OP(c_comm_init_hcom);
USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(send_v2, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
std::string rank_table_file = getenv("RANK_TABLE_FILE");
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
int src_rank = atoi(getenv("SRC_RANK"));
int dest_rank = atoi(getenv("DEST_RANK"));
VLOG(3) << "rank_id " << rank_id << "src_rank" << src_rank << "dest_rank"
<< dest_rank;
std::vector<int> rank_ids = {0, 1};
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2;
comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids;
// comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs);
f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
}
void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx) {
std::cout << "BEGIN TEST:" << __FUNCTION__ << std::endl;
auto x = scope->Var("X");
void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){
std::cout<< "BEGIN TEST:"<< __FUNCTION__ <<std::endl;
auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>();
int num = atoi(getenv("DATA_SIZE"));
int num = atoi(getenv("DATA_SIZE"));;
EXPECT_GT(num, 0);
EXPECT_LT(num, 1 << 15);
std::vector<float> init(num * num, 1.0 * atoi(getenv("DEST_RANK")));
std::vector<float> init(num*num, 1.0 * atoi(getenv("DEST_RANK")));
int rank_id = atoi(getenv("RANK_ID"));
VLOG(3) << "rank id:" << rank_id;
VLOG(3)<<"rank id:"<<rank_id;
TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num, num});
ctx.Wait();
......@@ -85,28 +128,29 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait();
f::AttributeMap attrs;
attrs["tag"] = std::string("srtest");
attrs["peer"] = atoi(getenv("DEST_RANK"));
attrs["ring_id"] = 0;
attrs["srTag"] = 0;
attrs["tag"]=std::string("srtest");
attrs["peer"]=atoi(getenv("DEST_RANK"));
attrs["ring_id"]=0;
attrs["srTag"]=0;
auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs);
auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"Data"}}}, {}, attrs);
for (int i = 0; i < 10; i++) {
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3) << "send run over";
VLOG(3)<<"send run over";
ctx.Wait();
}
TEST(send_v2, NPU) {
TEST(send_v2, NPU){
f::Scope scope;
char* npu_id = getenv("FLAGS_selected_npus");
HcclRootInfo hccl_id;
char * npu_id=getenv("FLAGS_selected_npus");
VLOG(3) << "Select npu:" << npu_id;
auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(atoi(npu_id)));
VLOG(3) << "Place over";
Prepare(&scope, *ctx);
VLOG(3) << "Prepare over";
TestHcomSendOp(&scope, *ctx);
VLOG(3) << "Test over";
p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
PrepareUniqueId(&scope, ctx, &hccl_id);
Prepare(&scope, ctx, &hccl_id);
TestHcomSendOp(&scope, ctx);
}
......@@ -157,15 +157,10 @@ class HCCLComm {
virtual int nranks() const = 0;
virtual int rank() const = 0;
virtual int device_id() const = 0;
virtual HcclComm comm() const = 0;
virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default;
unsigned long NextTagId() {
return tag_counter_++;
}
private:
std::atomic<unsigned long> tag_counter_;
};
// A singleton HCCL communicator context reserves communication ring ids
......@@ -176,11 +171,12 @@ class HCCLCommContext {
return comm_ctx;
}
HCCLComm* CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id = 0);
HCCLComm* CreateHCCLComm(HcclRootInfo* hccl_id, int nranks,
int rank, int dev_id, int ring_id);
// a latter comm with the same dev_id and the same ring_id
// will override the former
HCCLComm* AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id = 0);
HCCLComm* AssignHCCLComm(HcclComm comm, int nranks, int rank,
int dev_id, int ring_id);
// retrieve a communicator by the ring id in multiprocessing mode
HCCLComm* Get(int ring_id) const {
......@@ -217,20 +213,21 @@ class HCCLCommContext {
private:
// Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); }
HCCLCommContext() {}
// we may use group feature in the feature
// HCCLCommContext() { InitHcomWorldGroup(); }
HcclComm comm_;
public:
~HCCLCommContext(){
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy());
}
~HCCLCommContext(){ }
std::once_flag once_flag_;
std::mutex comm_map_mutex_;
// ring id to dev-HCCLComm
std::map<int, std::map<int, std::unique_ptr<HCCLComm>>> comm_map_;
void InitHcomWorldGroup();
// void InitHcomWorldGroup();
void ReleaseHCCLComms();
DISABLE_COPY_AND_ASSIGN(HCCLCommContext);
......
......@@ -34,6 +34,13 @@ class HCCLCommImpl : public HCCLComm {
return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device;
}
~HCCLCommImpl(){
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclCommDestroy(comm_));
}
void set_comm(HcclComm comm) { comm_ = comm; }
HcclComm comm() const override { return comm_; }
aclrtStream stream() const override { return dev_ctx_->stream(); }
void set_dev_ctx(std::unique_ptr<NPUDeviceContext>&& dev_ctx) {
......@@ -45,46 +52,43 @@ class HCCLCommImpl : public HCCLComm {
int ring_id_;
int nranks_;
int rank_;
HcclComm comm_;
std::unique_ptr<NPUDeviceContext> dev_ctx_;
};
HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id) {
HCCLComm* HCCLCommContext::CreateHCCLComm(HcclRootInfo* hccl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(hccl_id,
platform::errors::InvalidArgument(
"The hccl unique id should not be null."));
PADDLE_ENFORCE_GT(
world_rank_ids.size(), 1,
nranks, 1,
platform::errors::InvalidArgument(
"Expected world_rank_ids.size() > 1. But received size is %d.", world_rank_ids.size()));
"Expected nranks > 1. But received nranks is %d.", nranks));
PADDLE_ENFORCE_GE(rank, 0,
platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_LT(
rank, world_rank_ids.size(),
rank, nranks,
platform::errors::InvalidArgument(
"Expected rank < nranks. But received rank is %d, nranks is %d.",
rank, world_rank_ids.size()));
rank, nranks));
PADDLE_ENFORCE_GE(
dev_id, 0,
platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", dev_id));
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"Expected ring_id >= 0. But received ring_id is %d.", ring_id));
auto* comm_wrapper = AssignHCCLComm(world_rank_ids.size(), rank, dev_id, ring_id);
// HACK(sunpeng17): hcom API requires bind stream to a model
// but we don't need model in Paddle, so we feed stream pointer as model pointer
HcclComm comm;
PADDLE_ENFORCE_NPU_SUCCESS(aclrtSetDevice(dev_id));
PADDLE_ENFORCE_NPU_SUCCESS(
platform::dynload::hcom_bind_model(comm_wrapper->stream(),
comm_wrapper->stream()));
platform::dynload::HcclCommInitRootInfo(nranks, hccl_id, rank, &comm));
// Get world_rank_ids registered in gen_nccl_id op
std::string group_name = HCOM_GROUP_PREFIX + std::to_string(ring_id);
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_create_group(
group_name.c_str(), world_rank_ids.size(), (unsigned int*)world_rank_ids.data()));
VLOG(1) << "initialized comm: " << &comm << ", nranks: " << nranks << ", hccl_id: " << hccl_id << ", rank: " << rank;
auto* comm_wrapper = AssignHCCLComm(comm, nranks, rank, dev_id, ring_id);
VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created on device " << dev_id << ", group name: " << group_name;
<< " has been created on device " << dev_id << ", with comm: " << comm_wrapper->comm();
std::call_once(once_flag_, []() {
std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); });
......@@ -93,7 +97,8 @@ HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids
return comm_wrapper;
}
HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id) {
HCCLComm* HCCLCommContext::AssignHCCLComm(HcclComm comm, int nranks, int rank,
int dev_id, int ring_id) {
std::unique_ptr<NPUDeviceContext> dev_ctx(
new NPUDeviceContext(NPUPlace(dev_id)));
......@@ -101,6 +106,7 @@ HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock();
......@@ -112,23 +118,14 @@ HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int
dev2comm.emplace(dev_id, std::unique_ptr<HCCLComm>(c));
comm_map_mutex_.unlock();
return comm_map_[ring_id][dev_id].get();
}
void HCCLCommContext::InitHcomWorldGroup() {
const char *rank_table_file = getenv(ENV_RANK_TABLE_FILE);
PADDLE_ENFORCE_NOT_NULL(
rank_table_file,
platform::errors::InvalidArgument("The RANK_TABLE_FILE environment variable should not be null."));
const char *rank_id = getenv(ENV_RANK_ID);
PADDLE_ENFORCE_NOT_NULL(
rank_id,
platform::errors::InvalidArgument("The RANK_ID environment variable should not be null."));
if (ring_id == 0) {
auto* dev_ctx = static_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::NPUPlace(dev_id)));
dev_ctx->set_hccl_comm(comm);
}
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_init(rank_table_file, rank_id));
VLOG(3) << "Successfully initialized hcom. rank_table_file: "
<< rank_table_file << ", rank_id " << rank_id;
return comm_map_[ring_id][dev_id].get();
}
void HCCLCommContext::ReleaseHCCLComms() {
......
......@@ -185,11 +185,21 @@ class NPUDeviceContext : public DeviceContext {
void WaitStreamCallback() const { return stream_->WaitCallback(); }
#if defined(PADDLE_WITH_ASCEND_CL)
/*! \brief Return hccl communicators. */
HcclComm hccl_comm() const { return hccl_comm_; }
/*! \brief Set hccl communicators. */
void set_hccl_comm(HcclComm comm) { hccl_comm_ = comm; }
#endif
private:
NPUPlace place_;
aclrtContext context_;
#ifdef PADDLE_WITH_ASCEND_HCCL
HCCLContext_t hccl_context_;
#ifdef PADDLE_WITH_ASCEND_CL
// HCCLContext_t hccl_context_;
HcclComm hccl_comm_{nullptr};
#endif
// Need to be the same with other DeviceContext,
......
......@@ -13,14 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include <hccl/hccl.h>
// #include <hccl/hccl_types.h>
#include <hccl/hccl.h>
#include <hccl/hccl_types.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/dynload/hcom.h"
// #include "paddle/fluid/platform/dynload/hcom.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#define HCOM_GROUP_PREFIX "HCOM_GROUP_"
namespace paddle {
namespace platform {
namespace dynload {
......@@ -43,27 +45,14 @@ extern void* hccl_dso_handle;
extern DynLoad__##__name __name
#define HCCL_RAND_ROUTINE_EACH(__macro) \
__macro(hcom_init); \
__macro(hcom_destroy); \
__macro(hcom_bind_model); \
__macro(hcom_unbind_model); \
__macro(hcom_send); \
__macro(hcom_receive); \
__macro(hcom_broadcast); \
__macro(hcom_all_gather); \
__macro(hcom_all_reduce); \
__macro(hcom_reduce_scatter); \
__macro(hcom_create_group); \
__macro(hcom_destroy_group); \
__macro(hcom_get_rank_id); \
__macro(hcom_get_local_rank_id); \
__macro(hcom_get_local_rank_size); \
__macro(hcom_get_split_strategy); \
__macro(hcom_set_split_strategy_by_size); \
__macro(hcom_set_split_strategy_by_index); \
__macro(hcom_get_group_rank_from_world_rank); \
__macro(hcom_get_world_rank_from_group_rank);
__macro(HcclReduceScatter); \
__macro(HcclCommDestroy); \
__macro(HcclAllReduce); \
__macro(HcclCommInitRootInfo); \
__macro(HcclGetRootInfo); \
__macro(HcclBroadcast); \
__macro(HcclCommInitClusterInfo); \
__macro(HcclAllGather);
HCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_HCCL_WRAP)
......
......@@ -40,7 +40,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
#include "paddle/fluid/platform/dynload/hcom.h"
#include "hccl/hccl_types.h"
#endif // PADDLE_WITH_ASCEND_CL
#include <fstream>
......@@ -1013,7 +1013,7 @@ struct NPUStatusType {};
}
DEFINE_NPU_STATUS_TYPE(aclError, ACL_ERROR_NONE);
DEFINE_NPU_STATUS_TYPE(hcclResult_t, HCCL_SUCCESS);
DEFINE_NPU_STATUS_TYPE(HcclResult, HCCL_SUCCESS);
} // namespace details
inline std::string build_npu_error_msg(aclError stat) {
......@@ -1022,7 +1022,7 @@ inline std::string build_npu_error_msg(aclError stat) {
return sout.str();
}
inline std::string build_npu_error_msg(hcclResult_t stat) {
inline std::string build_npu_error_msg(HcclResult stat) {
std::ostringstream sout;
sout << " HCCL error, the error code is : " << stat << ". ";
return sout.str();
......
......@@ -14,7 +14,7 @@
#pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_ASCEND_CL)
#if defined(PADDLE_WITH_HCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_ASCEND_CL)
#include <stdio.h>
#include <memory>
......@@ -24,30 +24,22 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#ifdef PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/dynload/hccl.h"
#endif
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#define NCCL_ID_VARNAME "NCCLID"
#define HCCL_ID_VARNAME "HCCLID"
namespace paddle {
namespace platform {
inline hcclDataType_t ToHCCLDataType(framework::proto::VarType::Type type) {
inline HcclDataType ToHCCLDataType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP32) {
return HCCL_DATA_TYPE_FP32;
} else if (type == framework::proto::VarType::FP16) {
......@@ -66,298 +58,301 @@ inline hcclDataType_t ToHCCLDataType(framework::proto::VarType::Type type) {
}
}
// // NOTE(minqiyang): according to the ncclGroupEnd documentations:
// // https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// // ncclGroupEnd will wait for all communicators to be initialized, which will
// // cause blocking problem when a runtime_error was thrown, so try only guard
// // NCCL actions when use it.
// class NCCLGroupGuard {
// NOTE(minqiyang): according to the ncclGroupEnd documentations:
// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// ncclGroupEnd will wait for all communicators to be initialized, which will
// cause blocking problem when a runtime_error was thrown, so try only guard
// HCCL actions when use it.
// class HCCLGroupGuard {
// public:
// static std::mutex &NCCLMutex() {
// static std::mutex &HCCLMutex() {
// static std::mutex mtx;
// return mtx;
// }
// inline NCCLGroupGuard() {
// NCCLMutex().lock();
// inline HCCLGroupGuard() {
// HCCLMutex().lock();
// PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupStart());
// }
// inline ~NCCLGroupGuard() PADDLE_MAY_THROW {
// inline ~HCCLGroupGuard() PADDLE_MAY_THROW {
// PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd());
// NCCLMutex().unlock();
// HCCLMutex().unlock();
// }
// };
// struct NCCLContext {
// std::unique_ptr<CUDADeviceContext> ctx_;
// ncclComm_t comm_;
struct HCCLContext {
std::unique_ptr<NPUDeviceContext> ctx_;
HcclComm comm_;
// explicit NCCLContext(int dev_id)
// : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
explicit HCCLContext(int dev_id)
: ctx_(new NPUDeviceContext(NPUPlace(dev_id))), comm_{nullptr} {}
// gpuStream_t stream() const { return ctx_->stream(); }
// ncclComm_t comm() const { return comm_; }
// int device_id() const {
// return BOOST_GET_CONST(platform::CUDAPlace, ctx_->GetPlace()).device;
// }
// };
aclrtStream stream() const { return ctx_->stream(); }
HcclComm comm() const { return comm_; }
// struct NCCLContextMap {
// std::unordered_map<int, NCCLContext> contexts_;
// std::vector<int> order_;
// explicit NCCLContextMap(const std::vector<platform::Place> &places,
// ncclUniqueId *nccl_id = nullptr,
// size_t num_trainers = 1, size_t trainer_id = 0) {
// PADDLE_ENFORCE_EQ(!places.empty(), true,
// platform::errors::InvalidArgument(
// "The NCCL place should not be empty."));
// order_.reserve(places.size());
// for (auto &p : places) {
// int dev_id = BOOST_GET_CONST(CUDAPlace, p).device;
// order_.emplace_back(dev_id);
// contexts_.emplace(dev_id, NCCLContext(dev_id));
// }
// PADDLE_ENFORCE_EQ(
// order_.size(), contexts_.size(),
// platform::errors::Unavailable("NCCL Context Map does not support "
// "contain two or more same device."));
// std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// // if num_trainers == 1, should create a new nccl id for local comms.
// if (num_trainers == 1 && nccl_id == nullptr) {
// std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
// PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
// comms.get(), static_cast<int>(order_.size()), order_.data()));
// } else {
// PADDLE_ENFORCE_NOT_NULL(nccl_id, platform::errors::InvalidArgument(
// "The NCCL id should not be null."));
// {
// int nranks = num_trainers * order_.size();
// NCCLGroupGuard gurad;
// for (size_t i = 0; i < order_.size(); ++i) {
// int gpu_id = order_[i];
// int rank;
// if (order_.size() > 1) {
// rank = trainer_id * order_.size() + i;
// } else {
// rank = trainer_id;
// }
// VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks
// << ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
// SetDeviceId(gpu_id);
// PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitRank(
// comms.get() + i, nranks, *nccl_id, rank));
// }
// }
// }
// int i = 0;
// for (auto &dev_id : order_) {
// contexts_.at(dev_id).comm_ = comms[i++];
// }
// }
int device_id() const {
return BOOST_GET_CONST(platform::NPUPlace, ctx_->GetPlace()).device;
}
};
struct HCCLContextMap {
std::unordered_map<int, HCCLContext> contexts_;
std::vector<int> order_;
explicit HCCLContextMap(const std::vector<platform::Place> &places,
HcclRootInfo *hccl_id = nullptr,
size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE_EQ(!places.empty(), true,
platform::errors::InvalidArgument(
"The HCCL place should not be empty."));
order_.reserve(places.size());
for (auto &p : places) {
int dev_id = BOOST_GET_CONST(NPUPlace, p).device;
order_.emplace_back(dev_id);
contexts_.emplace(dev_id, HCCLContext(dev_id));
}
PADDLE_ENFORCE_EQ(
order_.size(), contexts_.size(),
platform::errors::Unavailable("HCCL Context Map does not support "
"contain two or more same device."));
std::unique_ptr<HcclComm[]> comms(new HcclComm[order_.size()]);
// if num_trainers == 1, should create a new nccl id for local comms.
if (num_trainers == 1 && hccl_id == nullptr) {
// we do not know how to tackle this situation under hccl
// std::lock_guard<std::mutex> guard(HCCLGroupGuard::HCCLMutex());
// PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::ncclCommInitAll(
// comms.get(), static_cast<int>(order_.size()), order_.data()));
} else {
PADDLE_ENFORCE_NOT_NULL(hccl_id, platform::errors::InvalidArgument(
"The HCCL id should not be null."));
{
int nranks = num_trainers * order_.size();
// HCCLGroupGuard gurad;
for (size_t i = 0; i < order_.size(); ++i) {
int gpu_id = order_[i];
int rank;
if (order_.size() > 1) {
rank = trainer_id * order_.size() + i;
} else {
rank = trainer_id;
}
VLOG(1) << "init hccl rank:" << rank << ", nranks:" << nranks
<< ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
aclrtSetDevice(gpu_id);
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclCommInitRootInfo(
nranks, hccl_id, rank, comms.get() + i));
}
}
}
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
}
// NCCLContextMap(const NCCLContextMap &other) = delete;
// NCCLContextMap &operator=(const NCCLContextMap &other) = delete;
HCCLContextMap(const HCCLContextMap &other) = delete;
HCCLContextMap &operator=(const HCCLContextMap &other) = delete;
// CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
NPUDeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
// CUDADeviceContext *DevCtx(platform::Place p) const {
// return DevCtx(BOOST_GET_CONST(CUDAPlace, p).device);
// }
NPUDeviceContext *DevCtx(platform::Place p) const {
return DevCtx(BOOST_GET_CONST(NPUPlace, p).device);
}
// const NCCLContext &at(platform::Place p) const {
// return this->at(BOOST_GET_CONST(CUDAPlace, p).device);
// }
const HCCLContext &at(platform::Place p) const {
return this->at(BOOST_GET_CONST(NPUPlace, p).device);
}
// const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }
const HCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }
// void WaitAll() {
// for (auto &p : contexts_) {
// p.second.ctx_->Wait();
// }
// }
// };
void WaitAll() {
for (auto &p : contexts_) {
p.second.ctx_->Wait();
}
}
};
// inline std::string GetFlatNCCLVarName(size_t pos) {
// if (pos == 0) {
// return NCCL_ID_VARNAME;
// }
// return string::Sprintf("%s_%d", NCCL_ID_VARNAME, static_cast<int>(pos));
// }
inline std::string GetFlatHCCLVarName(size_t pos) {
if (pos == 0) {
return HCCL_ID_VARNAME;
}
return string::Sprintf("%s_%d", HCCL_ID_VARNAME, static_cast<int>(pos));
}
// inline std::string GetHierarchicalExterNCCLVarName(size_t pos) {
// return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME,
// static_cast<int>(pos));
// }
// inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
// return string::Sprintf("Hierarchical_inter_%s_%d", NCCL_ID_VARNAME,
// static_cast<int>(pos));
// }
inline std::string GetHierarchicalExterHCCLVarName(size_t pos) {
return string::Sprintf("Hierarchical_exter_%s_%d", HCCL_ID_VARNAME,
static_cast<int>(pos));
}
inline std::string GetHierarchicalInterHCCLVarName(size_t pos) {
return string::Sprintf("Hierarchical_inter_%s_%d", HCCL_ID_VARNAME,
static_cast<int>(pos));
}
// class NCCLCommunicator {
// public:
// NCCLCommunicator() {}
// virtual ~NCCLCommunicator() PADDLE_MAY_THROW {}
class HCCLCommunicator {
public:
HCCLCommunicator() {}
virtual ~HCCLCommunicator() PADDLE_MAY_THROW {}
// NCCLContextMap *DefaultFlatCtx() const {
// if (flat_ctxs_.size() == 0) {
// return nullptr;
// }
HCCLContextMap *DefaultFlatCtx() const {
if (flat_ctxs_.size() == 0) {
return nullptr;
}
// return flat_ctxs_[0].get();
// }
return flat_ctxs_[0].get();
}
// std::vector<std::unique_ptr<NCCLContextMap>> *GetFlatCtxs() {
// return &flat_ctxs_;
// }
std::vector<std::unique_ptr<HCCLContextMap>> *GetFlatCtxs() {
return &flat_ctxs_;
}
// NCCLContextMap *GetFlatCtx(size_t run_order) const {
// return flat_ctxs_[run_order % flat_ctxs_.size()].get();
// }
HCCLContextMap *GetFlatCtx(size_t run_order) const {
return flat_ctxs_[run_order % flat_ctxs_.size()].get();
}
// NCCLContextMap *GetRunEnvNCCLCtx(size_t run_order,
// bool use_hierarchical_allreduce) const {
// if (!use_hierarchical_allreduce) {
// return GetFlatCtx(run_order);
// }
HCCLContextMap *GetRunEnvHCCLCtx(size_t run_order,
bool use_hierarchical_allreduce) const {
if (!use_hierarchical_allreduce) {
return GetFlatCtx(run_order);
}
// return GetHierarchicalInterCtx(run_order);
// }
return GetHierarchicalInterCtx(run_order);
}
// *When nccl inits nccl comm using ncclCommInitAll, it meets error when
// *allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
// *create a new nccl comm for sync_batch_norm_op. And these codes should be
// *polished with a unified nccl management.
/*
When nccl inits nccl comm using ncclCommInitAll, it meets error when
allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
create a new nccl comm for sync_batch_norm_op. And these codes should be
polished with a unified nccl management.
*/
// NCCLContextMap *GetSyncBatchNormCtx(
// framework::Scope *scope, const std::vector<platform::Place> &places) {
// auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
// if (nccl_id_var != nullptr) {
// return DefaultFlatCtx();
// }
HCCLContextMap *GetSyncBatchNormCtx(framework::Scope* scope, const std::vector<platform::Place> &places) {
auto *hccl_id_var = scope->FindVar(HCCL_ID_VARNAME);
if (hccl_id_var != nullptr) {
return DefaultFlatCtx();
}
// if (sync_batch_norm_ctx_.get() == nullptr) {
// sync_batch_norm_ctx_.reset(new NCCLContextMap(places));
// }
// return sync_batch_norm_ctx_.get();
// }
if (sync_batch_norm_ctx_.get() == nullptr) {
sync_batch_norm_ctx_.reset(new HCCLContextMap(places));
}
return sync_batch_norm_ctx_.get();
}
// void InitFlatCtxs(const std::vector<platform::Place> &places,
// const std::vector<ncclUniqueId *> &nccl_ids,
// size_t trainers_num, size_t trainer_id) {
// if (nccl_ids.size() == 0) {
// auto ptr = new platform::NCCLContextMap(places);
// VLOG(1) << "init local trainer";
// flat_ctxs_.emplace_back(ptr);
// } else {
// for (size_t i = 0; i < nccl_ids.size(); i++) {
// auto ptr = new platform::NCCLContextMap(places, nccl_ids[i],
// trainers_num, trainer_id);
// VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
// flat_ctxs_.emplace_back(ptr);
// }
// }
void InitFlatCtxs(const std::vector<platform::Place> &places,
const std::vector<HcclRootInfo *> &hccl_ids,
size_t trainers_num, size_t trainer_id) {
if (hccl_ids.size() == 0) {
auto ptr = new platform::HCCLContextMap(places);
VLOG(1) << "init local trainer";
flat_ctxs_.emplace_back(ptr);
} else {
for (size_t i = 0; i < hccl_ids.size(); i++) {
auto ptr = new platform::HCCLContextMap(places, hccl_ids[i],
trainers_num, trainer_id);
VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
flat_ctxs_.emplace_back(ptr);
}
}
// // as Executor have no way to use ncclComm created by ParallelExecutor,
// // we assign all flatten contexts to NCCLCommContext to fix.
// int nranks = static_cast<int>(trainers_num * places.size());
// int nrings = static_cast<int>(flat_ctxs_.size());
// for (int ring_id = 0; ring_id < nrings; ++ring_id) {
// for (size_t p = 0; p < places.size(); ++p) {
// int rank = trainer_id * places.size() + p;
// int dev_id = BOOST_GET_CONST(CUDAPlace, places[p]).device;
// auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
// NCCLCommContext::Instance().AssignNCCLComm(ctx.comm_, nranks, rank,
// dev_id, ring_id);
// }
// }
// }
// as Executor have no way to use ncclComm created by ParallelExecutor,
// we assign all flatten contexts to HCCLCommContext to fix.
int nranks = static_cast<int>(trainers_num * places.size());
int nrings = static_cast<int>(flat_ctxs_.size());
for (int ring_id = 0; ring_id < nrings; ++ring_id) {
for (size_t p = 0; p < places.size(); ++p) {
int rank = trainer_id * places.size() + p;
int dev_id = BOOST_GET_CONST(NPUPlace, places[p]).device;
auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
HCCLCommContext::Instance().AssignHCCLComm(ctx.comm_, nranks, rank,
dev_id, ring_id);
}
}
}
// void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
// const std::vector<ncclUniqueId *> &inter_nccl_ids,
// const std::vector<ncclUniqueId *> &exter_nccl_ids,
// size_t trainers_num, size_t trainer_id,
// size_t inter_trainers_num,
// size_t exter_trainers_num) {
// PADDLE_ENFORCE_EQ(
// trainers_num, inter_trainers_num * exter_trainers_num,
// platform::errors::InvalidArgument(
// "trainers_num:%llu != inter_trainers_num:%llu * "
// "exter_trainers_num:%llu",
// trainers_num, inter_trainers_num, exter_trainers_num));
// PADDLE_ENFORCE_GT(
// inter_trainers_num, 1,
// platform::errors::InvalidArgument(
// "The inter_trainers_num:%llu should be larger than 1.",
// inter_trainers_num));
// int inter_trainer_id = trainer_id % inter_trainers_num;
// for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
// VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
// << ", comm no:" << i;
// auto local = new NCCLContextMap(places, inter_nccl_ids[i],
// inter_trainers_num, inter_trainer_id);
// h_inter_ctxs_.emplace_back(local);
// }
void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
const std::vector<HcclRootInfo *> &inter_hccl_ids,
const std::vector<HcclRootInfo *> &exter_hccl_ids,
size_t trainers_num, size_t trainer_id,
size_t inter_trainers_num,
size_t exter_trainers_num) {
PADDLE_ENFORCE_EQ(
trainers_num, inter_trainers_num * exter_trainers_num,
platform::errors::InvalidArgument(
"trainers_num:%llu != inter_trainers_num:%llu * "
"exter_trainers_num:%llu",
trainers_num, inter_trainers_num, exter_trainers_num));
PADDLE_ENFORCE_GT(
inter_trainers_num, 1,
platform::errors::InvalidArgument(
"The inter_trainers_num:%llu should be larger than 1.",
inter_trainers_num));
int inter_trainer_id = trainer_id % inter_trainers_num;
for (size_t i = 0; i < inter_hccl_ids.size(); i++) {
VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
<< ", comm no:" << i;
auto local = new HCCLContextMap(places, inter_hccl_ids[i],
inter_trainers_num, inter_trainer_id);
h_inter_ctxs_.emplace_back(local);
}
// int exter_trainer_id = -1;
// if (trainer_id % inter_trainers_num == 0) {
// exter_trainer_id = trainer_id / inter_trainers_num;
// }
int exter_trainer_id = -1;
if (trainer_id % inter_trainers_num == 0) {
exter_trainer_id = trainer_id / inter_trainers_num;
}
// if (exter_trainer_id >= 0) {
// for (size_t i = 0; i < exter_nccl_ids.size(); i++) {
// auto ex = new NCCLContextMap(places, exter_nccl_ids[i],
// exter_trainers_num, exter_trainer_id);
// VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
// << ", comm no:" << i;
// h_exter_ctxs_.emplace_back(ex);
// }
// }
// }
if (exter_trainer_id >= 0) {
for (size_t i = 0; i < exter_hccl_ids.size(); i++) {
auto ex = new HCCLContextMap(places, exter_hccl_ids[i],
exter_trainers_num, exter_trainer_id);
VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
<< ", comm no:" << i;
h_exter_ctxs_.emplace_back(ex);
}
}
}
// bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }
bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }
// NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
// PADDLE_ENFORCE_GT(h_inter_ctxs_.size(), 0,
// platform::errors::InvalidArgument(
// "Hierarchical ctxs should be initialized firstly!"));
// return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
// }
HCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
PADDLE_ENFORCE_GT(h_inter_ctxs_.size(), 0,
platform::errors::InvalidArgument(
"Hierarchical ctxs should be initialized firstly!"));
return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
}
// NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
// PADDLE_ENFORCE_GT(h_exter_ctxs_.size(), 0,
// platform::errors::InvalidArgument(
// "Hierarchical ctxs should be initialized firstly!"));
// return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
// }
HCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
PADDLE_ENFORCE_GT(h_exter_ctxs_.size(), 0,
platform::errors::InvalidArgument(
"Hierarchical ctxs should be initialized firstly!"));
return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
}
// std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalInterCtxs() {
// return &h_inter_ctxs_;
// }
std::vector<std::unique_ptr<HCCLContextMap>> *GetHierarchicalInterCtxs() {
return &h_inter_ctxs_;
}
// std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalExterCtxs() {
// return &h_exter_ctxs_;
// }
std::vector<std::unique_ptr<HCCLContextMap>> *GetHierarchicalExterCtxs() {
return &h_exter_ctxs_;
}
// protected:
// // Support multi nccl comm on default nccl ring while NCCLContextMap can't.
// std::vector<std::unique_ptr<NCCLContextMap>> flat_ctxs_;
protected:
// Support multi nccl comm on default nccl ring while HCCLContextMap can't.
std::vector<std::unique_ptr<HCCLContextMap>> flat_ctxs_;
// // h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce.
// // And h_exter_ctxs_ can support multi comm too.
// std::vector<std::unique_ptr<NCCLContextMap>> h_inter_ctxs_;
// std::vector<std::unique_ptr<NCCLContextMap>> h_exter_ctxs_;
// h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce.
// And h_exter_ctxs_ can support multi comm too.
std::vector<std::unique_ptr<HCCLContextMap>> h_inter_ctxs_;
std::vector<std::unique_ptr<HCCLContextMap>> h_exter_ctxs_;
// // just used for sync_batch_norm op.
// std::unique_ptr<NCCLContextMap> sync_batch_norm_ctx_;
// };
// just used for sync_batch_norm op.
std::unique_ptr<HCCLContextMap> sync_batch_norm_ctx_;
};
} // namespace platform
} // namespace paddle
......
......@@ -151,17 +151,30 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_npu():
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_comm_init_hcom',
type='c_gen_hccl_id',
inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
'rank_ids': nranks,
OP_ROLE_KEY: OpRole.Forward
})
......
......@@ -108,19 +108,32 @@ class PipelineHelper(object):
OP_ROLE_KEY: OpRole.Forward,
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_comm_init_hcom',
type='c_gen_hccl_id',
inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
'rank_ids': nranks,
OP_ROLE_KEY: OpRole.Forward
})
......
......@@ -2053,7 +2053,7 @@ class Operator(object):
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_comm_init_hcom', 'c_sync_calc_stream',
'gen_nccl_id', 'c_gen_nccl_id', 'c_gen_hccl_id', 'c_comm_init', 'c_comm_init_hccl', 'c_sync_calc_stream',
'c_sync_comm_stream', 'queue_generator', 'dequeue', 'enqueue',
'heter_listen_and_serv'
}
......
......@@ -131,19 +131,32 @@ class Collective(object):
self.op_role_key: OpRole.Forward
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_comm_init_hcom',
type='c_gen_hccl_id',
inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
'rank_ids': nranks,
self.op_role_key: OpRole.Forward
})
......
......@@ -162,19 +162,33 @@ def init_communicator(program, rank, nranks, wait_port, current_endpoint,
'ring_id': 0,
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_comm_init_hcom',
type='c_gen_hccl_id',
inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
'rank_ids': nranks,
OP_ROLE_KEY: OpRole.Forward
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册