From 368d69a48b8502941c6273378bcc7d2f78613910 Mon Sep 17 00:00:00 2001 From: guru4elephant Date: Fri, 15 Nov 2019 14:37:48 +0800 Subject: [PATCH] refine fl scheduler --- paddle_fl/core/scheduler/agent_master.py | 113 +++++++++++------- paddle_fl/core/scheduler/test_agent_master.py | 59 +++++++++ 2 files changed, 132 insertions(+), 40 deletions(-) create mode 100644 paddle_fl/core/scheduler/test_agent_master.py diff --git a/paddle_fl/core/scheduler/agent_master.py b/paddle_fl/core/scheduler/agent_master.py index 4807fb5..ad105c5 100644 --- a/paddle_fl/core/scheduler/agent_master.py +++ b/paddle_fl/core/scheduler/agent_master.py @@ -2,32 +2,58 @@ import zmq import time import random - def recv_and_parse_kv(socket): message = socket.recv() - socket.send("alive") group = message.split("\t") - print(group) - assert len(group) == 2 - return group[0], group[1] + if group[0] == "alive": + return group[0], "0" + else: + return group[0], group[1] WORKER_EP = "WORKER_EP" SERVER_EP = "SERVER_EP" -class FLAgent(object): +class FLServerAgent(object): + def __init__(self, scheduler_ep, current_ep): + self.scheduler_ep = scheduler_ep + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REQ) + self.socket.connect("tcp://127.0.0.1:9091") + self.current_ep = current_ep + + def connect_scheduler(self): + self.socket.send("SERVER_EP\t{}".format(self.current_ep)) + self.socket.recv() + + +class FLWorkerAgent(object): def __init__(self, scheduler_ep, current_ep): self.scheduler_ep = scheduler_ep self.context = zmq.Context() self.socket = self.context.socket(zmq.REQ) + self.socket.connect("tcp://127.0.0.1:9091") self.current_ep = current_ep def connect_scheduler(self): self.socket.send("WORKER_EP\t{}".format(self.current_ep)) self.socket.recv() + def finish_training(self): + self.socket.send("FINISH\t{}".format(self.current_ep)) + key, value = recv_and_parse_kv(self.socket) + if key == "WAIT": + time.sleep(3) + def can_join_training(self): self.socket.send("JOIN\t{}".format(self.current_ep)) - self.socket.recv() + key, value = recv_and_parse_kv(self.socket) + + if key == "ACCEPT": + return True + elif key == "REJECT": + return False + return False + class FLScheduler(object): @@ -53,43 +79,50 @@ class FLScheduler(object): key, value = recv_and_parse_kv(self.socket) if key == WORKER_EP: self.fl_workers.append(value) + self.socket.send("INIT\t{}".format(value)) if key == SERVER_EP: self.fl_servers.append(value) + self.socket.send("INIT\t{}".format(value)) if len(self.fl_workers) == self.worker_num and \ len(self.fl_servers) == self.server_num: ready = True - print("FL training environment started") - print("fl workers endpoints") - print(self.fl_workers) - print("fl servers endpoints") - print(self.fl_servers) - - def start_fl_step(self): - # random select some fl_workers here - random.shuffle(self.workers) - worker_dict = {} - for worker in self.workers[:self.sample_worker_num]: - worker_dict[worker] = 0 - ready = False - ready_workers = [] - while not ready: - key, value = recv_and_parse_kv(self.socket) - if key == "JOIN": - if value in worker_dict: - if worker_dict[value] == 0: - ready_workers.append(value) - worker_dict[value] = 1 - if len(ready_workers) == len(worker_dict): - ready = True - start_workers = [] - while len(start_workers) != len(ready_workers): - key, value = recv_and_parse_kv(self.socket) - if key == "REQUEST_START": - if value in ready_workers: - start_workers.append(value) - socket.send("ACCEPT_START") - continue - else: - socket.send("alive") + def start_fl_training(self): + # loop until training is done + while True: + random.shuffle(self.fl_workers) + worker_dict = {} + for worker in self.fl_workers[:self.sample_worker_num]: + worker_dict[worker] = 0 + + ready_workers = [] + all_ready_to_train = False + while not all_ready_to_train: + key, value = recv_and_parse_kv(self.socket) + if key == "JOIN": + if value in worker_dict: + if worker_dict[value] == 0: + ready_workers.append(value) + worker_dict[value] = 1 + self.socket.send("ACCEPT\t0") + continue + else: + ready_workers.append(value) + self.socket.send("REJECT\t0") + + if len(ready_workers) == len(self.fl_workers): + all_ready_to_train = True + + all_finish_training = False + finish_training_dict = {} + while not all_finish_training: + key, value = recv_and_parse_kv(self.socket) + if key == "FINISH": + finish_training_dict[value] = 1 + self.socket.send("WAIT\t0") + else: + self.socket.send("REJECT\t0") + if len(finish_training_dict) == len(worker_dict): + all_finish_training = True + time.sleep(5) diff --git a/paddle_fl/core/scheduler/test_agent_master.py b/paddle_fl/core/scheduler/test_agent_master.py new file mode 100644 index 0000000..6f7cd7d --- /dev/null +++ b/paddle_fl/core/scheduler/test_agent_master.py @@ -0,0 +1,59 @@ +import multiprocessing +import leveldb +import sys +import os +from agent_master import * + +def task_func(task_info): + def init_scheduler(): + worker_num = 10 + server_num = 10 + scheduler = FLScheduler(worker_num, server_num) + scheduler.set_sample_worker_num() + scheduler.init_env() + print("init env done.") + scheduler.start_fl_training() + + def init_worker(): + agent = FLWorkerAgent("127.0.0.1:9091", "127.0.0.1:{}".format(9000 + task_info[0])) + agent.connect_scheduler() + print("connected") + import time + time.sleep(3) + + for i in range(10): + if agent.can_join_training(): + # do some training here + time.sleep(3) + agent.finish_training() + else: + print("rejected") + time.sleep(3) + print("round {} finished".format(i)) + + def init_server(): + agent = FLServerAgent("127.0.0.1:9091", "127.0.0.1:{}".format(9000 + task_info[0])) + agent.connect_scheduler() + + if task_info[1] == 0: + init_scheduler() + elif task_info[1] == 1: + init_worker() + else: + init_server() + +pool = multiprocessing.Pool(processes=21) +port_index = 1 +task_info = [] +task_info.append([port_index, 0]) +port_index += 1 +for i in range(10): + task_info.append([port_index, 1]) + port_index += 1 +for i in range(10): + task_info.append([port_index, 2]) + port_index += 1 + +results = pool.map(task_func, task_info) +pool.close() +pool.join() -- GitLab