未验证 提交 1edf4374 编写于 作者: L LiuWei 提交者: GitHub

change hccl_helper as commid helper (#34118)

上级 348d043e
......@@ -23,15 +23,35 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h"
#endif
#include "paddle/fluid/platform/dynload/hccl.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_ASCEND_CL
static void GenHCCLID(std::vector<HcclRootInfo>* hccl_ids) {
for (size_t i = 0; i < hccl_ids->size(); ++i) {
PADDLE_ENFORCE_NPU_SUCCESS(
platform::dynload::HcclGetRootInfo(&(*hccl_ids)[i]));
}
}
static void CopyHCCLIDToVar(const std::vector<HcclRootInfo>& hccl_ids,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
for (size_t i = 0; i < hccl_ids.size(); ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto hccl_id = var->GetMutable<HcclRootInfo>();
memcpy(hccl_id, &hccl_ids[i], sizeof(HcclRootInfo));
}
}
class CGenHCCLIdOp : public framework::OperatorBase {
public:
CGenHCCLIdOp(const std::string& type,
......@@ -49,14 +69,22 @@ class CGenHCCLIdOp : public framework::OperatorBase {
return Output("Out");
};
std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
std::vector<HcclRootInfo> hccl_ids;
hccl_ids.resize(1);
if (rank == 0) {
GenHCCLID(&hccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
SendBroadCastHCCLID(endpoint_list, 1, func, local_scope);
platform::SendBroadCastCommID(endpoint_list, &hccl_ids);
} else {
std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastHCCLID(endpoint, 1, func, local_scope);
platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids);
}
CopyHCCLIDToVar(hccl_ids, func, scope);
scope.DeleteScope(&local_scope);
}
};
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include <arpa/inet.h>
......@@ -33,6 +33,10 @@ limitations under the License. */
#include "xpu/bkcl.h"
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle {
namespace platform {
......@@ -262,10 +266,17 @@ static int ConnectAddr(const std::string& ep, const char* head) {
return sock;
}
// TODO(WANGXI): maybe need to unify this hard code
#ifdef PADDLE_WITH_ASCEND_CL
#define MAX_COMMUNIQUEID_LEN 4108
#else
#define MAX_COMMUNIQUEID_LEN 1024
#endif
template <typename CommUniqueId>
static void RecvCommID(int conn, CommUniqueId* nccl_id) {
char buffer[1024] = {0};
static_assert(sizeof(CommUniqueId) <= 1024,
char buffer[MAX_COMMUNIQUEID_LEN] = {0};
static_assert(sizeof(CommUniqueId) <= MAX_COMMUNIQUEID_LEN,
"nccl id bytes must <= buffer size");
CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(CommUniqueId)),
......@@ -275,7 +286,7 @@ static void RecvCommID(int conn, CommUniqueId* nccl_id) {
template <typename CommUniqueId>
static void SendCommID(int conn, CommUniqueId* nccl_id) {
char buffer[1024] = {0};
char buffer[MAX_COMMUNIQUEID_LEN] = {0};
memcpy(buffer, nccl_id, sizeof(CommUniqueId));
CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(CommUniqueId)),
......@@ -361,6 +372,9 @@ INSTANT_TEMPLATE(ncclUniqueId)
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(BKCLUniqueId)
#endif
#ifdef PADDLE_WITH_ASCEND_CL
INSTANT_TEMPLATE(HcclRootInfo)
#endif
} // namespace platform
} // namespace paddle
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
#include <functional>
#include <memory>
#include <mutex>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册