提交 7741acd6 编写于 作者: S Shiyuan Shang-Guan

multi socket


Former-commit-id: 76c25553437543749ed58495709824bcee8e0b55
上级 52a6c519
......@@ -71,11 +71,32 @@ void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_ms
SocketMsg msg;
msg.msg_type = SocketMsgType::kActor;
msg.actor_msg = actor_msg;
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
int32_t link_i = std::uniform_int_distribution<int32_t>(0, epoll_conf_.link_num())(random_gen_);
GetSocketHelper(dst_machine_id, link_i)->AsyncWrite(msg);
}
void EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg) {
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
void EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& total_msg) {
const SocketMemDesc* src_mem_desc =
static_cast<const SocketMemDesc*>(total_msg.request_read_msg.src_token);
const SocketMemDesc* dst_mem_desc =
static_cast<const SocketMemDesc*>(total_msg.request_read_msg.dst_token);
CHECK_EQ(src_mem_desc->byte_size, dst_mem_desc->byte_size);
int32_t total_byte_size = src_mem_desc->byte_size;
int32_t offset = (total_byte_size + epoll_conf_.link_num() - 1) / epoll_conf_.link_num();
offset = RoundUp(offset, kCacheLineSize);
int32_t part_num = (total_byte_size + offset - 1) / offset;
for (int32_t link_i = 0; link_i < part_num; ++link_i) {
int32_t byte_size = (total_byte_size > offset) ? (offset) : (total_byte_size);
total_byte_size -= offset;
SocketMsg msg;
msg.msg_type = total_msg.msg_type;
msg.request_read_msg.src_token = NewMemDesc(src_mem_desc->mem_ptr + link_i * offset, byte_size);
msg.request_read_msg.dst_token = NewMemDesc(dst_mem_desc->mem_ptr + link_i * offset, byte_size);
msg.request_read_msg.read_id = total_msg.request_read_msg.read_id;
msg.request_read_msg.part_num = part_num;
GetSocketHelper(dst_machine_id, link_i)->AsyncWrite(msg);
}
CHECK_EQ(total_byte_size, 0);
}
SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {
......@@ -85,7 +106,8 @@ SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {
return mem_desc;
}
EpollCommNet::EpollCommNet(const Plan& plan) : CommNetIf(plan) {
EpollCommNet::EpollCommNet(const Plan& plan)
: CommNetIf(plan), epoll_conf_(Global<JobDesc>::Get()->epoll_conf()) {
pollers_.resize(Global<JobDesc>::Get()->CommNetWorkerNum(), nullptr);
for (size_t i = 0; i < pollers_.size(); ++i) { pollers_[i] = new IOEventPoller; }
InitSockets();
......@@ -96,7 +118,7 @@ void EpollCommNet::InitSockets() {
int64_t this_machine_id = Global<MachineCtx>::Get()->this_machine_id();
auto this_machine = Global<JobDesc>::Get()->resource().machine(this_machine_id);
int64_t total_machine_num = Global<JobDesc>::Get()->TotalMachineNum();
machine_id2sockfd_.assign(total_machine_num, -1);
machine_id2sockfds_.assign(total_machine_num * epoll_conf_.link_num(), -1);
sockfd2helper_.clear();
size_t poller_idx = 0;
auto NewSocketHelper = [&](int sockfd) {
......@@ -125,53 +147,72 @@ void EpollCommNet::InitSockets() {
int32_t src_machine_count = 0;
// connect
for (int64_t peer_id : peer_machine_id()) {
if (peer_id < this_machine_id) {
for (int64_t peer_mchn_id : peer_machine_id()) {
if (peer_mchn_id < this_machine_id) {
++src_machine_count;
continue;
}
uint16_t peer_port = PullPort(peer_id);
auto peer_machine = Global<JobDesc>::Get()->resource().machine(peer_id);
uint16_t peer_port = PullPort(peer_mchn_id);
auto peer_machine = Global<JobDesc>::Get()->resource().machine(peer_mchn_id);
sockaddr_in peer_sockaddr = GetSockAddr(peer_machine.addr(), peer_port);
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), sizeof(peer_sockaddr))
== 0);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
machine_id2sockfd_[peer_id] = sockfd;
for (int32_t link_i = 0; link_i < epoll_conf_.link_num(); ++link_i) {
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), sizeof(peer_sockaddr))
== 0);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
machine_id2sockfds_[peer_mchn_id * epoll_conf_.link_num() + link_i] = sockfd;
}
}
// accept
FOR_RANGE(int32_t, idx, 0, src_machine_count) {
sockaddr_in peer_sockaddr;
socklen_t len = sizeof(peer_sockaddr);
int sockfd = accept(listen_sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), &len);
PCHECK(sockfd != -1);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
int64_t peer_machine_id = GetMachineId(peer_sockaddr);
machine_id2sockfd_[peer_machine_id] = sockfd;
for (int32_t link_i = 0; link_i < epoll_conf_.link_num(); ++link_i) {
int sockfd = accept(listen_sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), &len);
PCHECK(sockfd != -1);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
int64_t peer_mchn_id = GetMachineId(peer_sockaddr);
machine_id2sockfds_[peer_mchn_id * epoll_conf_.link_num() + link_i] = sockfd;
}
}
PCHECK(close(listen_sockfd) == 0);
ClearPort(this_machine_id);
// useful log
FOR_RANGE(int64_t, machine_id, 0, total_machine_num) {
LOG(INFO) << "machine " << machine_id << " sockfd " << machine_id2sockfd_[machine_id];
FOR_RANGE(int32_t, link_i, 0, epoll_conf_.link_num()) {
LOG(INFO) << "machine: " << machine_id << ", link index: " << link_i << ", sockfd: "
<< machine_id2sockfds_[machine_id * epoll_conf_.link_num() + link_i];
}
}
}
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id) {
int sockfd = machine_id2sockfd_.at(machine_id);
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id, int32_t link_index) {
int sockfd = machine_id2sockfds_.at(machine_id * epoll_conf_.link_num() + link_index);
return sockfd2helper_.at(sockfd);
}
void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) {
CHECK(read_id2part_done_cnt_.emplace(read_id, 0).second);
SocketMsg msg;
msg.msg_type = SocketMsgType::kRequestWrite;
msg.request_write_msg.src_token = src_token;
msg.request_write_msg.dst_machine_id = Global<MachineCtx>::Get()->this_machine_id();
msg.request_write_msg.dst_token = dst_token;
msg.request_write_msg.read_id = read_id;
GetSocketHelper(src_machine_id)->AsyncWrite(msg);
int32_t link_i = std::uniform_int_distribution<int32_t>(0, epoll_conf_.link_num())(random_gen_);
GetSocketHelper(src_machine_id, link_i)->AsyncWrite(msg);
}
void EpollCommNet::PartReadDone(void* read_id, int32_t part_num) {
int32_t& part_read_done_cnt = read_id2part_done_cnt_.at(read_id);
std::unique_lock<std::mutex> lck(part_done_cnt_mtx_);
part_read_done_cnt++;
if (part_read_done_cnt == part_num) {
ReadDone(read_id);
read_id2part_done_cnt_.erase(read_id);
}
}
} // namespace oneflow
......
......@@ -20,18 +20,23 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);
void PartReadDone(void* read_id, int32_t part_num);
private:
SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) override;
EpollCommNet(const Plan& plan);
void InitSockets();
SocketHelper* GetSocketHelper(int64_t machine_id);
SocketHelper* GetSocketHelper(int64_t machine_id, int32_t link_index);
void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;
const EpollConf& epoll_conf_;
std::vector<IOEventPoller*> pollers_;
std::vector<int> machine_id2sockfd_;
std::vector<int> machine_id2sockfds_;
HashMap<int, SocketHelper*> sockfd2helper_;
std::mt19937 random_gen_;
std::mutex part_done_cnt_mtx_;
HashMap<void*, int32_t> read_id2part_done_cnt_;
};
template<>
......
......@@ -41,6 +41,7 @@ struct RequestReadMsg {
void* src_token;
void* dst_token;
void* read_id;
int32_t part_num;
};
struct SocketMsg {
......
......@@ -63,7 +63,8 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() {
void SocketReadHelper::SetStatusWhenMsgBodyDone() {
if (cur_msg_.msg_type == SocketMsgType::kRequestRead) {
Global<EpollCommNet>::Get()->ReadDone(cur_msg_.request_read_msg.read_id);
Global<EpollCommNet>::Get()->PartReadDone(cur_msg_.request_read_msg.read_id,
cur_msg_.request_read_msg.part_num);
}
SwitchToMsgHeadReadHandle();
}
......
......@@ -177,6 +177,7 @@ inline double GetCurTime() {
const size_t kCudaAlignSize = 8;
const size_t kCudaMemAllocAlignSize = 256;
const size_t kCacheLineSize = 64;
inline size_t RoundUp(size_t n, size_t val) { return (n + val - 1) / val * val; }
size_t GetAvailableCpuMemSize();
......
......@@ -50,12 +50,28 @@ message ExperimentalRunConf {
optional bool enable_experiment_run = 2 [default = false];
}
message EpollConf {
optional int32 link_num = 1 [default = 5];
}
message IBVerbsConf {
}
message CommNetworkConf {
oneof comm_net_type {
EpollConf epoll_conf = 1;
IBVerbsConf ibverbs_conf = 2;
}
}
message OtherConf {
required int64 piece_size = 2;
required int32 data_part_num = 3; // piece_size % data_part_num = 0
required int64 total_batch_num = 4;
required FileSystemConf data_fs_conf = 1;
required FileSystemConf snapshot_fs_conf = 2;
required CommNetworkConf comm_net_conf = 3;
required int64 piece_size = 4;
required int32 data_part_num = 5; // piece_size % data_part_num = 0
required int64 total_batch_num = 6;
optional bool use_rdma = 100 [default = false];
optional string model_load_snapshot_path = 101 [default = ""];
optional int32 max_data_id_length = 102 [default = 0];
optional bool enable_cudnn = 103 [default = true];
......@@ -67,8 +83,6 @@ message OtherConf {
optional bool save_downloaded_file_to_local_fs = 109 [default = false];
optional uint64 rdma_mem_block_mbyte = 110 [default = 8];
optional uint64 rdma_recv_msg_buf_mbyte = 111 [default = 6];
required FileSystemConf data_fs_conf = 112;
required FileSystemConf snapshot_fs_conf = 113;
optional bool collect_act_event = 125 [default = false];
optional bool enable_mem_sharing = 126 [default = true];
......
......@@ -255,7 +255,7 @@ void JobDesc::Init() {
SplitDecodeOps();
AddRecordLoadOps();
#ifndef WITH_RDMA
CHECK_EQ(job_conf_.other().use_rdma(), false) << "Please compile ONEFLOW with RDMA";
CHECK_EQ(this->use_rdma(), false) << "Please compile ONEFLOW with RDMA";
#endif
#ifndef WITH_CUDA
CHECK_EQ(job_conf_.other().enable_nccl(), false) << "Please compile ONEFLOW with NCCL";
......
......@@ -22,10 +22,19 @@ class JobDesc final {
const Resource& resource() const { return job_conf_.resource(); }
const Placement& placement() const { return job_conf_.placement(); }
const OtherConf& other_conf() const { return job_conf_.other(); }
const CommNetworkConf& comm_net_conf() const { return job_conf_.other().comm_net_conf(); }
bool use_rdma() const { return job_conf_.other().comm_net_conf().has_ibverbs_conf(); }
const EpollConf& epoll_conf() const {
CHECK(!this->use_rdma());
return this->comm_net_conf().epoll_conf();
}
const IBVerbsConf& ibverbs_conf() const {
CHECK(this->use_rdma());
return this->comm_net_conf().ibverbs_conf();
}
const std::string& MdLoadSnapshotPath() { return job_conf_.other().model_load_snapshot_path(); }
DataType DefaultDataType() const { return job_conf_.other().default_data_type(); }
size_t SizeOfOneDataId() const { return job_conf_.other().max_data_id_length() * sizeof(char); }
bool use_rdma() const { return job_conf_.other().use_rdma(); }
bool EnableCudnn() const { return job_conf_.other().enable_cudnn(); }
int64_t TotalMachineNum() const { return job_conf_.resource().machine().size(); }
int32_t CpuDeviceNum() const { return job_conf_.resource().cpu_device_num(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册