Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
7741acd6
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
7741acd6
编写于
1月 17, 2019
作者:
S
Shiyuan Shang-Guan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multi socket
Former-commit-id: 76c25553437543749ed58495709824bcee8e0b55
上级
52a6c519
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
106 addition
and
34 deletion
+106
-34
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
+64
-23
oneflow/core/comm_network/epoll/epoll_comm_network.h
oneflow/core/comm_network/epoll/epoll_comm_network.h
+7
-2
oneflow/core/comm_network/epoll/socket_message.h
oneflow/core/comm_network/epoll/socket_message.h
+1
-0
oneflow/core/comm_network/epoll/socket_read_helper.cpp
oneflow/core/comm_network/epoll/socket_read_helper.cpp
+2
-1
oneflow/core/common/util.h
oneflow/core/common/util.h
+1
-0
oneflow/core/job/job_conf.proto
oneflow/core/job/job_conf.proto
+20
-6
oneflow/core/job/job_desc.cpp
oneflow/core/job/job_desc.cpp
+1
-1
oneflow/core/job/job_desc.h
oneflow/core/job/job_desc.h
+10
-1
未找到文件。
oneflow/core/comm_network/epoll/epoll_comm_network.cpp
浏览文件 @
7741acd6
...
...
@@ -71,11 +71,32 @@ void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_ms
SocketMsg
msg
;
msg
.
msg_type
=
SocketMsgType
::
kActor
;
msg
.
actor_msg
=
actor_msg
;
GetSocketHelper
(
dst_machine_id
)
->
AsyncWrite
(
msg
);
int32_t
link_i
=
std
::
uniform_int_distribution
<
int32_t
>
(
0
,
epoll_conf_
.
link_num
())(
random_gen_
);
GetSocketHelper
(
dst_machine_id
,
link_i
)
->
AsyncWrite
(
msg
);
}
void
EpollCommNet
::
SendSocketMsg
(
int64_t
dst_machine_id
,
const
SocketMsg
&
msg
)
{
GetSocketHelper
(
dst_machine_id
)
->
AsyncWrite
(
msg
);
void
EpollCommNet
::
SendSocketMsg
(
int64_t
dst_machine_id
,
const
SocketMsg
&
total_msg
)
{
const
SocketMemDesc
*
src_mem_desc
=
static_cast
<
const
SocketMemDesc
*>
(
total_msg
.
request_read_msg
.
src_token
);
const
SocketMemDesc
*
dst_mem_desc
=
static_cast
<
const
SocketMemDesc
*>
(
total_msg
.
request_read_msg
.
dst_token
);
CHECK_EQ
(
src_mem_desc
->
byte_size
,
dst_mem_desc
->
byte_size
);
int32_t
total_byte_size
=
src_mem_desc
->
byte_size
;
int32_t
offset
=
(
total_byte_size
+
epoll_conf_
.
link_num
()
-
1
)
/
epoll_conf_
.
link_num
();
offset
=
RoundUp
(
offset
,
kCacheLineSize
);
int32_t
part_num
=
(
total_byte_size
+
offset
-
1
)
/
offset
;
for
(
int32_t
link_i
=
0
;
link_i
<
part_num
;
++
link_i
)
{
int32_t
byte_size
=
(
total_byte_size
>
offset
)
?
(
offset
)
:
(
total_byte_size
);
total_byte_size
-=
offset
;
SocketMsg
msg
;
msg
.
msg_type
=
total_msg
.
msg_type
;
msg
.
request_read_msg
.
src_token
=
NewMemDesc
(
src_mem_desc
->
mem_ptr
+
link_i
*
offset
,
byte_size
);
msg
.
request_read_msg
.
dst_token
=
NewMemDesc
(
dst_mem_desc
->
mem_ptr
+
link_i
*
offset
,
byte_size
);
msg
.
request_read_msg
.
read_id
=
total_msg
.
request_read_msg
.
read_id
;
msg
.
request_read_msg
.
part_num
=
part_num
;
GetSocketHelper
(
dst_machine_id
,
link_i
)
->
AsyncWrite
(
msg
);
}
CHECK_EQ
(
total_byte_size
,
0
);
}
SocketMemDesc
*
EpollCommNet
::
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
{
...
...
@@ -85,7 +106,8 @@ SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {
return
mem_desc
;
}
EpollCommNet
::
EpollCommNet
(
const
Plan
&
plan
)
:
CommNetIf
(
plan
)
{
EpollCommNet
::
EpollCommNet
(
const
Plan
&
plan
)
:
CommNetIf
(
plan
),
epoll_conf_
(
Global
<
JobDesc
>::
Get
()
->
epoll_conf
())
{
pollers_
.
resize
(
Global
<
JobDesc
>::
Get
()
->
CommNetWorkerNum
(),
nullptr
);
for
(
size_t
i
=
0
;
i
<
pollers_
.
size
();
++
i
)
{
pollers_
[
i
]
=
new
IOEventPoller
;
}
InitSockets
();
...
...
@@ -96,7 +118,7 @@ void EpollCommNet::InitSockets() {
int64_t
this_machine_id
=
Global
<
MachineCtx
>::
Get
()
->
this_machine_id
();
auto
this_machine
=
Global
<
JobDesc
>::
Get
()
->
resource
().
machine
(
this_machine_id
);
int64_t
total_machine_num
=
Global
<
JobDesc
>::
Get
()
->
TotalMachineNum
();
machine_id2sockfd
_
.
assign
(
total_machine_num
,
-
1
);
machine_id2sockfd
s_
.
assign
(
total_machine_num
*
epoll_conf_
.
link_num
()
,
-
1
);
sockfd2helper_
.
clear
();
size_t
poller_idx
=
0
;
auto
NewSocketHelper
=
[
&
](
int
sockfd
)
{
...
...
@@ -125,53 +147,72 @@ void EpollCommNet::InitSockets() {
int32_t
src_machine_count
=
0
;
// connect
for
(
int64_t
peer_id
:
peer_machine_id
())
{
if
(
peer_id
<
this_machine_id
)
{
for
(
int64_t
peer_
mchn_
id
:
peer_machine_id
())
{
if
(
peer_
mchn_
id
<
this_machine_id
)
{
++
src_machine_count
;
continue
;
}
uint16_t
peer_port
=
PullPort
(
peer_id
);
auto
peer_machine
=
Global
<
JobDesc
>::
Get
()
->
resource
().
machine
(
peer_id
);
uint16_t
peer_port
=
PullPort
(
peer_
mchn_
id
);
auto
peer_machine
=
Global
<
JobDesc
>::
Get
()
->
resource
().
machine
(
peer_
mchn_
id
);
sockaddr_in
peer_sockaddr
=
GetSockAddr
(
peer_machine
.
addr
(),
peer_port
);
int
sockfd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
PCHECK
(
connect
(
sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
peer_sockaddr
),
sizeof
(
peer_sockaddr
))
==
0
);
CHECK
(
sockfd2helper_
.
emplace
(
sockfd
,
NewSocketHelper
(
sockfd
)).
second
);
machine_id2sockfd_
[
peer_id
]
=
sockfd
;
for
(
int32_t
link_i
=
0
;
link_i
<
epoll_conf_
.
link_num
();
++
link_i
)
{
int
sockfd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
PCHECK
(
connect
(
sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
peer_sockaddr
),
sizeof
(
peer_sockaddr
))
==
0
);
CHECK
(
sockfd2helper_
.
emplace
(
sockfd
,
NewSocketHelper
(
sockfd
)).
second
);
machine_id2sockfds_
[
peer_mchn_id
*
epoll_conf_
.
link_num
()
+
link_i
]
=
sockfd
;
}
}
// accept
FOR_RANGE
(
int32_t
,
idx
,
0
,
src_machine_count
)
{
sockaddr_in
peer_sockaddr
;
socklen_t
len
=
sizeof
(
peer_sockaddr
);
int
sockfd
=
accept
(
listen_sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
peer_sockaddr
),
&
len
);
PCHECK
(
sockfd
!=
-
1
);
CHECK
(
sockfd2helper_
.
emplace
(
sockfd
,
NewSocketHelper
(
sockfd
)).
second
);
int64_t
peer_machine_id
=
GetMachineId
(
peer_sockaddr
);
machine_id2sockfd_
[
peer_machine_id
]
=
sockfd
;
for
(
int32_t
link_i
=
0
;
link_i
<
epoll_conf_
.
link_num
();
++
link_i
)
{
int
sockfd
=
accept
(
listen_sockfd
,
reinterpret_cast
<
sockaddr
*>
(
&
peer_sockaddr
),
&
len
);
PCHECK
(
sockfd
!=
-
1
);
CHECK
(
sockfd2helper_
.
emplace
(
sockfd
,
NewSocketHelper
(
sockfd
)).
second
);
int64_t
peer_mchn_id
=
GetMachineId
(
peer_sockaddr
);
machine_id2sockfds_
[
peer_mchn_id
*
epoll_conf_
.
link_num
()
+
link_i
]
=
sockfd
;
}
}
PCHECK
(
close
(
listen_sockfd
)
==
0
);
ClearPort
(
this_machine_id
);
// useful log
FOR_RANGE
(
int64_t
,
machine_id
,
0
,
total_machine_num
)
{
LOG
(
INFO
)
<<
"machine "
<<
machine_id
<<
" sockfd "
<<
machine_id2sockfd_
[
machine_id
];
FOR_RANGE
(
int32_t
,
link_i
,
0
,
epoll_conf_
.
link_num
())
{
LOG
(
INFO
)
<<
"machine: "
<<
machine_id
<<
", link index: "
<<
link_i
<<
", sockfd: "
<<
machine_id2sockfds_
[
machine_id
*
epoll_conf_
.
link_num
()
+
link_i
];
}
}
}
SocketHelper
*
EpollCommNet
::
GetSocketHelper
(
int64_t
machine_id
)
{
int
sockfd
=
machine_id2sockfd
_
.
at
(
machine_id
);
SocketHelper
*
EpollCommNet
::
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
)
{
int
sockfd
=
machine_id2sockfd
s_
.
at
(
machine_id
*
epoll_conf_
.
link_num
()
+
link_index
);
return
sockfd2helper_
.
at
(
sockfd
);
}
void
EpollCommNet
::
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
{
CHECK
(
read_id2part_done_cnt_
.
emplace
(
read_id
,
0
).
second
);
SocketMsg
msg
;
msg
.
msg_type
=
SocketMsgType
::
kRequestWrite
;
msg
.
request_write_msg
.
src_token
=
src_token
;
msg
.
request_write_msg
.
dst_machine_id
=
Global
<
MachineCtx
>::
Get
()
->
this_machine_id
();
msg
.
request_write_msg
.
dst_token
=
dst_token
;
msg
.
request_write_msg
.
read_id
=
read_id
;
GetSocketHelper
(
src_machine_id
)
->
AsyncWrite
(
msg
);
int32_t
link_i
=
std
::
uniform_int_distribution
<
int32_t
>
(
0
,
epoll_conf_
.
link_num
())(
random_gen_
);
GetSocketHelper
(
src_machine_id
,
link_i
)
->
AsyncWrite
(
msg
);
}
void
EpollCommNet
::
PartReadDone
(
void
*
read_id
,
int32_t
part_num
)
{
int32_t
&
part_read_done_cnt
=
read_id2part_done_cnt_
.
at
(
read_id
);
std
::
unique_lock
<
std
::
mutex
>
lck
(
part_done_cnt_mtx_
);
part_read_done_cnt
++
;
if
(
part_read_done_cnt
==
part_num
)
{
ReadDone
(
read_id
);
read_id2part_done_cnt_
.
erase
(
read_id
);
}
}
}
// namespace oneflow
...
...
oneflow/core/comm_network/epoll/epoll_comm_network.h
浏览文件 @
7741acd6
...
...
@@ -20,18 +20,23 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
void
SendActorMsg
(
int64_t
dst_machine_id
,
const
ActorMsg
&
msg
)
override
;
void
SendSocketMsg
(
int64_t
dst_machine_id
,
const
SocketMsg
&
msg
);
void
PartReadDone
(
void
*
read_id
,
int32_t
part_num
);
private:
SocketMemDesc
*
NewMemDesc
(
void
*
ptr
,
size_t
byte_size
)
override
;
EpollCommNet
(
const
Plan
&
plan
);
void
InitSockets
();
SocketHelper
*
GetSocketHelper
(
int64_t
machine_id
);
SocketHelper
*
GetSocketHelper
(
int64_t
machine_id
,
int32_t
link_index
);
void
DoRead
(
void
*
read_id
,
int64_t
src_machine_id
,
void
*
src_token
,
void
*
dst_token
)
override
;
const
EpollConf
&
epoll_conf_
;
std
::
vector
<
IOEventPoller
*>
pollers_
;
std
::
vector
<
int
>
machine_id2sockfd_
;
std
::
vector
<
int
>
machine_id2sockfd
s
_
;
HashMap
<
int
,
SocketHelper
*>
sockfd2helper_
;
std
::
mt19937
random_gen_
;
std
::
mutex
part_done_cnt_mtx_
;
HashMap
<
void
*
,
int32_t
>
read_id2part_done_cnt_
;
};
template
<
>
...
...
oneflow/core/comm_network/epoll/socket_message.h
浏览文件 @
7741acd6
...
...
@@ -41,6 +41,7 @@ struct RequestReadMsg {
void
*
src_token
;
void
*
dst_token
;
void
*
read_id
;
int32_t
part_num
;
};
struct
SocketMsg
{
...
...
oneflow/core/comm_network/epoll/socket_read_helper.cpp
浏览文件 @
7741acd6
...
...
@@ -63,7 +63,8 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() {
void
SocketReadHelper
::
SetStatusWhenMsgBodyDone
()
{
if
(
cur_msg_
.
msg_type
==
SocketMsgType
::
kRequestRead
)
{
Global
<
EpollCommNet
>::
Get
()
->
ReadDone
(
cur_msg_
.
request_read_msg
.
read_id
);
Global
<
EpollCommNet
>::
Get
()
->
PartReadDone
(
cur_msg_
.
request_read_msg
.
read_id
,
cur_msg_
.
request_read_msg
.
part_num
);
}
SwitchToMsgHeadReadHandle
();
}
...
...
oneflow/core/common/util.h
浏览文件 @
7741acd6
...
...
@@ -177,6 +177,7 @@ inline double GetCurTime() {
const
size_t
kCudaAlignSize
=
8
;
const
size_t
kCudaMemAllocAlignSize
=
256
;
const
size_t
kCacheLineSize
=
64
;
inline
size_t
RoundUp
(
size_t
n
,
size_t
val
)
{
return
(
n
+
val
-
1
)
/
val
*
val
;
}
size_t
GetAvailableCpuMemSize
();
...
...
oneflow/core/job/job_conf.proto
浏览文件 @
7741acd6
...
...
@@ -50,12 +50,28 @@ message ExperimentalRunConf {
optional
bool
enable_experiment_run
=
2
[
default
=
false
];
}
message
EpollConf
{
optional
int32
link_num
=
1
[
default
=
5
];
}
message
IBVerbsConf
{
}
message
CommNetworkConf
{
oneof
comm_net_type
{
EpollConf
epoll_conf
=
1
;
IBVerbsConf
ibverbs_conf
=
2
;
}
}
message
OtherConf
{
required
int64
piece_size
=
2
;
required
int32
data_part_num
=
3
;
// piece_size % data_part_num = 0
required
int64
total_batch_num
=
4
;
required
FileSystemConf
data_fs_conf
=
1
;
required
FileSystemConf
snapshot_fs_conf
=
2
;
required
CommNetworkConf
comm_net_conf
=
3
;
required
int64
piece_size
=
4
;
required
int32
data_part_num
=
5
;
// piece_size % data_part_num = 0
required
int64
total_batch_num
=
6
;
optional
bool
use_rdma
=
100
[
default
=
false
];
optional
string
model_load_snapshot_path
=
101
[
default
=
""
];
optional
int32
max_data_id_length
=
102
[
default
=
0
];
optional
bool
enable_cudnn
=
103
[
default
=
true
];
...
...
@@ -67,8 +83,6 @@ message OtherConf {
optional
bool
save_downloaded_file_to_local_fs
=
109
[
default
=
false
];
optional
uint64
rdma_mem_block_mbyte
=
110
[
default
=
8
];
optional
uint64
rdma_recv_msg_buf_mbyte
=
111
[
default
=
6
];
required
FileSystemConf
data_fs_conf
=
112
;
required
FileSystemConf
snapshot_fs_conf
=
113
;
optional
bool
collect_act_event
=
125
[
default
=
false
];
optional
bool
enable_mem_sharing
=
126
[
default
=
true
];
...
...
oneflow/core/job/job_desc.cpp
浏览文件 @
7741acd6
...
...
@@ -255,7 +255,7 @@ void JobDesc::Init() {
SplitDecodeOps
();
AddRecordLoadOps
();
#ifndef WITH_RDMA
CHECK_EQ
(
job_conf_
.
other
().
use_rdma
(),
false
)
<<
"Please compile ONEFLOW with RDMA"
;
CHECK_EQ
(
this
->
use_rdma
(),
false
)
<<
"Please compile ONEFLOW with RDMA"
;
#endif
#ifndef WITH_CUDA
CHECK_EQ
(
job_conf_
.
other
().
enable_nccl
(),
false
)
<<
"Please compile ONEFLOW with NCCL"
;
...
...
oneflow/core/job/job_desc.h
浏览文件 @
7741acd6
...
...
@@ -22,10 +22,19 @@ class JobDesc final {
const
Resource
&
resource
()
const
{
return
job_conf_
.
resource
();
}
const
Placement
&
placement
()
const
{
return
job_conf_
.
placement
();
}
const
OtherConf
&
other_conf
()
const
{
return
job_conf_
.
other
();
}
const
CommNetworkConf
&
comm_net_conf
()
const
{
return
job_conf_
.
other
().
comm_net_conf
();
}
bool
use_rdma
()
const
{
return
job_conf_
.
other
().
comm_net_conf
().
has_ibverbs_conf
();
}
const
EpollConf
&
epoll_conf
()
const
{
CHECK
(
!
this
->
use_rdma
());
return
this
->
comm_net_conf
().
epoll_conf
();
}
const
IBVerbsConf
&
ibverbs_conf
()
const
{
CHECK
(
this
->
use_rdma
());
return
this
->
comm_net_conf
().
ibverbs_conf
();
}
const
std
::
string
&
MdLoadSnapshotPath
()
{
return
job_conf_
.
other
().
model_load_snapshot_path
();
}
DataType
DefaultDataType
()
const
{
return
job_conf_
.
other
().
default_data_type
();
}
size_t
SizeOfOneDataId
()
const
{
return
job_conf_
.
other
().
max_data_id_length
()
*
sizeof
(
char
);
}
bool
use_rdma
()
const
{
return
job_conf_
.
other
().
use_rdma
();
}
bool
EnableCudnn
()
const
{
return
job_conf_
.
other
().
enable_cudnn
();
}
int64_t
TotalMachineNum
()
const
{
return
job_conf_
.
resource
().
machine
().
size
();
}
int32_t
CpuDeviceNum
()
const
{
return
job_conf_
.
resource
().
cpu_device_num
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录