未验证 提交 3145ed18 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #9 from jhjiangcs/paddlefl-dpsgd

add sec_agg strategy and demo for secure aggregation of model paramet…
# Example in Recognize Digits with DPSGD
This document introduces how to use PaddleFL to train a model with Fl Strategy: Secure Aggregation. Using Secure Aggregation strategy, the server can aggregate the model parameters without learning the value of the parameters.
### Dependencies
- paddlepaddle>=1.6
### How to install PaddleFL
Please use python which has paddlepaddle installed
```
python setup.py install
```
### Model
The simplest Softmax regression model is to get features with input layer passing through a fully connected layer and then compute and ouput probabilities of multiple classes directly via Softmax function [[PaddlePaddle tutorial: recognize digits](https://github.com/PaddlePaddle/book/tree/develop/02.recognize_digits#references)].
### Datasets
Public Dataset [MNIST](http://yann.lecun.com/exdb/mnist/)
The dataset will downloaded automatically in the API and will be located under `/home/username/.cache/paddle/dataset/mnist`:
| filename | note |
| ----------------------- | ------------------------------- |
| train-images-idx3-ubyte | train data picture, 60,000 data |
| train-labels-idx1-ubyte | train data label, 60,000 data |
| t10k-images-idx3-ubyte | test data picture, 10,000 data |
| t10k-labels-idx1-ubyte | test data label, 10,000 data |
### How to work in PaddleFL
PaddleFL has two phases , CompileTime and RunTime. In CompileTime, a federated learning task is defined by fl_master. In RunTime, a federated learning job is executed on fl_server and fl_trainer in distributed clusters.
```
sh run.sh
```
#### How to work in CompileTime
In this example, we implement compile time programs in fl_master.py
```
python fl_master.py
```
In fl_master.py, we first define FL-Strategy, User-Defined-Program and Distributed-Config. Then FL-Job-Generator generate FL-Job for federated server and worker.
```python
def linear_regression(self, inputs, label):
param_attrs = fluid.ParamAttr(
name="fc_0.b_0",
initializer=fluid.initializer.ConstantInitializer(0.0))
param_attrs = fluid.ParamAttr(
name="fc_0.w_0",
initializer=fluid.initializer.ConstantInitializer(0.0))
self.predict = fluid.layers.fc(input=inputs, size=10, act='softmax', param_attr=param_attrs)
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label)
self.loss = fluid.layers.mean(self.sum_cost)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.startup_program = fluid.default_startup_program()
inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
model = Model()
model.linear_regression(inputs, label)
job_generator = JobGenerator()
optimizer = fluid.optimizer.SGD(learning_rate=0.01)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names(
[inputs.name, label.name], [model.loss.name])
build_strategy = FLStrategyFactory()
build_strategy.sec_agg = True
param_name_list = []
param_name_list.append("fc_0.w_0.opti.trainer_") # need trainer_id when running
param_name_list.append("fc_0.b_0.opti.trainer_")
build_strategy.param_name_list = param_name_list
build_strategy.inner_step = 10
strategy = build_strategy.create_fl_strategy()
# endpoints will be collected through the cluster
# in this example, we suppose endpoints have been collected
endpoints = ["127.0.0.1:8181"]
output = "fl_job_config"
job_generator.generate_fl_job(
strategy, server_endpoints=endpoints, worker_num=2, output=output)
```
How to work in RunTime
```shell
python3 fl_master.py
sleep 2
python3 -u fl_server.py >log/server0.log &
sleep 2
python3 -u fl_trainer.py 0 >log/trainer0.log &
sleep 2
python3 -u fl_trainer.py 1 >log/trainer1.log &
```
In fl_server.py, we load and run the FL server job.
```
server = FLServer()
server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
server.set_server_job(job)
server.start()
```
In fl_trainer.py, we prepare the MNIST dataset, load and run the FL trainer job, then evaluate the accuracy. Before training , we first prepare the party's private key and other party's public key. Then, each party generates a random noise using Diffie-Hellman key aggregate protocol with its private key and each other's public key [1]. If the other party's id is larger than this party's id, the model parameters add this random noise. If the other party's id is less than this party's id, the model parameters subtract this random noise. So, and the model parameters is masked before uploading to the server. Finally, the random noises can be removed when aggregating the masked parameters from all the parties.
```python
logging.basicConfig(filename="log/test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
logger = logging.getLogger("FLTrainer")
BATCH_SIZE = 64
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
trainer_num = 2
trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id
trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/"
trainer.start()
output_folder = "fl_model"
epoch_id = 0
step_i = 0
inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace())
# for test
test_program = trainer._main_program.clone(for_test=True)
def train_test(train_test_program,
train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = trainer.exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=["accuracy_0.tmp_0", "mean_0.tmp_0"])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
# for test
while not trainer.stop():
epoch_id += 1
print("epoch %d start train" % (epoch_id))
for data in train_reader():
step_i += 1
trainer.step_id = step_i
accuracy, = trainer.run(feed=feeder.feed(data),
fetch=["accuracy_0.tmp_0"])
if step_i % 100 == 0:
print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0]))
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
if epoch_id > 40:
break
if step_i % 100 == 0:
trainer.save_inference_program(output_folder)
```
[1] Aaron Segal, Antonio Marcedone, Benjamin Kreuter, Daniel Ramage, H. Brendan McMahan, Karn Seth, Keith Bonawitz, Sarvar Patel, Vladimir Ivanov. **Practical Secure Aggregation for Privacy-Preserving Machine Learning**, The 24th ACM Conference on Computer and Communications Security (**CCS**), 2017
......@@ -24,6 +24,7 @@ class FLStrategyFactory(object):
def __init__(self):
self._fed_avg = False
self._dpsgd = False
self._sec_agg = False
self._inner_step = 1
@property
......@@ -42,6 +43,14 @@ class FLStrategyFactory(object):
def dpsgd(self, s):
self._dpsgd = s
@property
def sec_agg(self):
return self._sec_agg
@sec_agg.setter
def sec_agg(self, s):
self._sec_agg = s
@property
def inner_step(self):
return self._inner_step
......@@ -58,10 +67,17 @@ class FLStrategyFactory(object):
strategy = FedAvgStrategy()
strategy._fed_avg = True
strategy._dpsgd = False
strategy._sec_agg = False
elif self._dpsgd == True:
strategy = DPSGDStrategy()
strategy._fed_avg = False
strategy._dpsgd = True
strategy._sec_agg = False
elif self._sec_agg == True:
strategy = SecAggStrategy()
strategy._fed_avg = False
strategy._dpsgd = False
strategy._sec_agg = True
strategy._inner_step = self._inner_step
return strategy
......@@ -239,87 +255,23 @@ class FedAvgStrategy(FLStrategyBase):
job._server_main_programs.append(main_prog)
class SecAggStrategy(FedAvgStrategy):
"""
DPSGDStrategy: this is model averaging optimization proposed in
Aaron Segal, Antonio Marcedone, Benjamin Kreuter, et al.
Practical Secure Aggregation for Privacy-Preserving Machine Learning,
The 24th ACM Conference on Computer and Communications Security ( CCS2017 ).
"""
def __init__(self):
super(SecAggStrategy, self).__init__()
self._param_name_list = []
@property
def param_name_list(self):
return self._param_name_list
@param_name_list.setter
def param_name_list(self, s):
self._param_name_list = s
# coding=utf-8
#
# (c) Chris von Csefalvay, 2015.
"""
__init__.py is responsible for [brief description here].
"""
# coding=utf-8
#
# The MIT License (MIT)
#
# Copyright (c) 2016 Chris von Csefalvay
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
"""
decorators declares some decorators that ensure the object has the
correct keys declared when need be.
"""
def requires_private_key(func):
"""
Decorator for functions that require the private key to be defined.
"""
def func_wrapper(self, *args, **kwargs):
if hasattr(self, "private_key"):
func(self, *args, **kwargs)
else:
self.generate_private_key()
func(self, *args, **kwargs)
return func_wrapper
def requires_public_key(func):
"""
Decorator for functions that require the public key to be defined. By definition, this includes the private key, as such, it's enough to use this to effect definition of both public and private key.
"""
def func_wrapper(self, *args, **kwargs):
if hasattr(self, "public_key"):
func(self, *args, **kwargs)
else:
self.generate_public_key()
func(self, *args, **kwargs)
return func_wrapper
# coding=utf-8
#
# The MIT License (MIT)
#
# Copyright (c) 2016 Chris von Csefalvay
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
"""
diffiehellmann declares the main key exchange class.
"""
__version__ = '0.13.3'
from hashlib import sha256
from .decorators import requires_private_key
from .exceptions import MalformedPublicKey, RNGError
from .primes import PRIMES
import os
#python3
try:
from ssl import RAND_bytes
rng = RAND_bytes
except(AttributeError, ImportError):
raise RNGError
#python2
#rng = os.urandom
class DiffieHellman:
"""
Implements the Diffie-Hellman key exchange protocol.
"""
def __init__(self,
group=18,
key_length=640):
self.key_length = max(200, key_length)
self.generator = PRIMES[group]["generator"]
self.prime = PRIMES[group]["prime"]
def load_private_key(self, priv_key_filepath="priv_key.txt"):
f = open(priv_key_filepath, "r")
self.private_key = int(f.read())
# self.private_key = 1236621350910932696206938487330072474688096146032487063733488274339542368951034578546908207981931
def generate_private_key(self):
"""
Generates a private key of key_length bits and attaches it to the object as the __private_key variable.
:return: void
:rtype: void
"""
key_length = self.key_length // 8 + 8
key = 0
try:
key = int.from_bytes(rng(key_length), byteorder='big')
except:
key = int(hex(rng(key_length)), base=16)
self.private_key = key
def verify_public_key(self, other_public_key):
return self.prime - 1 > other_public_key > 2 and pow(other_public_key, (self.prime - 1) // 2, self.prime) == 1
@requires_private_key
def generate_public_key(self):
"""
Generates public key.
:return: void
:rtype: void
"""
self.public_key = pow(self.generator,
self.private_key,
self.prime)
@requires_private_key
def generate_shared_secret(self, other_public_key, echo_return_key=False):
"""
Generates shared secret from the other party's public key.
:param other_public_key: Other party's public key
:type other_public_key: int
:param echo_return_key: Echo return shared key
:type bool
:return: void
:rtype: void
"""
if self.verify_public_key(other_public_key) is False:
raise MalformedPublicKey
self.shared_secret = pow(other_public_key,
self.private_key,
self.prime)
#python2
#length = self.shared_secret.bit_length() // 8 + 1
#shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:]
#python3
shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big')
_h = sha256()
_h.update(bytes(shared_secret_as_bytes))
self.shared_key = _h.hexdigest()
if echo_return_key is True:
return self.shared_key
# coding=utf-8
#
# (c) Chris von Csefalvay, 2015.
"""
exceptions is responsible for exception handling etc.
"""
class MalformedPublicKey(BaseException):
"""
The public key is malformed as it does not meet the Legendre symbol criterion. The key might have been tampered with or might have been damaged in transit.
"""
def __str__(self):
return "Public key malformed: fails Legendre symbol verification."
class RNGError(BaseException):
"""
Thrown when RNG could not be obtained.
"""
def __str__(self):
return "RNG could not be obtained. This module currently only works with Python 3."
\ No newline at end of file
# coding=utf-8
#
# The MIT License (MIT)
#
# Copyright (c) 2016 Chris von Csefalvay
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The primes presented here are (c) The Internet Society, 2003.
# Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman
# groups for Internet Key Exchange (IKE)_.
#
"""
primes holds the RFC 3526 MODP primes and their generators.
"""
PRIMES = {
5: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF,
"generator": 2
},
14: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF,
"generator": 2
},
15: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF,
"generator": 2
},
16: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF,
"generator": 2
},
17: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF,
"generator": 2
},
18: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF,
"generator": 2
},
}
......@@ -13,6 +13,9 @@
# limitations under the License.
import paddle.fluid as fluid
import logging
import numpy
import hmac
from .diffiehellman.diffiehellman import DiffieHellman
class FLTrainerFactory(object):
def __init__(self):
......@@ -27,6 +30,9 @@ class FLTrainerFactory(object):
elif strategy._dpsgd == True:
trainer = FLTrainer()
trainer.set_trainer_job(job)
elif strategy._sec_agg == True:
trainer = SecAggTrainer()
trainer.set_trainer_job(job)
trainer.set_trainer_job(job)
return trainer
......@@ -75,6 +81,7 @@ class FLTrainer(object):
# TODO(guru4elephant): add connection with master
return False
class FedAvgTrainer(FLTrainer):
def __init__(self):
super(FedAvgTrainer, self).__init__()
......@@ -111,4 +118,102 @@ class FedAvgTrainer(FLTrainer):
def stop(self):
return False
class SecAggTrainer(FLTrainer):
def __init__(self):
super(SecAggTrainer, self).__init__()
pass
@property
def trainer_id(self):
return self._trainer_id
@trainer_id.setter
def trainer_id(self, s):
self._trainer_id = s
@property
def trainer_num(self):
return self._trainer_num
@trainer_num.setter
def trainer_num(self, s):
self._trainer_num = s
@property
def key_dir(self):
return self._key_dir
@key_dir.setter
def key_dir(self, s):
self._key_dir = s
@property
def step_id(self):
return self._step_id
@step_id.setter
def step_id(self, s):
self._step_id = s
def start(self):
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe.run(self._startup_program)
self.cur_step = 0
def set_trainer_job(self, job):
super(SecAggTrainer, self).set_trainer_job(job)
self._send_program = job._trainer_send_program
self._recv_program = job._trainer_recv_program
self_step = job._strategy._inner_step
self._param_name_list = job._strategy._param_name_list
def reset(self):
self.cur_step = 0
def run(self, feed, fetch):
self._logger.debug("begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step))
if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
scope = fluid.global_scope()
self._logger.debug("begin to run current step")
loss = self.exe.run(self._main_program,
feed=feed,
fetch_list=fetch)
if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program")
noise = 0.0
scale = pow(10.0, 5)
digestmod="SHA256"
# 1. load priv key and other's pub key
dh = DiffieHellman(group=15, key_length=256)
dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt")
key = str(self._step_id).encode("utf-8")
for i in range(self._trainer_num):
if i != self._trainer_id:
f = open(self._key_dir + str(i) + "_pub_key.txt", "r")
public_key = int(f.read())
dh.generate_shared_secret(public_key, echo_return_key=True)
msg = dh.shared_key.encode("utf-8")
hex_res1 = hmac.new(key=key, msg=msg, digestmod=digestmod).hexdigest()
current_noise = int(hex_res1[0:8], 16) / scale
if i > self._trainer_id:
noise = noise + current_noise
else:
noise = noise - current_noise
scope = fluid.global_scope()
for param_name in self._param_name_list:
fluid.global_scope().var(param_name + str(self._trainer_id)).get_tensor().set(
numpy.array(scope.find_var(param_name + str(self._trainer_id)).get_tensor()) + noise, fluid.CPUPlace())
self.exe.run(self._send_program)
self.cur_step += 1
return loss
def stop(self):
return False
import paddle.fluid as fluid
import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
class Model(object):
def __init__(self):
pass
def linear_regression(self, inputs, label):
param_attrs = fluid.ParamAttr(
name="fc_0.b_0",
initializer=fluid.initializer.ConstantInitializer(0.0))
param_attrs = fluid.ParamAttr(
name="fc_0.w_0",
initializer=fluid.initializer.ConstantInitializer(0.0))
self.predict = fluid.layers.fc(input=inputs, size=10, act='softmax', param_attr=param_attrs)
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label)
self.loss = fluid.layers.mean(self.sum_cost)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.startup_program = fluid.default_startup_program()
inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
model = Model()
model.linear_regression(inputs, label)
job_generator = JobGenerator()
optimizer = fluid.optimizer.SGD(learning_rate=0.01)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names(
[inputs.name, label.name], [model.loss.name])
build_strategy = FLStrategyFactory()
#build_strategy.fed_avg = True
build_strategy.sec_agg = True
param_name_list = []
param_name_list.append("fc_0.w_0.opti.trainer_") # need trainer_id when running
param_name_list.append("fc_0.b_0.opti.trainer_")
build_strategy.param_name_list = param_name_list
build_strategy.inner_step = 10
strategy = build_strategy.create_fl_strategy()
# endpoints will be collected through the cluster
# in this example, we suppose endpoints have been collected
endpoints = ["127.0.0.1:8181"]
output = "fl_job_config"
job_generator.generate_fl_job(
strategy, server_endpoints=endpoints, worker_num=2, output=output)
# fl_job_config will be dispatched to workers
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle_fl as fl
import paddle.fluid as fluid
from paddle_fl.core.server.fl_server import FLServer
from paddle_fl.core.master.fl_job import FLRunTimeJob
server = FLServer()
server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
server.set_server_job(job)
server.start()
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy
import sys
import logging
import paddle
import paddle.fluid as fluid
import time
import datetime
import math
import hashlib
import hmac
logging.basicConfig(filename="log/test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
logger = logging.getLogger("FLTrainer")
BATCH_SIZE = 64
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
trainer_num = 2
trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id
trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/"
trainer.start()
output_folder = "fl_model"
epoch_id = 0
step_i = 0
inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace())
# for test
test_program = trainer._main_program.clone(for_test=True)
def train_test(train_test_program,
train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = trainer.exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=["accuracy_0.tmp_0", "mean_0.tmp_0"])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
# for test
while not trainer.stop():
epoch_id += 1
print("epoch %d start train" % (epoch_id))
for data in train_reader():
step_i += 1
trainer.step_id = step_i
accuracy, = trainer.run(feed=feeder.feed(data),
fetch=["accuracy_0.tmp_0"])
if step_i % 100 == 0:
print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0]))
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
if epoch_id > 40:
break
if step_i % 100 == 0:
trainer.save_inference_program(output_folder)
1236621350910932696206938487330072474688096146032487063733488274339542368951034578546908207981931
2438748580808349511143047663636683775879288034436941526695550498623461587527621346172907651006831789701999970929529915459467532662545948308044143788306668377086821294492459623439935894424167712515718436351900091777957477710004777078638317806960364609629258387413979203403741893205419691425902518810085451041187685334971769087054033027561974230347468587825700834108657561999305482311897914109364221430821533207693682979777541616125499682380618775029176238891407643926372043660610226672413497764635239787000143341827693941253721638947580506197728500367325524850325027980531066702962726949006217630290236644410746181942256812170056772600756232506116738493114591218127323741133163913140583529684827066023347088796194846253682954154504336640429027403657831470993825621749318372332546269811820953216261135662418531598954663771775691648448615131802158937156803324423733802071166119966224716088242291968098450309032800335049617861465
\ No newline at end of file
239766553287108165963157659638452083696427531397301209851404530650120167468636815062334148471396
2514645349791449916465355128335954929464444612258498884322250411584328344530925790221013632799576102047787468232441470392901627580493383471719532612816534848391408601920539266665550346246343040368608757429591392784807798812848893304745441721044204602414415725300562075953290154457726382683020925406187524758708694161689285261293920782115270550960717687322240202298426363733065475031448888618026711435993053991653780694485897784243346859087028197560255857091562150381619708471192080403868115055265681423866891707972356449212236920210992128063075734725292359699578940877165584175851667803049667020005978253573912096358469050409541237248593428640028573437783747958382446666712845099931003578559688134007238753677011257181086592677636834099341020870502521085827878680362572013623469761170961943916356175726242515624843837354222489536469472365937707615959657846428001149463801728084949088483783942894784080451399167982957006474558
\ No newline at end of file
unset http_proxy
unset https_proxy
if [ ! -d log ];then
mkdir log
fi
python3 fl_master.py
sleep 2
python3 -u fl_server.py >log/server0.log &
sleep 2
python3 -u fl_trainer.py 0 >log/trainer0.log &
sleep 2
python3 -u fl_trainer.py 1 >log/trainer1.log &
#!/bin/bash
echo "Stop service!"
ps -ef | grep -E "fl" | grep -v grep | awk '{print $2}' | xargs kill -9
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册