提交 6b1fb7cc 编写于 作者: Q qjing666

fix code style

上级 59998725
...@@ -3,4 +3,3 @@ mistune ...@@ -3,4 +3,3 @@ mistune
sphinx_rtd_theme sphinx_rtd_theme
paddlepaddle>=1.6 paddlepaddle>=1.6
zmq zmq
...@@ -181,4 +181,3 @@ while not trainer.stop(): ...@@ -181,4 +181,3 @@ while not trainer.stop():
To show the effectiveness of DPSGD-based federated learning with PaddleFL, a simulated experiment is conducted on an open source dataset MNIST. From the figure given below, model evaluation results are similar between DPSGD-based federated learning and traditional parameter server training when the overall privacy budget *epsilon* is 1.3 or 0.13. To show the effectiveness of DPSGD-based federated learning with PaddleFL, a simulated experiment is conducted on an open source dataset MNIST. From the figure given below, model evaluation results are similar between DPSGD-based federated learning and traditional parameter server training when the overall privacy budget *epsilon* is 1.3 or 0.13.
<img src="fl_dpsgd_benchmark.png" height=400 width=600 hspace='10'/> <br /> <img src="fl_dpsgd_benchmark.png" height=400 width=600 hspace='10'/> <br />
...@@ -103,6 +103,3 @@ wget https://paddle-zwh.bj.bcebos.com/gru4rec_paddlefl_benchmark/gru4rec_benchma ...@@ -103,6 +103,3 @@ wget https://paddle-zwh.bj.bcebos.com/gru4rec_paddlefl_benchmark/gru4rec_benchma
| 1/4 of the whole dataset | private training | - | 0.282 | | 1/4 of the whole dataset | private training | - | 0.282 |
<img src="fl_benchmark.png" height=300 width=500 hspace='10'/> <br /> <img src="fl_benchmark.png" height=300 width=500 hspace='10'/> <br />
...@@ -55,4 +55,3 @@ In PaddleFL, components for defining a federated learning task and training a fe ...@@ -55,4 +55,3 @@ In PaddleFL, components for defining a federated learning task and training a fe
- Federated Learning Systems deployment methods in Kubernetes. - Federated Learning Systems deployment methods in Kubernetes.
- Vertical Federated Learning Strategies and more horizontal federated learning strategies will be open sourced. - Vertical Federated Learning Strategies and more horizontal federated learning strategies will be open sourced.
...@@ -22,4 +22,3 @@ from .scheduler.agent_master import FLWorkerAgent ...@@ -22,4 +22,3 @@ from .scheduler.agent_master import FLWorkerAgent
from .scheduler.agent_master import FLScheduler from .scheduler.agent_master import FLScheduler
from .submitter.client_base import HPCClient from .submitter.client_base import HPCClient
from .submitter.client_base import CloudClient from .submitter.client_base import CloudClient
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
import os import os
import paddle.fluid as fluid import paddle.fluid as fluid
class FLJobBase(object): class FLJobBase(object):
""" """
FLJobBase is fl job base class, responsible for save and load FLJobBase is fl job base class, responsible for save and load
a federated learning job a federated learning job
""" """
def __init__(self): def __init__(self):
pass pass
...@@ -64,6 +66,7 @@ class FLJobBase(object): ...@@ -64,6 +66,7 @@ class FLJobBase(object):
return fluid.Program.parse_from_string(program_desc_str) return fluid.Program.parse_from_string(program_desc_str)
return None return None
class FLCompileTimeJob(FLJobBase): class FLCompileTimeJob(FLJobBase):
""" """
FLCompileTimeJob is a container for compile time job in federated learning. FLCompileTimeJob is a container for compile time job in federated learning.
...@@ -71,6 +74,7 @@ class FLCompileTimeJob(FLJobBase): ...@@ -71,6 +74,7 @@ class FLCompileTimeJob(FLJobBase):
are in FLCompileTimeJob. Also, server main programs and server startup programs are in FLCompileTimeJob. Also, server main programs and server startup programs
are in this class. FLCompileTimeJob has server endpoints for debugging as well are in this class. FLCompileTimeJob has server endpoints for debugging as well
""" """
def __init__(self): def __init__(self):
self._trainer_startup_programs = [] self._trainer_startup_programs = []
self._trainer_recv_programs = [] self._trainer_recv_programs = []
...@@ -101,18 +105,15 @@ class FLCompileTimeJob(FLJobBase): ...@@ -101,18 +105,15 @@ class FLCompileTimeJob(FLJobBase):
os.system("mkdir -p %s" % server_folder) os.system("mkdir -p %s" % server_folder)
server_startup = self._server_startup_programs[i] server_startup = self._server_startup_programs[i]
server_main = self._server_main_programs[i] server_main = self._server_main_programs[i]
self._save_program( self._save_program(server_startup,
server_startup,
"%s/server.startup.program" % server_folder) "%s/server.startup.program" % server_folder)
self._save_program( self._save_program(server_main,
server_main,
"%s/server.main.program" % server_folder) "%s/server.main.program" % server_folder)
self._save_readable_program(server_startup,
"%s/server.startup.program.txt" %
server_folder)
self._save_readable_program( self._save_readable_program(
server_startup, server_main, "%s/server.main.program.txt" % server_folder)
"%s/server.startup.program.txt" % server_folder)
self._save_readable_program(
server_main,
"%s/server.main.program.txt" % server_folder)
self._save_str_list(self._feed_names, self._save_str_list(self._feed_names,
"%s/feed_names" % server_folder) "%s/feed_names" % server_folder)
self._save_str_list(self._target_names, self._save_str_list(self._target_names,
...@@ -127,18 +128,15 @@ class FLCompileTimeJob(FLJobBase): ...@@ -127,18 +128,15 @@ class FLCompileTimeJob(FLJobBase):
os.system("mkdir -p %s" % trainer_folder) os.system("mkdir -p %s" % trainer_folder)
trainer_startup = self._trainer_startup_programs[i] trainer_startup = self._trainer_startup_programs[i]
trainer_main = self._trainer_main_programs[i] trainer_main = self._trainer_main_programs[i]
self._save_program( self._save_program(trainer_startup,
trainer_startup,
"%s/trainer.startup.program" % trainer_folder) "%s/trainer.startup.program" % trainer_folder)
self._save_program( self._save_program(trainer_main,
trainer_main,
"%s/trainer.main.program" % trainer_folder) "%s/trainer.main.program" % trainer_folder)
self._save_readable_program(trainer_startup,
"%s/trainer.startup.program.txt" %
trainer_folder)
self._save_readable_program( self._save_readable_program(
trainer_startup, trainer_main, "%s/trainer.main.program.txt" % trainer_folder)
"%s/trainer.startup.program.txt" % trainer_folder)
self._save_readable_program(
trainer_main,
"%s/trainer.main.program.txt" % trainer_folder)
self._save_str_list(self._feed_names, self._save_str_list(self._feed_names,
"%s/feed_names" % trainer_folder) "%s/feed_names" % trainer_folder)
self._save_str_list(self._target_names, self._save_str_list(self._target_names,
...@@ -152,18 +150,14 @@ class FLCompileTimeJob(FLJobBase): ...@@ -152,18 +150,14 @@ class FLCompileTimeJob(FLJobBase):
trainer_folder = "%s/trainer%d" % (folder, i) trainer_folder = "%s/trainer%d" % (folder, i)
trainer_send = self._trainer_send_programs[i] trainer_send = self._trainer_send_programs[i]
trainer_recv = self._trainer_recv_programs[i] trainer_recv = self._trainer_recv_programs[i]
self._save_program( self._save_program(trainer_send,
trainer_send,
"%s/trainer.send.program" % trainer_folder) "%s/trainer.send.program" % trainer_folder)
self._save_program( self._save_program(trainer_recv,
trainer_recv,
"%s/trainer.recv.program" % trainer_folder) "%s/trainer.recv.program" % trainer_folder)
self._save_readable_program( self._save_readable_program(
trainer_send, trainer_send, "%s/trainer.send.program.txt" % trainer_folder)
"%s/trainer.send.program.txt" % trainer_folder)
self._save_readable_program( self._save_readable_program(
trainer_recv, trainer_recv, "%s/trainer.recv.program.txt" % trainer_folder)
"%s/trainer.recv.program.txt" % trainer_folder)
class FLRunTimeJob(FLJobBase): class FLRunTimeJob(FLJobBase):
...@@ -172,6 +166,7 @@ class FLRunTimeJob(FLJobBase): ...@@ -172,6 +166,7 @@ class FLRunTimeJob(FLJobBase):
A trainer or a server can load FLRunTimeJob. Only necessary programs A trainer or a server can load FLRunTimeJob. Only necessary programs
can be loaded in FLRunTimeJob can be loaded in FLRunTimeJob
""" """
def __init__(self): def __init__(self):
self._trainer_startup_program = None self._trainer_startup_program = None
self._trainer_recv_program = None self._trainer_recv_program = None
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
from .fl_job import FLCompileTimeJob from .fl_job import FLCompileTimeJob
class JobGenerator(object): class JobGenerator(object):
""" """
A JobGenerator is responsible for generating distributed federated A JobGenerator is responsible for generating distributed federated
...@@ -21,6 +22,7 @@ class JobGenerator(object): ...@@ -21,6 +22,7 @@ class JobGenerator(object):
need to define a deep learning model together to do horizontal federated need to define a deep learning model together to do horizontal federated
learning. learning.
""" """
def __init__(self): def __init__(self):
# worker num for federated learning # worker num for federated learning
self._worker_num = 0 self._worker_num = 0
...@@ -32,7 +34,6 @@ class JobGenerator(object): ...@@ -32,7 +34,6 @@ class JobGenerator(object):
self._feed_names = [] self._feed_names = []
self._target_names = [] self._target_names = []
def set_optimizer(self, optimizer): def set_optimizer(self, optimizer):
""" """
Set optimizer of current job Set optimizer of current job
...@@ -56,8 +57,10 @@ class JobGenerator(object): ...@@ -56,8 +57,10 @@ class JobGenerator(object):
self._startup_prog = startup self._startup_prog = startup
def set_infer_feed_and_target_names(self, feed_names, target_names): def set_infer_feed_and_target_names(self, feed_names, target_names):
if not isinstance(feed_names, list) or not isinstance(target_names, list): if not isinstance(feed_names, list) or not isinstance(target_names,
raise ValueError("input should be list in set_infer_feed_and_target_names") list):
raise ValueError(
"input should be list in set_infer_feed_and_target_names")
''' '''
print(feed_names) print(feed_names)
print(target_names) print(target_names)
...@@ -76,7 +79,6 @@ class JobGenerator(object): ...@@ -76,7 +79,6 @@ class JobGenerator(object):
server_endpoints=[], server_endpoints=[],
worker_num=1, worker_num=1,
output=None): output=None):
""" """
Generate Federated Learning Job, based on user defined configs Generate Federated Learning Job, based on user defined configs
...@@ -130,24 +132,29 @@ class JobGenerator(object): ...@@ -130,24 +132,29 @@ class JobGenerator(object):
startup_program = self._startup_prog.clone() startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone() main_program = self._losses[0].block.program.clone()
fl_strategy._build_trainer_program_for_job( fl_strategy._build_trainer_program_for_job(
trainer_id, program=main_program, trainer_id,
ps_endpoints=server_endpoints, trainers=worker_num, program=main_program,
sync_mode=True, startup_program=startup_program, ps_endpoints=server_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job) job=local_job)
startup_program = self._startup_prog.clone() startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone() main_program = self._losses[0].block.program.clone()
fl_strategy._build_server_programs_for_job( fl_strategy._build_server_programs_for_job(
program=main_program, ps_endpoints=server_endpoints, program=main_program,
trainers=worker_num, sync_mode=True, ps_endpoints=server_endpoints,
startup_program=startup_program, job=local_job) trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
local_job.set_feed_names(self._feed_names) local_job.set_feed_names(self._feed_names)
local_job.set_target_names(self._target_names) local_job.set_target_names(self._target_names)
local_job.set_strategy(fl_strategy) local_job.set_strategy(fl_strategy)
local_job.save(output) local_job.save(output)
def generate_fl_job_for_k8s(self, def generate_fl_job_for_k8s(self,
fl_strategy, fl_strategy,
server_pod_endpoints=[], server_pod_endpoints=[],
...@@ -168,17 +175,23 @@ class JobGenerator(object): ...@@ -168,17 +175,23 @@ class JobGenerator(object):
startup_program = self._startup_prog.clone() startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone() main_program = self._losses[0].block.program.clone()
fl_strategy._build_trainer_program_for_job( fl_strategy._build_trainer_program_for_job(
trainer_id, program=main_program, trainer_id,
ps_endpoints=server_service_endpoints, trainers=worker_num, program=main_program,
sync_mode=True, startup_program=startup_program, ps_endpoints=server_service_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job) job=local_job)
startup_program = self._startup_prog.clone() startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone() main_program = self._losses[0].block.program.clone()
fl_strategy._build_server_programs_for_job( fl_strategy._build_server_programs_for_job(
program=main_program, ps_endpoints=server_pod_endpoints, program=main_program,
trainers=worker_num, sync_mode=True, ps_endpoints=server_pod_endpoints,
startup_program=startup_program, job=local_job) trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
local_job.set_feed_names(self._feed_names) local_job.set_feed_names(self._feed_names)
local_job.set_target_names(self._target_names) local_job.set_target_names(self._target_names)
......
...@@ -2,6 +2,7 @@ import zmq ...@@ -2,6 +2,7 @@ import zmq
import time import time
import random import random
def recv_and_parse_kv(socket): def recv_and_parse_kv(socket):
message = socket.recv() message = socket.recv()
group = message.decode().split("\t") group = message.decode().split("\t")
...@@ -10,9 +11,11 @@ def recv_and_parse_kv(socket): ...@@ -10,9 +11,11 @@ def recv_and_parse_kv(socket):
else: else:
return group[0], group[1] return group[0], group[1]
WORKER_EP = "WORKER_EP" WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP" SERVER_EP = "SERVER_EP"
class FLServerAgent(object): class FLServerAgent(object):
def __init__(self, scheduler_ep, current_ep): def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep self.scheduler_ep = scheduler_ep
...@@ -29,6 +32,7 @@ class FLServerAgent(object): ...@@ -29,6 +32,7 @@ class FLServerAgent(object):
if group[0] == 'INIT': if group[0] == 'INIT':
break break
class FLWorkerAgent(object): class FLWorkerAgent(object):
def __init__(self, scheduler_ep, current_ep): def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep self.scheduler_ep = scheduler_ep
...@@ -64,7 +68,6 @@ class FLWorkerAgent(object): ...@@ -64,7 +68,6 @@ class FLWorkerAgent(object):
return False return False
class FLScheduler(object): class FLScheduler(object):
def __init__(self, worker_num, server_num, port=9091, socket=None): def __init__(self, worker_num, server_num, port=9091, socket=None):
self.context = zmq.Context() self.context = zmq.Context()
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle_fl.core.scheduler.agent_master import FLServerAgent from paddle_fl.core.scheduler.agent_master import FLServerAgent
class FLServer(object):
class FLServer(object):
def __init__(self): def __init__(self):
self._startup_program = None self._startup_program = None
self._main_program = None self._main_program = None
......
...@@ -48,8 +48,8 @@ def wait_server_ready(endpoints): ...@@ -48,8 +48,8 @@ def wait_server_ready(endpoints):
not_ready_endpoints.append(ep) not_ready_endpoints.append(ep)
if not all_ok: if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n") sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints)
"\n") + "\n")
sys.stderr.flush() sys.stderr.flush()
time.sleep(3) time.sleep(3)
else: else:
......
...@@ -163,7 +163,8 @@ def block_to_code(block, block_idx, fout=None, skip_op_callstack=False): ...@@ -163,7 +163,8 @@ def block_to_code(block, block_idx, fout=None, skip_op_callstack=False):
indent = 0 indent = 0
print( print(
"{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx), "{0}{1} // block {2}".format(
get_indent_space(indent), '{', block_idx),
file=fout) file=fout)
indent += 1 indent += 1
......
...@@ -50,6 +50,7 @@ def log(*args): ...@@ -50,6 +50,7 @@ def log(*args):
if PRINT_LOG: if PRINT_LOG:
print(args) print(args)
def same_or_split_var(p_name, var_name): def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block") return p_name == var_name or p_name.startswith(var_name + ".block")
...@@ -113,7 +114,9 @@ class FLDistributeTranspiler(object): ...@@ -113,7 +114,9 @@ class FLDistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program): def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = [] sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] sparse_update_op_types = [
"lookup_table", "nce", "hierarchical_sigmoid"
]
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr( if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True: 'remote_prefetch') is True:
...@@ -411,7 +414,8 @@ class FLDistributeTranspiler(object): ...@@ -411,7 +414,8 @@ class FLDistributeTranspiler(object):
if self.sync_mode and self.trainer_num > 1: if self.sync_mode and self.trainer_num > 1:
for trainer_id in range(self.trainer_num): for trainer_id in range(self.trainer_num):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.opti.trainer_%d" % (orig_var_name, trainer_id), name="%s.opti.trainer_%d" %
(orig_var_name, trainer_id),
persistable=False, persistable=False,
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
...@@ -816,7 +820,6 @@ class FLDistributeTranspiler(object): ...@@ -816,7 +820,6 @@ class FLDistributeTranspiler(object):
iomap = collections.OrderedDict() iomap = collections.OrderedDict()
return iomap return iomap
def _get_lr_ops(self): def _get_lr_ops(self):
lr_ops = [] lr_ops = []
block = self.origin_program.global_block() block = self.origin_program.global_block()
......
...@@ -16,11 +16,13 @@ from .fl_distribute_transpiler import FLDistributeTranspiler ...@@ -16,11 +16,13 @@ from .fl_distribute_transpiler import FLDistributeTranspiler
from paddle.fluid.optimizer import SGD from paddle.fluid.optimizer import SGD
import paddle.fluid as fluid import paddle.fluid as fluid
class FLStrategyFactory(object): class FLStrategyFactory(object):
""" """
FLStrategyFactory is a FLStrategy builder FLStrategyFactory is a FLStrategy builder
Users can define strategy config to create different FLStrategy Users can define strategy config to create different FLStrategy
""" """
def __init__(self): def __init__(self):
self._fed_avg = False self._fed_avg = False
self._dpsgd = False self._dpsgd = False
...@@ -86,6 +88,7 @@ class FLStrategyBase(object): ...@@ -86,6 +88,7 @@ class FLStrategyBase(object):
""" """
FLStrategyBase is federated learning algorithm container FLStrategyBase is federated learning algorithm container
""" """
def __init__(self): def __init__(self):
self._fed_avg = False self._fed_avg = False
self._dpsgd = False self._dpsgd = False
...@@ -105,17 +108,23 @@ class FLStrategyBase(object): ...@@ -105,17 +108,23 @@ class FLStrategyBase(object):
for loss in losses: for loss in losses:
optimizer.minimize(loss) optimizer.minimize(loss)
def _build_trainer_program_for_job( def _build_trainer_program_for_job(self,
self, trainer_id=0, program=None, trainer_id=0,
ps_endpoints=[], trainers=0, program=None,
sync_mode=True, startup_program=None, ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None): job=None):
pass pass
def _build_server_programs_for_job( def _build_server_programs_for_job(self,
self, program=None, ps_endpoints=[], program=None,
trainers=0, sync_mode=True, ps_endpoints=[],
startup_program=None, job=None): trainers=0,
sync_mode=True,
startup_program=None,
job=None):
pass pass
...@@ -123,6 +132,7 @@ class DPSGDStrategy(FLStrategyBase): ...@@ -123,6 +132,7 @@ class DPSGDStrategy(FLStrategyBase):
""" """
DPSGDStrategy: Deep Learning with Differential Privacy. 2016 DPSGDStrategy: Deep Learning with Differential Privacy. 2016
""" """
def __init__(self): def __init__(self):
super(DPSGDStrategy, self).__init__() super(DPSGDStrategy, self).__init__()
...@@ -162,16 +172,24 @@ class DPSGDStrategy(FLStrategyBase): ...@@ -162,16 +172,24 @@ class DPSGDStrategy(FLStrategyBase):
""" """
Define Dpsgd optimizer Define Dpsgd optimizer
""" """
optimizer = fluid.optimizer.Dpsgd(self._learning_rate, clip=self._clip, batch_size=self._batch_size, sigma=self._sigma) optimizer = fluid.optimizer.Dpsgd(
self._learning_rate,
clip=self._clip,
batch_size=self._batch_size,
sigma=self._sigma)
optimizer.minimize(losses[0]) optimizer.minimize(losses[0])
def _build_trainer_program_for_job( def _build_trainer_program_for_job(self,
self, trainer_id=0, program=None, trainer_id=0,
ps_endpoints=[], trainers=0, program=None,
sync_mode=True, startup_program=None, ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None): job=None):
transpiler = fluid.DistributeTranspiler() transpiler = fluid.DistributeTranspiler()
transpiler.transpile(trainer_id, transpiler.transpile(
trainer_id,
program=program, program=program,
pservers=",".join(ps_endpoints), pservers=",".join(ps_endpoints),
trainers=trainers, trainers=trainers,
...@@ -181,10 +199,13 @@ class DPSGDStrategy(FLStrategyBase): ...@@ -181,10 +199,13 @@ class DPSGDStrategy(FLStrategyBase):
job._trainer_startup_programs.append(startup_program) job._trainer_startup_programs.append(startup_program)
job._trainer_main_programs.append(main) job._trainer_main_programs.append(main)
def _build_server_programs_for_job( def _build_server_programs_for_job(self,
self, program=None, ps_endpoints=[], program=None,
trainers=0, sync_mode=True, ps_endpoints=[],
startup_program=None, job=None): trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = fluid.DistributeTranspiler() transpiler = fluid.DistributeTranspiler()
trainer_id = 0 trainer_id = 0
transpiler.transpile( transpiler.transpile(
...@@ -207,6 +228,7 @@ class FedAvgStrategy(FLStrategyBase): ...@@ -207,6 +228,7 @@ class FedAvgStrategy(FLStrategyBase):
FedAvgStrategy: this is model averaging optimization proposed in FedAvgStrategy: this is model averaging optimization proposed in
H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017 H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017
""" """
def __init__(self): def __init__(self):
super(FedAvgStrategy, self).__init__() super(FedAvgStrategy, self).__init__()
...@@ -216,13 +238,17 @@ class FedAvgStrategy(FLStrategyBase): ...@@ -216,13 +238,17 @@ class FedAvgStrategy(FLStrategyBase):
""" """
optimizer.minimize(losses[0]) optimizer.minimize(losses[0])
def _build_trainer_program_for_job( def _build_trainer_program_for_job(self,
self, trainer_id=0, program=None, trainer_id=0,
ps_endpoints=[], trainers=0, program=None,
sync_mode=True, startup_program=None, ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None): job=None):
transpiler = FLDistributeTranspiler() transpiler = FLDistributeTranspiler()
transpiler.transpile(trainer_id, transpiler.transpile(
trainer_id,
program=program, program=program,
pservers=",".join(ps_endpoints), pservers=",".join(ps_endpoints),
trainers=trainers, trainers=trainers,
...@@ -234,10 +260,13 @@ class FedAvgStrategy(FLStrategyBase): ...@@ -234,10 +260,13 @@ class FedAvgStrategy(FLStrategyBase):
job._trainer_send_programs.append(send) job._trainer_send_programs.append(send)
job._trainer_recv_programs.append(recv) job._trainer_recv_programs.append(recv)
def _build_server_programs_for_job( def _build_server_programs_for_job(self,
self, program=None, ps_endpoints=[], program=None,
trainers=0, sync_mode=True, ps_endpoints=[],
startup_program=None, job=None): trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = FLDistributeTranspiler() transpiler = FLDistributeTranspiler()
trainer_id = 0 trainer_id = 0
transpiler.transpile( transpiler.transpile(
...@@ -262,6 +291,7 @@ class SecAggStrategy(FedAvgStrategy): ...@@ -262,6 +291,7 @@ class SecAggStrategy(FedAvgStrategy):
Practical Secure Aggregation for Privacy-Preserving Machine Learning, Practical Secure Aggregation for Privacy-Preserving Machine Learning,
The 24th ACM Conference on Computer and Communications Security ( CCS2017 ). The 24th ACM Conference on Computer and Communications Security ( CCS2017 ).
""" """
def __init__(self): def __init__(self):
super(SecAggStrategy, self).__init__() super(SecAggStrategy, self).__init__()
self._param_name_list = [] self._param_name_list = []
......
import sys import sys
import os import os
class CloudClient(object): class CloudClient(object):
def __init__(self): def __init__(self):
pass pass
...@@ -16,6 +17,7 @@ class CloudClient(object): ...@@ -16,6 +17,7 @@ class CloudClient(object):
def submit(self, **kwargs): def submit(self, **kwargs):
pass pass
class HPCClient(object): class HPCClient(object):
def __init__(self): def __init__(self):
self.conf_dict = {} self.conf_dict = {}
...@@ -70,27 +72,20 @@ class HPCClient(object): ...@@ -70,27 +72,20 @@ class HPCClient(object):
fout.write("#!/bin/bash\n") fout.write("#!/bin/bash\n")
fout.write("unset http_proxy\n") fout.write("unset http_proxy\n")
fout.write("unset https_proxy\n") fout.write("unset https_proxy\n")
fout.write("export HADOOP_HOME={}\n".format( fout.write("export HADOOP_HOME={}\n".format(self.hadoop_home))
self.hadoop_home))
fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}" fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}"
" -Dfs.default.name={} -rmr {}\n".format( " -Dfs.default.name={} -rmr {}\n".format(
self.ugi, self.ugi, self.hdfs_path, self.hdfs_output))
self.hdfs_path,
self.hdfs_output))
fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem)) fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem))
fout.write("{}/bin/qsub_f -N {} --conf qsub.conf " fout.write("{}/bin/qsub_f -N {} --conf qsub.conf "
"--hdfs {} --ugi {} --hout {} --files ./package " "--hdfs {} --ugi {} --hout {} --files ./package "
"-l nodes={},walltime=1000:00:00,pmem-hard={}," "-l nodes={},walltime=1000:00:00,pmem-hard={},"
"pcpu-soft={},pnetin-soft=1000," "pcpu-soft={},pnetin-soft=1000,"
"pnetout-soft=1000 job.sh\n".format( "pnetout-soft=1000 job.sh\n".format(
self.hpc_home, self.hpc_home, self.task_name, self.hdfs_path,
self.task_name, self.ugi, self.hdfs_output,
self.hdfs_path,
self.ugi,
self.hdfs_output,
int(self.worker_nodes) + int(self.server_nodes), int(self.worker_nodes) + int(self.server_nodes),
self.mpi_node_mem, self.mpi_node_mem, self.pcpu))
self.pcpu))
def generate_job_sh(self, job_dir): def generate_job_sh(self, job_dir):
with open("{}/job.sh".format(job_dir), "w") as fout: with open("{}/job.sh".format(job_dir), "w") as fout:
...@@ -98,17 +93,23 @@ class HPCClient(object): ...@@ -98,17 +93,23 @@ class HPCClient(object):
fout.write("WORKDIR=`pwd`\n") fout.write("WORKDIR=`pwd`\n")
fout.write("mpirun -npernode 1 mv package/* ./\n") fout.write("mpirun -npernode 1 mv package/* ./\n")
fout.write("echo 'current dir: '$WORKDIR\n") fout.write("echo 'current dir: '$WORKDIR\n")
fout.write("mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n") fout.write(
fout.write("export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n") "mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n")
fout.write(
"export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n")
fout.write("mpirun -npernode 1 python/bin/python -m pip install " fout.write("mpirun -npernode 1 python/bin/python -m pip install "
"{} --index-url=http://pip.baidu.com/pypi/simple " "{} --index-url=http://pip.baidu.com/pypi/simple "
"--trusted-host pip.baidu.com > /dev/null\n".format( "--trusted-host pip.baidu.com > /dev/null\n".format(
self.wheel)) self.wheel))
fout.write("export PATH=python/bin:$PATH\n") fout.write("export PATH=python/bin:$PATH\n")
if self.monitor_cmd != "": if self.monitor_cmd != "":
fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile " fout.write(
"${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".format(self.monitor_cmd)) "mpirun -npernode 1 -timestamp-output -tag-output -machinefile "
fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n") "${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".
format(self.monitor_cmd))
fout.write(
"mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n"
)
fout.write("if [[ $? -ne 0 ]]; then\n") fout.write("if [[ $? -ne 0 ]]; then\n")
fout.write(" echo 'Failed to run mpi!' 1>&2\n") fout.write(" echo 'Failed to run mpi!' 1>&2\n")
fout.write(" exit 1\n") fout.write(" exit 1\n")
...@@ -150,4 +151,5 @@ class HPCClient(object): ...@@ -150,4 +151,5 @@ class HPCClient(object):
# generate job.sh # generate job.sh
self.generate_qsub_conf(jobdir) self.generate_qsub_conf(jobdir)
# run submit # run submit
os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(jobdir)) os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(
jobdir))
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# (c) Chris von Csefalvay, 2015. # (c) Chris von Csefalvay, 2015.
""" """
__init__.py is responsible for [brief description here]. __init__.py is responsible for [brief description here].
""" """
# coding=utf-8 # coding=utf-8
# #
# The MIT License (MIT) # The MIT License (MIT)
# #
...@@ -21,8 +20,6 @@ ...@@ -21,8 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# #
""" """
decorators declares some decorators that ensure the object has the decorators declares some decorators that ensure the object has the
correct keys declared when need be. correct keys declared when need be.
......
...@@ -20,10 +20,6 @@ ...@@ -20,10 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# #
""" """
diffiehellmann declares the main key exchange class. diffiehellmann declares the main key exchange class.
""" """
...@@ -41,18 +37,17 @@ import os ...@@ -41,18 +37,17 @@ import os
try: try:
from ssl import RAND_bytes from ssl import RAND_bytes
rng = RAND_bytes rng = RAND_bytes
except(AttributeError, ImportError): except (AttributeError, ImportError):
rng = os.urandom rng = os.urandom
class DiffieHellman: class DiffieHellman:
""" """
Implements the Diffie-Hellman key exchange protocol. Implements the Diffie-Hellman key exchange protocol.
""" """
def __init__(self, def __init__(self, group=18, key_length=640):
group=18,
key_length=640):
self.key_length = max(200, key_length) self.key_length = max(200, key_length)
self.generator = PRIMES[group]["generator"] self.generator = PRIMES[group]["generator"]
...@@ -81,7 +76,8 @@ class DiffieHellman: ...@@ -81,7 +76,8 @@ class DiffieHellman:
self.private_key = key self.private_key = key
def verify_public_key(self, other_public_key): def verify_public_key(self, other_public_key):
return self.prime - 1 > other_public_key > 2 and pow(other_public_key, (self.prime - 1) // 2, self.prime) == 1 return self.prime - 1 > other_public_key > 2 and pow(
other_public_key, (self.prime - 1) // 2, self.prime) == 1
@requires_private_key @requires_private_key
def generate_public_key(self): def generate_public_key(self):
...@@ -91,9 +87,7 @@ class DiffieHellman: ...@@ -91,9 +87,7 @@ class DiffieHellman:
:return: void :return: void
:rtype: void :rtype: void
""" """
self.public_key = pow(self.generator, self.public_key = pow(self.generator, self.private_key, self.prime)
self.private_key,
self.prime)
@requires_private_key @requires_private_key
def generate_shared_secret(self, other_public_key, echo_return_key=False): def generate_shared_secret(self, other_public_key, echo_return_key=False):
...@@ -110,16 +104,17 @@ class DiffieHellman: ...@@ -110,16 +104,17 @@ class DiffieHellman:
if self.verify_public_key(other_public_key) is False: if self.verify_public_key(other_public_key) is False:
raise MalformedPublicKey raise MalformedPublicKey
self.shared_secret = pow(other_public_key, self.shared_secret = pow(other_public_key, self.private_key,
self.private_key,
self.prime) self.prime)
try: try:
#python3 #python3
shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big') shared_secret_as_bytes = self.shared_secret.to_bytes(
self.shared_secret.bit_length() // 8 + 1, byteorder='big')
except: except:
#python2 #python2
length = self.shared_secret.bit_length() // 8 + 1 length = self.shared_secret.bit_length() // 8 + 1
shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:] shared_secret_as_bytes = ('%%0%dx' % (
length << 1) % self.shared_secret).decode('hex')[-length:]
_h = sha256() _h = sha256()
_h.update(bytes(shared_secret_as_bytes)) _h.update(bytes(shared_secret_as_bytes))
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# (c) Chris von Csefalvay, 2015. # (c) Chris von Csefalvay, 2015.
""" """
exceptions is responsible for exception handling etc. exceptions is responsible for exception handling etc.
""" """
......
# coding=utf-8 # coding=utf-8
# #
# The MIT License (MIT) # The MIT License (MIT)
# #
...@@ -25,34 +24,39 @@ ...@@ -25,34 +24,39 @@
# Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman # Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman
# groups for Internet Key Exchange (IKE)_. # groups for Internet Key Exchange (IKE)_.
# #
""" """
primes holds the RFC 3526 MODP primes and their generators. primes holds the RFC 3526 MODP primes and their generators.
""" """
PRIMES = { PRIMES = {
5: { 5: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
14: { 14: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
15: { 15: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
16: { 16: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
17: { 17: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
18: { 18: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
} }
...@@ -19,6 +19,7 @@ import hmac ...@@ -19,6 +19,7 @@ import hmac
import hashlib import hashlib
from .diffiehellman.diffiehellman import DiffieHellman from .diffiehellman.diffiehellman import DiffieHellman
class FLTrainerFactory(object): class FLTrainerFactory(object):
def __init__(self): def __init__(self):
pass pass
...@@ -65,9 +66,7 @@ class FLTrainer(object): ...@@ -65,9 +66,7 @@ class FLTrainer(object):
def run(self, feed, fetch): def run(self, feed, fetch):
self._logger.debug("begin to run") self._logger.debug("begin to run")
self.exe.run(self._main_program, self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
feed=feed,
fetch_list=fetch)
self._logger.debug("end to run current batch") self._logger.debug("end to run current batch")
self.cur_step += 1 self.cur_step += 1
...@@ -119,7 +118,7 @@ class FedAvgTrainer(FLTrainer): ...@@ -119,7 +118,7 @@ class FedAvgTrainer(FLTrainer):
def reset(self): def reset(self):
self.cur_step = 0 self.cur_step = 0
def run_with_epoch(self,reader,feeder,fetch,num_epoch): def run_with_epoch(self, reader, feeder, fetch, num_epoch):
self._logger.debug("begin to run recv program") self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program) self.exe.run(self._recv_program)
epoch = 0 epoch = 0
...@@ -132,16 +131,16 @@ class FedAvgTrainer(FLTrainer): ...@@ -132,16 +131,16 @@ class FedAvgTrainer(FLTrainer):
epoch += 1 epoch += 1
self._logger.debug("begin to run send program") self._logger.debug("begin to run send program")
self.exe.run(self._send_program) self.exe.run(self._send_program)
def run(self, feed, fetch): def run(self, feed, fetch):
self._logger.debug("begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" % self._logger.debug(
"begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step)) (self.cur_step, self._step))
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program") self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program) self.exe.run(self._recv_program)
self._logger.debug("begin to run current step") self._logger.debug("begin to run current step")
loss = self.exe.run(self._main_program, loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
feed=feed,
fetch_list=fetch)
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program") self._logger.debug("begin to run send program")
self.exe.run(self._send_program) self.exe.run(self._send_program)
...@@ -149,9 +148,6 @@ class FedAvgTrainer(FLTrainer): ...@@ -149,9 +148,6 @@ class FedAvgTrainer(FLTrainer):
return loss return loss
class SecAggTrainer(FLTrainer): class SecAggTrainer(FLTrainer):
def __init__(self): def __init__(self):
super(SecAggTrainer, self).__init__() super(SecAggTrainer, self).__init__()
...@@ -207,24 +203,24 @@ class SecAggTrainer(FLTrainer): ...@@ -207,24 +203,24 @@ class SecAggTrainer(FLTrainer):
self.cur_step = 0 self.cur_step = 0
def run(self, feed, fetch): def run(self, feed, fetch):
self._logger.debug("begin to run SecAggTrainer, cur_step=%d, inner_step=%d" % self._logger.debug(
"begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step)) (self.cur_step, self._step))
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program") self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program) self.exe.run(self._recv_program)
scope = fluid.global_scope() scope = fluid.global_scope()
self._logger.debug("begin to run current step") self._logger.debug("begin to run current step")
loss = self.exe.run(self._main_program, loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
feed=feed,
fetch_list=fetch)
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program") self._logger.debug("begin to run send program")
noise = 0.0 noise = 0.0
scale = pow(10.0, 5) scale = pow(10.0, 5)
digestmod=hashlib.sha256 digestmod = hashlib.sha256
# 1. load priv key and other's pub key # 1. load priv key and other's pub key
dh = DiffieHellman(group=15, key_length=256) dh = DiffieHellman(group=15, key_length=256)
dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt") dh.load_private_key(self._key_dir + str(self._trainer_id) +
"_priv_key.txt")
key = str(self._step_id).encode("utf-8") key = str(self._step_id).encode("utf-8")
for i in range(self._trainer_num): for i in range(self._trainer_num):
if i != self._trainer_id: if i != self._trainer_id:
...@@ -232,7 +228,8 @@ class SecAggTrainer(FLTrainer): ...@@ -232,7 +228,8 @@ class SecAggTrainer(FLTrainer):
public_key = int(f.read()) public_key = int(f.read())
dh.generate_shared_secret(public_key, echo_return_key=True) dh.generate_shared_secret(public_key, echo_return_key=True)
msg = dh.shared_key.encode("utf-8") msg = dh.shared_key.encode("utf-8")
hex_res1 = hmac.new(key=key, msg=msg, digestmod=digestmod).hexdigest() hex_res1 = hmac.new(key=key, msg=msg,
digestmod=digestmod).hexdigest()
current_noise = int(hex_res1[0:8], 16) / scale current_noise = int(hex_res1[0:8], 16) / scale
if i > self._trainer_id: if i > self._trainer_id:
noise = noise + current_noise noise = noise + current_noise
...@@ -241,9 +238,11 @@ class SecAggTrainer(FLTrainer): ...@@ -241,9 +238,11 @@ class SecAggTrainer(FLTrainer):
scope = fluid.global_scope() scope = fluid.global_scope()
for param_name in self._param_name_list: for param_name in self._param_name_list:
fluid.global_scope().var(param_name + str(self._trainer_id)).get_tensor().set( fluid.global_scope().var(param_name + str(
numpy.array(scope.find_var(param_name + str(self._trainer_id)).get_tensor()) + noise, fluid.CPUPlace()) self._trainer_id)).get_tensor().set(
numpy.array(
scope.find_var(param_name + str(self._trainer_id))
.get_tensor()) + noise, fluid.CPUPlace())
self.exe.run(self._send_program) self.exe.run(self._send_program)
self.cur_step += 1 self.cur_step += 1
return loss return loss
...@@ -5,20 +5,22 @@ import tarfile ...@@ -5,20 +5,22 @@ import tarfile
import random import random
def download(url,tar_path): def download(url, tar_path):
r = requests.get(url) r = requests.get(url)
with open(tar_path,'wb') as f: with open(tar_path, 'wb') as f:
f.write(r.content) f.write(r.content)
def extract(tar_path,target_path):
def extract(tar_path, target_path):
tar = tarfile.open(tar_path, "r:gz") tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames() file_names = tar.getnames()
for file_name in file_names: for file_name in file_names:
tar.extract(file_name,target_path) tar.extract(file_name, target_path)
tar.close() tar.close()
def train(trainer_id,inner_step,batch_size,count_by_step):
def train(trainer_id, inner_step, batch_size, count_by_step):
target_path = "trainer%d_data" % trainer_id target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data" data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz" tar_path = data_path + ".tar.gz"
...@@ -27,20 +29,26 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -27,20 +29,26 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(data_path): if not os.path.exists(data_path):
print("Preparing data...") print("Preparing data...")
if not os.path.exists(tar_path): if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
extract(tar_path,target_path) tar_path)
extract(tar_path, target_path)
def train_data(): def train_data():
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r') train_file = open(
"./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json"
% (trainer_id, trainer_id), 'r')
json_train = json.load(train_file) json_train = json.load(train_file)
users = json_train["users"] users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer rand = random.randrange(
0, len(users)) # random choose a user from each trainer
cur_user = users[rand] cur_user = users[rand]
print('training using '+cur_user) print('training using ' + cur_user)
train_images = json_train["user_data"][cur_user]['x'] train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y'] train_labels = json_train["user_data"][cur_user]['y']
if count_by_step: if count_by_step:
for i in range(inner_step*batch_size): for i in range(inner_step * batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))] yield train_images[i % (len(train_images))], train_labels[i % (
len(train_images))]
else: else:
for i in range(len(train_images)): for i in range(len(train_images)):
yield train_images[i], train_labels[i] yield train_images[i], train_labels[i]
...@@ -49,7 +57,8 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -49,7 +57,8 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
return train_data return train_data
def test(trainer_id,inner_step,batch_size,count_by_step):
def test(trainer_id, inner_step, batch_size, count_by_step):
target_path = "trainer%d_data" % trainer_id target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data" data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz" tar_path = data_path + ".tar.gz"
...@@ -58,10 +67,14 @@ def test(trainer_id,inner_step,batch_size,count_by_step): ...@@ -58,10 +67,14 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(data_path): if not os.path.exists(data_path):
print("Preparing data...") print("Preparing data...")
if not os.path.exists(tar_path): if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
extract(tar_path,target_path) tar_path)
extract(tar_path, target_path)
def test_data(): def test_data():
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r') test_file = open(
"./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json"
% (trainer_id, trainer_id), 'r')
json_test = json.load(test_file) json_test = json.load(test_file)
users = json_test["users"] users = json_test["users"]
for user in users: for user in users:
...@@ -73,5 +86,3 @@ def test(trainer_id,inner_step,batch_size,count_by_step): ...@@ -73,5 +86,3 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
test_file.close() test_file.close()
return test_data return test_data
...@@ -3,6 +3,7 @@ import paddle_fl as fl ...@@ -3,6 +3,7 @@ import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
class Model(object): class Model(object):
def __init__(self): def __init__(self):
pass pass
...@@ -12,7 +13,8 @@ class Model(object): ...@@ -12,7 +13,8 @@ class Model(object):
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu') self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu') self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax') self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost) self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
...@@ -34,8 +36,8 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.1) ...@@ -34,8 +36,8 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.1)
job_generator.set_optimizer(optimizer) job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names( job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
[x.name for x in inputs], [model.predict.name]) [model.predict.name])
build_strategy = FLStrategyFactory() build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True build_strategy.fed_avg = True
......
...@@ -3,7 +3,7 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler ...@@ -3,7 +3,7 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2 worker_num = 2
server_num = 1 server_num = 1
# Define the number of worker/server and the port for scheduler # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091) scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num) scheduler.set_sample_worker_num(worker_num)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
...@@ -4,7 +4,12 @@ import numpy as np ...@@ -4,7 +4,12 @@ import numpy as np
import sys import sys
import logging import logging
import time import time
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
def reader(): def reader():
...@@ -15,13 +20,14 @@ def reader(): ...@@ -15,13 +20,14 @@ def reader():
data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64') data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
yield data_dict yield data_dict
trainer_id = int(sys.argv[1]) # trainer id for each guest trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start() trainer.start()
print(trainer._scheduler_ep, trainer._current_ep) print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model" output_folder = "fl_model"
...@@ -37,4 +43,3 @@ while not trainer.stop(): ...@@ -37,4 +43,3 @@ while not trainer.stop():
epoch_id += 1 epoch_id += 1
if epoch_id % 5 == 0: if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
...@@ -14,4 +14,3 @@ ...@@ -14,4 +14,3 @@
""" PaddleFL version string """ """ PaddleFL version string """
fl_version = "0.1.11" fl_version = "0.1.11"
module_proto_version = "0.1.11" module_proto_version = "0.1.11"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册