未验证 提交 7b450e78 编写于 作者: V Void Main 提交者: GitHub

Add auto-increasing tag id for Hcom OPs (#31702)

上级 50bc1162
......@@ -35,10 +35,10 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
......
......@@ -119,7 +119,9 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -118,7 +118,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -118,7 +118,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
// we need to pre-allocate 512 Bytes before the data
// we need to pre-allocate 512 Bytes before the data
// and 512 Bytes after the data, so the hccl allreduce
// can work. This is a must acooding to huawei peer.
#define PRE_MALLOC_SIZE_BYTES 512
......@@ -135,16 +135,16 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel});
tmp_in.mutable_data<T>(place); // allocate
tmp_out.mutable_data<T>(place); // allocate
auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
std::string tag = ctx.Attr<std::string>("tag");
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
......@@ -154,9 +154,13 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
stream = comm->stream();
}
// we need to memset this memory firstly to avoid core by hccl
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff,
memory::Copy(npu_place, sendbuff,
npu_place, reinterpret_cast<void*>(const_cast<T*>(in->data<T>())),
numel * sizeof(T),
stream);
......@@ -195,10 +199,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
tag.c_str(), sendbuff, recvbuff, numel, dtype, hccl_red_type, group.c_str(), (void*)stream));
memory::Copy(npu_place, reinterpret_cast<void*>(out->data<T>()),
npu_place, recvbuff,
npu_place, recvbuff,
numel * sizeof(T),
stream);
out->Resize(in->dims());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -117,7 +117,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
{{"Out", {"Out"}}},
attrs);
op->Run(*scope, place);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -48,7 +48,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root
<< ", group is " << group
......
......@@ -113,7 +113,9 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -32,10 +32,10 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
auto out_dims = in->dims();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0,
......@@ -43,7 +43,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0], nranks));
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
......@@ -66,7 +66,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
<< "hccl_red_type: " << HCCL_REP_OP_SUM
<< ", group is: " << group
<< ", tag is " << tag;
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_reduce_scatter(
tag.c_str(), inputPtr, outputPtr, (u64)recv_numel, dtype, HCCL_REP_OP_SUM, group.c_str(), (void*)stream));
#else
......@@ -82,7 +82,7 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(c_reducescatter,
REGISTER_OP_NPU_KERNEL(c_reducescatter,
ops::CReduceScatterOpAscendKernel<int8_t>,
ops::CReduceScatterOpAscendKernel<int>,
ops::CReduceScatterOpAscendKernel<float>,
......
......@@ -119,7 +119,9 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -42,7 +42,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
} else {
stream = comm->stream();
}
std::string tag = ctx.Attr<std::string>("tag");
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int srcRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
......@@ -66,7 +66,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(recv_v2,
REGISTER_OP_NPU_KERNEL(recv_v2,
ops::CRecvOpASCENDKernel<int>,
ops::CRecvOpASCENDKernel<int8_t>,
ops::CRecvOpASCENDKernel<float>,
......
......@@ -99,7 +99,9 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs);
VLOG(3) << "CreateOp recv_v2";
op->Run(*scope, place);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3) << "Run op recv_v2";
std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
......
......@@ -42,7 +42,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
} else {
stream = comm->stream();
}
std::string tag = ctx.Attr<std::string>("tag");
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int destRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
......@@ -50,7 +50,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_send(
tag.c_str(), reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), (u64)numel, dtype, destRank,
srTag, group.c_str(), stream));
VLOG(3) << "Dest rank:" << destRank << " Invoke hcom send. Sent "
<< x->numel();
......@@ -67,7 +67,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(send_v2,
REGISTER_OP_NPU_KERNEL(send_v2,
ops::CSendOpASCENDKernel<int>,
ops::CSendOpASCENDKernel<int8_t>,
ops::CSendOpASCENDKernel<float>,
......
......@@ -90,7 +90,9 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs);
op->Run(*scope, place);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3)<<"send run over";
ctx.Wait();
}
......
......@@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <vector>
#include <atomic>
#include "boost/variant.hpp"
#include "paddle/fluid/platform/enforce.h"
......@@ -148,8 +149,7 @@ class NCCLCommContext {
class NPUDeviceContext;
#define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE"
#define ENV_RANK_ID "RANK_ID"
#define ENV_DEV_ID "DEV_ID"
#define ENV_RANK_ID "PADDLE_TRAINER_ID"
class HCCLComm {
public:
......@@ -160,6 +160,12 @@ class HCCLComm {
virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default;
unsigned long NextTagId() {
return tag_counter_++;
}
private:
std::atomic<unsigned long> tag_counter_;
};
// A singleton HCCL communicator context reserves communication ring ids
......@@ -208,10 +214,12 @@ class HCCLCommContext {
return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device);
}
private:
// Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); }
public:
~HCCLCommContext(){
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册