Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleFL
提交
368d69a4
P
PaddleFL
项目概览
PaddlePaddle
/
PaddleFL
通知
35
Star
5
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
6
列表
看板
标记
里程碑
合并请求
4
Wiki
3
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleFL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
6
Issue
6
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
3
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
368d69a4
编写于
11月 15, 2019
作者:
G
guru4elephant
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine fl scheduler
上级
ed9ec58d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
132 addition
and
40 deletion
+132
-40
paddle_fl/core/scheduler/agent_master.py
paddle_fl/core/scheduler/agent_master.py
+73
-40
paddle_fl/core/scheduler/test_agent_master.py
paddle_fl/core/scheduler/test_agent_master.py
+59
-0
未找到文件。
paddle_fl/core/scheduler/agent_master.py
浏览文件 @
368d69a4
...
...
@@ -2,32 +2,58 @@ import zmq
import
time
import
random
def
recv_and_parse_kv
(
socket
):
message
=
socket
.
recv
()
socket
.
send
(
"alive"
)
group
=
message
.
split
(
"
\t
"
)
print
(
group
)
assert
len
(
group
)
==
2
return
group
[
0
],
group
[
1
]
if
group
[
0
]
==
"alive"
:
return
group
[
0
],
"0"
else
:
return
group
[
0
],
group
[
1
]
WORKER_EP
=
"WORKER_EP"
SERVER_EP
=
"SERVER_EP"
class
FLAgent
(
object
):
class
FLServerAgent
(
object
):
def
__init__
(
self
,
scheduler_ep
,
current_ep
):
self
.
scheduler_ep
=
scheduler_ep
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
REQ
)
self
.
socket
.
connect
(
"tcp://127.0.0.1:9091"
)
self
.
current_ep
=
current_ep
def
connect_scheduler
(
self
):
self
.
socket
.
send
(
"SERVER_EP
\t
{}"
.
format
(
self
.
current_ep
))
self
.
socket
.
recv
()
class
FLWorkerAgent
(
object
):
def
__init__
(
self
,
scheduler_ep
,
current_ep
):
self
.
scheduler_ep
=
scheduler_ep
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
REQ
)
self
.
socket
.
connect
(
"tcp://127.0.0.1:9091"
)
self
.
current_ep
=
current_ep
def
connect_scheduler
(
self
):
self
.
socket
.
send
(
"WORKER_EP
\t
{}"
.
format
(
self
.
current_ep
))
self
.
socket
.
recv
()
def
finish_training
(
self
):
self
.
socket
.
send
(
"FINISH
\t
{}"
.
format
(
self
.
current_ep
))
key
,
value
=
recv_and_parse_kv
(
self
.
socket
)
if
key
==
"WAIT"
:
time
.
sleep
(
3
)
def
can_join_training
(
self
):
self
.
socket
.
send
(
"JOIN
\t
{}"
.
format
(
self
.
current_ep
))
self
.
socket
.
recv
()
key
,
value
=
recv_and_parse_kv
(
self
.
socket
)
if
key
==
"ACCEPT"
:
return
True
elif
key
==
"REJECT"
:
return
False
return
False
class
FLScheduler
(
object
):
...
...
@@ -53,43 +79,50 @@ class FLScheduler(object):
key
,
value
=
recv_and_parse_kv
(
self
.
socket
)
if
key
==
WORKER_EP
:
self
.
fl_workers
.
append
(
value
)
self
.
socket
.
send
(
"INIT
\t
{}"
.
format
(
value
))
if
key
==
SERVER_EP
:
self
.
fl_servers
.
append
(
value
)
self
.
socket
.
send
(
"INIT
\t
{}"
.
format
(
value
))
if
len
(
self
.
fl_workers
)
==
self
.
worker_num
and
\
len
(
self
.
fl_servers
)
==
self
.
server_num
:
ready
=
True
print
(
"FL training environment started"
)
print
(
"fl workers endpoints"
)
print
(
self
.
fl_workers
)
print
(
"fl servers endpoints"
)
print
(
self
.
fl_servers
)
def
start_fl_step
(
self
):
# random select some fl_workers here
random
.
shuffle
(
self
.
workers
)
worker_dict
=
{}
for
worker
in
self
.
workers
[:
self
.
sample_worker_num
]:
worker_dict
[
worker
]
=
0
ready
=
False
ready_workers
=
[]
while
not
ready
:
key
,
value
=
recv_and_parse_kv
(
self
.
socket
)
if
key
==
"JOIN"
:
if
value
in
worker_dict
:
if
worker_dict
[
value
]
==
0
:
ready_workers
.
append
(
value
)
worker_dict
[
value
]
=
1
if
len
(
ready_workers
)
==
len
(
worker_dict
):
ready
=
True
start_workers
=
[]
while
len
(
start_workers
)
!=
len
(
ready_workers
):
key
,
value
=
recv_and_parse_kv
(
self
.
socket
)
if
key
==
"REQUEST_START"
:
if
value
in
ready_workers
:
start_workers
.
append
(
value
)
socket
.
send
(
"ACCEPT_START"
)
continue
else
:
socket
.
send
(
"alive"
)
def
start_fl_training
(
self
):
# loop until training is done
while
True
:
random
.
shuffle
(
self
.
fl_workers
)
worker_dict
=
{}
for
worker
in
self
.
fl_workers
[:
self
.
sample_worker_num
]:
worker_dict
[
worker
]
=
0
ready_workers
=
[]
all_ready_to_train
=
False
while
not
all_ready_to_train
:
key
,
value
=
recv_and_parse_kv
(
self
.
socket
)
if
key
==
"JOIN"
:
if
value
in
worker_dict
:
if
worker_dict
[
value
]
==
0
:
ready_workers
.
append
(
value
)
worker_dict
[
value
]
=
1
self
.
socket
.
send
(
"ACCEPT
\t
0"
)
continue
else
:
ready_workers
.
append
(
value
)
self
.
socket
.
send
(
"REJECT
\t
0"
)
if
len
(
ready_workers
)
==
len
(
self
.
fl_workers
):
all_ready_to_train
=
True
all_finish_training
=
False
finish_training_dict
=
{}
while
not
all_finish_training
:
key
,
value
=
recv_and_parse_kv
(
self
.
socket
)
if
key
==
"FINISH"
:
finish_training_dict
[
value
]
=
1
self
.
socket
.
send
(
"WAIT
\t
0"
)
else
:
self
.
socket
.
send
(
"REJECT
\t
0"
)
if
len
(
finish_training_dict
)
==
len
(
worker_dict
):
all_finish_training
=
True
time
.
sleep
(
5
)
paddle_fl/core/scheduler/test_agent_master.py
0 → 100644
浏览文件 @
368d69a4
import
multiprocessing
import
leveldb
import
sys
import
os
from
agent_master
import
*
def
task_func
(
task_info
):
def
init_scheduler
():
worker_num
=
10
server_num
=
10
scheduler
=
FLScheduler
(
worker_num
,
server_num
)
scheduler
.
set_sample_worker_num
()
scheduler
.
init_env
()
print
(
"init env done."
)
scheduler
.
start_fl_training
()
def
init_worker
():
agent
=
FLWorkerAgent
(
"127.0.0.1:9091"
,
"127.0.0.1:{}"
.
format
(
9000
+
task_info
[
0
]))
agent
.
connect_scheduler
()
print
(
"connected"
)
import
time
time
.
sleep
(
3
)
for
i
in
range
(
10
):
if
agent
.
can_join_training
():
# do some training here
time
.
sleep
(
3
)
agent
.
finish_training
()
else
:
print
(
"rejected"
)
time
.
sleep
(
3
)
print
(
"round {} finished"
.
format
(
i
))
def
init_server
():
agent
=
FLServerAgent
(
"127.0.0.1:9091"
,
"127.0.0.1:{}"
.
format
(
9000
+
task_info
[
0
]))
agent
.
connect_scheduler
()
if
task_info
[
1
]
==
0
:
init_scheduler
()
elif
task_info
[
1
]
==
1
:
init_worker
()
else
:
init_server
()
pool
=
multiprocessing
.
Pool
(
processes
=
21
)
port_index
=
1
task_info
=
[]
task_info
.
append
([
port_index
,
0
])
port_index
+=
1
for
i
in
range
(
10
):
task_info
.
append
([
port_index
,
1
])
port_index
+=
1
for
i
in
range
(
10
):
task_info
.
append
([
port_index
,
2
])
port_index
+=
1
results
=
pool
.
map
(
task_func
,
task_info
)
pool
.
close
()
pool
.
join
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录