提交 a9bcf71f 编写于 作者: Q qjing666

add scheduler and update demos

上级 68175d00
sphinx==2.1.0
mistune
sphinx_rtd_theme
paddlepaddle
paddlepaddle>=1.6
......@@ -17,3 +17,6 @@ from .master.fl_job import FLRunTimeJob
from .master.job_generator import JobGenerator
from .strategy.fl_strategy_base import DPSGDStrategy
from .strategy.fl_strategy_base import FedAvgStrategy
from .scheduler.agent_master import FLServerAgent
from .scheduler.agent_master import FLWorkerAgent
from .scheduler.agent_master import FLScheduler
......@@ -176,6 +176,7 @@ class FLRunTimeJob(FLJobBase):
self._server_main_program = None
self._feed_names = None
self._target_names = None
self._scheduler_ep = None
def _load_strategy(self, input_file):
import pickle
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq
import time
import random
def recv_and_parse_kv(socket):
message = socket.recv()
group = message.split("\t")
if group[0] == "alive":
return group[0], "0"
else:
return group[0], group[1]
WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP"
class FLServerAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect("tcp://127.0.0.1:9091")
self.current_ep = current_ep
def connect_scheduler(self):
self.socket.send("SERVER_EP\t{}".format(self.current_ep))
self.socket.recv()
class FLWorkerAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect("tcp://127.0.0.1:9091")
self.current_ep = current_ep
def connect_scheduler(self):
self.socket.send("WORKER_EP\t{}".format(self.current_ep))
self.socket.recv()
def finish_training(self):
self.socket.send("FINISH\t{}".format(self.current_ep))
key, value = recv_and_parse_kv(self.socket)
if key == "WAIT":
time.sleep(3)
return True
return False
def can_join_training(self):
self.socket.send("JOIN\t{}".format(self.current_ep))
key, value = recv_and_parse_kv(self.socket)
if key == "ACCEPT":
return True
elif key == "REJECT":
return False
return False
class FLScheduler(object):
def __init__(self, worker_num, server_num, port=9091):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind("tcp://*:{}".format(port))
self.worker_num = worker_num
self.server_num = server_num
self.sample_worker_num = 0
self.fl_workers = []
self.fl_servers = []
def set_sample_worker_num(self, sample_worker_num=0):
if sample_worker_num == 0:
self.sample_worker_num = int(self.worker_num * 0.1)
else:
self.sample_worker_num = sample_worker_num
def init_env(self):
ready = False
while not ready:
key, value = recv_and_parse_kv(self.socket)
if key == WORKER_EP:
self.fl_workers.append(value)
self.socket.send("INIT\t{}".format(value))
elif key == SERVER_EP:
self.fl_servers.append(value)
self.socket.send("INIT\t{}".format(value))
else:
time.sleep(3)
self.socket.send("REJECT\t0")
if len(self.fl_workers) == self.worker_num and \
len(self.fl_servers) == self.server_num:
ready = True
def start_fl_training(self):
# loop until training is done
loop = 0
while True:
if loop <= 1:
print(loop)
random.shuffle(self.fl_workers)
worker_dict = {}
for worker in self.fl_workers[:self.sample_worker_num]:
worker_dict[worker] = 0
ready_workers = []
all_ready_to_train = False
while not all_ready_to_train:
key, value = recv_and_parse_kv(self.socket)
if key == "JOIN":
if value in worker_dict:
if worker_dict[value] == 0:
ready_workers.append(value)
worker_dict[value] = 1
self.socket.send("ACCEPT\t0")
continue
else:
if value not in ready_workers:
ready_workers.append(value)
self.socket.send("REJECT\t0")
if len(ready_workers) == len(self.fl_workers):
all_ready_to_train = True
all_finish_training = False
finish_training_dict = {}
while not all_finish_training:
key, value = recv_and_parse_kv(self.socket)
if key == "FINISH":
finish_training_dict[value] = 1
self.socket.send("WAIT\t0")
else:
self.socket.send("REJECT\t0")
if len(finish_training_dict) == len(worker_dict):
all_finish_training = True
time.sleep(5)
loop += 1
......@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle_fl.core.scheduler.agent_master import FLServerAgent
class FLServer(object):
def __init__(self):
self._startup_program = None
self._main_program = None
self._scheduler_ep = None
self._current_ep = None
def set_server_job(self, job):
# need to parse startup and main program in job
......@@ -25,9 +28,12 @@ class FLServer(object):
# need to parse master endpoint
self._startup_program = job._server_startup_program
self._main_program = job._server_main_program
self._scheduler_ep = job._scheduler_ep
self._current_ep = None
def start(self):
self.agent = FLServerAgent(self._scheduler_ep, self._current_ep)
self.agent.connect_scheduler()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(self._startup_program)
exe.run(self._main_program)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle.fluid as fluid
import logging
from paddle_fl.core.scheduler.agent_master import FLWorkerAgent
class FLTrainerFactory(object):
def __init__(self):
......@@ -44,17 +45,24 @@ class FLTrainer(object):
self._step = job._strategy._inner_step
self._feed_names = job._feed_names
self._target_names = job._target_names
self._scheduler_ep = job._scheduler_ep
self._current_ep = None
self.cur_step = 0
def start(self):
#current_ep = "to be added"
self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
self.agent.connect_scheduler()
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe.run(self._startup_program)
def run(self, feed, fetch):
self._logger.debug("begin to run")
self.exe.run(self._main_program,
feed=feed,
fetch_list=fetch)
feed=feed,
fetch_list=fetch)
self._logger.debug("end to run current batch")
self.cur_step += 1
def save_inference_program(self, output_folder):
target_vars = []
......@@ -73,7 +81,15 @@ class FLTrainer(object):
# ask for termination with master endpoint
# currently not open sourced, will release the code later
# TODO(guru4elephant): add connection with master
return False
if self.cur_step != 0:
while not self.agent.finish_training():
print('wait others finish')
continue
while not self.agent.can_join_training():
print("wait permit")
continue
print("ready to train")
return False
class FedAvgTrainer(FLTrainer):
def __init__(self):
......@@ -81,9 +97,11 @@ class FedAvgTrainer(FLTrainer):
pass
def start(self):
#current_ep = "to be added"
self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
self.agent.connect_scheduler()
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe.run(self._startup_program)
self.cur_step = 0
def set_trainer_job(self, job):
super(FedAvgTrainer, self).set_trainer_job(job)
......@@ -108,7 +126,4 @@ class FedAvgTrainer(FLTrainer):
self.exe.run(self._send_program)
self.cur_step += 1
return loss
def stop(self):
return False
......@@ -38,7 +38,7 @@ job_generator.set_infer_feed_and_target_names(
[x.name for x in inputs], [model.predict.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
build_strategy.fed_avg = True
build_strategy.inner_step = 10
strategy = build_strategy.create_fl_strategy()
......@@ -47,5 +47,5 @@ strategy = build_strategy.create_fl_strategy()
endpoints = ["127.0.0.1:8181"]
output = "fl_job_config"
job_generator.generate_fl_job(
strategy, server_endpoints=endpoints, worker_num=2, output=output)
strategy, server_endpoints=endpoints, worker_num=5, output=output)
# fl_job_config will be dispatched to workers
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 5
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
scheduler.set_sample_worker_num(5)
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
......@@ -21,5 +21,8 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091"
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181"
server.start()
print("connect")
......@@ -3,6 +3,7 @@ from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
import sys
import logging
import time
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
......@@ -18,15 +19,22 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
step_i = 0
while not trainer.stop():
step_i += 1
print("batch %d start train" % (step_i))
train_step = 0
for data in reader():
trainer.run(feed=data, fetch=[])
trainer.run(feed=data, fetch=[])
train_step += 1
if train_step == trainer._step:
break
step_i += 1
if step_i % 100 == 0:
trainer.save_inference_program(output_folder)
......@@ -2,8 +2,13 @@ unset http_proxy
unset https_proxy
python fl_master.py
sleep 2
python -u fl_scheduler.py > scheduler.log &
sleep 5
python -u fl_server.py >server0.log &
sleep 2
python -u fl_trainer.py 0 >trainer0.log &
sleep 2
python -u fl_trainer.py 1 >trainer1.log &
for ((i=0;i<5;i++))
do
python -u fl_trainer.py $i >trainer$i.log &
sleep 2
done
......@@ -49,5 +49,5 @@ strategy.sigma = CLIP * SIGMA
endpoints = ["127.0.0.1:8181"]
output = "fl_job_config"
job_generator.generate_fl_job(
strategy, server_endpoints=endpoints, worker_num=2, output=output)
strategy, server_endpoints=endpoints, worker_num=4, output=output)
# fl_job_config will be dispatched to workers
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
......@@ -21,5 +21,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091"
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181"
server.start()
......@@ -13,7 +13,9 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
test_program = trainer._main_program.clone(for_test=True)
......
python fl_master.py
sleep 2
python -u fl_scheduler.py >scheduler.log &
sleep 2
python -u fl_server.py >server0.log &
sleep 2
python -u fl_trainer.py 0 >trainer0.log &
sleep 2
python -u fl_trainer.py 1 >trainer1.log &
sleep 2
python -u fl_trainer.py 2 >trainer2.log &
sleep 2
python -u fl_trainer.py 3 >trainer3.log &
......@@ -73,7 +73,7 @@ job_generator.set_infer_feed_and_target_names(
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
build_strategy.inner_step = 1
build_strategy.inner_step = 10
strategy = build_strategy.create_fl_strategy()
# endpoints will be collected through the cluster
......
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
......@@ -21,5 +21,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091"
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181"
server.start()
......@@ -14,7 +14,9 @@ train_file_dir = "mid_data/node4/%d/" % trainer_id
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
r = Gru4rec_Reader()
......@@ -25,10 +27,14 @@ step_i = 0
while not trainer.stop():
step_i += 1
print("batch %d start train" % (step_i))
train_step = 0
for data in train_reader():
#print(np.array(data['src_wordseq']))
ret_avg_cost = trainer.run(feed=data,
fetch=["mean_0.tmp_0"])
train_step += 1
if train_step == trainer._step:
break
avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl)
print("ppl:%.3f" % (newest_ppl))
......
......@@ -2,6 +2,7 @@ unset http_proxy
unset https_proxy
python fl_master.py
sleep 2
python -u fl_scheduler.py >scheduler.log &
python -u fl_server.py >server0.log &
sleep 2
python -u fl_trainer.py 0 >trainer0.log &
......
......@@ -37,6 +37,8 @@ if max_version < 3:
else:
REQUIRED_PACKAGES += ["numpy"]
REQUIRED_PACKAGES += ["unittest2"]
setup(
name='paddle_fl',
version=fl_version.replace('-', ''),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册