提交 368d69a4 编写于 作者: G guru4elephant

refine fl scheduler

上级 ed9ec58d
...@@ -2,32 +2,58 @@ import zmq ...@@ -2,32 +2,58 @@ import zmq
import time import time
import random import random
def recv_and_parse_kv(socket): def recv_and_parse_kv(socket):
message = socket.recv() message = socket.recv()
socket.send("alive")
group = message.split("\t") group = message.split("\t")
print(group) if group[0] == "alive":
assert len(group) == 2 return group[0], "0"
return group[0], group[1] else:
return group[0], group[1]
WORKER_EP = "WORKER_EP" WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP" SERVER_EP = "SERVER_EP"
class FLAgent(object): class FLServerAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect("tcp://127.0.0.1:9091")
self.current_ep = current_ep
def connect_scheduler(self):
self.socket.send("SERVER_EP\t{}".format(self.current_ep))
self.socket.recv()
class FLWorkerAgent(object):
def __init__(self, scheduler_ep, current_ep): def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep self.scheduler_ep = scheduler_ep
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ) self.socket = self.context.socket(zmq.REQ)
self.socket.connect("tcp://127.0.0.1:9091")
self.current_ep = current_ep self.current_ep = current_ep
def connect_scheduler(self): def connect_scheduler(self):
self.socket.send("WORKER_EP\t{}".format(self.current_ep)) self.socket.send("WORKER_EP\t{}".format(self.current_ep))
self.socket.recv() self.socket.recv()
def finish_training(self):
self.socket.send("FINISH\t{}".format(self.current_ep))
key, value = recv_and_parse_kv(self.socket)
if key == "WAIT":
time.sleep(3)
def can_join_training(self): def can_join_training(self):
self.socket.send("JOIN\t{}".format(self.current_ep)) self.socket.send("JOIN\t{}".format(self.current_ep))
self.socket.recv() key, value = recv_and_parse_kv(self.socket)
if key == "ACCEPT":
return True
elif key == "REJECT":
return False
return False
class FLScheduler(object): class FLScheduler(object):
...@@ -53,43 +79,50 @@ class FLScheduler(object): ...@@ -53,43 +79,50 @@ class FLScheduler(object):
key, value = recv_and_parse_kv(self.socket) key, value = recv_and_parse_kv(self.socket)
if key == WORKER_EP: if key == WORKER_EP:
self.fl_workers.append(value) self.fl_workers.append(value)
self.socket.send("INIT\t{}".format(value))
if key == SERVER_EP: if key == SERVER_EP:
self.fl_servers.append(value) self.fl_servers.append(value)
self.socket.send("INIT\t{}".format(value))
if len(self.fl_workers) == self.worker_num and \ if len(self.fl_workers) == self.worker_num and \
len(self.fl_servers) == self.server_num: len(self.fl_servers) == self.server_num:
ready = True ready = True
print("FL training environment started") def start_fl_training(self):
print("fl workers endpoints") # loop until training is done
print(self.fl_workers) while True:
print("fl servers endpoints") random.shuffle(self.fl_workers)
print(self.fl_servers) worker_dict = {}
for worker in self.fl_workers[:self.sample_worker_num]:
def start_fl_step(self): worker_dict[worker] = 0
# random select some fl_workers here
random.shuffle(self.workers) ready_workers = []
worker_dict = {} all_ready_to_train = False
for worker in self.workers[:self.sample_worker_num]: while not all_ready_to_train:
worker_dict[worker] = 0 key, value = recv_and_parse_kv(self.socket)
ready = False if key == "JOIN":
ready_workers = [] if value in worker_dict:
while not ready: if worker_dict[value] == 0:
key, value = recv_and_parse_kv(self.socket) ready_workers.append(value)
if key == "JOIN": worker_dict[value] = 1
if value in worker_dict: self.socket.send("ACCEPT\t0")
if worker_dict[value] == 0: continue
ready_workers.append(value) else:
worker_dict[value] = 1 ready_workers.append(value)
if len(ready_workers) == len(worker_dict): self.socket.send("REJECT\t0")
ready = True
start_workers = [] if len(ready_workers) == len(self.fl_workers):
while len(start_workers) != len(ready_workers): all_ready_to_train = True
key, value = recv_and_parse_kv(self.socket)
if key == "REQUEST_START": all_finish_training = False
if value in ready_workers: finish_training_dict = {}
start_workers.append(value) while not all_finish_training:
socket.send("ACCEPT_START") key, value = recv_and_parse_kv(self.socket)
continue if key == "FINISH":
else: finish_training_dict[value] = 1
socket.send("alive") self.socket.send("WAIT\t0")
else:
self.socket.send("REJECT\t0")
if len(finish_training_dict) == len(worker_dict):
all_finish_training = True
time.sleep(5)
import multiprocessing
import leveldb
import sys
import os
from agent_master import *
def task_func(task_info):
def init_scheduler():
worker_num = 10
server_num = 10
scheduler = FLScheduler(worker_num, server_num)
scheduler.set_sample_worker_num()
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
def init_worker():
agent = FLWorkerAgent("127.0.0.1:9091", "127.0.0.1:{}".format(9000 + task_info[0]))
agent.connect_scheduler()
print("connected")
import time
time.sleep(3)
for i in range(10):
if agent.can_join_training():
# do some training here
time.sleep(3)
agent.finish_training()
else:
print("rejected")
time.sleep(3)
print("round {} finished".format(i))
def init_server():
agent = FLServerAgent("127.0.0.1:9091", "127.0.0.1:{}".format(9000 + task_info[0]))
agent.connect_scheduler()
if task_info[1] == 0:
init_scheduler()
elif task_info[1] == 1:
init_worker()
else:
init_server()
pool = multiprocessing.Pool(processes=21)
port_index = 1
task_info = []
task_info.append([port_index, 0])
port_index += 1
for i in range(10):
task_info.append([port_index, 1])
port_index += 1
for i in range(10):
task_info.append([port_index, 2])
port_index += 1
results = pool.map(task_func, task_info)
pool.close()
pool.join()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册