未验证 提交 7b450e78 编写于 作者: V Void Main 提交者: GitHub

Add auto-increasing tag id for Hcom OPs (#31702)

上级 50bc1162
...@@ -35,10 +35,10 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> { ...@@ -35,10 +35,10 @@ class CAllGatherOpASCENDKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks(); int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
framework::DDim out_dims = in->dims(); framework::DDim out_dims = in->dims();
out_dims[0] *= nranks; out_dims[0] *= nranks;
......
...@@ -119,7 +119,9 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -119,7 +119,9 @@ void TestHCCLAllGatherOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_allgather", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -118,7 +118,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -118,7 +118,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_allreduce_max", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -135,16 +135,16 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -135,16 +135,16 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
paddle::framework::LoDTensor tmp_in, tmp_out; paddle::framework::LoDTensor tmp_in, tmp_out;
tmp_in.Resize({tmp_numel}); tmp_in.Resize({tmp_numel});
tmp_out.Resize({tmp_numel}); tmp_out.Resize({tmp_numel});
tmp_in.mutable_data<T>(place); // allocate auto p_tmp_in = tmp_in.mutable_data<T>(place); // allocate
tmp_out.mutable_data<T>(place); // allocate auto p_tmp_out = tmp_out.mutable_data<T>(place); // allocate
void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size); void* sendbuff = reinterpret_cast<void*>(tmp_in.data<T>() + pre_tmp_size);
void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size); void* recvbuff = reinterpret_cast<void*>(tmp_out.data<T>() + pre_tmp_size);
std::string tag = ctx.Attr<std::string>("tag");
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
...@@ -154,6 +154,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -154,6 +154,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
stream = comm->stream(); stream = comm->stream();
} }
// we need to memset this memory firstly to avoid core by hccl
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_in), 0, tmp_numel*sizeof(T), stream);
platform::NPUMemsetAsync(static_cast<void*>(p_tmp_out), 0, tmp_numel*sizeof(T), stream);
auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place); auto npu_place = BOOST_GET_CONST(platform::NPUPlace, place);
memory::Copy(npu_place, sendbuff, memory::Copy(npu_place, sendbuff,
......
...@@ -117,7 +117,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -117,7 +117,9 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx) {
{{"Out", {"Out"}}}, {{"Out", {"Out"}}},
attrs); attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -48,7 +48,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> { ...@@ -48,7 +48,7 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag"); std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root VLOG(3) << "begin hccl broadcast, parameter is: "<< "root " << root
<< ", group is " << group << ", group is " << group
......
...@@ -113,7 +113,9 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -113,7 +113,9 @@ void TestHCCLBroadcastOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_broadcast", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -32,10 +32,10 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> { ...@@ -32,10 +32,10 @@ class CReduceScatterOpAscendKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
std::string tag = ctx.Attr<std::string>("tag");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::HCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks(); int nranks = comm->nranks();
std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
auto out_dims = in->dims(); auto out_dims = in->dims();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0, PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0,
......
...@@ -119,7 +119,9 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) { ...@@ -119,7 +119,9 @@ void TestHCCLReduceScatterOp(f::Scope* scope, const p::DeviceContext& ctx) {
auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}}, auto op = f::OpRegistry::CreateOp("c_reducescatter", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
}
ctx.Wait(); ctx.Wait();
std::vector<float> out_vec; std::vector<float> out_vec;
......
...@@ -42,7 +42,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> { ...@@ -42,7 +42,7 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
std::string tag = ctx.Attr<std::string>("tag"); std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int srcRank = ctx.Attr<int>("peer"); int srcRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag"); int srTag = ctx.Attr<int>("srTag");
......
...@@ -99,7 +99,9 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){ ...@@ -99,7 +99,9 @@ void TestHcomRecvOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs); auto op = f::OpRegistry::CreateOp("recv_v2", {}, {{"Out", {"Out"}}}, attrs);
VLOG(3) << "CreateOp recv_v2"; VLOG(3) << "CreateOp recv_v2";
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
}
VLOG(3) << "Run op recv_v2"; VLOG(3) << "Run op recv_v2";
std::vector<float> out_vec; std::vector<float> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec); TensorToVector(*tensor_out, ctx, &out_vec);
......
...@@ -42,7 +42,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> { ...@@ -42,7 +42,7 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
std::string tag = ctx.Attr<std::string>("tag"); std::string tag = std::to_string(ring_id) + "_" + std::to_string(comm->NextTagId());
std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id); std::string group = std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
int destRank = ctx.Attr<int>("peer"); int destRank = ctx.Attr<int>("peer");
int srTag = ctx.Attr<int>("srTag"); int srTag = ctx.Attr<int>("srTag");
......
...@@ -90,7 +90,9 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){ ...@@ -90,7 +90,9 @@ void TestHcomSendOp(f::Scope* scope, const p::DeviceContext& ctx){
auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs); auto op = f::OpRegistry::CreateOp("send_v2", {{"X", {"X"}}}, {}, attrs);
for (int i = 0; i < 10; i ++) {
op->Run(*scope, place); op->Run(*scope, place);
}
VLOG(3)<<"send run over"; VLOG(3)<<"send run over";
ctx.Wait(); ctx.Wait();
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <atomic>
#include "boost/variant.hpp" #include "boost/variant.hpp"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -148,8 +149,7 @@ class NCCLCommContext { ...@@ -148,8 +149,7 @@ class NCCLCommContext {
class NPUDeviceContext; class NPUDeviceContext;
#define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE" #define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE"
#define ENV_RANK_ID "RANK_ID" #define ENV_RANK_ID "PADDLE_TRAINER_ID"
#define ENV_DEV_ID "DEV_ID"
class HCCLComm { class HCCLComm {
public: public:
...@@ -160,6 +160,12 @@ class HCCLComm { ...@@ -160,6 +160,12 @@ class HCCLComm {
virtual aclrtStream stream() const = 0; virtual aclrtStream stream() const = 0;
virtual NPUDeviceContext* dev_context() const = 0; virtual NPUDeviceContext* dev_context() const = 0;
virtual ~HCCLComm() = default; virtual ~HCCLComm() = default;
unsigned long NextTagId() {
return tag_counter_++;
}
private:
std::atomic<unsigned long> tag_counter_;
}; };
// A singleton HCCL communicator context reserves communication ring ids // A singleton HCCL communicator context reserves communication ring ids
...@@ -208,10 +214,12 @@ class HCCLCommContext { ...@@ -208,10 +214,12 @@ class HCCLCommContext {
return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device); return Get(ring_id, BOOST_GET_CONST(NPUPlace, place).device);
} }
private: private:
// Init global hcom // Init global hcom
HCCLCommContext() { InitHcomWorldGroup(); } HCCLCommContext() { InitHcomWorldGroup(); }
public: public:
~HCCLCommContext(){ ~HCCLCommContext(){
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy()); PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::hcom_destroy());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册