未验证 提交 ce26f882 编写于 作者: L lw921014 提交者: GitHub

update Ascendrc hccl to 20.3 (#32126)

update Ascendrc hccl to 20.3 (#32126)
上级 75dd8423
...@@ -400,6 +400,7 @@ OperatorBase::OperatorBase(const std::string& type, ...@@ -400,6 +400,7 @@ OperatorBase::OperatorBase(const std::string& type,
// framework::OpRegistry::CreateOp(type, {}, {}, {}, false). // framework::OpRegistry::CreateOp(type, {}, {}, {}, false).
// Inputs, outputs and attrs will be set to empty map // Inputs, outputs and attrs will be set to empty map
// to improve the execution efficiency of dygraph. // to improve the execution efficiency of dygraph.
if (inputs_.size() > 0 || outputs_.size() > 0) { if (inputs_.size() > 0 || outputs_.size() > 0) {
GenerateTemporaryNames(); GenerateTemporaryNames();
CheckAllInputOutputSet(); CheckAllInputOutputSet();
......
...@@ -31,6 +31,11 @@ ...@@ -31,6 +31,11 @@
#endif #endif
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL
#include <hccl/hccl.h>
#include <hccl/hccl_types.h>
#endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h" #include "xpu/bkcl.h"
#endif #endif
...@@ -45,6 +50,10 @@ class Communicator; ...@@ -45,6 +50,10 @@ class Communicator;
class NCCLCommunicator; class NCCLCommunicator;
#endif #endif
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL
class Communicator;
class HCCLCommunicator;
#endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
class BKCLCommunicator; class BKCLCommunicator;
...@@ -157,6 +166,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -157,6 +166,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#endif #endif
operators::CudnnRNNCache, operators::CudnnRNNCache,
#endif #endif
#if defined(PADDLE_WITH_ASCEND_CL)
HcclRootInfo,
#endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId, platform::BKCLCommunicator, BKCLUniqueId, platform::BKCLCommunicator,
#endif #endif
......
...@@ -11,7 +11,7 @@ foreach(src ${OPS}) ...@@ -11,7 +11,7 @@ foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS}) set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS})
endforeach() endforeach()
register_operators(EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) register_operators(EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op c_gen_hccl_id_op gen_hccl_id_op DEPS ${COLLECTIVE_DEPS})
if(WITH_NCCL) if(WITH_NCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
...@@ -24,39 +24,43 @@ if(WITH_GLOO) ...@@ -24,39 +24,43 @@ if(WITH_GLOO)
endif() endif()
if(WITH_XPU_BKCL) if(WITH_XPU_BKCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper nccl_common)
op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
endif() endif()
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper) cc_library(gen_hccl_id_op_helper SRCS gen_hccl_id_op_helper.cc DEPS dynload_warpctc dynamic_loader scope)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper gen_hccl_id_op_helper)
op_library(c_gen_hccl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_hccl_id_op DEPS ${COLLECTIVE_DEPS})
endif() endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE) set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency") set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
set(COMMON_TEST_DEPS_FOR_HCOM c_comm_init_hcom_op op_registry ascend_hccl flags set(COMMON_TEST_DEPS_FOR_HCOM c_comm_init_hccl_op c_gen_hccl_id_op gen_hccl_id_op_helper
gen_hccl_id_op op_registry ascend_hccl flags
dynamic_loader dynload_warpctc scope device_context enforce executor) dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_broadcast_op_npu_test SRCS c_broadcast_op_npu_test.cc cc_test(c_broadcast_op_npu_test SRCS c_broadcast_op_npu_test.cc
DEPS c_broadcast_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) DEPS c_broadcast_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_sum_op_npu_test SRCS c_allreduce_sum_op_npu_test.cc cc_test(c_allreduce_sum_op_npu_test SRCS c_allreduce_sum_op_npu_test.cc
DEPS c_allreduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) DEPS c_allreduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc
DEPS c_allreduce_max_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_reduce_sum_op_npu_test SRCS c_reduce_sum_op_npu_test.cc
DEPS c_reduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_reducescatter_op_npu_test SRCS c_reducescatter_op_npu_test.cc cc_test(c_reducescatter_op_npu_test SRCS c_reducescatter_op_npu_test.cc
DEPS c_reducescatter_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) DEPS c_reducescatter_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc cc_test(c_allgather_op_npu_test SRCS c_allgather_op_npu_test.cc
DEPS c_allgather_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) DEPS c_allgather_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_reduce_sum_op_npu_test SRCS c_reduce_sum_op_npu_test.cc
DEPS c_reduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_allreduce_max_op_npu_test SRCS c_allreduce_max_op_npu_test.cc
DEPS c_allreduce_max_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(send_v2_op_npu_test SRCS send_v2_op_npu_test.cc cc_test(send_v2_op_npu_test SRCS send_v2_op_npu_test.cc
DEPS send_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) DEPS send_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(recv_v2_op_npu_test SRCS recv_v2_op_npu_test.cc cc_test(recv_v2_op_npu_test SRCS recv_v2_op_npu_test.cc
DEPS recv_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM}) DEPS recv_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
cc_test(c_sync_comm_stream_op_npu_test SRCS c_sync_comm_stream_op_npu_test.cc cc_test(c_sync_comm_stream_op_npu_test SRCS c_sync_comm_stream_op_npu_test.cc
DEPS op_registry c_broadcast_op c_comm_init_hcom_op c_sync_comm_stream_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) DEPS op_registry c_broadcast_op c_comm_init_hccl_op c_sync_comm_stream_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
cc_test(c_sync_calc_stream_op_npu_test SRCS c_sync_calc_stream_op_npu_test.cc cc_test(c_sync_calc_stream_op_npu_test SRCS c_sync_calc_stream_op_npu_test.cc
DEPS op_registry elementwise_add_op c_sync_calc_stream_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor) DEPS op_registry elementwise_add_op c_sync_calc_stream_op ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
endif() endif()
...@@ -31,20 +31,19 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> { ...@@ -31,20 +31,19 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
hcclDataType_t dtype = platform::ToHCCLDataType(in->type()); HcclDataType dtype = platform::ToHCCLDataType(in->type());
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks(); int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
framework::DDim out_dims = in->dims(); framework::DDim out_dims = in->dims();
out_dims[0] *= nranks; out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel(); uint64_t send_numel = in->numel();
void *send_buff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>())); void *send_buff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
void *recv_buff = reinterpret_cast<void*>(out->data<T>()); void *recv_buff = reinterpret_cast<void*>(out->data<T>());
...@@ -59,12 +58,11 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> { ...@@ -59,12 +58,11 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
VLOG(3) << "begin hccl allgather, parameter is: " VLOG(3) << "begin hccl allgather, parameter is: "
<< ", group is " << group << ", group is " << group
<< ", ring_id is " << ring_id << ", ring_id is " << ring_id
<< ", nranks is " << nranks << ", nranks is " << nranks;
<< ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_gather( PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllGather(
tag.c_str(), send_buff, recv_buff, (u64)send_numel, dtype, send_buff, recv_buff, send_numel, dtype,
group.c_str(), (void*)stream)); comm->comm(), (void*)stream));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -45,7 +46,8 @@ namespace p = paddle::platform; ...@@ -45,7 +46,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(c_allgather); USE_OP(c_allgather);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allgather, NPU); USE_OP_DEVICE_KERNEL(c_allgather, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
...@@ -56,26 +58,68 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { ...@@ -56,26 +58,68 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
for (auto ele : data) { for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(","); debugstring += std::to_string(ele) + std::string(",");
} }
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl <<debugstring;
} }
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
...@@ -83,7 +127,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -83,7 +127,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("X"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init; std::vector<float> init;
...@@ -102,7 +146,7 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -102,7 +146,7 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("Out"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
...@@ -110,12 +154,12 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -110,12 +154,12 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
// run // run
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("tagx"); attrs["tag"]=std::string("tagx");
attrs["ring_id"] = 0; attrs["ring_id"]=0;
attrs["nranks"] = 2; attrs["nranks"]=2;
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"Data"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
op->Run(*scope, place); op->Run(*scope, place);
...@@ -139,11 +183,12 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -139,11 +183,12 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
TEST(c_allgather, NPU) { TEST(c_allgather, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get( p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
TestHCCLAllGatherOp(&scope, *ctx); Prepare(&scope, ctx, &hccl_id);
TestHCCLAllGatherOp(&scope, ctx);
} }
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -45,7 +46,8 @@ namespace p = paddle::platform; ...@@ -45,7 +46,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(c_allreduce_max); USE_OP(c_allreduce_max);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allreduce_max, NPU); USE_OP_DEVICE_KERNEL(c_allreduce_max, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
...@@ -59,23 +61,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { ...@@ -59,23 +61,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl << debugstring;
} }
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
...@@ -83,7 +127,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -83,7 +127,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("X"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init; std::vector<float> init;
...@@ -102,7 +146,7 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -102,7 +146,7 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("Out"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
...@@ -113,8 +157,8 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -113,8 +157,8 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
attrs["tag"] = std::string("tagx"); attrs["tag"] = std::string("tagx");
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"Data"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
op->Run(*scope, place); op->Run(*scope, place);
...@@ -135,11 +179,12 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -135,11 +179,12 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
TEST(c_allreduce_max, NPU) { TEST(c_allreduce_max, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get( p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
TestHCCLAllReduceOp(&scope, *ctx); Prepare(&scope, ctx, &hccl_id);
TestHCCLAllReduceOp(&scope, ctx);
} }
...@@ -117,34 +117,18 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -117,34 +117,18 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
// we need to pre-allocate 512 Bytes before the data
// and 512 Bytes after the data, so the hccl allreduce
// can work. This is a must acooding to huawei peer.
#define PRE_MALLOC_SIZE_BYTES 512
auto in = ctx.Input<framework::LoDTensor>("X"); auto in = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
hcclDataType_t dtype = platform::ToHCCLDataType(in->type()); HcclDataType dtype = platform::ToHCCLDataType(in->type());
int64_t numel = in->numel(); int64_t numel = in->numel();
int64_t pre_tmp_size = PRE_MALLOC_SIZE_BYTES / sizeof(T); void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
int64_t tmp_numel = numel + pre_tmp_size * 2; void* recvbuff = reinterpret_cast<void*>(out->data<T>());
paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel});
auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
...@@ -154,33 +138,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -154,33 +138,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
stream = comm->stream(); stream = comm->stream();
} }
// we need to memset this memory firstly to avoid core by hccl HcclReduceOp hccl_red_type = HCCL_REDUCE_SUM;
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff,
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T),
stream);
hcclRedOp_t hccl_red_type = HCCL_REP_OP_SUM;
switch (red_type) { switch (red_type) {
case kRedSum: case kRedSum:
hccl_red_type = HCCL_REP_OP_SUM; hccl_red_type = HCCL_REDUCE_SUM;
break; break;
case kRedMax: case kRedMax:
hccl_red_type = HCCL_REP_OP_MAX; hccl_red_type = HCCL_REDUCE_MAX;
break; break;
case kRedMin: case kRedMin:
hccl_red_type = HCCL_REP_OP_MIN; hccl_red_type = HCCL_REDUCE_MIN;
break; break;
case kRedProd: case kRedProd:
hccl_red_type = HCCL_REP_OP_PROD; hccl_red_type = HCCL_REDUCE_PROD;
break; break;
default: default:
...@@ -192,16 +165,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -192,16 +165,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
<< "input num: " << numel << "input num: " << numel
<< "dtype: " << dtype << "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type << "hccl_red_type: " << hccl_red_type
<< ", group is: " << group << ", group is: " << group;
<< ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_reduce(
tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream));
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()), PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce(
npu_place, recvbuff, sendbuff, recvbuff, numel, dtype, hccl_red_type, comm->comm(), (void*)stream));
numel * sizeof(T),
stream);
out->Resize(in->dims()); out->Resize(in->dims());
#else #else
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -42,7 +43,8 @@ namespace p = paddle::platform; ...@@ -42,7 +43,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(c_allreduce_sum); USE_OP(c_allreduce_sum);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU); USE_OP_DEVICE_KERNEL(c_allreduce_sum, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
...@@ -56,23 +58,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { ...@@ -56,23 +58,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(3) << preStr << ":" << std::endl << debugstring; VLOG(3) << preStr << ":" << std::endl << debugstring;
} }
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
...@@ -81,7 +125,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -81,7 +125,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx, void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
int iter) { int iter) {
// init // init
auto x = scope->Var("X"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
...@@ -100,7 +144,7 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx, ...@@ -100,7 +144,7 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
tensor_x->Resize({num1, num2}); tensor_x->Resize({num1, num2});
ctx.Wait(); ctx.Wait();
auto out = scope->Var("Out"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
...@@ -111,8 +155,10 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx, ...@@ -111,8 +155,10 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
attrs["tag"] = std::string("tagx_" + std::to_string(iter)); attrs["tag"] = std::string("tagx_" + std::to_string(iter));
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp("c_allreduce_sum", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_allreduce_sum",
{{"Out", {"Out"}}}, attrs); {{"X", {"Data"}}},
{{"Out", {"OutData"}}},
attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
op->Run(*scope, place); op->Run(*scope, place);
...@@ -133,14 +179,17 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx, ...@@ -133,14 +179,17 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
TEST(c_allreduce_sum, NPU) { TEST(c_allreduce_sum, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get( // auto* ctx = p::DeviceContextPool::Instance().Get(
p::NPUPlace(atoi(FLAGS_selected_npus.c_str()))); // p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
for (int i = 0; i < 1; i++) { Prepare(&scope, ctx, &hccl_id);
for(int i = 0; i < 1; i ++){
VLOG(2) << "iter num: " << i; VLOG(2) << "iter num: " << i;
TestHCCLAllReduceOp(&scope, *ctx, i); TestHCCLAllReduceOp(&scope, ctx, i);
} }
} }
...@@ -30,7 +30,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> { ...@@ -30,7 +30,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
auto x = ctx.Input<framework::LoDTensor>("X"); auto x = ctx.Input<framework::LoDTensor>("X");
void *ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>())); void *ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
int numel = x->numel(); int numel = x->numel();
hcclDataType_t dtype = platform::ToHCCLDataType(x->type()); HcclDataType dtype = platform::ToHCCLDataType(x->type());
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
...@@ -48,14 +48,12 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> { ...@@ -48,14 +48,12 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root
<< ", group is " << group << ", group is " << group << ", comm: " << comm->comm() << ", stream: " << stream;
<< ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_broadcast(tag.c_str(), ptr, numel, PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(ptr, numel,
dtype, (uint32_t)root, group.c_str(), (void*)stream)); dtype, (uint32_t)root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved " VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims()); << framework::product(out->dims());
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -42,7 +43,8 @@ namespace p = paddle::platform; ...@@ -42,7 +43,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(c_broadcast); USE_OP(c_broadcast);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU); USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
...@@ -53,26 +55,68 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { ...@@ -53,26 +55,68 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
for (auto ele : data) { for (auto ele : data) {
debugstring += std::to_string(ele) + std::string(","); debugstring += std::to_string(ele) + std::string(",");
} }
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl <<debugstring;
} }
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
...@@ -80,7 +124,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -80,7 +124,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("X"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
int num = 2; int num = 2;
std::vector<float> init; std::vector<float> init;
...@@ -96,7 +140,7 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -96,7 +140,7 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("Out"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num, num}); tensor_out->Resize({num, num});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
...@@ -108,8 +152,8 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -108,8 +152,8 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
attrs["root"] = 0; attrs["root"] = 0;
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"Data"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"OutData"}}}, attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
op->Run(*scope, place); op->Run(*scope, place);
...@@ -129,11 +173,11 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -129,11 +173,11 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
TEST(c_broadcast, NPU) { TEST(c_broadcast, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get( p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
TestHCCLBroadcastOp(&scope, *ctx); Prepare(&scope, ctx, &hccl_id);
TestHCCLBroadcastOp(&scope, ctx);
} }
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,66 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,66 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Scope; class Scope;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CCommInitOpNPU : public framework::OperatorBase { class CCommInitOpAscend : public framework::OperatorBase {
public: public:
CCommInitOpNPU(const std::string& type, CCommInitOpAscend(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
int rid = Attr<int>("ring_id"); PADDLE_ENFORCE_EQ(is_npu_place(place), true,
int nranks = Attr<int>("nranks"); platform::errors::PreconditionNotMet(
"CCommInitOpAscend can run on npu place only."));
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
#if defined(PADDLE_WITH_ASCEND_CL)
HcclRootInfo* hccl_id = var->GetMutable<HcclRootInfo>();
int rank_ids = Attr<int>("rank_ids");
int rank_id = Attr<int>("rank"); int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::NPUPlace, place).device; int device_id = BOOST_GET_CONST(platform::NPUPlace, place).device;
if (Attr<int>("device_id") >= 0) { if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id"); device_id = Attr<int>("device_id");
} }
std::vector<int> rank_ids = Attr<std::vector<int>>("rank_ids");
VLOG(3) << "begin c_comm_init on npu, parameters are: "
<< "ring id[" << rid
<< "], nranks[" << nranks
<< "], rank_id[" << rank_id
<< "], device_id[" << device_id
<< "]";
platform::HCCLCommContext::Instance().CreateHCCLComm( platform::HCCLCommContext::Instance().CreateHCCLComm(
rank_ids, rank_id, device_id, rid); hccl_id, rank_ids, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU."));
#endif
} }
}; };
class CCommInitOpNPUMaker : public framework::OpProtoAndCheckerMaker { class CCommInitOpAscendMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC( AddComment(R"DOC(
CCommInit operator on NPU CCommInit operator
Initialize collective communication context within this trainer Initialize collective communicatoin context within this trainer
)DOC"); )DOC");
AddAttr<int>("nranks", "(int) The number of ranks of distributed trainers"); AddAttr<int>("rank_ids", "(int) The number of ranks of distributed trainers");
AddAttr<std::vector<int>>("rank_ids", "The world rank ids of the group");
AddAttr<int>("rank", AddAttr<int>("rank",
"(int) The rank of the trainer in distributed training."); "(int) The rank of the trainer in distributed training.");
AddAttr<int>("device_id", AddAttr<int>("device_id",
...@@ -89,6 +90,4 @@ Initialize collective communication context within this trainer ...@@ -89,6 +90,4 @@ Initialize collective communication context within this trainer
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init_hcom, ops::CCommInitOpNPU, ops::CCommInitOpNPUMaker); REGISTER_OPERATOR(c_comm_init_hccl, ops::CCommInitOpAscend, ops::CCommInitOpAscendMaker);
#endif
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#endif
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_ASCEND_CL
class CGenHCCLIdOp : public framework::OperatorBase {
public:
CGenHCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};
if (rank == 0) {
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
SendBroadCastHCCLID(endpoint_list, 1, func, local_scope);
} else {
std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastHCCLID(endpoint, 1, func, local_scope);
}
scope.DeleteScope(&local_scope);
}
};
#else
class CGenHCCLIdOp : public framework::OperatorBase {
public:
CGenHCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
}
};
#endif
class CGenHCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
VLOG(3) << "ele";
AddOutput("Out", "Raw variable contains a HCCL UniqueId instaces.");
AddComment(R"DOC(
CGenHCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::string>("endpoint",
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint");
AddAttr<std::vector<std::string>>(
"other_endpoints",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of other trainer endpoints")
.SetDefault({});
AddAttr<int>("rank",
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_gen_hccl_id, ops::CGenHCCLIdOp, ops::CGenHCCLIdOpMaker);
...@@ -63,6 +63,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -63,6 +63,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
}; };
#else #else
class CGenNCCLIdOp : public framework::OperatorBase { class CGenNCCLIdOp : public framework::OperatorBase {
public: public:
CGenNCCLIdOp(const std::string& type, CGenNCCLIdOp(const std::string& type,
......
...@@ -121,31 +121,15 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -121,31 +121,15 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
// we need to pre-allocate 512 Bytes before the data
// and 512 Bytes after the data, so the hccl allreduce
// can work. This is a must acooding to huawei peer.
#define PRE_MALLOC_SIZE_BYTES 512
auto in = ctx.Input<framework::LoDTensor>("X"); auto in = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
hcclDataType_t dtype = platform::ToHCCLDataType(in->type()); HcclDataType dtype = platform::ToHCCLDataType(in->type());
int64_t numel = in->numel(); int64_t numel = in->numel();
int64_t pre_tmp_size = PRE_MALLOC_SIZE_BYTES / sizeof(T); void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
int64_t tmp_numel = numel + pre_tmp_size * 2; void* recvbuff = reinterpret_cast<void*>(out->data<T>());
paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel});
auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
std::string tag = ctx.Attr<std::string>("tag");
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
int root_id = ctx.Attr<int>("root_id"); int root_id = ctx.Attr<int>("root_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
...@@ -161,33 +145,22 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -161,33 +145,22 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> {
int rank_id = comm->rank(); int rank_id = comm->rank();
// we need to memset this memory firstly to avoid core by hccl HcclReduceOp hccl_red_type = HCCL_REDUCE_SUM;
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff,
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T),
stream);
hcclRedOp_t hccl_red_type = HCCL_REP_OP_SUM;
switch (red_type) { switch (red_type) {
case kRedSum: case kRedSum:
hccl_red_type = HCCL_REP_OP_SUM; hccl_red_type = HCCL_REDUCE_SUM;
break; break;
case kRedMax: case kRedMax:
hccl_red_type = HCCL_REP_OP_MAX; hccl_red_type = HCCL_REDUCE_MAX;
break; break;
case kRedMin: case kRedMin:
hccl_red_type = HCCL_REP_OP_MIN; hccl_red_type = HCCL_REDUCE_MIN;
break; break;
case kRedProd: case kRedProd:
hccl_red_type = HCCL_REP_OP_PROD; hccl_red_type = HCCL_REDUCE_PROD;
break; break;
default: default:
...@@ -200,18 +173,14 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -200,18 +173,14 @@ class CReduceOpASCENDKernel : public framework::OpKernel<T> {
<< "root_id: " << root_id << "root_id: " << root_id
<< "dtype: " << dtype << "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type << "hccl_red_type: " << hccl_red_type
<< ", group is: " << group << ", group is: " << group;
<< ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_all_reduce( PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce(
tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream)); sendbuff, recvbuff, numel, dtype, hccl_red_type, comm->comm(), (void*)stream));
if(rank_id == root_id){
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()), if(rank_id != root_id){
npu_place, recvbuff, auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
numel * sizeof(T),
stream);
}else{
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()), memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()),
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())), npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T), numel * sizeof(T),
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/c_reduce_op.h" #include "paddle/fluid/operators/collective/c_reduce_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -42,7 +43,8 @@ namespace p = paddle::platform; ...@@ -42,7 +43,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(c_reduce_sum); USE_OP(c_reduce_sum);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_reduce_sum, NPU); USE_OP_DEVICE_KERNEL(c_reduce_sum, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
...@@ -56,23 +58,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { ...@@ -56,23 +58,65 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(3) << preStr << ":" << std::endl << debugstring; VLOG(3) << preStr << ":" << std::endl << debugstring;
} }
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id << "; device_id = " << device_id VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id << "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("RANK_TABLE_FILE")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
...@@ -80,7 +124,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -80,7 +124,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) { void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
// init // init
auto x = scope->Var("X"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
...@@ -99,7 +143,7 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) { ...@@ -99,7 +143,7 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
tensor_x->Resize({num1, num2}); tensor_x->Resize({num1, num2});
ctx.Wait(); ctx.Wait();
auto out = scope->Var("Out"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
...@@ -112,8 +156,10 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) { ...@@ -112,8 +156,10 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
int root_id = 0; int root_id = 0;
attrs["root_id"] = root_id; attrs["root_id"] = root_id;
auto op = f::OpRegistry::CreateOp("c_reduce_sum", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_reduce_sum",
{{"Out", {"Out"}}}, attrs); {{"X", {"Data"}}},
{{"Out", {"OutData"}}},
attrs);
op->Run(*scope, place); op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
...@@ -136,14 +182,15 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) { ...@@ -136,14 +182,15 @@ void TestHCCLReduceOp(f::Scope* scope, const p::DeviceContext& ctx, int iter) {
TEST(c_reduce_sum, NPU) { TEST(c_reduce_sum, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get( p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
for (int i = 0; i < 2; i++) { Prepare(&scope, ctx, &hccl_id);
for(int i = 0; i < 2; i ++){
VLOG(2) << "iter num: " << i; VLOG(2) << "iter num: " << i;
TestHCCLReduceOp(&scope, *ctx, i); TestHCCLReduceOp(&scope, ctx, i);
} }
} }
...@@ -35,7 +35,6 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,6 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks(); int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
auto out_dims = in->dims(); auto out_dims = in->dims();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0, PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0,
...@@ -47,11 +46,11 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -47,11 +46,11 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
out_dims[0] = out_dims[0] / nranks; out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
int64_t recv_numel = in->numel() / nranks; uint64_t recv_numel = in->numel() / nranks;
void* inputPtr = reinterpret_cast<void*>(const_cast<T*>(in->data<T>())); void* inputPtr = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
void* outputPtr = reinterpret_cast<void*>(out->data<T>()); void* outputPtr = reinterpret_cast<void*>(out->data<T>());
hcclDataType_t dtype = platform::ToHCCLDataType(in->type()); HcclDataType dtype = platform::ToHCCLDataType(in->type());
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
...@@ -63,12 +62,11 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -63,12 +62,11 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
VLOG(3) << "begin hccl reduce scatter, parameter is: " VLOG(3) << "begin hccl reduce scatter, parameter is: "
<< "recv_numel: " << recv_numel << "recv_numel: " << recv_numel
<< "dtype: " << dtype << "dtype: " << dtype
<< "hccl_red_type: " << HCCL_REP_OP_SUM << "hccl_red_type: " << HCCL_REDUCE_SUM
<< ", group is: " << group << ", group is: " << group;
<< ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_reduce_scatter( PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclReduceScatter(
tag.c_str(), inputPtr, outputPtr, (u64)recv_numel, dtype, HCCL_REP_OP_SUM, group.c_str(), (void*)stream)); inputPtr, outputPtr, recv_numel, dtype, HCCL_REDUCE_SUM, comm->comm(), (void*)stream));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU.")); "PaddlePaddle should compile with NPU."));
......
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -45,7 +46,8 @@ namespace p = paddle::platform; ...@@ -45,7 +46,8 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(c_reducescatter); USE_OP(c_reducescatter);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_reducescatter, NPU); USE_OP_DEVICE_KERNEL(c_reducescatter, NPU);
DECLARE_string(selected_npus); DECLARE_string(selected_npus);
...@@ -59,7 +61,8 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) { ...@@ -59,7 +61,8 @@ void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
VLOG(2) << preStr << ":" << std::endl << debugstring; VLOG(2) << preStr << ":" << std::endl << debugstring;
} }
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
...@@ -68,22 +71,63 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -68,22 +71,63 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID")); << "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1}; std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
// init // init
auto x = scope->Var("X"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<float> init; std::vector<float> init;
...@@ -101,7 +145,7 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -101,7 +145,7 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("Out"); auto out = scope->Var("OutData");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num1, num2}); tensor_out->Resize({num1, num2});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
...@@ -114,14 +158,14 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -114,14 +158,14 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
attrs["ring_id"] = 0; attrs["ring_id"] = 0;
attrs["nranks"] = 2; attrs["nranks"] = 2;
auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"Data"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"OutData"}}}, attrs);
int iter_num = 10; int iter_num = 10;
for (int i = 0; i < iter_num; i++) { for (int i = 0; i < iter_num; i++) {
op->Run(*scope, place); op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
}
std::vector<float> out_vec; std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec); TensorToVector(*tensor_out, ctx, &out_vec);
...@@ -130,17 +174,18 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -130,17 +174,18 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
PrintDebugInfo("output data", out_vec); PrintDebugInfo("output data", out_vec);
EXPECT_EQ(out_vec.size(), init.size() / 2); EXPECT_EQ(out_vec.size(), init.size() / 2);
for (uint32_t i = 0; i < out_vec.size(); i++) { for (uint32_t i = 0; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], iter_num + 1); EXPECT_EQ(out_vec[i], 2.0);
} }
} }
TEST(c_reducescatter, NPU) { TEST(c_reducescatter, NPU) {
f::Scope scope; f::Scope scope;
HcclRootInfo hccl_id;
// only support one device, if more than one device, use first default // only support one device, if more than one device, use first default
auto* ctx = p::DeviceContextPool::Instance().Get( p::NPUDeviceContext ctx(p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
p::NPUPlace(atoi(FLAGS_selected_npus.c_str())));
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
TestHCCLReduceScatterOp(&scope, *ctx); Prepare(&scope, ctx, &hccl_id);
TestHCCLReduceScatterOp(&scope, ctx);
} }
...@@ -41,7 +41,7 @@ namespace m = paddle::operators::math; ...@@ -41,7 +41,7 @@ namespace m = paddle::operators::math;
USE_OP(c_broadcast); USE_OP(c_broadcast);
USE_NO_KERNEL_OP(c_sync_comm_stream); USE_NO_KERNEL_OP(c_sync_comm_stream);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(c_broadcast, NPU); USE_OP_DEVICE_KERNEL(c_broadcast, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_ASCEND_CL
class GenHCCLIdOp : public framework::OperatorBase {
public:
GenHCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
std::vector<std::string> trainers =
Attr<std::vector<std::string>>("trainers");
int trainer_id = Attr<int>("trainer_id");
std::string endpoint = trainers[trainer_id];
PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument(
"trainer_id %d is less than 0. Its "
"valid range is [0, trainer_size)"));
PADDLE_ENFORCE_LT(
trainer_id, static_cast<int>(trainers.size()),
platform::errors::OutOfRange("trainer_id %d is out of range. Its valid "
"range is [0, trainer_size)",
trainer_id));
int hccl_comm_num = Attr<int>("hccl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
int inter_nranks = Attr<int>("hierarchical_allreduce_inter_nranks");
int inter_trainer_id = -1;
int exter_trainer_id = -1;
if (use_hierarchical_allreduce) {
PADDLE_ENFORCE_GT(
trainers.size(), 1,
platform::errors::PreconditionNotMet(
"The number of collective trainers %llu <= 1", trainers.size()));
PADDLE_ENFORCE_GT(
inter_nranks, 1,
platform::errors::PreconditionNotMet(
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
inter_nranks));
PADDLE_ENFORCE_EQ(
trainers.size() % inter_nranks, 0,
platform::errors::PreconditionNotMet(
"The number of trainers %llu mod inter_nranks %d is not equal 0",
trainers.size(), inter_nranks));
inter_trainer_id = trainer_id % inter_nranks;
if (trainer_id % inter_nranks == 0) {
exter_trainer_id = trainer_id / inter_nranks;
}
}
std::ostringstream ss;
for (size_t i = 0; i < trainers.size(); i++) {
ss << trainers[i] << ",";
}
VLOG(1) << "trainer_id:" << trainer_id
<< ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
<< ", hccl_comm_num:" << hccl_comm_num
<< ", inter_nranks:" << inter_nranks
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< ", trainers:" << ss.str();
int server_fd = -1;
/// 1. init flat
std::function<std::string(size_t)> func = platform::GetFlatHCCLVarName;
if (trainer_id == 0) {
// server endpoints
std::vector<std::string> flat_endpoints;
flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1,
trainers.end());
SendBroadCastHCCLID(flat_endpoints, hccl_comm_num, func, scope);
} else {
server_fd = CreateListenSocket(endpoint);
RecvBroadCastHCCLID(server_fd, endpoint, hccl_comm_num, func, scope);
}
/// 2. hierarchical inter ncclid
func = platform::GetHierarchicalInterHCCLVarName;
if (inter_trainer_id == 0) {
std::ostringstream ss;
ss << endpoint;
std::vector<std::string> inter_endpoints;
for (int i = trainer_id + 1; i < trainer_id + inter_nranks &&
i < static_cast<int>(trainers.size());
i++) {
ss << ",";
inter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
SendBroadCastHCCLID(inter_endpoints, hccl_comm_num, func, scope);
} else if (inter_trainer_id > 0) {
VLOG(1) << "Hierarchical inter ring";
RecvBroadCastHCCLID(server_fd, endpoint, hccl_comm_num, func, scope);
}
/// 3. hierarchical exter ncclid
func = platform::GetHierarchicalExterHCCLVarName;
if (exter_trainer_id == 0) {
std::ostringstream ss;
std::vector<std::string> exter_endpoints;
ss << endpoint;
for (size_t i = inter_nranks; i < trainers.size(); i += inter_nranks) {
ss << ",";
exter_endpoints.push_back(trainers[i]);
ss << trainers[i];
}
VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str();
SendBroadCastHCCLID(exter_endpoints, hccl_comm_num, func, scope);
} else if (exter_trainer_id > 0) {
VLOG(1) << "Hierarchical exter ring";
RecvBroadCastHCCLID(server_fd, endpoint, hccl_comm_num, func, scope);
}
// close socket server
if (trainer_id != 0) {
CloseSocket(server_fd);
}
}
};
#else
class GenHCCLIdOp : public framework::OperatorBase {
public:
GenHCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
}
};
#endif
class GenHCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("HCCLID", "Raw variable contains a HCCL UniqueId instaces.");
AddComment(R"DOC(
GenHCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::vector<std::string>>(
"trainers",
"['trainer0_ip:port', 'trainer1_ip:port', ...] "
"list of all trainer endpoints")
.SetDefault({});
AddAttr<int>("trainer_id",
"(int) "
"The index of the trainer in distributed training.");
AddAttr<int>("hccl_comm_num",
"(int default 1) "
"The number of nccl communicator num.")
.SetDefault(1);
AddAttr<bool>("use_hierarchical_allreduce",
"(bool default false) "
"Wheter to use hierarchical allreduce.")
.SetDefault(false);
AddAttr<int>("hierarchical_allreduce_inter_nranks",
"(int default 1) "
"Wheter to use hierarchical allreduce.")
.SetDefault(-1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gen_hccl_id, ops::GenHCCLIdOp, ops::GenHCCLIdOpMaker);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <algorithm>
#include <ostream>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/split.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif
namespace paddle {
namespace operators {
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
#define HCCL_UNIQUE_ID_BYTES 1024
// Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \
do { \
int retval; \
CHECK_SYS_CALL_VAL(call, name, retval); \
} while (false)
#define CHECK_SYS_CALL_VAL(call, name, retval) \
do { \
RETRY_SYS_CALL_VAL(call, name, retval); \
if (retval == -1) { \
PADDLE_THROW(platform::errors::Unavailable("Call to %s failed: %s", \
name, strerror(errno))); \
} \
} while (false)
#define RETRY_SYS_CALL_VAL(call, name, retval) \
do { \
retval = (call); \
if (retval == -1 && \
(errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
LOG(WARNING) << "Call " << name << " returned " << strerror(errno) \
<< " retry"; \
} else { \
break; \
} \
} while (true)
static int SocketSend(int fd, const char* buffer, int size) {
int offset = 0;
int bytes = 0;
while (offset < size) {
bytes = send(fd, buffer + offset, size - offset, 0);
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
// send failed
return -1;
} else {
bytes = 0;
}
}
offset += bytes;
}
return offset;
}
static int SocketRecv(int fd, char* buffer, int size) {
int offset = 0;
int bytes = 0;
while (offset < size) {
bytes = recv(fd, buffer + offset, size - offset, 0);
if (bytes == 0) {
// closed by client, maybe probing alive client
return 0;
}
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
return -1;
} else {
bytes = 0;
}
}
offset += bytes;
}
return offset;
}
static void BindOrConnectFailed(int timeout, int* try_times, int* total_time,
const char* op, const std::string& ep) {
PADDLE_ENFORCE_LT(
*total_time, timeout,
platform::errors::Unavailable("%s addr=%s timeout, failed reason: %s", op,
ep.c_str(), strerror(errno)));
++(*try_times);
int retry_time = std::min(*try_times * 500, 3000); // max 3 seconds
*total_time += retry_time;
LOG(WARNING) << op << " addr=" << ep << " failed " << *try_times
<< " times with reason: " << strerror(errno) << " retry after "
<< retry_time / 1000.0 << " seconds";
std::this_thread::sleep_for(std::chrono::milliseconds(retry_time));
}
int CreateListenSocket(const std::string& ep) {
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
// creating socket fd
int server_fd = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", server_fd);
// NOTE. Solutions to `Address already in use`.
// 1. Reuse addr&port. Otherwise, once the server closes the socket
// before client, the server will enter TIME-WAIT status. If we bind port
// again, the error `Address already in use` will appear.
// 2. Or we can close the client first to ensure that the server does
// not enter the TIME-WAIT state. But this is obviously not as convenient
// as the reuse method.
int opt = 1;
#if defined(SO_REUSEPORT)
// since Linux kernel 3.9
CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT,
&opt, sizeof(opt)),
"setsockopt");
#else
CHECK_SYS_CALL(
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)),
"setsockopt");
#endif
struct sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(port);
// TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000;
int try_times = 0;
int total_time = 0;
while (true) {
int ret_val = -1;
RETRY_SYS_CALL_VAL(
bind(server_fd, (struct sockaddr*)&address, sizeof(address)), "bind",
ret_val);
if (ret_val == -1) {
BindOrConnectFailed(timeout, &try_times, &total_time, "bind", ep);
continue;
}
break;
}
CHECK_SYS_CALL(listen(server_fd, 3), "listen");
LOG(INFO) << "Server listening on: " << ep << " successful.";
return server_fd;
}
void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); }
static int SocketAccept(int server_fd, const char* head) {
struct sockaddr_in client_addr;
socklen_t addr_length = sizeof(client_addr);
char buffer[1024] = {0};
int conn = -1;
while (true) {
CHECK_SYS_CALL_VAL(
accept(server_fd, reinterpret_cast<struct sockaddr*>(&client_addr),
&addr_length),
"accept", conn);
int ret_val = SocketRecv(conn, buffer, strlen(head));
if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) {
break; // accept client
} else {
VLOG(3) << "socket read failed with ret_val=" << ret_val;
CloseSocket(conn);
}
}
return conn;
}
static int ConnectAddr(const std::string& ep, const char* head) {
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
int sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
char* ip = NULL;
struct hostent* hp = NULL;
hp = gethostbyname(host.c_str());
PADDLE_ENFORCE_NOT_NULL(hp, platform::errors::InvalidArgument(
"Fail to get host by name %s.", host));
int i = 0;
while (hp->h_addr_list[i] != NULL) {
ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]);
VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip;
break;
}
PADDLE_ENFORCE_GT(inet_pton(AF_INET, ip, &server_addr.sin_addr), 0,
platform::errors::Unavailable("Open address %s failed: %s",
ep, strerror(errno)));
// TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000;
int try_times = 0;
int total_time = 0;
while (true) {
int ret_val = -1;
RETRY_SYS_CALL_VAL(
connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)),
"connect", ret_val);
if (ret_val == -1) {
BindOrConnectFailed(timeout, &try_times, &total_time, "connect", ep);
continue;
}
CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send");
break;
}
return sock;
}
static void RecvHCCLID(int conn, HcclRootInfo* hccl_id) {
char buffer[1024] = {0};
static_assert(HCCL_UNIQUE_ID_BYTES <= 1024,
"hccl id bytes must <= buffer size");
CHECK_SYS_CALL(SocketRecv(conn, buffer, HCCL_UNIQUE_ID_BYTES), "recv hccl id");
memcpy(hccl_id, buffer, HCCL_UNIQUE_ID_BYTES);
}
static void SendHCCLID(int conn, HcclRootInfo* hccl_id) {
char buffer[1024] = {0};
memcpy(buffer, hccl_id, HCCL_UNIQUE_ID_BYTES);
CHECK_SYS_CALL(SocketSend(conn, buffer, HCCL_UNIQUE_ID_BYTES),
"send hccl id");
}
void SendBroadCastHCCLID(std::vector<std::string> servers, int hccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
// connect with server
std::vector<int> connects;
for (auto server : servers) {
VLOG(3) << "connecting endpoint: " << server;
int conn = ConnectAddr(server, COMM_HEAD);
connects.push_back(conn);
}
VLOG(3) << "connecting completed...";
for (int i = 0; i < hccl_comm_num; ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto hccl_id = var->GetMutable<HcclRootInfo>();
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclGetRootInfo(hccl_id));
int j = 0;
for (auto conn : connects) {
VLOG(3) << "sending hccl_id_var: " << var_name << " to " << servers[j]
<< " hccl_comm_no: " << i;
SendHCCLID(conn, hccl_id);
++j;
}
VLOG(3) << "sending completed...";
}
// close client
for (auto conn : connects) {
CloseSocket(conn);
}
}
void RecvBroadCastHCCLID(std::string endpoint, int hccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
int server = CreateListenSocket(endpoint);
RecvBroadCastHCCLID(server, endpoint, hccl_comm_num, func, scope);
CloseSocket(server);
}
void RecvBroadCastHCCLID(int server_fd, std::string endpoint, int hccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
int client = SocketAccept(server_fd, COMM_HEAD);
for (int i = 0; i < hccl_comm_num; ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto hccl_id = var->GetMutable<HcclRootInfo>();
VLOG(3) << "trainer: " << endpoint << " receiving hccl_id_var: " << var_name
<< " from trainer 0, hccl_comm_no: " << i;
RecvHCCLID(client, hccl_id);
}
VLOG(3) << "receiving completed...";
CloseSocket(client);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <string>
#include <vector>
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
int CreateListenSocket(const std::string& ep);
void CloseSocket(int fd);
void SendBroadCastHCCLID(std::vector<std::string> servers, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
// server listen on endpoint, then recv nccl id
void RecvBroadCastHCCLID(std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
// recv nccl id from socket
void RecvBroadCastHCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
} // namespace operators
} // namespace paddle
...@@ -27,32 +27,39 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> { ...@@ -27,32 +27,39 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
auto out = ctx.Output<framework::LoDTensor>("Out"); auto x = ctx.Output<framework::LoDTensor>("Out");
int numel = out->numel(); void *ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
hcclDataType_t dtype = platform::ToHCCLDataType(out->type()); int numel = x->numel();
HcclDataType dtype = platform::ToHCCLDataType(x->type());
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
if (ctx.Attr<bool>("use_calc_stream")) {
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); int nranks = comm->nranks();
int srcRank = ctx.Attr<int>("peer"); int peer = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
VLOG(3) << "recv_v2_npu attr get"; PADDLE_ENFORCE_EQ(nranks, 2,
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_receive( platform::errors::InvalidArgument(
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(out->data<T>())), (u64)numel, dtype, srcRank, "The nranks must be 2, but (%d)",
srTag, group.c_str(), stream)); nranks));
VLOG(3) << "Source Rank: " << srcRank << " Invoke hcom receive. receiving ";
out->Resize(out->dims()); int root = peer;
out->set_lod(out->lod());
VLOG(3) << "begin hccl recv, parameter is: "<< "root " << root
<< ", comm: " << comm->comm() << ", stream: " << stream;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(ptr, numel,
dtype, (uint32_t)root, comm->comm(), stream));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU.")); "PaddlePaddle should compile with NPU."));
......
...@@ -31,6 +31,8 @@ limitations under the License. */ ...@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/recv_v2_op.h" #include "paddle/fluid/operators/collective/recv_v2_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -42,45 +44,86 @@ namespace p = paddle::platform; ...@@ -42,45 +44,86 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(recv_v2); USE_OP(recv_v2);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(recv_v2, NPU); USE_OP_DEVICE_KERNEL(recv_v2, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) { void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
std::string rank_table_file = getenv("RANK_TABLE_FILE");
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
int src_rank = atoi(getenv("SRC_RANK"));
int dest_rank = atoi(getenv("DEST_RANK"));
VLOG(3) << "rank_id " << rank_id << "src_rank" << src_rank << "dest_rank"
<< dest_rank;
std::vector<int> rank_ids = {0, 1}; VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
VLOG(3) << "CreateOp c_comm_init_hcom";
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){
std::cout << "BEGIN TEST:" << __FUNCTION__ << std::endl; std::cout << "BEGIN TEST:" << __FUNCTION__ << std::endl;
int num = atoi(getenv("DATA_SIZE")); int num = atoi(getenv("DATA_SIZE"));
EXPECT_GT(num, 0); EXPECT_GT(num, 0);
EXPECT_LT(num, 1 << 15); EXPECT_LT(num, 1 << 15);
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
VLOG(3) << "rank_id:" << rank_id << std::endl; VLOG(3) << "rank_id:" << rank_id<<std::endl;
ctx.Wait(); ctx.Wait();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto out = scope->Var("Out"); auto out = scope->Var("Data");
auto tensor_out = out->GetMutable<f::LoDTensor>(); auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({num, num}); tensor_out->Resize({num, num});
tensor_out->mutable_data<float>(place); // allocate tensor_out->mutable_data<float>(place); // allocate
...@@ -88,37 +131,39 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -88,37 +131,39 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait(); ctx.Wait();
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("srtest"); attrs["tag"]=std::string("srtest");
attrs["peer"] = atoi(getenv("SRC_RANK")); attrs["peer"]=atoi(getenv("SRC_RANK"));
attrs["ring_id"] = 0; attrs["ring_id"]=0;
attrs["srTag"] = 0; attrs["srTag"]=0;
std::vector<int> out_shape; std::vector<int> out_shape;
out_shape.push_back(num); out_shape.push_back(num);
out_shape.push_back(num); out_shape.push_back(num);
attrs["out_shape"] = out_shape; attrs["out_shape"]=out_shape;
auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs); auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Data"}}}, attrs);
VLOG(3) << "CreateOp recv_v2"; VLOG(3) << "CreateOp recv_v2";
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
} }
VLOG(3) << "Run op recv_v2"; VLOG(3) << "Run op recv_v2";
std::vector<float> out_vec; std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec); TensorToVector(*tensor_out, ctx, &out_vec);
ctx.Wait(); ctx.Wait();
std::vector<float> init(num * num, 1.0 * atoi(getenv("DEST_RANK"))); std::vector<float> init(num*num, 1.0 * atoi(getenv("DEST_RANK")));
EXPECT_EQ(out_vec == init, true); EXPECT_EQ(out_vec == init, true);
} }
TEST(recv_v2, NPU) {
TEST(recv_v2, NPU){
f::Scope scope; f::Scope scope;
char* npu_id = getenv("FLAGS_selected_npus"); HcclRootInfo hccl_id;
char * npu_id=getenv("FLAGS_selected_npus");
VLOG(3) << "Select npu:" << npu_id; VLOG(3) << "Select npu:" << npu_id;
auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(atoi(npu_id))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
VLOG(3) << "Place over";
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
VLOG(3) << "Prepare over"; Prepare(&scope, ctx, &hccl_id);
TestHcomRecvOp(&scope, *ctx); TestHcomRecvOp(&scope, ctx);
VLOG(3) << "Test over";
} }
...@@ -28,31 +28,37 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> { ...@@ -28,31 +28,37 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
auto x = ctx.Input<framework::LoDTensor>("X"); auto x = ctx.Input<framework::LoDTensor>("X");
void *ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
int numel = x->numel(); int numel = x->numel();
hcclDataType_t dtype = platform::ToHCCLDataType(x->type()); HcclDataType dtype = platform::ToHCCLDataType(x->type());
auto place = ctx.GetPlace();
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto place = ctx.GetPlace();
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
if (ctx.Attr<bool>("use_calc_stream")) {
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int destRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_send( int nranks = comm->nranks();
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), (u64)numel, dtype, destRank, int rank = comm->rank();
srTag, group.c_str(), stream));
PADDLE_ENFORCE_EQ(nranks, 2,
platform::errors::InvalidArgument(
"The nranks must be 2, but (%d)",
nranks));
int root = rank;
VLOG(3) << "begin hccl send, parameter is: "<< "root " << root
<< ", comm: " << comm->comm() << ", stream: " << stream;
VLOG(3) << "Dest rank:" << destRank << " Invoke hcom send. Sent " PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(ptr, numel,
<< x->numel(); dtype, (uint32_t)root, comm->comm(), stream));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
...@@ -30,6 +30,7 @@ limitations under the License. */ ...@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/send_v2_op.h" #include "paddle/fluid/operators/collective/send_v2_op.h"
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -41,43 +42,85 @@ namespace p = paddle::platform; ...@@ -41,43 +42,85 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
USE_OP(send_v2); USE_OP(send_v2);
USE_NO_KERNEL_OP(c_comm_init_hcom); USE_NO_KERNEL_OP(c_gen_hccl_id);
USE_NO_KERNEL_OP(c_comm_init_hccl);
USE_OP_DEVICE_KERNEL(send_v2, NPU); USE_OP_DEVICE_KERNEL(send_v2, NPU);
void Prepare(f::Scope* scope, const p::DeviceContext& ctx) {
std::string rank_table_file = getenv("RANK_TABLE_FILE"); void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID")); int device_id = atoi(getenv("DEVICE_ID"));
int src_rank = atoi(getenv("SRC_RANK"));
int dest_rank = atoi(getenv("DEST_RANK"));
VLOG(3) << "rank_id " << rank_id << "src_rank" << src_rank << "dest_rank"
<< dest_rank;
std::vector<int> rank_ids = {0, 1}; VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
std::vector<int> rank_ids{0, 1};
f::AttributeMap gen_hccl_id;
std::vector<std::string > endpointList={"127.0.0.1:6175", "127.0.0.1:6177"};
gen_hccl_id["rank"] = rank_id;
gen_hccl_id["endpoint"] = endpointList[rank_id];
std::vector<std::string> other_endpoints= {endpointList[rank_id == 0 ? 1 : 0]};
gen_hccl_id["other_endpoints"] = other_endpoints;
auto out = scope->Var("Out");
auto id = out->GetMutable<HcclRootInfo>();
VLOG(3) << "break";
auto comm_init_op =
f::OpRegistry::CreateOp("c_gen_hccl_id", {}, {{"Out", {"Out"}}}, gen_hccl_id);
VLOG(3) << "break";
auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place);
ctx.Wait();
memcpy(hccl_id, id, 1024);
}
void Prepare(f::Scope* scope, const p::DeviceContext& ctx, HcclRootInfo* hccl_id){
auto x = scope->Var("X");
auto id = x->GetMutable<HcclRootInfo>();
memcpy(id, hccl_id, 1024);
int rank_id = atoi(getenv("RANK_ID"));
int device_id = atoi(getenv("DEVICE_ID"));
VLOG(2) << "rank_id = " << rank_id
<< "; device_id = " << device_id
<< "; rank_id = " << rank_id
<< "; RANK_TABLE_FILE = " << atoi(getenv("DEVICE_ID"));
// std::vector<int> rank_ids{0, 1};
f::AttributeMap comm_init_attrs; f::AttributeMap comm_init_attrs;
comm_init_attrs["ring_id"] = 0; comm_init_attrs["ring_id"] = 0;
comm_init_attrs["nranks"] = 2; comm_init_attrs["rank_ids"] = 2;
comm_init_attrs["rank"] = rank_id; comm_init_attrs["rank"] = rank_id;
comm_init_attrs["device_id"] = device_id; comm_init_attrs["device_id"] = device_id;
comm_init_attrs["rank_ids"] = rank_ids; // comm_init_attrs["rank_ids"] = rank_ids;
auto comm_init_op = auto comm_init_op =
f::OpRegistry::CreateOp("c_comm_init_hcom", {}, {}, comm_init_attrs); f::OpRegistry::CreateOp("c_comm_init_hccl", {{"X", {"X"}}}, {}, comm_init_attrs);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
comm_init_op->Run(*scope, place); comm_init_op->Run(*scope, place);
ctx.Wait(); ctx.Wait();
} }
void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx) { void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){
std::cout << "BEGIN TEST:" << __FUNCTION__ << std::endl; std::cout<< "BEGIN TEST:"<< __FUNCTION__ <<std::endl;
auto x = scope->Var("X"); auto x = scope->Var("Data");
auto tensor_x = x->GetMutable<f::LoDTensor>(); auto tensor_x = x->GetMutable<f::LoDTensor>();
int num = atoi(getenv("DATA_SIZE")); int num = atoi(getenv("DATA_SIZE"));;
EXPECT_GT(num, 0); EXPECT_GT(num, 0);
EXPECT_LT(num, 1 << 15); EXPECT_LT(num, 1 << 15);
std::vector<float> init(num * num, 1.0 * atoi(getenv("DEST_RANK"))); std::vector<float> init(num*num, 1.0 * atoi(getenv("DEST_RANK")));
int rank_id = atoi(getenv("RANK_ID")); int rank_id = atoi(getenv("RANK_ID"));
VLOG(3) << "rank id:" << rank_id; VLOG(3)<<"rank id:"<<rank_id;
TensorFromVector(init, ctx, tensor_x); TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({num, num}); tensor_x->Resize({num, num});
ctx.Wait(); ctx.Wait();
...@@ -85,28 +128,29 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -85,28 +128,29 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx) {
ctx.Wait(); ctx.Wait();
f::AttributeMap attrs; f::AttributeMap attrs;
attrs["tag"] = std::string("srtest"); attrs["tag"]=std::string("srtest");
attrs["peer"] = atoi(getenv("DEST_RANK")); attrs["peer"]=atoi(getenv("DEST_RANK"));
attrs["ring_id"] = 0; attrs["ring_id"]=0;
attrs["srTag"] = 0; attrs["srTag"]=0;
auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs); auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"Data"}}}, {}, attrs);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
} }
VLOG(3) << "send run over"; VLOG(3)<<"send run over";
ctx.Wait(); ctx.Wait();
} }
TEST(send_v2, NPU) { TEST(send_v2, NPU){
f::Scope scope; f::Scope scope;
char* npu_id = getenv("FLAGS_selected_npus"); HcclRootInfo hccl_id;
char * npu_id=getenv("FLAGS_selected_npus");
VLOG(3) << "Select npu:" << npu_id; VLOG(3) << "Select npu:" << npu_id;
auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(atoi(npu_id))); p::NPUDeviceContext ctx(p::NPUPlace(atoi(npu_id)));
VLOG(3) << "Place over";
Prepare(&scope, *ctx); PrepareUniqueId(&scope, ctx, &hccl_id);
VLOG(3) << "Prepare over"; Prepare(&scope, ctx, &hccl_id);
TestHcomSendOp(&scope, *ctx); TestHcomSendOp(&scope, ctx);
VLOG(3) << "Test over";
} }
...@@ -157,15 +157,10 @@ class HCCLComm { ...@@ -157,15 +157,10 @@ class HCCLComm {
virtual int nranks() const = 0; virtual int nranks() const = 0;
virtual int rank() const = 0; virtual int rank() const = 0;
virtual int device_id() const = 0; virtual int device_id() const = 0;
virtual HcclComm comm() const = 0;
virtual aclrtStream stream() const = 0; virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0; virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default; virtual ~HCCLComm() = default;
unsigned long NextTagId() {
return tag_counter_++;
}
private:
std::atomic<unsigned long> tag_counter_;
}; };
// A singleton HCCL communicator context reserves communication ring ids // A singleton HCCL communicator context reserves communication ring ids
...@@ -176,11 +171,12 @@ class HCCLCommContext { ...@@ -176,11 +171,12 @@ class HCCLCommContext {
return comm_ctx; return comm_ctx;
} }
HCCLComm* CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id = 0); HCCLComm* CreateHCCLComm(HcclRootInfo* hccl_id, int nranks,
int rank, int dev_id, int ring_id);
// a latter comm with the same dev_id and the same ring_id // a latter comm with the same dev_id and the same ring_id
// will override the former // will override the former
HCCLComm* AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id = 0); HCCLComm* AssignHCCLComm(HcclComm comm, int nranks, int rank,
int dev_id, int ring_id);
// retrieve a communicator by the ring id in multiprocessing mode // retrieve a communicator by the ring id in multiprocessing mode
HCCLComm* Get(int ring_id) const { HCCLComm* Get(int ring_id) const {
...@@ -217,20 +213,21 @@ class HCCLCommContext { ...@@ -217,20 +213,21 @@ class HCCLCommContext {
private: private:
// Init global hcom // Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); } HCCLCommContext() {}
// we may use group feature in the feature
// HCCLCommContext() { InitHcomWorldGroup(); }
HcclComm comm_;
public: public:
~HCCLCommContext(){ ~HCCLCommContext(){ }
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy());
}
std::once_flag once_flag_; std::once_flag once_flag_;
std::mutex comm_map_mutex_; std::mutex comm_map_mutex_;
// ring id to dev-HCCLComm // ring id to dev-HCCLComm
std::map<int, std::map<int, std::unique_ptr<HCCLComm>>> comm_map_; std::map<int, std::map<int, std::unique_ptr<HCCLComm>>> comm_map_;
void InitHcomWorldGroup(); // void InitHcomWorldGroup();
void ReleaseHCCLComms(); void ReleaseHCCLComms();
DISABLE_COPY_AND_ASSIGN(HCCLCommContext); DISABLE_COPY_AND_ASSIGN(HCCLCommContext);
......
...@@ -34,6 +34,13 @@ class HCCLCommImpl : public HCCLComm { ...@@ -34,6 +34,13 @@ class HCCLCommImpl : public HCCLComm {
return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device; return BOOST_GET_CONST(NPUPlace, dev_ctx_->GetPlace()).device;
} }
~HCCLCommImpl(){
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclCommDestroy(comm_));
}
void set_comm(HcclComm comm) { comm_ = comm; }
HcclComm comm() const override { return comm_; }
aclrtStream stream() const override { return dev_ctx_->stream(); } aclrtStream stream() const override { return dev_ctx_->stream(); }
void set_dev_ctx(std::unique_ptr<NPUDeviceContext>&& dev_ctx) { void set_dev_ctx(std::unique_ptr<NPUDeviceContext>&& dev_ctx) {
...@@ -45,46 +52,43 @@ class HCCLCommImpl : public HCCLComm { ...@@ -45,46 +52,43 @@ class HCCLCommImpl : public HCCLComm {
int ring_id_; int ring_id_;
int nranks_; int nranks_;
int rank_; int rank_;
HcclComm comm_;
std::unique_ptr<NPUDeviceContext> dev_ctx_; std::unique_ptr<NPUDeviceContext> dev_ctx_;
}; };
HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids, int rank, int dev_id, int ring_id) { HCCLComm* HCCLCommContext::CreateHCCLComm(HcclRootInfo* hccl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(hccl_id,
platform::errors::InvalidArgument(
"The hccl unique id should not be null."));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
world_rank_ids.size(), 1, nranks, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expected world_rank_ids.size() > 1. But received size is %d.", world_rank_ids.size())); "Expected nranks > 1. But received nranks is %d.", nranks));
PADDLE_ENFORCE_GE(rank, 0, PADDLE_ENFORCE_GE(rank, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank)); "Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
rank, world_rank_ids.size(), rank, nranks,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expected rank < nranks. But received rank is %d, nranks is %d.", "Expected rank < nranks. But received rank is %d, nranks is %d.",
rank, world_rank_ids.size())); rank, nranks));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
dev_id, 0, dev_id, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", dev_id)); "Expected dev_id >= 0. But received dev_id is %d.", dev_id));
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"Expected ring_id >= 0. But received ring_id is %d.", ring_id));
auto* comm_wrapper = AssignHCCLComm(world_rank_ids.size(), rank, dev_id, ring_id);
// HACK(sunpeng17): hcom API requires bind stream to a model HcclComm comm;
// but we don't need model in Paddle, so we feed stream pointer as model pointer PADDLE_ENFORCE_NPU_SUCCESS(aclrtSetDevice(dev_id));
PADDLE_ENFORCE_NPU_SUCCESS( PADDLE_ENFORCE_NPU_SUCCESS(
platform::dynload::hcom_bind_model(comm_wrapper->stream(), platform::dynload::HcclCommInitRootInfo(nranks, hccl_id, rank, &comm));
comm_wrapper->stream()));
// Get world_rank_ids registered in gen_nccl_id op VLOG(1) << "initialized comm: " << &comm << ", nranks: " << nranks << ", hccl_id: " << hccl_id << ", rank: " << rank;
std::string group_name = HCOM_GROUP_PREFIX + std::to_string(ring_id);
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_create_group( auto* comm_wrapper = AssignHCCLComm(comm, nranks, rank, dev_id, ring_id);
group_name.c_str(), world_rank_ids.size(), (unsigned int*)world_rank_ids.data()));
VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id VLOG(1) << "hccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created on device " << dev_id << ", group name: " << group_name; << " has been created on device " << dev_id << ", with comm: " << comm_wrapper->comm();
std::call_once(once_flag_, []() { std::call_once(once_flag_, []() {
std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); }); std::atexit([]() { HCCLCommContext::Instance().ReleaseHCCLComms(); });
...@@ -93,7 +97,8 @@ HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids ...@@ -93,7 +97,8 @@ HCCLComm* HCCLCommContext::CreateHCCLComm(const std::vector<int>& world_rank_ids
return comm_wrapper; return comm_wrapper;
} }
HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ring_id) { HCCLComm* HCCLCommContext::AssignHCCLComm(HcclComm comm, int nranks, int rank,
int dev_id, int ring_id) {
std::unique_ptr<NPUDeviceContext> dev_ctx( std::unique_ptr<NPUDeviceContext> dev_ctx(
new NPUDeviceContext(NPUPlace(dev_id))); new NPUDeviceContext(NPUPlace(dev_id)));
...@@ -101,6 +106,7 @@ HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ...@@ -101,6 +106,7 @@ HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int
c->set_ring_id(ring_id); c->set_ring_id(ring_id);
c->set_nranks(nranks); c->set_nranks(nranks);
c->set_rank(rank); c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx)); c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock(); comm_map_mutex_.lock();
...@@ -112,23 +118,14 @@ HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int ...@@ -112,23 +118,14 @@ HCCLComm* HCCLCommContext::AssignHCCLComm(int nranks, int rank, int dev_id, int
dev2comm.emplace(dev_id, std::unique_ptr<HCCLComm>(c)); dev2comm.emplace(dev_id, std::unique_ptr<HCCLComm>(c));
comm_map_mutex_.unlock(); comm_map_mutex_.unlock();
return comm_map_[ring_id][dev_id].get(); if (ring_id == 0) {
} auto* dev_ctx = static_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(
void HCCLCommContext::InitHcomWorldGroup() { platform::NPUPlace(dev_id)));
const char *rank_table_file = getenv(ENV_RANK_TABLE_FILE); dev_ctx->set_hccl_comm(comm);
PADDLE_ENFORCE_NOT_NULL( }
rank_table_file,
platform::errors::InvalidArgument("The RANK_TABLE_FILE environment variable should not be null."));
const char *rank_id = getenv(ENV_RANK_ID);
PADDLE_ENFORCE_NOT_NULL(
rank_id,
platform::errors::InvalidArgument("The RANK_ID environment variable should not be null."));
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_init(rank_table_file, rank_id)); return comm_map_[ring_id][dev_id].get();
VLOG(3) << "Successfully initialized hcom. rank_table_file: "
<< rank_table_file << ", rank_id " << rank_id;
} }
void HCCLCommContext::ReleaseHCCLComms() { void HCCLCommContext::ReleaseHCCLComms() {
......
...@@ -185,11 +185,21 @@ class NPUDeviceContext : public DeviceContext { ...@@ -185,11 +185,21 @@ class NPUDeviceContext : public DeviceContext {
void WaitStreamCallback() const { return stream_->WaitCallback(); } void WaitStreamCallback() const { return stream_->WaitCallback(); }
#if defined(PADDLE_WITH_ASCEND_CL)
/*! \brief Return hccl communicators. */
HcclComm hccl_comm() const { return hccl_comm_; }
/*! \brief Set hccl communicators. */
void set_hccl_comm(HcclComm comm) { hccl_comm_ = comm; }
#endif
private: private:
NPUPlace place_; NPUPlace place_;
aclrtContext context_; aclrtContext context_;
#ifdef PADDLE_WITH_ASCEND_HCCL
HCCLContext_t hccl_context_; #ifdef PADDLE_WITH_ASCEND_CL
// HCCLContext_t hccl_context_;
HcclComm hccl_comm_{nullptr};
#endif #endif
// Need to be the same with other DeviceContext, // Need to be the same with other DeviceContext,
......
...@@ -13,14 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
// #include <hccl/hccl.h> #include <hccl/hccl.h>
// #include <hccl/hccl_types.h> #include <hccl/hccl_types.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/dynload/hcom.h" // #include "paddle/fluid/platform/dynload/hcom.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#define HCOM_GROUP_PREFIX "HCOM_GROUP_"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
...@@ -43,27 +45,14 @@ extern void* hccl_dso_handle; ...@@ -43,27 +45,14 @@ extern void* hccl_dso_handle;
extern DynLoad__##__name __name extern DynLoad__##__name __name
#define HCCL_RAND_ROUTINE_EACH(__macro) \ #define HCCL_RAND_ROUTINE_EACH(__macro) \
__macro(hcom_init); \ __macro(HcclReduceScatter); \
__macro(hcom_destroy); \ __macro(HcclCommDestroy); \
__macro(hcom_bind_model); \ __macro(HcclAllReduce); \
__macro(hcom_unbind_model); \ __macro(HcclCommInitRootInfo); \
__macro(hcom_send); \ __macro(HcclGetRootInfo); \
__macro(hcom_receive); \ __macro(HcclBroadcast); \
__macro(hcom_broadcast); \ __macro(HcclCommInitClusterInfo); \
__macro(hcom_all_gather); \ __macro(HcclAllGather);
__macro(hcom_all_reduce); \
__macro(hcom_reduce_scatter); \
__macro(hcom_create_group); \
__macro(hcom_destroy_group); \
__macro(hcom_get_rank_id); \
__macro(hcom_get_local_rank_id); \
__macro(hcom_get_local_rank_size); \
__macro(hcom_get_split_strategy); \
__macro(hcom_set_split_strategy_by_size); \
__macro(hcom_set_split_strategy_by_index); \
__macro(hcom_get_group_rank_from_world_rank); \
__macro(hcom_get_world_rank_from_group_rank);
HCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_HCCL_WRAP) HCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_HCCL_WRAP)
......
...@@ -40,7 +40,7 @@ limitations under the License. */ ...@@ -40,7 +40,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h" #include "acl/acl.h"
#include "paddle/fluid/platform/dynload/hcom.h" #include "hccl/hccl_types.h"
#endif // PADDLE_WITH_ASCEND_CL #endif // PADDLE_WITH_ASCEND_CL
#include <fstream> #include <fstream>
...@@ -1013,7 +1013,7 @@ struct NPUStatusType {}; ...@@ -1013,7 +1013,7 @@ struct NPUStatusType {};
} }
DEFINE_NPU_STATUS_TYPE(aclError, ACL_ERROR_NONE); DEFINE_NPU_STATUS_TYPE(aclError, ACL_ERROR_NONE);
DEFINE_NPU_STATUS_TYPE(hcclResult_t, HCCL_SUCCESS); DEFINE_NPU_STATUS_TYPE(HcclResult, HCCL_SUCCESS);
} // namespace details } // namespace details
inline std::string build_npu_error_msg(aclError stat) { inline std::string build_npu_error_msg(aclError stat) {
...@@ -1022,7 +1022,7 @@ inline std::string build_npu_error_msg(aclError stat) { ...@@ -1022,7 +1022,7 @@ inline std::string build_npu_error_msg(aclError stat) {
return sout.str(); return sout.str();
} }
inline std::string build_npu_error_msg(hcclResult_t stat) { inline std::string build_npu_error_msg(HcclResult stat) {
std::ostringstream sout; std::ostringstream sout;
sout << " HCCL error, the error code is : " << stat << ". "; sout << " HCCL error, the error code is : " << stat << ". ";
return sout.str(); return sout.str();
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_HCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_ASCEND_CL)
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
...@@ -24,30 +24,22 @@ ...@@ -24,30 +24,22 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#ifdef PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/dynload/hccl.h" #include "paddle/fluid/platform/dynload/hccl.h"
#endif #endif
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#define NCCL_ID_VARNAME "NCCLID" #define HCCL_ID_VARNAME "HCCLID"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
inline hcclDataType_t ToHCCLDataType(framework::proto::VarType::Type type) { inline HcclDataType ToHCCLDataType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP32) { if (type == framework::proto::VarType::FP32) {
return HCCL_DATA_TYPE_FP32; return HCCL_DATA_TYPE_FP32;
} else if (type == framework::proto::VarType::FP16) { } else if (type == framework::proto::VarType::FP16) {
...@@ -66,298 +58,301 @@ inline hcclDataType_t ToHCCLDataType(framework::proto::VarType::Type type) { ...@@ -66,298 +58,301 @@ inline hcclDataType_t ToHCCLDataType(framework::proto::VarType::Type type) {
} }
} }
// // NOTE(minqiyang): according to the ncclGroupEnd documentations: // NOTE(minqiyang): according to the ncclGroupEnd documentations:
// // https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html, // https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// // ncclGroupEnd will wait for all communicators to be initialized, which will // ncclGroupEnd will wait for all communicators to be initialized, which will
// // cause blocking problem when a runtime_error was thrown, so try only guard // cause blocking problem when a runtime_error was thrown, so try only guard
// // NCCL actions when use it. // HCCL actions when use it.
// class NCCLGroupGuard {
// class HCCLGroupGuard {
// public: // public:
// static std::mutex &NCCLMutex() { // static std::mutex &HCCLMutex() {
// static std::mutex mtx; // static std::mutex mtx;
// return mtx; // return mtx;
// } // }
// inline NCCLGroupGuard() { // inline HCCLGroupGuard() {
// NCCLMutex().lock(); // HCCLMutex().lock();
// PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupStart()); // PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupStart());
// } // }
// inline ~NCCLGroupGuard() PADDLE_MAY_THROW { // inline ~HCCLGroupGuard() PADDLE_MAY_THROW {
// PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd()); // PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd());
// NCCLMutex().unlock(); // HCCLMutex().unlock();
// } // }
// }; // };
// struct NCCLContext { struct HCCLContext {
// std::unique_ptr<CUDADeviceContext> ctx_; std::unique_ptr<NPUDeviceContext> ctx_;
// ncclComm_t comm_; HcclComm comm_;
// explicit NCCLContext(int dev_id) explicit HCCLContext(int dev_id)
// : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {} : ctx_(new NPUDeviceContext(NPUPlace(dev_id))), comm_{nullptr} {}
// gpuStream_t stream() const { return ctx_->stream(); } aclrtStream stream() const { return ctx_->stream(); }
// ncclComm_t comm() const { return comm_; } HcclComm comm() const { return comm_; }
// int device_id() const {
// return BOOST_GET_CONST(platform::CUDAPlace, ctx_->GetPlace()).device;
// }
// };
// struct NCCLContextMap { int device_id() const {
// std::unordered_map<int, NCCLContext> contexts_; return BOOST_GET_CONST(platform::NPUPlace, ctx_->GetPlace()).device;
// std::vector<int> order_; }
};
// explicit NCCLContextMap(const std::vector<platform::Place> &places,
// ncclUniqueId *nccl_id = nullptr, struct HCCLContextMap {
// size_t num_trainers = 1, size_t trainer_id = 0) { std::unordered_map<int, HCCLContext> contexts_;
// PADDLE_ENFORCE_EQ(!places.empty(), true, std::vector<int> order_;
// platform::errors::InvalidArgument(
// "The NCCL place should not be empty.")); explicit HCCLContextMap(const std::vector<platform::Place> &places,
// order_.reserve(places.size()); HcclRootInfo *hccl_id = nullptr,
// for (auto &p : places) { size_t num_trainers = 1, size_t trainer_id = 0) {
// int dev_id = BOOST_GET_CONST(CUDAPlace, p).device; PADDLE_ENFORCE_EQ(!places.empty(), true,
// order_.emplace_back(dev_id); platform::errors::InvalidArgument(
// contexts_.emplace(dev_id, NCCLContext(dev_id)); "The HCCL place should not be empty."));
// } order_.reserve(places.size());
// PADDLE_ENFORCE_EQ( for (auto &p : places) {
// order_.size(), contexts_.size(), int dev_id = BOOST_GET_CONST(NPUPlace, p).device;
// platform::errors::Unavailable("NCCL Context Map does not support " order_.emplace_back(dev_id);
// "contain two or more same device.")); contexts_.emplace(dev_id, HCCLContext(dev_id));
}
// std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]); PADDLE_ENFORCE_EQ(
// // if num_trainers == 1, should create a new nccl id for local comms. order_.size(), contexts_.size(),
// if (num_trainers == 1 && nccl_id == nullptr) { platform::errors::Unavailable("HCCL Context Map does not support "
// std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex()); "contain two or more same device."));
// PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
// comms.get(), static_cast<int>(order_.size()), order_.data())); std::unique_ptr<HcclComm[]> comms(new HcclComm[order_.size()]);
// } else { // if num_trainers == 1, should create a new nccl id for local comms.
// PADDLE_ENFORCE_NOT_NULL(nccl_id, platform::errors::InvalidArgument( if (num_trainers == 1 && hccl_id == nullptr) {
// "The NCCL id should not be null.")); // we do not know how to tackle this situation under hccl
// { // std::lock_guard<std::mutex> guard(HCCLGroupGuard::HCCLMutex());
// int nranks = num_trainers * order_.size(); // PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::ncclCommInitAll(
// NCCLGroupGuard gurad; // comms.get(), static_cast<int>(order_.size()), order_.data()));
// for (size_t i = 0; i < order_.size(); ++i) { } else {
// int gpu_id = order_[i]; PADDLE_ENFORCE_NOT_NULL(hccl_id, platform::errors::InvalidArgument(
// int rank; "The HCCL id should not be null."));
// if (order_.size() > 1) { {
// rank = trainer_id * order_.size() + i; int nranks = num_trainers * order_.size();
// } else { // HCCLGroupGuard gurad;
// rank = trainer_id; for (size_t i = 0; i < order_.size(); ++i) {
// } int gpu_id = order_[i];
// VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks int rank;
// << ", gpu_id:" << gpu_id << ", dev_id:" << order_[i]; if (order_.size() > 1) {
// SetDeviceId(gpu_id); rank = trainer_id * order_.size() + i;
// PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitRank( } else {
// comms.get() + i, nranks, *nccl_id, rank)); rank = trainer_id;
// } }
// } VLOG(1) << "init hccl rank:" << rank << ", nranks:" << nranks
// } << ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
// int i = 0; aclrtSetDevice(gpu_id);
// for (auto &dev_id : order_) { PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclCommInitRootInfo(
// contexts_.at(dev_id).comm_ = comms[i++]; nranks, hccl_id, rank, comms.get() + i));
// } }
// } }
}
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
}
// NCCLContextMap(const NCCLContextMap &other) = delete; HCCLContextMap(const HCCLContextMap &other) = delete;
// NCCLContextMap &operator=(const NCCLContextMap &other) = delete; HCCLContextMap &operator=(const HCCLContextMap &other) = delete;
// CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); } NPUDeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
// CUDADeviceContext *DevCtx(platform::Place p) const { NPUDeviceContext *DevCtx(platform::Place p) const {
// return DevCtx(BOOST_GET_CONST(CUDAPlace, p).device); return DevCtx(BOOST_GET_CONST(NPUPlace, p).device);
// } }
// const NCCLContext &at(platform::Place p) const { const HCCLContext &at(platform::Place p) const {
// return this->at(BOOST_GET_CONST(CUDAPlace, p).device); return this->at(BOOST_GET_CONST(NPUPlace, p).device);
// } }
// const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); } const HCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }
// void WaitAll() { void WaitAll() {
// for (auto &p : contexts_) { for (auto &p : contexts_) {
// p.second.ctx_->Wait(); p.second.ctx_->Wait();
// } }
// } }
// }; };
// inline std::string GetFlatNCCLVarName(size_t pos) { inline std::string GetFlatHCCLVarName(size_t pos) {
// if (pos == 0) { if (pos == 0) {
// return NCCL_ID_VARNAME; return HCCL_ID_VARNAME;
// } }
// return string::Sprintf("%s_%d", NCCL_ID_VARNAME, static_cast<int>(pos)); return string::Sprintf("%s_%d", HCCL_ID_VARNAME, static_cast<int>(pos));
// } }
// inline std::string GetHierarchicalExterNCCLVarName(size_t pos) { inline std::string GetHierarchicalExterHCCLVarName(size_t pos) {
// return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME, return string::Sprintf("Hierarchical_exter_%s_%d", HCCL_ID_VARNAME,
// static_cast<int>(pos)); static_cast<int>(pos));
// } }
// inline std::string GetHierarchicalInterNCCLVarName(size_t pos) { inline std::string GetHierarchicalInterHCCLVarName(size_t pos) {
// return string::Sprintf("Hierarchical_inter_%s_%d", NCCL_ID_VARNAME, return string::Sprintf("Hierarchical_inter_%s_%d", HCCL_ID_VARNAME,
// static_cast<int>(pos)); static_cast<int>(pos));
// } }
// class NCCLCommunicator { class HCCLCommunicator {
// public: public:
// NCCLCommunicator() {} HCCLCommunicator() {}
// virtual ~NCCLCommunicator() PADDLE_MAY_THROW {} virtual ~HCCLCommunicator() PADDLE_MAY_THROW {}
// NCCLContextMap *DefaultFlatCtx() const { HCCLContextMap *DefaultFlatCtx() const {
// if (flat_ctxs_.size() == 0) { if (flat_ctxs_.size() == 0) {
// return nullptr; return nullptr;
// } }
// return flat_ctxs_[0].get(); return flat_ctxs_[0].get();
// } }
// std::vector<std::unique_ptr<NCCLContextMap>> *GetFlatCtxs() { std::vector<std::unique_ptr<HCCLContextMap>> *GetFlatCtxs() {
// return &flat_ctxs_; return &flat_ctxs_;
// } }
// NCCLContextMap *GetFlatCtx(size_t run_order) const { HCCLContextMap *GetFlatCtx(size_t run_order) const {
// return flat_ctxs_[run_order % flat_ctxs_.size()].get(); return flat_ctxs_[run_order % flat_ctxs_.size()].get();
// } }
// NCCLContextMap *GetRunEnvNCCLCtx(size_t run_order, HCCLContextMap *GetRunEnvHCCLCtx(size_t run_order,
// bool use_hierarchical_allreduce) const { bool use_hierarchical_allreduce) const {
// if (!use_hierarchical_allreduce) { if (!use_hierarchical_allreduce) {
// return GetFlatCtx(run_order); return GetFlatCtx(run_order);
// } }
// return GetHierarchicalInterCtx(run_order); return GetHierarchicalInterCtx(run_order);
// } }
// *When nccl inits nccl comm using ncclCommInitAll, it meets error when /*
// *allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So When nccl inits nccl comm using ncclCommInitAll, it meets error when
// *create a new nccl comm for sync_batch_norm_op. And these codes should be allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
// *polished with a unified nccl management. create a new nccl comm for sync_batch_norm_op. And these codes should be
polished with a unified nccl management.
*/
// NCCLContextMap *GetSyncBatchNormCtx( HCCLContextMap *GetSyncBatchNormCtx(framework::Scope* scope, const std::vector<platform::Place> &places) {
// framework::Scope *scope, const std::vector<platform::Place> &places) { auto *hccl_id_var = scope->FindVar(HCCL_ID_VARNAME);
// auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); if (hccl_id_var != nullptr) {
// if (nccl_id_var != nullptr) { return DefaultFlatCtx();
// return DefaultFlatCtx(); }
// }
// if (sync_batch_norm_ctx_.get() == nullptr) { if (sync_batch_norm_ctx_.get() == nullptr) {
// sync_batch_norm_ctx_.reset(new NCCLContextMap(places)); sync_batch_norm_ctx_.reset(new HCCLContextMap(places));
// } }
// return sync_batch_norm_ctx_.get(); return sync_batch_norm_ctx_.get();
// } }
// void InitFlatCtxs(const std::vector<platform::Place> &places, void InitFlatCtxs(const std::vector<platform::Place> &places,
// const std::vector<ncclUniqueId *> &nccl_ids, const std::vector<HcclRootInfo *> &hccl_ids,
// size_t trainers_num, size_t trainer_id) { size_t trainers_num, size_t trainer_id) {
// if (nccl_ids.size() == 0) { if (hccl_ids.size() == 0) {
// auto ptr = new platform::NCCLContextMap(places); auto ptr = new platform::HCCLContextMap(places);
// VLOG(1) << "init local trainer"; VLOG(1) << "init local trainer";
// flat_ctxs_.emplace_back(ptr); flat_ctxs_.emplace_back(ptr);
// } else { } else {
// for (size_t i = 0; i < nccl_ids.size(); i++) { for (size_t i = 0; i < hccl_ids.size(); i++) {
// auto ptr = new platform::NCCLContextMap(places, nccl_ids[i], auto ptr = new platform::HCCLContextMap(places, hccl_ids[i],
// trainers_num, trainer_id); trainers_num, trainer_id);
// VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
// flat_ctxs_.emplace_back(ptr); flat_ctxs_.emplace_back(ptr);
// } }
// } }
// // as Executor have no way to use ncclComm created by ParallelExecutor, // as Executor have no way to use ncclComm created by ParallelExecutor,
// // we assign all flatten contexts to NCCLCommContext to fix. // we assign all flatten contexts to HCCLCommContext to fix.
// int nranks = static_cast<int>(trainers_num * places.size()); int nranks = static_cast<int>(trainers_num * places.size());
// int nrings = static_cast<int>(flat_ctxs_.size()); int nrings = static_cast<int>(flat_ctxs_.size());
// for (int ring_id = 0; ring_id < nrings; ++ring_id) { for (int ring_id = 0; ring_id < nrings; ++ring_id) {
// for (size_t p = 0; p < places.size(); ++p) { for (size_t p = 0; p < places.size(); ++p) {
// int rank = trainer_id * places.size() + p; int rank = trainer_id * places.size() + p;
// int dev_id = BOOST_GET_CONST(CUDAPlace, places[p]).device; int dev_id = BOOST_GET_CONST(NPUPlace, places[p]).device;
// auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id); auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
// NCCLCommContext::Instance().AssignNCCLComm(ctx.comm_, nranks, rank, HCCLCommContext::Instance().AssignHCCLComm(ctx.comm_, nranks, rank,
// dev_id, ring_id); dev_id, ring_id);
// } }
// } }
// } }
// void InitHierarchicalCtxs(const std::vector<platform::Place> &places, void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
// const std::vector<ncclUniqueId *> &inter_nccl_ids, const std::vector<HcclRootInfo *> &inter_hccl_ids,
// const std::vector<ncclUniqueId *> &exter_nccl_ids, const std::vector<HcclRootInfo *> &exter_hccl_ids,
// size_t trainers_num, size_t trainer_id, size_t trainers_num, size_t trainer_id,
// size_t inter_trainers_num, size_t inter_trainers_num,
// size_t exter_trainers_num) { size_t exter_trainers_num) {
// PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
// trainers_num, inter_trainers_num * exter_trainers_num, trainers_num, inter_trainers_num * exter_trainers_num,
// platform::errors::InvalidArgument( platform::errors::InvalidArgument(
// "trainers_num:%llu != inter_trainers_num:%llu * " "trainers_num:%llu != inter_trainers_num:%llu * "
// "exter_trainers_num:%llu", "exter_trainers_num:%llu",
// trainers_num, inter_trainers_num, exter_trainers_num)); trainers_num, inter_trainers_num, exter_trainers_num));
// PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
// inter_trainers_num, 1, inter_trainers_num, 1,
// platform::errors::InvalidArgument( platform::errors::InvalidArgument(
// "The inter_trainers_num:%llu should be larger than 1.", "The inter_trainers_num:%llu should be larger than 1.",
// inter_trainers_num)); inter_trainers_num));
// int inter_trainer_id = trainer_id % inter_trainers_num; int inter_trainer_id = trainer_id % inter_trainers_num;
// for (size_t i = 0; i < inter_nccl_ids.size(); i++) { for (size_t i = 0; i < inter_hccl_ids.size(); i++) {
// VLOG(1) << "init inter_trainer_id:" << inter_trainer_id VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
// << ", comm no:" << i; << ", comm no:" << i;
// auto local = new NCCLContextMap(places, inter_nccl_ids[i], auto local = new HCCLContextMap(places, inter_hccl_ids[i],
// inter_trainers_num, inter_trainer_id); inter_trainers_num, inter_trainer_id);
// h_inter_ctxs_.emplace_back(local); h_inter_ctxs_.emplace_back(local);
// } }
// int exter_trainer_id = -1; int exter_trainer_id = -1;
// if (trainer_id % inter_trainers_num == 0) { if (trainer_id % inter_trainers_num == 0) {
// exter_trainer_id = trainer_id / inter_trainers_num; exter_trainer_id = trainer_id / inter_trainers_num;
// } }
// if (exter_trainer_id >= 0) { if (exter_trainer_id >= 0) {
// for (size_t i = 0; i < exter_nccl_ids.size(); i++) { for (size_t i = 0; i < exter_hccl_ids.size(); i++) {
// auto ex = new NCCLContextMap(places, exter_nccl_ids[i], auto ex = new HCCLContextMap(places, exter_hccl_ids[i],
// exter_trainers_num, exter_trainer_id); exter_trainers_num, exter_trainer_id);
// VLOG(1) << "init exter_trainer_id:" << exter_trainer_id VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
// << ", comm no:" << i; << ", comm no:" << i;
// h_exter_ctxs_.emplace_back(ex); h_exter_ctxs_.emplace_back(ex);
// } }
// } }
// } }
// bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; } bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }
// NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const { HCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
// PADDLE_ENFORCE_GT(h_inter_ctxs_.size(), 0, PADDLE_ENFORCE_GT(h_inter_ctxs_.size(), 0,
// platform::errors::InvalidArgument( platform::errors::InvalidArgument(
// "Hierarchical ctxs should be initialized firstly!")); "Hierarchical ctxs should be initialized firstly!"));
// return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get(); return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
// } }
// NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const { HCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
// PADDLE_ENFORCE_GT(h_exter_ctxs_.size(), 0, PADDLE_ENFORCE_GT(h_exter_ctxs_.size(), 0,
// platform::errors::InvalidArgument( platform::errors::InvalidArgument(
// "Hierarchical ctxs should be initialized firstly!")); "Hierarchical ctxs should be initialized firstly!"));
// return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get(); return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
// } }
// std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalInterCtxs() { std::vector<std::unique_ptr<HCCLContextMap>> *GetHierarchicalInterCtxs() {
// return &h_inter_ctxs_; return &h_inter_ctxs_;
// } }
// std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalExterCtxs() { std::vector<std::unique_ptr<HCCLContextMap>> *GetHierarchicalExterCtxs() {
// return &h_exter_ctxs_; return &h_exter_ctxs_;
// } }
// protected: protected:
// // Support multi nccl comm on default nccl ring while NCCLContextMap can't. // Support multi nccl comm on default nccl ring while HCCLContextMap can't.
// std::vector<std::unique_ptr<NCCLContextMap>> flat_ctxs_; std::vector<std::unique_ptr<HCCLContextMap>> flat_ctxs_;
// // h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce. // h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce.
// // And h_exter_ctxs_ can support multi comm too. // And h_exter_ctxs_ can support multi comm too.
// std::vector<std::unique_ptr<NCCLContextMap>> h_inter_ctxs_; std::vector<std::unique_ptr<HCCLContextMap>> h_inter_ctxs_;
// std::vector<std::unique_ptr<NCCLContextMap>> h_exter_ctxs_; std::vector<std::unique_ptr<HCCLContextMap>> h_exter_ctxs_;
// // just used for sync_batch_norm op. // just used for sync_batch_norm op.
// std::unique_ptr<NCCLContextMap> sync_batch_norm_ctx_; std::unique_ptr<HCCLContextMap> sync_batch_norm_ctx_;
// }; };
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
...@@ -151,17 +151,30 @@ class CollectiveHelper(object): ...@@ -151,17 +151,30 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)} endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op( block.append_op(
type='c_comm_init_hcom', type='c_gen_hccl_id',
inputs={}, inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={}, outputs={},
attrs={ attrs={
'nranks': nranks,
'rank': rank, 'rank': rank,
'ring_id': ring_id, 'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")), 'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints], 'rank_ids': nranks,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
......
...@@ -108,19 +108,32 @@ class PipelineHelper(object): ...@@ -108,19 +108,32 @@ class PipelineHelper(object):
OP_ROLE_KEY: OpRole.Forward, OP_ROLE_KEY: OpRole.Forward,
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
endpoint_to_index_map = { if rank == 0 and wait_port:
e: idx for idx, e in enumerate(endpoints) wait_server_ready(other_endpoints)
} hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op( block.append_op(
type='c_comm_init_hcom', type='c_gen_hccl_id',
inputs={}, inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={}, outputs={},
attrs={ attrs={
'nranks': nranks,
'rank': rank, 'rank': rank,
'ring_id': ring_id, 'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")), 'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints], 'rank_ids': nranks,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
......
...@@ -2053,7 +2053,7 @@ class Operator(object): ...@@ -2053,7 +2053,7 @@ class Operator(object):
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad', 'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify', 'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_comm_init_hcom', 'c_sync_calc_stream', 'gen_nccl_id', 'c_gen_nccl_id', 'c_gen_hccl_id', 'c_comm_init', 'c_comm_init_hccl', 'c_sync_calc_stream',
'c_sync_comm_stream', 'queue_generator', 'dequeue', 'enqueue', 'c_sync_comm_stream', 'queue_generator', 'dequeue', 'enqueue',
'heter_listen_and_serv' 'heter_listen_and_serv'
} }
......
...@@ -131,19 +131,32 @@ class Collective(object): ...@@ -131,19 +131,32 @@ class Collective(object):
self.op_role_key: OpRole.Forward self.op_role_key: OpRole.Forward
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
endpoint_to_index_map = { if rank == 0 and wait_port:
e: idx for idx, e in enumerate(endpoints) wait_server_ready(other_endpoints)
} hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op( block.append_op(
type='c_comm_init_hcom', type='c_gen_hccl_id',
inputs={}, inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={}, outputs={},
attrs={ attrs={
'nranks': nranks,
'rank': rank, 'rank': rank,
'ring_id': ring_id, 'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")), 'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints], 'rank_ids': nranks,
self.op_role_key: OpRole.Forward self.op_role_key: OpRole.Forward
}) })
......
...@@ -162,19 +162,33 @@ def init_communicator(program, rank, nranks, wait_port, current_endpoint, ...@@ -162,19 +162,33 @@ def init_communicator(program, rank, nranks, wait_port, current_endpoint,
'ring_id': 0, 'ring_id': 0,
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
endpoint_to_index_map = { if rank == 0 and wait_port:
e: idx for idx, e in enumerate(endpoints) wait_server_ready(other_endpoints)
} hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op( block.append_op(
type='c_comm_init_hcom', type='c_gen_hccl_id',
inputs={}, inputs={},
outputs={'Out': hccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
outputs={}, outputs={},
attrs={ attrs={
'nranks': nranks,
'rank': rank, 'rank': rank,
'ring_id': 0, 'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")), 'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints], 'rank_ids': nranks,
OP_ROLE_KEY: OpRole.Forward
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册