提交 86df01ce 编写于 作者: G guru4elephant

add fl scheduler

1) a fl scheduler will be started before training
2) fl workers can connect fl scheduler
3) scheduling algorithms can be implemented in fl scheduler
上级 3d7e5ac0
......@@ -17,3 +17,5 @@ from .master.fl_job import FLRunTimeJob
from .master.job_generator import JobGenerator
from .strategy.fl_strategy_base import DPSGDStrategy
from .strategy.fl_strategy_base import FedAvgStrategy
from .scheduler.agent_master import FLAgent
from .scheduler.agent_master import FLScheduler
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq
import time
import random
def recv_and_parse_kv(socket):
message = socket.recv()
socket.send("alive")
group = message.split("\t")
print(group)
assert len(group) == 2
return group[0], group[1]
WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP"
class FLAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.current_ep = current_ep
def connect_scheduler(self):
self.socket.send("WORKER_EP\t{}".format(self.current_ep))
self.socket.recv()
def can_join_training(self):
self.socket.send("JOIN\t{}".format(self.current_ep))
self.socket.recv()
class FLScheduler(object):
def __init__(self, worker_num, server_num, port=9091):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind("tcp://*:{}".format(port))
self.worker_num = worker_num
self.server_num = server_num
self.sample_worker_num = 0
self.fl_workers = []
self.fl_servers = []
def set_sample_worker_num(self, sample_worker_num=0):
if sample_worker_num == 0:
self.sample_worker_num = int(self.worker_num * 0.1)
else:
self.sample_worker_num = sample_worker_num
def init_env(self):
ready = False
while not ready:
key, value = recv_and_parse_kv(self.socket)
if key == WORKER_EP:
self.fl_workers.append(value)
if key == SERVER_EP:
self.fl_servers.append(value)
if len(self.fl_workers) == self.worker_num and \
len(self.fl_servers) == self.server_num:
ready = True
print("FL training environment started")
print("fl workers endpoints")
print(self.fl_workers)
print("fl servers endpoints")
print(self.fl_servers)
def start_fl_step(self):
# random select some fl_workers here
random.shuffle(self.workers)
worker_dict = {}
for worker in self.workers[:self.sample_worker_num]:
worker_dict[worker] = 0
ready = False
ready_workers = []
while not ready:
key, value = recv_and_parse_kv(self.socket)
if key == "JOIN":
if value in worker_dict:
if worker_dict[value] == 0:
ready_workers.append(value)
worker_dict[value] = 1
if len(ready_workers) == len(worker_dict):
ready = True
start_workers = []
while len(start_workers) != len(ready_workers):
key, value = recv_and_parse_kv(self.socket)
if key == "REQUEST_START":
if value in ready_workers:
start_workers.append(value)
socket.send("ACCEPT_START")
continue
else:
socket.send("alive")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -44,10 +44,14 @@ class FLTrainer(object):
self._step = job._strategy._inner_step
self._feed_names = job._feed_names
self._target_names = job._target_names
self._scheduler_ep = job._scheduler_ep
def start(self):
current_ep = "to be added"
self.agent = FLAgent(self._scheduler_ep, current_ep)
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe.run(self._startup_program)
self.agent.connect_scheduler()
def run(self, feed, fetch):
self._logger.debug("begin to run")
......@@ -73,6 +77,8 @@ class FLTrainer(object):
# ask for termination with master endpoint
# currently not open sourced, will release the code later
# TODO(guru4elephant): add connection with master
while not self.agent.can_join_training():
return False
class FedAvgTrainer(FLTrainer):
......@@ -108,7 +114,4 @@ class FedAvgTrainer(FLTrainer):
self.exe.run(self._send_program)
self.cur_step += 1
return loss
def stop(self):
return False
......@@ -27,6 +27,7 @@ while not trainer.stop():
print("batch %d start train" % (step_i))
for data in train_reader():
#print(np.array(data['src_wordseq']))
ret_avg_cost = trainer.run(feed=data,
fetch=["mean_0.tmp_0"])
avg_ppl = np.exp(ret_avg_cost[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册