未验证 提交 97333c4a 编写于 作者: Q Qinghe JING 提交者: GitHub

Merge pull request #46 from qjing666/travis

add travis CI and pre-commit-config
repos:
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
sha: v1.0.1
hooks:
- id: remove-crlf
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- repo: local
hooks:
- id: clang-format-with-version-check
name: clang-format
description: Format files with ClangFormat.
entry: bash ./tools/codestyle/clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: local
hooks:
- id: cpplint-cpp-source
name: cpplint
description: Check C++ code style using cpplint.py.
entry: bash ./tools/codestyle/cpplint_pre_commit.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
- repo: local
hooks:
- id: pylint-doc-string
name: pylint
description: Check python docstring style using docstring_checker.
entry: bash ./tools/codestyle/pylint_pre_commit.hook
language: system
files: \.(py)$
- repo: local
hooks:
- id: copyright_checker
name: copyright_checker
entry: python ./tools/codestyle/copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$
language:python
notifications:
email:
on_success: change
on_failure: always
sudo: false
os:
- linux
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
|---|---| |---|---|
| guru4elephant | Daxiang Dong | | guru4elephant | Daxiang Dong |
| frankwhzhang | Wenhui Zhang | | frankwhzhang | Wenhui Zhang |
| qjing666 | Qinghe Jing |
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq import zmq
import socket import socket
import msgpack import msgpack
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
import paddle import paddle
...@@ -12,14 +26,14 @@ import math ...@@ -12,14 +26,14 @@ import math
import msgpack import msgpack
def data_generater(samples,r): def data_generater(samples, r):
# data generater # data generater
def train_data(): def train_data():
for item in samples: for item in samples:
sample = msgpack.loads(r.get(str(item))) sample = msgpack.loads(r.get(str(item)))
conv = sample[0] conv = sample[0]
label = sample[1] label = sample[1]
yield conv,label yield conv, label
return train_data return train_data
...@@ -67,7 +81,7 @@ class ResNet(): ...@@ -67,7 +81,7 @@ class ResNet():
size=class_dim, size=class_dim,
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)), initializer=fluid.initializer.Uniform(-stdv, stdv)),
act = "softmax") act="softmax")
else: else:
for block in range(len(depth)): for block in range(len(depth)):
for i in range(depth[block]): for i in range(depth[block]):
...@@ -87,7 +101,7 @@ class ResNet(): ...@@ -87,7 +101,7 @@ class ResNet():
size=class_dim, size=class_dim,
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)), initializer=fluid.initializer.Uniform(-stdv, stdv)),
act = "softmax") act="softmax")
return out return out
def conv_bn_layer(self, def conv_bn_layer(self,
...@@ -123,8 +137,6 @@ class ResNet(): ...@@ -123,8 +137,6 @@ class ResNet():
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', ) moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, is_first, name): def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1] ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True: if ch_in != ch_out or stride != 1 or is_first == True:
...@@ -181,31 +193,33 @@ class ResNet(): ...@@ -181,31 +193,33 @@ class ResNet():
input, num_filters, stride, is_first, name=name + "_branch1") input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
# local redis config # local redis config
redis_host = "127.0.0.1" redis_host = "127.0.0.1"
redis_port = 6379 redis_port = 6379
redis_password = "" redis_password = ""
r = redis.StrictRedis(host=redis_host, port=redis_port, password=redis_password) r = redis.StrictRedis(
host=redis_host, port=redis_port, password=redis_password)
# reader generation # reader generation
reader = fluid.layers.py_reader(capacity=64, reader = fluid.layers.py_reader(
shapes=[(-1,64, 8, 8), (-1,1)], capacity=64, shapes=[(-1, 64, 8, 8), (-1, 1)],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'])
samples = r.keys() samples = r.keys()
train_data = data_generater(samples,r) train_data = data_generater(samples, r)
reader.decorate_paddle_reader(paddle.batch( reader.decorate_paddle_reader(
paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
train_data, buf_size=5000), train_data, buf_size=5000), batch_size=64))
batch_size=64))
conv1,label = fluid.layers.read_file(reader) conv1, label = fluid.layers.read_file(reader)
# train program # train program
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
model = ResNet(layers=50) model = ResNet(layers=50)
predicts = model.net(conv1,10) predicts = model.net(conv1, 10)
cost = fluid.layers.cross_entropy(input=predicts, label=label) cost = fluid.layers.cross_entropy(input=predicts, label=label)
accuracy = fluid.layers.accuracy(input=predicts, label=label) accuracy = fluid.layers.accuracy(input=predicts, label=label)
loss = fluid.layers.mean(cost) loss = fluid.layers.mean(cost)
...@@ -226,10 +240,12 @@ for pass_id in range(EPOCH_NUM): ...@@ -226,10 +240,12 @@ for pass_id in range(EPOCH_NUM):
try: try:
while True: while True:
start_time = time.time() start_time = time.time()
loss_value,acc_value = exe.run(fetch_list=[loss.name,accuracy.name]) loss_value, acc_value = exe.run(
fetch_list=[loss.name, accuracy.name])
step += 1 step += 1
if step % 10 == 0: if step % 10 == 0:
print("epoch: "+ str(pass_id)+"step: "+str(step)+"loss: "+ str(loss_value)+"acc: "+str(acc_value)) print("epoch: " + str(pass_id) + "step: " + str(step) +
"loss: " + str(loss_value) + "acc: " + str(acc_value))
end_time = time.time() end_time = time.time()
total_time += (end_time - start_time) total_time += (end_time - start_time)
except fluid.core.EOFException: except fluid.core.EOFException:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
import paddle import paddle
...@@ -9,6 +23,8 @@ import time ...@@ -9,6 +23,8 @@ import time
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
import msgpack import msgpack
def conv_bn_layer(input, def conv_bn_layer(input,
num_filters, num_filters,
filter_size, filter_size,
...@@ -51,6 +67,7 @@ def load_conf(conf_file, local_dict): ...@@ -51,6 +67,7 @@ def load_conf(conf_file, local_dict):
local_dict[group[0]] = group[1] local_dict[group[0]] = group[1]
return local_dict return local_dict
# redis DB configuration # redis DB configuration
redis_host = "127.0.0.1" redis_host = "127.0.0.1"
redis_port = 6379 redis_port = 6379
...@@ -58,27 +75,40 @@ redis_password = "" ...@@ -58,27 +75,40 @@ redis_password = ""
start_time = time.time() start_time = time.time()
# start a redis client and empty the DB # start a redis client and empty the DB
r = redis.StrictRedis(host=redis_host, port=redis_port, password=redis_password) r = redis.StrictRedis(
host=redis_host, port=redis_port, password=redis_password)
r.flushall() r.flushall()
# encoding program # encoding program
images = fluid.layers.data(name='images', shape=[3,32,32], dtype='float32') images = fluid.layers.data(name='images', shape=[3, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
place = fluid.CPUPlace() place = fluid.CPUPlace()
conv1 = conv_bn_layer(input=images,num_filters=64,filter_size=7,stride=2,act='relu',name="conv1") conv1 = conv_bn_layer(
pool = fluid.layers.pool2d(input=conv1,pool_size=3,pool_stride=2,pool_padding=1,pool_type='max') input=images,
feeder = fluid.DataFeeder(place=place, feed_list=[images,label]) num_filters=64,
filter_size=7,
stride=2,
act='relu',
name="conv1")
pool = fluid.layers.pool2d(
input=conv1, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
pretrained_model = 'ResNet50_pretrained' pretrained_model = 'ResNet50_pretrained'
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# load pretrained mode and prepare datal # load pretrained mode and prepare datal
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, main_program=fluid.default_main_program(),
predicate=if_exist)
fluid.io.load_vars(
exe,
pretrained_model,
main_program=fluid.default_main_program(),
predicate=if_exist)
train_data = paddle.dataset.cifar.train10() train_data = paddle.dataset.cifar.train10()
step = 0 step = 0
...@@ -86,11 +116,13 @@ step = 0 ...@@ -86,11 +116,13 @@ step = 0
for data in train_data(): for data in train_data():
pre_data = [] pre_data = []
pre_data.append(data) pre_data.append(data)
res = exe.run(program=fluid.default_main_program(),feed=feeder.feed(pre_data), fetch_list=[pool.name]) res = exe.run(program=fluid.default_main_program(),
sample = [res[0][0].tolist(),data[1]] feed=feeder.feed(pre_data),
fetch_list=[pool.name])
sample = [res[0][0].tolist(), data[1]]
step += 1 step += 1
file = msgpack.dumps(sample) file = msgpack.dumps(sample)
r.set(step,file) r.set(step, file)
if step % 100 == 0: if step % 100 == 0:
print(numpy.array(sample[0]).shape) print(numpy.array(sample[0]).shape)
print("%dstart" % step) print("%dstart" % step)
...@@ -99,6 +131,4 @@ files = r.keys() ...@@ -99,6 +131,4 @@ files = r.keys()
print("upload file numbers: %d" % len(files)) print("upload file numbers: %d" % len(files))
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
print("total time: %d"% total_time) print("total time: %d" % total_time)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq import zmq
import socket import socket
import msgpack import msgpack
import os import os
mission_dict = {"mission": "image classification", "image_size": [3,32,32]} mission_dict = {"mission": "image classification", "image_size": [3, 32, 32]}
#send request #send request
context = zmq.Context() context = zmq.Context()
zmq_socket = context.socket(zmq.REQ) zmq_socket = context.socket(zmq.REQ)
......
...@@ -3,4 +3,3 @@ mistune ...@@ -3,4 +3,3 @@ mistune
sphinx_rtd_theme sphinx_rtd_theme
paddlepaddle>=1.6 paddlepaddle>=1.6
zmq zmq
...@@ -181,4 +181,3 @@ while not trainer.stop(): ...@@ -181,4 +181,3 @@ while not trainer.stop():
To show the effectiveness of DPSGD-based federated learning with PaddleFL, a simulated experiment is conducted on an open source dataset MNIST. From the figure given below, model evaluation results are similar between DPSGD-based federated learning and traditional parameter server training when the overall privacy budget *epsilon* is 1.3 or 0.13. To show the effectiveness of DPSGD-based federated learning with PaddleFL, a simulated experiment is conducted on an open source dataset MNIST. From the figure given below, model evaluation results are similar between DPSGD-based federated learning and traditional parameter server training when the overall privacy budget *epsilon* is 1.3 or 0.13.
<img src="fl_dpsgd_benchmark.png" height=400 width=600 hspace='10'/> <br /> <img src="fl_dpsgd_benchmark.png" height=400 width=600 hspace='10'/> <br />
...@@ -103,6 +103,3 @@ wget https://paddle-zwh.bj.bcebos.com/gru4rec_paddlefl_benchmark/gru4rec_benchma ...@@ -103,6 +103,3 @@ wget https://paddle-zwh.bj.bcebos.com/gru4rec_paddlefl_benchmark/gru4rec_benchma
| 1/4 of the whole dataset | private training | - | 0.282 | | 1/4 of the whole dataset | private training | - | 0.282 |
<img src="fl_benchmark.png" height=300 width=500 hspace='10'/> <br /> <img src="fl_benchmark.png" height=300 width=500 hspace='10'/> <br />
...@@ -55,4 +55,3 @@ In PaddleFL, components for defining a federated learning task and training a fe ...@@ -55,4 +55,3 @@ In PaddleFL, components for defining a federated learning task and training a fe
- Federated Learning Systems deployment methods in Kubernetes. - Federated Learning Systems deployment methods in Kubernetes.
- Vertical Federated Learning Strategies and more horizontal federated learning strategies will be open sourced. - Vertical Federated Learning Strategies and more horizontal federated learning strategies will be open sourced.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -22,4 +22,3 @@ from .scheduler.agent_master import FLWorkerAgent ...@@ -22,4 +22,3 @@ from .scheduler.agent_master import FLWorkerAgent
from .scheduler.agent_master import FLScheduler from .scheduler.agent_master import FLScheduler
from .submitter.client_base import HPCClient from .submitter.client_base import HPCClient
from .submitter.client_base import CloudClient from .submitter.client_base import CloudClient
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
import os import os
import paddle.fluid as fluid import paddle.fluid as fluid
class FLJobBase(object): class FLJobBase(object):
""" """
FLJobBase is fl job base class, responsible for save and load FLJobBase is fl job base class, responsible for save and load
a federated learning job a federated learning job
""" """
def __init__(self): def __init__(self):
pass pass
...@@ -64,6 +66,7 @@ class FLJobBase(object): ...@@ -64,6 +66,7 @@ class FLJobBase(object):
return fluid.Program.parse_from_string(program_desc_str) return fluid.Program.parse_from_string(program_desc_str)
return None return None
class FLCompileTimeJob(FLJobBase): class FLCompileTimeJob(FLJobBase):
""" """
FLCompileTimeJob is a container for compile time job in federated learning. FLCompileTimeJob is a container for compile time job in federated learning.
...@@ -71,6 +74,7 @@ class FLCompileTimeJob(FLJobBase): ...@@ -71,6 +74,7 @@ class FLCompileTimeJob(FLJobBase):
are in FLCompileTimeJob. Also, server main programs and server startup programs are in FLCompileTimeJob. Also, server main programs and server startup programs
are in this class. FLCompileTimeJob has server endpoints for debugging as well are in this class. FLCompileTimeJob has server endpoints for debugging as well
""" """
def __init__(self): def __init__(self):
self._trainer_startup_programs = [] self._trainer_startup_programs = []
self._trainer_recv_programs = [] self._trainer_recv_programs = []
...@@ -101,18 +105,15 @@ class FLCompileTimeJob(FLJobBase): ...@@ -101,18 +105,15 @@ class FLCompileTimeJob(FLJobBase):
os.system("mkdir -p %s" % server_folder) os.system("mkdir -p %s" % server_folder)
server_startup = self._server_startup_programs[i] server_startup = self._server_startup_programs[i]
server_main = self._server_main_programs[i] server_main = self._server_main_programs[i]
self._save_program( self._save_program(server_startup,
server_startup,
"%s/server.startup.program" % server_folder) "%s/server.startup.program" % server_folder)
self._save_program( self._save_program(server_main,
server_main,
"%s/server.main.program" % server_folder) "%s/server.main.program" % server_folder)
self._save_readable_program(server_startup,
"%s/server.startup.program.txt" %
server_folder)
self._save_readable_program( self._save_readable_program(
server_startup, server_main, "%s/server.main.program.txt" % server_folder)
"%s/server.startup.program.txt" % server_folder)
self._save_readable_program(
server_main,
"%s/server.main.program.txt" % server_folder)
self._save_str_list(self._feed_names, self._save_str_list(self._feed_names,
"%s/feed_names" % server_folder) "%s/feed_names" % server_folder)
self._save_str_list(self._target_names, self._save_str_list(self._target_names,
...@@ -127,18 +128,15 @@ class FLCompileTimeJob(FLJobBase): ...@@ -127,18 +128,15 @@ class FLCompileTimeJob(FLJobBase):
os.system("mkdir -p %s" % trainer_folder) os.system("mkdir -p %s" % trainer_folder)
trainer_startup = self._trainer_startup_programs[i] trainer_startup = self._trainer_startup_programs[i]
trainer_main = self._trainer_main_programs[i] trainer_main = self._trainer_main_programs[i]
self._save_program( self._save_program(trainer_startup,
trainer_startup,
"%s/trainer.startup.program" % trainer_folder) "%s/trainer.startup.program" % trainer_folder)
self._save_program( self._save_program(trainer_main,
trainer_main,
"%s/trainer.main.program" % trainer_folder) "%s/trainer.main.program" % trainer_folder)
self._save_readable_program(trainer_startup,
"%s/trainer.startup.program.txt" %
trainer_folder)
self._save_readable_program( self._save_readable_program(
trainer_startup, trainer_main, "%s/trainer.main.program.txt" % trainer_folder)
"%s/trainer.startup.program.txt" % trainer_folder)
self._save_readable_program(
trainer_main,
"%s/trainer.main.program.txt" % trainer_folder)
self._save_str_list(self._feed_names, self._save_str_list(self._feed_names,
"%s/feed_names" % trainer_folder) "%s/feed_names" % trainer_folder)
self._save_str_list(self._target_names, self._save_str_list(self._target_names,
...@@ -152,18 +150,14 @@ class FLCompileTimeJob(FLJobBase): ...@@ -152,18 +150,14 @@ class FLCompileTimeJob(FLJobBase):
trainer_folder = "%s/trainer%d" % (folder, i) trainer_folder = "%s/trainer%d" % (folder, i)
trainer_send = self._trainer_send_programs[i] trainer_send = self._trainer_send_programs[i]
trainer_recv = self._trainer_recv_programs[i] trainer_recv = self._trainer_recv_programs[i]
self._save_program( self._save_program(trainer_send,
trainer_send,
"%s/trainer.send.program" % trainer_folder) "%s/trainer.send.program" % trainer_folder)
self._save_program( self._save_program(trainer_recv,
trainer_recv,
"%s/trainer.recv.program" % trainer_folder) "%s/trainer.recv.program" % trainer_folder)
self._save_readable_program( self._save_readable_program(
trainer_send, trainer_send, "%s/trainer.send.program.txt" % trainer_folder)
"%s/trainer.send.program.txt" % trainer_folder)
self._save_readable_program( self._save_readable_program(
trainer_recv, trainer_recv, "%s/trainer.recv.program.txt" % trainer_folder)
"%s/trainer.recv.program.txt" % trainer_folder)
class FLRunTimeJob(FLJobBase): class FLRunTimeJob(FLJobBase):
...@@ -172,6 +166,7 @@ class FLRunTimeJob(FLJobBase): ...@@ -172,6 +166,7 @@ class FLRunTimeJob(FLJobBase):
A trainer or a server can load FLRunTimeJob. Only necessary programs A trainer or a server can load FLRunTimeJob. Only necessary programs
can be loaded in FLRunTimeJob can be loaded in FLRunTimeJob
""" """
def __init__(self): def __init__(self):
self._trainer_startup_program = None self._trainer_startup_program = None
self._trainer_recv_program = None self._trainer_recv_program = None
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
from .fl_job import FLCompileTimeJob from .fl_job import FLCompileTimeJob
class JobGenerator(object): class JobGenerator(object):
""" """
A JobGenerator is responsible for generating distributed federated A JobGenerator is responsible for generating distributed federated
...@@ -21,6 +22,7 @@ class JobGenerator(object): ...@@ -21,6 +22,7 @@ class JobGenerator(object):
need to define a deep learning model together to do horizontal federated need to define a deep learning model together to do horizontal federated
learning. learning.
""" """
def __init__(self): def __init__(self):
# worker num for federated learning # worker num for federated learning
self._worker_num = 0 self._worker_num = 0
...@@ -32,7 +34,6 @@ class JobGenerator(object): ...@@ -32,7 +34,6 @@ class JobGenerator(object):
self._feed_names = [] self._feed_names = []
self._target_names = [] self._target_names = []
def set_optimizer(self, optimizer): def set_optimizer(self, optimizer):
""" """
Set optimizer of current job Set optimizer of current job
...@@ -56,8 +57,10 @@ class JobGenerator(object): ...@@ -56,8 +57,10 @@ class JobGenerator(object):
self._startup_prog = startup self._startup_prog = startup
def set_infer_feed_and_target_names(self, feed_names, target_names): def set_infer_feed_and_target_names(self, feed_names, target_names):
if not isinstance(feed_names, list) or not isinstance(target_names, list): if not isinstance(feed_names, list) or not isinstance(target_names,
raise ValueError("input should be list in set_infer_feed_and_target_names") list):
raise ValueError(
"input should be list in set_infer_feed_and_target_names")
''' '''
print(feed_names) print(feed_names)
print(target_names) print(target_names)
...@@ -76,7 +79,6 @@ class JobGenerator(object): ...@@ -76,7 +79,6 @@ class JobGenerator(object):
server_endpoints=[], server_endpoints=[],
worker_num=1, worker_num=1,
output=None): output=None):
""" """
Generate Federated Learning Job, based on user defined configs Generate Federated Learning Job, based on user defined configs
...@@ -130,17 +132,66 @@ class JobGenerator(object): ...@@ -130,17 +132,66 @@ class JobGenerator(object):
startup_program = self._startup_prog.clone() startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone() main_program = self._losses[0].block.program.clone()
fl_strategy._build_trainer_program_for_job( fl_strategy._build_trainer_program_for_job(
trainer_id, program=main_program, trainer_id,
ps_endpoints=server_endpoints, trainers=worker_num, program=main_program,
sync_mode=True, startup_program=startup_program, ps_endpoints=server_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_server_programs_for_job(
program=main_program,
ps_endpoints=server_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
local_job.set_feed_names(self._feed_names)
local_job.set_target_names(self._target_names)
local_job.set_strategy(fl_strategy)
local_job.save(output)
def generate_fl_job_for_k8s(self,
fl_strategy,
server_pod_endpoints=[],
server_service_endpoints=[],
worker_num=1,
output=None):
local_job = FLCompileTimeJob()
assert len(self._losses) > 0
assert self._startup_prog != None
assert fl_strategy != None
assert output != None
fl_strategy.minimize(self._optimizer, self._losses)
# strategy can generate startup and main program
# of a single worker and servers
for trainer_id in range(worker_num):
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_trainer_program_for_job(
trainer_id,
program=main_program,
ps_endpoints=server_service_endpoints,
trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job) job=local_job)
startup_program = self._startup_prog.clone() startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone() main_program = self._losses[0].block.program.clone()
fl_strategy._build_server_programs_for_job( fl_strategy._build_server_programs_for_job(
program=main_program, ps_endpoints=server_endpoints, program=main_program,
trainers=worker_num, sync_mode=True, ps_endpoints=server_pod_endpoints,
startup_program=startup_program, job=local_job) trainers=worker_num,
sync_mode=True,
startup_program=startup_program,
job=local_job)
local_job.set_feed_names(self._feed_names) local_job.set_feed_names(self._feed_names)
local_job.set_target_names(self._target_names) local_job.set_target_names(self._target_names)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq import zmq
import time import time
import random import random
def recv_and_parse_kv(socket): def recv_and_parse_kv(socket):
message = socket.recv() message = socket.recv()
group = message.decode().split("\t") group = message.decode().split("\t")
...@@ -10,9 +25,11 @@ def recv_and_parse_kv(socket): ...@@ -10,9 +25,11 @@ def recv_and_parse_kv(socket):
else: else:
return group[0], group[1] return group[0], group[1]
WORKER_EP = "WORKER_EP" WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP" SERVER_EP = "SERVER_EP"
class FLServerAgent(object): class FLServerAgent(object):
def __init__(self, scheduler_ep, current_ep): def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep self.scheduler_ep = scheduler_ep
...@@ -29,6 +46,7 @@ class FLServerAgent(object): ...@@ -29,6 +46,7 @@ class FLServerAgent(object):
if group[0] == 'INIT': if group[0] == 'INIT':
break break
class FLWorkerAgent(object): class FLWorkerAgent(object):
def __init__(self, scheduler_ep, current_ep): def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep self.scheduler_ep = scheduler_ep
...@@ -64,7 +82,6 @@ class FLWorkerAgent(object): ...@@ -64,7 +82,6 @@ class FLWorkerAgent(object):
return False return False
class FLScheduler(object): class FLScheduler(object):
def __init__(self, worker_num, server_num, port=9091, socket=None): def __init__(self, worker_num, server_num, port=9091, socket=None):
self.context = zmq.Context() self.context = zmq.Context()
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle_fl.core.scheduler.agent_master import FLServerAgent from paddle_fl.core.scheduler.agent_master import FLServerAgent
class FLServer(object):
class FLServer(object):
def __init__(self): def __init__(self):
self._startup_program = None self._startup_program = None
self._main_program = None self._main_program = None
......
...@@ -48,8 +48,8 @@ def wait_server_ready(endpoints): ...@@ -48,8 +48,8 @@ def wait_server_ready(endpoints):
not_ready_endpoints.append(ep) not_ready_endpoints.append(ep)
if not all_ok: if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n") sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints)
"\n") + "\n")
sys.stderr.flush() sys.stderr.flush()
time.sleep(3) time.sleep(3)
else: else:
......
...@@ -163,7 +163,8 @@ def block_to_code(block, block_idx, fout=None, skip_op_callstack=False): ...@@ -163,7 +163,8 @@ def block_to_code(block, block_idx, fout=None, skip_op_callstack=False):
indent = 0 indent = 0
print( print(
"{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx), "{0}{1} // block {2}".format(
get_indent_space(indent), '{', block_idx),
file=fout) file=fout)
indent += 1 indent += 1
......
...@@ -50,6 +50,7 @@ def log(*args): ...@@ -50,6 +50,7 @@ def log(*args):
if PRINT_LOG: if PRINT_LOG:
print(args) print(args)
def same_or_split_var(p_name, var_name): def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block") return p_name == var_name or p_name.startswith(var_name + ".block")
...@@ -113,7 +114,9 @@ class FLDistributeTranspiler(object): ...@@ -113,7 +114,9 @@ class FLDistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program): def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = [] sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] sparse_update_op_types = [
"lookup_table", "nce", "hierarchical_sigmoid"
]
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr( if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True: 'remote_prefetch') is True:
...@@ -411,7 +414,8 @@ class FLDistributeTranspiler(object): ...@@ -411,7 +414,8 @@ class FLDistributeTranspiler(object):
if self.sync_mode and self.trainer_num > 1: if self.sync_mode and self.trainer_num > 1:
for trainer_id in range(self.trainer_num): for trainer_id in range(self.trainer_num):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.opti.trainer_%d" % (orig_var_name, trainer_id), name="%s.opti.trainer_%d" %
(orig_var_name, trainer_id),
persistable=False, persistable=False,
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
...@@ -816,7 +820,6 @@ class FLDistributeTranspiler(object): ...@@ -816,7 +820,6 @@ class FLDistributeTranspiler(object):
iomap = collections.OrderedDict() iomap = collections.OrderedDict()
return iomap return iomap
def _get_lr_ops(self): def _get_lr_ops(self):
lr_ops = [] lr_ops = []
block = self.origin_program.global_block() block = self.origin_program.global_block()
......
...@@ -16,11 +16,13 @@ from .fl_distribute_transpiler import FLDistributeTranspiler ...@@ -16,11 +16,13 @@ from .fl_distribute_transpiler import FLDistributeTranspiler
from paddle.fluid.optimizer import SGD from paddle.fluid.optimizer import SGD
import paddle.fluid as fluid import paddle.fluid as fluid
class FLStrategyFactory(object): class FLStrategyFactory(object):
""" """
FLStrategyFactory is a FLStrategy builder FLStrategyFactory is a FLStrategy builder
Users can define strategy config to create different FLStrategy Users can define strategy config to create different FLStrategy
""" """
def __init__(self): def __init__(self):
self._fed_avg = False self._fed_avg = False
self._dpsgd = False self._dpsgd = False
...@@ -86,6 +88,7 @@ class FLStrategyBase(object): ...@@ -86,6 +88,7 @@ class FLStrategyBase(object):
""" """
FLStrategyBase is federated learning algorithm container FLStrategyBase is federated learning algorithm container
""" """
def __init__(self): def __init__(self):
self._fed_avg = False self._fed_avg = False
self._dpsgd = False self._dpsgd = False
...@@ -105,17 +108,23 @@ class FLStrategyBase(object): ...@@ -105,17 +108,23 @@ class FLStrategyBase(object):
for loss in losses: for loss in losses:
optimizer.minimize(loss) optimizer.minimize(loss)
def _build_trainer_program_for_job( def _build_trainer_program_for_job(self,
self, trainer_id=0, program=None, trainer_id=0,
ps_endpoints=[], trainers=0, program=None,
sync_mode=True, startup_program=None, ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None): job=None):
pass pass
def _build_server_programs_for_job( def _build_server_programs_for_job(self,
self, program=None, ps_endpoints=[], program=None,
trainers=0, sync_mode=True, ps_endpoints=[],
startup_program=None, job=None): trainers=0,
sync_mode=True,
startup_program=None,
job=None):
pass pass
...@@ -123,6 +132,7 @@ class DPSGDStrategy(FLStrategyBase): ...@@ -123,6 +132,7 @@ class DPSGDStrategy(FLStrategyBase):
""" """
DPSGDStrategy: Deep Learning with Differential Privacy. 2016 DPSGDStrategy: Deep Learning with Differential Privacy. 2016
""" """
def __init__(self): def __init__(self):
super(DPSGDStrategy, self).__init__() super(DPSGDStrategy, self).__init__()
...@@ -162,16 +172,24 @@ class DPSGDStrategy(FLStrategyBase): ...@@ -162,16 +172,24 @@ class DPSGDStrategy(FLStrategyBase):
""" """
Define Dpsgd optimizer Define Dpsgd optimizer
""" """
optimizer = fluid.optimizer.Dpsgd(self._learning_rate, clip=self._clip, batch_size=self._batch_size, sigma=self._sigma) optimizer = fluid.optimizer.Dpsgd(
self._learning_rate,
clip=self._clip,
batch_size=self._batch_size,
sigma=self._sigma)
optimizer.minimize(losses[0]) optimizer.minimize(losses[0])
def _build_trainer_program_for_job( def _build_trainer_program_for_job(self,
self, trainer_id=0, program=None, trainer_id=0,
ps_endpoints=[], trainers=0, program=None,
sync_mode=True, startup_program=None, ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None): job=None):
transpiler = fluid.DistributeTranspiler() transpiler = fluid.DistributeTranspiler()
transpiler.transpile(trainer_id, transpiler.transpile(
trainer_id,
program=program, program=program,
pservers=",".join(ps_endpoints), pservers=",".join(ps_endpoints),
trainers=trainers, trainers=trainers,
...@@ -181,10 +199,13 @@ class DPSGDStrategy(FLStrategyBase): ...@@ -181,10 +199,13 @@ class DPSGDStrategy(FLStrategyBase):
job._trainer_startup_programs.append(startup_program) job._trainer_startup_programs.append(startup_program)
job._trainer_main_programs.append(main) job._trainer_main_programs.append(main)
def _build_server_programs_for_job( def _build_server_programs_for_job(self,
self, program=None, ps_endpoints=[], program=None,
trainers=0, sync_mode=True, ps_endpoints=[],
startup_program=None, job=None): trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = fluid.DistributeTranspiler() transpiler = fluid.DistributeTranspiler()
trainer_id = 0 trainer_id = 0
transpiler.transpile( transpiler.transpile(
...@@ -207,6 +228,7 @@ class FedAvgStrategy(FLStrategyBase): ...@@ -207,6 +228,7 @@ class FedAvgStrategy(FLStrategyBase):
FedAvgStrategy: this is model averaging optimization proposed in FedAvgStrategy: this is model averaging optimization proposed in
H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017 H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017
""" """
def __init__(self): def __init__(self):
super(FedAvgStrategy, self).__init__() super(FedAvgStrategy, self).__init__()
...@@ -216,13 +238,17 @@ class FedAvgStrategy(FLStrategyBase): ...@@ -216,13 +238,17 @@ class FedAvgStrategy(FLStrategyBase):
""" """
optimizer.minimize(losses[0]) optimizer.minimize(losses[0])
def _build_trainer_program_for_job( def _build_trainer_program_for_job(self,
self, trainer_id=0, program=None, trainer_id=0,
ps_endpoints=[], trainers=0, program=None,
sync_mode=True, startup_program=None, ps_endpoints=[],
trainers=0,
sync_mode=True,
startup_program=None,
job=None): job=None):
transpiler = FLDistributeTranspiler() transpiler = FLDistributeTranspiler()
transpiler.transpile(trainer_id, transpiler.transpile(
trainer_id,
program=program, program=program,
pservers=",".join(ps_endpoints), pservers=",".join(ps_endpoints),
trainers=trainers, trainers=trainers,
...@@ -234,10 +260,13 @@ class FedAvgStrategy(FLStrategyBase): ...@@ -234,10 +260,13 @@ class FedAvgStrategy(FLStrategyBase):
job._trainer_send_programs.append(send) job._trainer_send_programs.append(send)
job._trainer_recv_programs.append(recv) job._trainer_recv_programs.append(recv)
def _build_server_programs_for_job( def _build_server_programs_for_job(self,
self, program=None, ps_endpoints=[], program=None,
trainers=0, sync_mode=True, ps_endpoints=[],
startup_program=None, job=None): trainers=0,
sync_mode=True,
startup_program=None,
job=None):
transpiler = FLDistributeTranspiler() transpiler = FLDistributeTranspiler()
trainer_id = 0 trainer_id = 0
transpiler.transpile( transpiler.transpile(
...@@ -262,6 +291,7 @@ class SecAggStrategy(FedAvgStrategy): ...@@ -262,6 +291,7 @@ class SecAggStrategy(FedAvgStrategy):
Practical Secure Aggregation for Privacy-Preserving Machine Learning, Practical Secure Aggregation for Privacy-Preserving Machine Learning,
The 24th ACM Conference on Computer and Communications Security ( CCS2017 ). The 24th ACM Conference on Computer and Communications Security ( CCS2017 ).
""" """
def __init__(self): def __init__(self):
super(SecAggStrategy, self).__init__() super(SecAggStrategy, self).__init__()
self._param_name_list = [] self._param_name_list = []
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import os import os
class CloudClient(object): class CloudClient(object):
def __init__(self): def __init__(self):
pass pass
...@@ -16,6 +31,7 @@ class CloudClient(object): ...@@ -16,6 +31,7 @@ class CloudClient(object):
def submit(self, **kwargs): def submit(self, **kwargs):
pass pass
class HPCClient(object): class HPCClient(object):
def __init__(self): def __init__(self):
self.conf_dict = {} self.conf_dict = {}
...@@ -70,27 +86,20 @@ class HPCClient(object): ...@@ -70,27 +86,20 @@ class HPCClient(object):
fout.write("#!/bin/bash\n") fout.write("#!/bin/bash\n")
fout.write("unset http_proxy\n") fout.write("unset http_proxy\n")
fout.write("unset https_proxy\n") fout.write("unset https_proxy\n")
fout.write("export HADOOP_HOME={}\n".format( fout.write("export HADOOP_HOME={}\n".format(self.hadoop_home))
self.hadoop_home))
fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}" fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}"
" -Dfs.default.name={} -rmr {}\n".format( " -Dfs.default.name={} -rmr {}\n".format(
self.ugi, self.ugi, self.hdfs_path, self.hdfs_output))
self.hdfs_path,
self.hdfs_output))
fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem)) fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem))
fout.write("{}/bin/qsub_f -N {} --conf qsub.conf " fout.write("{}/bin/qsub_f -N {} --conf qsub.conf "
"--hdfs {} --ugi {} --hout {} --files ./package " "--hdfs {} --ugi {} --hout {} --files ./package "
"-l nodes={},walltime=1000:00:00,pmem-hard={}," "-l nodes={},walltime=1000:00:00,pmem-hard={},"
"pcpu-soft={},pnetin-soft=1000," "pcpu-soft={},pnetin-soft=1000,"
"pnetout-soft=1000 job.sh\n".format( "pnetout-soft=1000 job.sh\n".format(
self.hpc_home, self.hpc_home, self.task_name, self.hdfs_path,
self.task_name, self.ugi, self.hdfs_output,
self.hdfs_path,
self.ugi,
self.hdfs_output,
int(self.worker_nodes) + int(self.server_nodes), int(self.worker_nodes) + int(self.server_nodes),
self.mpi_node_mem, self.mpi_node_mem, self.pcpu))
self.pcpu))
def generate_job_sh(self, job_dir): def generate_job_sh(self, job_dir):
with open("{}/job.sh".format(job_dir), "w") as fout: with open("{}/job.sh".format(job_dir), "w") as fout:
...@@ -98,17 +107,23 @@ class HPCClient(object): ...@@ -98,17 +107,23 @@ class HPCClient(object):
fout.write("WORKDIR=`pwd`\n") fout.write("WORKDIR=`pwd`\n")
fout.write("mpirun -npernode 1 mv package/* ./\n") fout.write("mpirun -npernode 1 mv package/* ./\n")
fout.write("echo 'current dir: '$WORKDIR\n") fout.write("echo 'current dir: '$WORKDIR\n")
fout.write("mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n") fout.write(
fout.write("export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n") "mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n")
fout.write(
"export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n")
fout.write("mpirun -npernode 1 python/bin/python -m pip install " fout.write("mpirun -npernode 1 python/bin/python -m pip install "
"{} --index-url=http://pip.baidu.com/pypi/simple " "{} --index-url=http://pip.baidu.com/pypi/simple "
"--trusted-host pip.baidu.com > /dev/null\n".format( "--trusted-host pip.baidu.com > /dev/null\n".format(
self.wheel)) self.wheel))
fout.write("export PATH=python/bin:$PATH\n") fout.write("export PATH=python/bin:$PATH\n")
if self.monitor_cmd != "": if self.monitor_cmd != "":
fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile " fout.write(
"${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".format(self.monitor_cmd)) "mpirun -npernode 1 -timestamp-output -tag-output -machinefile "
fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n") "${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".
format(self.monitor_cmd))
fout.write(
"mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n"
)
fout.write("if [[ $? -ne 0 ]]; then\n") fout.write("if [[ $? -ne 0 ]]; then\n")
fout.write(" echo 'Failed to run mpi!' 1>&2\n") fout.write(" echo 'Failed to run mpi!' 1>&2\n")
fout.write(" exit 1\n") fout.write(" exit 1\n")
...@@ -150,4 +165,5 @@ class HPCClient(object): ...@@ -150,4 +165,5 @@ class HPCClient(object):
# generate job.sh # generate job.sh
self.generate_qsub_conf(jobdir) self.generate_qsub_conf(jobdir)
# run submit # run submit
os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(jobdir)) os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(
jobdir))
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# (c) Chris von Csefalvay, 2015. # (c) Chris von Csefalvay, 2015.
""" """
__init__.py is responsible for [brief description here]. __init__.py is responsible for [brief description here].
""" """
# coding=utf-8 # coding=utf-8
# #
# The MIT License (MIT) # The MIT License (MIT)
# #
...@@ -21,8 +20,6 @@ ...@@ -21,8 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# #
""" """
decorators declares some decorators that ensure the object has the decorators declares some decorators that ensure the object has the
correct keys declared when need be. correct keys declared when need be.
......
...@@ -20,10 +20,6 @@ ...@@ -20,10 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# #
""" """
diffiehellmann declares the main key exchange class. diffiehellmann declares the main key exchange class.
""" """
...@@ -41,18 +37,17 @@ import os ...@@ -41,18 +37,17 @@ import os
try: try:
from ssl import RAND_bytes from ssl import RAND_bytes
rng = RAND_bytes rng = RAND_bytes
except(AttributeError, ImportError): except (AttributeError, ImportError):
rng = os.urandom rng = os.urandom
class DiffieHellman: class DiffieHellman:
""" """
Implements the Diffie-Hellman key exchange protocol. Implements the Diffie-Hellman key exchange protocol.
""" """
def __init__(self, def __init__(self, group=18, key_length=640):
group=18,
key_length=640):
self.key_length = max(200, key_length) self.key_length = max(200, key_length)
self.generator = PRIMES[group]["generator"] self.generator = PRIMES[group]["generator"]
...@@ -81,7 +76,8 @@ class DiffieHellman: ...@@ -81,7 +76,8 @@ class DiffieHellman:
self.private_key = key self.private_key = key
def verify_public_key(self, other_public_key): def verify_public_key(self, other_public_key):
return self.prime - 1 > other_public_key > 2 and pow(other_public_key, (self.prime - 1) // 2, self.prime) == 1 return self.prime - 1 > other_public_key > 2 and pow(
other_public_key, (self.prime - 1) // 2, self.prime) == 1
@requires_private_key @requires_private_key
def generate_public_key(self): def generate_public_key(self):
...@@ -91,9 +87,7 @@ class DiffieHellman: ...@@ -91,9 +87,7 @@ class DiffieHellman:
:return: void :return: void
:rtype: void :rtype: void
""" """
self.public_key = pow(self.generator, self.public_key = pow(self.generator, self.private_key, self.prime)
self.private_key,
self.prime)
@requires_private_key @requires_private_key
def generate_shared_secret(self, other_public_key, echo_return_key=False): def generate_shared_secret(self, other_public_key, echo_return_key=False):
...@@ -110,16 +104,17 @@ class DiffieHellman: ...@@ -110,16 +104,17 @@ class DiffieHellman:
if self.verify_public_key(other_public_key) is False: if self.verify_public_key(other_public_key) is False:
raise MalformedPublicKey raise MalformedPublicKey
self.shared_secret = pow(other_public_key, self.shared_secret = pow(other_public_key, self.private_key,
self.private_key,
self.prime) self.prime)
try: try:
#python3 #python3
shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big') shared_secret_as_bytes = self.shared_secret.to_bytes(
self.shared_secret.bit_length() // 8 + 1, byteorder='big')
except: except:
#python2 #python2
length = self.shared_secret.bit_length() // 8 + 1 length = self.shared_secret.bit_length() // 8 + 1
shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:] shared_secret_as_bytes = ('%%0%dx' % (
length << 1) % self.shared_secret).decode('hex')[-length:]
_h = sha256() _h = sha256()
_h.update(bytes(shared_secret_as_bytes)) _h.update(bytes(shared_secret_as_bytes))
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# (c) Chris von Csefalvay, 2015. # (c) Chris von Csefalvay, 2015.
""" """
exceptions is responsible for exception handling etc. exceptions is responsible for exception handling etc.
""" """
......
# coding=utf-8 # coding=utf-8
# #
# The MIT License (MIT) # The MIT License (MIT)
# #
...@@ -25,34 +24,39 @@ ...@@ -25,34 +24,39 @@
# Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman # Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman
# groups for Internet Key Exchange (IKE)_. # groups for Internet Key Exchange (IKE)_.
# #
""" """
primes holds the RFC 3526 MODP primes and their generators. primes holds the RFC 3526 MODP primes and their generators.
""" """
PRIMES = { PRIMES = {
5: { 5: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
14: { 14: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
15: { 15: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
16: { 16: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
17: { 17: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
18: { 18: {
"prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF, "prime":
0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF,
"generator": 2 "generator": 2
}, },
} }
...@@ -19,6 +19,7 @@ import hmac ...@@ -19,6 +19,7 @@ import hmac
import hashlib import hashlib
from .diffiehellman.diffiehellman import DiffieHellman from .diffiehellman.diffiehellman import DiffieHellman
class FLTrainerFactory(object): class FLTrainerFactory(object):
def __init__(self): def __init__(self):
pass pass
...@@ -65,9 +66,7 @@ class FLTrainer(object): ...@@ -65,9 +66,7 @@ class FLTrainer(object):
def run(self, feed, fetch): def run(self, feed, fetch):
self._logger.debug("begin to run") self._logger.debug("begin to run")
self.exe.run(self._main_program, self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
feed=feed,
fetch_list=fetch)
self._logger.debug("end to run current batch") self._logger.debug("end to run current batch")
self.cur_step += 1 self.cur_step += 1
...@@ -119,7 +118,7 @@ class FedAvgTrainer(FLTrainer): ...@@ -119,7 +118,7 @@ class FedAvgTrainer(FLTrainer):
def reset(self): def reset(self):
self.cur_step = 0 self.cur_step = 0
def run_with_epoch(self,reader,feeder,fetch,num_epoch): def run_with_epoch(self, reader, feeder, fetch, num_epoch):
self._logger.debug("begin to run recv program") self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program) self.exe.run(self._recv_program)
epoch = 0 epoch = 0
...@@ -132,16 +131,16 @@ class FedAvgTrainer(FLTrainer): ...@@ -132,16 +131,16 @@ class FedAvgTrainer(FLTrainer):
epoch += 1 epoch += 1
self._logger.debug("begin to run send program") self._logger.debug("begin to run send program")
self.exe.run(self._send_program) self.exe.run(self._send_program)
def run(self, feed, fetch): def run(self, feed, fetch):
self._logger.debug("begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" % self._logger.debug(
"begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step)) (self.cur_step, self._step))
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program") self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program) self.exe.run(self._recv_program)
self._logger.debug("begin to run current step") self._logger.debug("begin to run current step")
loss = self.exe.run(self._main_program, loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
feed=feed,
fetch_list=fetch)
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program") self._logger.debug("begin to run send program")
self.exe.run(self._send_program) self.exe.run(self._send_program)
...@@ -149,9 +148,6 @@ class FedAvgTrainer(FLTrainer): ...@@ -149,9 +148,6 @@ class FedAvgTrainer(FLTrainer):
return loss return loss
class SecAggTrainer(FLTrainer): class SecAggTrainer(FLTrainer):
def __init__(self): def __init__(self):
super(SecAggTrainer, self).__init__() super(SecAggTrainer, self).__init__()
...@@ -207,24 +203,24 @@ class SecAggTrainer(FLTrainer): ...@@ -207,24 +203,24 @@ class SecAggTrainer(FLTrainer):
self.cur_step = 0 self.cur_step = 0
def run(self, feed, fetch): def run(self, feed, fetch):
self._logger.debug("begin to run SecAggTrainer, cur_step=%d, inner_step=%d" % self._logger.debug(
"begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
(self.cur_step, self._step)) (self.cur_step, self._step))
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program") self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program) self.exe.run(self._recv_program)
scope = fluid.global_scope() scope = fluid.global_scope()
self._logger.debug("begin to run current step") self._logger.debug("begin to run current step")
loss = self.exe.run(self._main_program, loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
feed=feed,
fetch_list=fetch)
if self.cur_step % self._step == 0: if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program") self._logger.debug("begin to run send program")
noise = 0.0 noise = 0.0
scale = pow(10.0, 5) scale = pow(10.0, 5)
digestmod=hashlib.sha256 digestmod = hashlib.sha256
# 1. load priv key and other's pub key # 1. load priv key and other's pub key
dh = DiffieHellman(group=15, key_length=256) dh = DiffieHellman(group=15, key_length=256)
dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt") dh.load_private_key(self._key_dir + str(self._trainer_id) +
"_priv_key.txt")
key = str(self._step_id).encode("utf-8") key = str(self._step_id).encode("utf-8")
for i in range(self._trainer_num): for i in range(self._trainer_num):
if i != self._trainer_id: if i != self._trainer_id:
...@@ -232,7 +228,8 @@ class SecAggTrainer(FLTrainer): ...@@ -232,7 +228,8 @@ class SecAggTrainer(FLTrainer):
public_key = int(f.read()) public_key = int(f.read())
dh.generate_shared_secret(public_key, echo_return_key=True) dh.generate_shared_secret(public_key, echo_return_key=True)
msg = dh.shared_key.encode("utf-8") msg = dh.shared_key.encode("utf-8")
hex_res1 = hmac.new(key=key, msg=msg, digestmod=digestmod).hexdigest() hex_res1 = hmac.new(key=key, msg=msg,
digestmod=digestmod).hexdigest()
current_noise = int(hex_res1[0:8], 16) / scale current_noise = int(hex_res1[0:8], 16) / scale
if i > self._trainer_id: if i > self._trainer_id:
noise = noise + current_noise noise = noise + current_noise
...@@ -241,9 +238,11 @@ class SecAggTrainer(FLTrainer): ...@@ -241,9 +238,11 @@ class SecAggTrainer(FLTrainer):
scope = fluid.global_scope() scope = fluid.global_scope()
for param_name in self._param_name_list: for param_name in self._param_name_list:
fluid.global_scope().var(param_name + str(self._trainer_id)).get_tensor().set( fluid.global_scope().var(param_name + str(
numpy.array(scope.find_var(param_name + str(self._trainer_id)).get_tensor()) + noise, fluid.CPUPlace()) self._trainer_id)).get_tensor().set(
numpy.array(
scope.find_var(param_name + str(self._trainer_id))
.get_tensor()) + noise, fluid.CPUPlace())
self.exe.run(self._send_program) self.exe.run(self._send_program)
self.cur_step += 1 self.cur_step += 1
return loss return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests import requests
import os import os
import json import json
...@@ -5,20 +19,22 @@ import tarfile ...@@ -5,20 +19,22 @@ import tarfile
import random import random
def download(url,tar_path): def download(url, tar_path):
r = requests.get(url) r = requests.get(url)
with open(tar_path,'wb') as f: with open(tar_path, 'wb') as f:
f.write(r.content) f.write(r.content)
def extract(tar_path,target_path):
def extract(tar_path, target_path):
tar = tarfile.open(tar_path, "r:gz") tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames() file_names = tar.getnames()
for file_name in file_names: for file_name in file_names:
tar.extract(file_name,target_path) tar.extract(file_name, target_path)
tar.close() tar.close()
def train(trainer_id,inner_step,batch_size,count_by_step):
def train(trainer_id, inner_step, batch_size, count_by_step):
target_path = "trainer%d_data" % trainer_id target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data" data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz" tar_path = data_path + ".tar.gz"
...@@ -27,20 +43,26 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -27,20 +43,26 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(data_path): if not os.path.exists(data_path):
print("Preparing data...") print("Preparing data...")
if not os.path.exists(tar_path): if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
extract(tar_path,target_path) tar_path)
extract(tar_path, target_path)
def train_data(): def train_data():
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r') train_file = open(
"./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json"
% (trainer_id, trainer_id), 'r')
json_train = json.load(train_file) json_train = json.load(train_file)
users = json_train["users"] users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer rand = random.randrange(
0, len(users)) # random choose a user from each trainer
cur_user = users[rand] cur_user = users[rand]
print('training using '+cur_user) print('training using ' + cur_user)
train_images = json_train["user_data"][cur_user]['x'] train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y'] train_labels = json_train["user_data"][cur_user]['y']
if count_by_step: if count_by_step:
for i in range(inner_step*batch_size): for i in range(inner_step * batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))] yield train_images[i % (len(train_images))], train_labels[i % (
len(train_images))]
else: else:
for i in range(len(train_images)): for i in range(len(train_images)):
yield train_images[i], train_labels[i] yield train_images[i], train_labels[i]
...@@ -49,7 +71,8 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -49,7 +71,8 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
return train_data return train_data
def test(trainer_id,inner_step,batch_size,count_by_step):
def test(trainer_id, inner_step, batch_size, count_by_step):
target_path = "trainer%d_data" % trainer_id target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data" data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz" tar_path = data_path + ".tar.gz"
...@@ -58,10 +81,14 @@ def test(trainer_id,inner_step,batch_size,count_by_step): ...@@ -58,10 +81,14 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(data_path): if not os.path.exists(data_path):
print("Preparing data...") print("Preparing data...")
if not os.path.exists(tar_path): if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
extract(tar_path,target_path) tar_path)
extract(tar_path, target_path)
def test_data(): def test_data():
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r') test_file = open(
"./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json"
% (trainer_id, trainer_id), 'r')
json_test = json.load(test_file) json_test = json.load(test_file)
users = json_test["users"] users = json_test["users"]
for user in users: for user in users:
...@@ -73,5 +100,3 @@ def test(trainer_id,inner_step,batch_size,count_by_step): ...@@ -73,5 +100,3 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
test_file.close() test_file.close()
return test_data return test_data
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_fl as fl import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
class Model(object): class Model(object):
def __init__(self): def __init__(self):
pass pass
...@@ -12,7 +27,8 @@ class Model(object): ...@@ -12,7 +27,8 @@ class Model(object):
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu') self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu') self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax') self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost) self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
...@@ -34,8 +50,8 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.1) ...@@ -34,8 +50,8 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.1)
job_generator.set_optimizer(optimizer) job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names( job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
[x.name for x in inputs], [model.predict.name]) [model.predict.name])
build_strategy = FLStrategyFactory() build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True build_strategy.fed_avg = True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.scheduler.agent_master import FLScheduler from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2 worker_num = 2
server_num = 1 server_num = 1
# Define the number of worker/server and the port for scheduler # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091) scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num) scheduler.set_sample_worker_num(worker_num)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np import numpy as np
import sys import sys
import logging import logging
import time import time
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
def reader(): def reader():
...@@ -15,13 +34,14 @@ def reader(): ...@@ -15,13 +34,14 @@ def reader():
data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64') data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
yield data_dict yield data_dict
trainer_id = int(sys.argv[1]) # trainer id for each guest trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start() trainer.start()
print(trainer._scheduler_ep, trainer._current_ep) print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model" output_folder = "fl_model"
...@@ -37,4 +57,3 @@ while not trainer.stop(): ...@@ -37,4 +57,3 @@ while not trainer.stop():
epoch_id += 1 epoch_id += 1
if epoch_id % 5 == 0: if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_fl as fl import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
import math import math
class Model(object): class Model(object):
def __init__(self): def __init__(self):
pass pass
def lr_network(self): def lr_network(self):
self.inputs = fluid.layers.data(name='img', shape=[1, 28, 28], dtype="float32") self.inputs = fluid.layers.data(
self.label = fluid.layers.data(name='label', shape=[1],dtype='int64') name='img', shape=[1, 28, 28], dtype="float32")
self.predict = fluid.layers.fc(input=self.inputs, size=10, act='softmax') self.label = fluid.layers.data(name='label', shape=[1], dtype='int64')
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=self.label) self.predict = fluid.layers.fc(input=self.inputs,
self.accuracy = fluid.layers.accuracy(input=self.predict, label=self.label) size=10,
act='softmax')
self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=self.label)
self.accuracy = fluid.layers.accuracy(
input=self.predict, label=self.label)
self.loss = fluid.layers.mean(self.sum_cost) self.loss = fluid.layers.mean(self.sum_cost)
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
...@@ -23,7 +43,7 @@ model.lr_network() ...@@ -23,7 +43,7 @@ model.lr_network()
STEP_EPSILON = 0.1 STEP_EPSILON = 0.1
DELTA = 0.00001 DELTA = 0.00001
SIGMA = math.sqrt(2.0 * math.log(1.25/DELTA)) / STEP_EPSILON SIGMA = math.sqrt(2.0 * math.log(1.25 / DELTA)) / STEP_EPSILON
CLIP = 4.0 CLIP = 4.0
batch_size = 64 batch_size = 64
...@@ -33,7 +53,8 @@ job_generator.set_optimizer(optimizer) ...@@ -33,7 +53,8 @@ job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names( job_generator.set_infer_feed_and_target_names(
[model.inputs.name, model.label.name], [model.loss.name, model.accuracy.name]) [model.inputs.name, model.label.name],
[model.loss.name, model.accuracy.name])
build_strategy = FLStrategyFactory() build_strategy = FLStrategyFactory()
build_strategy.dpsgd = True build_strategy.dpsgd = True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.scheduler.agent_master import FLScheduler from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4 worker_num = 4
server_num = 1 server_num = 1
#Define number of worker/server and the port for scheduler #Define number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091) scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(4) scheduler.set_sample_worker_num(4)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy import numpy
...@@ -7,7 +21,12 @@ import paddle.fluid as fluid ...@@ -7,7 +21,12 @@ import paddle.fluid as fluid
import logging import logging
import math import math
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
trainer_id = int(sys.argv[1]) # trainer id for each guest trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
...@@ -15,36 +34,38 @@ job = FLRunTimeJob() ...@@ -15,36 +34,38 @@ job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start() trainer.start()
test_program = trainer._main_program.clone(for_test=True) test_program = trainer._main_program.clone(for_test=True)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=64) batch_size=64)
test_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64)
paddle.dataset.mnist.test(), batch_size=64)
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace()) feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace())
def train_test(train_test_program, train_test_feed, train_test_reader): def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = [] acc_set = []
for test_data in train_test_reader(): for test_data in train_test_reader():
acc_np = trainer.exe.run( acc_np = trainer.exe.run(program=train_test_program,
program=train_test_program,
feed=train_test_feed.feed(test_data), feed=train_test_feed.feed(test_data),
fetch_list=["accuracy_0.tmp_0"]) fetch_list=["accuracy_0.tmp_0"])
acc_set.append(float(acc_np[0])) acc_set.append(float(acc_np[0]))
acc_val_mean = numpy.array(acc_set).mean() acc_val_mean = numpy.array(acc_set).mean()
return acc_val_mean return acc_val_mean
def compute_privacy_budget(sample_ratio, epsilon, step, delta): def compute_privacy_budget(sample_ratio, epsilon, step, delta):
E = 2 * epsilon * math.sqrt(step * sample_ratio) E = 2 * epsilon * math.sqrt(step * sample_ratio)
print("({0}, {1})-DP".format(E, delta)) print("({0}, {1})-DP".format(E, delta))
output_folder = "model_node%d" % trainer_id output_folder = "model_node%d" % trainer_id
epoch_id = 0 epoch_id = 0
step = 0 step = 0
...@@ -64,7 +85,8 @@ while not trainer.stop(): ...@@ -64,7 +85,8 @@ while not trainer.stop():
train_test_feed=feeder) train_test_feed=feeder)
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val)) print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
compute_privacy_budget(sample_ratio=0.001, epsilon=0.1, step=step, delta=0.00001) compute_privacy_budget(
sample_ratio=0.001, epsilon=0.1, step=step, delta=0.00001)
save_dir = (output_folder + "/epoch_%d") % epoch_id save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_fl as fl import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.master.job_generator import JobGenerator
...@@ -9,14 +23,31 @@ class Model(object): ...@@ -9,14 +23,31 @@ class Model(object):
pass pass
def cnn(self): def cnn(self):
self.inputs = fluid.layers.data(name='img', shape=[1, 28, 28], dtype="float32") self.inputs = fluid.layers.data(
self.label = fluid.layers.data(name='label', shape=[1],dtype='int64') name='img', shape=[1, 28, 28], dtype="float32")
self.conv_pool_1 = fluid.nets.simple_img_conv_pool(input=self.inputs,num_filters=20,filter_size=5,pool_size=2,pool_stride=2,act='relu') self.label = fluid.layers.data(name='label', shape=[1], dtype='int64')
self.conv_pool_2 = fluid.nets.simple_img_conv_pool(input=self.conv_pool_1,num_filters=50,filter_size=5,pool_size=2,pool_stride=2,act='relu') self.conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=self.inputs,
self.predict = self.predict = fluid.layers.fc(input=self.conv_pool_2, size=62, act='softmax') num_filters=20,
self.cost = fluid.layers.cross_entropy(input=self.predict, label=self.label) filter_size=5,
self.accuracy = fluid.layers.accuracy(input=self.predict, label=self.label) pool_size=2,
pool_stride=2,
act='relu')
self.conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=self.conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
self.predict = self.predict = fluid.layers.fc(input=self.conv_pool_2,
size=62,
act='softmax')
self.cost = fluid.layers.cross_entropy(
input=self.predict, label=self.label)
self.accuracy = fluid.layers.accuracy(
input=self.predict, label=self.label)
self.loss = fluid.layers.mean(self.cost) self.loss = fluid.layers.mean(self.cost)
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
...@@ -30,8 +61,8 @@ job_generator.set_optimizer(optimizer) ...@@ -30,8 +61,8 @@ job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names( job_generator.set_infer_feed_and_target_names(
[model.inputs.name, model.label.name], [model.loss.name, model.accuracy.name]) [model.inputs.name, model.label.name],
[model.loss.name, model.accuracy.name])
build_strategy = FLStrategyFactory() build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True build_strategy.fed_avg = True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.scheduler.agent_master import FLScheduler from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4 worker_num = 4
server_num = 1 server_num = 1
# Define the number of worker/server and the port for scheduler # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091) scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(4) scheduler.set_sample_worker_num(4)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle_fl as fl import paddle_fl as fl
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle_fl.core.server.fl_server import FLServer from paddle_fl.core.server.fl_server import FLServer
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob from paddle_fl.core.master.fl_job import FLRunTimeJob
import paddle_fl.dataset.femnist import paddle_fl.dataset.femnist
...@@ -8,7 +22,12 @@ import paddle.fluid as fluid ...@@ -8,7 +22,12 @@ import paddle.fluid as fluid
import logging import logging
import math import math
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
trainer_id = int(sys.argv[1]) # trainer id for each guest trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
...@@ -17,7 +36,7 @@ job.load_trainer_job(job_path, trainer_id) ...@@ -17,7 +36,7 @@ job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
print(job._target_names) print(job._target_names)
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start() trainer.start()
print(trainer._step) print(trainer._step)
test_program = trainer._main_program.clone(for_test=True) test_program = trainer._main_program.clone(for_test=True)
...@@ -26,17 +45,18 @@ img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') ...@@ -26,17 +45,18 @@ img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace()) feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace())
def train_test(train_test_program, train_test_feed, train_test_reader): def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = [] acc_set = []
for test_data in train_test_reader(): for test_data in train_test_reader():
acc_np = trainer.exe.run( acc_np = trainer.exe.run(program=train_test_program,
program=train_test_program,
feed=train_test_feed.feed(test_data), feed=train_test_feed.feed(test_data),
fetch_list=["accuracy_0.tmp_0"]) fetch_list=["accuracy_0.tmp_0"])
acc_set.append(float(acc_np[0])) acc_set.append(float(acc_np[0]))
acc_val_mean = numpy.array(acc_set).mean() acc_val_mean = numpy.array(acc_set).mean()
return acc_val_mean return acc_val_mean
epoch_id = 0 epoch_id = 0
step = 0 step = 0
epoch = 3000 epoch = 3000
...@@ -46,7 +66,6 @@ if count_by_step: ...@@ -46,7 +66,6 @@ if count_by_step:
else: else:
output_folder = "model_node%d_epoch" % trainer_id output_folder = "model_node%d_epoch" % trainer_id
while not trainer.stop(): while not trainer.stop():
count = 0 count = 0
epoch_id += 1 epoch_id += 1
...@@ -55,11 +74,22 @@ while not trainer.stop(): ...@@ -55,11 +74,22 @@ while not trainer.stop():
print("epoch %d start train" % (epoch_id)) print("epoch %d start train" % (epoch_id))
#train_data,test_data= data_generater(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step) #train_data,test_data= data_generater(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle(paddle_fl.dataset.femnist.train(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), buf_size=500), paddle.reader.shuffle(
paddle_fl.dataset.femnist.train(
trainer_id,
inner_step=trainer._step,
batch_size=64,
count_by_step=count_by_step),
buf_size=500),
batch_size=64) batch_size=64)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle_fl.dataset.femnist.test(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), batch_size=64) paddle_fl.dataset.femnist.test(
trainer_id,
inner_step=trainer._step,
batch_size=64,
count_by_step=count_by_step),
batch_size=64)
if count_by_step: if count_by_step:
for step_id, data in enumerate(train_reader()): for step_id, data in enumerate(train_reader()):
...@@ -71,8 +101,8 @@ while not trainer.stop(): ...@@ -71,8 +101,8 @@ while not trainer.stop():
break break
# print("acc:%.3f" % (acc[0])) # print("acc:%.3f" % (acc[0]))
else: else:
trainer.run_with_epoch(train_reader,feeder,fetch=["accuracy_0.tmp_0"],num_epoch=1) trainer.run_with_epoch(
train_reader, feeder, fetch=["accuracy_0.tmp_0"], num_epoch=1)
acc_val = train_test( acc_val = train_test(
train_test_program=test_program, train_test_program=test_program,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_fl as fl import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
class Model(object): class Model(object):
def __init__(self): def __init__(self):
pass pass
...@@ -34,7 +49,8 @@ class Model(object): ...@@ -34,7 +49,8 @@ class Model(object):
size=hid_size * 3, size=hid_size * 3,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound), low=init_low_bound,
high=init_high_bound),
learning_rate=gru_lr_x)) learning_rate=gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru( gru_h0 = fluid.layers.dynamic_gru(
input=fc0, input=fc0,
...@@ -49,7 +65,8 @@ class Model(object): ...@@ -49,7 +65,8 @@ class Model(object):
act='softmax', act='softmax',
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound), low=init_low_bound,
high=init_high_bound),
learning_rate=fc_lr_x)) learning_rate=fc_lr_x))
cost = fluid.layers.cross_entropy( cost = fluid.layers.cross_entropy(
input=self.fc, label=self.dst_wordseq) input=self.fc, label=self.dst_wordseq)
...@@ -59,7 +76,6 @@ class Model(object): ...@@ -59,7 +76,6 @@ class Model(object):
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
model = Model() model = Model()
model.gru4rec_network() model.gru4rec_network()
...@@ -69,7 +85,8 @@ job_generator.set_optimizer(optimizer) ...@@ -69,7 +85,8 @@ job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names( job_generator.set_infer_feed_and_target_names(
[model.src_wordseq.name, model.dst_wordseq.name], [model.loss.name, model.acc.name]) [model.src_wordseq.name, model.dst_wordseq.name],
[model.loss.name, model.acc.name])
build_strategy = FLStrategyFactory() build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True build_strategy.fed_avg = True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.scheduler.agent_master import FLScheduler from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4 worker_num = 4
server_num = 1 server_num = 1
# Define the number of worker/server and the port for scheduler # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091) scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(4) scheduler.set_sample_worker_num(4)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob from paddle_fl.core.master.fl_job import FLRunTimeJob
from paddle_fl.reader.gru4rec_reader import Gru4rec_Reader from paddle_fl.reader.gru4rec_reader import Gru4rec_Reader
...@@ -6,7 +20,12 @@ import numpy as np ...@@ -6,7 +20,12 @@ import numpy as np
import sys import sys
import os import os
import logging import logging
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
trainer_id = int(sys.argv[1]) # trainer id for each guest trainer_id = int(sys.argv[1]) # trainer id for each guest
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -16,11 +35,11 @@ job = FLRunTimeJob() ...@@ -16,11 +35,11 @@ job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start() trainer.start()
r = Gru4rec_Reader() r = Gru4rec_Reader()
train_reader = r.reader(train_file_dir, place, batch_size = 125) train_reader = r.reader(train_file_dir, place, batch_size=125)
output_folder = "model_node4" output_folder = "model_node4"
step_i = 0 step_i = 0
...@@ -30,8 +49,7 @@ while not trainer.stop(): ...@@ -30,8 +49,7 @@ while not trainer.stop():
train_step = 0 train_step = 0
for data in train_reader(): for data in train_reader():
#print(np.array(data['src_wordseq'])) #print(np.array(data['src_wordseq']))
ret_avg_cost = trainer.run(feed=data, ret_avg_cost = trainer.run(feed=data, fetch=["mean_0.tmp_0"])
fetch=["mean_0.tmp_0"])
train_step += 1 train_step += 1
if train_step == trainer._step: if train_step == trainer._step:
break break
......
apiVersion: v1
kind: Service
metadata:
name: fl-master
spec:
type: LoadBalancer
ports:
- name: fl-master
port: 8000
targetPort: 8000
selector:
app: fl-master
---
apiVersion: v1
kind: Service
metadata:
name: fl-scheduler
spec:
type: LoadBalancer
ports:
- name: fl-scheduler
port: 9091
targetPort: 9091
selector:
app: fl-scheduler
---
apiVersion: v1
kind: Service
metadata:
name: fl-server
spec:
type: LoadBalancer
ports:
- name: fl-server
port: 8181
targetPort: 8181
selector:
app: fl-server
---
apiVersion: v1
kind: Service
metadata:
name: trainer0
spec:
type: LoadBalancer
ports:
- name: trainer0
port: 9000
targetPort: 9000
selector:
app: trainer0
---
apiVersion: v1
kind: Service
metadata:
name: trainer1
spec:
type: LoadBalancer
ports:
- name: trainer1
port: 9001
targetPort: 9001
selector:
app: trainer1
---
apiVersion: apps/v1beta1
kind: Deployment
metadata:
name: fl-master
labels:
app: fl-master
spec:
replicas: 1
template:
metadata:
name: fl-master
labels:
app: fl-master
spec:
containers:
- name: fl-master
image: hub.baidubce.com/paddlefl/paddlefl:v3
imagePullPolicy: Always
ports:
- containerPort: 8000
workingDir: /root/k8s_deployment/master
command: ['/bin/bash']
args: ['run_master.sh']
---
apiVersion: apps/v1beta1
kind: Deployment
metadata:
name: fl-scheduler
labels:
app: fl-scheduler
spec:
replicas: 1
template:
metadata:
name: fl-scheduler
labels:
app: fl-scheduler
spec:
containers:
- name: fl-scheduler
image: hub.baidubce.com/paddlefl/paddlefl:v3
imagePullPolicy: Always
ports:
- containerPort: 9091
workingDir: /root/k8s_deployment/scheduler
command: ['/bin/bash']
args: ['run_scheduler.sh']
---
apiVersion: apps/v1beta1
kind: Deployment
metadata:
name: fl-server
labels:
app: fl-server
spec:
replicas: 1
template:
metadata:
name: fl-server
labels:
app: fl-server
spec:
containers:
- name: fl-server
image: hub.baidubce.com/paddlefl/paddlefl:v3
imagePullPolicy: Always
ports:
- containerPort: 8181
workingDir: /root/k8s_deployment/server
command: ['/bin/bash']
args: ['run_server.sh']
env:
- name: POD_IP
valueFrom:
fieldRef:
apiVersion: v1
fieldPath: status.podIP
---
apiVersion: apps/v1beta1
kind: Deployment
metadata:
name: trainer0
labels:
app: trainer0
spec:
replicas: 1
template:
metadata:
name: trainer0
labels:
app: trainer0
spec:
containers:
- name: trainer0
image: hub.baidubce.com/paddlefl/paddlefl:v3
imagePullPolicy: Always
ports:
- containerPort: 9000
workingDir: /root/k8s_deployment/trainer0
command: ['/bin/bash']
args: ['test_trainer.sh']
---
apiVersion: apps/v1beta1
kind: Deployment
metadata:
name: trainer1
labels:
app: trainer1
spec:
replicas: 1
template:
metadata:
name: trainer1
labels:
app: trainer1
spec:
containers:
- name: trainer1
image: hub.baidubce.com/paddlefl/paddlefl:v3
imagePullPolicy: Always
ports:
- containerPort: 9001
workingDir: /root/k8s_deployment/trainer1
command: ['/bin/bash']
args: ['test_trainer.sh']
---
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle.fluid as fluid
import os
import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
def parse_args():
parser = argparse.ArgumentParser(description="master")
parser.add_argument(
'--trainer_num',
type=int,
default=2,
help='number of trainer(default: 2)')
return parser.parse_args()
class Model(object):
def __init__(self):
pass
def mlp(self, inputs, label, hidden_size=128):
self.concat = fluid.layers.concat(inputs, axis=1)
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program()
inputs = [fluid.layers.data( \
name=str(slot_id), shape=[5],
dtype="float32")
for slot_id in range(3)]
label = fluid.layers.data( \
name="label",
shape=[1],
dtype='int64')
model = Model()
model.mlp(inputs, label)
job_generator = JobGenerator()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
[model.predict.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
build_strategy.inner_step = 10
strategy = build_strategy.create_fl_strategy()
# endpoints will be collected through the cluster
# in this example, we suppose endpoints have been collected
server_service_ip = os.environ['FL_SERVER_SERVICE_HOST'] + ":" + os.environ[
'FL_SERVER_SERVICE_PORT_FL_SERVER']
service_endpoints = [server_service_ip]
pod_endpoints = ["0.0.0.0:8181"]
output = "fl_job_config"
args = parse_args()
num_trainer = args.trainer_num
#job_generator.generate_fl_job(
# strategy, server_endpoints=endpoints, worker_num=num_trainer, output=output)
# fl_job_config will be dispatched to workers
job_generator.generate_fl_job_for_k8s(
strategy,
server_pod_endpoints=pod_endpoints,
server_service_endpoints=service_endpoints,
worker_num=2,
output=output)
python fl_master.py --trainer_num 2
tar -zcvf fl_job_config.tar.gz fl_job_config
python -m SimpleHTTPServer 8000
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from paddle_fl.core.scheduler.agent_master import FLScheduler
def parse_args():
parser = argparse.ArgumentParser(description="scheduler")
parser.add_argument(
'--trainer_num',
type=int,
default=2,
help='number trainers(default: 2)')
return parser.parse_args()
args = parse_args()
num_trainer = args.trainer_num
worker_num = num_trainer
server_num = 1
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle_fl as fl
import os
import paddle.fluid as fluid
from paddle_fl.core.server.fl_server import FLServer
from paddle_fl.core.master.fl_job import FLRunTimeJob
import time
server = FLServer()
server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[
'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER'] # IP address for scheduler
#job._endpoints = os.environ['POD_IP'] + ":" + os.environ['FL_SERVER_SERVICE_PORT_FL_SERVER'] # IP address for server
server.set_server_job(job)
server._current_ep = os.environ['FL_SERVER_SERVICE_HOST'] + ":" + os.environ[
'FL_SERVER_SERVICE_PORT_FL_SERVER'] # IP address for server
print(job._scheduler_ep, server._current_ep)
server.start()
print("connect")
export GLOG_v=3
wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
while [ $? -ne 0 ]
do
sleep 3
wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
done
tar -xf fl_job_config.tar.gz
python -u fl_server.py > server.log 2>&1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
import sys
import os
import logging
import time
logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
def reader():
for i in range(1000):
data_dict = {}
for i in range(3):
data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
yield data_dict
trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
#job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[
'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER']
trainer = FLTrainerFactory().create_fl_trainer(job)
#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer._current_ep = os.environ['TRAINER0_SERVICE_HOST'] + ":" + os.environ[
'TRAINER0_SERVICE_PORT_TRAINER0']
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
epoch_id = 0
while not trainer.stop():
print("batch %d start train" % (epoch_id))
train_step = 0
for data in reader():
trainer.run(feed=data, fetch=[])
train_step += 1
if train_step == trainer._step:
break
epoch_id += 1
if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder)
wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
while [ $? -ne 0 ]
do
sleep 3
wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
done
tar -xf fl_job_config.tar.gz
sleep 10
python -u fl_trainer.py 0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
import sys
import os
import logging
import time
logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
def reader():
for i in range(1000):
data_dict = {}
for i in range(3):
data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
yield data_dict
trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
#job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[
'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER']
trainer = FLTrainerFactory().create_fl_trainer(job)
#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer._current_ep = os.environ['TRAINER1_SERVICE_HOST'] + ":" + os.environ[
'TRAINER1_SERVICE_PORT_TRAINER1']
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
epoch_id = 0
while not trainer.stop():
print("batch %d start train" % (epoch_id))
train_step = 0
for data in reader():
trainer.run(feed=data, fetch=[])
train_step += 1
if train_step == trainer._step:
break
epoch_id += 1
if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder)
wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
while [ $? -ne 0 ]
do
sleep 3
wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
done
tar -xf fl_job_config.tar.gz
sleep 10
python -u fl_trainer.py 1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_fl as fl import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
class Model(object): class Model(object):
def __init__(self): def __init__(self):
pass pass
...@@ -14,12 +29,17 @@ class Model(object): ...@@ -14,12 +29,17 @@ class Model(object):
param_attrs = fluid.ParamAttr( param_attrs = fluid.ParamAttr(
name="fc_0.w_0", name="fc_0.w_0",
initializer=fluid.initializer.ConstantInitializer(0.0)) initializer=fluid.initializer.ConstantInitializer(0.0))
self.predict = fluid.layers.fc(input=inputs, size=10, act='softmax', param_attr=param_attrs) self.predict = fluid.layers.fc(input=inputs,
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) size=10,
act='softmax',
param_attr=param_attrs)
self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=label)
self.loss = fluid.layers.mean(self.sum_cost) self.loss = fluid.layers.mean(self.sum_cost)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32') inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64') label = fluid.layers.data(name='y', shape=[1], dtype='int64')
...@@ -31,15 +51,16 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.01) ...@@ -31,15 +51,16 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.01)
job_generator.set_optimizer(optimizer) job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names( job_generator.set_infer_feed_and_target_names([inputs.name, label.name],
[inputs.name, label.name], [model.loss.name]) [model.loss.name])
build_strategy = FLStrategyFactory() build_strategy = FLStrategyFactory()
#build_strategy.fed_avg = True #build_strategy.fed_avg = True
build_strategy.sec_agg = True build_strategy.sec_agg = True
param_name_list = [] param_name_list = []
param_name_list.append("fc_0.w_0.opti.trainer_") # need trainer_id when running param_name_list.append(
"fc_0.w_0.opti.trainer_") # need trainer_id when running
param_name_list.append("fc_0.b_0.opti.trainer_") param_name_list.append("fc_0.b_0.opti.trainer_")
build_strategy.param_name_list = param_name_list build_strategy.param_name_list = param_name_list
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.scheduler.agent_master import FLScheduler from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2 worker_num = 2
server_num = 1 server_num = 1
scheduler = FLScheduler(worker_num,server_num,port=9091) scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num) scheduler.set_sample_worker_num(worker_num)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy import numpy
...@@ -11,16 +25,21 @@ import math ...@@ -11,16 +25,21 @@ import math
import hashlib import hashlib
import hmac import hmac
logging.basicConfig(filename="log/test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) logging.basicConfig(
filename="log/test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
logger = logging.getLogger("FLTrainer") logger = logging.getLogger("FLTrainer")
BATCH_SIZE = 64 BATCH_SIZE = 64
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
trainer_num = 2 trainer_num = 2
trainer_id = int(sys.argv[1]) # trainer id for each guest trainer_id = int(sys.argv[1]) # trainer id for each guest
...@@ -31,7 +50,7 @@ job.load_trainer_job(job_path, trainer_id) ...@@ -31,7 +50,7 @@ job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id trainer.trainer_id = trainer_id
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.trainer_num = trainer_num trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/" trainer.key_dir = "./keys/"
trainer.start() trainer.start()
...@@ -47,8 +66,8 @@ feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace()) ...@@ -47,8 +66,8 @@ feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace())
# for test # for test
test_program = trainer._main_program.clone(for_test=True) test_program = trainer._main_program.clone(for_test=True)
def train_test(train_test_program,
train_test_feed, train_test_reader): def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = [] acc_set = []
avg_loss_set = [] avg_loss_set = []
for test_data in train_test_reader(): for test_data in train_test_reader():
...@@ -61,6 +80,8 @@ def train_test(train_test_program, ...@@ -61,6 +80,8 @@ def train_test(train_test_program,
acc_val_mean = numpy.array(acc_set).mean() acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean() avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean return avg_loss_val_mean, acc_val_mean
# for test # for test
while not trainer.stop(): while not trainer.stop():
...@@ -73,13 +94,16 @@ while not trainer.stop(): ...@@ -73,13 +94,16 @@ while not trainer.stop():
accuracy, = trainer.run(feed=feeder.feed(data), accuracy, = trainer.run(feed=feeder.feed(data),
fetch=["accuracy_0.tmp_0"]) fetch=["accuracy_0.tmp_0"])
if step_i % 100 == 0: if step_i % 100 == 0:
print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0])) print("Epoch: {0}, step: {1}, accuracy: {2}".format(
epoch_id, step_i, accuracy[0]))
print(step_i) print(step_i)
avg_loss_val, acc_val = train_test(train_test_program=test_program, avg_loss_val, acc_val = train_test(
train_test_program=test_program,
train_test_reader=test_reader, train_test_reader=test_reader,
train_test_feed=feeder) train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val)) print("Test with Epoch %d, avg_cost: %s, acc: %s" %
(epoch_id, avg_loss_val, acc_val))
if epoch_id > 40: if epoch_id > 40:
break break
......
...@@ -21,4 +21,3 @@ server=yq01-hpc-lvliang01-smart-master.dmop.baidu.com ...@@ -21,4 +21,3 @@ server=yq01-hpc-lvliang01-smart-master.dmop.baidu.com
python_tar=./python.tar.gz python_tar=./python.tar.gz
wheel=./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64.whl wheel=./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64.whl
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
class Model(object): class Model(object):
def __init__(self): def __init__(self):
pass pass
...@@ -9,8 +24,8 @@ class Model(object): ...@@ -9,8 +24,8 @@ class Model(object):
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu') self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu') self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax') self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost) self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import socket import socket
import random import random
...@@ -49,6 +63,7 @@ default_dict = { ...@@ -49,6 +63,7 @@ default_dict = {
"wheel": "./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64-0.whl" "wheel": "./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64-0.whl"
} }
def load_conf(conf_file, local_dict): def load_conf(conf_file, local_dict):
with open(conf_file) as fin: with open(conf_file) as fin:
for line in fin: for line in fin:
...@@ -58,6 +73,7 @@ def load_conf(conf_file, local_dict): ...@@ -58,6 +73,7 @@ def load_conf(conf_file, local_dict):
local_dict[group[0]] = group[1] local_dict[group[0]] = group[1]
return local_dict return local_dict
client = HPCClient() client = HPCClient()
default_dict = load_conf(sys.argv[1], default_dict) default_dict = load_conf(sys.argv[1], default_dict)
...@@ -94,9 +110,11 @@ all_ips_ready = False ...@@ -94,9 +110,11 @@ all_ips_ready = False
ip_list = [] ip_list = []
scheduler = FLScheduler(int(default_dict["worker_nodes"]), scheduler = FLScheduler(
int(default_dict["worker_nodes"]),
int(default_dict["server_nodes"]), int(default_dict["server_nodes"]),
port=random_port, socket=zmq_socket) port=random_port,
socket=zmq_socket)
scheduler.set_sample_worker_num(int(default_dict["worker_nodes"])) scheduler.set_sample_worker_num(int(default_dict["worker_nodes"]))
...@@ -124,9 +142,11 @@ for i in range(len(ip_list)): ...@@ -124,9 +142,11 @@ for i in range(len(ip_list)):
if i < int(default_dict["server_nodes"]): if i < int(default_dict["server_nodes"]):
ip_role[ip_list[i]] = 'server%d' % i ip_role[ip_list[i]] = 'server%d' % i
else: else:
ip_role[ip_list[i]] = 'trainer%d' % (i-int(default_dict["server_nodes"])) ip_role[ip_list[i]] = 'trainer%d' % (
i - int(default_dict["server_nodes"]))
print(ip_role) print(ip_role)
def job_generate(): def job_generate():
#generate a fl job which is the same as fl_master #generate a fl job which is the same as fl_master
inputs = [fluid.layers.data( \ inputs = [fluid.layers.data( \
...@@ -146,8 +166,8 @@ def job_generate(): ...@@ -146,8 +166,8 @@ def job_generate():
job_generator.set_optimizer(optimizer) job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names( job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
[x.name for x in inputs], [model.predict.name]) [model.predict.name])
build_strategy = FLStrategyFactory() build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True build_strategy.fed_avg = True
...@@ -160,17 +180,21 @@ def job_generate(): ...@@ -160,17 +180,21 @@ def job_generate():
output = "job_config" output = "job_config"
job_generator.generate_fl_job( job_generator.generate_fl_job(
strategy, server_endpoints=server_ip, worker_num=int(default_dict["worker_nodes"]), output=output) strategy,
server_endpoints=server_ip,
worker_num=int(default_dict["worker_nodes"]),
output=output)
file_list = os.listdir(output) file_list = os.listdir(output)
for file in file_list: for file in file_list:
tar = tarfile.open('{}/{}.tar.gz'.format(output,file),'w:gz') tar = tarfile.open('{}/{}.tar.gz'.format(output, file), 'w:gz')
for root,dir,files in os.walk("{}/{}".format(output,file)): for root, dir, files in os.walk("{}/{}".format(output, file)):
for f in files: for f in files:
fullpath = os.path.join(root,f) fullpath = os.path.join(root, f)
tar.add(fullpath) tar.add(fullpath)
tar.close() tar.close()
job_generate() job_generate()
#send the allocated rolls to the remote endpoints #send the allocated rolls to the remote endpoints
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket import socket
import random import random
import zmq import zmq
...@@ -13,7 +27,6 @@ import sys ...@@ -13,7 +27,6 @@ import sys
import logging import logging
import time import time
random_port = 60001 random_port = 60001
scheduler_conf = {} scheduler_conf = {}
...@@ -31,8 +44,7 @@ download_url = "{}:8080".format(scheduler_ip[0]) ...@@ -31,8 +44,7 @@ download_url = "{}:8080".format(scheduler_ip[0])
print(download_url) print(download_url)
context = zmq.Context() context = zmq.Context()
zmq_socket = context.socket(zmq.REQ) zmq_socket = context.socket(zmq.REQ)
zmq_socket.connect( zmq_socket.connect("tcp://{}".format(scheduler_conf["ENDPOINT"]))
"tcp://{}".format(scheduler_conf["ENDPOINT"]))
zmq_socket.send("ENDPOINT\t{}".format(endpoint)) zmq_socket.send("ENDPOINT\t{}".format(endpoint))
message = zmq_socket.recv() message = zmq_socket.recv()
print(message) print(message)
...@@ -47,7 +59,7 @@ while True: ...@@ -47,7 +59,7 @@ while True:
if group[0] == "WAIT": if group[0] == "WAIT":
continue continue
else: else:
os.system("wget {}/job_config/{}.tar.gz".format(download_url,message)) os.system("wget {}/job_config/{}.tar.gz".format(download_url, message))
print(message) print(message)
break break
...@@ -71,6 +83,7 @@ if 'server' in message: ...@@ -71,6 +83,7 @@ if 'server' in message:
server._current_ep = endpoint server._current_ep = endpoint
server.start() server.start()
else: else:
def reader(): def reader():
for i in range(1000): for i in range(1000):
data_dict = {} data_dict = {}
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import os import os
class Gru4rec_Reader: class Gru4rec_Reader:
def __init__(self): def __init__(self):
pass pass
...@@ -21,7 +36,6 @@ class Gru4rec_Reader: ...@@ -21,7 +36,6 @@ class Gru4rec_Reader:
res.set_lod([lod]) res.set_lod([lod])
return res return res
def lod_reader(self, reader, place): def lod_reader(self, reader, place):
def feed_reader(): def feed_reader():
for data in reader(): for data in reader():
...@@ -33,12 +47,14 @@ class Gru4rec_Reader: ...@@ -33,12 +47,14 @@ class Gru4rec_Reader:
fe_data["src_wordseq"] = lod_src_wordseq fe_data["src_wordseq"] = lod_src_wordseq
fe_data["dst_wordseq"] = lod_dst_wordseq fe_data["dst_wordseq"] = lod_dst_wordseq
yield fe_data yield fe_data
return feed_reader return feed_reader
def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False): def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False):
""" """
Create a batched reader. Create a batched reader.
""" """
def batch_reader(): def batch_reader():
r = reader() r = reader()
b = [] b = []
...@@ -66,11 +82,11 @@ class Gru4rec_Reader: ...@@ -66,11 +82,11 @@ class Gru4rec_Reader:
# Batch size check # Batch size check
batch_size = int(batch_size) batch_size = int(batch_size)
if batch_size <= 0: if batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, " raise ValueError(
"batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size)) "but got batch_size={}".format(batch_size))
return batch_reader return batch_reader
def reader_creator(self, file_dir): def reader_creator(self, file_dir):
def reader(): def reader():
files = os.listdir(file_dir) files = os.listdir(file_dir)
...@@ -82,10 +98,12 @@ class Gru4rec_Reader: ...@@ -82,10 +98,12 @@ class Gru4rec_Reader:
src_seq = l[:len(l) - 1] src_seq = l[:len(l) - 1]
trg_seq = l[1:] trg_seq = l[1:]
yield src_seq, trg_seq yield src_seq, trg_seq
return reader return reader
def reader(self, file_dir, place, batch_size=5): def reader(self, file_dir, place, batch_size=5):
""" prepare the English Pann Treebank (PTB) data """ """ prepare the English Pann Treebank (PTB) data """
print("start constuct word dict") print("start constuct word dict")
reader = self.sort_batch(self.reader_creator(file_dir), batch_size, batch_size * 20) reader = self.sort_batch(
self.reader_creator(file_dir), batch_size, batch_size * 20)
return self.lod_reader(reader, place) return self.lod_reader(reader, place)
...@@ -12,6 +12,5 @@ ...@@ -12,6 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PaddleFL version string """ """ PaddleFL version string """
fl_version = "0.1.10" fl_version = "0.1.11"
module_proto_version = "0.1.10" module_proto_version = "0.1.11"
...@@ -26,10 +26,12 @@ from paddle_fl.version import fl_version ...@@ -26,10 +26,12 @@ from paddle_fl.version import fl_version
def python_version(): def python_version():
return [int(v) for v in platform.python_version().split(".")] return [int(v) for v in platform.python_version().split(".")]
max_version, mid_version, min_version = python_version() max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0','paddlepaddle >= 1.6', 'zmq', 'paddlepaddle-gpu >= 1.6' 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.6', 'zmq',
'paddlepaddle-gpu >= 1.6'
] ]
if max_version < 3: if max_version < 3:
...@@ -42,8 +44,7 @@ REQUIRED_PACKAGES += ["unittest2"] ...@@ -42,8 +44,7 @@ REQUIRED_PACKAGES += ["unittest2"]
setup( setup(
name='paddle_fl', name='paddle_fl',
version=fl_version.replace('-', ''), version=fl_version.replace('-', ''),
description= description=('Federated Deep Learning Package Based on PaddlePaddle.'),
('Federated Deep Learning Package Based on PaddlePaddle.'),
long_description='', long_description='',
url='https://github.com/PaddlePaddle/PaddleFL', url='https://github.com/PaddlePaddle/PaddleFL',
author='PaddlePaddle Author', author='PaddlePaddle Author',
...@@ -70,4 +71,5 @@ setup( ...@@ -70,4 +71,5 @@ setup(
'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: Software Development :: Libraries :: Python Modules',
], ],
license='Apache 2.0', license='Apache 2.0',
keywords=('paddle_fl paddlepaddle multi-task transfer distributed-training')) keywords=(
'paddle_fl paddlepaddle multi-task transfer distributed-training'))
#!/bin/bash
set -e
readonly VERSION="3.8"
version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1
fi
clang-format $@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import io, re
import sys, os
import subprocess
import platform
COPYRIGHT = '''
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
LANG_COMMENT_MARK = None
NEW_LINE_MARK = None
COPYRIGHT_HEADER = None
if platform.system() == "Windows":
NEW_LINE_MARK = "\r\n"
else:
NEW_LINE_MARK = '\n'
COPYRIGHT_HEADER = COPYRIGHT.split(NEW_LINE_MARK)[1]
p = re.search('(\d{4})', COPYRIGHT_HEADER).group(0)
process = subprocess.Popen(["date", "+%Y"], stdout=subprocess.PIPE)
date, err = process.communicate()
date = date.decode("utf-8").rstrip("\n")
COPYRIGHT_HEADER = COPYRIGHT_HEADER.replace(p, date)
def generate_copyright(template, lang='C'):
if lang == 'Python':
LANG_COMMENT_MARK = '#'
else:
LANG_COMMENT_MARK = "//"
lines = template.split(NEW_LINE_MARK)
BLANK = " "
ans = LANG_COMMENT_MARK + BLANK + COPYRIGHT_HEADER + NEW_LINE_MARK
for lino, line in enumerate(lines):
if lino == 0 or lino == 1 or lino == len(lines) - 1: continue
if len(line) == 0:
BLANK = ""
else:
BLANK = " "
ans += LANG_COMMENT_MARK + BLANK + line + NEW_LINE_MARK
return ans + "\n"
def lang_type(filename):
if filename.endswith(".py"):
return "Python"
elif filename.endswith(".h"):
return "C"
elif filename.endswith(".c"):
return "C"
elif filename.endswith(".hpp"):
return "C"
elif filename.endswith(".cc"):
return "C"
elif filename.endswith(".cpp"):
return "C"
elif filename.endswith(".cu"):
return "C"
elif filename.endswith(".cuh"):
return "C"
elif filename.endswith(".go"):
return "C"
elif filename.endswith(".proto"):
return "C"
else:
print("Unsupported filetype %s", filename)
exit(0)
PYTHON_ENCODE = re.compile("^[ \t\v]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)")
def main(argv=None):
parser = argparse.ArgumentParser(
description='Checker for copyright declaration.')
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
fd = io.open(filename, encoding="utf-8")
first_line = fd.readline()
second_line = fd.readline()
if "COPYRIGHT (C)" in first_line.upper(): continue
if first_line.startswith("#!") or PYTHON_ENCODE.match(
second_line) != None or PYTHON_ENCODE.match(first_line) != None:
continue
original_contents = io.open(filename, encoding="utf-8").read()
new_contents = generate_copyright(
COPYRIGHT, lang_type(filename)) + original_contents
print('Auto Insert Copyright Header {}'.format(filename))
retv = 1
with io.open(filename, 'w') as output_file:
output_file.write(new_contents)
return retv
if __name__ == '__main__':
exit(main())
#!/bin/bash
TOTAL_ERRORS=0
if [[ ! $TRAVIS_BRANCH ]]; then
# install cpplint on local machine.
if [[ ! $(which cpplint) ]]; then
pip install cpplint
fi
# diff files on local machine.
files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
else
# diff files between PR and latest commit on Travis CI.
branch_ref=$(git rev-parse "$TRAVIS_BRANCH")
head_ref=$(git rev-parse HEAD)
files=$(git diff --name-status $branch_ref $head_ref | awk '$1 != "D" {print $2}')
fi
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $files; do
if [[ $file =~ ^(patches/grpc/.*) ]]; then
continue;
else
cpplint --filter=-readability/fn_size $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
fi
done
exit $TOTAL_ERRORS
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DocstringChecker is used to check python doc string's style."""
import six
import astroid
from pylint.checkers import BaseChecker, utils
from pylint.interfaces import IAstroidChecker
from collections import defaultdict
import re
def register(linter):
"""Register checkers."""
linter.register_checker(DocstringChecker(linter))
class Docstring(object):
"""Docstring class holds the parsed doc string elements.
"""
def __init__(self):
self.d = defaultdict(list) #name->[]
self.clear()
def clear(self):
self.d['Args'] = []
self.d['Examples'] = []
self.d['Returns'] = []
self.d['Raises'] = []
self.args = {} #arg_name->arg_type
def get_level(self, string, indent=' '):
level = 0
unit_size = len(indent)
while string[:unit_size] == indent:
string = string[unit_size:]
level += 1
return level
def parse(self, doc):
"""parse gets sections from doc
Such as Args, Returns, Raises, Examples s
Args:
doc (string): is the astroid node doc string.
Returns:
True if doc is parsed successfully.
"""
self.clear()
lines = doc.splitlines()
state = ("others", -1)
for l in lines:
c = l.strip()
if len(c) <= 0:
continue
level = self.get_level(l)
if c.startswith("Args:"):
state = ("Args", level)
elif c.startswith("Returns:"):
state = ("Returns", level)
elif c.startswith("Raises:"):
state = ("Raises", level)
elif c.startswith("Examples:"):
state = ("Examples", level)
else:
if level > state[1]:
self.d[state[0]].append(c)
continue
state = ("others", -1)
self.d[state[0]].append(c)
self._arg_with_type()
return True
def get_returns(self):
return self.d['Returns']
def get_raises(self):
return self.d['Raises']
def get_examples(self):
return self.d['Examples']
def _arg_with_type(self):
for t in self.d['Args']:
m = re.search('([A-Za-z0-9_-]+)\s{0,4}(\(.+\))\s{0,4}:', t)
if m:
self.args[m.group(1)] = m.group(2)
return self.args
class DocstringChecker(BaseChecker):
"""DosstringChecker is pylint checker to
check docstring style.
"""
__implements__ = (IAstroidChecker, )
POSITIONAL_MESSAGE_ID = 'str-used-on-positional-format-argument'
KEYWORD_MESSAGE_ID = 'str-used-on-keyword-format-argument'
name = 'doc-string-checker'
symbol = "doc-string"
priority = -1
msgs = {
'W9001': ('One line doc string on > 1 lines', symbol + "-one-line",
'Used when a short doc string is on multiple lines'),
'W9002':
('Doc string does not end with "." period', symbol + "-end-with",
'Used when a doc string does not end with a period'),
'W9003':
('All args with their types must be mentioned in doc string %s',
symbol + "-with-all-args",
'Used when not all arguments are in the doc string '),
'W9005': ('Missing docstring or docstring is too short',
symbol + "-missing", 'Add docstring longer >=10'),
'W9006': ('Docstring indent error, use 4 space for indent',
symbol + "-indent-error", 'Use 4 space for indent'),
'W9007': ('You should add `Returns` in comments',
symbol + "-with-returns",
'There should be a `Returns` section in comments'),
'W9008': ('You should add `Raises` section in comments',
symbol + "-with-raises",
'There should be a `Raises` section in comments'),
}
options = ()
def visit_functiondef(self, node):
"""visit_functiondef checks Function node docstring style.
Args:
node (astroid.node): The visiting node.
Returns:
True if successful other wise False.
"""
self.check_doc_string(node)
if node.tolineno - node.fromlineno <= 10:
return True
if not node.doc:
return True
doc = Docstring()
doc.parse(node.doc)
self.all_args_in_doc(node, doc)
self.with_returns(node, doc)
self.with_raises(node, doc)
def visit_module(self, node):
self.check_doc_string(node)
def visit_classdef(self, node):
self.check_doc_string(node)
def check_doc_string(self, node):
self.missing_doc_string(node)
self.one_line(node)
self.has_period(node)
self.indent_style(node)
def missing_doc_string(self, node):
if node.name.startswith("__") or node.name.startswith("_"):
return True
if node.tolineno - node.fromlineno <= 10:
return True
if node.doc is None or len(node.doc) < 10:
self.add_message('W9005', node=node, line=node.fromlineno)
return False
# FIXME(gongwb): give the docstring line-no
def indent_style(self, node, indent=4):
"""indent_style checks docstring's indent style
Args:
node (astroid.node): The visiting node.
indent (int): The default indent of style
Returns:
True if successful other wise False.
"""
if node.doc is None:
return True
doc = node.doc
lines = doc.splitlines()
line_num = 0
for l in lines:
if line_num == 0:
continue
cur_indent = len(l) - len(l.lstrip())
if cur_indent % indent != 0:
self.add_message('W9006', node=node, line=node.fromlineno)
return False
line_num += 1
return True
def one_line(self, node):
"""one_line checks if docstring (len < 40) is on one line.
Args:
node (astroid.node): The node visiting.
Returns:
True if successful otherwise False.
"""
doc = node.doc
if doc is None:
return True
if len(doc) > 40:
return True
elif sum(doc.find(nl) for nl in ('\n', '\r', '\n\r')) == -3:
return True
else:
self.add_message('W9001', node=node, line=node.fromlineno)
return False
return True
def has_period(self, node):
"""has_period checks if one line doc end-with '.' .
Args:
node (astroid.node): the node is visiting.
Returns:
True if successful otherwise False.
"""
if node.doc is None:
return True
if len(node.doc.splitlines()) > 1:
return True
if not node.doc.strip().endswith('.'):
self.add_message('W9002', node=node, line=node.fromlineno)
return False
return True
def with_raises(self, node, doc):
"""with_raises checks if one line doc end-with '.' .
Args:
node (astroid.node): the node is visiting.
doc (Docstring): Docstring object.
Returns:
True if successful otherwise False.
"""
find = False
for t in node.body:
if not isinstance(t, astroid.Raise):
continue
find = True
break
if not find:
return True
if len(doc.get_raises()) == 0:
self.add_message('W9008', node=node, line=node.fromlineno)
return False
return True
def with_returns(self, node, doc):
"""with_returns checks if docstring comments what are returned .
Args:
node (astroid.node): the node is visiting.
doc (Docstring): Docstring object.
Returns:
True if successful otherwise False.
"""
if node.name.startswith("__") or node.name.startswith("_"):
return True
find = False
for t in node.body:
if not isinstance(t, astroid.Return):
continue
find = True
break
if not find:
return True
if len(doc.get_returns()) == 0:
self.add_message('W9007', node=node, line=node.fromlineno)
return False
return True
def all_args_in_doc(self, node, doc):
"""all_args_in_doc checks if arguments are mentioned in doc
Args:
node (astroid.node): the node is visiting.
doc (Docstring): Docstring object
Returns:
True if successful otherwise False.
"""
if node.name.startswith("__") or node.name.startswith("_"):
return True
args = []
for arg in node.args.get_children():
if (not isinstance(arg, astroid.AssignName)) \
or arg.name == "self":
continue
args.append(arg.name)
if len(args) <= 0:
return True
parsed_args = doc.args
args_not_documented = set(args) - set(parsed_args)
if len(args) > 0 and len(parsed_args) <= 0:
self.add_message(
'W9003',
node=node,
line=node.fromlineno,
args=list(args_not_documented))
return False
for t in args:
if t not in parsed_args:
self.add_message(
'W9003', node=node, line=node.fromlineno, args=[t, ])
return False
return True
#!/bin/bash
TOTAL_ERRORS=0
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export PYTHONPATH=$DIR:$PYTHONPATH
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $(git diff --name-status | awk '$1 != "D" {print $2}'); do
pylint --disable=all --load-plugins=docstring_checker \
--enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
done
exit $TOTAL_ERRORS
#For now, just warning:
#exit 0
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import docstring_checker
import pylint.testutils
import astroid
import pytest
import sys
class TestDocstring(pylint.testutils.CheckerTestCase):
CHECKER_CLASS = docstring_checker.DocstringChecker
def test_one_line(self):
func_node = astroid.extract_node('''
def test():
"""get
news.
"""
if True:
return 5
return 5
''')
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 1
assert 'W9001' == got[0][0]
def test_one_line(self):
func_node = astroid.extract_node('''
def test():
"""get news"""
if True:
return 5
return 5
''')
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 1
assert 'W9002' == got[0][0]
def test_args(self):
func_node = astroid.extract_node('''
def test(scale, mean):
"""get news.
Args:
scale (int): scale is the number.
"""
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
''')
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 1
assert 'W9003' == got[0][0]
def test_missing(self):
func_node = astroid.extract_node('''
def test():
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
''')
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 1
assert 'W9005' == got[0][0]
def test_indent(self):
func_node = astroid.extract_node('''
def test():
""" get get get get get get get get
get get get get get get get get.
"""
pass
''')
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 1
assert 'W9006' == got[0][0]
def test_with_resturns(self):
func_node = astroid.extract_node('''
def test():
"""get news.
Args:
scale (int): scale is the number.
"""
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
return mean
''')
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 1
assert 'W9007' == got[0][0]
def test_with_raises(self):
func_node = astroid.extract_node('''
def test():
"""get news.
Args:
scale (int): scale is the number.
"""
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
mean=scale
raise ValueError('A very specific bad thing happened.')
''')
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 1
assert 'W9008' == got[0][0]
def test_no_message(self):
p = '''
def fc(input,
size,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
act=None,
name=None):
"""
**Fully Connected Layer**
The fully connected layer can take multiple tensors as its inputs. It
creates a variable called weights for each input tensor, which represents
a fully connected weight matrix from each input unit to each output unit.
The fully connected layer multiplies each input tensor with its coresponding
weight to produce an output Tensor. If multiple input tensors are given,
the results of multiple multiplications will be sumed up. If bias_attr is
not None, a bias variable will be created and added to the output. Finally,
if activation is not None, it will be applied to the output as well.
This process can be formulated as follows:
Args:
input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of
the input tensor(s) is at least 2.
size(int): The number of output units in this layer.
num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multidimensional tensor will first be flattened
into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input
tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1)
dimensions will be flatten to form the first dimension of the final matrix (height of
the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to
form the second dimension of the final matrix (width of the matrix). For example, suppose
`X` is a 6-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3.
Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30].
param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units.
act (str, default None): Activation to be applied to the output of this layer.
name (str, default None): The name of this layer.
Returns:
A tensor variable storing the transformation result.
Raises:
ValueError: If rank of the input tensor is less than 2.
Examples:
.. code-block:: python
data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
fc = fluid.layers.fc(input=data, size=1000, act="tanh")
"""
raise ValueError('A very specific bad thing happened.')
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
size = 1
return size
'''
func_node = astroid.extract_node(p)
self.checker.visit_functiondef(func_node)
got = self.linter.release_messages()
assert len(got) == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册