未验证 提交 7b450e78 编写于 作者: V Void Main 提交者: GitHub

Add auto-increasing tag id for Hcom OPs (#31702)

上级 50bc1162
...@@ -35,10 +35,10 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> { ...@@ -35,10 +35,10 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks(); int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
framework::DDim out_dims = in->dims(); framework::DDim out_dims = in->dims();
out_dims[0] *= nranks; out_dims[0] *= nranks;
......
...@@ -119,7 +119,9 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -119,7 +119,9 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
op->Run(*scope, place); for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -118,7 +118,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -118,7 +118,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
op->Run(*scope, place); for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -118,7 +118,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -118,7 +118,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
// we need to pre-allocate 512 Bytes before the data // we need to pre-allocate 512 Bytes before the data
// and 512 Bytes after the data, so the hccl allreduce // and 512 Bytes after the data, so the hccl allreduce
// can work. This is a must acooding to huawei peer. // can work. This is a must acooding to huawei peer.
#define PRE_MALLOC_SIZE_BYTES 512 #define PRE_MALLOC_SIZE_BYTES 512
...@@ -135,16 +135,16 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -135,16 +135,16 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
paddle::framework::LoDTensor tmp_in, tmp_out; paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel}); tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel}); tmp_out.Resize({tmp_numel});
tmp_in.mutable_data<T>(place); // allocate auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
tmp_out.mutable_data<T>(place); // allocate auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size); void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size); void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
std::string tag = ctx.Attr<std::string>("tag");
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
...@@ -154,9 +154,13 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -154,9 +154,13 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
stream = comm->stream(); stream = comm->stream();
} }
// we need to memset this memory firstly to avoid core by hccl
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place); auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff, memory::Copy(npu_place, sendbuff,
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())), npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T), numel * sizeof(T),
stream); stream);
...@@ -195,10 +199,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -195,10 +199,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream)); tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream));
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()), memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()),
npu_place, recvbuff, npu_place, recvbuff,
numel * sizeof(T), numel * sizeof(T),
stream); stream);
out->Resize(in->dims()); out->Resize(in->dims());
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
...@@ -117,7 +117,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -117,7 +117,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
{{"Out", {"Out"}}}, {{"Out", {"Out"}}},
attrs); attrs);
op->Run(*scope, place); for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -48,7 +48,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> { ...@@ -48,7 +48,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag"); std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root
<< ", group is " << group << ", group is " << group
......
...@@ -113,7 +113,9 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -113,7 +113,9 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
op->Run(*scope, place); for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -32,10 +32,10 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -32,10 +32,10 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks(); int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
auto out_dims = in->dims(); auto out_dims = in->dims();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0, PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0,
...@@ -43,7 +43,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
"The input tensor X's " "The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)", "dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0], nranks)); out_dims[0], nranks));
out_dims[0] = out_dims[0] / nranks; out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
...@@ -66,7 +66,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -66,7 +66,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
<< "hccl_red_type: " << HCCL_REP_OP_SUM << "hccl_red_type: " << HCCL_REP_OP_SUM
<< ", group is: " << group << ", group is: " << group
<< ", tag is " << tag; << ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_reduce_scatter( PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_reduce_scatter(
tag.c_str(), inputPtr, outputPtr, (u64)recv_numel, dtype, HCCL_REP_OP_SUM, group.c_str(), (void*)stream)); tag.c_str(), inputPtr, outputPtr, (u64)recv_numel, dtype, HCCL_REP_OP_SUM, group.c_str(), (void*)stream));
#else #else
...@@ -82,7 +82,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -82,7 +82,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(c_reducescatter, REGISTER_OP_NPU_KERNEL(c_reducescatter,
ops::CReduceScatterOpAscendKernel<int8_t>, ops::CReduceScatterOpAscendKernel<int8_t>,
ops::CReduceScatterOpAscendKernel<int>, ops::CReduceScatterOpAscendKernel<int>,
ops::CReduceScatterOpAscendKernel<float>, ops::CReduceScatterOpAscendKernel<float>,
......
...@@ -119,7 +119,9 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -119,7 +119,9 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
op->Run(*scope, place); for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -42,7 +42,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> { ...@@ -42,7 +42,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
std::string tag = ctx.Attr<std::string>("tag"); std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int srcRank = ctx.Attr<int>("peer"); int srcRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag"); int srTag = ctx.Attr<int>("srTag");
...@@ -66,7 +66,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> { ...@@ -66,7 +66,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(recv_v2, REGISTER_OP_NPU_KERNEL(recv_v2,
ops::CRecvOpASCENDKernel<int>, ops::CRecvOpASCENDKernel<int>,
ops::CRecvOpASCENDKernel<int8_t>, ops::CRecvOpASCENDKernel<int8_t>,
ops::CRecvOpASCENDKernel<float>, ops::CRecvOpASCENDKernel<float>,
......
...@@ -99,7 +99,9 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){ ...@@ -99,7 +99,9 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs); auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs);
VLOG(3) << "CreateOp recv_v2"; VLOG(3) << "CreateOp recv_v2";
op->Run(*scope, place); for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3) << "Run op recv_v2"; VLOG(3) << "Run op recv_v2";
std::vector<float> out_vec; std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec); TensorToVector(*tensor_out, ctx, &out_vec);
......
...@@ -42,7 +42,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> { ...@@ -42,7 +42,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
std::string tag = ctx.Attr<std::string>("tag"); std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int destRank = ctx.Attr<int>("peer"); int destRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag"); int srTag = ctx.Attr<int>("srTag");
...@@ -50,7 +50,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> { ...@@ -50,7 +50,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_send( PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_send(
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), (u64)numel, dtype, destRank, tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), (u64)numel, dtype, destRank,
srTag, group.c_str(), stream)); srTag, group.c_str(), stream));
VLOG(3) << "Dest rank:" << destRank << " Invoke hcom send. Sent " VLOG(3) << "Dest rank:" << destRank << " Invoke hcom send. Sent "
<< x->numel(); << x->numel();
...@@ -67,7 +67,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> { ...@@ -67,7 +67,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(send_v2, REGISTER_OP_NPU_KERNEL(send_v2,
ops::CSendOpASCENDKernel<int>, ops::CSendOpASCENDKernel<int>,
ops::CSendOpASCENDKernel<int8_t>, ops::CSendOpASCENDKernel<int8_t>,
ops::CSendOpASCENDKernel<float>, ops::CSendOpASCENDKernel<float>,
......
...@@ -90,7 +90,9 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){ ...@@ -90,7 +90,9 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs); auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs);
op->Run(*scope, place); for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3)<<"send run over"; VLOG(3)<<"send run over";
ctx.Wait(); ctx.Wait();
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <atomic>
#include "boost/variant.hpp" #include "boost/variant.hpp"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -148,8 +149,7 @@ class NCCLCommContext { ...@@ -148,8 +149,7 @@ class NCCLCommContext {
class NPUDeviceContext; class NPUDeviceContext;
#define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE" #define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE"
#define ENV_RANK_ID "RANK_ID" #define ENV_RANK_ID "PADDLE_TRAINER_ID"
#define ENV_DEV_ID "DEV_ID"
class HCCLComm { class HCCLComm {
public: public:
...@@ -160,6 +160,12 @@ class HCCLComm { ...@@ -160,6 +160,12 @@ class HCCLComm {
virtual aclrtStream stream() const = 0; virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0; virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default; virtual ~HCCLComm() = default;
unsigned long NextTagId() {
return tag_counter_++;
}
private:
std::atomic<unsigned long> tag_counter_;
}; };
// A singleton HCCL communicator context reserves communication ring ids // A singleton HCCL communicator context reserves communication ring ids
...@@ -208,10 +214,12 @@ class HCCLCommContext { ...@@ -208,10 +214,12 @@ class HCCLCommContext {
return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device); return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device);
} }
private: private:
// Init global hcom // Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); } HCCLCommContext() { InitHcomWorldGroup(); }
public: public:
~HCCLCommContext(){ ~HCCLCommContext(){
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy()); PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册