Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
93034430
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
93034430
编写于
1月 18, 2019
作者:
S
Shiyuan Shang-Guan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm lck
Former-commit-id: 3655dbad6d2d61ae84e0e484a3a5a2fed2fc66b2
上级
39ffdcb8
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
41 addition
and
35 deletion
+41
-35
oneflow/core/comm_network/comm_network.cpp
oneflow/core/comm_network/comm_network.cpp
+2
-2
oneflow/core/comm_network/comm_network.h
oneflow/core/comm_network/comm_network.h
+8
-7
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
+16
-13
oneflow/core/comm_network/epoll/epoll_comm_network.h
oneflow/core/comm_network/epoll/epoll_comm_network.h
+8
-8
oneflow/core/comm_network/epoll/socket_read_helper.cpp
oneflow/core/comm_network/epoll/socket_read_helper.cpp
+1
-0
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
+2
-2
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
+4
-3
未找到文件。
oneflow/core/comm_network/comm_network.cpp
浏览文件 @
93034430
...
...
@@ -7,9 +7,9 @@ CommNet::~CommNet() {
ready_cb_poller_
.
join
();
}
void
*
CommNet
::
NewActorReadId
()
{
return
new
ActorReadContext
;
}
void
*
CommNet
::
NewActorReadId
()
const
{
return
new
ActorReadContext
;
}
void
CommNet
::
DeleteActorReadId
(
void
*
actor_read_id
)
{
void
CommNet
::
DeleteActorReadId
(
void
*
actor_read_id
)
const
{
auto
actor_read_ctx
=
static_cast
<
ActorReadContext
*>
(
actor_read_id
);
CHECK
(
actor_read_ctx
->
waiting_list
.
empty
());
delete
actor_read_ctx
;
...
...
oneflow/core/comm_network/comm_network.h
浏览文件 @
93034430
...
...
@@ -29,20 +29,21 @@ class CommNet {
virtual
void
RegisterMemoryDone
()
=
0
;
// Stream
void
*
NewActorReadId
();
void
DeleteActorReadId
(
void
*
actor_read_id
);
void
*
NewActorReadId
()
const
;
void
DeleteActorReadId
(
void
*
actor_read_id
)
const
;
void
Read
(
void
*
actor_read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
);
void
AddReadCallBack
(
void
*
actor_read_id
,
std
::
function
<
void
()
>
callback
);
void
ReadDone
(
void
*
read_id
);
//
virtual
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
=
0
;
virtual
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
=
0
;
protected:
CommNet
(
const
Plan
&
plan
);
virtual
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
=
0
;
const
HashSet
<
int64_t
>&
peer_machine_id
()
{
return
peer_machine_id_
;
}
virtual
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
const
=
0
;
const
HashSet
<
int64_t
>&
peer_machine_id
()
const
{
return
peer_machine_id_
;
}
Channel
<
std
::
function
<
void
()
>>
ready_cbs_
;
...
...
@@ -84,8 +85,8 @@ class CommNetIf : public CommNet {
}
protected:
virtual
MemDescType
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
=
0
;
const
HashSet
<
MemDescType
*>&
mem_descs
()
{
return
mem_descs_
;
}
virtual
MemDescType
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
=
0
;
const
HashSet
<
MemDescType
*>&
mem_descs
()
const
{
return
mem_descs_
;
}
private:
std
::
mutex
mem_descs_mtx_
;
...
...
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
浏览文件 @
93034430
...
...
@@ -64,10 +64,15 @@ EpollCommNet::~EpollCommNet() {
}
void
EpollCommNet
::
RegisterMemoryDone
()
{
// do nothing
for
(
void
*
dst_token
:
mem_descs
())
{
CHECK
(
dst_token2part_done_cnt_
.
emplace
(
dst_token
,
std
::
shared_ptr
<
std
::
atomic
<
int32_t
>>
(
new
std
::
atomic
<
int32_t
>
(
0
)))
.
second
);
}
}
void
EpollCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
actor_msg
)
{
void
EpollCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
actor_msg
)
const
{
SocketMsg
msg
;
msg
.
msg_type
=
SocketMsgType
::
kActor
;
msg
.
actor_msg
=
actor_msg
;
...
...
@@ -75,7 +80,7 @@ void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_ms
}
void
EpollCommNet
::
RequestRead
(
int64_t
dst_machine_id
,
void
*
src_token
,
void
*
dst_token
,
void
*
read_id
)
{
void
*
read_id
)
const
{
int32_t
total_byte_size
=
static_cast
<
const
SocketMemDesc
*>
(
src_token
)
->
byte_size
;
int32_t
offset
=
(
total_byte_size
+
epoll_conf_
.
link_num
()
-
1
)
/
epoll_conf_
.
link_num
();
offset
=
RoundUp
(
offset
,
epoll_conf_
.
msg_segment_kbyte
()
*
1024
);
...
...
@@ -96,7 +101,7 @@ void EpollCommNet::RequestRead(int64_t dst_machine_id, void* src_token, void* ds
CHECK_EQ
(
total_byte_size
,
0
);
}
SocketMemDesc
*
EpollCommNet
::
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
{
SocketMemDesc
*
EpollCommNet
::
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
{
SocketMemDesc
*
mem_desc
=
new
SocketMemDesc
;
mem_desc
->
mem_ptr
=
ptr
;
mem_desc
->
byte_size
=
byte_size
;
...
...
@@ -185,29 +190,27 @@ void EpollCommNet::InitSockets() {
}
}
SocketHelper
*
EpollCommNet
::
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
)
{
SocketHelper
*
EpollCommNet
::
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
)
const
{
int32_t
sockfd
=
machine_id2sockfds_
.
at
(
machine_id
*
epoll_conf_
.
link_num
()
+
link_index
);
return
sockfd2helper_
.
at
(
sockfd
);
}
void
EpollCommNet
::
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
{
CHECK
(
read_id2part_done_cnt_
.
emplace
(
read_id
,
0
).
second
);
void
EpollCommNet
::
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
const
{
SocketMsg
msg
;
msg
.
msg_type
=
SocketMsgType
::
kRequestWrite
;
msg
.
request_write_msg
.
src_token
=
src_token
;
msg
.
request_write_msg
.
dst_machine_id
=
Global
<
MachineCtx
>::
Get
()
->
this_machine_id
();
msg
.
request_write_msg
.
dst_token
=
dst_token
;
msg
.
request_write_msg
.
read_id
=
read_id
;
*
(
dst_token2part_done_cnt_
.
at
(
dst_token
))
=
0
;
GetSocketHelper
(
src_machine_id
,
0
)
->
AsyncWrite
(
msg
);
}
void
EpollCommNet
::
PartReadDone
(
void
*
read_id
,
int32_t
part_num
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
part_done_cnt_mtx_
);
int32_t
&
part_read_done_cnt
=
read_id2part_done_cnt_
.
at
(
read_id
);
part_read_done_cnt
++
;
if
(
part_read_done_cnt
==
part_num
)
{
void
EpollCommNet
::
PartReadDone
(
void
*
read_id
,
void
*
dst_token
,
int32_t
part_num
)
{
if
(
dst_token2part_done_cnt_
.
at
(
dst_token
)
->
fetch_add
(
1
,
std
::
memory_order_relaxed
)
==
(
part_num
-
1
))
{
ReadDone
(
read_id
);
read_id2part_done_cnt_
.
erase
(
read_id
);
}
}
...
...
oneflow/core/comm_network/epoll/epoll_comm_network.h
浏览文件 @
93034430
...
...
@@ -18,24 +18,24 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
void
RegisterMemoryDone
()
override
;
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
override
;
void
RequestRead
(
int64_t
dst_machine_id
,
void
*
src_token
,
void
*
dst_token
,
void
*
read_id
);
void
PartReadDone
(
void
*
read_id
,
int32_t
part_num
);
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
override
;
void
RequestRead
(
int64_t
dst_machine_id
,
void
*
src_token
,
void
*
dst_token
,
void
*
read_id
)
const
;
void
PartReadDone
(
void
*
read_id
,
void
*
dst_token
,
int32_t
part_num
);
private:
SocketMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
override
;
SocketMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
override
;
EpollCommNet
(
const
Plan
&
plan
);
void
InitSockets
();
SocketHelper
*
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
);
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
override
;
SocketHelper
*
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
)
const
;
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
const
override
;
const
EpollConf
&
epoll_conf_
;
std
::
vector
<
IOEventPoller
*>
pollers_
;
std
::
vector
<
int32_t
>
machine_id2sockfds_
;
HashMap
<
int
,
SocketHelper
*>
sockfd2helper_
;
std
::
mutex
part_done_cnt_mtx_
;
HashMap
<
void
*
,
int32_t
>
read_id2part_done_cnt_
;
HashMap
<
void
*
,
std
::
shared_ptr
<
std
::
atomic
<
int32_t
>>>
dst_token2part_done_cnt_
;
};
template
<
>
...
...
oneflow/core/comm_network/epoll/socket_read_helper.cpp
浏览文件 @
93034430
...
...
@@ -64,6 +64,7 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() {
void
SocketReadHelper
::
SetStatusWhenMsgBodyDone
()
{
if
(
cur_msg_
.
msg_type
==
SocketMsgType
::
kRequestRead
)
{
Global
<
EpollCommNet
>::
Get
()
->
PartReadDone
(
cur_msg_
.
request_read_msg
.
read_id
,
cur_msg_
.
request_read_msg
.
dst_token
,
cur_msg_
.
request_read_msg
.
part_num
);
}
SwitchToMsgHeadReadHandle
();
...
...
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
浏览文件 @
93034430
...
...
@@ -49,7 +49,7 @@ void IBVerbsCommNet::RegisterMemoryDone() {
Global
<
CtrlClient
>::
Get
()
->
ClearKV
(
GenTokensMsgKey
(
this_machine_id
));
}
void
IBVerbsCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
{
void
IBVerbsCommNet
::
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
{
qp_vec_
.
at
(
dst_machine_id
)
->
PostSendRequest
(
msg
);
}
...
...
@@ -101,7 +101,7 @@ IBVerbsCommNet::IBVerbsCommNet(const Plan& plan)
}
void
IBVerbsCommNet
::
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
{
void
*
dst_token
)
const
{
qp_vec_
.
at
(
src_machine_id
)
->
PostReadRequest
(
token2mem_desc_
.
at
(
src_machine_id
).
at
(
src_token
),
*
static_cast
<
const
IBVerbsMemDesc
*>
(
dst_token
),
read_id
);
...
...
oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
浏览文件 @
93034430
...
...
@@ -23,15 +23,16 @@ class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
void
RegisterMemoryDone
()
override
;
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
override
;
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
const
override
;
private:
IBVerbsMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
override
{
IBVerbsMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
const
override
{
return
new
IBVerbsMemDesc
(
pd_
,
ptr
,
byte_size
);
}
IBVerbsCommNet
(
const
Plan
&
);
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
override
;
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
const
override
;
void
PollCQ
();
static
const
int32_t
max_poll_wc_num_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录