未验证 提交 56b7ebbc 编写于 作者: W WangXi 提交者: GitHub

[hybrid] remove the using of global ring in hybrid parallel (#34525)

上级 9b6c7eb9
......@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm(
platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, xpu_id,
ring_id);
}
......@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm(
platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);
}
......
......@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm(
platform::NCCLCommContext::Instance().CreateComm(
&nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id,
ring_id);
......@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm(
platform::NCCLCommContext::Instance().CreateComm(
&nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id);
compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
......
......@@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
// TODO(wangxi): Put this in the unified header file
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
using UniqueId = ncclUniqueId;
using Place = platform::CUDAPlace;
using CommContext = platform::NCCLCommContext;
#elif defined(PADDLE_WITH_XPU_BKCL)
using UniqueId = BKCLUniqueId;
using Place = platform::XPUPlace;
using CommContext = platform::BKCLCommContext;
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU or XPU."));
#endif
PADDLE_ENFORCE_EQ(is_gpu_place(place) || is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only."));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
if (is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU."));
#endif
} else if (is_xpu_place(place)) {
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId* bkcl_id = var->GetMutable<BKCLUniqueId>();
UniqueId* comm_id = var->GetMutable<UniqueId>();
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
#if defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_EQ(
rid, 0,
platform::errors::OutOfRange(
"Ring id must equal 0 in multi Kunlun cards training, but got %d",
rid));
int device_id = BOOST_GET_CONST(platform::XPUPlace, place).device;
#endif
int device_id = BOOST_GET_CONST(Place, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::BKCLCommContext::Instance().CreateBKCLComm(
bkcl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with XPU."));
CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id,
rid);
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only."));
}
}
};
......
......@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
......@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
GenBKCLID(&bkcl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids);
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id);
} else {
std::string endpoint = Attr<std::string>("endpoint");
platform::RecvBroadCastCommID(endpoint, &bkcl_ids);
platform::RecvBroadCastCommID(endpoint, &bkcl_ids, ring_id);
}
CopyBKCLIDToVar(bkcl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};
......@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
......
......@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
......@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
GenHCCLID(&hccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &hccl_ids);
platform::SendBroadCastCommID(endpoint_list, &hccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids);
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids, ring_id);
}
CopyHCCLIDToVar(hccl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};
......@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
......
......@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
int ring_id = Attr<int>("ring_id");
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
......@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
GenNCCLID(&nccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids);
platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id);
}
CopyNCCLIDToVar(nccl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};
......@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
......
......@@ -72,7 +72,7 @@ class NCCLCommImpl : public NCCLComm {
std::shared_ptr<platform::CudaEventObject> comm_event_;
};
NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
NCCLComm* NCCLCommContext::CreateComm(ncclUniqueId* nccl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(nccl_id,
platform::errors::InvalidArgument(
......@@ -225,7 +225,7 @@ class BKCLCommImpl : public BKCLComm {
std::unique_ptr<XPUDeviceContext> dev_ctx_;
};
BKCLComm* BKCLCommContext::CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks,
BKCLComm* BKCLCommContext::CreateComm(BKCLUniqueId* bkcl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(bkcl_id,
platform::errors::InvalidArgument(
......
......@@ -72,8 +72,8 @@ class NCCLCommContext {
return comm_ctx;
}
NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank,
int dev_id, int ring_id = 0);
NCCLComm* CreateComm(ncclUniqueId* nccl_id, int nranks, int rank, int dev_id,
int ring_id = 0);
void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
......@@ -274,8 +274,8 @@ class BKCLCommContext {
return comm_ctx;
}
BKCLComm* CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, int rank,
int dev_id, int ring_id = 0);
BKCLComm* CreateComm(BKCLUniqueId* bkcl_id, int nranks, int rank, int dev_id,
int ring_id = 0);
void CreateAllBKCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
......
......@@ -42,7 +42,10 @@ namespace platform {
std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
struct CommHead {
int version = 1; // unused for now
int ring_id = 0;
};
// Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \
......@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) {
void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); }
static int SocketAccept(int server_fd, const char* head) {
static int SocketAccept(int server_fd, const CommHead head) {
static_assert(sizeof(CommHead) <= 1024,
"sizeof(CommHead) must <= buffer size");
struct sockaddr_in client_addr;
socklen_t addr_length = sizeof(client_addr);
char buffer[1024] = {0};
int conn = -1;
const char* phead = reinterpret_cast<const char*>(&head);
while (true) {
CHECK_SYS_CALL_VAL(
......@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) {
&addr_length),
"accept", conn);
int ret_val = SocketRecv(conn, buffer, strlen(head));
if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) {
int ret_val = SocketRecv(conn, buffer, sizeof(head));
if (ret_val > 0 && memcmp(buffer, phead, sizeof(head)) == 0) {
// send a message to the sender, indicating that the link is correct
CHECK_SYS_CALL(SocketSend(conn, phead, sizeof(head)), "send");
break; // accept client
} else {
VLOG(3) << "socket read failed with ret_val=" << ret_val;
......@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) {
return conn;
}
static int ConnectAddr(const std::string& ep, const char* head) {
static int ConnectAddr(const std::string& ep, const CommHead head) {
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
......@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) {
std::string host = addr[0];
int port = std::stoi(addr[1]);
int sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
......@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) {
platform::errors::Unavailable("Open address %s failed: %s",
ep, strerror(errno)));
static_assert(sizeof(CommHead) <= 1024,
"sizeof(CommHead) must <= buffer size");
char buffer[1024] = {0};
const char* phead = reinterpret_cast<const char*>(&head);
// TODO(wangxi) Set from env, default 900s=15min
int timeout = 900 * 1000;
int try_times = 0;
int total_time = 0;
int sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
while (true) {
int ret_val = -1;
RETRY_SYS_CALL_VAL(
......@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) {
continue;
}
CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send");
break;
CHECK_SYS_CALL(SocketSend(sock, phead, sizeof(head)), "send");
ret_val = SocketRecv(sock, buffer, sizeof(head));
if (ret_val > 0 && memcmp(buffer, phead, sizeof(head)) == 0) {
// recv same message from recver, indicating that the link is correct
break; // accept client
} else {
VLOG(3) << "socket read failed with ret_val=" << ret_val;
CloseSocket(sock);
}
sock = -1;
CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock);
// unmatched link, retry after 80ms
std::this_thread::sleep_for(std::chrono::milliseconds(80));
}
return sock;
}
......@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
template <typename CommUniqueId>
void SendBroadCastCommID(std::vector<std::string> servers,
std::vector<CommUniqueId>* nccl_ids) {
std::vector<CommUniqueId>* nccl_ids, int ring_id) {
CommHead head;
head.ring_id = ring_id;
// connect with server
std::vector<int> connects;
for (auto server : servers) {
VLOG(3) << "connecting endpoint: " << server;
int conn = ConnectAddr(server, COMM_HEAD);
int conn = ConnectAddr(server, head);
connects.push_back(conn);
}
VLOG(3) << "connecting completed...";
......@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers,
template <typename CommUniqueId>
void RecvBroadCastCommID(std::string endpoint,
std::vector<CommUniqueId>* nccl_ids) {
std::vector<CommUniqueId>* nccl_ids, int ring_id) {
int server = CreateListenSocket(endpoint);
RecvBroadCastCommID(server, endpoint, nccl_ids);
RecvBroadCastCommID(server, endpoint, nccl_ids, ring_id);
CloseSocket(server);
}
template <typename CommUniqueId>
void RecvBroadCastCommID(int server_fd, std::string endpoint,
std::vector<CommUniqueId>* nccl_ids) {
int client = SocketAccept(server_fd, COMM_HEAD);
std::vector<CommUniqueId>* nccl_ids, int ring_id) {
CommHead head;
head.ring_id = ring_id;
int client = SocketAccept(server_fd, head);
for (size_t i = 0; i < nccl_ids->size(); ++i) {
VLOG(3) << "trainer: " << endpoint
......@@ -362,9 +392,13 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) {
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
std::vector<Type> * nccl_ids, \
int ring_id = 0); \
template void RecvBroadCastCommID<Type>( \
std::string endpoint, std::vector<Type> * nccl_ids, int ring_id = 0); \
template void RecvBroadCastCommID<Type>(int server_fd, std::string endpoint, \
std::vector<Type>* nccl_ids, \
int ring_id = 0);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE(ncclUniqueId)
......
......@@ -31,16 +31,16 @@ void CloseSocket(int fd);
template <typename CommUniqueId>
void SendBroadCastCommID(std::vector<std::string> servers,
std::vector<CommUniqueId>* nccl_ids);
std::vector<CommUniqueId>* nccl_ids, int ring_id = 0);
template <typename CommUniqueId>
void RecvBroadCastCommID(std::string endpoint,
std::vector<CommUniqueId>* nccl_ids);
std::vector<CommUniqueId>* nccl_ids, int ring_id = 0);
// recv nccl id from socket
template <typename CommUniqueId>
void RecvBroadCastCommID(int server_fd, std::string endpoint,
std::vector<CommUniqueId>* nccl_ids);
std::vector<CommUniqueId>* nccl_ids, int ring_id = 0);
class SocketServer {
public:
......
......@@ -126,11 +126,11 @@ class CollectiveHelper(object):
_add_sync_by_allreduce(block)
return
if core.is_compiled_with_cuda():
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
name=unique_name.generate('comm_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
if core.is_compiled_with_cuda():
block.append_op(
type='c_gen_nccl_id',
inputs={},
......@@ -139,6 +139,7 @@ class CollectiveHelper(object):
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
......@@ -152,10 +153,6 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_xpu():
comm_id_var = block.create_var(
name=unique_name.generate('bkcl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_bkcl_id',
inputs={},
......@@ -164,6 +161,7 @@ class CollectiveHelper(object):
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
......@@ -177,24 +175,20 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_npu():
hccl_id_var = block.create_var(
name=unique_name.generate('hccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_gen_hccl_id',
inputs={},
outputs={'Out': hccl_id_var},
outputs={'Out': comm_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init_hccl',
inputs={'X': hccl_id_var},
inputs={'X': comm_id_var},
outputs={},
attrs={
'rank': rank,
......
......@@ -73,7 +73,7 @@ class FP16Utils(object):
return inserted_op_num
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
"""
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
......@@ -146,6 +146,7 @@ class FP16Utils(object):
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
......@@ -156,9 +157,14 @@ class FP16Utils(object):
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
for ring_id in ring_ids:
if ring_id == -1: continue
# this allreduce communication should not overlap with calc
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
update_loss_scaling_op_idx,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
......@@ -167,8 +173,10 @@ class FP16Utils(object):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
......@@ -177,11 +185,12 @@ class FP16Utils(object):
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._sync_with_cpp()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod
def sync_amp_check_nan_inf(block, ring_id):
def sync_amp_check_nan_inf(block, ring_ids):
update_loss_scaling_op_idx = -1
for idx, op in reversed(list(enumerate(block.ops))):
......@@ -189,10 +198,14 @@ class FP16Utils(object):
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
break
# not use amp
if update_loss_scaling_op_idx == -1:
return
# 0. inf_var_int32 = cast(inf_var)
# 1. inf_var_int32 = allreduce_max(inf_var_int32)
# 3. inf_var = cast(inf_var_int32)
inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
......@@ -212,8 +225,13 @@ class FP16Utils(object):
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
# allreduce(mp)->allreduce(pp)
for ring_id in ring_ids:
if ring_id == -1: continue
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
update_loss_scaling_op_idx,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
......@@ -222,8 +240,10 @@ class FP16Utils(object):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global},
......@@ -232,4 +252,5 @@ class FP16Utils(object):
"out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
block._sync_with_cpp()
......@@ -25,7 +25,7 @@ class GradientClipHelper(object):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
def prune_gradient_clip(self, block, shard, ring_ids):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
......@@ -82,33 +82,23 @@ class GradientClipHelper(object):
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset = 1
for ring_id in ring_ids:
if ring_id == -1: continue
# this allreduce should not overlap with calc and should be scheduled in calc stream
block._insert_op_without_sync(
idx + 1,
idx + idx_offset,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': self.mp_ring_id,
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
idx_offset += 1
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
......@@ -126,20 +116,25 @@ class GradientClipHelper(object):
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
def sync_global_norm(self, block, ring_ids):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
# FIXME(wangxi): mp should prune duplicated param_grads
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
for ring_id in ring_ids:
if ring_id == -1: continue
idx = idx + 1
block._insert_op_without_sync(
idx + 1,
idx,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
......@@ -149,20 +144,4 @@ class GradientClipHelper(object):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
return
......@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
if self.sharding_degree <= 1:
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp
FP16Utils.sync_amp_check_nan_inf(main_block, self.global_ring_id)
FP16Utils.sync_amp_check_nan_inf(
main_block, [self.mp_ring_id, self.pp_ring_id])
# clip
gradientclip_helper = GradientClipHelper(self.global_ring_id)
gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm(
main_block, self.global_ring_id, self.dp_degree)
main_block, [self.mp_ring_id, self.pp_ring_id])
# step6: loss div dp_degree
global_dp_degree = self.sharding_degree * self.dp_degree
......@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
def _init_npu_pipeline_comm(self, startup_block):
......@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair = send_to_next_pair if even else recv_from_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair)
logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(pair,
ring_id))
......@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair)
logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair,
ring_id))
......@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair[0] * 1000 + pair[1],
max_ring_id + 1) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format(
......@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair[0] * 1000 + pair[1],
max_ring_id + 2) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format(
......@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase):
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id,
False,
sync=False)
if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block)
return
......@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank in pair:
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
def _init_comm(self):
......@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase):
dtype=core.VarDesc.VarType.INT32,
persistable=False)
# global ring
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.global_endpoints,
self.global_rank,
self.global_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# mp ring
if self.mp_degree > 1:
self._collective_helper._init_communicator(
......@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.mp_rank,
self.mp_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding ring
if self.sharding_degree > 1:
......@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_rank,
self.sharding_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# pp ring
if self.pp_degree > 1:
......@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_rank,
self.dp_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
startup_block._sync_with_cpp()
......@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard)
# FIXME(wangxi): mp should prune duplicated param_grads
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self.global_ring_id)
FP16Utils.prune_fp16(
block, self._shard, self._reduced_grads_to_param,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
if self.mp_degree * self.pp_degree == 1:
# separate the sharding-hybrid senario to keep the accuracy
gradientclip_helper = GradientClipHelper(self.sharding_ring_id)
gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.prune_gradient_clip(
block, self._shard, pure_dp_degree=1)
else:
gradientclip_helper = GradientClipHelper(self.global_ring_id)
gradientclip_helper.prune_gradient_clip(
block, self._shard, pure_dp_degree=self.dp_degree)
block, self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
# build prog deps
reduced_grads = []
......@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp
if self.pp_degree > 1:
self.pp_ring_id = 20
self.pp_pair_ring_id = 20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self.pp_ring_id = 4
self.pp_rank = self.global_rank // (self.sharding_degree *
self.mp_degree) % self.pp_degree
# (NOTE): Already adjust for (outter-pure) dp
......@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_first_stage_idx + pp_stage_offset * i])
assert self.current_endpoint in self.pp_group_endpoints
else:
self.pp_degree = 1
self.pp_ring_id = -1
self.pp_degree = 1
self.pp_pair_ring_id = -1
self.pp_rank = -1
self.pp_group_id = -1
self.pp_group_endpoints = []
......@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase):
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
# sync within global group
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding gradient merge
def create_persistable_gradients_and_insert_merge_ops(
......
......@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self.set_strategy(strategy, 'sharding')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
x.name for x in train_prog.list_vars() if x.persistable is True
]
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
......@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
])
class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "3"
os.environ[
......@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
self.sharding_ring_id = 1
self.dp_ring_id = 2
self.global_ring_id = 3
self.pp_ring_id = 20
self.pp_pair_ring_id = 20
def test_sharding_with_mp(self):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
......@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
......@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
......@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
......@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
......@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
......@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
......@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream',
'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'c_allreduce_sum',
'c_sync_calc_stream', 'c_gen_nccl_id', 'c_comm_init',
'fill_constant', 'c_allreduce_sum', 'c_sync_calc_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init'
])
self.assertEqual(main_prog_op_types, [
......@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if op.type == "c_comm_init"
]
self.assertIn(self.sharding_ring_id, created_ring_ids)
self.assertIn(self.pp_ring_id, created_ring_ids)
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
# check correctness of pp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
......@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
0] == "comm_id_1":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
......@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if op.type == 'c_allreduce_sum':
assert 'FusedOutput' in op.input_arg_names[0]
def test_hybrid_with_mp_pp_amp_gclip(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
self.set_strategy(strategy, 'amp')
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 2,
"pp_degree": 2,
"dp_degree": 1,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init'
])
# pp + mp, partial send recv
self.assertIn('partial_recv', main_prog_op_types)
self.assertIn('partial_allgather', main_prog_op_types)
self.assertIn('partial_send', main_prog_op_types)
# amp check_finite_and_unscale, allreduce(mp)->allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 2)
# global gradient clip, allreduce(mp)->allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_sum'), 2)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.mp_ring_id, created_ring_ids)
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
# check correctness of pp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
mp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(mp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_1":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册