提交 73d5821a 编写于 作者: S Shiyuan Shang-Guan

port multi-socket to master


Former-commit-id: 80a189d6b3ea9fe976e1609bd3aa5e078e55c5c9
上级 fbfd1d0c
......@@ -22,7 +22,8 @@ class NormalBackwardCompActor final : public CompActor {
void AsyncReturnAllCustomizedReadableRegst() override;
std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()
override {
return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{"activation", "data_tmp", "out", "out_diff", "in"});
return std::make_pair(RegstNameType::kNaive,
HashSet<std::string>{"activation", "data_tmp", "out", "out_diff", "in"});
}
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;
void AsyncSendCustomizedConsumedRegstMsgToProducer() override;
......
......@@ -7,9 +7,9 @@ CommNet::~CommNet() {
ready_cb_poller_.join();
}
void* CommNet::NewActorReadId() { return new ActorReadContext; }
void* CommNet::NewActorReadId() const { return new ActorReadContext; }
void CommNet::DeleteActorReadId(void* actor_read_id) {
void CommNet::DeleteActorReadId(void* actor_read_id) const {
auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);
CHECK(actor_read_ctx->waiting_list.empty());
delete actor_read_ctx;
......
......@@ -29,20 +29,20 @@ class CommNet {
virtual void RegisterMemoryDone() = 0;
// Stream
void* NewActorReadId();
void DeleteActorReadId(void* actor_read_id);
void* NewActorReadId() const;
void DeleteActorReadId(void* actor_read_id) const;
void Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token);
void AddReadCallBack(void* actor_read_id, std::function<void()> callback);
void ReadDone(void* read_id);
//
virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) = 0;
virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const = 0;
protected:
CommNet(const Plan& plan);
virtual void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) = 0;
const HashSet<int64_t>& peer_machine_id() { return peer_machine_id_; }
const HashSet<int64_t>& peer_machine_id() const { return peer_machine_id_; }
Channel<std::function<void()>> ready_cbs_;
......@@ -84,8 +84,8 @@ class CommNetIf : public CommNet {
}
protected:
virtual MemDescType* NewMemDesc(void* ptr, size_t byte_size) = 0;
const HashSet<MemDescType*>& mem_descs() { return mem_descs_; }
virtual MemDescType* NewMemDesc(void* ptr, size_t byte_size) const = 0;
const HashSet<MemDescType*>& mem_descs() const { return mem_descs_; }
private:
std::mutex mem_descs_mtx_;
......
......@@ -16,9 +16,9 @@ sockaddr_in GetSockAddr(const std::string& addr, uint16_t port) {
return sa;
}
int SockListen(int listen_sockfd, uint16_t listen_port, int32_t total_machine_num) {
int32_t SockListen(int32_t listen_sockfd, uint16_t listen_port, int32_t total_machine_num) {
sockaddr_in sa = GetSockAddr("0.0.0.0", listen_port);
int bind_result = bind(listen_sockfd, reinterpret_cast<sockaddr*>(&sa), sizeof(sa));
int32_t bind_result = bind(listen_sockfd, reinterpret_cast<sockaddr*>(&sa), sizeof(sa));
if (bind_result == 0) {
PCHECK(listen(listen_sockfd, total_machine_num) == 0);
LOG(INFO) << "CommNet:Epoll listening on "
......@@ -56,7 +56,7 @@ uint16_t PullPort(int64_t machine_id) {
EpollCommNet::~EpollCommNet() {
for (size_t i = 0; i < pollers_.size(); ++i) {
LOG(INFO) << "CommNet Thread " << i << " finish";
pollers_[i]->Stop();
pollers_.at(i)->Stop();
}
OF_BARRIER();
for (IOEventPoller* poller : pollers_) { delete poller; }
......@@ -64,30 +64,54 @@ EpollCommNet::~EpollCommNet() {
}
void EpollCommNet::RegisterMemoryDone() {
// do nothing
for (void* dst_token : mem_descs()) { dst_token2part_done_cnt_[dst_token] = 0; }
}
void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) {
void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) const {
SocketMsg msg;
msg.msg_type = SocketMsgType::kActor;
msg.actor_msg = actor_msg;
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
GetSocketHelper(dst_machine_id, epoll_conf_.link_num() - 1)->AsyncWrite(msg);
}
void EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg) {
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
void EpollCommNet::RequestRead(int64_t dst_machine_id, void* src_token, void* dst_token,
void* read_id) const {
int32_t total_byte_size = static_cast<const SocketMemDesc*>(src_token)->byte_size;
CHECK_GT(total_byte_size, 0);
int32_t part_length = (total_byte_size + epoll_conf_.link_num() - 1) / epoll_conf_.link_num();
part_length = RoundUp(part_length, epoll_conf_.msg_segment_kbyte() * 1024);
int32_t part_num = (total_byte_size + part_length - 1) / part_length;
CHECK_LE(part_num, epoll_conf_.link_num());
for (int32_t link_i = 0; link_i < part_num; ++link_i) {
int32_t byte_size = (total_byte_size > part_length) ? (part_length) : (total_byte_size);
CHECK_GT(byte_size, 0);
total_byte_size -= byte_size;
SocketMsg msg;
msg.msg_type = SocketMsgType::kRequestRead;
msg.request_read_msg.src_token = src_token;
msg.request_read_msg.dst_token = dst_token;
msg.request_read_msg.offset = link_i * part_length;
msg.request_read_msg.byte_size = byte_size;
msg.request_read_msg.read_id = read_id;
msg.request_read_msg.part_num = part_num;
GetSocketHelper(dst_machine_id, link_i)->AsyncWrite(msg);
}
CHECK_EQ(total_byte_size, 0);
}
SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {
SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) const {
SocketMemDesc* mem_desc = new SocketMemDesc;
mem_desc->mem_ptr = ptr;
mem_desc->byte_size = byte_size;
return mem_desc;
}
EpollCommNet::EpollCommNet(const Plan& plan) : CommNetIf(plan) {
EpollCommNet::EpollCommNet(const Plan& plan)
: CommNetIf(plan), epoll_conf_(Global<JobDesc>::Get()->epoll_conf()) {
pollers_.resize(Global<JobDesc>::Get()->CommNetWorkerNum(), nullptr);
for (size_t i = 0; i < pollers_.size(); ++i) { pollers_[i] = new IOEventPoller; }
for (size_t i = 0; i < pollers_.size(); ++i) { pollers_.at(i) = new IOEventPoller; }
InitSockets();
for (IOEventPoller* poller : pollers_) { poller->Start(); }
}
......@@ -96,17 +120,17 @@ void EpollCommNet::InitSockets() {
int64_t this_machine_id = Global<MachineCtx>::Get()->this_machine_id();
auto this_machine = Global<JobDesc>::Get()->resource().machine(this_machine_id);
int64_t total_machine_num = Global<JobDesc>::Get()->TotalMachineNum();
machine_id2sockfd_.assign(total_machine_num, -1);
machine_link_id2sockfds_.assign(total_machine_num * epoll_conf_.link_num(), -1);
sockfd2helper_.clear();
size_t poller_idx = 0;
auto NewSocketHelper = [&](int sockfd) {
IOEventPoller* poller = pollers_[poller_idx];
auto NewSocketHelper = [&](int32_t sockfd) {
IOEventPoller* poller = pollers_.at(poller_idx);
poller_idx = (poller_idx + 1) % pollers_.size();
return new SocketHelper(sockfd, poller);
};
// listen
int listen_sockfd = socket(AF_INET, SOCK_STREAM, 0);
int32_t listen_sockfd = socket(AF_INET, SOCK_STREAM, 0);
int32_t this_listen_port = Global<JobDesc>::Get()->resource().data_port();
if (this_listen_port != -1) {
CHECK_EQ(SockListen(listen_sockfd, this_listen_port, total_machine_num), 0);
......@@ -125,42 +149,51 @@ void EpollCommNet::InitSockets() {
int32_t src_machine_count = 0;
// connect
for (int64_t peer_id : peer_machine_id()) {
if (peer_id < this_machine_id) {
for (int64_t peer_mchn_id : peer_machine_id()) {
if (peer_mchn_id < this_machine_id) {
++src_machine_count;
continue;
}
uint16_t peer_port = PullPort(peer_id);
auto peer_machine = Global<JobDesc>::Get()->resource().machine(peer_id);
uint16_t peer_port = PullPort(peer_mchn_id);
auto peer_machine = Global<JobDesc>::Get()->resource().machine(peer_mchn_id);
sockaddr_in peer_sockaddr = GetSockAddr(peer_machine.addr(), peer_port);
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), sizeof(peer_sockaddr))
== 0);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
machine_id2sockfd_[peer_id] = sockfd;
for (int32_t link_i = 0; link_i < epoll_conf_.link_num(); ++link_i) {
int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), sizeof(peer_sockaddr))
== 0);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
machine_link_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i) = sockfd;
}
}
// accept
FOR_RANGE(int32_t, idx, 0, src_machine_count) {
sockaddr_in peer_sockaddr;
socklen_t len = sizeof(peer_sockaddr);
int sockfd = accept(listen_sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), &len);
PCHECK(sockfd != -1);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
int64_t peer_machine_id = GetMachineId(peer_sockaddr);
machine_id2sockfd_[peer_machine_id] = sockfd;
for (int32_t link_i = 0; link_i < epoll_conf_.link_num(); ++link_i) {
int32_t sockfd = accept(listen_sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), &len);
PCHECK(sockfd != -1);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
int64_t peer_mchn_id = GetMachineId(peer_sockaddr);
machine_link_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i) = sockfd;
}
}
PCHECK(close(listen_sockfd) == 0);
ClearPort(this_machine_id);
// useful log
FOR_RANGE(int64_t, machine_id, 0, total_machine_num) {
LOG(INFO) << "machine " << machine_id << " sockfd " << machine_id2sockfd_[machine_id];
for (int64_t peer_mchn_id : peer_machine_id()) {
FOR_RANGE(int32_t, link_i, 0, epoll_conf_.link_num()) {
int32_t sockfd = machine_link_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i);
CHECK_GT(sockfd, 0);
LOG(INFO) << "machine: " << peer_mchn_id << ", link index: " << link_i
<< ", sockfd: " << sockfd;
}
}
}
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id) {
int sockfd = machine_id2sockfd_.at(machine_id);
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id, int32_t link_index) const {
int32_t sockfd = machine_link_id2sockfds_.at(machine_id * epoll_conf_.link_num() + link_index);
return sockfd2helper_.at(sockfd);
}
......@@ -171,7 +204,15 @@ void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token
msg.request_write_msg.dst_machine_id = Global<MachineCtx>::Get()->this_machine_id();
msg.request_write_msg.dst_token = dst_token;
msg.request_write_msg.read_id = read_id;
GetSocketHelper(src_machine_id)->AsyncWrite(msg);
dst_token2part_done_cnt_.at(dst_token) = 0;
GetSocketHelper(src_machine_id, epoll_conf_.link_num() - 1)->AsyncWrite(msg);
}
void EpollCommNet::PartReadDone(void* read_id, void* dst_token, int32_t part_num) {
if (dst_token2part_done_cnt_.at(dst_token).fetch_add(1, std::memory_order_relaxed)
== (part_num - 1)) {
ReadDone(read_id);
}
}
} // namespace oneflow
......
......@@ -18,20 +18,24 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
void RegisterMemoryDone() override;
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const override;
void RequestRead(int64_t dst_machine_id, void* src_token, void* dst_token, void* read_id) const;
void PartReadDone(void* read_id, void* dst_token, int32_t part_num);
private:
SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) override;
SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) const override;
EpollCommNet(const Plan& plan);
void InitSockets();
SocketHelper* GetSocketHelper(int64_t machine_id);
SocketHelper* GetSocketHelper(int64_t machine_id, int32_t link_index) const;
void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;
const EpollConf& epoll_conf_;
std::vector<IOEventPoller*> pollers_;
std::vector<int> machine_id2sockfd_;
// machine_link_id = machine_id * epoll_conf_.link_num() + link_id
std::vector<int64_t> machine_link_id2sockfds_;
HashMap<int, SocketHelper*> sockfd2helper_;
HashMap<void*, std::atomic<int32_t>> dst_token2part_done_cnt_;
};
template<>
......
......@@ -6,7 +6,7 @@
namespace oneflow {
const int IOEventPoller::max_event_num_ = 32;
const int32_t IOEventPoller::max_event_num_ = 32;
IOEventPoller::IOEventPoller() {
epfd_ = epoll_create1(0);
......@@ -26,12 +26,12 @@ IOEventPoller::~IOEventPoller() {
PCHECK(close(epfd_) == 0);
}
void IOEventPoller::AddFd(int fd, std::function<void()> read_handler,
void IOEventPoller::AddFd(int32_t fd, std::function<void()> read_handler,
std::function<void()> write_handler) {
AddFd(fd, &read_handler, &write_handler);
}
void IOEventPoller::AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler) {
void IOEventPoller::AddFdWithOnlyReadHandler(int32_t fd, std::function<void()> read_handler) {
AddFd(fd, &read_handler, nullptr);
}
......@@ -43,10 +43,10 @@ void IOEventPoller::Stop() {
thread_.join();
}
void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler,
void IOEventPoller::AddFd(int32_t fd, std::function<void()>* read_handler,
std::function<void()>* write_handler) {
// Set Fd NONBLOCK
int opt = fcntl(fd, F_GETFL);
int32_t opt = fcntl(fd, F_GETFL);
PCHECK(opt != -1);
PCHECK(fcntl(fd, F_SETFL, opt | O_NONBLOCK) == 0);
// Set CLOEXEC
......@@ -70,13 +70,13 @@ void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler,
void IOEventPoller::EpollLoop() {
while (true) {
int event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1);
int32_t event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1);
if (event_num == -1) {
PCHECK(errno == EINTR);
continue;
}
const epoll_event* cur_event = ep_events_;
for (int event_idx = 0; event_idx < event_num; ++event_idx, ++cur_event) {
for (int32_t event_idx = 0; event_idx < event_num; ++event_idx, ++cur_event) {
auto io_handler = static_cast<IOHandler*>(cur_event->data.ptr);
PCHECK(!(cur_event->events & EPOLLERR)) << "fd: " << io_handler->fd;
if (io_handler->fd == break_epoll_loop_fd_) { return; }
......
......@@ -13,8 +13,8 @@ class IOEventPoller final {
IOEventPoller();
~IOEventPoller();
void AddFd(int fd, std::function<void()> read_handler, std::function<void()> write_handler);
void AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler);
void AddFd(int32_t fd, std::function<void()> read_handler, std::function<void()> write_handler);
void AddFdWithOnlyReadHandler(int32_t fd, std::function<void()> read_handler);
void Start();
void Stop();
......@@ -28,18 +28,18 @@ class IOEventPoller final {
}
std::function<void()> read_handler;
std::function<void()> write_handler;
int fd;
int32_t fd;
};
void AddFd(int fd, std::function<void()>* read_handler, std::function<void()>* write_handler);
void AddFd(int32_t fd, std::function<void()>* read_handler, std::function<void()>* write_handler);
void EpollLoop();
static const int max_event_num_;
static const int32_t max_event_num_;
int epfd_;
int32_t epfd_;
epoll_event* ep_events_;
std::forward_list<IOHandler*> io_handlers_;
int break_epoll_loop_fd_;
int32_t break_epoll_loop_fd_;
std::thread thread_;
};
......
......@@ -4,7 +4,7 @@
namespace oneflow {
SocketHelper::SocketHelper(int sockfd, IOEventPoller* poller) {
SocketHelper::SocketHelper(int32_t sockfd, IOEventPoller* poller) {
read_helper_ = new SocketReadHelper(sockfd);
write_helper_ = new SocketWriteHelper(sockfd, poller);
poller->AddFd(sockfd, [this]() { read_helper_->NotifyMeSocketReadable(); },
......
......@@ -15,7 +15,7 @@ class SocketHelper final {
SocketHelper() = delete;
~SocketHelper();
SocketHelper(int sockfd, IOEventPoller* poller);
SocketHelper(int32_t sockfd, IOEventPoller* poller);
void AsyncWrite(const SocketMsg& msg);
......
......@@ -40,7 +40,10 @@ struct RequestWriteMsg {
struct RequestReadMsg {
void* src_token;
void* dst_token;
int64_t offset;
int64_t byte_size;
void* read_id;
int32_t part_num;
};
struct SocketMsg {
......
......@@ -10,7 +10,7 @@ SocketReadHelper::~SocketReadHelper() {
// do nothing
}
SocketReadHelper::SocketReadHelper(int sockfd) {
SocketReadHelper::SocketReadHelper(int32_t sockfd) {
sockfd_ = sockfd;
SwitchToMsgHeadReadHandle();
}
......@@ -63,26 +63,24 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() {
void SocketReadHelper::SetStatusWhenMsgBodyDone() {
if (cur_msg_.msg_type == SocketMsgType::kRequestRead) {
Global<EpollCommNet>::Get()->ReadDone(cur_msg_.request_read_msg.read_id);
Global<EpollCommNet>::Get()->PartReadDone(cur_msg_.request_read_msg.read_id,
cur_msg_.request_read_msg.dst_token,
cur_msg_.request_read_msg.part_num);
}
SwitchToMsgHeadReadHandle();
}
void SocketReadHelper::SetStatusWhenRequestWriteMsgHeadDone() {
SocketMsg msg_to_send;
msg_to_send.msg_type = SocketMsgType::kRequestRead;
msg_to_send.request_read_msg.src_token = cur_msg_.request_write_msg.src_token;
msg_to_send.request_read_msg.dst_token = cur_msg_.request_write_msg.dst_token;
msg_to_send.request_read_msg.read_id = cur_msg_.request_write_msg.read_id;
Global<EpollCommNet>::Get()->SendSocketMsg(cur_msg_.request_write_msg.dst_machine_id,
msg_to_send);
Global<EpollCommNet>::Get()->RequestRead(
cur_msg_.request_write_msg.dst_machine_id, cur_msg_.request_write_msg.src_token,
cur_msg_.request_write_msg.dst_token, cur_msg_.request_write_msg.read_id);
SwitchToMsgHeadReadHandle();
}
void SocketReadHelper::SetStatusWhenRequestReadMsgHeadDone() {
auto mem_desc = static_cast<const SocketMemDesc*>(cur_msg_.request_read_msg.dst_token);
read_ptr_ = reinterpret_cast<char*>(mem_desc->mem_ptr);
read_size_ = mem_desc->byte_size;
read_ptr_ = reinterpret_cast<char*>(mem_desc->mem_ptr) + cur_msg_.request_read_msg.offset;
read_size_ = cur_msg_.request_read_msg.byte_size;
cur_read_handle_ = &SocketReadHelper::MsgBodyReadHandle;
}
......
......@@ -13,7 +13,7 @@ class SocketReadHelper final {
SocketReadHelper() = delete;
~SocketReadHelper();
SocketReadHelper(int sockfd);
SocketReadHelper(int32_t sockfd);
void NotifyMeSocketReadable();
......@@ -32,7 +32,7 @@ class SocketReadHelper final {
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);
#undef MAKE_ENTRY
int sockfd_;
int32_t sockfd_;
SocketMsg cur_msg_;
bool (SocketReadHelper::*cur_read_handle_)();
......
......@@ -17,7 +17,7 @@ SocketWriteHelper::~SocketWriteHelper() {
}
}
SocketWriteHelper::SocketWriteHelper(int sockfd, IOEventPoller* poller) {
SocketWriteHelper::SocketWriteHelper(int32_t sockfd, IOEventPoller* poller) {
sockfd_ = sockfd;
queue_not_empty_fd_ = eventfd(0, 0);
PCHECK(queue_not_empty_fd_ != -1);
......@@ -116,8 +116,9 @@ void SocketWriteHelper::SetStatusWhenRequestWriteMsgHeadDone() {
void SocketWriteHelper::SetStatusWhenRequestReadMsgHeadDone() {
const void* src_token = cur_msg_.request_read_msg.src_token;
auto src_mem_desc = static_cast<const SocketMemDesc*>(src_token);
write_ptr_ = reinterpret_cast<const char*>(src_mem_desc->mem_ptr);
write_size_ = src_mem_desc->byte_size;
write_ptr_ =
reinterpret_cast<const char*>(src_mem_desc->mem_ptr) + cur_msg_.request_read_msg.offset;
write_size_ = cur_msg_.request_read_msg.byte_size;
cur_write_handle_ = &SocketWriteHelper::MsgBodyWriteHandle;
}
......
......@@ -14,7 +14,7 @@ class SocketWriteHelper final {
SocketWriteHelper() = delete;
~SocketWriteHelper();
SocketWriteHelper(int sockfd, IOEventPoller* poller);
SocketWriteHelper(int32_t sockfd, IOEventPoller* poller);
void AsyncWrite(const SocketMsg& msg);
......@@ -37,8 +37,8 @@ class SocketWriteHelper final {
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);
#undef MAKE_ENTRY
int sockfd_;
int queue_not_empty_fd_;
int32_t sockfd_;
int32_t queue_not_empty_fd_;
std::queue<SocketMsg>* cur_msg_queue_;
......
......@@ -49,7 +49,7 @@ void IBVerbsCommNet::RegisterMemoryDone() {
Global<CtrlClient>::Get()->ClearKV(GenTokensMsgKey(this_machine_id));
}
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const {
qp_vec_.at(dst_machine_id)->PostSendRequest(msg);
}
......
......@@ -23,10 +23,10 @@ class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
void RegisterMemoryDone() override;
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const override;
private:
IBVerbsMemDesc* NewMemDesc(void* ptr, size_t byte_size) override {
IBVerbsMemDesc* NewMemDesc(void* ptr, size_t byte_size) const override {
return new IBVerbsMemDesc(pd_, ptr, byte_size);
}
......
......@@ -46,11 +46,29 @@ message FileSystemConf {
}
}
message EpollConf {
optional int32 link_num = 1 [default = 5];
optional int32 msg_segment_kbyte = 2 [default = 256];
}
message IBVerbsConf {
}
message CommNetworkConf {
oneof comm_net_type {
EpollConf epoll_conf = 1;
IBVerbsConf ibverbs_conf = 2;
}
}
message OtherConf {
required int64 piece_size = 2;
required int32 data_part_num = 3; // piece_size % data_part_num = 0
required FileSystemConf data_fs_conf = 1;
required FileSystemConf snapshot_fs_conf = 2;
optional CommNetworkConf comm_net_conf = 3;
required int64 piece_size = 4;
required int32 data_part_num = 5; // piece_size % data_part_num = 0
optional bool use_rdma = 100 [default = false];
optional string model_load_snapshot_path = 101 [default = ""];
optional int32 max_data_id_length = 102 [default = 0];
optional bool enable_cudnn = 103 [default = true];
......@@ -72,9 +90,6 @@ message OtherConf {
optional bool use_nccl_inter_node_communication = 143 [default = false];
optional int64 cudnn_buf_limit_mbyte = 144 [default = 1024]; // 1GByte
required FileSystemConf data_fs_conf = 121;
required FileSystemConf snapshot_fs_conf = 122;
oneof JobType {
TrainConf train_conf = 200;
PredictConf predict_conf = 201;
......
......@@ -118,7 +118,9 @@ void JobDesc::Init() {
SplitDecodeOps();
AddRecordLoadOps();
#ifndef WITH_RDMA
CHECK_EQ(job_conf_.other().use_rdma(), false) << "Please compile ONEFLOW with RDMA";
if (this->TotalMachineNum() > 1) {
CHECK_EQ(job_conf_.other().use_rdma(), false) << "Please compile ONEFLOW with RDMA";
}
#endif
#ifndef WITH_CUDA
CHECK_EQ(job_conf_.other().enable_nccl(), false) << "Please compile ONEFLOW with NCCL";
......
......@@ -21,10 +21,22 @@ class JobDesc final {
const Resource& resource() const { return job_conf_.resource(); }
const Placement& placement() const { return job_conf_.placement(); }
const OtherConf& other_conf() const { return job_conf_.other(); }
const CommNetworkConf& comm_net_conf() const {
CHECK(this->other_conf().has_comm_net_conf());
return job_conf_.other().comm_net_conf();
}
bool use_rdma() const { return this->comm_net_conf().has_ibverbs_conf(); }
const EpollConf& epoll_conf() {
CHECK(!this->use_rdma());
return this->comm_net_conf().epoll_conf();
}
const IBVerbsConf& ibverbs_conf() const {
CHECK(this->use_rdma());
return this->comm_net_conf().ibverbs_conf();
}
const std::string& MdLoadSnapshotPath() { return job_conf_.other().model_load_snapshot_path(); }
DataType DefaultDataType() const { return job_conf_.other().default_data_type(); }
size_t SizeOfOneDataId() const { return job_conf_.other().max_data_id_length() * sizeof(char); }
bool use_rdma() const { return job_conf_.other().use_rdma(); }
bool EnableCudnn() const { return job_conf_.other().enable_cudnn(); }
int64_t TotalMachineNum() const { return job_conf_.resource().machine().size(); }
int32_t CpuDeviceNum() const { return job_conf_.resource().cpu_device_num(); }
......
......@@ -112,7 +112,7 @@ class Blob final {
}
void Init(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, char* body_ptr);
int32_t record_num_; // FIXME() by dim0
int32_t record_num_; // FIXME() by dim0
bool is_contiguous_;
void* header_ptr_;
char* data_id_ptr_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册