agent_master.py 4.9 KB
Newer Older
Q
qjing666 已提交
1 2 3 4
import zmq
import time
import random

Q
qjing666 已提交
5

Q
qjing666 已提交
6 7
def recv_and_parse_kv(socket):
    message = socket.recv()
Q
qjing666 已提交
8
    group = message.decode().split("\t")
Q
qjing666 已提交
9 10 11 12 13
    if group[0] == "alive":
        return group[0], "0"
    else:
        return group[0], group[1]

Q
qjing666 已提交
14

Q
qjing666 已提交
15 16 17
WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP"

Q
qjing666 已提交
18

Q
qjing666 已提交
19 20 21 22 23
class FLServerAgent(object):
    def __init__(self, scheduler_ep, current_ep):
        self.scheduler_ep = scheduler_ep
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.REQ)
Q
qjing666 已提交
24
        self.socket.connect("tcp://{}".format(scheduler_ep))
Q
qjing666 已提交
25 26 27
        self.current_ep = current_ep

    def connect_scheduler(self):
Q
qjing666 已提交
28
        while True:
Q
qjing666 已提交
29
            self.socket.send_string("SERVER_EP\t{}".format(self.current_ep))
Q
qjing666 已提交
30
            message = self.socket.recv()
Q
qjing666 已提交
31
            group = message.decode().split("\t")
Q
qjing666 已提交
32 33
            if group[0] == 'INIT':
                break
Q
qjing666 已提交
34

Q
qjing666 已提交
35

Q
qjing666 已提交
36 37 38 39 40
class FLWorkerAgent(object):
    def __init__(self, scheduler_ep, current_ep):
        self.scheduler_ep = scheduler_ep
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.REQ)
Q
qjing666 已提交
41
        self.socket.connect("tcp://{}".format(scheduler_ep))
Q
qjing666 已提交
42 43 44
        self.current_ep = current_ep

    def connect_scheduler(self):
Q
qjing666 已提交
45
        while True:
Q
qjing666 已提交
46
            self.socket.send_string("WORKER_EP\t{}".format(self.current_ep))
Q
qjing666 已提交
47
            message = self.socket.recv()
Q
qjing666 已提交
48
            group = message.decode().split("\t")
Q
qjing666 已提交
49 50
            if group[0] == 'INIT':
                break
Q
qjing666 已提交
51 52

    def finish_training(self):
Q
qjing666 已提交
53
        self.socket.send_string("FINISH\t{}".format(self.current_ep))
Q
qjing666 已提交
54 55 56
        key, value = recv_and_parse_kv(self.socket)
        if key == "WAIT":
            time.sleep(3)
G
giddenslee 已提交
57 58
            return True
        return False
Q
qjing666 已提交
59 60

    def can_join_training(self):
Q
qjing666 已提交
61
        self.socket.send_string("JOIN\t{}".format(self.current_ep))
Q
qjing666 已提交
62 63 64 65 66 67 68 69 70 71
        key, value = recv_and_parse_kv(self.socket)

        if key == "ACCEPT":
            return True
        elif key == "REJECT":
            return False
        return False


class FLScheduler(object):
Q
qjing666 已提交
72
    def __init__(self, worker_num, server_num, port=9091, socket=None):
Q
qjing666 已提交
73
        self.context = zmq.Context()
Q
qjing666 已提交
74 75 76 77 78
        if socket == None:
            self.socket = self.context.socket(zmq.REP)
            self.socket.bind("tcp://*:{}".format(port))
        else:
            self.socket = socket
Q
qjing666 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        self.worker_num = worker_num
        self.server_num = server_num
        self.sample_worker_num = 0
        self.fl_workers = []
        self.fl_servers = []

    def set_sample_worker_num(self, sample_worker_num=0):
        if sample_worker_num == 0:
            self.sample_worker_num = int(self.worker_num * 0.1)
        else:
            self.sample_worker_num = sample_worker_num

    def init_env(self):
        ready = False
        while not ready:
            key, value = recv_and_parse_kv(self.socket)
            if key == WORKER_EP:
                self.fl_workers.append(value)
Q
qjing666 已提交
97
                self.socket.send_string("INIT\t{}".format(value))
Q
qjing666 已提交
98 99
            elif key == SERVER_EP:
                self.fl_servers.append(value)
Q
qjing666 已提交
100
                self.socket.send_string("INIT\t{}".format(value))
Q
qjing666 已提交
101 102
            else:
                time.sleep(3)
Q
qjing666 已提交
103
                self.socket.send_string("REJECT\t0")
Q
qjing666 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
            if len(self.fl_workers) == self.worker_num and \
               len(self.fl_servers) == self.server_num:
                ready = True

    def start_fl_training(self):
        # loop until training is done
        while True:
            random.shuffle(self.fl_workers)
            worker_dict = {}
            for worker in self.fl_workers[:self.sample_worker_num]:
                worker_dict[worker] = 0

            ready_workers = []
            all_ready_to_train = False
            while not all_ready_to_train:
                key, value = recv_and_parse_kv(self.socket)
                if key == "JOIN":
                    if value in worker_dict:
                        if worker_dict[value] == 0:
                            ready_workers.append(value)
                            worker_dict[value] = 1
Q
qjing666 已提交
125
                            self.socket.send_string("ACCEPT\t0")
Q
qjing666 已提交
126 127
                            continue
                    else:
G
giddenslee 已提交
128 129
                        if value not in ready_workers:
                            ready_workers.append(value)
Q
qjing666 已提交
130
                self.socket.send_string("REJECT\t0")
Q
qjing666 已提交
131 132 133 134 135 136 137
                if len(ready_workers) == len(self.fl_workers):
                    all_ready_to_train = True

            all_finish_training = False
            finish_training_dict = {}
            while not all_finish_training:
                key, value = recv_and_parse_kv(self.socket)
G
giddenslee 已提交
138
                if key == "FINISH":
Q
qjing666 已提交
139
                    finish_training_dict[value] = 1
Q
qjing666 已提交
140
                    self.socket.send_string("WAIT\t0")
Q
qjing666 已提交
141
                else:
Q
qjing666 已提交
142
                    self.socket.send_string("REJECT\t0")
Q
qjing666 已提交
143 144 145
                if len(finish_training_dict) == len(worker_dict):
                    all_finish_training = True
            time.sleep(5)