未验证 提交 9066b74f 编写于 作者: W WangXi 提交者: GitHub

c_gen_nccl_id add SocketServer to persit server (#31589)

上级 a32e8bf1
......@@ -75,7 +75,8 @@ class CGenNCCLIdOp : public framework::OperatorBase {
platform::SendBroadCastCommID(endpoint_list, &nccl_ids);
} else {
std::string endpoint = Attr<std::string>("endpoint");
platform::RecvBroadCastCommID(endpoint, &nccl_ids);
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
}
CopyNCCLIDToVar(nccl_ids, func, scope);
......
......@@ -36,6 +36,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
// Check system calls, such as socket, bind.
......@@ -330,6 +332,22 @@ void RecvBroadCastCommID(int server_fd, std::string endpoint,
CloseSocket(client);
}
SocketServer& SocketServer::GetInstance(const std::string& end_point) {
static SocketServer instance;
std::call_once(init_flag_, [&]() {
instance.server_fd_ = CreateListenSocket(end_point);
instance.end_point_ = end_point;
});
PADDLE_ENFORCE_NE(instance.server_fd_, -1,
platform::errors::Unavailable(
"listen socket failed with end_point=%s", end_point));
PADDLE_ENFORCE_EQ(instance.end_point_, end_point,
platform::errors::InvalidArgument(
"old end_point=%s must equal with new end_point=%s",
instance.end_point_, end_point));
return instance;
}
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
......@@ -39,6 +41,24 @@ void RecvBroadCastCommID(std::string endpoint,
template <typename CommUniqueId>
void RecvBroadCastCommID(int server_fd, std::string endpoint,
std::vector<CommUniqueId>* nccl_ids);
class SocketServer {
public:
SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); }
int socket() const { return server_fd_; }
static SocketServer& GetInstance(const std::string& end_point);
private:
int server_fd_{-1};
std::string end_point_;
static std::once_flag init_flag_;
};
} // namespace platform
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册