未验证 提交 56b7ebbc 编写于 作者: W WangXi 提交者: GitHub

[hybrid] remove the using of global ring in hybrid parallel (#34525)

上级 9b6c7eb9
...@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() { ...@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id << " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id; << " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id // it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm( platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, xpu_id, &bkcl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, xpu_id,
ring_id); ring_id);
} }
...@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) { ...@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id << " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id; << " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id // it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm( platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id); &bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);
} }
......
...@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() { ...@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id; << " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id // it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm( platform::NCCLCommContext::Instance().CreateComm(
&nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id, &nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id,
ring_id); ring_id);
...@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) { ...@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id; << " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id // it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm( platform::NCCLCommContext::Instance().CreateComm(
&nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id); &nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id);
compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New( compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
......
...@@ -24,15 +24,16 @@ limitations under the License. */ ...@@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Scope; class Scope;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
// TODO(wangxi): Put this in the unified header file
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
using UniqueId = ncclUniqueId;
using Place = platform::CUDAPlace;
using CommContext = platform::NCCLCommContext;
#elif defined(PADDLE_WITH_XPU_BKCL)
using UniqueId = BKCLUniqueId;
using Place = platform::XPUPlace;
using CommContext = platform::BKCLCommContext;
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU or XPU."));
#endif
PADDLE_ENFORCE_EQ(is_gpu_place(place) || is_xpu_place(place), true, PADDLE_ENFORCE_EQ(is_gpu_place(place) || is_xpu_place(place), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only.")); "CCommInitOp can run on gpu or xpu place only."));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
auto var = scope.FindVar(Input("X")); auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty.")); var, platform::errors::InvalidArgument("Input con not be empty."));
if (is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) UniqueId* comm_id = var->GetMutable<UniqueId>();
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int nranks = Attr<int>("nranks");
int nranks = Attr<int>("nranks"); int rank_id = Attr<int>("rank");
int rank_id = Attr<int>("rank"); int rid = Attr<int>("ring_id");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU."));
#endif
} else if (is_xpu_place(place)) {
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId* bkcl_id = var->GetMutable<BKCLUniqueId>(); PADDLE_ENFORCE_EQ(
rid, 0,
int nranks = Attr<int>("nranks"); platform::errors::OutOfRange(
int rank_id = Attr<int>("rank"); "Ring id must equal 0 in multi Kunlun cards training, but got %d",
int rid = Attr<int>("ring_id"); rid));
PADDLE_ENFORCE_EQ(
rid, 0,
platform::errors::OutOfRange(
"Ring id must equal 0 in multi Kunlun cards training, but got %d",
rid));
int device_id = BOOST_GET_CONST(platform::XPUPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::BKCLCommContext::Instance().CreateBKCLComm(
bkcl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with XPU."));
#endif #endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet( int device_id = BOOST_GET_CONST(Place, place).device;
"CCommInitOp can run on gpu or xpu place only.")); if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
} }
CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id,
rid);
#endif
} }
}; };
......
...@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase { ...@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
int rank = Attr<int>("rank"); int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope(); int ring_id = Attr<int>("ring_id");
std::function<std::string(size_t)> func = [&](size_t i) -> std::string { std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out"); return Output("Out");
...@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase { ...@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
GenBKCLID(&bkcl_ids); GenBKCLID(&bkcl_ids);
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints"); Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids); platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id);
} else { } else {
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
platform::RecvBroadCastCommID(endpoint, &bkcl_ids); platform::RecvBroadCastCommID(endpoint, &bkcl_ids, ring_id);
} }
CopyBKCLIDToVar(bkcl_ids, func, scope); CopyBKCLIDToVar(bkcl_ids, func, scope);
scope.DeleteScope(&local_scope);
} }
}; };
...@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser ...@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) " "(int default 0) "
"The rank of the trainer in distributed training.") "The rank of the trainer in distributed training.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
} }
}; };
......
...@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase { ...@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
int rank = Attr<int>("rank"); int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope(); int ring_id = Attr<int>("ring_id");
std::function<std::string(size_t)> func = [&](size_t i) -> std::string { std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out"); return Output("Out");
...@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase { ...@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
GenHCCLID(&hccl_ids); GenHCCLID(&hccl_ids);
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints"); Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &hccl_ids); platform::SendBroadCastCommID(endpoint_list, &hccl_ids, ring_id);
} else { } else {
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids); platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids, ring_id);
} }
CopyHCCLIDToVar(hccl_ids, func, scope); CopyHCCLIDToVar(hccl_ids, func, scope);
scope.DeleteScope(&local_scope);
} }
}; };
...@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser ...@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) " "(int default 0) "
"The rank of the trainer in distributed training.") "The rank of the trainer in distributed training.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
} }
}; };
......
...@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
int rank = Attr<int>("rank"); int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope(); int ring_id = Attr<int>("ring_id");
std::function<std::string(size_t)> func = [&](size_t i) -> std::string { std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out"); return Output("Out");
...@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
GenNCCLID(&nccl_ids); GenNCCLID(&nccl_ids);
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints"); Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids); platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id);
} else { } else {
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id);
} }
CopyNCCLIDToVar(nccl_ids, func, scope); CopyNCCLIDToVar(nccl_ids, func, scope);
scope.DeleteScope(&local_scope);
} }
}; };
...@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser ...@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) " "(int default 0) "
"The rank of the trainer in distributed training.") "The rank of the trainer in distributed training.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
} }
}; };
......
...@@ -72,8 +72,8 @@ class NCCLCommImpl : public NCCLComm { ...@@ -72,8 +72,8 @@ class NCCLCommImpl : public NCCLComm {
std::shared_ptr<platform::CudaEventObject> comm_event_; std::shared_ptr<platform::CudaEventObject> comm_event_;
}; };
NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, NCCLComm* NCCLCommContext::CreateComm(ncclUniqueId* nccl_id, int nranks,
int rank, int dev_id, int ring_id) { int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(nccl_id, PADDLE_ENFORCE_NOT_NULL(nccl_id,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The nccl unique id should not be null.")); "The nccl unique id should not be null."));
...@@ -225,8 +225,8 @@ class BKCLCommImpl : public BKCLComm { ...@@ -225,8 +225,8 @@ class BKCLCommImpl : public BKCLComm {
std::unique_ptr<XPUDeviceContext> dev_ctx_; std::unique_ptr<XPUDeviceContext> dev_ctx_;
}; };
BKCLComm* BKCLCommContext::CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, BKCLComm* BKCLCommContext::CreateComm(BKCLUniqueId* bkcl_id, int nranks,
int rank, int dev_id, int ring_id) { int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(bkcl_id, PADDLE_ENFORCE_NOT_NULL(bkcl_id,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The bkcl unique id should not be null.")); "The bkcl unique id should not be null."));
......
...@@ -72,8 +72,8 @@ class NCCLCommContext { ...@@ -72,8 +72,8 @@ class NCCLCommContext {
return comm_ctx; return comm_ctx;
} }
NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank, NCCLComm* CreateComm(ncclUniqueId* nccl_id, int nranks, int rank, int dev_id,
int dev_id, int ring_id = 0); int ring_id = 0);
void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0); void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
...@@ -274,8 +274,8 @@ class BKCLCommContext { ...@@ -274,8 +274,8 @@ class BKCLCommContext {
return comm_ctx; return comm_ctx;
} }
BKCLComm* CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, int rank, BKCLComm* CreateComm(BKCLUniqueId* bkcl_id, int nranks, int rank, int dev_id,
int dev_id, int ring_id = 0); int ring_id = 0);
void CreateAllBKCLComms(const std::vector<int>& dev_ids, int ring_id = 0); void CreateAllBKCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
......
...@@ -42,7 +42,10 @@ namespace platform { ...@@ -42,7 +42,10 @@ namespace platform {
std::once_flag SocketServer::init_flag_; std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; struct CommHead {
int version = 1; // unused for now
int ring_id = 0;
};
// Check system calls, such as socket, bind. // Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \ #define CHECK_SYS_CALL(call, name) \
...@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) { ...@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) {
void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); } void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); }
static int SocketAccept(int server_fd, const char* head) { static int SocketAccept(int server_fd, const CommHead head) {
static_assert(sizeof(CommHead) <= 1024,
"sizeof(CommHead) must <= buffer size");
struct sockaddr_in client_addr; struct sockaddr_in client_addr;
socklen_t addr_length = sizeof(client_addr); socklen_t addr_length = sizeof(client_addr);
char buffer[1024] = {0}; char buffer[1024] = {0};
int conn = -1; int conn = -1;
const char* phead = reinterpret_cast<const char*>(&head);
while (true) { while (true) {
CHECK_SYS_CALL_VAL( CHECK_SYS_CALL_VAL(
...@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) { ...@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) {
&addr_length), &addr_length),
"accept", conn); "accept", conn);
int ret_val = SocketRecv(conn, buffer, strlen(head)); int ret_val = SocketRecv(conn, buffer, sizeof(head));
if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) { if (ret_val > 0 && memcmp(buffer, phead, sizeof(head)) == 0) {
// send a message to the sender, indicating that the link is correct
CHECK_SYS_CALL(SocketSend(conn, phead, sizeof(head)), "send");
break; // accept client break; // accept client
} else { } else {
VLOG(3) << "socket read failed with ret_val=" << ret_val; VLOG(3) << "socket read failed with ret_val=" << ret_val;
...@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) { ...@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) {
return conn; return conn;
} }
static int ConnectAddr(const std::string& ep, const char* head) { static int ConnectAddr(const std::string& ep, const CommHead head) {
auto addr = paddle::string::Split(ep, ':'); auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
addr.size(), 2UL, addr.size(), 2UL,
...@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) { ...@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) {
std::string host = addr[0]; std::string host = addr[0];
int port = std::stoi(addr[1]); int port = std::stoi(addr[1]);
int sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
struct sockaddr_in server_addr; struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr)); memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET; server_addr.sin_family = AF_INET;
...@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) { ...@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) {
platform::errors::Unavailable("Open address %s failed: %s", platform::errors::Unavailable("Open address %s failed: %s",
ep, strerror(errno))); ep, strerror(errno)));
static_assert(sizeof(CommHead) <= 1024,
"sizeof(CommHead) must <= buffer size");
char buffer[1024] = {0};
const char* phead = reinterpret_cast<const char*>(&head);
// TODO(wangxi) Set from env, default 900s=15min // TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000; int timeout = 900 * 1000;
int try_times = 0; int try_times = 0;
int total_time = 0; int total_time = 0;
int sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
while (true) { while (true) {
int ret_val = -1; int ret_val = -1;
RETRY_SYS_CALL_VAL( RETRY_SYS_CALL_VAL(
...@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) { ...@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) {
continue; continue;
} }
CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send"); CHECK_SYS_CALL(SocketSend(sock, phead, sizeof(head)), "send");
break; ret_val = SocketRecv(sock, buffer, sizeof(head));
if (ret_val > 0 && memcmp(buffer, phead, sizeof(head)) == 0) {
// recv same message from recver, indicating that the link is correct
break; // accept client
} else {
VLOG(3) << "socket read failed with ret_val=" << ret_val;
CloseSocket(sock);
}
sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
// unmatched link, retry after 80ms
std::this_thread::sleep_for(std::chrono::milliseconds(80));
} }
return sock; return sock;
} }
...@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) { ...@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
template <typename CommUniqueId> template <typename CommUniqueId>
void SendBroadCastCommID(std::vector<std::string> servers, void SendBroadCastCommID(std::vector<std::string> servers,
std::vector<CommUniqueId>* nccl_ids) { std::vector<CommUniqueId>* nccl_ids, int ring_id) {
CommHead head;
head.ring_id = ring_id;
// connect with server // connect with server
std::vector<int> connects; std::vector<int> connects;
for (auto server : servers) { for (auto server : servers) {
VLOG(3) << "connecting endpoint: " << server; VLOG(3) << "connecting endpoint: " << server;
int conn = ConnectAddr(server, COMM_HEAD); int conn = ConnectAddr(server, head);
connects.push_back(conn); connects.push_back(conn);
} }
VLOG(3) << "connecting completed..."; VLOG(3) << "connecting completed...";
...@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers, ...@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers,
template <typename CommUniqueId> template <typename CommUniqueId>
void RecvBroadCastCommID(std::string endpoint, void RecvBroadCastCommID(std::string endpoint,
std::vector<CommUniqueId>* nccl_ids) { std::vector<CommUniqueId>* nccl_ids, int ring_id) {
int server = CreateListenSocket(endpoint); int server = CreateListenSocket(endpoint);
RecvBroadCastCommID(server, endpoint, nccl_ids); RecvBroadCastCommID(server, endpoint, nccl_ids, ring_id);
CloseSocket(server); CloseSocket(server);
} }
template <typename CommUniqueId> template <typename CommUniqueId>
void RecvBroadCastCommID(int server_fd, std::string endpoint, void RecvBroadCastCommID(int server_fd, std::string endpoint,
std::vector<CommUniqueId>* nccl_ids) { std::vector<CommUniqueId>* nccl_ids, int ring_id) {
int client = SocketAccept(server_fd, COMM_HEAD); CommHead head;
head.ring_id = ring_id;
int client = SocketAccept(server_fd, head);
for (size_t i = 0; i < nccl_ids->size(); ++i) { for (size_t i = 0; i < nccl_ids->size(); ++i) {
VLOG(3) << "trainer: " << endpoint VLOG(3) << "trainer: " << endpoint
...@@ -360,11 +390,15 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) { ...@@ -360,11 +390,15 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) {
} }
/// template instantiation /// template instantiation
#define INSTANT_TEMPLATE(Type) \ #define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \ template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \ std::vector<Type> * nccl_ids, \
template void RecvBroadCastCommID<Type>(std::string endpoint, \ int ring_id = 0); \
std::vector<Type> * nccl_ids); template void RecvBroadCastCommID<Type>( \
std::string endpoint, std::vector<Type> * nccl_ids, int ring_id = 0); \
template void RecvBroadCastCommID<Type>(int server_fd, std::string endpoint, \
std::vector<Type>* nccl_ids, \
int ring_id = 0);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE(ncclUniqueId) INSTANT_TEMPLATE(ncclUniqueId)
......
...@@ -31,16 +31,16 @@ void CloseSocket(int fd); ...@@ -31,16 +31,16 @@ void CloseSocket(int fd);
template <typename CommUniqueId> template <typename CommUniqueId>
void SendBroadCastCommID(std::vector<std::string> servers, void SendBroadCastCommID(std::vector<std::string> servers,
std::vector<CommUniqueId>* nccl_ids); std::vector<CommUniqueId>* nccl_ids, int ring_id = 0);
template <typename CommUniqueId> template <typename CommUniqueId>
void RecvBroadCastCommID(std::string endpoint, void RecvBroadCastCommID(std::string endpoint,
std::vector<CommUniqueId>* nccl_ids); std::vector<CommUniqueId>* nccl_ids, int ring_id = 0);
// recv nccl id from socket // recv nccl id from socket
template <typename CommUniqueId> template <typename CommUniqueId>
void RecvBroadCastCommID(int server_fd, std::string endpoint, void RecvBroadCastCommID(int server_fd, std::string endpoint,
std::vector<CommUniqueId>* nccl_ids); std::vector<CommUniqueId>* nccl_ids, int ring_id = 0);
class SocketServer { class SocketServer {
public: public:
......
...@@ -126,11 +126,11 @@ class CollectiveHelper(object): ...@@ -126,11 +126,11 @@ class CollectiveHelper(object):
_add_sync_by_allreduce(block) _add_sync_by_allreduce(block)
return return
comm_id_var = block.create_var(
name=unique_name.generate('comm_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op( block.append_op(
type='c_gen_nccl_id', type='c_gen_nccl_id',
inputs={}, inputs={},
...@@ -139,6 +139,7 @@ class CollectiveHelper(object): ...@@ -139,6 +139,7 @@ class CollectiveHelper(object):
'rank': rank, 'rank': rank,
'endpoint': current_endpoint, 'endpoint': current_endpoint,
'other_endpoints': other_endpoints, 'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
block.append_op( block.append_op(
...@@ -152,10 +153,6 @@ class CollectiveHelper(object): ...@@ -152,10 +153,6 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
elif core.is_compiled_with_xpu(): elif core.is_compiled_with_xpu():
comm_id_var = block.create_var(
name=unique_name.generate('bkcl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op( block.append_op(
type='c_gen_bkcl_id', type='c_gen_bkcl_id',
inputs={}, inputs={},
...@@ -164,6 +161,7 @@ class CollectiveHelper(object): ...@@ -164,6 +161,7 @@ class CollectiveHelper(object):
'rank': rank, 'rank': rank,
'endpoint': current_endpoint, 'endpoint': current_endpoint,
'other_endpoints': other_endpoints, 'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
block.append_op( block.append_op(
...@@ -177,24 +175,20 @@ class CollectiveHelper(object): ...@@ -177,24 +175,20 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op( block.append_op(
type='c_gen_hccl_id', type='c_gen_hccl_id',
inputs={}, inputs={},
outputs={'Out': hccl_id_var}, outputs={'Out': comm_id_var},
attrs={ attrs={
'rank': rank, 'rank': rank,
'endpoint': current_endpoint, 'endpoint': current_endpoint,
'other_endpoints': other_endpoints, 'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
block.append_op( block.append_op(
type='c_comm_init_hccl', type='c_comm_init_hccl',
inputs={'X': hccl_id_var}, inputs={'X': comm_id_var},
outputs={}, outputs={},
attrs={ attrs={
'rank': rank, 'rank': rank,
......
...@@ -73,7 +73,7 @@ class FP16Utils(object): ...@@ -73,7 +73,7 @@ class FP16Utils(object):
return inserted_op_num return inserted_op_num
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id): def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
""" """
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding 2. revise amp inifine grad checking for sharding
...@@ -146,6 +146,7 @@ class FP16Utils(object): ...@@ -146,6 +146,7 @@ class FP16Utils(object):
name=inf_var_name + "@sharding", name=inf_var_name + "@sharding",
shape=inf_var.shape, shape=inf_var.shape,
dtype=inf_var.dtype) dtype=inf_var.dtype)
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
type='cast', type='cast',
...@@ -156,19 +157,26 @@ class FP16Utils(object): ...@@ -156,19 +157,26 @@ class FP16Utils(object):
"out_dtype": inf_var_int32.dtype, "out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
# this allreduce communication should not overlap with calc update_loss_scaling_op_idx += 1
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1, # allreduce(mp)->allreduce(sharding)->allreduce(pp)
type='c_allreduce_max', for ring_id in ring_ids:
inputs={'X': inf_var_int32}, if ring_id == -1: continue
outputs={'Out': inf_var_int32}, # this allreduce communication should not overlap with calc
attrs={ block._insert_op_without_sync(
'ring_id': ring_id, update_loss_scaling_op_idx,
'use_calc_stream': True, type='c_allreduce_max',
OP_ROLE_KEY: OpRole.Optimize inputs={'X': inf_var_int32},
}) outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 2, update_loss_scaling_op_idx,
type='cast', type='cast',
inputs={'X': inf_var_int32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding}, outputs={'Out': inf_var_sharding},
...@@ -177,11 +185,12 @@ class FP16Utils(object): ...@@ -177,11 +185,12 @@ class FP16Utils(object):
"out_dtype": inf_var_sharding.dtype, "out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
update_loss_scaling_op_idx += 1
block._sync_with_cpp() block._sync_with_cpp()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism # TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod @staticmethod
def sync_amp_check_nan_inf(block, ring_id): def sync_amp_check_nan_inf(block, ring_ids):
update_loss_scaling_op_idx = -1 update_loss_scaling_op_idx = -1
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
...@@ -189,10 +198,14 @@ class FP16Utils(object): ...@@ -189,10 +198,14 @@ class FP16Utils(object):
update_loss_scaling_op_idx = idx update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0] inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD") op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
break
# not use amp # not use amp
if update_loss_scaling_op_idx == -1: if update_loss_scaling_op_idx == -1:
return return
# 0. inf_var_int32 = cast(inf_var)
# 1. inf_var_int32 = allreduce_max(inf_var_int32)
# 3. inf_var = cast(inf_var_int32)
inf_var = block.var(inf_var_name) inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var( inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32", name=inf_var_name + "@cast_int32",
...@@ -212,18 +225,25 @@ class FP16Utils(object): ...@@ -212,18 +225,25 @@ class FP16Utils(object):
"out_dtype": inf_var_int32.dtype, "out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
update_loss_scaling_op_idx += 1
# allreduce(mp)->allreduce(pp)
for ring_id in ring_ids:
if ring_id == -1: continue
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 1, update_loss_scaling_op_idx,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='cast', type='cast',
inputs={'X': inf_var_int32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global}, outputs={'Out': inf_var_global},
...@@ -232,4 +252,5 @@ class FP16Utils(object): ...@@ -232,4 +252,5 @@ class FP16Utils(object):
"out_dtype": inf_var_global.dtype, "out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
update_loss_scaling_op_idx += 1
block._sync_with_cpp() block._sync_with_cpp()
...@@ -25,7 +25,7 @@ class GradientClipHelper(object): ...@@ -25,7 +25,7 @@ class GradientClipHelper(object):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip") and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard, pure_dp_degree=1): def prune_gradient_clip(self, block, shard, ring_ids):
""" """
prune gradient_clip related ops for params that not belong to cur shard prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul prune: square, reduce_sum, elementwise_mul
...@@ -82,33 +82,23 @@ class GradientClipHelper(object): ...@@ -82,33 +82,23 @@ class GradientClipHelper(object):
assert (len(op.desc.output_arg_names()) == 1) assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0] sum_res = op.desc.output_arg_names()[0]
# this allreduce should not overlap with calc and should be scheduled in calc stream # allreduce(mp)->allreduce(sharding)->allreduce(pp)
block._insert_op_without_sync( idx_offset = 1
idx + 1, for ring_id in ring_ids:
type='c_allreduce_sum', if ring_id == -1: continue
inputs={'X': sum_res}, # this allreduce should not overlap with calc and should be scheduled in calc stream
outputs={'Out': sum_res},
attrs={
'ring_id': self.mp_ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 2, idx + idx_offset,
type='scale', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'scale': 1.0 / float(pure_dp_degree), 'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism", 'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0, 'use_calc_stream': True,
'bias_after_scale': False, OP_ROLE_KEY: OpRole.Optimize,
OP_ROLE_KEY: OpRole.Optimize
}) })
idx_offset += 1
# the grad sum here should take the all and only param in the current shard # the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname) to_check_param = set(reversed_x_paramname)
...@@ -126,43 +116,32 @@ class GradientClipHelper(object): ...@@ -126,43 +116,32 @@ class GradientClipHelper(object):
return return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism # TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1): def sync_global_norm(self, block, ring_ids):
""" """
prune gradient_clip related ops for params that not belong to cur shard prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div keep: sum, sqrt, elementwise_max, elementwise_div
""" """
# FIXME(wangxi): mp should prune duplicated param_grads
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
if op.type == "sum": if op.type == "sum":
sum_res = op.desc.output_arg_names()[0] sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync( for ring_id in ring_ids:
idx + 1, if ring_id == -1: continue
type='c_allreduce_sum',
inputs={'X': sum_res}, idx = idx + 1
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 2, idx,
type='scale', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'scale': 1.0 / float(pure_dp_degree), 'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism", 'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0, 'use_calc_stream': True,
'bias_after_scale': False, OP_ROLE_KEY: OpRole.Optimize,
OP_ROLE_KEY: OpRole.Optimize
}) })
return
return
...@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# if not use sharding, adapt amp/clip, for remain parallelism. # if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt # cast --> amp --> clip --> opt
if self.sharding_degree <= 1: if self.sharding_degree <= 1:
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp # amp
FP16Utils.sync_amp_check_nan_inf(main_block, self.global_ring_id) FP16Utils.sync_amp_check_nan_inf(
main_block, [self.mp_ring_id, self.pp_ring_id])
# clip # clip
gradientclip_helper = GradientClipHelper(self.global_ring_id) gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm( gradientclip_helper.sync_global_norm(
main_block, self.global_ring_id, self.dp_degree) main_block, [self.mp_ring_id, self.pp_ring_id])
# step6: loss div dp_degree # step6: loss div dp_degree
global_dp_degree = self.sharding_degree * self.dp_degree global_dp_degree = self.sharding_degree * self.dp_degree
...@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_rank, pp_rank,
ring_id, ring_id,
False, False,
global_ring_id=self.global_ring_id,
sync=False) sync=False)
def _init_npu_pipeline_comm(self, startup_block): def _init_npu_pipeline_comm(self, startup_block):
...@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair = send_to_next_pair if even else recv_from_prev_pair pair = send_to_next_pair if even else recv_from_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(pair, logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(pair,
ring_id)) ring_id))
...@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair = recv_from_next_pair if even else send_to_prev_pair pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair, logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair,
ring_id)) ring_id))
...@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair[0] * 1000 + pair[1], pair[0] * 1000 + pair[1],
max_ring_id + 1) # 3->0 not in pp_ring_map max_ring_id + 1) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format( logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format(
...@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair[0] * 1000 + pair[1], pair[0] * 1000 + pair[1],
max_ring_id + 2) # 0->3 not in pp_ring_map max_ring_id + 2) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format( logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format(
...@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase):
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank) self.pp_rank_, self.pp_rank)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id,
False,
sync=False)
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block) self._init_npu_pipeline_comm(startup_block)
return return
...@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id)) logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank in pair: if self.pp_rank in pair:
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
def _init_comm(self): def _init_comm(self):
...@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase):
dtype=core.VarDesc.VarType.INT32, dtype=core.VarDesc.VarType.INT32,
persistable=False) persistable=False)
# global ring
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.global_endpoints,
self.global_rank,
self.global_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# mp ring # mp ring
if self.mp_degree > 1: if self.mp_degree > 1:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
...@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.mp_rank, self.mp_rank,
self.mp_ring_id, self.mp_ring_id,
False, False,
global_ring_id=self.global_ring_id,
sync=False) sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding ring # sharding ring
if self.sharding_degree > 1: if self.sharding_degree > 1:
...@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_rank, self.sharding_rank,
self.sharding_ring_id, self.sharding_ring_id,
False, False,
global_ring_id=self.global_ring_id,
sync=False) sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# pp ring # pp ring
if self.pp_degree > 1: if self.pp_degree > 1:
...@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_rank, self.dp_rank,
self.dp_ring_id, self.dp_ring_id,
False, False,
global_ring_id=self.global_ring_id,
sync=False) sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
weightdecay_helper = WeightDecayHelper() weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard) weightdecay_helper.prune_weight_decay(block, self._shard)
# FIXME(wangxi): mp should prune duplicated param_grads
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite # group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync # amp could use global group for sync
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, FP16Utils.prune_fp16(
self.global_ring_id) block, self._shard, self._reduced_grads_to_param,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp) # clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
if self.mp_degree * self.pp_degree == 1: gradientclip_helper = GradientClipHelper(None)
# separate the sharding-hybrid senario to keep the accuracy gradientclip_helper.prune_gradient_clip(
gradientclip_helper = GradientClipHelper(self.sharding_ring_id) block, self._shard,
gradientclip_helper.prune_gradient_clip( [self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
block, self._shard, pure_dp_degree=1)
else:
gradientclip_helper = GradientClipHelper(self.global_ring_id)
gradientclip_helper.prune_gradient_clip(
block, self._shard, pure_dp_degree=self.dp_degree)
# build prog deps # build prog deps
reduced_grads = [] reduced_grads = []
...@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp # pp
if self.pp_degree > 1: if self.pp_degree > 1:
self.pp_ring_id = 20 self.pp_pair_ring_id = 20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self.pp_ring_id = 4
self.pp_rank = self.global_rank // (self.sharding_degree * self.pp_rank = self.global_rank // (self.sharding_degree *
self.mp_degree) % self.pp_degree self.mp_degree) % self.pp_degree
# (NOTE): Already adjust for (outter-pure) dp # (NOTE): Already adjust for (outter-pure) dp
...@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_first_stage_idx + pp_stage_offset * i]) pp_first_stage_idx + pp_stage_offset * i])
assert self.current_endpoint in self.pp_group_endpoints assert self.current_endpoint in self.pp_group_endpoints
else: else:
self.pp_degree = 1
self.pp_ring_id = -1 self.pp_ring_id = -1
self.pp_degree = 1
self.pp_pair_ring_id = -1
self.pp_rank = -1 self.pp_rank = -1
self.pp_group_id = -1 self.pp_group_id = -1
self.pp_group_endpoints = [] self.pp_group_endpoints = []
...@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase):
outputs={'Out': params}, outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id, attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward}) OP_ROLE_KEY: OpRole.Forward})
# sync within global group
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding gradient merge # sharding gradient merge
def create_persistable_gradients_and_insert_merge_ops( def create_persistable_gradients_and_insert_merge_ops(
......
...@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self.set_strategy(strategy, 'sharding') self.set_strategy(strategy, 'sharding')
self.optimizer(avg_cost, strategy, train_prog, startup_prog) self.optimizer(avg_cost, strategy, train_prog, startup_prog)
parameters = [ parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True x.name for x in train_prog.list_vars() if x.persistable is True
] ]
ops = [op.type for op in avg_cost.block.ops] ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()] vars = [x.name for x in train_prog.list_vars()]
...@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
]) ])
class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
def setUp(self): def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "3" os.environ["PADDLE_TRAINER_ID"] = "3"
os.environ[ os.environ[
...@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
self.sharding_ring_id = 1 self.sharding_ring_id = 1
self.dp_ring_id = 2 self.dp_ring_id = 2
self.global_ring_id = 3 self.global_ring_id = 3
self.pp_ring_id = 20 self.pp_pair_ring_id = 20
def test_sharding_with_mp(self): def test_sharding_with_mp(self):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API # NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
...@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1": 0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints") sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
...@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2": 0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints") dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
...@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1": 0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints") sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
...@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2": 0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints") dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
...@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1": 0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints") sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
...@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2": 0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints") dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
...@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init'
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream',
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if op.type == "c_comm_init" if op.type == "c_comm_init"
] ]
self.assertIn(self.sharding_ring_id, created_ring_ids) self.assertIn(self.sharding_ring_id, created_ring_ids)
self.assertIn(self.pp_ring_id, created_ring_ids) self.assertIn(self.pp_pair_ring_id, created_ring_ids)
# check correctness of pp group # check correctness of pp group
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1": 0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints") sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
...@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None sharding_group_waiting_port = None
for op in startup_prog_ops: for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2": 0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints") dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
...@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): ...@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if op.type == 'c_allreduce_sum': if op.type == 'c_allreduce_sum':
assert 'FusedOutput' in op.input_arg_names[0] assert 'FusedOutput' in op.input_arg_names[0]
def test_hybrid_with_mp_pp_amp_gclip(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
self.set_strategy(strategy, 'amp')
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 2,
"pp_degree": 2,
"dp_degree": 1,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init'
])
# pp + mp, partial send recv
self.assertIn('partial_recv', main_prog_op_types)
self.assertIn('partial_allgather', main_prog_op_types)
self.assertIn('partial_send', main_prog_op_types)
# amp check_finite_and_unscale, allreduce(mp)->allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 2)
# global gradient clip, allreduce(mp)->allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_sum'), 2)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.mp_ring_id, created_ring_ids)
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
# check correctness of pp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
mp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(mp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_1":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册