Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
73d5821a
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
73d5821a
编写于
1月 23, 2019
作者:
S
Shiyuan Shang-Guan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
port multi-socket to master
Former-commit-id: 80a189d6b3ea9fe976e1609bd3aa5e078e55c5c9
上级
fbfd1d0c
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
171 addition
and
94 deletion
+171
-94
oneflow/core/actor/normal_backward_compute_actor.h
oneflow/core/actor/normal_backward_compute_actor.h
+2
-1
oneflow/core/comm_network/comm_network.cpp
oneflow/core/comm_network/comm_network.cpp
+2
-2
oneflow/core/comm_network/comm_network.h
oneflow/core/comm_network/comm_network.h
+6
-6
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
+75
-34
oneflow/core/comm_network/epoll/epoll_comm_network.h
oneflow/core/comm_network/epoll/epoll_comm_network.h
+9
-5
oneflow/core/comm_network/epoll/io_event_poller.cpp
oneflow/core/comm_network/epoll/io_event_poller.cpp
+7
-7
oneflow/core/comm_network/epoll/io_event_poller.h
oneflow/core/comm_network/epoll/io_event_poller.h
+7
-7
oneflow/core/comm_network/epoll/socket_helper.cpp
oneflow/core/comm_network/epoll/socket_helper.cpp
+1
-1
oneflow/core/comm_network/epoll/socket_helper.h
oneflow/core/comm_network/epoll/socket_helper.h
+1
-1
oneflow/core/comm_network/epoll/socket_message.h
oneflow/core/comm_network/epoll/socket_message.h
+3
-0
oneflow/core/comm_network/epoll/socket_read_helper.cpp
oneflow/core/comm_network/epoll/socket_read_helper.cpp
+9
-11
oneflow/core/comm_network/epoll/socket_read_helper.h
oneflow/core/comm_network/epoll/socket_read_helper.h
+2
-2
oneflow/core/comm_network/epoll/socket_write_helper.cpp
oneflow/core/comm_network/epoll/socket_write_helper.cpp
+4
-3
oneflow/core/comm_network/epoll/socket_write_helper.h
oneflow/core/comm_network/epoll/socket_write_helper.h
+3
-3
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
+1
-1
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
+2
-2
oneflow/core/job/job_conf.proto
oneflow/core/job/job_conf.proto
+21
-6
oneflow/core/job/job_desc.cpp
oneflow/core/job/job_desc.cpp
+3
-1
oneflow/core/job/job_desc.h
oneflow/core/job/job_desc.h
+13
-1
未找到文件。
oneflow/core/actor/normal_backward_compute_actor.h
浏览文件 @
73d5821a
...
...
@@ -22,7 +22,8 @@ class NormalBackwardCompActor final : public CompActor {
void
AsyncReturnAllCustomizedReadableRegst
()
override
;
std
::
pair
<
RegstNameType
,
HashSet
<
std
::
string
>>
GetNaiveOrCustomizedConsumedRegstDescName
()
override
{
return
std
::
make_pair
(
RegstNameType
::
kNaive
,
HashSet
<
std
::
string
>
{
"activation"
,
"data_tmp"
,
"out"
,
"out_diff"
,
"in"
});
return
std
::
make_pair
(
RegstNameType
::
kNaive
,
HashSet
<
std
::
string
>
{
"activation"
,
"data_tmp"
,
"out"
,
"out_diff"
,
"in"
});
}
void
VirtualAsyncSendNaiveProducedRegstMsgToConsumer
()
override
;
void
AsyncSendCustomizedConsumedRegstMsgToProducer
()
override
;
...
...
oneflow/core/comm_network/comm_network.cpp
浏览文件 @
73d5821a
...
...
@@ -7,9 +7,9 @@ CommNet::~CommNet() {
ready_cb_poller_
.
join
();
}
void
*
CommNet
::
NewActorReadId
()
{
return
new
ActorReadContext
;
}
void
*
CommNet
::
NewActorReadId
()
const
{
return
new
ActorReadContext
;
}
void
CommNet
::
DeleteActorReadId
(
void
*
actor_read_id
)
{
void
CommNet
::
DeleteActorReadId
(
void
*
actor_read_id
)
const
{
auto
actor_read_ctx
=
static_cast
<
ActorReadContext
*>
(
actor_read_id
);
CHECK
(
actor_read_ctx
->
waiting_list
.
empty
());
delete
actor_read_ctx
;
...
...
oneflow/core/comm_network/comm_network.h
浏览文件 @
73d5821a
...
...
@@ -29,20 +29,20 @@ class CommNet {
virtual
void
RegisterMemoryDone
()
=
0
;
// Stream
void
*
NewActorReadId
();
void
DeleteActorReadId
(
void
*
actor_read_id
);
void
*
NewActorReadId
()
const
;
void
DeleteActorReadId
(
void
*
actor_read_id
)
const
;
void
Read
(
void
*
actor_read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
);
void
AddReadCallBack
(
void
*
actor_read_id
,
std
::
function
<
void
()
>
callback
);
void
ReadDone
(
void
*
read_id
);
//
virtual
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
=
0
;
virtual
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
=
0
;
protected:
CommNet
(
const
Plan
&
plan
);
virtual
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
=
0
;
const
HashSet
<
int64_t
>&
peer_machine_id
()
{
return
peer_machine_id_
;
}
const
HashSet
<
int64_t
>&
peer_machine_id
()
const
{
return
peer_machine_id_
;
}
Channel
<
std
::
function
<
void
()
>>
ready_cbs_
;
...
...
@@ -84,8 +84,8 @@ class CommNetIf : public CommNet {
}
protected:
virtual
MemDescType
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
=
0
;
const
HashSet
<
MemDescType
*>&
mem_descs
()
{
return
mem_descs_
;
}
virtual
MemDescType
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
=
0
;
const
HashSet
<
MemDescType
*>&
mem_descs
()
const
{
return
mem_descs_
;
}
private:
std
::
mutex
mem_descs_mtx_
;
...
...
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
浏览文件 @
73d5821a
...
...
@@ -16,9 +16,9 @@ sockaddr_in GetSockAddr(const std::string& addr, uint16_t port) {
return
sa
;
}
int
SockListen
(
in
t
listen_sockfd
,
uint16_t
listen_port
,
int32_t
total_machine_num
)
{
int
32_t
SockListen
(
int32_
t
listen_sockfd
,
uint16_t
listen_port
,
int32_t
total_machine_num
)
{
sockaddr_in
sa
=
GetSockAddr
(
"0.0.0.0"
,
listen_port
);
int
bind_result
=
bind
(
listen_sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
sa
),
sizeof
(
sa
));
int
32_t
bind_result
=
bind
(
listen_sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
sa
),
sizeof
(
sa
));
if
(
bind_result
==
0
)
{
PCHECK
(
listen
(
listen_sockfd
,
total_machine_num
)
==
0
);
LOG
(
INFO
)
<<
"CommNet:Epoll listening on "
...
...
@@ -56,7 +56,7 @@ uint16_t PullPort(int64_t machine_id) {
EpollCommNet
::~
EpollCommNet
()
{
for
(
size_t
i
=
0
;
i
<
pollers_
.
size
();
++
i
)
{
LOG
(
INFO
)
<<
"CommNet Thread "
<<
i
<<
" finish"
;
pollers_
[
i
]
->
Stop
();
pollers_
.
at
(
i
)
->
Stop
();
}
OF_BARRIER
();
for
(
IOEventPoller
*
poller
:
pollers_
)
{
delete
poller
;
}
...
...
@@ -64,30 +64,54 @@ EpollCommNet::~EpollCommNet() {
}
void
EpollCommNet
::
RegisterMemoryDone
()
{
// do nothing
for
(
void
*
dst_token
:
mem_descs
())
{
dst_token2part_done_cnt_
[
dst_token
]
=
0
;
}
}
void
EpollCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
actor_msg
)
{
void
EpollCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
actor_msg
)
const
{
SocketMsg
msg
;
msg
.
msg_type
=
SocketMsgType
::
kActor
;
msg
.
actor_msg
=
actor_msg
;
GetSocketHelper
(
dst_machine_id
)
->
AsyncWrite
(
msg
);
GetSocketHelper
(
dst_machine_id
,
epoll_conf_
.
link_num
()
-
1
)
->
AsyncWrite
(
msg
);
}
void
EpollCommNet
::
SendSocketMsg
(
int64_t
dst_machine_id
,
const
SocketMsg
&
msg
)
{
GetSocketHelper
(
dst_machine_id
)
->
AsyncWrite
(
msg
);
void
EpollCommNet
::
RequestRead
(
int64_t
dst_machine_id
,
void
*
src_token
,
void
*
dst_token
,
void
*
read_id
)
const
{
int32_t
total_byte_size
=
static_cast
<
const
SocketMemDesc
*>
(
src_token
)
->
byte_size
;
CHECK_GT
(
total_byte_size
,
0
);
int32_t
part_length
=
(
total_byte_size
+
epoll_conf_
.
link_num
()
-
1
)
/
epoll_conf_
.
link_num
();
part_length
=
RoundUp
(
part_length
,
epoll_conf_
.
msg_segment_kbyte
()
*
1024
);
int32_t
part_num
=
(
total_byte_size
+
part_length
-
1
)
/
part_length
;
CHECK_LE
(
part_num
,
epoll_conf_
.
link_num
());
for
(
int32_t
link_i
=
0
;
link_i
<
part_num
;
++
link_i
)
{
int32_t
byte_size
=
(
total_byte_size
>
part_length
)
?
(
part_length
)
:
(
total_byte_size
);
CHECK_GT
(
byte_size
,
0
);
total_byte_size
-=
byte_size
;
SocketMsg
msg
;
msg
.
msg_type
=
SocketMsgType
::
kRequestRead
;
msg
.
request_read_msg
.
src_token
=
src_token
;
msg
.
request_read_msg
.
dst_token
=
dst_token
;
msg
.
request_read_msg
.
offset
=
link_i
*
part_length
;
msg
.
request_read_msg
.
byte_size
=
byte_size
;
msg
.
request_read_msg
.
read_id
=
read_id
;
msg
.
request_read_msg
.
part_num
=
part_num
;
GetSocketHelper
(
dst_machine_id
,
link_i
)
->
AsyncWrite
(
msg
);
}
CHECK_EQ
(
total_byte_size
,
0
);
}
SocketMemDesc
*
EpollCommNet
::
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
{
SocketMemDesc
*
EpollCommNet
::
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
{
SocketMemDesc
*
mem_desc
=
new
SocketMemDesc
;
mem_desc
->
mem_ptr
=
ptr
;
mem_desc
->
byte_size
=
byte_size
;
return
mem_desc
;
}
EpollCommNet
::
EpollCommNet
(
const
Plan
&
plan
)
:
CommNetIf
(
plan
)
{
EpollCommNet
::
EpollCommNet
(
const
Plan
&
plan
)
:
CommNetIf
(
plan
),
epoll_conf_
(
Global
<
JobDesc
>::
Get
()
->
epoll_conf
())
{
pollers_
.
resize
(
Global
<
JobDesc
>::
Get
()
->
CommNetWorkerNum
(),
nullptr
);
for
(
size_t
i
=
0
;
i
<
pollers_
.
size
();
++
i
)
{
pollers_
[
i
]
=
new
IOEventPoller
;
}
for
(
size_t
i
=
0
;
i
<
pollers_
.
size
();
++
i
)
{
pollers_
.
at
(
i
)
=
new
IOEventPoller
;
}
InitSockets
();
for
(
IOEventPoller
*
poller
:
pollers_
)
{
poller
->
Start
();
}
}
...
...
@@ -96,17 +120,17 @@ void EpollCommNet::InitSockets() {
int64_t
this_machine_id
=
Global
<
MachineCtx
>::
Get
()
->
this_machine_id
();
auto
this_machine
=
Global
<
JobDesc
>::
Get
()
->
resource
().
machine
(
this_machine_id
);
int64_t
total_machine_num
=
Global
<
JobDesc
>::
Get
()
->
TotalMachineNum
();
machine_
id2sockfd_
.
assign
(
total_machine_num
,
-
1
);
machine_
link_id2sockfds_
.
assign
(
total_machine_num
*
epoll_conf_
.
link_num
()
,
-
1
);
sockfd2helper_
.
clear
();
size_t
poller_idx
=
0
;
auto
NewSocketHelper
=
[
&
](
int
sockfd
)
{
IOEventPoller
*
poller
=
pollers_
[
poller_idx
]
;
auto
NewSocketHelper
=
[
&
](
int
32_t
sockfd
)
{
IOEventPoller
*
poller
=
pollers_
.
at
(
poller_idx
)
;
poller_idx
=
(
poller_idx
+
1
)
%
pollers_
.
size
();
return
new
SocketHelper
(
sockfd
,
poller
);
};
// listen
int
listen_sockfd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
int
32_t
listen_sockfd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
int32_t
this_listen_port
=
Global
<
JobDesc
>::
Get
()
->
resource
().
data_port
();
if
(
this_listen_port
!=
-
1
)
{
CHECK_EQ
(
SockListen
(
listen_sockfd
,
this_listen_port
,
total_machine_num
),
0
);
...
...
@@ -125,42 +149,51 @@ void EpollCommNet::InitSockets() {
int32_t
src_machine_count
=
0
;
// connect
for
(
int64_t
peer_id
:
peer_machine_id
())
{
if
(
peer_id
<
this_machine_id
)
{
for
(
int64_t
peer_
mchn_
id
:
peer_machine_id
())
{
if
(
peer_
mchn_
id
<
this_machine_id
)
{
++
src_machine_count
;
continue
;
}
uint16_t
peer_port
=
PullPort
(
peer_id
);
auto
peer_machine
=
Global
<
JobDesc
>::
Get
()
->
resource
().
machine
(
peer_id
);
uint16_t
peer_port
=
PullPort
(
peer_
mchn_
id
);
auto
peer_machine
=
Global
<
JobDesc
>::
Get
()
->
resource
().
machine
(
peer_
mchn_
id
);
sockaddr_in
peer_sockaddr
=
GetSockAddr
(
peer_machine
.
addr
(),
peer_port
);
int
sockfd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
for
(
int32_t
link_i
=
0
;
link_i
<
epoll_conf_
.
link_num
();
++
link_i
)
{
int32_t
sockfd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
PCHECK
(
connect
(
sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
peer_sockaddr
),
sizeof
(
peer_sockaddr
))
==
0
);
CHECK
(
sockfd2helper_
.
emplace
(
sockfd
,
NewSocketHelper
(
sockfd
)).
second
);
machine_id2sockfd_
[
peer_id
]
=
sockfd
;
machine_link_id2sockfds_
.
at
(
peer_mchn_id
*
epoll_conf_
.
link_num
()
+
link_i
)
=
sockfd
;
}
}
// accept
FOR_RANGE
(
int32_t
,
idx
,
0
,
src_machine_count
)
{
sockaddr_in
peer_sockaddr
;
socklen_t
len
=
sizeof
(
peer_sockaddr
);
int
sockfd
=
accept
(
listen_sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
peer_sockaddr
),
&
len
);
for
(
int32_t
link_i
=
0
;
link_i
<
epoll_conf_
.
link_num
();
++
link_i
)
{
int32_t
sockfd
=
accept
(
listen_sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
peer_sockaddr
),
&
len
);
PCHECK
(
sockfd
!=
-
1
);
CHECK
(
sockfd2helper_
.
emplace
(
sockfd
,
NewSocketHelper
(
sockfd
)).
second
);
int64_t
peer_machine_id
=
GetMachineId
(
peer_sockaddr
);
machine_id2sockfd_
[
peer_machine_id
]
=
sockfd
;
int64_t
peer_mchn_id
=
GetMachineId
(
peer_sockaddr
);
machine_link_id2sockfds_
.
at
(
peer_mchn_id
*
epoll_conf_
.
link_num
()
+
link_i
)
=
sockfd
;
}
}
PCHECK
(
close
(
listen_sockfd
)
==
0
);
ClearPort
(
this_machine_id
);
// useful log
FOR_RANGE
(
int64_t
,
machine_id
,
0
,
total_machine_num
)
{
LOG
(
INFO
)
<<
"machine "
<<
machine_id
<<
" sockfd "
<<
machine_id2sockfd_
[
machine_id
];
for
(
int64_t
peer_mchn_id
:
peer_machine_id
())
{
FOR_RANGE
(
int32_t
,
link_i
,
0
,
epoll_conf_
.
link_num
())
{
int32_t
sockfd
=
machine_link_id2sockfds_
.
at
(
peer_mchn_id
*
epoll_conf_
.
link_num
()
+
link_i
);
CHECK_GT
(
sockfd
,
0
);
LOG
(
INFO
)
<<
"machine: "
<<
peer_mchn_id
<<
", link index: "
<<
link_i
<<
", sockfd: "
<<
sockfd
;
}
}
}
SocketHelper
*
EpollCommNet
::
GetSocketHelper
(
int64_t
machine_id
)
{
int
sockfd
=
machine_id2sockfd_
.
at
(
machine_id
);
SocketHelper
*
EpollCommNet
::
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
)
const
{
int
32_t
sockfd
=
machine_link_id2sockfds_
.
at
(
machine_id
*
epoll_conf_
.
link_num
()
+
link_index
);
return
sockfd2helper_
.
at
(
sockfd
);
}
...
...
@@ -171,7 +204,15 @@ void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token
msg
.
request_write_msg
.
dst_machine_id
=
Global
<
MachineCtx
>::
Get
()
->
this_machine_id
();
msg
.
request_write_msg
.
dst_token
=
dst_token
;
msg
.
request_write_msg
.
read_id
=
read_id
;
GetSocketHelper
(
src_machine_id
)
->
AsyncWrite
(
msg
);
dst_token2part_done_cnt_
.
at
(
dst_token
)
=
0
;
GetSocketHelper
(
src_machine_id
,
epoll_conf_
.
link_num
()
-
1
)
->
AsyncWrite
(
msg
);
}
void
EpollCommNet
::
PartReadDone
(
void
*
read_id
,
void
*
dst_token
,
int32_t
part_num
)
{
if
(
dst_token2part_done_cnt_
.
at
(
dst_token
).
fetch_add
(
1
,
std
::
memory_order_relaxed
)
==
(
part_num
-
1
))
{
ReadDone
(
read_id
);
}
}
}
// namespace oneflow
...
...
oneflow/core/comm_network/epoll/epoll_comm_network.h
浏览文件 @
73d5821a
...
...
@@ -18,20 +18,24 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
void
RegisterMemoryDone
()
override
;
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
override
;
void
SendSocketMsg
(
int64_t
dst_machine_id
,
const
SocketMsg
&
msg
);
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
override
;
void
RequestRead
(
int64_t
dst_machine_id
,
void
*
src_token
,
void
*
dst_token
,
void
*
read_id
)
const
;
void
PartReadDone
(
void
*
read_id
,
void
*
dst_token
,
int32_t
part_num
);
private:
SocketMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
override
;
SocketMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
override
;
EpollCommNet
(
const
Plan
&
plan
);
void
InitSockets
();
SocketHelper
*
GetSocketHelper
(
int64_t
machine_id
)
;
SocketHelper
*
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
)
const
;
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
override
;
const
EpollConf
&
epoll_conf_
;
std
::
vector
<
IOEventPoller
*>
pollers_
;
std
::
vector
<
int
>
machine_id2sockfd_
;
// machine_link_id = machine_id * epoll_conf_.link_num() + link_id
std
::
vector
<
int64_t
>
machine_link_id2sockfds_
;
HashMap
<
int
,
SocketHelper
*>
sockfd2helper_
;
HashMap
<
void
*
,
std
::
atomic
<
int32_t
>>
dst_token2part_done_cnt_
;
};
template
<
>
...
...
oneflow/core/comm_network/epoll/io_event_poller.cpp
浏览文件 @
73d5821a
...
...
@@ -6,7 +6,7 @@
namespace
oneflow
{
const
int
IOEventPoller
::
max_event_num_
=
32
;
const
int
32_t
IOEventPoller
::
max_event_num_
=
32
;
IOEventPoller
::
IOEventPoller
()
{
epfd_
=
epoll_create1
(
0
);
...
...
@@ -26,12 +26,12 @@ IOEventPoller::~IOEventPoller() {
PCHECK
(
close
(
epfd_
)
==
0
);
}
void
IOEventPoller
::
AddFd
(
int
fd
,
std
::
function
<
void
()
>
read_handler
,
void
IOEventPoller
::
AddFd
(
int
32_t
fd
,
std
::
function
<
void
()
>
read_handler
,
std
::
function
<
void
()
>
write_handler
)
{
AddFd
(
fd
,
&
read_handler
,
&
write_handler
);
}
void
IOEventPoller
::
AddFdWithOnlyReadHandler
(
int
fd
,
std
::
function
<
void
()
>
read_handler
)
{
void
IOEventPoller
::
AddFdWithOnlyReadHandler
(
int
32_t
fd
,
std
::
function
<
void
()
>
read_handler
)
{
AddFd
(
fd
,
&
read_handler
,
nullptr
);
}
...
...
@@ -43,10 +43,10 @@ void IOEventPoller::Stop() {
thread_
.
join
();
}
void
IOEventPoller
::
AddFd
(
int
fd
,
std
::
function
<
void
()
>*
read_handler
,
void
IOEventPoller
::
AddFd
(
int
32_t
fd
,
std
::
function
<
void
()
>*
read_handler
,
std
::
function
<
void
()
>*
write_handler
)
{
// Set Fd NONBLOCK
int
opt
=
fcntl
(
fd
,
F_GETFL
);
int
32_t
opt
=
fcntl
(
fd
,
F_GETFL
);
PCHECK
(
opt
!=
-
1
);
PCHECK
(
fcntl
(
fd
,
F_SETFL
,
opt
|
O_NONBLOCK
)
==
0
);
// Set CLOEXEC
...
...
@@ -70,13 +70,13 @@ void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler,
void
IOEventPoller
::
EpollLoop
()
{
while
(
true
)
{
int
event_num
=
epoll_wait
(
epfd_
,
ep_events_
,
max_event_num_
,
-
1
);
int
32_t
event_num
=
epoll_wait
(
epfd_
,
ep_events_
,
max_event_num_
,
-
1
);
if
(
event_num
==
-
1
)
{
PCHECK
(
errno
==
EINTR
);
continue
;
}
const
epoll_event
*
cur_event
=
ep_events_
;
for
(
int
event_idx
=
0
;
event_idx
<
event_num
;
++
event_idx
,
++
cur_event
)
{
for
(
int
32_t
event_idx
=
0
;
event_idx
<
event_num
;
++
event_idx
,
++
cur_event
)
{
auto
io_handler
=
static_cast
<
IOHandler
*>
(
cur_event
->
data
.
ptr
);
PCHECK
(
!
(
cur_event
->
events
&
EPOLLERR
))
<<
"fd: "
<<
io_handler
->
fd
;
if
(
io_handler
->
fd
==
break_epoll_loop_fd_
)
{
return
;
}
...
...
oneflow/core/comm_network/epoll/io_event_poller.h
浏览文件 @
73d5821a
...
...
@@ -13,8 +13,8 @@ class IOEventPoller final {
IOEventPoller
();
~
IOEventPoller
();
void
AddFd
(
int
fd
,
std
::
function
<
void
()
>
read_handler
,
std
::
function
<
void
()
>
write_handler
);
void
AddFdWithOnlyReadHandler
(
int
fd
,
std
::
function
<
void
()
>
read_handler
);
void
AddFd
(
int
32_t
fd
,
std
::
function
<
void
()
>
read_handler
,
std
::
function
<
void
()
>
write_handler
);
void
AddFdWithOnlyReadHandler
(
int
32_t
fd
,
std
::
function
<
void
()
>
read_handler
);
void
Start
();
void
Stop
();
...
...
@@ -28,18 +28,18 @@ class IOEventPoller final {
}
std
::
function
<
void
()
>
read_handler
;
std
::
function
<
void
()
>
write_handler
;
int
fd
;
int
32_t
fd
;
};
void
AddFd
(
int
fd
,
std
::
function
<
void
()
>*
read_handler
,
std
::
function
<
void
()
>*
write_handler
);
void
AddFd
(
int
32_t
fd
,
std
::
function
<
void
()
>*
read_handler
,
std
::
function
<
void
()
>*
write_handler
);
void
EpollLoop
();
static
const
int
max_event_num_
;
static
const
int
32_t
max_event_num_
;
int
epfd_
;
int
32_t
epfd_
;
epoll_event
*
ep_events_
;
std
::
forward_list
<
IOHandler
*>
io_handlers_
;
int
break_epoll_loop_fd_
;
int
32_t
break_epoll_loop_fd_
;
std
::
thread
thread_
;
};
...
...
oneflow/core/comm_network/epoll/socket_helper.cpp
浏览文件 @
73d5821a
...
...
@@ -4,7 +4,7 @@
namespace
oneflow
{
SocketHelper
::
SocketHelper
(
int
sockfd
,
IOEventPoller
*
poller
)
{
SocketHelper
::
SocketHelper
(
int
32_t
sockfd
,
IOEventPoller
*
poller
)
{
read_helper_
=
new
SocketReadHelper
(
sockfd
);
write_helper_
=
new
SocketWriteHelper
(
sockfd
,
poller
);
poller
->
AddFd
(
sockfd
,
[
this
]()
{
read_helper_
->
NotifyMeSocketReadable
();
},
...
...
oneflow/core/comm_network/epoll/socket_helper.h
浏览文件 @
73d5821a
...
...
@@ -15,7 +15,7 @@ class SocketHelper final {
SocketHelper
()
=
delete
;
~
SocketHelper
();
SocketHelper
(
int
sockfd
,
IOEventPoller
*
poller
);
SocketHelper
(
int
32_t
sockfd
,
IOEventPoller
*
poller
);
void
AsyncWrite
(
const
SocketMsg
&
msg
);
...
...
oneflow/core/comm_network/epoll/socket_message.h
浏览文件 @
73d5821a
...
...
@@ -40,7 +40,10 @@ struct RequestWriteMsg {
struct
RequestReadMsg
{
void
*
src_token
;
void
*
dst_token
;
int64_t
offset
;
int64_t
byte_size
;
void
*
read_id
;
int32_t
part_num
;
};
struct
SocketMsg
{
...
...
oneflow/core/comm_network/epoll/socket_read_helper.cpp
浏览文件 @
73d5821a
...
...
@@ -10,7 +10,7 @@ SocketReadHelper::~SocketReadHelper() {
// do nothing
}
SocketReadHelper
::
SocketReadHelper
(
int
sockfd
)
{
SocketReadHelper
::
SocketReadHelper
(
int
32_t
sockfd
)
{
sockfd_
=
sockfd
;
SwitchToMsgHeadReadHandle
();
}
...
...
@@ -63,26 +63,24 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() {
void
SocketReadHelper
::
SetStatusWhenMsgBodyDone
()
{
if
(
cur_msg_
.
msg_type
==
SocketMsgType
::
kRequestRead
)
{
Global
<
EpollCommNet
>::
Get
()
->
ReadDone
(
cur_msg_
.
request_read_msg
.
read_id
);
Global
<
EpollCommNet
>::
Get
()
->
PartReadDone
(
cur_msg_
.
request_read_msg
.
read_id
,
cur_msg_
.
request_read_msg
.
dst_token
,
cur_msg_
.
request_read_msg
.
part_num
);
}
SwitchToMsgHeadReadHandle
();
}
void
SocketReadHelper
::
SetStatusWhenRequestWriteMsgHeadDone
()
{
SocketMsg
msg_to_send
;
msg_to_send
.
msg_type
=
SocketMsgType
::
kRequestRead
;
msg_to_send
.
request_read_msg
.
src_token
=
cur_msg_
.
request_write_msg
.
src_token
;
msg_to_send
.
request_read_msg
.
dst_token
=
cur_msg_
.
request_write_msg
.
dst_token
;
msg_to_send
.
request_read_msg
.
read_id
=
cur_msg_
.
request_write_msg
.
read_id
;
Global
<
EpollCommNet
>::
Get
()
->
SendSocketMsg
(
cur_msg_
.
request_write_msg
.
dst_machine_id
,
msg_to_send
);
Global
<
EpollCommNet
>::
Get
()
->
RequestRead
(
cur_msg_
.
request_write_msg
.
dst_machine_id
,
cur_msg_
.
request_write_msg
.
src_token
,
cur_msg_
.
request_write_msg
.
dst_token
,
cur_msg_
.
request_write_msg
.
read_id
);
SwitchToMsgHeadReadHandle
();
}
void
SocketReadHelper
::
SetStatusWhenRequestReadMsgHeadDone
()
{
auto
mem_desc
=
static_cast
<
const
SocketMemDesc
*>
(
cur_msg_
.
request_read_msg
.
dst_token
);
read_ptr_
=
reinterpret_cast
<
char
*>
(
mem_desc
->
mem_ptr
);
read_size_
=
mem_desc
->
byte_size
;
read_ptr_
=
reinterpret_cast
<
char
*>
(
mem_desc
->
mem_ptr
)
+
cur_msg_
.
request_read_msg
.
offset
;
read_size_
=
cur_msg_
.
request_read_msg
.
byte_size
;
cur_read_handle_
=
&
SocketReadHelper
::
MsgBodyReadHandle
;
}
...
...
oneflow/core/comm_network/epoll/socket_read_helper.h
浏览文件 @
73d5821a
...
...
@@ -13,7 +13,7 @@ class SocketReadHelper final {
SocketReadHelper
()
=
delete
;
~
SocketReadHelper
();
SocketReadHelper
(
int
sockfd
);
SocketReadHelper
(
int
32_t
sockfd
);
void
NotifyMeSocketReadable
();
...
...
@@ -32,7 +32,7 @@ class SocketReadHelper final {
OF_PP_FOR_EACH_TUPLE
(
MAKE_ENTRY
,
SOCKET_MSG_TYPE_SEQ
);
#undef MAKE_ENTRY
int
sockfd_
;
int
32_t
sockfd_
;
SocketMsg
cur_msg_
;
bool
(
SocketReadHelper
::*
cur_read_handle_
)();
...
...
oneflow/core/comm_network/epoll/socket_write_helper.cpp
浏览文件 @
73d5821a
...
...
@@ -17,7 +17,7 @@ SocketWriteHelper::~SocketWriteHelper() {
}
}
SocketWriteHelper
::
SocketWriteHelper
(
int
sockfd
,
IOEventPoller
*
poller
)
{
SocketWriteHelper
::
SocketWriteHelper
(
int
32_t
sockfd
,
IOEventPoller
*
poller
)
{
sockfd_
=
sockfd
;
queue_not_empty_fd_
=
eventfd
(
0
,
0
);
PCHECK
(
queue_not_empty_fd_
!=
-
1
);
...
...
@@ -116,8 +116,9 @@ void SocketWriteHelper::SetStatusWhenRequestWriteMsgHeadDone() {
void
SocketWriteHelper
::
SetStatusWhenRequestReadMsgHeadDone
()
{
const
void
*
src_token
=
cur_msg_
.
request_read_msg
.
src_token
;
auto
src_mem_desc
=
static_cast
<
const
SocketMemDesc
*>
(
src_token
);
write_ptr_
=
reinterpret_cast
<
const
char
*>
(
src_mem_desc
->
mem_ptr
);
write_size_
=
src_mem_desc
->
byte_size
;
write_ptr_
=
reinterpret_cast
<
const
char
*>
(
src_mem_desc
->
mem_ptr
)
+
cur_msg_
.
request_read_msg
.
offset
;
write_size_
=
cur_msg_
.
request_read_msg
.
byte_size
;
cur_write_handle_
=
&
SocketWriteHelper
::
MsgBodyWriteHandle
;
}
...
...
oneflow/core/comm_network/epoll/socket_write_helper.h
浏览文件 @
73d5821a
...
...
@@ -14,7 +14,7 @@ class SocketWriteHelper final {
SocketWriteHelper
()
=
delete
;
~
SocketWriteHelper
();
SocketWriteHelper
(
int
sockfd
,
IOEventPoller
*
poller
);
SocketWriteHelper
(
int
32_t
sockfd
,
IOEventPoller
*
poller
);
void
AsyncWrite
(
const
SocketMsg
&
msg
);
...
...
@@ -37,8 +37,8 @@ class SocketWriteHelper final {
OF_PP_FOR_EACH_TUPLE
(
MAKE_ENTRY
,
SOCKET_MSG_TYPE_SEQ
);
#undef MAKE_ENTRY
int
sockfd_
;
int
queue_not_empty_fd_
;
int
32_t
sockfd_
;
int
32_t
queue_not_empty_fd_
;
std
::
queue
<
SocketMsg
>*
cur_msg_queue_
;
...
...
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
浏览文件 @
73d5821a
...
...
@@ -49,7 +49,7 @@ void IBVerbsCommNet::RegisterMemoryDone() {
Global
<
CtrlClient
>::
Get
()
->
ClearKV
(
GenTokensMsgKey
(
this_machine_id
));
}
void
IBVerbsCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
{
void
IBVerbsCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
{
qp_vec_
.
at
(
dst_machine_id
)
->
PostSendRequest
(
msg
);
}
...
...
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
浏览文件 @
73d5821a
...
...
@@ -23,10 +23,10 @@ class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
void
RegisterMemoryDone
()
override
;
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
override
;
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
override
;
private:
IBVerbsMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
override
{
IBVerbsMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
override
{
return
new
IBVerbsMemDesc
(
pd_
,
ptr
,
byte_size
);
}
...
...
oneflow/core/job/job_conf.proto
浏览文件 @
73d5821a
...
...
@@ -46,11 +46,29 @@ message FileSystemConf {
}
}
message
EpollConf
{
optional
int32
link_num
=
1
[
default
=
5
];
optional
int32
msg_segment_kbyte
=
2
[
default
=
256
];
}
message
IBVerbsConf
{
}
message
CommNetworkConf
{
oneof
comm_net_type
{
EpollConf
epoll_conf
=
1
;
IBVerbsConf
ibverbs_conf
=
2
;
}
}
message
OtherConf
{
required
int64
piece_size
=
2
;
required
int32
data_part_num
=
3
;
// piece_size % data_part_num = 0
required
FileSystemConf
data_fs_conf
=
1
;
required
FileSystemConf
snapshot_fs_conf
=
2
;
optional
CommNetworkConf
comm_net_conf
=
3
;
required
int64
piece_size
=
4
;
required
int32
data_part_num
=
5
;
// piece_size % data_part_num = 0
optional
bool
use_rdma
=
100
[
default
=
false
];
optional
string
model_load_snapshot_path
=
101
[
default
=
""
];
optional
int32
max_data_id_length
=
102
[
default
=
0
];
optional
bool
enable_cudnn
=
103
[
default
=
true
];
...
...
@@ -72,9 +90,6 @@ message OtherConf {
optional
bool
use_nccl_inter_node_communication
=
143
[
default
=
false
];
optional
int64
cudnn_buf_limit_mbyte
=
144
[
default
=
1024
];
// 1GByte
required
FileSystemConf
data_fs_conf
=
121
;
required
FileSystemConf
snapshot_fs_conf
=
122
;
oneof
JobType
{
TrainConf
train_conf
=
200
;
PredictConf
predict_conf
=
201
;
...
...
oneflow/core/job/job_desc.cpp
浏览文件 @
73d5821a
...
...
@@ -118,7 +118,9 @@ void JobDesc::Init() {
SplitDecodeOps
();
AddRecordLoadOps
();
#ifndef WITH_RDMA
if
(
this
->
TotalMachineNum
()
>
1
)
{
CHECK_EQ
(
job_conf_
.
other
().
use_rdma
(),
false
)
<<
"Please compile ONEFLOW with RDMA"
;
}
#endif
#ifndef WITH_CUDA
CHECK_EQ
(
job_conf_
.
other
().
enable_nccl
(),
false
)
<<
"Please compile ONEFLOW with NCCL"
;
...
...
oneflow/core/job/job_desc.h
浏览文件 @
73d5821a
...
...
@@ -21,10 +21,22 @@ class JobDesc final {
const
Resource
&
resource
()
const
{
return
job_conf_
.
resource
();
}
const
Placement
&
placement
()
const
{
return
job_conf_
.
placement
();
}
const
OtherConf
&
other_conf
()
const
{
return
job_conf_
.
other
();
}
const
CommNetworkConf
&
comm_net_conf
()
const
{
CHECK
(
this
->
other_conf
().
has_comm_net_conf
());
return
job_conf_
.
other
().
comm_net_conf
();
}
bool
use_rdma
()
const
{
return
this
->
comm_net_conf
().
has_ibverbs_conf
();
}
const
EpollConf
&
epoll_conf
()
{
CHECK
(
!
this
->
use_rdma
());
return
this
->
comm_net_conf
().
epoll_conf
();
}
const
IBVerbsConf
&
ibverbs_conf
()
const
{
CHECK
(
this
->
use_rdma
());
return
this
->
comm_net_conf
().
ibverbs_conf
();
}
const
std
::
string
&
MdLoadSnapshotPath
()
{
return
job_conf_
.
other
().
model_load_snapshot_path
();
}
DataType
DefaultDataType
()
const
{
return
job_conf_
.
other
().
default_data_type
();
}
size_t
SizeOfOneDataId
()
const
{
return
job_conf_
.
other
().
max_data_id_length
()
*
sizeof
(
char
);
}
bool
use_rdma
()
const
{
return
job_conf_
.
other
().
use_rdma
();
}
bool
EnableCudnn
()
const
{
return
job_conf_
.
other
().
enable_cudnn
();
}
int64_t
TotalMachineNum
()
const
{
return
job_conf_
.
resource
().
machine
().
size
();
}
int32_t
CpuDeviceNum
()
const
{
return
job_conf_
.
resource
().
cpu_device_num
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录