提交 6b1fb7cc 编写于 作者: Q qjing666

fix code style

上级 59998725
| Github account | name |
|---|---|
| guru4elephant | Daxiang Dong |
| frankwhzhang | Wenhui Zhang |
\ No newline at end of file
| frankwhzhang | Wenhui Zhang |
......@@ -3,4 +3,3 @@ mistune
sphinx_rtd_theme
paddlepaddle>=1.6
zmq
......@@ -181,4 +181,3 @@ while not trainer.stop():
To show the effectiveness of DPSGD-based federated learning with PaddleFL, a simulated experiment is conducted on an open source dataset MNIST. From the figure given below, model evaluation results are similar between DPSGD-based federated learning and traditional parameter server training when the overall privacy budget *epsilon* is 1.3 or 0.13.
<img src="fl_dpsgd_benchmark.png" height=400 width=600 hspace='10'/> <br />
......@@ -103,6 +103,3 @@ wget https://paddle-zwh.bj.bcebos.com/gru4rec_paddlefl_benchmark/gru4rec_benchma
| 1/4 of the whole dataset | private training | - | 0.282 |
<img src="fl_benchmark.png" height=300 width=500 hspace='10'/> <br />
......@@ -55,4 +55,3 @@ In PaddleFL, components for defining a federated learning task and training a fe
- Federated Learning Systems deployment methods in Kubernetes.
- Vertical Federated Learning Strategies and more horizontal federated learning strategies will be open sourced.
......@@ -14,4 +14,4 @@
[7]. Virginia Smith, Chao-Kai Chiang, Maziar Sanjabi, Ameet Talwalkar. **Federated Multi-Task Learning** 2016
[8]. Yang Liu, Tianjian Chen, Qiang Yang. **Secure Federated Transfer Learning.** 2018
\ No newline at end of file
[8]. Yang Liu, Tianjian Chen, Qiang Yang. **Secure Federated Transfer Learning.** 2018
......@@ -22,4 +22,3 @@ from .scheduler.agent_master import FLWorkerAgent
from .scheduler.agent_master import FLScheduler
from .submitter.client_base import HPCClient
from .submitter.client_base import CloudClient
......@@ -14,11 +14,13 @@
import os
import paddle.fluid as fluid
class FLJobBase(object):
"""
FLJobBase is fl job base class, responsible for save and load
a federated learning job
"""
def __init__(self):
pass
......@@ -64,6 +66,7 @@ class FLJobBase(object):
return fluid.Program.parse_from_string(program_desc_str)
return None
class FLCompileTimeJob(FLJobBase):
"""
FLCompileTimeJob is a container for compile time job in federated learning.
......@@ -71,6 +74,7 @@ class FLCompileTimeJob(FLJobBase):
are in FLCompileTimeJob. Also, server main programs and server startup programs
are in this class. FLCompileTimeJob has server endpoints for debugging as well
"""
def __init__(self):
self._trainer_startup_programs = []
self._trainer_recv_programs = []
......@@ -101,69 +105,59 @@ class FLCompileTimeJob(FLJobBase):
os.system("mkdir -p %s" % server_folder)
server_startup = self._server_startup_programs[i]
server_main = self._server_main_programs[i]
self._save_program(
server_startup,
"%s/server.startup.program" % server_folder)
self._save_program(
server_main,
"%s/server.main.program" % server_folder)
self._save_program(server_startup,
"%s/server.startup.program" % server_folder)
self._save_program(server_main,
"%s/server.main.program" % server_folder)
self._save_readable_program(server_startup,
"%s/server.startup.program.txt" %
server_folder)
self._save_readable_program(
server_startup,
"%s/server.startup.program.txt" % server_folder)
self._save_readable_program(
server_main,
"%s/server.main.program.txt" % server_folder)
server_main, "%s/server.main.program.txt" % server_folder)
self._save_str_list(self._feed_names,
"%s/feed_names" % server_folder)
"%s/feed_names" % server_folder)
self._save_str_list(self._target_names,
"%s/target_names" % server_folder)
"%s/target_names" % server_folder)
self._save_endpoints(self._server_endpoints,
"%s/endpoints" % server_folder)
"%s/endpoints" % server_folder)
self._save_strategy(self._strategy,
"%s/strategy.pkl" % server_folder)
"%s/strategy.pkl" % server_folder)
for i in range(trainer_num):
trainer_folder = "%s/trainer%d" % (folder, i)
os.system("mkdir -p %s" % trainer_folder)
trainer_startup = self._trainer_startup_programs[i]
trainer_main = self._trainer_main_programs[i]
self._save_program(
trainer_startup,
"%s/trainer.startup.program" % trainer_folder)
self._save_program(
trainer_main,
"%s/trainer.main.program" % trainer_folder)
self._save_readable_program(
trainer_startup,
"%s/trainer.startup.program.txt" % trainer_folder)
self._save_program(trainer_startup,
"%s/trainer.startup.program" % trainer_folder)
self._save_program(trainer_main,
"%s/trainer.main.program" % trainer_folder)
self._save_readable_program(trainer_startup,
"%s/trainer.startup.program.txt" %
trainer_folder)
self._save_readable_program(
trainer_main,
"%s/trainer.main.program.txt" % trainer_folder)
trainer_main, "%s/trainer.main.program.txt" % trainer_folder)
self._save_str_list(self._feed_names,
"%s/feed_names" % trainer_folder)
"%s/feed_names" % trainer_folder)
self._save_str_list(self._target_names,
"%s/target_names" % trainer_folder)
"%s/target_names" % trainer_folder)
self._save_endpoints(self._server_endpoints,
"%s/endpoints" % trainer_folder)
"%s/endpoints" % trainer_folder)
self._save_strategy(self._strategy,
"%s/strategy.pkl" % trainer_folder)
"%s/strategy.pkl" % trainer_folder)
for i in range(send_prog_num):
trainer_folder = "%s/trainer%d" % (folder, i)
trainer_send = self._trainer_send_programs[i]
trainer_recv = self._trainer_recv_programs[i]
self._save_program(
trainer_send,
"%s/trainer.send.program" % trainer_folder)
self._save_program(
trainer_recv,
"%s/trainer.recv.program" % trainer_folder)
self._save_program(trainer_send,
"%s/trainer.send.program" % trainer_folder)
self._save_program(trainer_recv,
"%s/trainer.recv.program" % trainer_folder)
self._save_readable_program(
trainer_send,
"%s/trainer.send.program.txt" % trainer_folder)
trainer_send, "%s/trainer.send.program.txt" % trainer_folder)
self._save_readable_program(
trainer_recv,
"%s/trainer.recv.program.txt" % trainer_folder)
trainer_recv, "%s/trainer.recv.program.txt" % trainer_folder)
class FLRunTimeJob(FLJobBase):
......@@ -172,6 +166,7 @@ class FLRunTimeJob(FLJobBase):
A trainer or a server can load FLRunTimeJob. Only necessary programs
can be loaded in FLRunTimeJob
"""
def __init__(self):
self._trainer_startup_program = None
self._trainer_recv_program = None
......
......@@ -14,6 +14,7 @@
import paddle.fluid as fluid
from .fl_job import FLCompileTimeJob
class JobGenerator(object):
"""
A JobGenerator is responsible for generating distributed federated
......@@ -21,6 +22,7 @@ class JobGenerator(object):
need to define a deep learning model together to do horizontal federated
learning.
"""
def __init__(self):
# worker num for federated learning
self._worker_num = 0
......@@ -32,7 +34,6 @@ class JobGenerator(object):
self._feed_names = []
self._target_names = []
def set_optimizer(self, optimizer):
"""
Set optimizer of current job
......@@ -56,8 +57,10 @@ class JobGenerator(object):
self._startup_prog = startup
def set_infer_feed_and_target_names(self, feed_names, target_names):
if not isinstance(feed_names, list) or not isinstance(target_names, list):
raise ValueError("input should be list in set_infer_feed_and_target_names")
if not isinstance(feed_names, list) or not isinstance(target_names,
list):
raise ValueError(
"input should be list in set_infer_feed_and_target_names")
'''
print(feed_names)
print(target_names)
......@@ -76,7 +79,6 @@ class JobGenerator(object):
server_endpoints=[],
worker_num=1,
output=None):
"""
Generate Federated Learning Job, based on user defined configs
......@@ -130,30 +132,35 @@ class JobGenerator(object):
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_trainer_program_for_job(
trainer_id, program=main_program,
ps_endpoints=server_endpoints, trainers=worker_num,
sync_mode=True, startup_program=startup_program,
trainer_id,
program=main_program,
ps_endpoints=server_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_server_programs_for_job(
program=main_program, ps_endpoints=server_endpoints,
trainers=worker_num, sync_mode=True,
startup_program=startup_program, job=local_job)
program=main_program,
ps_endpoints=server_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
local_job.set_feed_names(self._feed_names)
local_job.set_target_names(self._target_names)
local_job.set_strategy(fl_strategy)
local_job.save(output)
def generate_fl_job_for_k8s(self,
fl_strategy,
server_pod_endpoints=[],
server_service_endpoints=[],
worker_num=1,
output=None):
fl_strategy,
server_pod_endpoints=[],
server_service_endpoints=[],
worker_num=1,
output=None):
local_job = FLCompileTimeJob()
assert len(self._losses) > 0
......@@ -168,17 +175,23 @@ class JobGenerator(object):
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_trainer_program_for_job(
trainer_id, program=main_program,
ps_endpoints=server_service_endpoints, trainers=worker_num,
sync_mode=True, startup_program=startup_program,
trainer_id,
program=main_program,
ps_endpoints=server_service_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_server_programs_for_job(
program=main_program, ps_endpoints=server_pod_endpoints,
trainers=worker_num, sync_mode=True,
startup_program=startup_program, job=local_job)
program=main_program,
ps_endpoints=server_pod_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
local_job.set_feed_names(self._feed_names)
local_job.set_target_names(self._target_names)
......
......@@ -2,6 +2,7 @@ import zmq
import time
import random
def recv_and_parse_kv(socket):
message = socket.recv()
group = message.decode().split("\t")
......@@ -10,9 +11,11 @@ def recv_and_parse_kv(socket):
else:
return group[0], group[1]
WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP"
class FLServerAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
......@@ -29,6 +32,7 @@ class FLServerAgent(object):
if group[0] == 'INIT':
break
class FLWorkerAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
......@@ -64,7 +68,6 @@ class FLWorkerAgent(object):
return False
class FLScheduler(object):
def __init__(self, worker_num, server_num, port=9091, socket=None):
self.context = zmq.Context()
......
......@@ -14,8 +14,8 @@
import paddle.fluid as fluid
from paddle_fl.core.scheduler.agent_master import FLServerAgent
class FLServer(object):
class FLServer(object):
def __init__(self):
self._startup_program = None
self._main_program = None
......
......@@ -48,8 +48,8 @@ def wait_server_ready(endpoints):
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) +
"\n")
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints)
+ "\n")
sys.stderr.flush()
time.sleep(3)
else:
......
......@@ -163,7 +163,8 @@ def block_to_code(block, block_idx, fout=None, skip_op_callstack=False):
indent = 0
print(
"{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx),
"{0}{1} // block {2}".format(
get_indent_space(indent), '{', block_idx),
file=fout)
indent += 1
......
......@@ -50,6 +50,7 @@ def log(*args):
if PRINT_LOG:
print(args)
def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
......@@ -113,7 +114,9 @@ class FLDistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"]
sparse_update_op_types = [
"lookup_table", "nce", "hierarchical_sigmoid"
]
for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True:
......@@ -406,12 +409,13 @@ class FLDistributeTranspiler(object):
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var = pserver_program.global_block().var(
orig_var_name)
orig_var_name)
if self.sync_mode and self.trainer_num > 1:
for trainer_id in range(self.trainer_num):
var = pserver_program.global_block().create_var(
name="%s.opti.trainer_%d" % (orig_var_name, trainer_id),
name="%s.opti.trainer_%d" %
(orig_var_name, trainer_id),
persistable=False,
type=v.type,
dtype=v.dtype,
......@@ -816,7 +820,6 @@ class FLDistributeTranspiler(object):
iomap = collections.OrderedDict()
return iomap
def _get_lr_ops(self):
lr_ops = []
block = self.origin_program.global_block()
......
......@@ -16,11 +16,13 @@ from .fl_distribute_transpiler import FLDistributeTranspiler
from paddle.fluid.optimizer import SGD
import paddle.fluid as fluid
class FLStrategyFactory(object):
"""
FLStrategyFactory is a FLStrategy builder
Users can define strategy config to create different FLStrategy
"""
def __init__(self):
self._fed_avg = False
self._dpsgd = False
......@@ -86,6 +88,7 @@ class FLStrategyBase(object):
"""
FLStrategyBase is federated learning algorithm container
"""
def __init__(self):
self._fed_avg = False
self._dpsgd = False
......@@ -105,17 +108,23 @@ class FLStrategyBase(object):
for loss in losses:
optimizer.minimize(loss)
def _build_trainer_program_for_job(
self, trainer_id=0, program=None,
ps_endpoints=[], trainers=0,
sync_mode=True, startup_program=None,
job=None):
def _build_trainer_program_for_job(self,
trainer_id=0,
program=None,
ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None):
pass
def _build_server_programs_for_job(
self, program=None, ps_endpoints=[],
trainers=0, sync_mode=True,
startup_program=None, job=None):
def _build_server_programs_for_job(self,
program=None,
ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None):
pass
......@@ -123,6 +132,7 @@ class DPSGDStrategy(FLStrategyBase):
"""
DPSGDStrategy: Deep Learning with Differential Privacy. 2016
"""
def __init__(self):
super(DPSGDStrategy, self).__init__()
......@@ -162,29 +172,40 @@ class DPSGDStrategy(FLStrategyBase):
"""
Define Dpsgd optimizer
"""
optimizer = fluid.optimizer.Dpsgd(self._learning_rate, clip=self._clip, batch_size=self._batch_size, sigma=self._sigma)
optimizer = fluid.optimizer.Dpsgd(
self._learning_rate,
clip=self._clip,
batch_size=self._batch_size,
sigma=self._sigma)
optimizer.minimize(losses[0])
def _build_trainer_program_for_job(
self, trainer_id=0, program=None,
ps_endpoints=[], trainers=0,
sync_mode=True, startup_program=None,
job=None):
def _build_trainer_program_for_job(self,
trainer_id=0,
program=None,
ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = fluid.DistributeTranspiler()
transpiler.transpile(trainer_id,
program=program,
pservers=",".join(ps_endpoints),
trainers=trainers,
sync_mode=sync_mode,
startup_program=startup_program)
transpiler.transpile(
trainer_id,
program=program,
pservers=",".join(ps_endpoints),
trainers=trainers,
sync_mode=sync_mode,
startup_program=startup_program)
main = transpiler.get_trainer_program(wait_port=False)
job._trainer_startup_programs.append(startup_program)
job._trainer_main_programs.append(main)
def _build_server_programs_for_job(
self, program=None, ps_endpoints=[],
trainers=0, sync_mode=True,
startup_program=None, job=None):
def _build_server_programs_for_job(self,
program=None,
ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = fluid.DistributeTranspiler()
trainer_id = 0
transpiler.transpile(
......@@ -207,6 +228,7 @@ class FedAvgStrategy(FLStrategyBase):
FedAvgStrategy: this is model averaging optimization proposed in
H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017
"""
def __init__(self):
super(FedAvgStrategy, self).__init__()
......@@ -216,28 +238,35 @@ class FedAvgStrategy(FLStrategyBase):
"""
optimizer.minimize(losses[0])
def _build_trainer_program_for_job(
self, trainer_id=0, program=None,
ps_endpoints=[], trainers=0,
sync_mode=True, startup_program=None,
job=None):
def _build_trainer_program_for_job(self,
trainer_id=0,
program=None,
ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = FLDistributeTranspiler()
transpiler.transpile(trainer_id,
program=program,
pservers=",".join(ps_endpoints),
trainers=trainers,
sync_mode=sync_mode,
startup_program=startup_program)
transpiler.transpile(
trainer_id,
program=program,
pservers=",".join(ps_endpoints),
trainers=trainers,
sync_mode=sync_mode,
startup_program=startup_program)
recv, main, send = transpiler.get_trainer_program()
job._trainer_startup_programs.append(startup_program)
job._trainer_main_programs.append(main)
job._trainer_send_programs.append(send)
job._trainer_recv_programs.append(recv)
def _build_server_programs_for_job(
self, program=None, ps_endpoints=[],
trainers=0, sync_mode=True,
startup_program=None, job=None):
def _build_server_programs_for_job(self,
program=None,
ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = FLDistributeTranspiler()
trainer_id = 0
transpiler.transpile(
......@@ -262,6 +291,7 @@ class SecAggStrategy(FedAvgStrategy):
Practical Secure Aggregation for Privacy-Preserving Machine Learning,
The 24th ACM Conference on Computer and Communications Security ( CCS2017 ).
"""
def __init__(self):
super(SecAggStrategy, self).__init__()
self._param_name_list = []
......
import sys
import os
class CloudClient(object):
def __init__(self):
pass
def generate_submit_sh(self, job_dir):
with open() as fout:
pass
......@@ -16,6 +17,7 @@ class CloudClient(object):
def submit(self, **kwargs):
pass
class HPCClient(object):
def __init__(self):
self.conf_dict = {}
......@@ -70,27 +72,20 @@ class HPCClient(object):
fout.write("#!/bin/bash\n")
fout.write("unset http_proxy\n")
fout.write("unset https_proxy\n")
fout.write("export HADOOP_HOME={}\n".format(
self.hadoop_home))
fout.write("export HADOOP_HOME={}\n".format(self.hadoop_home))
fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}"
" -Dfs.default.name={} -rmr {}\n".format(
self.ugi,
self.hdfs_path,
self.hdfs_output))
self.ugi, self.hdfs_path, self.hdfs_output))
fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem))
fout.write("{}/bin/qsub_f -N {} --conf qsub.conf "
"--hdfs {} --ugi {} --hout {} --files ./package "
"-l nodes={},walltime=1000:00:00,pmem-hard={},"
"pcpu-soft={},pnetin-soft=1000,"
"pnetout-soft=1000 job.sh\n".format(
self.hpc_home,
self.task_name,
self.hdfs_path,
self.ugi,
self.hdfs_output,
self.hpc_home, self.task_name, self.hdfs_path,
self.ugi, self.hdfs_output,
int(self.worker_nodes) + int(self.server_nodes),
self.mpi_node_mem,
self.pcpu))
self.mpi_node_mem, self.pcpu))
def generate_job_sh(self, job_dir):
with open("{}/job.sh".format(job_dir), "w") as fout:
......@@ -98,17 +93,23 @@ class HPCClient(object):
fout.write("WORKDIR=`pwd`\n")
fout.write("mpirun -npernode 1 mv package/* ./\n")
fout.write("echo 'current dir: '$WORKDIR\n")
fout.write("mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n")
fout.write("export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n")
fout.write(
"mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n")
fout.write(
"export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n")
fout.write("mpirun -npernode 1 python/bin/python -m pip install "
"{} --index-url=http://pip.baidu.com/pypi/simple "
"--trusted-host pip.baidu.com > /dev/null\n".format(
self.wheel))
fout.write("export PATH=python/bin:$PATH\n")
if self.monitor_cmd != "":
fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile "
"${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".format(self.monitor_cmd))
fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n")
fout.write(
"mpirun -npernode 1 -timestamp-output -tag-output -machinefile "
"${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".
format(self.monitor_cmd))
fout.write(
"mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n"
)
fout.write("if [[ $? -ne 0 ]]; then\n")
fout.write(" echo 'Failed to run mpi!' 1>&2\n")
fout.write(" exit 1\n")
......@@ -150,4 +151,5 @@ class HPCClient(object):
# generate job.sh
self.generate_qsub_conf(jobdir)
# run submit
os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(jobdir))
os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(
jobdir))
......@@ -2,7 +2,6 @@
#
# (c) Chris von Csefalvay, 2015.
"""
__init__.py is responsible for [brief description here].
"""
# coding=utf-8
#
# The MIT License (MIT)
#
......@@ -21,8 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
"""
decorators declares some decorators that ensure the object has the
correct keys declared when need be.
......
......@@ -20,10 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
"""
diffiehellmann declares the main key exchange class.
"""
......@@ -41,18 +37,17 @@ import os
try:
from ssl import RAND_bytes
rng = RAND_bytes
except(AttributeError, ImportError):
except (AttributeError, ImportError):
rng = os.urandom
class DiffieHellman:
"""
Implements the Diffie-Hellman key exchange protocol.
"""
def __init__(self,
group=18,
key_length=640):
def __init__(self, group=18, key_length=640):
self.key_length = max(200, key_length)
self.generator = PRIMES[group]["generator"]
......@@ -81,7 +76,8 @@ class DiffieHellman:
self.private_key = key
def verify_public_key(self, other_public_key):
return self.prime - 1 > other_public_key > 2 and pow(other_public_key, (self.prime - 1) // 2, self.prime) == 1
return self.prime - 1 > other_public_key > 2 and pow(
other_public_key, (self.prime - 1) // 2, self.prime) == 1
@requires_private_key
def generate_public_key(self):
......@@ -91,9 +87,7 @@ class DiffieHellman:
:return: void
:rtype: void
"""
self.public_key = pow(self.generator,
self.private_key,
self.prime)
self.public_key = pow(self.generator, self.private_key, self.prime)
@requires_private_key
def generate_shared_secret(self, other_public_key, echo_return_key=False):
......@@ -110,16 +104,17 @@ class DiffieHellman:
if self.verify_public_key(other_public_key) is False:
raise MalformedPublicKey
self.shared_secret = pow(other_public_key,
self.private_key,
self.shared_secret = pow(other_public_key, self.private_key,
self.prime)
try:
#python3
shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big')
shared_secret_as_bytes = self.shared_secret.to_bytes(
self.shared_secret.bit_length() // 8 + 1, byteorder='big')
except:
#python2
length = self.shared_secret.bit_length() // 8 + 1
shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:]
shared_secret_as_bytes = ('%%0%dx' % (
length << 1) % self.shared_secret).decode('hex')[-length:]
_h = sha256()
_h.update(bytes(shared_secret_as_bytes))
......
......@@ -2,7 +2,6 @@
#
# (c) Chris von Csefalvay, 2015.
"""
exceptions is responsible for exception handling etc.
"""
......
# coding=utf-8
#
# The MIT License (MIT)
#
......@@ -25,34 +24,39 @@
# Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman
# groups for Internet Key Exchange (IKE)_.
#
"""
primes holds the RFC 3526 MODP primes and their generators.
"""
PRIMES = {
5: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF,
"prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF,
"generator": 2
},
14: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF,
"prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF,
"generator": 2
},
15: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF,
"prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF,
"generator": 2
},
16: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF,
"prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF,
"generator": 2
},
17: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF,
"prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF,
"generator": 2
},
18: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF,
"prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF,
"generator": 2
},
}
......@@ -19,6 +19,7 @@ import hmac
import hashlib
from .diffiehellman.diffiehellman import DiffieHellman
class FLTrainerFactory(object):
def __init__(self):
pass
......@@ -65,9 +66,7 @@ class FLTrainer(object):
def run(self, feed, fetch):
self._logger.debug("begin to run")
self.exe.run(self._main_program,
feed=feed,
fetch_list=fetch)
self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
self._logger.debug("end to run current batch")
self.cur_step += 1
......@@ -119,37 +118,34 @@ class FedAvgTrainer(FLTrainer):
def reset(self):
self.cur_step = 0
def run_with_epoch(self,reader,feeder,fetch,num_epoch):
def run_with_epoch(self, reader, feeder, fetch, num_epoch):
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
epoch = 0
for i in range(num_epoch):
for data in reader():
self.exe.run(self._main_program,
feed=feeder.feed(data),
fetch_list=fetch)
self.cur_step += 1
epoch += 1
for data in reader():
self.exe.run(self._main_program,
feed=feeder.feed(data),
fetch_list=fetch)
self.cur_step += 1
epoch += 1
self._logger.debug("begin to run send program")
self.exe.run(self._send_program)
def run(self, feed, fetch):
self._logger.debug("begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step))
self._logger.debug(
"begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step))
if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
self._logger.debug("begin to run current step")
loss = self.exe.run(self._main_program,
feed=feed,
fetch_list=fetch)
loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program")
self.exe.run(self._send_program)
self.cur_step += 1
return loss
class SecAggTrainer(FLTrainer):
......@@ -207,24 +203,24 @@ class SecAggTrainer(FLTrainer):
self.cur_step = 0
def run(self, feed, fetch):
self._logger.debug("begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step))
self._logger.debug(
"begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step))
if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
scope = fluid.global_scope()
self._logger.debug("begin to run current step")
loss = self.exe.run(self._main_program,
feed=feed,
fetch_list=fetch)
loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program")
noise = 0.0
scale = pow(10.0, 5)
digestmod=hashlib.sha256
digestmod = hashlib.sha256
# 1. load priv key and other's pub key
dh = DiffieHellman(group=15, key_length=256)
dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt")
dh.load_private_key(self._key_dir + str(self._trainer_id) +
"_priv_key.txt")
key = str(self._step_id).encode("utf-8")
for i in range(self._trainer_num):
if i != self._trainer_id:
......@@ -232,7 +228,8 @@ class SecAggTrainer(FLTrainer):
public_key = int(f.read())
dh.generate_shared_secret(public_key, echo_return_key=True)
msg = dh.shared_key.encode("utf-8")
hex_res1 = hmac.new(key=key, msg=msg, digestmod=digestmod).hexdigest()
hex_res1 = hmac.new(key=key, msg=msg,
digestmod=digestmod).hexdigest()
current_noise = int(hex_res1[0:8], 16) / scale
if i > self._trainer_id:
noise = noise + current_noise
......@@ -241,9 +238,11 @@ class SecAggTrainer(FLTrainer):
scope = fluid.global_scope()
for param_name in self._param_name_list:
fluid.global_scope().var(param_name + str(self._trainer_id)).get_tensor().set(
numpy.array(scope.find_var(param_name + str(self._trainer_id)).get_tensor()) + noise, fluid.CPUPlace())
fluid.global_scope().var(param_name + str(
self._trainer_id)).get_tensor().set(
numpy.array(
scope.find_var(param_name + str(self._trainer_id))
.get_tensor()) + noise, fluid.CPUPlace())
self.exe.run(self._send_program)
self.cur_step += 1
return loss
......@@ -5,73 +5,84 @@ import tarfile
import random
def download(url,tar_path):
r = requests.get(url)
with open(tar_path,'wb') as f:
f.write(r.content)
def download(url, tar_path):
r = requests.get(url)
with open(tar_path, 'wb') as f:
f.write(r.content)
def extract(tar_path,target_path):
tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name,target_path)
tar.close()
def extract(tar_path, target_path):
tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, target_path)
def train(trainer_id,inner_step,batch_size,count_by_step):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def train_data():
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand]
print('training using '+cur_user)
train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y']
if count_by_step:
for i in range(inner_step*batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))]
else:
for i in range(len(train_images)):
yield train_images[i], train_labels[i]
tar.close()
train_file.close()
return train_data
def train(trainer_id, inner_step, batch_size, count_by_step):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
tar_path)
extract(tar_path, target_path)
def test(trainer_id,inner_step,batch_size,count_by_step):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def test_data():
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y']
for i in range(len(test_images)):
yield test_images[i], test_labels[i]
def train_data():
train_file = open(
"./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json"
% (trainer_id, trainer_id), 'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(
0, len(users)) # random choose a user from each trainer
cur_user = users[rand]
print('training using ' + cur_user)
train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y']
if count_by_step:
for i in range(inner_step * batch_size):
yield train_images[i % (len(train_images))], train_labels[i % (
len(train_images))]
else:
for i in range(len(train_images)):
yield train_images[i], train_labels[i]
test_file.close()
train_file.close()
return test_data
return train_data
def test(trainer_id, inner_step, batch_size, count_by_step):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
tar_path)
extract(tar_path, target_path)
def test_data():
test_file = open(
"./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json"
% (trainer_id, trainer_id), 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y']
for i in range(len(test_images)):
yield test_images[i], test_labels[i]
test_file.close()
return test_data
......@@ -3,6 +3,7 @@ import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
class Model(object):
def __init__(self):
pass
......@@ -12,7 +13,8 @@ class Model(object):
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label)
self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program()
......@@ -34,8 +36,8 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.1)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names(
[x.name for x in inputs], [model.predict.name])
job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
[model.predict.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
......
......@@ -3,7 +3,7 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2
server_num = 1
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
......
......@@ -21,8 +21,8 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" # IP address for server
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
print("connect")
......@@ -4,7 +4,12 @@ import numpy as np
import sys
import logging
import time
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
def reader():
......@@ -15,13 +20,14 @@ def reader():
data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
yield data_dict
trainer_id = int(sys.argv[1]) # trainer id for each guest
trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
......@@ -37,4 +43,3 @@ while not trainer.stop():
epoch_id += 1
if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder)
......@@ -14,4 +14,3 @@
""" PaddleFL version string """
fl_version = "0.1.11"
module_proto_version = "0.1.11"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册