diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d8112837dc9627bc2e501940b8e97c89e97c45ff
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,52 @@
+repos:
+- repo: https://github.com/Lucas-C/pre-commit-hooks.git
+ sha: v1.0.1
+ hooks:
+ - id: remove-crlf
+ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
+- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
+ sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
+ hooks:
+ - id: yapf
+ files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0
+ hooks:
+ - id: check-added-large-files
+ - id: check-merge-conflict
+ - id: check-symlinks
+ - id: detect-private-key
+ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
+ - id: end-of-file-fixer
+- repo: local
+ hooks:
+ - id: clang-format-with-version-check
+ name: clang-format
+ description: Format files with ClangFormat.
+ entry: bash ./tools/codestyle/clang_format.hook -i
+ language: system
+ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
+- repo: local
+ hooks:
+ - id: cpplint-cpp-source
+ name: cpplint
+ description: Check C++ code style using cpplint.py.
+ entry: bash ./tools/codestyle/cpplint_pre_commit.hook
+ language: system
+ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
+- repo: local
+ hooks:
+ - id: pylint-doc-string
+ name: pylint
+ description: Check python docstring style using docstring_checker.
+ entry: bash ./tools/codestyle/pylint_pre_commit.hook
+ language: system
+ files: \.(py)$
+- repo: local
+ hooks:
+ - id: copyright_checker
+ name: copyright_checker
+ entry: python ./tools/codestyle/copyright.hook
+ language: system
+ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
+ exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9f25a7770bbebfc11617923bfd49175d229993df
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,11 @@
+language:python
+
+notifications:
+ email:
+ on_success: change
+ on_failure: always
+
+sudo: false
+
+os:
+ - linux
diff --git a/AUTHORS.md b/AUTHORS.md
index de2c1cba58d1e3079959cf40f40a4b89c4c19455..f22dad93df96e011e648e872ecd9d86f55b3af35 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -1,4 +1,5 @@
| Github account | name |
|---|---|
| guru4elephant | Daxiang Dong |
-| frankwhzhang | Wenhui Zhang |
\ No newline at end of file
+| frankwhzhang | Wenhui Zhang |
+| qjing666 | Qinghe Jing |
diff --git a/contrib/data_safety_training/image_classification/server/receiver.py b/contrib/data_safety_training/image_classification/server/receiver.py
index 70d19b02574e889aa66f0a98e12f061a8bf05a0f..0ac474ac35e67ef9810baafdd20d5fb05a655f62 100644
--- a/contrib/data_safety_training/image_classification/server/receiver.py
+++ b/contrib/data_safety_training/image_classification/server/receiver.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import zmq
import socket
import msgpack
diff --git a/contrib/data_safety_training/image_classification/server/server.py b/contrib/data_safety_training/image_classification/server/server.py
index 39f4781d5e9b0c18a1548a0e09281f29bc187abd..93f8f53c1907e3d1a9b5e46da2eb2cf2204551c8 100644
--- a/contrib/data_safety_training/image_classification/server/server.py
+++ b/contrib/data_safety_training/image_classification/server/server.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import print_function
import os
import paddle
@@ -12,14 +26,14 @@ import math
import msgpack
-def data_generater(samples,r):
- # data generater
+def data_generater(samples, r):
+ # data generater
def train_data():
for item in samples:
sample = msgpack.loads(r.get(str(item)))
conv = sample[0]
label = sample[1]
- yield conv,label
+ yield conv, label
return train_data
@@ -67,7 +81,7 @@ class ResNet():
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)),
- act = "softmax")
+ act="softmax")
else:
for block in range(len(depth)):
for i in range(depth[block]):
@@ -87,7 +101,7 @@ class ResNet():
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)),
- act = "softmax")
+ act="softmax")
return out
def conv_bn_layer(self,
@@ -123,8 +137,6 @@ class ResNet():
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
-
-
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True:
@@ -181,31 +193,33 @@ class ResNet():
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
+
# local redis config
redis_host = "127.0.0.1"
redis_port = 6379
redis_password = ""
-r = redis.StrictRedis(host=redis_host, port=redis_port, password=redis_password)
+r = redis.StrictRedis(
+ host=redis_host, port=redis_port, password=redis_password)
# reader generation
-reader = fluid.layers.py_reader(capacity=64,
- shapes=[(-1,64, 8, 8), (-1,1)],
- dtypes=['float32', 'int64'])
+reader = fluid.layers.py_reader(
+ capacity=64, shapes=[(-1, 64, 8, 8), (-1, 1)],
+ dtypes=['float32', 'int64'])
samples = r.keys()
-train_data = data_generater(samples,r)
+train_data = data_generater(samples, r)
-reader.decorate_paddle_reader(paddle.batch(
- paddle.reader.shuffle(
- train_data, buf_size=5000),
- batch_size=64))
+reader.decorate_paddle_reader(
+ paddle.batch(
+ paddle.reader.shuffle(
+ train_data, buf_size=5000), batch_size=64))
-conv1,label = fluid.layers.read_file(reader)
+conv1, label = fluid.layers.read_file(reader)
# train program
place = fluid.CUDAPlace(0)
model = ResNet(layers=50)
-predicts = model.net(conv1,10)
+predicts = model.net(conv1, 10)
cost = fluid.layers.cross_entropy(input=predicts, label=label)
accuracy = fluid.layers.accuracy(input=predicts, label=label)
loss = fluid.layers.mean(cost)
@@ -222,18 +236,20 @@ step = 0
train_start = time.time()
# start training
for pass_id in range(EPOCH_NUM):
- reader.start()
- try:
- while True:
- start_time = time.time()
- loss_value,acc_value = exe.run(fetch_list=[loss.name,accuracy.name])
- step += 1
- if step % 10 == 0:
- print("epoch: "+ str(pass_id)+"step: "+str(step)+"loss: "+ str(loss_value)+"acc: "+str(acc_value))
- end_time = time.time()
- total_time += (end_time - start_time)
- except fluid.core.EOFException:
- reader.reset()
+ reader.start()
+ try:
+ while True:
+ start_time = time.time()
+ loss_value, acc_value = exe.run(
+ fetch_list=[loss.name, accuracy.name])
+ step += 1
+ if step % 10 == 0:
+ print("epoch: " + str(pass_id) + "step: " + str(step) +
+ "loss: " + str(loss_value) + "acc: " + str(acc_value))
+ end_time = time.time()
+ total_time += (end_time - start_time)
+ except fluid.core.EOFException:
+ reader.reset()
train_end = time.time()
print("total time: %d" % (train_end - train_start))
print("computation time: %d" % total_time)
diff --git a/contrib/data_safety_training/image_classification/server/user.py b/contrib/data_safety_training/image_classification/server/user.py
index 89668f35d4b69a4b7d561ac4b562fe4531e7b223..505c329f0ca8310824989d83af4a04f05714d49d 100644
--- a/contrib/data_safety_training/image_classification/server/user.py
+++ b/contrib/data_safety_training/image_classification/server/user.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import print_function
import os
import paddle
@@ -5,10 +19,12 @@ import paddle.fluid as fluid
import numpy
import sys
import redis
-import time
+import time
from paddle.fluid import layers
from paddle.fluid.param_attr import ParamAttr
import msgpack
+
+
def conv_bn_layer(input,
num_filters,
filter_size,
@@ -16,30 +32,30 @@ def conv_bn_layer(input,
groups=1,
act=None,
name=None):
- conv = fluid.layers.conv2d(
- input=input,
- num_filters=num_filters,
- filter_size=filter_size,
- stride=stride,
- padding=(filter_size - 1) // 2,
- groups=groups,
- act=None,
- param_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False,
- name=name + '.conv2d.output.1')
-
- if name == "conv1":
- bn_name = "bn_" + name
- else:
- bn_name = "bn" + name[3:]
- return fluid.layers.batch_norm(
- input=conv,
- act=act,
- name=bn_name + '.output.1',
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance', )
+ conv = fluid.layers.conv2d(
+ input=input,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=(filter_size - 1) // 2,
+ groups=groups,
+ act=None,
+ param_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False,
+ name=name + '.conv2d.output.1')
+
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ return fluid.layers.batch_norm(
+ input=conv,
+ act=act,
+ name=bn_name + '.output.1',
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance', )
def load_conf(conf_file, local_dict):
@@ -51,6 +67,7 @@ def load_conf(conf_file, local_dict):
local_dict[group[0]] = group[1]
return local_dict
+
# redis DB configuration
redis_host = "127.0.0.1"
redis_port = 6379
@@ -58,26 +75,39 @@ redis_password = ""
start_time = time.time()
# start a redis client and empty the DB
-r = redis.StrictRedis(host=redis_host, port=redis_port, password=redis_password)
+r = redis.StrictRedis(
+ host=redis_host, port=redis_port, password=redis_password)
r.flushall()
# encoding program
-images = fluid.layers.data(name='images', shape=[3,32,32], dtype='float32')
+images = fluid.layers.data(name='images', shape=[3, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
place = fluid.CPUPlace()
-conv1 = conv_bn_layer(input=images,num_filters=64,filter_size=7,stride=2,act='relu',name="conv1")
-pool = fluid.layers.pool2d(input=conv1,pool_size=3,pool_stride=2,pool_padding=1,pool_type='max')
-feeder = fluid.DataFeeder(place=place, feed_list=[images,label])
+conv1 = conv_bn_layer(
+ input=images,
+ num_filters=64,
+ filter_size=7,
+ stride=2,
+ act='relu',
+ name="conv1")
+pool = fluid.layers.pool2d(
+ input=conv1, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
+feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
pretrained_model = 'ResNet50_pretrained'
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
+
# load pretrained mode and prepare datal
def if_exist(var):
- return os.path.exists(os.path.join(pretrained_model, var.name))
-fluid.io.load_vars(exe, pretrained_model, main_program=fluid.default_main_program(),
- predicate=if_exist)
+ return os.path.exists(os.path.join(pretrained_model, var.name))
+
+fluid.io.load_vars(
+ exe,
+ pretrained_model,
+ main_program=fluid.default_main_program(),
+ predicate=if_exist)
train_data = paddle.dataset.cifar.train10()
step = 0
@@ -86,11 +116,13 @@ step = 0
for data in train_data():
pre_data = []
pre_data.append(data)
- res = exe.run(program=fluid.default_main_program(),feed=feeder.feed(pre_data), fetch_list=[pool.name])
- sample = [res[0][0].tolist(),data[1]]
+ res = exe.run(program=fluid.default_main_program(),
+ feed=feeder.feed(pre_data),
+ fetch_list=[pool.name])
+ sample = [res[0][0].tolist(), data[1]]
step += 1
file = msgpack.dumps(sample)
- r.set(step,file)
+ r.set(step, file)
if step % 100 == 0:
print(numpy.array(sample[0]).shape)
print("%dstart" % step)
@@ -99,6 +131,4 @@ files = r.keys()
print("upload file numbers: %d" % len(files))
end_time = time.time()
total_time = end_time - start_time
-print("total time: %d"% total_time)
-
-
+print("total time: %d" % total_time)
diff --git a/contrib/data_safety_training/image_classification/submitter.py b/contrib/data_safety_training/image_classification/submitter.py
index 69bd5d8495bda836babbea6ca8dc80deb3677b3f..920b60ad8fed2d7ff0b13d17001d8227f3b0abb8 100644
--- a/contrib/data_safety_training/image_classification/submitter.py
+++ b/contrib/data_safety_training/image_classification/submitter.py
@@ -1,8 +1,22 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import zmq
import socket
import msgpack
import os
-mission_dict = {"mission": "image classification", "image_size": [3,32,32]}
+mission_dict = {"mission": "image classification", "image_size": [3, 32, 32]}
#send request
context = zmq.Context()
zmq_socket = context.socket(zmq.REQ)
diff --git a/docs/requirements.txt b/docs/requirements.txt
index c860963602bfd591ff57c8c8a722dc9670bc582d..90c6171099d9bb6e5ce32ac1164ba848c14426f4 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -3,4 +3,3 @@ mistune
sphinx_rtd_theme
paddlepaddle>=1.6
zmq
-
diff --git a/docs/source/examples/md/dpsgd-example.md b/docs/source/examples/md/dpsgd-example.md
index b0b822702ed9e0cbce1cad20f773bea1172a3755..20f80252b54eb054df6a0c429ef0463f92f20e9b 100644
--- a/docs/source/examples/md/dpsgd-example.md
+++ b/docs/source/examples/md/dpsgd-example.md
@@ -181,4 +181,3 @@ while not trainer.stop():
To show the effectiveness of DPSGD-based federated learning with PaddleFL, a simulated experiment is conducted on an open source dataset MNIST. From the figure given below, model evaluation results are similar between DPSGD-based federated learning and traditional parameter server training when the overall privacy budget *epsilon* is 1.3 or 0.13.
-
diff --git a/docs/source/examples/md/gru4rec_examples.md b/docs/source/examples/md/gru4rec_examples.md
index ea729adce65a678c0d0db52b31dd1e1b015748e6..b40e9eb8dc4033aec55bcf7769a372c9eca2a7f7 100644
--- a/docs/source/examples/md/gru4rec_examples.md
+++ b/docs/source/examples/md/gru4rec_examples.md
@@ -103,6 +103,3 @@ wget https://paddle-zwh.bj.bcebos.com/gru4rec_paddlefl_benchmark/gru4rec_benchma
| 1/4 of the whole dataset | private training | - | 0.282 |
-
-
-
diff --git a/docs/source/md/introduction.md b/docs/source/md/introduction.md
index 8c1513631989061c74ddc327f294ceeed6df31ff..cb38bfe10ed5177d1dae2dcbca6a11cec1f79ccc 100644
--- a/docs/source/md/introduction.md
+++ b/docs/source/md/introduction.md
@@ -55,4 +55,3 @@ In PaddleFL, components for defining a federated learning task and training a fe
- Federated Learning Systems deployment methods in Kubernetes.
- Vertical Federated Learning Strategies and more horizontal federated learning strategies will be open sourced.
-
diff --git a/docs/source/md/reference.md b/docs/source/md/reference.md
index 5a2a72fdadf16842739c1439397b41a665f8de52..214b68a8101faa24115fe8cd9ca025c2ca82df03 100644
--- a/docs/source/md/reference.md
+++ b/docs/source/md/reference.md
@@ -14,4 +14,4 @@
[7]. Virginia Smith, Chao-Kai Chiang, Maziar Sanjabi, Ameet Talwalkar. **Federated Multi-Task Learning** 2016
-[8]. Yang Liu, Tianjian Chen, Qiang Yang. **Secure Federated Transfer Learning.** 2018
\ No newline at end of file
+[8]. Yang Liu, Tianjian Chen, Qiang Yang. **Secure Federated Transfer Learning.** 2018
diff --git a/paddle_fl/common/__init__.py b/paddle_fl/common/__init__.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..abf198b97e6e818e1fbe59006f98492640bcee54 100644
--- a/paddle_fl/common/__init__.py
+++ b/paddle_fl/common/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddle_fl/core/__init__.py b/paddle_fl/core/__init__.py
index 34df36dd857ac7204a1b134d86889f60d1fad7b5..6bec080b88ea20103b8f103eccc5bf0819c7818d 100644
--- a/paddle_fl/core/__init__.py
+++ b/paddle_fl/core/__init__.py
@@ -22,4 +22,3 @@ from .scheduler.agent_master import FLWorkerAgent
from .scheduler.agent_master import FLScheduler
from .submitter.client_base import HPCClient
from .submitter.client_base import CloudClient
-
diff --git a/paddle_fl/core/master/fl_job.py b/paddle_fl/core/master/fl_job.py
index a221897d38393cfb8998ea0d988cff4f3c22e61e..7a80e13e8e6d1e5255ce176a5de4ffdee97bf5e4 100644
--- a/paddle_fl/core/master/fl_job.py
+++ b/paddle_fl/core/master/fl_job.py
@@ -14,11 +14,13 @@
import os
import paddle.fluid as fluid
+
class FLJobBase(object):
"""
FLJobBase is fl job base class, responsible for save and load
a federated learning job
"""
+
def __init__(self):
pass
@@ -64,6 +66,7 @@ class FLJobBase(object):
return fluid.Program.parse_from_string(program_desc_str)
return None
+
class FLCompileTimeJob(FLJobBase):
"""
FLCompileTimeJob is a container for compile time job in federated learning.
@@ -71,6 +74,7 @@ class FLCompileTimeJob(FLJobBase):
are in FLCompileTimeJob. Also, server main programs and server startup programs
are in this class. FLCompileTimeJob has server endpoints for debugging as well
"""
+
def __init__(self):
self._trainer_startup_programs = []
self._trainer_recv_programs = []
@@ -101,69 +105,59 @@ class FLCompileTimeJob(FLJobBase):
os.system("mkdir -p %s" % server_folder)
server_startup = self._server_startup_programs[i]
server_main = self._server_main_programs[i]
- self._save_program(
- server_startup,
- "%s/server.startup.program" % server_folder)
- self._save_program(
- server_main,
- "%s/server.main.program" % server_folder)
+ self._save_program(server_startup,
+ "%s/server.startup.program" % server_folder)
+ self._save_program(server_main,
+ "%s/server.main.program" % server_folder)
+ self._save_readable_program(server_startup,
+ "%s/server.startup.program.txt" %
+ server_folder)
self._save_readable_program(
- server_startup,
- "%s/server.startup.program.txt" % server_folder)
- self._save_readable_program(
- server_main,
- "%s/server.main.program.txt" % server_folder)
+ server_main, "%s/server.main.program.txt" % server_folder)
self._save_str_list(self._feed_names,
- "%s/feed_names" % server_folder)
+ "%s/feed_names" % server_folder)
self._save_str_list(self._target_names,
- "%s/target_names" % server_folder)
+ "%s/target_names" % server_folder)
self._save_endpoints(self._server_endpoints,
- "%s/endpoints" % server_folder)
+ "%s/endpoints" % server_folder)
self._save_strategy(self._strategy,
- "%s/strategy.pkl" % server_folder)
+ "%s/strategy.pkl" % server_folder)
for i in range(trainer_num):
trainer_folder = "%s/trainer%d" % (folder, i)
os.system("mkdir -p %s" % trainer_folder)
trainer_startup = self._trainer_startup_programs[i]
trainer_main = self._trainer_main_programs[i]
- self._save_program(
- trainer_startup,
- "%s/trainer.startup.program" % trainer_folder)
- self._save_program(
- trainer_main,
- "%s/trainer.main.program" % trainer_folder)
- self._save_readable_program(
- trainer_startup,
- "%s/trainer.startup.program.txt" % trainer_folder)
+ self._save_program(trainer_startup,
+ "%s/trainer.startup.program" % trainer_folder)
+ self._save_program(trainer_main,
+ "%s/trainer.main.program" % trainer_folder)
+ self._save_readable_program(trainer_startup,
+ "%s/trainer.startup.program.txt" %
+ trainer_folder)
self._save_readable_program(
- trainer_main,
- "%s/trainer.main.program.txt" % trainer_folder)
+ trainer_main, "%s/trainer.main.program.txt" % trainer_folder)
self._save_str_list(self._feed_names,
- "%s/feed_names" % trainer_folder)
+ "%s/feed_names" % trainer_folder)
self._save_str_list(self._target_names,
- "%s/target_names" % trainer_folder)
+ "%s/target_names" % trainer_folder)
self._save_endpoints(self._server_endpoints,
- "%s/endpoints" % trainer_folder)
+ "%s/endpoints" % trainer_folder)
self._save_strategy(self._strategy,
- "%s/strategy.pkl" % trainer_folder)
+ "%s/strategy.pkl" % trainer_folder)
for i in range(send_prog_num):
trainer_folder = "%s/trainer%d" % (folder, i)
trainer_send = self._trainer_send_programs[i]
trainer_recv = self._trainer_recv_programs[i]
- self._save_program(
- trainer_send,
- "%s/trainer.send.program" % trainer_folder)
- self._save_program(
- trainer_recv,
- "%s/trainer.recv.program" % trainer_folder)
+ self._save_program(trainer_send,
+ "%s/trainer.send.program" % trainer_folder)
+ self._save_program(trainer_recv,
+ "%s/trainer.recv.program" % trainer_folder)
self._save_readable_program(
- trainer_send,
- "%s/trainer.send.program.txt" % trainer_folder)
+ trainer_send, "%s/trainer.send.program.txt" % trainer_folder)
self._save_readable_program(
- trainer_recv,
- "%s/trainer.recv.program.txt" % trainer_folder)
+ trainer_recv, "%s/trainer.recv.program.txt" % trainer_folder)
class FLRunTimeJob(FLJobBase):
@@ -172,6 +166,7 @@ class FLRunTimeJob(FLJobBase):
A trainer or a server can load FLRunTimeJob. Only necessary programs
can be loaded in FLRunTimeJob
"""
+
def __init__(self):
self._trainer_startup_program = None
self._trainer_recv_program = None
diff --git a/paddle_fl/core/master/job_generator.py b/paddle_fl/core/master/job_generator.py
index e4291241a864c354c1a549ef07acb7652ec377dc..64feb7d8083699459d05f3ccee6185cd00194312 100644
--- a/paddle_fl/core/master/job_generator.py
+++ b/paddle_fl/core/master/job_generator.py
@@ -14,6 +14,7 @@
import paddle.fluid as fluid
from .fl_job import FLCompileTimeJob
+
class JobGenerator(object):
"""
A JobGenerator is responsible for generating distributed federated
@@ -21,6 +22,7 @@ class JobGenerator(object):
need to define a deep learning model together to do horizontal federated
learning.
"""
+
def __init__(self):
# worker num for federated learning
self._worker_num = 0
@@ -32,7 +34,6 @@ class JobGenerator(object):
self._feed_names = []
self._target_names = []
-
def set_optimizer(self, optimizer):
"""
Set optimizer of current job
@@ -56,8 +57,10 @@ class JobGenerator(object):
self._startup_prog = startup
def set_infer_feed_and_target_names(self, feed_names, target_names):
- if not isinstance(feed_names, list) or not isinstance(target_names, list):
- raise ValueError("input should be list in set_infer_feed_and_target_names")
+ if not isinstance(feed_names, list) or not isinstance(target_names,
+ list):
+ raise ValueError(
+ "input should be list in set_infer_feed_and_target_names")
'''
print(feed_names)
print(target_names)
@@ -76,7 +79,6 @@ class JobGenerator(object):
server_endpoints=[],
worker_num=1,
output=None):
-
"""
Generate Federated Learning Job, based on user defined configs
@@ -130,17 +132,66 @@ class JobGenerator(object):
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_trainer_program_for_job(
- trainer_id, program=main_program,
- ps_endpoints=server_endpoints, trainers=worker_num,
- sync_mode=True, startup_program=startup_program,
+ trainer_id,
+ program=main_program,
+ ps_endpoints=server_endpoints,
+ trainers=worker_num,
+ sync_mode=True,
+ startup_program=startup_program,
+ job=local_job)
+
+ startup_program = self._startup_prog.clone()
+ main_program = self._losses[0].block.program.clone()
+ fl_strategy._build_server_programs_for_job(
+ program=main_program,
+ ps_endpoints=server_endpoints,
+ trainers=worker_num,
+ sync_mode=True,
+ startup_program=startup_program,
+ job=local_job)
+
+ local_job.set_feed_names(self._feed_names)
+ local_job.set_target_names(self._target_names)
+ local_job.set_strategy(fl_strategy)
+ local_job.save(output)
+
+ def generate_fl_job_for_k8s(self,
+ fl_strategy,
+ server_pod_endpoints=[],
+ server_service_endpoints=[],
+ worker_num=1,
+ output=None):
+
+ local_job = FLCompileTimeJob()
+ assert len(self._losses) > 0
+ assert self._startup_prog != None
+ assert fl_strategy != None
+ assert output != None
+ fl_strategy.minimize(self._optimizer, self._losses)
+
+ # strategy can generate startup and main program
+ # of a single worker and servers
+ for trainer_id in range(worker_num):
+ startup_program = self._startup_prog.clone()
+ main_program = self._losses[0].block.program.clone()
+ fl_strategy._build_trainer_program_for_job(
+ trainer_id,
+ program=main_program,
+ ps_endpoints=server_service_endpoints,
+ trainers=worker_num,
+ sync_mode=True,
+ startup_program=startup_program,
job=local_job)
startup_program = self._startup_prog.clone()
main_program = self._losses[0].block.program.clone()
fl_strategy._build_server_programs_for_job(
- program=main_program, ps_endpoints=server_endpoints,
- trainers=worker_num, sync_mode=True,
- startup_program=startup_program, job=local_job)
+ program=main_program,
+ ps_endpoints=server_pod_endpoints,
+ trainers=worker_num,
+ sync_mode=True,
+ startup_program=startup_program,
+ job=local_job)
local_job.set_feed_names(self._feed_names)
local_job.set_target_names(self._target_names)
diff --git a/paddle_fl/core/scheduler/agent_master.py b/paddle_fl/core/scheduler/agent_master.py
index f26a85a7f44aca7e662dd3140b269944d89201f5..7c1fe4fbca182551a695190f3161488b8d9c4db3 100644
--- a/paddle_fl/core/scheduler/agent_master.py
+++ b/paddle_fl/core/scheduler/agent_master.py
@@ -1,7 +1,22 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import zmq
import time
import random
+
def recv_and_parse_kv(socket):
message = socket.recv()
group = message.decode().split("\t")
@@ -10,9 +25,11 @@ def recv_and_parse_kv(socket):
else:
return group[0], group[1]
+
WORKER_EP = "WORKER_EP"
SERVER_EP = "SERVER_EP"
+
class FLServerAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
@@ -29,6 +46,7 @@ class FLServerAgent(object):
if group[0] == 'INIT':
break
+
class FLWorkerAgent(object):
def __init__(self, scheduler_ep, current_ep):
self.scheduler_ep = scheduler_ep
@@ -64,7 +82,6 @@ class FLWorkerAgent(object):
return False
-
class FLScheduler(object):
def __init__(self, worker_num, server_num, port=9091, socket=None):
self.context = zmq.Context()
diff --git a/paddle_fl/core/server/fl_server.py b/paddle_fl/core/server/fl_server.py
index f8fc1922da8b9ebe2a5abe68782a6f32cf272989..0c1529c1840ea3983691ff6b57decf5f60537ee8 100644
--- a/paddle_fl/core/server/fl_server.py
+++ b/paddle_fl/core/server/fl_server.py
@@ -14,8 +14,8 @@
import paddle.fluid as fluid
from paddle_fl.core.scheduler.agent_master import FLServerAgent
-class FLServer(object):
+class FLServer(object):
def __init__(self):
self._startup_program = None
self._main_program = None
diff --git a/paddle_fl/core/strategy/details/checkport.py b/paddle_fl/core/strategy/details/checkport.py
index 89dd4dd50b0299de986b84f46e889d554030f180..9749bc37dbeff39f39da6b6c726287861f32056d 100644
--- a/paddle_fl/core/strategy/details/checkport.py
+++ b/paddle_fl/core/strategy/details/checkport.py
@@ -48,8 +48,8 @@ def wait_server_ready(endpoints):
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
- sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) +
- "\n")
+ sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints)
+ + "\n")
sys.stderr.flush()
time.sleep(3)
else:
diff --git a/paddle_fl/core/strategy/details/program_utils.py b/paddle_fl/core/strategy/details/program_utils.py
index dc78ffe70b3dfda75a799583e85b76d8d921e078..28ec5e097d7ad8df3d8709f439342a11dc4bd592 100644
--- a/paddle_fl/core/strategy/details/program_utils.py
+++ b/paddle_fl/core/strategy/details/program_utils.py
@@ -163,7 +163,8 @@ def block_to_code(block, block_idx, fout=None, skip_op_callstack=False):
indent = 0
print(
- "{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx),
+ "{0}{1} // block {2}".format(
+ get_indent_space(indent), '{', block_idx),
file=fout)
indent += 1
diff --git a/paddle_fl/core/strategy/fl_distribute_transpiler.py b/paddle_fl/core/strategy/fl_distribute_transpiler.py
index 28b369ce9b3ada38eedc438f10f3b1da9725833b..0d204f8fbcd6c79e7ec85e269d0900c9b01e3b61 100644
--- a/paddle_fl/core/strategy/fl_distribute_transpiler.py
+++ b/paddle_fl/core/strategy/fl_distribute_transpiler.py
@@ -50,6 +50,7 @@ def log(*args):
if PRINT_LOG:
print(args)
+
def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
@@ -113,7 +114,9 @@ class FLDistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = []
- sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"]
+ sparse_update_op_types = [
+ "lookup_table", "nce", "hierarchical_sigmoid"
+ ]
for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True:
@@ -406,12 +409,13 @@ class FLDistributeTranspiler(object):
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var = pserver_program.global_block().var(
- orig_var_name)
+ orig_var_name)
if self.sync_mode and self.trainer_num > 1:
for trainer_id in range(self.trainer_num):
var = pserver_program.global_block().create_var(
- name="%s.opti.trainer_%d" % (orig_var_name, trainer_id),
+ name="%s.opti.trainer_%d" %
+ (orig_var_name, trainer_id),
persistable=False,
type=v.type,
dtype=v.dtype,
@@ -816,7 +820,6 @@ class FLDistributeTranspiler(object):
iomap = collections.OrderedDict()
return iomap
-
def _get_lr_ops(self):
lr_ops = []
block = self.origin_program.global_block()
diff --git a/paddle_fl/core/strategy/fl_strategy_base.py b/paddle_fl/core/strategy/fl_strategy_base.py
index 14e97ce346f4c67dbd7eb5ee34e6a12f2bdd866d..d1579d2f744c4bf67d31263e68872fe0be8b1263 100644
--- a/paddle_fl/core/strategy/fl_strategy_base.py
+++ b/paddle_fl/core/strategy/fl_strategy_base.py
@@ -16,11 +16,13 @@ from .fl_distribute_transpiler import FLDistributeTranspiler
from paddle.fluid.optimizer import SGD
import paddle.fluid as fluid
+
class FLStrategyFactory(object):
"""
FLStrategyFactory is a FLStrategy builder
Users can define strategy config to create different FLStrategy
"""
+
def __init__(self):
self._fed_avg = False
self._dpsgd = False
@@ -86,6 +88,7 @@ class FLStrategyBase(object):
"""
FLStrategyBase is federated learning algorithm container
"""
+
def __init__(self):
self._fed_avg = False
self._dpsgd = False
@@ -105,17 +108,23 @@ class FLStrategyBase(object):
for loss in losses:
optimizer.minimize(loss)
- def _build_trainer_program_for_job(
- self, trainer_id=0, program=None,
- ps_endpoints=[], trainers=0,
- sync_mode=True, startup_program=None,
- job=None):
+ def _build_trainer_program_for_job(self,
+ trainer_id=0,
+ program=None,
+ ps_endpoints=[],
+ trainers=0,
+ sync_mode=True,
+ startup_program=None,
+ job=None):
pass
- def _build_server_programs_for_job(
- self, program=None, ps_endpoints=[],
- trainers=0, sync_mode=True,
- startup_program=None, job=None):
+ def _build_server_programs_for_job(self,
+ program=None,
+ ps_endpoints=[],
+ trainers=0,
+ sync_mode=True,
+ startup_program=None,
+ job=None):
pass
@@ -123,6 +132,7 @@ class DPSGDStrategy(FLStrategyBase):
"""
DPSGDStrategy: Deep Learning with Differential Privacy. 2016
"""
+
def __init__(self):
super(DPSGDStrategy, self).__init__()
@@ -162,29 +172,40 @@ class DPSGDStrategy(FLStrategyBase):
"""
Define Dpsgd optimizer
"""
- optimizer = fluid.optimizer.Dpsgd(self._learning_rate, clip=self._clip, batch_size=self._batch_size, sigma=self._sigma)
+ optimizer = fluid.optimizer.Dpsgd(
+ self._learning_rate,
+ clip=self._clip,
+ batch_size=self._batch_size,
+ sigma=self._sigma)
optimizer.minimize(losses[0])
- def _build_trainer_program_for_job(
- self, trainer_id=0, program=None,
- ps_endpoints=[], trainers=0,
- sync_mode=True, startup_program=None,
- job=None):
+ def _build_trainer_program_for_job(self,
+ trainer_id=0,
+ program=None,
+ ps_endpoints=[],
+ trainers=0,
+ sync_mode=True,
+ startup_program=None,
+ job=None):
transpiler = fluid.DistributeTranspiler()
- transpiler.transpile(trainer_id,
- program=program,
- pservers=",".join(ps_endpoints),
- trainers=trainers,
- sync_mode=sync_mode,
- startup_program=startup_program)
+ transpiler.transpile(
+ trainer_id,
+ program=program,
+ pservers=",".join(ps_endpoints),
+ trainers=trainers,
+ sync_mode=sync_mode,
+ startup_program=startup_program)
main = transpiler.get_trainer_program(wait_port=False)
job._trainer_startup_programs.append(startup_program)
job._trainer_main_programs.append(main)
- def _build_server_programs_for_job(
- self, program=None, ps_endpoints=[],
- trainers=0, sync_mode=True,
- startup_program=None, job=None):
+ def _build_server_programs_for_job(self,
+ program=None,
+ ps_endpoints=[],
+ trainers=0,
+ sync_mode=True,
+ startup_program=None,
+ job=None):
transpiler = fluid.DistributeTranspiler()
trainer_id = 0
transpiler.transpile(
@@ -207,6 +228,7 @@ class FedAvgStrategy(FLStrategyBase):
FedAvgStrategy: this is model averaging optimization proposed in
H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017
"""
+
def __init__(self):
super(FedAvgStrategy, self).__init__()
@@ -216,28 +238,35 @@ class FedAvgStrategy(FLStrategyBase):
"""
optimizer.minimize(losses[0])
- def _build_trainer_program_for_job(
- self, trainer_id=0, program=None,
- ps_endpoints=[], trainers=0,
- sync_mode=True, startup_program=None,
- job=None):
+ def _build_trainer_program_for_job(self,
+ trainer_id=0,
+ program=None,
+ ps_endpoints=[],
+ trainers=0,
+ sync_mode=True,
+ startup_program=None,
+ job=None):
transpiler = FLDistributeTranspiler()
- transpiler.transpile(trainer_id,
- program=program,
- pservers=",".join(ps_endpoints),
- trainers=trainers,
- sync_mode=sync_mode,
- startup_program=startup_program)
+ transpiler.transpile(
+ trainer_id,
+ program=program,
+ pservers=",".join(ps_endpoints),
+ trainers=trainers,
+ sync_mode=sync_mode,
+ startup_program=startup_program)
recv, main, send = transpiler.get_trainer_program()
job._trainer_startup_programs.append(startup_program)
job._trainer_main_programs.append(main)
job._trainer_send_programs.append(send)
job._trainer_recv_programs.append(recv)
- def _build_server_programs_for_job(
- self, program=None, ps_endpoints=[],
- trainers=0, sync_mode=True,
- startup_program=None, job=None):
+ def _build_server_programs_for_job(self,
+ program=None,
+ ps_endpoints=[],
+ trainers=0,
+ sync_mode=True,
+ startup_program=None,
+ job=None):
transpiler = FLDistributeTranspiler()
trainer_id = 0
transpiler.transpile(
@@ -262,6 +291,7 @@ class SecAggStrategy(FedAvgStrategy):
Practical Secure Aggregation for Privacy-Preserving Machine Learning,
The 24th ACM Conference on Computer and Communications Security ( CCS2017 ).
"""
+
def __init__(self):
super(SecAggStrategy, self).__init__()
self._param_name_list = []
diff --git a/paddle_fl/core/submitter/client_base.py b/paddle_fl/core/submitter/client_base.py
index 43d9ece6e77174fac01235ae10852a05ef7a5e66..0f2e977906d20966957e290d9d4e1f6edeb3e8bc 100644
--- a/paddle_fl/core/submitter/client_base.py
+++ b/paddle_fl/core/submitter/client_base.py
@@ -1,10 +1,25 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import sys
import os
+
class CloudClient(object):
def __init__(self):
pass
-
+
def generate_submit_sh(self, job_dir):
with open() as fout:
pass
@@ -16,6 +31,7 @@ class CloudClient(object):
def submit(self, **kwargs):
pass
+
class HPCClient(object):
def __init__(self):
self.conf_dict = {}
@@ -70,27 +86,20 @@ class HPCClient(object):
fout.write("#!/bin/bash\n")
fout.write("unset http_proxy\n")
fout.write("unset https_proxy\n")
- fout.write("export HADOOP_HOME={}\n".format(
- self.hadoop_home))
+ fout.write("export HADOOP_HOME={}\n".format(self.hadoop_home))
fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}"
" -Dfs.default.name={} -rmr {}\n".format(
- self.ugi,
- self.hdfs_path,
- self.hdfs_output))
+ self.ugi, self.hdfs_path, self.hdfs_output))
fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem))
fout.write("{}/bin/qsub_f -N {} --conf qsub.conf "
"--hdfs {} --ugi {} --hout {} --files ./package "
"-l nodes={},walltime=1000:00:00,pmem-hard={},"
"pcpu-soft={},pnetin-soft=1000,"
"pnetout-soft=1000 job.sh\n".format(
- self.hpc_home,
- self.task_name,
- self.hdfs_path,
- self.ugi,
- self.hdfs_output,
+ self.hpc_home, self.task_name, self.hdfs_path,
+ self.ugi, self.hdfs_output,
int(self.worker_nodes) + int(self.server_nodes),
- self.mpi_node_mem,
- self.pcpu))
+ self.mpi_node_mem, self.pcpu))
def generate_job_sh(self, job_dir):
with open("{}/job.sh".format(job_dir), "w") as fout:
@@ -98,17 +107,23 @@ class HPCClient(object):
fout.write("WORKDIR=`pwd`\n")
fout.write("mpirun -npernode 1 mv package/* ./\n")
fout.write("echo 'current dir: '$WORKDIR\n")
- fout.write("mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n")
- fout.write("export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n")
+ fout.write(
+ "mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n")
+ fout.write(
+ "export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n")
fout.write("mpirun -npernode 1 python/bin/python -m pip install "
"{} --index-url=http://pip.baidu.com/pypi/simple "
"--trusted-host pip.baidu.com > /dev/null\n".format(
self.wheel))
fout.write("export PATH=python/bin:$PATH\n")
if self.monitor_cmd != "":
- fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile "
- "${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".format(self.monitor_cmd))
- fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n")
+ fout.write(
+ "mpirun -npernode 1 -timestamp-output -tag-output -machinefile "
+ "${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".
+ format(self.monitor_cmd))
+ fout.write(
+ "mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n"
+ )
fout.write("if [[ $? -ne 0 ]]; then\n")
fout.write(" echo 'Failed to run mpi!' 1>&2\n")
fout.write(" exit 1\n")
@@ -150,4 +165,5 @@ class HPCClient(object):
# generate job.sh
self.generate_qsub_conf(jobdir)
# run submit
- os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(jobdir))
+ os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(
+ jobdir))
diff --git a/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py b/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py
index 751a9b080047a407aeed62981906d860ffad0457..18a910df5193ec16c0a820339bc0c84ab3820408 100644
Binary files a/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py and b/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py differ
diff --git a/paddle_fl/core/trainer/diffiehellman/__init__.py b/paddle_fl/core/trainer/diffiehellman/__init__.py
index a1143ffe5f7d034f11053e1ea5dbe38be50064f6..cd5e8bc0757137eee8694e600c410d49db37a78a 100644
--- a/paddle_fl/core/trainer/diffiehellman/__init__.py
+++ b/paddle_fl/core/trainer/diffiehellman/__init__.py
@@ -2,7 +2,6 @@
#
# (c) Chris von Csefalvay, 2015.
-
"""
__init__.py is responsible for [brief description here].
"""
diff --git a/paddle_fl/core/trainer/diffiehellman/decorators.py b/paddle_fl/core/trainer/diffiehellman/decorators.py
index c6f3248e6857d6ca6b60ede18d8f2377a8e2f504..2483e75e5f4ba23106aa15ab377e6669465b968e 100644
--- a/paddle_fl/core/trainer/diffiehellman/decorators.py
+++ b/paddle_fl/core/trainer/diffiehellman/decorators.py
@@ -1,6 +1,5 @@
# coding=utf-8
-
#
# The MIT License (MIT)
#
@@ -21,8 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
-
-
"""
decorators declares some decorators that ensure the object has the
correct keys declared when need be.
diff --git a/paddle_fl/core/trainer/diffiehellman/diffiehellman.py b/paddle_fl/core/trainer/diffiehellman/diffiehellman.py
index d5e1b196f56e5694b4112a8ff82022ee3bac7910..9b4e7304d24cc5a4d1ba5abc25e7bb69e2941c9d 100644
--- a/paddle_fl/core/trainer/diffiehellman/diffiehellman.py
+++ b/paddle_fl/core/trainer/diffiehellman/diffiehellman.py
@@ -20,10 +20,6 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
-
-
-
-
"""
diffiehellmann declares the main key exchange class.
"""
@@ -41,18 +37,17 @@ import os
try:
from ssl import RAND_bytes
rng = RAND_bytes
-except(AttributeError, ImportError):
+except (AttributeError, ImportError):
rng = os.urandom
+
class DiffieHellman:
"""
Implements the Diffie-Hellman key exchange protocol.
"""
- def __init__(self,
- group=18,
- key_length=640):
+ def __init__(self, group=18, key_length=640):
self.key_length = max(200, key_length)
self.generator = PRIMES[group]["generator"]
@@ -81,7 +76,8 @@ class DiffieHellman:
self.private_key = key
def verify_public_key(self, other_public_key):
- return self.prime - 1 > other_public_key > 2 and pow(other_public_key, (self.prime - 1) // 2, self.prime) == 1
+ return self.prime - 1 > other_public_key > 2 and pow(
+ other_public_key, (self.prime - 1) // 2, self.prime) == 1
@requires_private_key
def generate_public_key(self):
@@ -91,9 +87,7 @@ class DiffieHellman:
:return: void
:rtype: void
"""
- self.public_key = pow(self.generator,
- self.private_key,
- self.prime)
+ self.public_key = pow(self.generator, self.private_key, self.prime)
@requires_private_key
def generate_shared_secret(self, other_public_key, echo_return_key=False):
@@ -110,16 +104,17 @@ class DiffieHellman:
if self.verify_public_key(other_public_key) is False:
raise MalformedPublicKey
- self.shared_secret = pow(other_public_key,
- self.private_key,
+ self.shared_secret = pow(other_public_key, self.private_key,
self.prime)
try:
#python3
- shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big')
+ shared_secret_as_bytes = self.shared_secret.to_bytes(
+ self.shared_secret.bit_length() // 8 + 1, byteorder='big')
except:
#python2
length = self.shared_secret.bit_length() // 8 + 1
- shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:]
+ shared_secret_as_bytes = ('%%0%dx' % (
+ length << 1) % self.shared_secret).decode('hex')[-length:]
_h = sha256()
_h.update(bytes(shared_secret_as_bytes))
diff --git a/paddle_fl/core/trainer/diffiehellman/exceptions.py b/paddle_fl/core/trainer/diffiehellman/exceptions.py
index 289062c9477e7f00ea27cfc05d56d570b8de76ad..2205c4eba7f7e33eae614b2f1975231e3e8ed956 100644
--- a/paddle_fl/core/trainer/diffiehellman/exceptions.py
+++ b/paddle_fl/core/trainer/diffiehellman/exceptions.py
@@ -2,7 +2,6 @@
#
# (c) Chris von Csefalvay, 2015.
-
"""
exceptions is responsible for exception handling etc.
"""
diff --git a/paddle_fl/core/trainer/diffiehellman/primes.py b/paddle_fl/core/trainer/diffiehellman/primes.py
index 624926e4c367e1937c8e23e468be78795d04611e..265bdcc99eebebfe272fbcc85ba1f350de27d0cb 100644
--- a/paddle_fl/core/trainer/diffiehellman/primes.py
+++ b/paddle_fl/core/trainer/diffiehellman/primes.py
@@ -1,6 +1,5 @@
# coding=utf-8
-
#
# The MIT License (MIT)
#
@@ -25,34 +24,39 @@
# Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman
# groups for Internet Key Exchange (IKE)_.
#
-
"""
primes holds the RFC 3526 MODP primes and their generators.
"""
PRIMES = {
5: {
- "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF,
+ "prime":
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF,
"generator": 2
},
14: {
- "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF,
+ "prime":
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF,
"generator": 2
},
15: {
- "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF,
+ "prime":
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF,
"generator": 2
},
16: {
- "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF,
+ "prime":
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF,
"generator": 2
},
17: {
- "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF,
+ "prime":
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF,
"generator": 2
},
18: {
- "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF,
+ "prime":
+ 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF,
"generator": 2
},
}
diff --git a/paddle_fl/core/trainer/fl_trainer.py b/paddle_fl/core/trainer/fl_trainer.py
index 00cb67e72617c7edf73e4714b307090d78493e88..e768f259716a71fe5db75d2a35fb4480b335465b 100755
--- a/paddle_fl/core/trainer/fl_trainer.py
+++ b/paddle_fl/core/trainer/fl_trainer.py
@@ -19,6 +19,7 @@ import hmac
import hashlib
from .diffiehellman.diffiehellman import DiffieHellman
+
class FLTrainerFactory(object):
def __init__(self):
pass
@@ -65,9 +66,7 @@ class FLTrainer(object):
def run(self, feed, fetch):
self._logger.debug("begin to run")
- self.exe.run(self._main_program,
- feed=feed,
- fetch_list=fetch)
+ self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
self._logger.debug("end to run current batch")
self.cur_step += 1
@@ -119,37 +118,34 @@ class FedAvgTrainer(FLTrainer):
def reset(self):
self.cur_step = 0
- def run_with_epoch(self,reader,feeder,fetch,num_epoch):
+ def run_with_epoch(self, reader, feeder, fetch, num_epoch):
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
epoch = 0
for i in range(num_epoch):
- for data in reader():
- self.exe.run(self._main_program,
- feed=feeder.feed(data),
- fetch_list=fetch)
- self.cur_step += 1
- epoch += 1
+ for data in reader():
+ self.exe.run(self._main_program,
+ feed=feeder.feed(data),
+ fetch_list=fetch)
+ self.cur_step += 1
+ epoch += 1
self._logger.debug("begin to run send program")
self.exe.run(self._send_program)
+
def run(self, feed, fetch):
- self._logger.debug("begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" %
- (self.cur_step, self._step))
+ self._logger.debug(
+ "begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" %
+ (self.cur_step, self._step))
if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
self._logger.debug("begin to run current step")
- loss = self.exe.run(self._main_program,
- feed=feed,
- fetch_list=fetch)
+ loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program")
self.exe.run(self._send_program)
self.cur_step += 1
return loss
-
-
-
class SecAggTrainer(FLTrainer):
@@ -207,24 +203,24 @@ class SecAggTrainer(FLTrainer):
self.cur_step = 0
def run(self, feed, fetch):
- self._logger.debug("begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
- (self.cur_step, self._step))
+ self._logger.debug(
+ "begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
+ (self.cur_step, self._step))
if self.cur_step % self._step == 0:
self._logger.debug("begin to run recv program")
self.exe.run(self._recv_program)
scope = fluid.global_scope()
self._logger.debug("begin to run current step")
- loss = self.exe.run(self._main_program,
- feed=feed,
- fetch_list=fetch)
+ loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch)
if self.cur_step % self._step == 0:
self._logger.debug("begin to run send program")
noise = 0.0
scale = pow(10.0, 5)
- digestmod=hashlib.sha256
+ digestmod = hashlib.sha256
# 1. load priv key and other's pub key
dh = DiffieHellman(group=15, key_length=256)
- dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt")
+ dh.load_private_key(self._key_dir + str(self._trainer_id) +
+ "_priv_key.txt")
key = str(self._step_id).encode("utf-8")
for i in range(self._trainer_num):
if i != self._trainer_id:
@@ -232,7 +228,8 @@ class SecAggTrainer(FLTrainer):
public_key = int(f.read())
dh.generate_shared_secret(public_key, echo_return_key=True)
msg = dh.shared_key.encode("utf-8")
- hex_res1 = hmac.new(key=key, msg=msg, digestmod=digestmod).hexdigest()
+ hex_res1 = hmac.new(key=key, msg=msg,
+ digestmod=digestmod).hexdigest()
current_noise = int(hex_res1[0:8], 16) / scale
if i > self._trainer_id:
noise = noise + current_noise
@@ -241,9 +238,11 @@ class SecAggTrainer(FLTrainer):
scope = fluid.global_scope()
for param_name in self._param_name_list:
- fluid.global_scope().var(param_name + str(self._trainer_id)).get_tensor().set(
- numpy.array(scope.find_var(param_name + str(self._trainer_id)).get_tensor()) + noise, fluid.CPUPlace())
+ fluid.global_scope().var(param_name + str(
+ self._trainer_id)).get_tensor().set(
+ numpy.array(
+ scope.find_var(param_name + str(self._trainer_id))
+ .get_tensor()) + noise, fluid.CPUPlace())
self.exe.run(self._send_program)
self.cur_step += 1
return loss
-
diff --git a/paddle_fl/dataset/__init__.py b/paddle_fl/dataset/__init__.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..abf198b97e6e818e1fbe59006f98492640bcee54 100644
--- a/paddle_fl/dataset/__init__.py
+++ b/paddle_fl/dataset/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddle_fl/dataset/femnist.py b/paddle_fl/dataset/femnist.py
index a4940b84a2edbfcece62a736be541d3dcf59c8a3..80a5b47d5a9a3f2e381c3a7712b428f53ae98a3f 100644
--- a/paddle_fl/dataset/femnist.py
+++ b/paddle_fl/dataset/femnist.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import requests
import os
import json
@@ -5,73 +19,84 @@ import tarfile
import random
-def download(url,tar_path):
- r = requests.get(url)
- with open(tar_path,'wb') as f:
- f.write(r.content)
-
-def extract(tar_path,target_path):
- tar = tarfile.open(tar_path, "r:gz")
- file_names = tar.getnames()
- for file_name in file_names:
- tar.extract(file_name,target_path)
-
- tar.close()
-
-def train(trainer_id,inner_step,batch_size,count_by_step):
- target_path = "trainer%d_data" % trainer_id
- data_path = target_path + "/femnist_data"
- tar_path = data_path + ".tar.gz"
- if not os.path.exists(target_path):
- os.system("mkdir trainer%d_data" % trainer_id)
- if not os.path.exists(data_path):
- print("Preparing data...")
- if not os.path.exists(tar_path):
- download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
- extract(tar_path,target_path)
- def train_data():
- train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
- json_train = json.load(train_file)
- users = json_train["users"]
- rand = random.randrange(0,len(users)) # random choose a user from each trainer
- cur_user = users[rand]
- print('training using '+cur_user)
- train_images = json_train["user_data"][cur_user]['x']
- train_labels = json_train["user_data"][cur_user]['y']
- if count_by_step:
- for i in range(inner_step*batch_size):
- yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))]
- else:
- for i in range(len(train_images)):
- yield train_images[i], train_labels[i]
-
- train_file.close()
-
- return train_data
-
-def test(trainer_id,inner_step,batch_size,count_by_step):
- target_path = "trainer%d_data" % trainer_id
- data_path = target_path + "/femnist_data"
- tar_path = data_path + ".tar.gz"
- if not os.path.exists(target_path):
- os.system("mkdir trainer%d_data" % trainer_id)
- if not os.path.exists(data_path):
- print("Preparing data...")
- if not os.path.exists(tar_path):
- download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
- extract(tar_path,target_path)
- def test_data():
- test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
- json_test = json.load(test_file)
- users = json_test["users"]
- for user in users:
- test_images = json_test['user_data'][user]['x']
- test_labels = json_test['user_data'][user]['y']
- for i in range(len(test_images)):
- yield test_images[i], test_labels[i]
-
- test_file.close()
-
- return test_data
-
-
+def download(url, tar_path):
+ r = requests.get(url)
+ with open(tar_path, 'wb') as f:
+ f.write(r.content)
+
+
+def extract(tar_path, target_path):
+ tar = tarfile.open(tar_path, "r:gz")
+ file_names = tar.getnames()
+ for file_name in file_names:
+ tar.extract(file_name, target_path)
+
+ tar.close()
+
+
+def train(trainer_id, inner_step, batch_size, count_by_step):
+ target_path = "trainer%d_data" % trainer_id
+ data_path = target_path + "/femnist_data"
+ tar_path = data_path + ".tar.gz"
+ if not os.path.exists(target_path):
+ os.system("mkdir trainer%d_data" % trainer_id)
+ if not os.path.exists(data_path):
+ print("Preparing data...")
+ if not os.path.exists(tar_path):
+ download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
+ tar_path)
+ extract(tar_path, target_path)
+
+ def train_data():
+ train_file = open(
+ "./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json"
+ % (trainer_id, trainer_id), 'r')
+ json_train = json.load(train_file)
+ users = json_train["users"]
+ rand = random.randrange(
+ 0, len(users)) # random choose a user from each trainer
+ cur_user = users[rand]
+ print('training using ' + cur_user)
+ train_images = json_train["user_data"][cur_user]['x']
+ train_labels = json_train["user_data"][cur_user]['y']
+ if count_by_step:
+ for i in range(inner_step * batch_size):
+ yield train_images[i % (len(train_images))], train_labels[i % (
+ len(train_images))]
+ else:
+ for i in range(len(train_images)):
+ yield train_images[i], train_labels[i]
+
+ train_file.close()
+
+ return train_data
+
+
+def test(trainer_id, inner_step, batch_size, count_by_step):
+ target_path = "trainer%d_data" % trainer_id
+ data_path = target_path + "/femnist_data"
+ tar_path = data_path + ".tar.gz"
+ if not os.path.exists(target_path):
+ os.system("mkdir trainer%d_data" % trainer_id)
+ if not os.path.exists(data_path):
+ print("Preparing data...")
+ if not os.path.exists(tar_path):
+ download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",
+ tar_path)
+ extract(tar_path, target_path)
+
+ def test_data():
+ test_file = open(
+ "./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json"
+ % (trainer_id, trainer_id), 'r')
+ json_test = json.load(test_file)
+ users = json_test["users"]
+ for user in users:
+ test_images = json_test['user_data'][user]['x']
+ test_labels = json_test['user_data'][user]['y']
+ for i in range(len(test_images)):
+ yield test_images[i], test_labels[i]
+
+ test_file.close()
+
+ return test_data
diff --git a/paddle_fl/examples/ctr_demo/fl_master.py b/paddle_fl/examples/ctr_demo/fl_master.py
index fcda61f6d0f051c81ceb39db96185207ac4e392e..57ac3812f93550e056c796db8283591ab39c1412 100644
--- a/paddle_fl/examples/ctr_demo/fl_master.py
+++ b/paddle_fl/examples/ctr_demo/fl_master.py
@@ -1,8 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle.fluid as fluid
import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
+
class Model(object):
def __init__(self):
pass
@@ -12,7 +27,8 @@ class Model(object):
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
- self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label)
+ self.sum_cost = fluid.layers.cross_entropy(
+ input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program()
@@ -34,8 +50,8 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.1)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
-job_generator.set_infer_feed_and_target_names(
- [x.name for x in inputs], [model.predict.name])
+job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
+ [model.predict.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
diff --git a/paddle_fl/examples/ctr_demo/fl_scheduler.py b/paddle_fl/examples/ctr_demo/fl_scheduler.py
index 9dc5d84497b376d1aa7fd5731771b0799343c2ec..fe93b55d96b73734437b3387bdd04340248d7888 100644
--- a/paddle_fl/examples/ctr_demo/fl_scheduler.py
+++ b/paddle_fl/examples/ctr_demo/fl_scheduler.py
@@ -1,9 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2
server_num = 1
# Define the number of worker/server and the port for scheduler
-scheduler = FLScheduler(worker_num,server_num,port=9091)
+scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
diff --git a/paddle_fl/examples/ctr_demo/fl_server.py b/paddle_fl/examples/ctr_demo/fl_server.py
index 529df8da4079fbbd217c58a857f7ab8a3c307586..2bc79fff528bc8e52e353cc56ca17618d6f4acca 100644
--- a/paddle_fl/examples/ctr_demo/fl_server.py
+++ b/paddle_fl/examples/ctr_demo/fl_server.py
@@ -21,8 +21,8 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
-job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
+job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
-server._current_ep = "127.0.0.1:8181" # IP address for server
+server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
print("connect")
diff --git a/paddle_fl/examples/ctr_demo/fl_trainer.py b/paddle_fl/examples/ctr_demo/fl_trainer.py
index 0bcfb19544741eb1ef7158dcbf6136c2e34d1b9c..9b4b490c9abc54962da11d7f5d88077e2f558f33 100644
--- a/paddle_fl/examples/ctr_demo/fl_trainer.py
+++ b/paddle_fl/examples/ctr_demo/fl_trainer.py
@@ -1,10 +1,29 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
import sys
import logging
import time
-logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
+logging.basicConfig(
+ filename="test.log",
+ filemode="w",
+ format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
+ datefmt="%d-%M-%Y %H:%M:%S",
+ level=logging.DEBUG)
def reader():
@@ -15,13 +34,14 @@ def reader():
data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
yield data_dict
-trainer_id = int(sys.argv[1]) # trainer id for each guest
+
+trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
-job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
+job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
-trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
+trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
@@ -37,4 +57,3 @@ while not trainer.stop():
epoch_id += 1
if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder)
-
diff --git a/paddle_fl/examples/dpsgd_demo/fl_master.py b/paddle_fl/examples/dpsgd_demo/fl_master.py
index f79472e8a7def3f3851d6b7eaf544a27c6eb8cc9..218f49ecb28b320db2df1f19751528d716bdb070 100644
--- a/paddle_fl/examples/dpsgd_demo/fl_master.py
+++ b/paddle_fl/examples/dpsgd_demo/fl_master.py
@@ -1,19 +1,39 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle.fluid as fluid
import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
import math
+
class Model(object):
def __init__(self):
pass
def lr_network(self):
- self.inputs = fluid.layers.data(name='img', shape=[1, 28, 28], dtype="float32")
- self.label = fluid.layers.data(name='label', shape=[1],dtype='int64')
- self.predict = fluid.layers.fc(input=self.inputs, size=10, act='softmax')
- self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=self.label)
- self.accuracy = fluid.layers.accuracy(input=self.predict, label=self.label)
+ self.inputs = fluid.layers.data(
+ name='img', shape=[1, 28, 28], dtype="float32")
+ self.label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+ self.predict = fluid.layers.fc(input=self.inputs,
+ size=10,
+ act='softmax')
+ self.sum_cost = fluid.layers.cross_entropy(
+ input=self.predict, label=self.label)
+ self.accuracy = fluid.layers.accuracy(
+ input=self.predict, label=self.label)
self.loss = fluid.layers.mean(self.sum_cost)
self.startup_program = fluid.default_startup_program()
@@ -23,7 +43,7 @@ model.lr_network()
STEP_EPSILON = 0.1
DELTA = 0.00001
-SIGMA = math.sqrt(2.0 * math.log(1.25/DELTA)) / STEP_EPSILON
+SIGMA = math.sqrt(2.0 * math.log(1.25 / DELTA)) / STEP_EPSILON
CLIP = 4.0
batch_size = 64
@@ -33,7 +53,8 @@ job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names(
- [model.inputs.name, model.label.name], [model.loss.name, model.accuracy.name])
+ [model.inputs.name, model.label.name],
+ [model.loss.name, model.accuracy.name])
build_strategy = FLStrategyFactory()
build_strategy.dpsgd = True
diff --git a/paddle_fl/examples/dpsgd_demo/fl_scheduler.py b/paddle_fl/examples/dpsgd_demo/fl_scheduler.py
index f8ea641e2fa356102ffc08eec179e84ca1993f7d..29bba5a610dde1253d35ce64d8776ba585b8d88f 100644
--- a/paddle_fl/examples/dpsgd_demo/fl_scheduler.py
+++ b/paddle_fl/examples/dpsgd_demo/fl_scheduler.py
@@ -1,9 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
#Define number of worker/server and the port for scheduler
-scheduler = FLScheduler(worker_num,server_num,port=9091)
+scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
diff --git a/paddle_fl/examples/dpsgd_demo/fl_server.py b/paddle_fl/examples/dpsgd_demo/fl_server.py
index 39056e82d99fb924f52c201e0fb230b6bc1626a1..3740982b54613169074c95e510809d55ed54121b 100644
--- a/paddle_fl/examples/dpsgd_demo/fl_server.py
+++ b/paddle_fl/examples/dpsgd_demo/fl_server.py
@@ -21,7 +21,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
-job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
+job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
-server._current_ep = "127.0.0.1:8181" # IP address for server
+server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
diff --git a/paddle_fl/examples/dpsgd_demo/fl_trainer.py b/paddle_fl/examples/dpsgd_demo/fl_trainer.py
index 074368194e20defdea2f1151c9a0f940ed613189..f0b4a8a188456ce654056f3d531f90aaca355053 100644
--- a/paddle_fl/examples/dpsgd_demo/fl_trainer.py
+++ b/paddle_fl/examples/dpsgd_demo/fl_trainer.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy
@@ -7,44 +21,51 @@ import paddle.fluid as fluid
import logging
import math
-logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
+logging.basicConfig(
+ filename="test.log",
+ filemode="w",
+ format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
+ datefmt="%d-%M-%Y %H:%M:%S",
+ level=logging.DEBUG)
-trainer_id = int(sys.argv[1]) # trainer id for each guest
+trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
-job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer
+job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
-trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
+trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
test_program = trainer._main_program.clone(for_test=True)
train_reader = paddle.batch(
- paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
- batch_size=64)
-test_reader = paddle.batch(
- paddle.dataset.mnist.test(), batch_size=64)
+ paddle.reader.shuffle(
+ paddle.dataset.mnist.train(), buf_size=500),
+ batch_size=64)
+test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64)
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace())
+
def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = []
for test_data in train_test_reader():
- acc_np = trainer.exe.run(
- program=train_test_program,
- feed=train_test_feed.feed(test_data),
- fetch_list=["accuracy_0.tmp_0"])
+ acc_np = trainer.exe.run(program=train_test_program,
+ feed=train_test_feed.feed(test_data),
+ fetch_list=["accuracy_0.tmp_0"])
acc_set.append(float(acc_np[0]))
acc_val_mean = numpy.array(acc_set).mean()
return acc_val_mean
+
def compute_privacy_budget(sample_ratio, epsilon, step, delta):
E = 2 * epsilon * math.sqrt(step * sample_ratio)
print("({0}, {1})-DP".format(E, delta))
+
output_folder = "model_node%d" % trainer_id
epoch_id = 0
step = 0
@@ -64,7 +85,8 @@ while not trainer.stop():
train_test_feed=feeder)
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
- compute_privacy_budget(sample_ratio=0.001, epsilon=0.1, step=step, delta=0.00001)
+ compute_privacy_budget(
+ sample_ratio=0.001, epsilon=0.1, step=step, delta=0.00001)
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder)
diff --git a/paddle_fl/examples/femnist_demo/fl_master.py b/paddle_fl/examples/femnist_demo/fl_master.py
index dd04165b917481dacf457bb6af49b7e976c28cdb..5102aa37009987010d37c7ab83b4c29a1cdcb7dc 100644
--- a/paddle_fl/examples/femnist_demo/fl_master.py
+++ b/paddle_fl/examples/femnist_demo/fl_master.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle.fluid as fluid
import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
@@ -9,14 +23,31 @@ class Model(object):
pass
def cnn(self):
- self.inputs = fluid.layers.data(name='img', shape=[1, 28, 28], dtype="float32")
- self.label = fluid.layers.data(name='label', shape=[1],dtype='int64')
- self.conv_pool_1 = fluid.nets.simple_img_conv_pool(input=self.inputs,num_filters=20,filter_size=5,pool_size=2,pool_stride=2,act='relu')
- self.conv_pool_2 = fluid.nets.simple_img_conv_pool(input=self.conv_pool_1,num_filters=50,filter_size=5,pool_size=2,pool_stride=2,act='relu')
-
- self.predict = self.predict = fluid.layers.fc(input=self.conv_pool_2, size=62, act='softmax')
- self.cost = fluid.layers.cross_entropy(input=self.predict, label=self.label)
- self.accuracy = fluid.layers.accuracy(input=self.predict, label=self.label)
+ self.inputs = fluid.layers.data(
+ name='img', shape=[1, 28, 28], dtype="float32")
+ self.label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+ self.conv_pool_1 = fluid.nets.simple_img_conv_pool(
+ input=self.inputs,
+ num_filters=20,
+ filter_size=5,
+ pool_size=2,
+ pool_stride=2,
+ act='relu')
+ self.conv_pool_2 = fluid.nets.simple_img_conv_pool(
+ input=self.conv_pool_1,
+ num_filters=50,
+ filter_size=5,
+ pool_size=2,
+ pool_stride=2,
+ act='relu')
+
+ self.predict = self.predict = fluid.layers.fc(input=self.conv_pool_2,
+ size=62,
+ act='softmax')
+ self.cost = fluid.layers.cross_entropy(
+ input=self.predict, label=self.label)
+ self.accuracy = fluid.layers.accuracy(
+ input=self.predict, label=self.label)
self.loss = fluid.layers.mean(self.cost)
self.startup_program = fluid.default_startup_program()
@@ -30,8 +61,8 @@ job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names(
- [model.inputs.name, model.label.name], [model.loss.name, model.accuracy.name])
-
+ [model.inputs.name, model.label.name],
+ [model.loss.name, model.accuracy.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
diff --git a/paddle_fl/examples/femnist_demo/fl_scheduler.py b/paddle_fl/examples/femnist_demo/fl_scheduler.py
index 346529fd6caadef06d1e46078fa873da264a7507..bc75fa66eae90aa487ecb4b9f7e7bf84c048f5d7 100644
--- a/paddle_fl/examples/femnist_demo/fl_scheduler.py
+++ b/paddle_fl/examples/femnist_demo/fl_scheduler.py
@@ -1,9 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
# Define the number of worker/server and the port for scheduler
-scheduler = FLScheduler(worker_num,server_num,port=9091)
+scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
diff --git a/paddle_fl/examples/femnist_demo/fl_server.py b/paddle_fl/examples/femnist_demo/fl_server.py
index cd558deb4ba577e5d0dbc74ecc90d0e22d278d8e..736b11f302b8653592db5b82a614985e4bdc7833 100644
--- a/paddle_fl/examples/femnist_demo/fl_server.py
+++ b/paddle_fl/examples/femnist_demo/fl_server.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle_fl as fl
import paddle.fluid as fluid
from paddle_fl.core.server.fl_server import FLServer
@@ -7,7 +21,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
-job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
+job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
-server._current_ep = "127.0.0.1:8181" # IP address for server
+server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
diff --git a/paddle_fl/examples/femnist_demo/fl_trainer.py b/paddle_fl/examples/femnist_demo/fl_trainer.py
index dce2c8af02ce3a3be58f0a6fe03eca18a882472a..03675a2472ca8f07364dee9b4510950379985a67 100644
--- a/paddle_fl/examples/femnist_demo/fl_trainer.py
+++ b/paddle_fl/examples/femnist_demo/fl_trainer.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import paddle_fl.dataset.femnist
@@ -8,16 +22,21 @@ import paddle.fluid as fluid
import logging
import math
-logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
+logging.basicConfig(
+ filename="test.log",
+ filemode="w",
+ format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
+ datefmt="%d-%M-%Y %H:%M:%S",
+ level=logging.DEBUG)
-trainer_id = int(sys.argv[1]) # trainer id for each guest
+trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
-job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
+job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
print(job._target_names)
trainer = FLTrainerFactory().create_fl_trainer(job)
-trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
+trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
print(trainer._step)
test_program = trainer._main_program.clone(for_test=True)
@@ -26,26 +45,26 @@ img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace())
+
def train_test(train_test_program, train_test_feed, train_test_reader):
- acc_set = []
- for test_data in train_test_reader():
- acc_np = trainer.exe.run(
- program=train_test_program,
- feed=train_test_feed.feed(test_data),
- fetch_list=["accuracy_0.tmp_0"])
- acc_set.append(float(acc_np[0]))
- acc_val_mean = numpy.array(acc_set).mean()
- return acc_val_mean
+ acc_set = []
+ for test_data in train_test_reader():
+ acc_np = trainer.exe.run(program=train_test_program,
+ feed=train_test_feed.feed(test_data),
+ fetch_list=["accuracy_0.tmp_0"])
+ acc_set.append(float(acc_np[0]))
+ acc_val_mean = numpy.array(acc_set).mean()
+ return acc_val_mean
+
epoch_id = 0
step = 0
epoch = 3000
count_by_step = False
if count_by_step:
- output_folder = "model_node%d" % trainer_id
-else:
- output_folder = "model_node%d_epoch" % trainer_id
-
+ output_folder = "model_node%d" % trainer_id
+else:
+ output_folder = "model_node%d_epoch" % trainer_id
while not trainer.stop():
count = 0
@@ -55,24 +74,35 @@ while not trainer.stop():
print("epoch %d start train" % (epoch_id))
#train_data,test_data= data_generater(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step)
train_reader = paddle.batch(
- paddle.reader.shuffle(paddle_fl.dataset.femnist.train(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), buf_size=500),
+ paddle.reader.shuffle(
+ paddle_fl.dataset.femnist.train(
+ trainer_id,
+ inner_step=trainer._step,
+ batch_size=64,
+ count_by_step=count_by_step),
+ buf_size=500),
batch_size=64)
test_reader = paddle.batch(
- paddle_fl.dataset.femnist.test(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), batch_size=64)
-
+ paddle_fl.dataset.femnist.test(
+ trainer_id,
+ inner_step=trainer._step,
+ batch_size=64,
+ count_by_step=count_by_step),
+ batch_size=64)
+
if count_by_step:
- for step_id, data in enumerate(train_reader()):
+ for step_id, data in enumerate(train_reader()):
acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"])
step += 1
count += 1
print(count)
- if count % trainer._step == 0:
+ if count % trainer._step == 0:
break
# print("acc:%.3f" % (acc[0]))
else:
- trainer.run_with_epoch(train_reader,feeder,fetch=["accuracy_0.tmp_0"],num_epoch=1)
-
+ trainer.run_with_epoch(
+ train_reader, feeder, fetch=["accuracy_0.tmp_0"], num_epoch=1)
acc_val = train_test(
train_test_program=test_program,
@@ -80,6 +110,6 @@ while not trainer.stop():
train_test_feed=feeder)
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
- if trainer_id == 0:
+ if trainer_id == 0:
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder)
diff --git a/paddle_fl/examples/gru4rec_demo/fl_master.py b/paddle_fl/examples/gru4rec_demo/fl_master.py
index 48433ecae963f121a26281a2d756b9514ad979b6..e5a20d971e0eaf9ab4c0fd96910a41a7227d9d77 100644
--- a/paddle_fl/examples/gru4rec_demo/fl_master.py
+++ b/paddle_fl/examples/gru4rec_demo/fl_master.py
@@ -1,8 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle.fluid as fluid
import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
+
class Model(object):
def __init__(self):
pass
@@ -34,7 +49,8 @@ class Model(object):
size=hid_size * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
- low=init_low_bound, high=init_high_bound),
+ low=init_low_bound,
+ high=init_high_bound),
learning_rate=gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru(
input=fc0,
@@ -45,12 +61,13 @@ class Model(object):
learning_rate=gru_lr_x))
self.fc = fluid.layers.fc(input=gru_h0,
- size=vocab_size,
- act='softmax',
- param_attr=fluid.ParamAttr(
- initializer=fluid.initializer.Uniform(
- low=init_low_bound, high=init_high_bound),
- learning_rate=fc_lr_x))
+ size=vocab_size,
+ act='softmax',
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=init_low_bound,
+ high=init_high_bound),
+ learning_rate=fc_lr_x))
cost = fluid.layers.cross_entropy(
input=self.fc, label=self.dst_wordseq)
self.acc = fluid.layers.accuracy(
@@ -59,7 +76,6 @@ class Model(object):
self.startup_program = fluid.default_startup_program()
-
model = Model()
model.gru4rec_network()
@@ -69,7 +85,8 @@ job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names(
- [model.src_wordseq.name, model.dst_wordseq.name], [model.loss.name, model.acc.name])
+ [model.src_wordseq.name, model.dst_wordseq.name],
+ [model.loss.name, model.acc.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
diff --git a/paddle_fl/examples/gru4rec_demo/fl_scheduler.py b/paddle_fl/examples/gru4rec_demo/fl_scheduler.py
index 346529fd6caadef06d1e46078fa873da264a7507..bc75fa66eae90aa487ecb4b9f7e7bf84c048f5d7 100644
--- a/paddle_fl/examples/gru4rec_demo/fl_scheduler.py
+++ b/paddle_fl/examples/gru4rec_demo/fl_scheduler.py
@@ -1,9 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
# Define the number of worker/server and the port for scheduler
-scheduler = FLScheduler(worker_num,server_num,port=9091)
+scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
diff --git a/paddle_fl/examples/gru4rec_demo/fl_server.py b/paddle_fl/examples/gru4rec_demo/fl_server.py
index 39056e82d99fb924f52c201e0fb230b6bc1626a1..3740982b54613169074c95e510809d55ed54121b 100644
--- a/paddle_fl/examples/gru4rec_demo/fl_server.py
+++ b/paddle_fl/examples/gru4rec_demo/fl_server.py
@@ -21,7 +21,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
-job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
+job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
-server._current_ep = "127.0.0.1:8181" # IP address for server
+server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
diff --git a/paddle_fl/examples/gru4rec_demo/fl_trainer.py b/paddle_fl/examples/gru4rec_demo/fl_trainer.py
index b8416f739d1b0dc7cb493768cf1c9283214778c4..1bf229a505aee6870c36556fb0b9b90b58067141 100644
--- a/paddle_fl/examples/gru4rec_demo/fl_trainer.py
+++ b/paddle_fl/examples/gru4rec_demo/fl_trainer.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
from paddle_fl.reader.gru4rec_reader import Gru4rec_Reader
@@ -6,21 +20,26 @@ import numpy as np
import sys
import os
import logging
-logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
+logging.basicConfig(
+ filename="test.log",
+ filemode="w",
+ format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
+ datefmt="%d-%M-%Y %H:%M:%S",
+ level=logging.DEBUG)
-trainer_id = int(sys.argv[1]) # trainer id for each guest
+trainer_id = int(sys.argv[1]) # trainer id for each guest
place = fluid.CPUPlace()
train_file_dir = "mid_data/node4/%d/" % trainer_id
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
-job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
+job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
-trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
+trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
r = Gru4rec_Reader()
-train_reader = r.reader(train_file_dir, place, batch_size = 125)
+train_reader = r.reader(train_file_dir, place, batch_size=125)
output_folder = "model_node4"
step_i = 0
@@ -30,8 +49,7 @@ while not trainer.stop():
train_step = 0
for data in train_reader():
#print(np.array(data['src_wordseq']))
- ret_avg_cost = trainer.run(feed=data,
- fetch=["mean_0.tmp_0"])
+ ret_avg_cost = trainer.run(feed=data, fetch=["mean_0.tmp_0"])
train_step += 1
if train_step == trainer._step:
break
diff --git a/paddle_fl/examples/k8s_deployment/master.yaml b/paddle_fl/examples/k8s_deployment/master.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f7e3451da77ad3288ba92182a777e0aae67d6749
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/master.yaml
@@ -0,0 +1,209 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: fl-master
+spec:
+ type: LoadBalancer
+ ports:
+ - name: fl-master
+ port: 8000
+ targetPort: 8000
+ selector:
+ app: fl-master
+
+---
+
+apiVersion: v1
+kind: Service
+metadata:
+ name: fl-scheduler
+spec:
+ type: LoadBalancer
+ ports:
+ - name: fl-scheduler
+ port: 9091
+ targetPort: 9091
+ selector:
+ app: fl-scheduler
+
+---
+
+apiVersion: v1
+kind: Service
+metadata:
+ name: fl-server
+spec:
+ type: LoadBalancer
+ ports:
+ - name: fl-server
+ port: 8181
+ targetPort: 8181
+ selector:
+ app: fl-server
+
+---
+
+apiVersion: v1
+kind: Service
+metadata:
+ name: trainer0
+spec:
+ type: LoadBalancer
+ ports:
+ - name: trainer0
+ port: 9000
+ targetPort: 9000
+ selector:
+ app: trainer0
+
+---
+
+apiVersion: v1
+kind: Service
+metadata:
+ name: trainer1
+spec:
+ type: LoadBalancer
+ ports:
+ - name: trainer1
+ port: 9001
+ targetPort: 9001
+ selector:
+ app: trainer1
+
+---
+
+apiVersion: apps/v1beta1
+kind: Deployment
+metadata:
+ name: fl-master
+ labels:
+ app: fl-master
+spec:
+ replicas: 1
+ template:
+ metadata:
+ name: fl-master
+ labels:
+ app: fl-master
+ spec:
+ containers:
+ - name: fl-master
+ image: hub.baidubce.com/paddlefl/paddlefl:v3
+ imagePullPolicy: Always
+ ports:
+ - containerPort: 8000
+ workingDir: /root/k8s_deployment/master
+ command: ['/bin/bash']
+ args: ['run_master.sh']
+
+---
+
+apiVersion: apps/v1beta1
+kind: Deployment
+metadata:
+ name: fl-scheduler
+ labels:
+ app: fl-scheduler
+spec:
+ replicas: 1
+ template:
+ metadata:
+ name: fl-scheduler
+ labels:
+ app: fl-scheduler
+ spec:
+ containers:
+ - name: fl-scheduler
+ image: hub.baidubce.com/paddlefl/paddlefl:v3
+ imagePullPolicy: Always
+ ports:
+ - containerPort: 9091
+ workingDir: /root/k8s_deployment/scheduler
+ command: ['/bin/bash']
+ args: ['run_scheduler.sh']
+
+---
+
+apiVersion: apps/v1beta1
+kind: Deployment
+metadata:
+ name: fl-server
+ labels:
+ app: fl-server
+spec:
+ replicas: 1
+ template:
+ metadata:
+ name: fl-server
+ labels:
+ app: fl-server
+ spec:
+ containers:
+ - name: fl-server
+ image: hub.baidubce.com/paddlefl/paddlefl:v3
+ imagePullPolicy: Always
+ ports:
+ - containerPort: 8181
+ workingDir: /root/k8s_deployment/server
+ command: ['/bin/bash']
+ args: ['run_server.sh']
+ env:
+ - name: POD_IP
+ valueFrom:
+ fieldRef:
+ apiVersion: v1
+ fieldPath: status.podIP
+---
+
+apiVersion: apps/v1beta1
+kind: Deployment
+metadata:
+ name: trainer0
+ labels:
+ app: trainer0
+spec:
+ replicas: 1
+ template:
+ metadata:
+ name: trainer0
+ labels:
+ app: trainer0
+ spec:
+ containers:
+ - name: trainer0
+ image: hub.baidubce.com/paddlefl/paddlefl:v3
+ imagePullPolicy: Always
+ ports:
+ - containerPort: 9000
+ workingDir: /root/k8s_deployment/trainer0
+ command: ['/bin/bash']
+ args: ['test_trainer.sh']
+
+---
+
+apiVersion: apps/v1beta1
+kind: Deployment
+metadata:
+ name: trainer1
+ labels:
+ app: trainer1
+spec:
+ replicas: 1
+ template:
+ metadata:
+ name: trainer1
+ labels:
+ app: trainer1
+ spec:
+ containers:
+ - name: trainer1
+ image: hub.baidubce.com/paddlefl/paddlefl:v3
+ imagePullPolicy: Always
+ ports:
+ - containerPort: 9001
+ workingDir: /root/k8s_deployment/trainer1
+ command: ['/bin/bash']
+ args: ['test_trainer.sh']
+
+---
diff --git a/paddle_fl/examples/k8s_deployment/master/fl_master.py b/paddle_fl/examples/k8s_deployment/master/fl_master.py
new file mode 100644
index 0000000000000000000000000000000000000000..96aab47ae7014d3753d872e67fca1a0b593629de
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/master/fl_master.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import paddle.fluid as fluid
+import os
+import paddle_fl as fl
+from paddle_fl.core.master.job_generator import JobGenerator
+from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="master")
+ parser.add_argument(
+ '--trainer_num',
+ type=int,
+ default=2,
+ help='number of trainer(default: 2)')
+
+ return parser.parse_args()
+
+
+class Model(object):
+ def __init__(self):
+ pass
+
+ def mlp(self, inputs, label, hidden_size=128):
+ self.concat = fluid.layers.concat(inputs, axis=1)
+ self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
+ self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
+ self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
+ self.sum_cost = fluid.layers.cross_entropy(
+ input=self.predict, label=label)
+ self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
+ self.loss = fluid.layers.reduce_mean(self.sum_cost)
+ self.startup_program = fluid.default_startup_program()
+
+inputs = [fluid.layers.data( \
+ name=str(slot_id), shape=[5],
+ dtype="float32")
+ for slot_id in range(3)]
+label = fluid.layers.data( \
+ name="label",
+ shape=[1],
+ dtype='int64')
+
+model = Model()
+model.mlp(inputs, label)
+
+job_generator = JobGenerator()
+optimizer = fluid.optimizer.SGD(learning_rate=0.1)
+job_generator.set_optimizer(optimizer)
+job_generator.set_losses([model.loss])
+job_generator.set_startup_program(model.startup_program)
+job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
+ [model.predict.name])
+
+build_strategy = FLStrategyFactory()
+build_strategy.fed_avg = True
+build_strategy.inner_step = 10
+strategy = build_strategy.create_fl_strategy()
+
+# endpoints will be collected through the cluster
+# in this example, we suppose endpoints have been collected
+server_service_ip = os.environ['FL_SERVER_SERVICE_HOST'] + ":" + os.environ[
+ 'FL_SERVER_SERVICE_PORT_FL_SERVER']
+service_endpoints = [server_service_ip]
+pod_endpoints = ["0.0.0.0:8181"]
+output = "fl_job_config"
+args = parse_args()
+num_trainer = args.trainer_num
+#job_generator.generate_fl_job(
+# strategy, server_endpoints=endpoints, worker_num=num_trainer, output=output)
+# fl_job_config will be dispatched to workers
+
+job_generator.generate_fl_job_for_k8s(
+ strategy,
+ server_pod_endpoints=pod_endpoints,
+ server_service_endpoints=service_endpoints,
+ worker_num=2,
+ output=output)
diff --git a/paddle_fl/examples/k8s_deployment/master/run_master.sh b/paddle_fl/examples/k8s_deployment/master/run_master.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b07187eb3078e3ce26e8f3234a4933b499557f4e
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/master/run_master.sh
@@ -0,0 +1,3 @@
+python fl_master.py --trainer_num 2
+tar -zcvf fl_job_config.tar.gz fl_job_config
+python -m SimpleHTTPServer 8000
diff --git a/paddle_fl/examples/k8s_deployment/scheduler/fl_scheduler.py b/paddle_fl/examples/k8s_deployment/scheduler/fl_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..232825b3c3d463f43750df0cfa698119e993118b
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/scheduler/fl_scheduler.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+from paddle_fl.core.scheduler.agent_master import FLScheduler
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="scheduler")
+ parser.add_argument(
+ '--trainer_num',
+ type=int,
+ default=2,
+ help='number trainers(default: 2)')
+
+ return parser.parse_args()
+
+
+args = parse_args()
+num_trainer = args.trainer_num
+worker_num = num_trainer
+server_num = 1
+# Define the number of worker/server and the port for scheduler
+scheduler = FLScheduler(worker_num, server_num, port=9091)
+scheduler.set_sample_worker_num(worker_num)
+scheduler.init_env()
+print("init env done.")
+scheduler.start_fl_training()
diff --git a/paddle_fl/examples/k8s_deployment/scheduler/run_scheduler.sh b/paddle_fl/examples/k8s_deployment/scheduler/run_scheduler.sh
new file mode 100644
index 0000000000000000000000000000000000000000..13494e2c365b64a76f488e4b6a595ba520803171
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/scheduler/run_scheduler.sh
@@ -0,0 +1,3 @@
+python fl_scheduler.py --trainer_num 2
+
+
diff --git a/paddle_fl/examples/k8s_deployment/server/fl_server.py b/paddle_fl/examples/k8s_deployment/server/fl_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..197b2f9d4f580a96b979aac8de02a70a8f65d7ce
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/server/fl_server.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle_fl as fl
+import os
+import paddle.fluid as fluid
+from paddle_fl.core.server.fl_server import FLServer
+from paddle_fl.core.master.fl_job import FLRunTimeJob
+import time
+server = FLServer()
+server_id = 0
+job_path = "fl_job_config"
+job = FLRunTimeJob()
+job.load_server_job(job_path, server_id)
+job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[
+ 'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER'] # IP address for scheduler
+#job._endpoints = os.environ['POD_IP'] + ":" + os.environ['FL_SERVER_SERVICE_PORT_FL_SERVER'] # IP address for server
+server.set_server_job(job)
+server._current_ep = os.environ['FL_SERVER_SERVICE_HOST'] + ":" + os.environ[
+ 'FL_SERVER_SERVICE_PORT_FL_SERVER'] # IP address for server
+print(job._scheduler_ep, server._current_ep)
+server.start()
+print("connect")
diff --git a/paddle_fl/examples/k8s_deployment/server/run_server.sh b/paddle_fl/examples/k8s_deployment/server/run_server.sh
new file mode 100644
index 0000000000000000000000000000000000000000..203d96c07c197b28efa282efd224443dd7bd3b7b
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/server/run_server.sh
@@ -0,0 +1,9 @@
+export GLOG_v=3
+wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
+while [ $? -ne 0 ]
+do
+ sleep 3
+ wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
+done
+tar -xf fl_job_config.tar.gz
+python -u fl_server.py > server.log 2>&1
diff --git a/paddle_fl/examples/k8s_deployment/trainer0/fl_trainer.py b/paddle_fl/examples/k8s_deployment/trainer0/fl_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cbce925ef6c401731aab8d5bd5dcc36610347e5
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/trainer0/fl_trainer.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
+from paddle_fl.core.master.fl_job import FLRunTimeJob
+import numpy as np
+import sys
+import os
+import logging
+import time
+logging.basicConfig(
+ filename="test.log",
+ filemode="w",
+ format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
+ datefmt="%d-%M-%Y %H:%M:%S",
+ level=logging.DEBUG)
+
+
+def reader():
+ for i in range(1000):
+ data_dict = {}
+ for i in range(3):
+ data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
+ data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
+ yield data_dict
+
+
+trainer_id = int(sys.argv[1]) # trainer id for each guest
+job_path = "fl_job_config"
+job = FLRunTimeJob()
+job.load_trainer_job(job_path, trainer_id)
+#job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
+job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[
+ 'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER']
+trainer = FLTrainerFactory().create_fl_trainer(job)
+#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
+trainer._current_ep = os.environ['TRAINER0_SERVICE_HOST'] + ":" + os.environ[
+ 'TRAINER0_SERVICE_PORT_TRAINER0']
+trainer.start()
+print(trainer._scheduler_ep, trainer._current_ep)
+output_folder = "fl_model"
+epoch_id = 0
+while not trainer.stop():
+ print("batch %d start train" % (epoch_id))
+ train_step = 0
+ for data in reader():
+ trainer.run(feed=data, fetch=[])
+ train_step += 1
+ if train_step == trainer._step:
+ break
+ epoch_id += 1
+ if epoch_id % 5 == 0:
+ trainer.save_inference_program(output_folder)
diff --git a/paddle_fl/examples/k8s_deployment/trainer0/test_trainer.sh b/paddle_fl/examples/k8s_deployment/trainer0/test_trainer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cfcedcf8342b744d874bc3ad5df2248a4d25b569
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/trainer0/test_trainer.sh
@@ -0,0 +1,9 @@
+wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
+while [ $? -ne 0 ]
+do
+ sleep 3
+ wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
+done
+tar -xf fl_job_config.tar.gz
+sleep 10
+python -u fl_trainer.py 0
diff --git a/paddle_fl/examples/k8s_deployment/trainer1/fl_trainer.py b/paddle_fl/examples/k8s_deployment/trainer1/fl_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed5dd099cce7e85b1ddb47857551b84d67c78c6
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/trainer1/fl_trainer.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
+from paddle_fl.core.master.fl_job import FLRunTimeJob
+import numpy as np
+import sys
+import os
+import logging
+import time
+logging.basicConfig(
+ filename="test.log",
+ filemode="w",
+ format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
+ datefmt="%d-%M-%Y %H:%M:%S",
+ level=logging.DEBUG)
+
+
+def reader():
+ for i in range(1000):
+ data_dict = {}
+ for i in range(3):
+ data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
+ data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
+ yield data_dict
+
+
+trainer_id = int(sys.argv[1]) # trainer id for each guest
+job_path = "fl_job_config"
+job = FLRunTimeJob()
+job.load_trainer_job(job_path, trainer_id)
+#job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
+job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[
+ 'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER']
+trainer = FLTrainerFactory().create_fl_trainer(job)
+#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
+trainer._current_ep = os.environ['TRAINER1_SERVICE_HOST'] + ":" + os.environ[
+ 'TRAINER1_SERVICE_PORT_TRAINER1']
+trainer.start()
+print(trainer._scheduler_ep, trainer._current_ep)
+output_folder = "fl_model"
+epoch_id = 0
+while not trainer.stop():
+ print("batch %d start train" % (epoch_id))
+ train_step = 0
+ for data in reader():
+ trainer.run(feed=data, fetch=[])
+ train_step += 1
+ if train_step == trainer._step:
+ break
+ epoch_id += 1
+ if epoch_id % 5 == 0:
+ trainer.save_inference_program(output_folder)
diff --git a/paddle_fl/examples/k8s_deployment/trainer1/test_trainer.sh b/paddle_fl/examples/k8s_deployment/trainer1/test_trainer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..82dd83ac826e9a00d8b7218f869463ddb677f851
--- /dev/null
+++ b/paddle_fl/examples/k8s_deployment/trainer1/test_trainer.sh
@@ -0,0 +1,9 @@
+wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
+while [ $? -ne 0 ]
+do
+ sleep 3
+ wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz
+done
+tar -xf fl_job_config.tar.gz
+sleep 10
+python -u fl_trainer.py 1
diff --git a/paddle_fl/examples/secagg_demo/fl_master.py b/paddle_fl/examples/secagg_demo/fl_master.py
index e0e19e6fc77392a0d628d287bf53295977ab7e81..c6245d5e08ec36d7c41cd8c97c1485ddead68082 100644
--- a/paddle_fl/examples/secagg_demo/fl_master.py
+++ b/paddle_fl/examples/secagg_demo/fl_master.py
@@ -1,8 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle.fluid as fluid
import paddle_fl as fl
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
+
class Model(object):
def __init__(self):
pass
@@ -14,12 +29,17 @@ class Model(object):
param_attrs = fluid.ParamAttr(
name="fc_0.w_0",
initializer=fluid.initializer.ConstantInitializer(0.0))
- self.predict = fluid.layers.fc(input=inputs, size=10, act='softmax', param_attr=param_attrs)
- self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label)
+ self.predict = fluid.layers.fc(input=inputs,
+ size=10,
+ act='softmax',
+ param_attr=param_attrs)
+ self.sum_cost = fluid.layers.cross_entropy(
+ input=self.predict, label=label)
self.loss = fluid.layers.mean(self.sum_cost)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.startup_program = fluid.default_startup_program()
+
inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
@@ -31,15 +51,16 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.01)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
-job_generator.set_infer_feed_and_target_names(
- [inputs.name, label.name], [model.loss.name])
+job_generator.set_infer_feed_and_target_names([inputs.name, label.name],
+ [model.loss.name])
build_strategy = FLStrategyFactory()
#build_strategy.fed_avg = True
build_strategy.sec_agg = True
param_name_list = []
-param_name_list.append("fc_0.w_0.opti.trainer_") # need trainer_id when running
+param_name_list.append(
+ "fc_0.w_0.opti.trainer_") # need trainer_id when running
param_name_list.append("fc_0.b_0.opti.trainer_")
build_strategy.param_name_list = param_name_list
diff --git a/paddle_fl/examples/secagg_demo/fl_scheduler.py b/paddle_fl/examples/secagg_demo/fl_scheduler.py
index 649f74768ef76cf1f8af52319b3818144f37e98a..1d8c21b8307ce8b1d20f2ecea1557bc2fb56e0c1 100644
--- a/paddle_fl/examples/secagg_demo/fl_scheduler.py
+++ b/paddle_fl/examples/secagg_demo/fl_scheduler.py
@@ -1,9 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2
server_num = 1
-scheduler = FLScheduler(worker_num,server_num,port=9091)
+scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
diff --git a/paddle_fl/examples/secagg_demo/fl_server.py b/paddle_fl/examples/secagg_demo/fl_server.py
index 529df8da4079fbbd217c58a857f7ab8a3c307586..2bc79fff528bc8e52e353cc56ca17618d6f4acca 100644
--- a/paddle_fl/examples/secagg_demo/fl_server.py
+++ b/paddle_fl/examples/secagg_demo/fl_server.py
@@ -21,8 +21,8 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
-job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
+job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
-server._current_ep = "127.0.0.1:8181" # IP address for server
+server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
print("connect")
diff --git a/paddle_fl/examples/secagg_demo/fl_trainer.py b/paddle_fl/examples/secagg_demo/fl_trainer.py
index 7b08e8724a67219654c7f94809fc281353202ad2..99d0a2f465b4df6bd4fcd6c47fbea4cc19f9f901 100644
--- a/paddle_fl/examples/secagg_demo/fl_trainer.py
+++ b/paddle_fl/examples/secagg_demo/fl_trainer.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy
@@ -11,27 +25,32 @@ import math
import hashlib
import hmac
-logging.basicConfig(filename="log/test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
+logging.basicConfig(
+ filename="log/test.log",
+ filemode="w",
+ format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
+ datefmt="%d-%M-%Y %H:%M:%S",
+ level=logging.DEBUG)
logger = logging.getLogger("FLTrainer")
BATCH_SIZE = 64
train_reader = paddle.batch(
- paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
+ paddle.reader.shuffle(
+ paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
-test_reader = paddle.batch(
- paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
+test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
trainer_num = 2
-trainer_id = int(sys.argv[1]) # trainer id for each guest
+trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
-job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
+job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id
-trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
+trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/"
trainer.start()
@@ -47,8 +66,8 @@ feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace())
# for test
test_program = trainer._main_program.clone(for_test=True)
-def train_test(train_test_program,
- train_test_feed, train_test_reader):
+
+def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
@@ -61,6 +80,8 @@ def train_test(train_test_program,
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
+
+
# for test
while not trainer.stop():
@@ -71,15 +92,18 @@ while not trainer.stop():
step_i += 1
trainer.step_id = step_i
accuracy, = trainer.run(feed=feeder.feed(data),
- fetch=["accuracy_0.tmp_0"])
+ fetch=["accuracy_0.tmp_0"])
if step_i % 100 == 0:
- print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0]))
+ print("Epoch: {0}, step: {1}, accuracy: {2}".format(
+ epoch_id, step_i, accuracy[0]))
print(step_i)
- avg_loss_val, acc_val = train_test(train_test_program=test_program,
- train_test_reader=test_reader,
- train_test_feed=feeder)
- print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
+ avg_loss_val, acc_val = train_test(
+ train_test_program=test_program,
+ train_test_reader=test_reader,
+ train_test_feed=feeder)
+ print("Test with Epoch %d, avg_cost: %s, acc: %s" %
+ (epoch_id, avg_loss_val, acc_val))
if epoch_id > 40:
break
diff --git a/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt b/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt
index cfb5684de948c423374fd2a820ccb29c654c8ac2..d33390b5dfff21882d7de124710389ede6b5b8ef 100644
--- a/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt
+++ b/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt
@@ -1 +1 @@
-2438748580808349511143047663636683775879288034436941526695550498623461587527621346172907651006831789701999970929529915459467532662545948308044143788306668377086821294492459623439935894424167712515718436351900091777957477710004777078638317806960364609629258387413979203403741893205419691425902518810085451041187685334971769087054033027561974230347468587825700834108657561999305482311897914109364221430821533207693682979777541616125499682380618775029176238891407643926372043660610226672413497764635239787000143341827693941253721638947580506197728500367325524850325027980531066702962726949006217630290236644410746181942256812170056772600756232506116738493114591218127323741133163913140583529684827066023347088796194846253682954154504336640429027403657831470993825621749318372332546269811820953216261135662418531598954663771775691648448615131802158937156803324423733802071166119966224716088242291968098450309032800335049617861465
\ No newline at end of file
+2438748580808349511143047663636683775879288034436941526695550498623461587527621346172907651006831789701999970929529915459467532662545948308044143788306668377086821294492459623439935894424167712515718436351900091777957477710004777078638317806960364609629258387413979203403741893205419691425902518810085451041187685334971769087054033027561974230347468587825700834108657561999305482311897914109364221430821533207693682979777541616125499682380618775029176238891407643926372043660610226672413497764635239787000143341827693941253721638947580506197728500367325524850325027980531066702962726949006217630290236644410746181942256812170056772600756232506116738493114591218127323741133163913140583529684827066023347088796194846253682954154504336640429027403657831470993825621749318372332546269811820953216261135662418531598954663771775691648448615131802158937156803324423733802071166119966224716088242291968098450309032800335049617861465
diff --git a/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt b/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt
index a88733c160ba5052b4d0b445a477357cbb88417a..32ecdc45c94580031a6f62573995fdef617f898f 100644
--- a/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt
+++ b/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt
@@ -1 +1 @@
-2514645349791449916465355128335954929464444612258498884322250411584328344530925790221013632799576102047787468232441470392901627580493383471719532612816534848391408601920539266665550346246343040368608757429591392784807798812848893304745441721044204602414415725300562075953290154457726382683020925406187524758708694161689285261293920782115270550960717687322240202298426363733065475031448888618026711435993053991653780694485897784243346859087028197560255857091562150381619708471192080403868115055265681423866891707972356449212236920210992128063075734725292359699578940877165584175851667803049667020005978253573912096358469050409541237248593428640028573437783747958382446666712845099931003578559688134007238753677011257181086592677636834099341020870502521085827878680362572013623469761170961943916356175726242515624843837354222489536469472365937707615959657846428001149463801728084949088483783942894784080451399167982957006474558
\ No newline at end of file
+2514645349791449916465355128335954929464444612258498884322250411584328344530925790221013632799576102047787468232441470392901627580493383471719532612816534848391408601920539266665550346246343040368608757429591392784807798812848893304745441721044204602414415725300562075953290154457726382683020925406187524758708694161689285261293920782115270550960717687322240202298426363733065475031448888618026711435993053991653780694485897784243346859087028197560255857091562150381619708471192080403868115055265681423866891707972356449212236920210992128063075734725292359699578940877165584175851667803049667020005978253573912096358469050409541237248593428640028573437783747958382446666712845099931003578559688134007238753677011257181086592677636834099341020870502521085827878680362572013623469761170961943916356175726242515624843837354222489536469472365937707615959657846428001149463801728084949088483783942894784080451399167982957006474558
diff --git a/paddle_fl/examples/submitter_demo/conf.txt b/paddle_fl/examples/submitter_demo/conf.txt
index 0880739e809063bc21531a92bb205ee2df4f169a..f2f0d48d028de7afb9fb8ecd75961344bdbaf002 100644
--- a/paddle_fl/examples/submitter_demo/conf.txt
+++ b/paddle_fl/examples/submitter_demo/conf.txt
@@ -21,4 +21,3 @@ server=yq01-hpc-lvliang01-smart-master.dmop.baidu.com
python_tar=./python.tar.gz
wheel=./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64.whl
-
diff --git a/paddle_fl/examples/submitter_demo/model.py b/paddle_fl/examples/submitter_demo/model.py
index f07549b9048c7b09a7fe25f9961db64fc8a820bb..2ead34a71835492df47e05f1d369e271433882f7 100644
--- a/paddle_fl/examples/submitter_demo/model.py
+++ b/paddle_fl/examples/submitter_demo/model.py
@@ -1,5 +1,20 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle.fluid as fluid
+
class Model(object):
def __init__(self):
pass
@@ -9,8 +24,8 @@ class Model(object):
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
- self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label)
+ self.sum_cost = fluid.layers.cross_entropy(
+ input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program()
-
diff --git a/paddle_fl/examples/submitter_demo/scheduler_client.py b/paddle_fl/examples/submitter_demo/scheduler_client.py
index eff68df845129feaca093b4277f1970dcfb911e9..2e903353585c2ec418614be1a373afcec72395aa 100644
--- a/paddle_fl/examples/submitter_demo/scheduler_client.py
+++ b/paddle_fl/examples/submitter_demo/scheduler_client.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import socket
import random
@@ -49,6 +63,7 @@ default_dict = {
"wheel": "./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64-0.whl"
}
+
def load_conf(conf_file, local_dict):
with open(conf_file) as fin:
for line in fin:
@@ -58,6 +73,7 @@ def load_conf(conf_file, local_dict):
local_dict[group[0]] = group[1]
return local_dict
+
client = HPCClient()
default_dict = load_conf(sys.argv[1], default_dict)
@@ -94,9 +110,11 @@ all_ips_ready = False
ip_list = []
-scheduler = FLScheduler(int(default_dict["worker_nodes"]),
- int(default_dict["server_nodes"]),
- port=random_port, socket=zmq_socket)
+scheduler = FLScheduler(
+ int(default_dict["worker_nodes"]),
+ int(default_dict["server_nodes"]),
+ port=random_port,
+ socket=zmq_socket)
scheduler.set_sample_worker_num(int(default_dict["worker_nodes"]))
@@ -121,12 +139,14 @@ print(ip_list)
#allocate the role of each endpoint and their ids
ip_role = {}
for i in range(len(ip_list)):
- if i < int(default_dict["server_nodes"]):
- ip_role[ip_list[i]] = 'server%d' % i
- else:
- ip_role[ip_list[i]] = 'trainer%d' % (i-int(default_dict["server_nodes"]))
+ if i < int(default_dict["server_nodes"]):
+ ip_role[ip_list[i]] = 'server%d' % i
+ else:
+ ip_role[ip_list[i]] = 'trainer%d' % (
+ i - int(default_dict["server_nodes"]))
print(ip_role)
+
def job_generate():
#generate a fl job which is the same as fl_master
inputs = [fluid.layers.data( \
@@ -146,8 +166,8 @@ def job_generate():
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
- job_generator.set_infer_feed_and_target_names(
- [x.name for x in inputs], [model.predict.name])
+ job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
+ [model.predict.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
@@ -157,20 +177,24 @@ def job_generate():
# endpoints will be collected through the cluster
# in this example, we suppose endpoints have been collected
server_ip = ["{}".format(ip_list[0])]
-
+
output = "job_config"
job_generator.generate_fl_job(
- strategy, server_endpoints=server_ip, worker_num=int(default_dict["worker_nodes"]), output=output)
-
+ strategy,
+ server_endpoints=server_ip,
+ worker_num=int(default_dict["worker_nodes"]),
+ output=output)
+
file_list = os.listdir(output)
for file in file_list:
- tar = tarfile.open('{}/{}.tar.gz'.format(output,file),'w:gz')
- for root,dir,files in os.walk("{}/{}".format(output,file)):
- for f in files:
- fullpath = os.path.join(root,f)
- tar.add(fullpath)
+ tar = tarfile.open('{}/{}.tar.gz'.format(output, file), 'w:gz')
+ for root, dir, files in os.walk("{}/{}".format(output, file)):
+ for f in files:
+ fullpath = os.path.join(root, f)
+ tar.add(fullpath)
tar.close()
+
job_generate()
#send the allocated rolls to the remote endpoints
diff --git a/paddle_fl/examples/submitter_demo/train_program.py b/paddle_fl/examples/submitter_demo/train_program.py
index a5f10fab0eb2fe8aae506f71f4961c5a5ea5cdfa..bba1482df0ed198a348e67e51ba79e60a9bc3a4b 100644
--- a/paddle_fl/examples/submitter_demo/train_program.py
+++ b/paddle_fl/examples/submitter_demo/train_program.py
@@ -1,3 +1,17 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import socket
import random
import zmq
@@ -13,7 +27,6 @@ import sys
import logging
import time
-
random_port = 60001
scheduler_conf = {}
@@ -31,8 +44,7 @@ download_url = "{}:8080".format(scheduler_ip[0])
print(download_url)
context = zmq.Context()
zmq_socket = context.socket(zmq.REQ)
-zmq_socket.connect(
- "tcp://{}".format(scheduler_conf["ENDPOINT"]))
+zmq_socket.connect("tcp://{}".format(scheduler_conf["ENDPOINT"]))
zmq_socket.send("ENDPOINT\t{}".format(endpoint))
message = zmq_socket.recv()
print(message)
@@ -47,7 +59,7 @@ while True:
if group[0] == "WAIT":
continue
else:
- os.system("wget {}/job_config/{}.tar.gz".format(download_url,message))
+ os.system("wget {}/job_config/{}.tar.gz".format(download_url, message))
print(message)
break
@@ -71,6 +83,7 @@ if 'server' in message:
server._current_ep = endpoint
server.start()
else:
+
def reader():
for i in range(1000):
data_dict = {}
@@ -96,7 +109,7 @@ else:
for data in reader():
trainer.run(feed=data, fetch=[])
step_i += 1
- if step_i == trainer._step:
+ if step_i == trainer._step:
break
epoch_id += 1
if epoch_id % 5 == 0:
diff --git a/paddle_fl/reader/gru4rec_reader.py b/paddle_fl/reader/gru4rec_reader.py
index dadb73136bccde03377c5c1136e30564e00dcb81..6e95909da989e73dc4fcc9c539967cca8b8f2024 100644
--- a/paddle_fl/reader/gru4rec_reader.py
+++ b/paddle_fl/reader/gru4rec_reader.py
@@ -1,7 +1,22 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import paddle.fluid as fluid
import numpy as np
import os
+
class Gru4rec_Reader:
def __init__(self):
pass
@@ -21,7 +36,6 @@ class Gru4rec_Reader:
res.set_lod([lod])
return res
-
def lod_reader(self, reader, place):
def feed_reader():
for data in reader():
@@ -33,12 +47,14 @@ class Gru4rec_Reader:
fe_data["src_wordseq"] = lod_src_wordseq
fe_data["dst_wordseq"] = lod_dst_wordseq
yield fe_data
+
return feed_reader
def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False):
"""
Create a batched reader.
"""
+
def batch_reader():
r = reader()
b = []
@@ -66,11 +82,11 @@ class Gru4rec_Reader:
# Batch size check
batch_size = int(batch_size)
if batch_size <= 0:
- raise ValueError("batch_size should be a positive integeral value, "
- "but got batch_size={}".format(batch_size))
+ raise ValueError(
+ "batch_size should be a positive integeral value, "
+ "but got batch_size={}".format(batch_size))
return batch_reader
-
def reader_creator(self, file_dir):
def reader():
files = os.listdir(file_dir)
@@ -82,10 +98,12 @@ class Gru4rec_Reader:
src_seq = l[:len(l) - 1]
trg_seq = l[1:]
yield src_seq, trg_seq
+
return reader
def reader(self, file_dir, place, batch_size=5):
""" prepare the English Pann Treebank (PTB) data """
print("start constuct word dict")
- reader = self.sort_batch(self.reader_creator(file_dir), batch_size, batch_size * 20)
+ reader = self.sort_batch(
+ self.reader_creator(file_dir), batch_size, batch_size * 20)
return self.lod_reader(reader, place)
diff --git a/paddle_fl/version.py b/paddle_fl/version.py
index 2f9278ce368bcbfc0d65b147d84fcd2938414638..2c95b454cebb44b8068493b4bc3002ee09f7df27 100644
--- a/paddle_fl/version.py
+++ b/paddle_fl/version.py
@@ -12,6 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaddleFL version string """
-fl_version = "0.1.10"
-module_proto_version = "0.1.10"
-
+fl_version = "0.1.11"
+module_proto_version = "0.1.11"
diff --git a/setup.py b/setup.py
index ba29a1496d39949b84f3b74eeec935d831e57c01..c12f11edd4c8ecaa95034ca3266e9ad363bf30f4 100644
--- a/setup.py
+++ b/setup.py
@@ -26,10 +26,12 @@ from paddle_fl.version import fl_version
def python_version():
return [int(v) for v in platform.python_version().split(".")]
+
max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [
- 'six >= 1.10.0', 'protobuf >= 3.1.0','paddlepaddle >= 1.6', 'zmq', 'paddlepaddle-gpu >= 1.6'
+ 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.6', 'zmq',
+ 'paddlepaddle-gpu >= 1.6'
]
if max_version < 3:
@@ -42,8 +44,7 @@ REQUIRED_PACKAGES += ["unittest2"]
setup(
name='paddle_fl',
version=fl_version.replace('-', ''),
- description=
- ('Federated Deep Learning Package Based on PaddlePaddle.'),
+ description=('Federated Deep Learning Package Based on PaddlePaddle.'),
long_description='',
url='https://github.com/PaddlePaddle/PaddleFL',
author='PaddlePaddle Author',
@@ -70,4 +71,5 @@ setup(
'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
- keywords=('paddle_fl paddlepaddle multi-task transfer distributed-training'))
+ keywords=(
+ 'paddle_fl paddlepaddle multi-task transfer distributed-training'))
diff --git a/tools/codestyle/clang_format.hook b/tools/codestyle/clang_format.hook
new file mode 100755
index 0000000000000000000000000000000000000000..1d928216867c0ba3897d71542fea44debf8d72a0
--- /dev/null
+++ b/tools/codestyle/clang_format.hook
@@ -0,0 +1,15 @@
+#!/bin/bash
+set -e
+
+readonly VERSION="3.8"
+
+version=$(clang-format -version)
+
+if ! [[ $version == *"$VERSION"* ]]; then
+ echo "clang-format version check failed."
+ echo "a version contains '$VERSION' is needed, but get '$version'"
+ echo "you can install the right version, and make an soft-link to '\$PATH' env"
+ exit -1
+fi
+
+clang-format $@
diff --git a/tools/codestyle/copyright.hook b/tools/codestyle/copyright.hook
new file mode 100644
index 0000000000000000000000000000000000000000..86b16ebdc46047c7cb3d7731a71cbf9647a1f2fe
--- /dev/null
+++ b/tools/codestyle/copyright.hook
@@ -0,0 +1,121 @@
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import argparse
+import io, re
+import sys, os
+import subprocess
+import platform
+
+COPYRIGHT = '''
+Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+'''
+
+LANG_COMMENT_MARK = None
+
+NEW_LINE_MARK = None
+
+COPYRIGHT_HEADER = None
+
+if platform.system() == "Windows":
+ NEW_LINE_MARK = "\r\n"
+else:
+ NEW_LINE_MARK = '\n'
+ COPYRIGHT_HEADER = COPYRIGHT.split(NEW_LINE_MARK)[1]
+ p = re.search('(\d{4})', COPYRIGHT_HEADER).group(0)
+ process = subprocess.Popen(["date", "+%Y"], stdout=subprocess.PIPE)
+ date, err = process.communicate()
+ date = date.decode("utf-8").rstrip("\n")
+ COPYRIGHT_HEADER = COPYRIGHT_HEADER.replace(p, date)
+
+
+def generate_copyright(template, lang='C'):
+ if lang == 'Python':
+ LANG_COMMENT_MARK = '#'
+ else:
+ LANG_COMMENT_MARK = "//"
+
+ lines = template.split(NEW_LINE_MARK)
+ BLANK = " "
+ ans = LANG_COMMENT_MARK + BLANK + COPYRIGHT_HEADER + NEW_LINE_MARK
+ for lino, line in enumerate(lines):
+ if lino == 0 or lino == 1 or lino == len(lines) - 1: continue
+ if len(line) == 0:
+ BLANK = ""
+ else:
+ BLANK = " "
+ ans += LANG_COMMENT_MARK + BLANK + line + NEW_LINE_MARK
+
+ return ans + "\n"
+
+
+def lang_type(filename):
+ if filename.endswith(".py"):
+ return "Python"
+ elif filename.endswith(".h"):
+ return "C"
+ elif filename.endswith(".c"):
+ return "C"
+ elif filename.endswith(".hpp"):
+ return "C"
+ elif filename.endswith(".cc"):
+ return "C"
+ elif filename.endswith(".cpp"):
+ return "C"
+ elif filename.endswith(".cu"):
+ return "C"
+ elif filename.endswith(".cuh"):
+ return "C"
+ elif filename.endswith(".go"):
+ return "C"
+ elif filename.endswith(".proto"):
+ return "C"
+ else:
+ print("Unsupported filetype %s", filename)
+ exit(0)
+
+
+PYTHON_ENCODE = re.compile("^[ \t\v]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)")
+
+
+def main(argv=None):
+ parser = argparse.ArgumentParser(
+ description='Checker for copyright declaration.')
+ parser.add_argument('filenames', nargs='*', help='Filenames to check')
+ args = parser.parse_args(argv)
+
+ retv = 0
+ for filename in args.filenames:
+ fd = io.open(filename, encoding="utf-8")
+ first_line = fd.readline()
+ second_line = fd.readline()
+ if "COPYRIGHT (C)" in first_line.upper(): continue
+ if first_line.startswith("#!") or PYTHON_ENCODE.match(
+ second_line) != None or PYTHON_ENCODE.match(first_line) != None:
+ continue
+ original_contents = io.open(filename, encoding="utf-8").read()
+ new_contents = generate_copyright(
+ COPYRIGHT, lang_type(filename)) + original_contents
+ print('Auto Insert Copyright Header {}'.format(filename))
+ retv = 1
+ with io.open(filename, 'w') as output_file:
+ output_file.write(new_contents)
+
+ return retv
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/tools/codestyle/cpplint_pre_commit.hook b/tools/codestyle/cpplint_pre_commit.hook
new file mode 100755
index 0000000000000000000000000000000000000000..658008d852123b6eab06d1f13d61ba896e7e9c98
--- /dev/null
+++ b/tools/codestyle/cpplint_pre_commit.hook
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+TOTAL_ERRORS=0
+if [[ ! $TRAVIS_BRANCH ]]; then
+ # install cpplint on local machine.
+ if [[ ! $(which cpplint) ]]; then
+ pip install cpplint
+ fi
+ # diff files on local machine.
+ files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
+else
+ # diff files between PR and latest commit on Travis CI.
+ branch_ref=$(git rev-parse "$TRAVIS_BRANCH")
+ head_ref=$(git rev-parse HEAD)
+ files=$(git diff --name-status $branch_ref $head_ref | awk '$1 != "D" {print $2}')
+fi
+# The trick to remove deleted files: https://stackoverflow.com/a/2413151
+for file in $files; do
+ if [[ $file =~ ^(patches/grpc/.*) ]]; then
+ continue;
+ else
+ cpplint --filter=-readability/fn_size $file;
+ TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
+ fi
+done
+
+exit $TOTAL_ERRORS
diff --git a/tools/codestyle/docstring_checker.py b/tools/codestyle/docstring_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d4b24a0cf6b743b72dca58fd885f927560964bf
--- /dev/null
+++ b/tools/codestyle/docstring_checker.py
@@ -0,0 +1,349 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DocstringChecker is used to check python doc string's style."""
+
+import six
+import astroid
+
+from pylint.checkers import BaseChecker, utils
+from pylint.interfaces import IAstroidChecker
+
+from collections import defaultdict
+import re
+
+
+def register(linter):
+ """Register checkers."""
+ linter.register_checker(DocstringChecker(linter))
+
+
+class Docstring(object):
+ """Docstring class holds the parsed doc string elements.
+ """
+
+ def __init__(self):
+ self.d = defaultdict(list) #name->[]
+ self.clear()
+
+ def clear(self):
+ self.d['Args'] = []
+ self.d['Examples'] = []
+ self.d['Returns'] = []
+ self.d['Raises'] = []
+ self.args = {} #arg_name->arg_type
+
+ def get_level(self, string, indent=' '):
+ level = 0
+ unit_size = len(indent)
+ while string[:unit_size] == indent:
+ string = string[unit_size:]
+ level += 1
+
+ return level
+
+ def parse(self, doc):
+ """parse gets sections from doc
+ Such as Args, Returns, Raises, Examples s
+ Args:
+ doc (string): is the astroid node doc string.
+ Returns:
+ True if doc is parsed successfully.
+ """
+ self.clear()
+
+ lines = doc.splitlines()
+ state = ("others", -1)
+ for l in lines:
+ c = l.strip()
+ if len(c) <= 0:
+ continue
+
+ level = self.get_level(l)
+ if c.startswith("Args:"):
+ state = ("Args", level)
+ elif c.startswith("Returns:"):
+ state = ("Returns", level)
+ elif c.startswith("Raises:"):
+ state = ("Raises", level)
+ elif c.startswith("Examples:"):
+ state = ("Examples", level)
+ else:
+ if level > state[1]:
+ self.d[state[0]].append(c)
+ continue
+
+ state = ("others", -1)
+ self.d[state[0]].append(c)
+
+ self._arg_with_type()
+ return True
+
+ def get_returns(self):
+ return self.d['Returns']
+
+ def get_raises(self):
+ return self.d['Raises']
+
+ def get_examples(self):
+ return self.d['Examples']
+
+ def _arg_with_type(self):
+
+ for t in self.d['Args']:
+ m = re.search('([A-Za-z0-9_-]+)\s{0,4}(\(.+\))\s{0,4}:', t)
+ if m:
+ self.args[m.group(1)] = m.group(2)
+
+ return self.args
+
+
+class DocstringChecker(BaseChecker):
+ """DosstringChecker is pylint checker to
+ check docstring style.
+ """
+ __implements__ = (IAstroidChecker, )
+
+ POSITIONAL_MESSAGE_ID = 'str-used-on-positional-format-argument'
+ KEYWORD_MESSAGE_ID = 'str-used-on-keyword-format-argument'
+
+ name = 'doc-string-checker'
+ symbol = "doc-string"
+ priority = -1
+ msgs = {
+ 'W9001': ('One line doc string on > 1 lines', symbol + "-one-line",
+ 'Used when a short doc string is on multiple lines'),
+ 'W9002':
+ ('Doc string does not end with "." period', symbol + "-end-with",
+ 'Used when a doc string does not end with a period'),
+ 'W9003':
+ ('All args with their types must be mentioned in doc string %s',
+ symbol + "-with-all-args",
+ 'Used when not all arguments are in the doc string '),
+ 'W9005': ('Missing docstring or docstring is too short',
+ symbol + "-missing", 'Add docstring longer >=10'),
+ 'W9006': ('Docstring indent error, use 4 space for indent',
+ symbol + "-indent-error", 'Use 4 space for indent'),
+ 'W9007': ('You should add `Returns` in comments',
+ symbol + "-with-returns",
+ 'There should be a `Returns` section in comments'),
+ 'W9008': ('You should add `Raises` section in comments',
+ symbol + "-with-raises",
+ 'There should be a `Raises` section in comments'),
+ }
+ options = ()
+
+ def visit_functiondef(self, node):
+ """visit_functiondef checks Function node docstring style.
+ Args:
+ node (astroid.node): The visiting node.
+ Returns:
+ True if successful other wise False.
+ """
+
+ self.check_doc_string(node)
+
+ if node.tolineno - node.fromlineno <= 10:
+ return True
+
+ if not node.doc:
+ return True
+
+ doc = Docstring()
+ doc.parse(node.doc)
+
+ self.all_args_in_doc(node, doc)
+ self.with_returns(node, doc)
+ self.with_raises(node, doc)
+
+ def visit_module(self, node):
+ self.check_doc_string(node)
+
+ def visit_classdef(self, node):
+ self.check_doc_string(node)
+
+ def check_doc_string(self, node):
+ self.missing_doc_string(node)
+ self.one_line(node)
+ self.has_period(node)
+ self.indent_style(node)
+
+ def missing_doc_string(self, node):
+ if node.name.startswith("__") or node.name.startswith("_"):
+ return True
+ if node.tolineno - node.fromlineno <= 10:
+ return True
+
+ if node.doc is None or len(node.doc) < 10:
+ self.add_message('W9005', node=node, line=node.fromlineno)
+ return False
+
+ # FIXME(gongwb): give the docstring line-no
+ def indent_style(self, node, indent=4):
+ """indent_style checks docstring's indent style
+ Args:
+ node (astroid.node): The visiting node.
+ indent (int): The default indent of style
+ Returns:
+ True if successful other wise False.
+ """
+ if node.doc is None:
+ return True
+
+ doc = node.doc
+ lines = doc.splitlines()
+ line_num = 0
+
+ for l in lines:
+ if line_num == 0:
+ continue
+ cur_indent = len(l) - len(l.lstrip())
+ if cur_indent % indent != 0:
+ self.add_message('W9006', node=node, line=node.fromlineno)
+ return False
+ line_num += 1
+
+ return True
+
+ def one_line(self, node):
+ """one_line checks if docstring (len < 40) is on one line.
+ Args:
+ node (astroid.node): The node visiting.
+ Returns:
+ True if successful otherwise False.
+ """
+
+ doc = node.doc
+ if doc is None:
+ return True
+
+ if len(doc) > 40:
+ return True
+ elif sum(doc.find(nl) for nl in ('\n', '\r', '\n\r')) == -3:
+ return True
+ else:
+ self.add_message('W9001', node=node, line=node.fromlineno)
+ return False
+
+ return True
+
+ def has_period(self, node):
+ """has_period checks if one line doc end-with '.' .
+ Args:
+ node (astroid.node): the node is visiting.
+ Returns:
+ True if successful otherwise False.
+ """
+ if node.doc is None:
+ return True
+
+ if len(node.doc.splitlines()) > 1:
+ return True
+
+ if not node.doc.strip().endswith('.'):
+ self.add_message('W9002', node=node, line=node.fromlineno)
+ return False
+
+ return True
+
+ def with_raises(self, node, doc):
+ """with_raises checks if one line doc end-with '.' .
+ Args:
+ node (astroid.node): the node is visiting.
+ doc (Docstring): Docstring object.
+ Returns:
+ True if successful otherwise False.
+ """
+
+ find = False
+ for t in node.body:
+ if not isinstance(t, astroid.Raise):
+ continue
+
+ find = True
+ break
+
+ if not find:
+ return True
+
+ if len(doc.get_raises()) == 0:
+ self.add_message('W9008', node=node, line=node.fromlineno)
+ return False
+
+ return True
+
+ def with_returns(self, node, doc):
+ """with_returns checks if docstring comments what are returned .
+ Args:
+ node (astroid.node): the node is visiting.
+ doc (Docstring): Docstring object.
+ Returns:
+ True if successful otherwise False.
+ """
+
+ if node.name.startswith("__") or node.name.startswith("_"):
+ return True
+ find = False
+ for t in node.body:
+ if not isinstance(t, astroid.Return):
+ continue
+
+ find = True
+ break
+
+ if not find:
+ return True
+
+ if len(doc.get_returns()) == 0:
+ self.add_message('W9007', node=node, line=node.fromlineno)
+ return False
+
+ return True
+
+ def all_args_in_doc(self, node, doc):
+ """all_args_in_doc checks if arguments are mentioned in doc
+ Args:
+ node (astroid.node): the node is visiting.
+ doc (Docstring): Docstring object
+ Returns:
+ True if successful otherwise False.
+ """
+ if node.name.startswith("__") or node.name.startswith("_"):
+ return True
+ args = []
+ for arg in node.args.get_children():
+ if (not isinstance(arg, astroid.AssignName)) \
+ or arg.name == "self":
+ continue
+ args.append(arg.name)
+
+ if len(args) <= 0:
+ return True
+
+ parsed_args = doc.args
+ args_not_documented = set(args) - set(parsed_args)
+ if len(args) > 0 and len(parsed_args) <= 0:
+ self.add_message(
+ 'W9003',
+ node=node,
+ line=node.fromlineno,
+ args=list(args_not_documented))
+ return False
+
+ for t in args:
+ if t not in parsed_args:
+ self.add_message(
+ 'W9003', node=node, line=node.fromlineno, args=[t, ])
+ return False
+
+ return True
diff --git a/tools/codestyle/pylint_pre_commit.hook b/tools/codestyle/pylint_pre_commit.hook
new file mode 100755
index 0000000000000000000000000000000000000000..150a3f5666bd39d30b7e6518e58a14fb5fe2f14b
--- /dev/null
+++ b/tools/codestyle/pylint_pre_commit.hook
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+TOTAL_ERRORS=0
+
+
+DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+export PYTHONPATH=$DIR:$PYTHONPATH
+
+# The trick to remove deleted files: https://stackoverflow.com/a/2413151
+for file in $(git diff --name-status | awk '$1 != "D" {print $2}'); do
+ pylint --disable=all --load-plugins=docstring_checker \
+ --enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file;
+ TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
+done
+
+exit $TOTAL_ERRORS
+#For now, just warning:
+#exit 0
+
diff --git a/tools/codestyle/test_docstring_checker.py b/tools/codestyle/test_docstring_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..0547f7d1610c64b0ca6efa9384e97d658c8276fe
--- /dev/null
+++ b/tools/codestyle/test_docstring_checker.py
@@ -0,0 +1,232 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import docstring_checker
+import pylint.testutils
+import astroid
+import pytest
+import sys
+
+
+class TestDocstring(pylint.testutils.CheckerTestCase):
+ CHECKER_CLASS = docstring_checker.DocstringChecker
+
+ def test_one_line(self):
+ func_node = astroid.extract_node('''
+ def test():
+ """get
+ news.
+ """
+ if True:
+ return 5
+ return 5
+ ''')
+
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 1
+ assert 'W9001' == got[0][0]
+
+ def test_one_line(self):
+ func_node = astroid.extract_node('''
+ def test():
+ """get news"""
+ if True:
+ return 5
+ return 5
+ ''')
+
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 1
+ assert 'W9002' == got[0][0]
+
+ def test_args(self):
+ func_node = astroid.extract_node('''
+ def test(scale, mean):
+ """get news.
+ Args:
+ scale (int): scale is the number.
+ """
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ ''')
+
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 1
+ assert 'W9003' == got[0][0]
+
+ def test_missing(self):
+ func_node = astroid.extract_node('''
+ def test():
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ ''')
+
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 1
+ assert 'W9005' == got[0][0]
+
+ def test_indent(self):
+ func_node = astroid.extract_node('''
+ def test():
+ """ get get get get get get get get
+ get get get get get get get get.
+ """
+ pass
+ ''')
+
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 1
+ assert 'W9006' == got[0][0]
+
+ def test_with_resturns(self):
+ func_node = astroid.extract_node('''
+ def test():
+ """get news.
+ Args:
+ scale (int): scale is the number.
+ """
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ return mean
+ ''')
+
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 1
+ assert 'W9007' == got[0][0]
+
+ def test_with_raises(self):
+ func_node = astroid.extract_node('''
+ def test():
+ """get news.
+ Args:
+ scale (int): scale is the number.
+ """
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ mean=scale
+ raise ValueError('A very specific bad thing happened.')
+ ''')
+
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 1
+ assert 'W9008' == got[0][0]
+
+ def test_no_message(self):
+ p = '''
+def fc(input,
+ size,
+ num_flatten_dims=1,
+ param_attr=None,
+ bias_attr=None,
+ act=None,
+ name=None):
+ """
+ **Fully Connected Layer**
+ The fully connected layer can take multiple tensors as its inputs. It
+ creates a variable called weights for each input tensor, which represents
+ a fully connected weight matrix from each input unit to each output unit.
+ The fully connected layer multiplies each input tensor with its coresponding
+ weight to produce an output Tensor. If multiple input tensors are given,
+ the results of multiple multiplications will be sumed up. If bias_attr is
+ not None, a bias variable will be created and added to the output. Finally,
+ if activation is not None, it will be applied to the output as well.
+ This process can be formulated as follows:
+
+ Args:
+ input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of
+ the input tensor(s) is at least 2.
+ size(int): The number of output units in this layer.
+ num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than
+ two dimensions. If this happens, the multidimensional tensor will first be flattened
+ into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input
+ tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1)
+ dimensions will be flatten to form the first dimension of the final matrix (height of
+ the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to
+ form the second dimension of the final matrix (width of the matrix). For example, suppose
+ `X` is a 6-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3.
+ Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30].
+ param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
+ parameters/weights of this layer.
+ bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
+ of this layer. If it is set to None, no bias will be added to the output units.
+ act (str, default None): Activation to be applied to the output of this layer.
+ name (str, default None): The name of this layer.
+ Returns:
+ A tensor variable storing the transformation result.
+ Raises:
+ ValueError: If rank of the input tensor is less than 2.
+ Examples:
+ .. code-block:: python
+ data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
+ fc = fluid.layers.fc(input=data, size=1000, act="tanh")
+ """
+ raise ValueError('A very specific bad thing happened.')
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ size = 1
+ return size
+ '''
+
+ func_node = astroid.extract_node(p)
+ self.checker.visit_functiondef(func_node)
+ got = self.linter.release_messages()
+ assert len(got) == 0