未验证 提交 7b450e78 编写于 作者: V Void Main 提交者: GitHub

Add auto-increasing tag id for Hcom OPs (#31702)

上级 50bc1162
......@@ -35,10 +35,10 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
......
......@@ -119,7 +119,9 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -118,7 +118,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -135,16 +135,16 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel});
tmp_in.mutable_data<T>(place); // allocate
tmp_out.mutable_data<T>(place); // allocate
auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
std::string tag = ctx.Attr<std::string>("tag");
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
......@@ -154,6 +154,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
stream = comm->stream();
}
// we need to memset this memory firstly to avoid core by hccl
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff,
......
......@@ -117,7 +117,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
{{"Out", {"Out"}}},
attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -48,7 +48,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root
<< ", group is " << group
......
......@@ -113,7 +113,9 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -32,10 +32,10 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
auto out_dims = in->dims();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0,
......
......@@ -119,7 +119,9 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
ctx.Wait();
std::vector<float> out_vec;
......
......@@ -42,7 +42,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
} else {
stream = comm->stream();
}
std::string tag = ctx.Attr<std::string>("tag");
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int srcRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
......
......@@ -99,7 +99,9 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs);
VLOG(3) << "CreateOp recv_v2";
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3) << "Run op recv_v2";
std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
......
......@@ -42,7 +42,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
} else {
stream = comm->stream();
}
std::string tag = ctx.Attr<std::string>("tag");
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int destRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag");
......
......@@ -90,7 +90,9 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place);
}
VLOG(3)<<"send run over";
ctx.Wait();
}
......
......@@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <vector>
#include <atomic>
#include "boost/variant.hpp"
#include "paddle/fluid/platform/enforce.h"
......@@ -148,8 +149,7 @@ class NCCLCommContext {
class NPUDeviceContext;
#define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE"
#define ENV_RANK_ID "RANK_ID"
#define ENV_DEV_ID "DEV_ID"
#define ENV_RANK_ID "PADDLE_TRAINER_ID"
class HCCLComm {
public:
......@@ -160,6 +160,12 @@ class HCCLComm {
virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default;
unsigned long NextTagId() {
return tag_counter_++;
}
private:
std::atomic<unsigned long> tag_counter_;
};
// A singleton HCCL communicator context reserves communication ring ids
......@@ -208,10 +214,12 @@ class HCCLCommContext {
return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device);
}
private:
// Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); }
public:
~HCCLCommContext(){
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册