diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc index 16f9454e9376e4368a478cf8adf9e3f988868785..ba9b70aea7b96c52f29411a9879a878aee21195b 100644 --- a/paddle/fluid/imperative/bkcl_context.cc +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -92,7 +92,7 @@ void BKCLParallelContext::Init() { << " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id << " ring id: " << ring_id; // it will assign bkcl_comm in XPUDeviceContext within ring_id - platform::BKCLCommContext::Instance().CreateBKCLComm( + platform::BKCLCommContext::Instance().CreateComm( &bkcl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id); } @@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) { << " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id << " ring id: " << ring_id; // it will assign bkcl_comm in XPUDeviceContext within ring_id - platform::BKCLCommContext::Instance().CreateBKCLComm( + platform::BKCLCommContext::Instance().CreateComm( &bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id); } diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 9f036742f0f5dd4113a92a67980484eca2da3965..32becda4edc95a638889d9e29312f1ad3d37a640 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -75,7 +75,7 @@ void NCCLParallelContext::Init() { << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id << " ring id: " << ring_id; // it will assign nccl_comm in CUDADeviceContext within ring_id - platform::NCCLCommContext::Instance().CreateNCCLComm( + platform::NCCLCommContext::Instance().CreateComm( &nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id); @@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) { << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id << " ring id: " << ring_id; // it will assign nccl_comm in CUDADeviceContext within ring_id - platform::NCCLCommContext::Instance().CreateNCCLComm( + platform::NCCLCommContext::Instance().CreateComm( &nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id); compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New( diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index f4510861672ca6c9e2fb329e603324a841866f73..9bf86dc92677380bcb7f28add88808322698a2b5 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -24,15 +24,16 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/collective_helper.h" +#endif + namespace paddle { namespace framework { class Scope; } // namespace framework } // namespace paddle -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) -#include "paddle/fluid/platform/collective_helper.h" -#endif namespace paddle { namespace operators { @@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { +// TODO(wangxi): Put this in the unified header file +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + using UniqueId = ncclUniqueId; + using Place = platform::CUDAPlace; + using CommContext = platform::NCCLCommContext; +#elif defined(PADDLE_WITH_XPU_BKCL) + using UniqueId = BKCLUniqueId; + using Place = platform::XPUPlace; + using CommContext = platform::BKCLCommContext; +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should be compiled with GPU or XPU.")); +#endif + PADDLE_ENFORCE_EQ(is_gpu_place(place) || is_xpu_place(place), true, platform::errors::PreconditionNotMet( "CCommInitOp can run on gpu or xpu place only.")); +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) auto var = scope.FindVar(Input("X")); PADDLE_ENFORCE_NOT_NULL( var, platform::errors::InvalidArgument("Input con not be empty.")); - if (is_gpu_place(place)) { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - ncclUniqueId* nccl_id = var->GetMutable(); - - int nranks = Attr("nranks"); - int rank_id = Attr("rank"); - int rid = Attr("ring_id"); - int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).device; - if (Attr("device_id") >= 0) { - device_id = Attr("device_id"); - } - platform::NCCLCommContext::Instance().CreateNCCLComm( - nccl_id, nranks, rank_id, device_id, rid); -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should be compiled with GPU.")); -#endif - } else if (is_xpu_place(place)) { + + UniqueId* comm_id = var->GetMutable(); + + int nranks = Attr("nranks"); + int rank_id = Attr("rank"); + int rid = Attr("ring_id"); + #if defined(PADDLE_WITH_XPU_BKCL) - BKCLUniqueId* bkcl_id = var->GetMutable(); - - int nranks = Attr("nranks"); - int rank_id = Attr("rank"); - int rid = Attr("ring_id"); - PADDLE_ENFORCE_EQ( - rid, 0, - platform::errors::OutOfRange( - "Ring id must equal 0 in multi Kunlun cards training, but got %d", - rid)); - int device_id = BOOST_GET_CONST(platform::XPUPlace, place).device; - if (Attr("device_id") >= 0) { - device_id = Attr("device_id"); - } - platform::BKCLCommContext::Instance().CreateBKCLComm( - bkcl_id, nranks, rank_id, device_id, rid); -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should be compiled with XPU.")); + PADDLE_ENFORCE_EQ( + rid, 0, + platform::errors::OutOfRange( + "Ring id must equal 0 in multi Kunlun cards training, but got %d", + rid)); #endif - } else { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "CCommInitOp can run on gpu or xpu place only.")); + + int device_id = BOOST_GET_CONST(Place, place).device; + if (Attr("device_id") >= 0) { + device_id = Attr("device_id"); } + CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id, + rid); +#endif } }; diff --git a/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc b/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc index 65685902b422e8f9b799ba7ac6760109520c2b49..ec174ad0e56bc938701101fd78a1976354c40b15 100644 --- a/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc @@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { int rank = Attr("rank"); - framework::Scope& local_scope = scope.NewScope(); + int ring_id = Attr("ring_id"); std::function func = [&](size_t i) -> std::string { return Output("Out"); @@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase { GenBKCLID(&bkcl_ids); std::vector endpoint_list = Attr>("other_endpoints"); - platform::SendBroadCastCommID(endpoint_list, &bkcl_ids); + platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id); } else { std::string endpoint = Attr("endpoint"); - platform::RecvBroadCastCommID(endpoint, &bkcl_ids); + platform::RecvBroadCastCommID(endpoint, &bkcl_ids, ring_id); } CopyBKCLIDToVar(bkcl_ids, func, scope); - scope.DeleteScope(&local_scope); } }; @@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser "(int default 0) " "The rank of the trainer in distributed training.") .SetDefault(0); + AddAttr("ring_id", "(int default 0) user specified ring id") + .SetDefault(0); } }; diff --git a/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc index af1e576a8c74f509822a1f227976c6a2ad803d82..9ab7d90efaa9f3b0124dd3e40275df8d07657e52 100644 --- a/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc @@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { int rank = Attr("rank"); - framework::Scope& local_scope = scope.NewScope(); + int ring_id = Attr("ring_id"); std::function func = [&](size_t i) -> std::string { return Output("Out"); @@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase { GenHCCLID(&hccl_ids); std::vector endpoint_list = Attr>("other_endpoints"); - platform::SendBroadCastCommID(endpoint_list, &hccl_ids); + platform::SendBroadCastCommID(endpoint_list, &hccl_ids, ring_id); } else { - platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids); + platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids, ring_id); } CopyHCCLIDToVar(hccl_ids, func, scope); - scope.DeleteScope(&local_scope); } }; @@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser "(int default 0) " "The rank of the trainer in distributed training.") .SetDefault(0); + AddAttr("ring_id", "(int default 0) user specified ring id") + .SetDefault(0); } }; diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index 470537582e97838322de2dabdd880b254d6401c9..0a0a824b77586676fb5ffecc9c918649386f8cd0 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { int rank = Attr("rank"); - framework::Scope& local_scope = scope.NewScope(); + int ring_id = Attr("ring_id"); std::function func = [&](size_t i) -> std::string { return Output("Out"); @@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase { GenNCCLID(&nccl_ids); std::vector endpoint_list = Attr>("other_endpoints"); - platform::SendBroadCastCommID(endpoint_list, &nccl_ids); + platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id); } else { - platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id); } CopyNCCLIDToVar(nccl_ids, func, scope); - scope.DeleteScope(&local_scope); } }; @@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser "(int default 0) " "The rank of the trainer in distributed training.") .SetDefault(0); + AddAttr("ring_id", "(int default 0) user specified ring id") + .SetDefault(0); } }; diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index f2b478f7d20e99293566985de31e3aad02795bd8..cc9f2c75989db8b85a8476e87773036000d40516 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -72,8 +72,8 @@ class NCCLCommImpl : public NCCLComm { std::shared_ptr comm_event_; }; -NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, - int rank, int dev_id, int ring_id) { +NCCLComm* NCCLCommContext::CreateComm(ncclUniqueId* nccl_id, int nranks, + int rank, int dev_id, int ring_id) { PADDLE_ENFORCE_NOT_NULL(nccl_id, platform::errors::InvalidArgument( "The nccl unique id should not be null.")); @@ -225,8 +225,8 @@ class BKCLCommImpl : public BKCLComm { std::unique_ptr dev_ctx_; }; -BKCLComm* BKCLCommContext::CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, - int rank, int dev_id, int ring_id) { +BKCLComm* BKCLCommContext::CreateComm(BKCLUniqueId* bkcl_id, int nranks, + int rank, int dev_id, int ring_id) { PADDLE_ENFORCE_NOT_NULL(bkcl_id, platform::errors::InvalidArgument( "The bkcl unique id should not be null.")); diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index b0b857f7ee3f2ae0e9ca86ee6f51a6fdd047a90f..b9be9dc8304e1e8e38acfa1bc4b48c8424ed5b53 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -72,8 +72,8 @@ class NCCLCommContext { return comm_ctx; } - NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank, - int dev_id, int ring_id = 0); + NCCLComm* CreateComm(ncclUniqueId* nccl_id, int nranks, int rank, int dev_id, + int ring_id = 0); void CreateAllNCCLComms(const std::vector& dev_ids, int ring_id = 0); @@ -274,8 +274,8 @@ class BKCLCommContext { return comm_ctx; } - BKCLComm* CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, int rank, - int dev_id, int ring_id = 0); + BKCLComm* CreateComm(BKCLUniqueId* bkcl_id, int nranks, int rank, int dev_id, + int ring_id = 0); void CreateAllBKCLComms(const std::vector& dev_ids, int ring_id = 0); diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index 5f6dd5679a1a8eacc270a17e0f725e4311897dda..73bc2c41a0bc9c568f0c90e68adb471406d92a03 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -42,7 +42,10 @@ namespace platform { std::once_flag SocketServer::init_flag_; -constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; +struct CommHead { + int version = 1; // unused for now + int ring_id = 0; +}; // Check system calls, such as socket, bind. #define CHECK_SYS_CALL(call, name) \ @@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) { void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); } -static int SocketAccept(int server_fd, const char* head) { +static int SocketAccept(int server_fd, const CommHead head) { + static_assert(sizeof(CommHead) <= 1024, + "sizeof(CommHead) must <= buffer size"); + struct sockaddr_in client_addr; socklen_t addr_length = sizeof(client_addr); char buffer[1024] = {0}; int conn = -1; + const char* phead = reinterpret_cast(&head); while (true) { CHECK_SYS_CALL_VAL( @@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) { &addr_length), "accept", conn); - int ret_val = SocketRecv(conn, buffer, strlen(head)); - if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) { + int ret_val = SocketRecv(conn, buffer, sizeof(head)); + if (ret_val > 0 && memcmp(buffer, phead, sizeof(head)) == 0) { + // send a message to the sender, indicating that the link is correct + CHECK_SYS_CALL(SocketSend(conn, phead, sizeof(head)), "send"); break; // accept client } else { VLOG(3) << "socket read failed with ret_val=" << ret_val; @@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) { return conn; } -static int ConnectAddr(const std::string& ep, const char* head) { +static int ConnectAddr(const std::string& ep, const CommHead head) { auto addr = paddle::string::Split(ep, ':'); PADDLE_ENFORCE_EQ( addr.size(), 2UL, @@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) { std::string host = addr[0]; int port = std::stoi(addr[1]); - int sock = -1; - CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock); - struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; @@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) { platform::errors::Unavailable("Open address %s failed: %s", ep, strerror(errno))); + static_assert(sizeof(CommHead) <= 1024, + "sizeof(CommHead) must <= buffer size"); + char buffer[1024] = {0}; + const char* phead = reinterpret_cast(&head); + // TODO(wangxi) Set from env, default 900s=15min int timeout = 900 * 1000; int try_times = 0; int total_time = 0; + + int sock = -1; + CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock); while (true) { int ret_val = -1; RETRY_SYS_CALL_VAL( @@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) { continue; } - CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send"); - break; + CHECK_SYS_CALL(SocketSend(sock, phead, sizeof(head)), "send"); + ret_val = SocketRecv(sock, buffer, sizeof(head)); + if (ret_val > 0 && memcmp(buffer, phead, sizeof(head)) == 0) { + // recv same message from recver, indicating that the link is correct + break; // accept client + } else { + VLOG(3) << "socket read failed with ret_val=" << ret_val; + CloseSocket(sock); + } + sock = -1; + CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock); + // unmatched link, retry after 80ms + std::this_thread::sleep_for(std::chrono::milliseconds(80)); } return sock; } @@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) { template void SendBroadCastCommID(std::vector servers, - std::vector* nccl_ids) { + std::vector* nccl_ids, int ring_id) { + CommHead head; + head.ring_id = ring_id; + // connect with server std::vector connects; for (auto server : servers) { VLOG(3) << "connecting endpoint: " << server; - int conn = ConnectAddr(server, COMM_HEAD); + int conn = ConnectAddr(server, head); connects.push_back(conn); } VLOG(3) << "connecting completed..."; @@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector servers, template void RecvBroadCastCommID(std::string endpoint, - std::vector* nccl_ids) { + std::vector* nccl_ids, int ring_id) { int server = CreateListenSocket(endpoint); - RecvBroadCastCommID(server, endpoint, nccl_ids); + RecvBroadCastCommID(server, endpoint, nccl_ids, ring_id); CloseSocket(server); } template void RecvBroadCastCommID(int server_fd, std::string endpoint, - std::vector* nccl_ids) { - int client = SocketAccept(server_fd, COMM_HEAD); + std::vector* nccl_ids, int ring_id) { + CommHead head; + head.ring_id = ring_id; + int client = SocketAccept(server_fd, head); for (size_t i = 0; i < nccl_ids->size(); ++i) { VLOG(3) << "trainer: " << endpoint @@ -360,11 +390,15 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) { } /// template instantiation -#define INSTANT_TEMPLATE(Type) \ - template void SendBroadCastCommID(std::vector servers, \ - std::vector * nccl_ids); \ - template void RecvBroadCastCommID(std::string endpoint, \ - std::vector * nccl_ids); +#define INSTANT_TEMPLATE(Type) \ + template void SendBroadCastCommID(std::vector servers, \ + std::vector * nccl_ids, \ + int ring_id = 0); \ + template void RecvBroadCastCommID( \ + std::string endpoint, std::vector * nccl_ids, int ring_id = 0); \ + template void RecvBroadCastCommID(int server_fd, std::string endpoint, \ + std::vector* nccl_ids, \ + int ring_id = 0); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) INSTANT_TEMPLATE(ncclUniqueId) diff --git a/paddle/fluid/platform/gen_comm_id_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h index fb5d8d8fcd94059cbef66de809bca295d205a73c..6198519eb06df8dfac1540f5bca8387a03afc71f 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -31,16 +31,16 @@ void CloseSocket(int fd); template void SendBroadCastCommID(std::vector servers, - std::vector* nccl_ids); + std::vector* nccl_ids, int ring_id = 0); template void RecvBroadCastCommID(std::string endpoint, - std::vector* nccl_ids); + std::vector* nccl_ids, int ring_id = 0); // recv nccl id from socket template void RecvBroadCastCommID(int server_fd, std::string endpoint, - std::vector* nccl_ids); + std::vector* nccl_ids, int ring_id = 0); class SocketServer { public: diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 9e891062bcbccbca4f34d8a2e211ca5f3ece44a3..a44607d13aafce39746e7579ffaa2a66e33d1b80 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -126,11 +126,11 @@ class CollectiveHelper(object): _add_sync_by_allreduce(block) return + comm_id_var = block.create_var( + name=unique_name.generate('comm_id'), + persistable=True, + type=core.VarDesc.VarType.RAW) if core.is_compiled_with_cuda(): - comm_id_var = block.create_var( - name=unique_name.generate('nccl_id'), - persistable=True, - type=core.VarDesc.VarType.RAW) block.append_op( type='c_gen_nccl_id', inputs={}, @@ -139,6 +139,7 @@ class CollectiveHelper(object): 'rank': rank, 'endpoint': current_endpoint, 'other_endpoints': other_endpoints, + 'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward }) block.append_op( @@ -152,10 +153,6 @@ class CollectiveHelper(object): OP_ROLE_KEY: OpRole.Forward }) elif core.is_compiled_with_xpu(): - comm_id_var = block.create_var( - name=unique_name.generate('bkcl_id'), - persistable=True, - type=core.VarDesc.VarType.RAW) block.append_op( type='c_gen_bkcl_id', inputs={}, @@ -164,6 +161,7 @@ class CollectiveHelper(object): 'rank': rank, 'endpoint': current_endpoint, 'other_endpoints': other_endpoints, + 'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward }) block.append_op( @@ -177,24 +175,20 @@ class CollectiveHelper(object): OP_ROLE_KEY: OpRole.Forward }) elif core.is_compiled_with_npu(): - hccl_id_var = block.create_var( - name=unique_name.generate('hccl_id'), - persistable=True, - type=core.VarDesc.VarType.RAW) - endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)} block.append_op( type='c_gen_hccl_id', inputs={}, - outputs={'Out': hccl_id_var}, + outputs={'Out': comm_id_var}, attrs={ 'rank': rank, 'endpoint': current_endpoint, 'other_endpoints': other_endpoints, + 'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward }) block.append_op( type='c_comm_init_hccl', - inputs={'X': hccl_id_var}, + inputs={'X': comm_id_var}, outputs={}, attrs={ 'rank': rank, diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index 8e6363537298459bd930e82fd18c662522171696..07272404768ff781fce7d18d634a56129bd0fb1c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -73,7 +73,7 @@ class FP16Utils(object): return inserted_op_num @staticmethod - def prune_fp16(block, shard, reduced_grads_to_param, ring_id): + def prune_fp16(block, shard, reduced_grads_to_param, ring_ids): """ 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard 2. revise amp inifine grad checking for sharding @@ -146,6 +146,7 @@ class FP16Utils(object): name=inf_var_name + "@sharding", shape=inf_var.shape, dtype=inf_var.dtype) + block._insert_op_without_sync( update_loss_scaling_op_idx, type='cast', @@ -156,19 +157,26 @@ class FP16Utils(object): "out_dtype": inf_var_int32.dtype, OP_ROLE_KEY: OpRole.Optimize }) - # this allreduce communication should not overlap with calc - block._insert_op_without_sync( - update_loss_scaling_op_idx + 1, - type='c_allreduce_max', - inputs={'X': inf_var_int32}, - outputs={'Out': inf_var_int32}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize - }) + update_loss_scaling_op_idx += 1 + + # allreduce(mp)->allreduce(sharding)->allreduce(pp) + for ring_id in ring_ids: + if ring_id == -1: continue + # this allreduce communication should not overlap with calc + block._insert_op_without_sync( + update_loss_scaling_op_idx, + type='c_allreduce_max', + inputs={'X': inf_var_int32}, + outputs={'Out': inf_var_int32}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) + update_loss_scaling_op_idx += 1 + block._insert_op_without_sync( - update_loss_scaling_op_idx + 2, + update_loss_scaling_op_idx, type='cast', inputs={'X': inf_var_int32}, outputs={'Out': inf_var_sharding}, @@ -177,11 +185,12 @@ class FP16Utils(object): "out_dtype": inf_var_sharding.dtype, OP_ROLE_KEY: OpRole.Optimize }) + update_loss_scaling_op_idx += 1 block._sync_with_cpp() # TODO (JZ-LIANG) revise this for uniform mixed parallelism @staticmethod - def sync_amp_check_nan_inf(block, ring_id): + def sync_amp_check_nan_inf(block, ring_ids): update_loss_scaling_op_idx = -1 for idx, op in reversed(list(enumerate(block.ops))): @@ -189,10 +198,14 @@ class FP16Utils(object): update_loss_scaling_op_idx = idx inf_var_name = op.desc.input('FoundInfinite')[0] op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD") + break # not use amp if update_loss_scaling_op_idx == -1: return + # 0. inf_var_int32 = cast(inf_var) + # 1. inf_var_int32 = allreduce_max(inf_var_int32) + # 3. inf_var = cast(inf_var_int32) inf_var = block.var(inf_var_name) inf_var_int32 = block.create_var( name=inf_var_name + "@cast_int32", @@ -212,18 +225,25 @@ class FP16Utils(object): "out_dtype": inf_var_int32.dtype, OP_ROLE_KEY: OpRole.Optimize }) + update_loss_scaling_op_idx += 1 + + # allreduce(mp)->allreduce(pp) + for ring_id in ring_ids: + if ring_id == -1: continue + block._insert_op_without_sync( + update_loss_scaling_op_idx, + type='c_allreduce_max', + inputs={'X': inf_var_int32}, + outputs={'Out': inf_var_int32}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) + update_loss_scaling_op_idx += 1 + block._insert_op_without_sync( - update_loss_scaling_op_idx + 1, - type='c_allreduce_max', - inputs={'X': inf_var_int32}, - outputs={'Out': inf_var_int32}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize - }) - block._insert_op_without_sync( - update_loss_scaling_op_idx + 2, + update_loss_scaling_op_idx, type='cast', inputs={'X': inf_var_int32}, outputs={'Out': inf_var_global}, @@ -232,4 +252,5 @@ class FP16Utils(object): "out_dtype": inf_var_global.dtype, OP_ROLE_KEY: OpRole.Optimize }) + update_loss_scaling_op_idx += 1 block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index fd74f28b69e19000fa3f59b973ae165f8dd38abb..e3d344dca25b3690ace2ec91a9f10b5d6a0f1042 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -25,7 +25,7 @@ class GradientClipHelper(object): return op.desc.has_attr("op_namescope") \ and op.desc.attr("op_namescope").startswith("/gradient_clip") - def prune_gradient_clip(self, block, shard, pure_dp_degree=1): + def prune_gradient_clip(self, block, shard, ring_ids): """ prune gradient_clip related ops for params that not belong to cur shard prune: square, reduce_sum, elementwise_mul @@ -82,33 +82,23 @@ class GradientClipHelper(object): assert (len(op.desc.output_arg_names()) == 1) sum_res = op.desc.output_arg_names()[0] - # this allreduce should not overlap with calc and should be scheduled in calc stream - block._insert_op_without_sync( - idx + 1, - type='c_allreduce_sum', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={ - 'ring_id': self.mp_ring_id, - 'op_namescope': "/gradient_clip_model_parallelism", - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize, - }) - - # global norm should only be sum within each model parallelism word size when use global group - if pure_dp_degree > 1: + # allreduce(mp)->allreduce(sharding)->allreduce(pp) + idx_offset = 1 + for ring_id in ring_ids: + if ring_id == -1: continue + # this allreduce should not overlap with calc and should be scheduled in calc stream block._insert_op_without_sync( - idx + 2, - type='scale', + idx + idx_offset, + type='c_allreduce_sum', inputs={'X': sum_res}, outputs={'Out': sum_res}, attrs={ - 'scale': 1.0 / float(pure_dp_degree), + 'ring_id': ring_id, 'op_namescope': "/gradient_clip_model_parallelism", - 'bias': 0.0, - 'bias_after_scale': False, - OP_ROLE_KEY: OpRole.Optimize + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, }) + idx_offset += 1 # the grad sum here should take the all and only param in the current shard to_check_param = set(reversed_x_paramname) @@ -126,43 +116,32 @@ class GradientClipHelper(object): return # TODO (JZ-LIANG) revise this for uniform mixed parallelism - def sync_global_norm(self, block, ring_id, pure_dp_degree=1): + def sync_global_norm(self, block, ring_ids): """ prune gradient_clip related ops for params that not belong to cur shard prune: square, reduce_sum, elementwise_mul keep: sum, sqrt, elementwise_max, elementwise_div """ + # FIXME(wangxi): mp should prune duplicated param_grads for idx, op in reversed(list(enumerate(block.ops))): if not self._is_gradient_clip_op(op): continue if op.type == "sum": sum_res = op.desc.output_arg_names()[0] - block._insert_op_without_sync( - idx + 1, - type='c_allreduce_sum', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={ - 'ring_id': ring_id, - 'op_namescope': "/gradient_clip_model_parallelism", - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize, - }) - - # global norm should only be sum within each model parallelism word size - if pure_dp_degree > 1: + for ring_id in ring_ids: + if ring_id == -1: continue + + idx = idx + 1 block._insert_op_without_sync( - idx + 2, - type='scale', + idx, + type='c_allreduce_sum', inputs={'X': sum_res}, outputs={'Out': sum_res}, attrs={ - 'scale': 1.0 / float(pure_dp_degree), + 'ring_id': ring_id, 'op_namescope': "/gradient_clip_model_parallelism", - 'bias': 0.0, - 'bias_after_scale': False, - OP_ROLE_KEY: OpRole.Optimize + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, }) - - return + return diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 8a591120c0289edd373582f7173d607219b816a2..df775247c8c9e53bdc5c6314a81f3ea62d6148a5 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase): # if not use sharding, adapt amp/clip, for remain parallelism. # cast --> amp --> clip --> opt if self.sharding_degree <= 1: + # FIXME(wangxi): mp should prune duplicated param_grads when calc + # amp inf_var & clip global_norm_var + # amp - FP16Utils.sync_amp_check_nan_inf(main_block, self.global_ring_id) + FP16Utils.sync_amp_check_nan_inf( + main_block, [self.mp_ring_id, self.pp_ring_id]) # clip - gradientclip_helper = GradientClipHelper(self.global_ring_id) + gradientclip_helper = GradientClipHelper(None) gradientclip_helper.sync_global_norm( - main_block, self.global_ring_id, self.dp_degree) + main_block, [self.mp_ring_id, self.pp_ring_id]) # step6: loss div dp_degree global_dp_degree = self.sharding_degree * self.dp_degree @@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase): pp_rank, ring_id, False, - global_ring_id=self.global_ring_id, sync=False) def _init_npu_pipeline_comm(self, startup_block): @@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase): pair = send_to_next_pair if even else recv_from_prev_pair ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] self._init_pair_comm(pair, ring_id) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) my_pair.remove(pair) logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(pair, ring_id)) @@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase): pair = recv_from_next_pair if even else send_to_prev_pair ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] self._init_pair_comm(pair, ring_id) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) my_pair.remove(pair) logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair, ring_id)) @@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase): pair[0] * 1000 + pair[1], max_ring_id + 1) # 3->0 not in pp_ring_map self._init_pair_comm(pair, ring_id) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: my_pair.remove(pair) logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format( @@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase): pair[0] * 1000 + pair[1], max_ring_id + 2) # 0->3 not in pp_ring_map self._init_pair_comm(pair, ring_id) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: my_pair.remove(pair) logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format( @@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase): assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( self.pp_rank_, self.pp_rank) + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + self.pp_group_endpoints, + self.pp_rank, + self.pp_ring_id, + False, + sync=False) + if core.is_compiled_with_npu(): self._init_npu_pipeline_comm(startup_block) return @@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase): logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id)) if self.pp_rank in pair: self._init_pair_comm(pair, ring_id) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) def _init_comm(self): @@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase): dtype=core.VarDesc.VarType.INT32, persistable=False) - # global ring - self._collective_helper._init_communicator( - self._startup_program, - self.current_endpoint, - self.global_endpoints, - self.global_rank, - self.global_ring_id, - False, - global_ring_id=self.global_ring_id, - sync=False) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) - # mp ring if self.mp_degree > 1: self._collective_helper._init_communicator( @@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase): self.mp_rank, self.mp_ring_id, False, - global_ring_id=self.global_ring_id, sync=False) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) # sharding ring if self.sharding_degree > 1: @@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase): self.sharding_rank, self.sharding_ring_id, False, - global_ring_id=self.global_ring_id, sync=False) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) # pp ring if self.pp_degree > 1: @@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase): self.dp_rank, self.dp_ring_id, False, - global_ring_id=self.global_ring_id, sync=False) - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) startup_block._sync_with_cpp() @@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase): """ weightdecay_helper = WeightDecayHelper() weightdecay_helper.prune_weight_decay(block, self._shard) + + # FIXME(wangxi): mp should prune duplicated param_grads # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism # group. and each Data Parallelism group should have its own sync of FoundInfinite # amp could use global group for sync - FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - self.global_ring_id) + FP16Utils.prune_fp16( + block, self._shard, self._reduced_grads_to_param, + [self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id]) + # clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp) - if self.mp_degree * self.pp_degree == 1: - # separate the sharding-hybrid senario to keep the accuracy - gradientclip_helper = GradientClipHelper(self.sharding_ring_id) - gradientclip_helper.prune_gradient_clip( - block, self._shard, pure_dp_degree=1) - else: - gradientclip_helper = GradientClipHelper(self.global_ring_id) - gradientclip_helper.prune_gradient_clip( - block, self._shard, pure_dp_degree=self.dp_degree) + gradientclip_helper = GradientClipHelper(None) + gradientclip_helper.prune_gradient_clip( + block, self._shard, + [self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id]) # build prog deps reduced_grads = [] @@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase): # pp if self.pp_degree > 1: - self.pp_ring_id = 20 + self.pp_pair_ring_id = 20 + # pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3 + self.pp_ring_id = 4 self.pp_rank = self.global_rank // (self.sharding_degree * self.mp_degree) % self.pp_degree # (NOTE): Already adjust for (outter-pure) dp @@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase): pp_first_stage_idx + pp_stage_offset * i]) assert self.current_endpoint in self.pp_group_endpoints else: - self.pp_degree = 1 self.pp_ring_id = -1 + self.pp_degree = 1 + self.pp_pair_ring_id = -1 self.pp_rank = -1 self.pp_group_id = -1 self.pp_group_endpoints = [] @@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase): outputs={'Out': params}, attrs={'ring_id': self.dp_ring_id, OP_ROLE_KEY: OpRole.Forward}) - # sync within global group - append_naive_sync(startup_block, self.startup_prog_sync_var, - self.global_ring_id) # sharding gradient merge def create_persistable_gradients_and_insert_merge_ops( diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index a1cd0df8d7c7e8d3ff8183a71be2cf2307794d89..1387827736560e0e2e3fb00041eb372d77530c09 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): self.set_strategy(strategy, 'sharding') self.optimizer(avg_cost, strategy, train_prog, startup_prog) parameters = [ - x.name for x in train_prog.list_vars() if x.persistable == True + x.name for x in train_prog.list_vars() if x.persistable is True ] ops = [op.type for op in avg_cost.block.ops] vars = [x.name for x in train_prog.list_vars()] @@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ]) -class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): +class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): def setUp(self): os.environ["PADDLE_TRAINER_ID"] = "3" os.environ[ @@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): self.sharding_ring_id = 1 self.dp_ring_id = 2 self.global_ring_id = 3 - self.pp_ring_id = 20 + self.pp_pair_ring_id = 20 def test_sharding_with_mp(self): # NOTE(JZ-LIANG) MP parallelism need user to build model with MP API @@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_1": + 0] == "comm_id_0": sharding_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) @@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_2": + 0] == "comm_id_1": dp_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) @@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_1": + 0] == "comm_id_0": sharding_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) @@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_2": + 0] == "comm_id_1": dp_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) @@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_1": + 0] == "comm_id_0": sharding_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) @@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_2": + 0] == "comm_id_1": dp_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) @@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): 'fill_constant', 'uniform_random', 'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', - 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum', - 'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init', - 'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream', - 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum', - 'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init', - 'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream' + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init' ]) self.assertEqual(main_prog_op_types, [ @@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): if op.type == "c_comm_init" ] self.assertIn(self.sharding_ring_id, created_ring_ids) - self.assertIn(self.pp_ring_id, created_ring_ids) + self.assertIn(self.pp_pair_ring_id, created_ring_ids) # check correctness of pp group sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_1": + 0] == "comm_id_0": sharding_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) @@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): sharding_group_waiting_port = None for op in startup_prog_ops: if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ - 0] == "nccl_id_2": + 0] == "comm_id_1": dp_group_waiting_ports = op.desc.attr("other_endpoints") self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) @@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer): if op.type == 'c_allreduce_sum': assert 'FusedOutput' in op.input_arg_names[0] + def test_hybrid_with_mp_pp_amp_gclip(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + self.set_strategy(strategy, 'amp') + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 2, + "pp_degree": 2, + "dp_degree": 1, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0) + self.optimizer( + avg_cost, strategy, train_prog, startup_prog, grad_clip=clip) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'fill_constant', 'uniform_random', + 'fill_constant', 'uniform_random', 'fill_constant', + 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init' + ]) + + # pp + mp, partial send recv + self.assertIn('partial_recv', main_prog_op_types) + self.assertIn('partial_allgather', main_prog_op_types) + self.assertIn('partial_send', main_prog_op_types) + + # amp check_finite_and_unscale, allreduce(mp)->allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 2) + + # global gradient clip, allreduce(mp)->allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_sum'), 2) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.mp_ring_id, created_ring_ids) + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + + # check correctness of pp group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + mp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(mp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of sharding group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_1": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) + if __name__ == "__main__": unittest.main()