提交 03e99e26 编写于 作者: W WangXi 提交者: sandyhouse

update

上级 7ab01c28
......@@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase {
SendBroadCastNCCLID(endpoint_list, 1, func, local_scope);
} else {
std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastNCCLID(endpoint, 1, func, local_scope);
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
}
scope.DeleteScope(&local_scope);
}
......@@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
}
const platform::Place& dev_place) const override {}
};
#endif
......
......@@ -31,7 +31,9 @@ limitations under the License. */
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace operators {
namespace platform {
std::once_flag SocketServer::init_flag_;
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
......@@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
CloseSocket(client);
}
} // namespace operators
SocketServer& SocketServer::GetInstance(const std::string& end_point) {
static SocketServer instance;
std::call_once(init_flag_, [&]() {
instance.server_fd_ = CreateListenSocket(end_point);
instance.end_point_ = end_point;
});
PADDLE_ENFORCE_NE(instance.server_fd_, -1,
platform::errors::Unavailable(
"listen socket failed with end_point=%s", end_point));
PADDLE_ENFORCE_EQ(instance.end_point_, end_point,
platform::errors::InvalidArgument(
"old end_point=%s must equal with new end_point=%s",
instance.end_point_, end_point));
return instance;
}
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE(ncclUniqueId)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(BKCLUniqueId)
#endif
} // namespace platform
} // namespace paddle
......@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <vector>
......@@ -25,7 +27,7 @@ class Scope;
} // namespace paddle
namespace paddle {
namespace operators {
namespace platform {
int CreateListenSocket(const std::string& ep);
......@@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
const framework::Scope& scope);
// recv nccl id from socket
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
std::function<std::string(size_t)> func,
const framework::Scope& scope);
} // namespace operators
template <typename CommUniqueId>
void RecvBroadCastCommID(int server_fd, std::string endpoint,
std::vector<CommUniqueId>* nccl_ids);
class SocketServer {
public:
SocketServer() = default;
~SocketServer() { CloseSocket(server_fd_); }
int socket() const { return server_fd_; }
static SocketServer& GetInstance(const std::string& end_point);
private:
int server_fd_{-1};
std::string end_point_;
static std::once_flag init_flag_;
};
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册