提交 03e99e26 编写于 作者: W WangXi 提交者: sandyhouse

update

上级 7ab01c28
...@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase {
SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); SendBroadCastNCCLID(endpoint_list, 1, func, local_scope);
} else { } else {
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastNCCLID(endpoint, 1, func, local_scope); int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
} }
scope.DeleteScope(&local_scope); scope.DeleteScope(&local_scope);
} }
...@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {}
}
}; };
#endif #endif
......
...@@ -31,7 +31,9 @@ limitations under the License. */ ...@@ -31,7 +31,9 @@ limitations under the License. */
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
namespace paddle { namespace paddle {
namespace operators { namespace platform {
std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
...@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, ...@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
CloseSocket(client); CloseSocket(client);
} }
} // namespace operators SocketServer& SocketServer::GetInstance(const std::string& end_point) {
static SocketServer instance;
std::call_once(init_flag_, [&]() {
instance.server_fd_ = CreateListenSocket(end_point);
instance.end_point_ = end_point;
});
PADDLE_ENFORCE_NE(instance.server_fd_, -1,
platform::errors::Unavailable(
"listen socket failed with end_point=%s", end_point));
PADDLE_ENFORCE_EQ(instance.end_point_, end_point,
platform::errors::InvalidArgument(
"old end_point=%s must equal with new end_point=%s",
instance.end_point_, end_point));
return instance;
}
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE(ncclUniqueId)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(BKCLUniqueId)
#endif
} // namespace platform
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <functional> #include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -25,7 +27,7 @@ class Scope; ...@@ -25,7 +27,7 @@ class Scope;
} // namespace paddle } // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace platform {
int CreateListenSocket(const std::string& ep); int CreateListenSocket(const std::string& ep);
...@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, ...@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
const framework::Scope& scope); const framework::Scope& scope);
// recv nccl id from socket // recv nccl id from socket
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void RecvBroadCastCommID(int server_fd, std::string endpoint,
const framework::Scope& scope); std::vector<CommUniqueId>* nccl_ids);
} // namespace operators
class SocketServer {
public:
SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); }
int socket() const { return server_fd_; }
static SocketServer& GetInstance(const std::string& end_point);
private:
int server_fd_{-1};
std::string end_point_;
static std::once_flag init_flag_;
};
} // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册