diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8112837dc9627bc2e501940b8e97c89e97c45ff --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,52 @@ +repos: +- repo: https://github.com/Lucas-C/pre-commit-hooks.git + sha: v1.0.1 + hooks: + - id: remove-crlf + files: (?!.*third_party)^.*$ | (?!.*book)^.*$ +- repo: https://github.com/PaddlePaddle/mirrors-yapf.git + sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + hooks: + - id: yapf + files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + files: (?!.*third_party)^.*$ | (?!.*book)^.*$ + - id: end-of-file-fixer +- repo: local + hooks: + - id: clang-format-with-version-check + name: clang-format + description: Format files with ClangFormat. + entry: bash ./tools/codestyle/clang_format.hook -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ +- repo: local + hooks: + - id: cpplint-cpp-source + name: cpplint + description: Check C++ code style using cpplint.py. + entry: bash ./tools/codestyle/cpplint_pre_commit.hook + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$ +- repo: local + hooks: + - id: pylint-doc-string + name: pylint + description: Check python docstring style using docstring_checker. + entry: bash ./tools/codestyle/pylint_pre_commit.hook + language: system + files: \.(py)$ +- repo: local + hooks: + - id: copyright_checker + name: copyright_checker + entry: python ./tools/codestyle/copyright.hook + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ + exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000000000000000000000000000000000000..9f25a7770bbebfc11617923bfd49175d229993df --- /dev/null +++ b/.travis.yml @@ -0,0 +1,11 @@ +language:python + +notifications: + email: + on_success: change + on_failure: always + +sudo: false + +os: + - linux diff --git a/AUTHORS.md b/AUTHORS.md index de2c1cba58d1e3079959cf40f40a4b89c4c19455..f22dad93df96e011e648e872ecd9d86f55b3af35 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -1,4 +1,5 @@ | Github account | name | |---|---| | guru4elephant | Daxiang Dong | -| frankwhzhang | Wenhui Zhang | \ No newline at end of file +| frankwhzhang | Wenhui Zhang | +| qjing666 | Qinghe Jing | diff --git a/contrib/data_safety_training/image_classification/server/receiver.py b/contrib/data_safety_training/image_classification/server/receiver.py index 70d19b02574e889aa66f0a98e12f061a8bf05a0f..0ac474ac35e67ef9810baafdd20d5fb05a655f62 100644 --- a/contrib/data_safety_training/image_classification/server/receiver.py +++ b/contrib/data_safety_training/image_classification/server/receiver.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import zmq import socket import msgpack diff --git a/contrib/data_safety_training/image_classification/server/server.py b/contrib/data_safety_training/image_classification/server/server.py index 39f4781d5e9b0c18a1548a0e09281f29bc187abd..93f8f53c1907e3d1a9b5e46da2eb2cf2204551c8 100644 --- a/contrib/data_safety_training/image_classification/server/server.py +++ b/contrib/data_safety_training/image_classification/server/server.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import print_function import os import paddle @@ -12,14 +26,14 @@ import math import msgpack -def data_generater(samples,r): - # data generater +def data_generater(samples, r): + # data generater def train_data(): for item in samples: sample = msgpack.loads(r.get(str(item))) conv = sample[0] label = sample[1] - yield conv,label + yield conv, label return train_data @@ -67,7 +81,7 @@ class ResNet(): size=class_dim, param_attr=fluid.param_attr.ParamAttr( initializer=fluid.initializer.Uniform(-stdv, stdv)), - act = "softmax") + act="softmax") else: for block in range(len(depth)): for i in range(depth[block]): @@ -87,7 +101,7 @@ class ResNet(): size=class_dim, param_attr=fluid.param_attr.ParamAttr( initializer=fluid.initializer.Uniform(-stdv, stdv)), - act = "softmax") + act="softmax") return out def conv_bn_layer(self, @@ -123,8 +137,6 @@ class ResNet(): moving_mean_name=bn_name + '_mean', moving_variance_name=bn_name + '_variance', ) - - def shortcut(self, input, ch_out, stride, is_first, name): ch_in = input.shape[1] if ch_in != ch_out or stride != 1 or is_first == True: @@ -181,31 +193,33 @@ class ResNet(): input, num_filters, stride, is_first, name=name + "_branch1") return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') + # local redis config redis_host = "127.0.0.1" redis_port = 6379 redis_password = "" -r = redis.StrictRedis(host=redis_host, port=redis_port, password=redis_password) +r = redis.StrictRedis( + host=redis_host, port=redis_port, password=redis_password) # reader generation -reader = fluid.layers.py_reader(capacity=64, - shapes=[(-1,64, 8, 8), (-1,1)], - dtypes=['float32', 'int64']) +reader = fluid.layers.py_reader( + capacity=64, shapes=[(-1, 64, 8, 8), (-1, 1)], + dtypes=['float32', 'int64']) samples = r.keys() -train_data = data_generater(samples,r) +train_data = data_generater(samples, r) -reader.decorate_paddle_reader(paddle.batch( - paddle.reader.shuffle( - train_data, buf_size=5000), - batch_size=64)) +reader.decorate_paddle_reader( + paddle.batch( + paddle.reader.shuffle( + train_data, buf_size=5000), batch_size=64)) -conv1,label = fluid.layers.read_file(reader) +conv1, label = fluid.layers.read_file(reader) # train program place = fluid.CUDAPlace(0) model = ResNet(layers=50) -predicts = model.net(conv1,10) +predicts = model.net(conv1, 10) cost = fluid.layers.cross_entropy(input=predicts, label=label) accuracy = fluid.layers.accuracy(input=predicts, label=label) loss = fluid.layers.mean(cost) @@ -222,18 +236,20 @@ step = 0 train_start = time.time() # start training for pass_id in range(EPOCH_NUM): - reader.start() - try: - while True: - start_time = time.time() - loss_value,acc_value = exe.run(fetch_list=[loss.name,accuracy.name]) - step += 1 - if step % 10 == 0: - print("epoch: "+ str(pass_id)+"step: "+str(step)+"loss: "+ str(loss_value)+"acc: "+str(acc_value)) - end_time = time.time() - total_time += (end_time - start_time) - except fluid.core.EOFException: - reader.reset() + reader.start() + try: + while True: + start_time = time.time() + loss_value, acc_value = exe.run( + fetch_list=[loss.name, accuracy.name]) + step += 1 + if step % 10 == 0: + print("epoch: " + str(pass_id) + "step: " + str(step) + + "loss: " + str(loss_value) + "acc: " + str(acc_value)) + end_time = time.time() + total_time += (end_time - start_time) + except fluid.core.EOFException: + reader.reset() train_end = time.time() print("total time: %d" % (train_end - train_start)) print("computation time: %d" % total_time) diff --git a/contrib/data_safety_training/image_classification/server/user.py b/contrib/data_safety_training/image_classification/server/user.py index 89668f35d4b69a4b7d561ac4b562fe4531e7b223..505c329f0ca8310824989d83af4a04f05714d49d 100644 --- a/contrib/data_safety_training/image_classification/server/user.py +++ b/contrib/data_safety_training/image_classification/server/user.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import print_function import os import paddle @@ -5,10 +19,12 @@ import paddle.fluid as fluid import numpy import sys import redis -import time +import time from paddle.fluid import layers from paddle.fluid.param_attr import ParamAttr import msgpack + + def conv_bn_layer(input, num_filters, filter_size, @@ -16,30 +32,30 @@ def conv_bn_layer(input, groups=1, act=None, name=None): - conv = fluid.layers.conv2d( - input=input, - num_filters=num_filters, - filter_size=filter_size, - stride=stride, - padding=(filter_size - 1) // 2, - groups=groups, - act=None, - param_attr=ParamAttr(name=name + "_weights"), - bias_attr=False, - name=name + '.conv2d.output.1') - - if name == "conv1": - bn_name = "bn_" + name - else: - bn_name = "bn" + name[3:] - return fluid.layers.batch_norm( - input=conv, - act=act, - name=bn_name + '.output.1', - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance', ) + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) def load_conf(conf_file, local_dict): @@ -51,6 +67,7 @@ def load_conf(conf_file, local_dict): local_dict[group[0]] = group[1] return local_dict + # redis DB configuration redis_host = "127.0.0.1" redis_port = 6379 @@ -58,26 +75,39 @@ redis_password = "" start_time = time.time() # start a redis client and empty the DB -r = redis.StrictRedis(host=redis_host, port=redis_port, password=redis_password) +r = redis.StrictRedis( + host=redis_host, port=redis_port, password=redis_password) r.flushall() # encoding program -images = fluid.layers.data(name='images', shape=[3,32,32], dtype='float32') +images = fluid.layers.data(name='images', shape=[3, 32, 32], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') place = fluid.CPUPlace() -conv1 = conv_bn_layer(input=images,num_filters=64,filter_size=7,stride=2,act='relu',name="conv1") -pool = fluid.layers.pool2d(input=conv1,pool_size=3,pool_stride=2,pool_padding=1,pool_type='max') -feeder = fluid.DataFeeder(place=place, feed_list=[images,label]) +conv1 = conv_bn_layer( + input=images, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1") +pool = fluid.layers.pool2d( + input=conv1, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') +feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) pretrained_model = 'ResNet50_pretrained' exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + # load pretrained mode and prepare datal def if_exist(var): - return os.path.exists(os.path.join(pretrained_model, var.name)) -fluid.io.load_vars(exe, pretrained_model, main_program=fluid.default_main_program(), - predicate=if_exist) + return os.path.exists(os.path.join(pretrained_model, var.name)) + +fluid.io.load_vars( + exe, + pretrained_model, + main_program=fluid.default_main_program(), + predicate=if_exist) train_data = paddle.dataset.cifar.train10() step = 0 @@ -86,11 +116,13 @@ step = 0 for data in train_data(): pre_data = [] pre_data.append(data) - res = exe.run(program=fluid.default_main_program(),feed=feeder.feed(pre_data), fetch_list=[pool.name]) - sample = [res[0][0].tolist(),data[1]] + res = exe.run(program=fluid.default_main_program(), + feed=feeder.feed(pre_data), + fetch_list=[pool.name]) + sample = [res[0][0].tolist(), data[1]] step += 1 file = msgpack.dumps(sample) - r.set(step,file) + r.set(step, file) if step % 100 == 0: print(numpy.array(sample[0]).shape) print("%dstart" % step) @@ -99,6 +131,4 @@ files = r.keys() print("upload file numbers: %d" % len(files)) end_time = time.time() total_time = end_time - start_time -print("total time: %d"% total_time) - - +print("total time: %d" % total_time) diff --git a/contrib/data_safety_training/image_classification/submitter.py b/contrib/data_safety_training/image_classification/submitter.py index 69bd5d8495bda836babbea6ca8dc80deb3677b3f..920b60ad8fed2d7ff0b13d17001d8227f3b0abb8 100644 --- a/contrib/data_safety_training/image_classification/submitter.py +++ b/contrib/data_safety_training/image_classification/submitter.py @@ -1,8 +1,22 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import zmq import socket import msgpack import os -mission_dict = {"mission": "image classification", "image_size": [3,32,32]} +mission_dict = {"mission": "image classification", "image_size": [3, 32, 32]} #send request context = zmq.Context() zmq_socket = context.socket(zmq.REQ) diff --git a/docs/requirements.txt b/docs/requirements.txt index c860963602bfd591ff57c8c8a722dc9670bc582d..90c6171099d9bb6e5ce32ac1164ba848c14426f4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,4 +3,3 @@ mistune sphinx_rtd_theme paddlepaddle>=1.6 zmq - diff --git a/docs/source/examples/md/dpsgd-example.md b/docs/source/examples/md/dpsgd-example.md index b0b822702ed9e0cbce1cad20f773bea1172a3755..20f80252b54eb054df6a0c429ef0463f92f20e9b 100644 --- a/docs/source/examples/md/dpsgd-example.md +++ b/docs/source/examples/md/dpsgd-example.md @@ -181,4 +181,3 @@ while not trainer.stop(): To show the effectiveness of DPSGD-based federated learning with PaddleFL, a simulated experiment is conducted on an open source dataset MNIST. From the figure given below, model evaluation results are similar between DPSGD-based federated learning and traditional parameter server training when the overall privacy budget *epsilon* is 1.3 or 0.13.
- diff --git a/docs/source/examples/md/gru4rec_examples.md b/docs/source/examples/md/gru4rec_examples.md index ea729adce65a678c0d0db52b31dd1e1b015748e6..b40e9eb8dc4033aec55bcf7769a372c9eca2a7f7 100644 --- a/docs/source/examples/md/gru4rec_examples.md +++ b/docs/source/examples/md/gru4rec_examples.md @@ -103,6 +103,3 @@ wget https://paddle-zwh.bj.bcebos.com/gru4rec_paddlefl_benchmark/gru4rec_benchma | 1/4 of the whole dataset | private training | - | 0.282 |
- - - diff --git a/docs/source/md/introduction.md b/docs/source/md/introduction.md index 8c1513631989061c74ddc327f294ceeed6df31ff..cb38bfe10ed5177d1dae2dcbca6a11cec1f79ccc 100644 --- a/docs/source/md/introduction.md +++ b/docs/source/md/introduction.md @@ -55,4 +55,3 @@ In PaddleFL, components for defining a federated learning task and training a fe - Federated Learning Systems deployment methods in Kubernetes. - Vertical Federated Learning Strategies and more horizontal federated learning strategies will be open sourced. - diff --git a/docs/source/md/reference.md b/docs/source/md/reference.md index 5a2a72fdadf16842739c1439397b41a665f8de52..214b68a8101faa24115fe8cd9ca025c2ca82df03 100644 --- a/docs/source/md/reference.md +++ b/docs/source/md/reference.md @@ -14,4 +14,4 @@ [7]. Virginia Smith, Chao-Kai Chiang, Maziar Sanjabi, Ameet Talwalkar. **Federated Multi-Task Learning** 2016 -[8]. Yang Liu, Tianjian Chen, Qiang Yang. **Secure Federated Transfer Learning.** 2018 \ No newline at end of file +[8]. Yang Liu, Tianjian Chen, Qiang Yang. **Secure Federated Transfer Learning.** 2018 diff --git a/paddle_fl/common/__init__.py b/paddle_fl/common/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..abf198b97e6e818e1fbe59006f98492640bcee54 100644 --- a/paddle_fl/common/__init__.py +++ b/paddle_fl/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddle_fl/core/__init__.py b/paddle_fl/core/__init__.py index 34df36dd857ac7204a1b134d86889f60d1fad7b5..6bec080b88ea20103b8f103eccc5bf0819c7818d 100644 --- a/paddle_fl/core/__init__.py +++ b/paddle_fl/core/__init__.py @@ -22,4 +22,3 @@ from .scheduler.agent_master import FLWorkerAgent from .scheduler.agent_master import FLScheduler from .submitter.client_base import HPCClient from .submitter.client_base import CloudClient - diff --git a/paddle_fl/core/master/fl_job.py b/paddle_fl/core/master/fl_job.py index a221897d38393cfb8998ea0d988cff4f3c22e61e..7a80e13e8e6d1e5255ce176a5de4ffdee97bf5e4 100644 --- a/paddle_fl/core/master/fl_job.py +++ b/paddle_fl/core/master/fl_job.py @@ -14,11 +14,13 @@ import os import paddle.fluid as fluid + class FLJobBase(object): """ FLJobBase is fl job base class, responsible for save and load a federated learning job """ + def __init__(self): pass @@ -64,6 +66,7 @@ class FLJobBase(object): return fluid.Program.parse_from_string(program_desc_str) return None + class FLCompileTimeJob(FLJobBase): """ FLCompileTimeJob is a container for compile time job in federated learning. @@ -71,6 +74,7 @@ class FLCompileTimeJob(FLJobBase): are in FLCompileTimeJob. Also, server main programs and server startup programs are in this class. FLCompileTimeJob has server endpoints for debugging as well """ + def __init__(self): self._trainer_startup_programs = [] self._trainer_recv_programs = [] @@ -101,69 +105,59 @@ class FLCompileTimeJob(FLJobBase): os.system("mkdir -p %s" % server_folder) server_startup = self._server_startup_programs[i] server_main = self._server_main_programs[i] - self._save_program( - server_startup, - "%s/server.startup.program" % server_folder) - self._save_program( - server_main, - "%s/server.main.program" % server_folder) + self._save_program(server_startup, + "%s/server.startup.program" % server_folder) + self._save_program(server_main, + "%s/server.main.program" % server_folder) + self._save_readable_program(server_startup, + "%s/server.startup.program.txt" % + server_folder) self._save_readable_program( - server_startup, - "%s/server.startup.program.txt" % server_folder) - self._save_readable_program( - server_main, - "%s/server.main.program.txt" % server_folder) + server_main, "%s/server.main.program.txt" % server_folder) self._save_str_list(self._feed_names, - "%s/feed_names" % server_folder) + "%s/feed_names" % server_folder) self._save_str_list(self._target_names, - "%s/target_names" % server_folder) + "%s/target_names" % server_folder) self._save_endpoints(self._server_endpoints, - "%s/endpoints" % server_folder) + "%s/endpoints" % server_folder) self._save_strategy(self._strategy, - "%s/strategy.pkl" % server_folder) + "%s/strategy.pkl" % server_folder) for i in range(trainer_num): trainer_folder = "%s/trainer%d" % (folder, i) os.system("mkdir -p %s" % trainer_folder) trainer_startup = self._trainer_startup_programs[i] trainer_main = self._trainer_main_programs[i] - self._save_program( - trainer_startup, - "%s/trainer.startup.program" % trainer_folder) - self._save_program( - trainer_main, - "%s/trainer.main.program" % trainer_folder) - self._save_readable_program( - trainer_startup, - "%s/trainer.startup.program.txt" % trainer_folder) + self._save_program(trainer_startup, + "%s/trainer.startup.program" % trainer_folder) + self._save_program(trainer_main, + "%s/trainer.main.program" % trainer_folder) + self._save_readable_program(trainer_startup, + "%s/trainer.startup.program.txt" % + trainer_folder) self._save_readable_program( - trainer_main, - "%s/trainer.main.program.txt" % trainer_folder) + trainer_main, "%s/trainer.main.program.txt" % trainer_folder) self._save_str_list(self._feed_names, - "%s/feed_names" % trainer_folder) + "%s/feed_names" % trainer_folder) self._save_str_list(self._target_names, - "%s/target_names" % trainer_folder) + "%s/target_names" % trainer_folder) self._save_endpoints(self._server_endpoints, - "%s/endpoints" % trainer_folder) + "%s/endpoints" % trainer_folder) self._save_strategy(self._strategy, - "%s/strategy.pkl" % trainer_folder) + "%s/strategy.pkl" % trainer_folder) for i in range(send_prog_num): trainer_folder = "%s/trainer%d" % (folder, i) trainer_send = self._trainer_send_programs[i] trainer_recv = self._trainer_recv_programs[i] - self._save_program( - trainer_send, - "%s/trainer.send.program" % trainer_folder) - self._save_program( - trainer_recv, - "%s/trainer.recv.program" % trainer_folder) + self._save_program(trainer_send, + "%s/trainer.send.program" % trainer_folder) + self._save_program(trainer_recv, + "%s/trainer.recv.program" % trainer_folder) self._save_readable_program( - trainer_send, - "%s/trainer.send.program.txt" % trainer_folder) + trainer_send, "%s/trainer.send.program.txt" % trainer_folder) self._save_readable_program( - trainer_recv, - "%s/trainer.recv.program.txt" % trainer_folder) + trainer_recv, "%s/trainer.recv.program.txt" % trainer_folder) class FLRunTimeJob(FLJobBase): @@ -172,6 +166,7 @@ class FLRunTimeJob(FLJobBase): A trainer or a server can load FLRunTimeJob. Only necessary programs can be loaded in FLRunTimeJob """ + def __init__(self): self._trainer_startup_program = None self._trainer_recv_program = None diff --git a/paddle_fl/core/master/job_generator.py b/paddle_fl/core/master/job_generator.py index e4291241a864c354c1a549ef07acb7652ec377dc..64feb7d8083699459d05f3ccee6185cd00194312 100644 --- a/paddle_fl/core/master/job_generator.py +++ b/paddle_fl/core/master/job_generator.py @@ -14,6 +14,7 @@ import paddle.fluid as fluid from .fl_job import FLCompileTimeJob + class JobGenerator(object): """ A JobGenerator is responsible for generating distributed federated @@ -21,6 +22,7 @@ class JobGenerator(object): need to define a deep learning model together to do horizontal federated learning. """ + def __init__(self): # worker num for federated learning self._worker_num = 0 @@ -32,7 +34,6 @@ class JobGenerator(object): self._feed_names = [] self._target_names = [] - def set_optimizer(self, optimizer): """ Set optimizer of current job @@ -56,8 +57,10 @@ class JobGenerator(object): self._startup_prog = startup def set_infer_feed_and_target_names(self, feed_names, target_names): - if not isinstance(feed_names, list) or not isinstance(target_names, list): - raise ValueError("input should be list in set_infer_feed_and_target_names") + if not isinstance(feed_names, list) or not isinstance(target_names, + list): + raise ValueError( + "input should be list in set_infer_feed_and_target_names") ''' print(feed_names) print(target_names) @@ -76,7 +79,6 @@ class JobGenerator(object): server_endpoints=[], worker_num=1, output=None): - """ Generate Federated Learning Job, based on user defined configs @@ -130,17 +132,66 @@ class JobGenerator(object): startup_program = self._startup_prog.clone() main_program = self._losses[0].block.program.clone() fl_strategy._build_trainer_program_for_job( - trainer_id, program=main_program, - ps_endpoints=server_endpoints, trainers=worker_num, - sync_mode=True, startup_program=startup_program, + trainer_id, + program=main_program, + ps_endpoints=server_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job) + + startup_program = self._startup_prog.clone() + main_program = self._losses[0].block.program.clone() + fl_strategy._build_server_programs_for_job( + program=main_program, + ps_endpoints=server_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job) + + local_job.set_feed_names(self._feed_names) + local_job.set_target_names(self._target_names) + local_job.set_strategy(fl_strategy) + local_job.save(output) + + def generate_fl_job_for_k8s(self, + fl_strategy, + server_pod_endpoints=[], + server_service_endpoints=[], + worker_num=1, + output=None): + + local_job = FLCompileTimeJob() + assert len(self._losses) > 0 + assert self._startup_prog != None + assert fl_strategy != None + assert output != None + fl_strategy.minimize(self._optimizer, self._losses) + + # strategy can generate startup and main program + # of a single worker and servers + for trainer_id in range(worker_num): + startup_program = self._startup_prog.clone() + main_program = self._losses[0].block.program.clone() + fl_strategy._build_trainer_program_for_job( + trainer_id, + program=main_program, + ps_endpoints=server_service_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, job=local_job) startup_program = self._startup_prog.clone() main_program = self._losses[0].block.program.clone() fl_strategy._build_server_programs_for_job( - program=main_program, ps_endpoints=server_endpoints, - trainers=worker_num, sync_mode=True, - startup_program=startup_program, job=local_job) + program=main_program, + ps_endpoints=server_pod_endpoints, + trainers=worker_num, + sync_mode=True, + startup_program=startup_program, + job=local_job) local_job.set_feed_names(self._feed_names) local_job.set_target_names(self._target_names) diff --git a/paddle_fl/core/scheduler/agent_master.py b/paddle_fl/core/scheduler/agent_master.py index f26a85a7f44aca7e662dd3140b269944d89201f5..7c1fe4fbca182551a695190f3161488b8d9c4db3 100644 --- a/paddle_fl/core/scheduler/agent_master.py +++ b/paddle_fl/core/scheduler/agent_master.py @@ -1,7 +1,22 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import zmq import time import random + def recv_and_parse_kv(socket): message = socket.recv() group = message.decode().split("\t") @@ -10,9 +25,11 @@ def recv_and_parse_kv(socket): else: return group[0], group[1] + WORKER_EP = "WORKER_EP" SERVER_EP = "SERVER_EP" + class FLServerAgent(object): def __init__(self, scheduler_ep, current_ep): self.scheduler_ep = scheduler_ep @@ -29,6 +46,7 @@ class FLServerAgent(object): if group[0] == 'INIT': break + class FLWorkerAgent(object): def __init__(self, scheduler_ep, current_ep): self.scheduler_ep = scheduler_ep @@ -64,7 +82,6 @@ class FLWorkerAgent(object): return False - class FLScheduler(object): def __init__(self, worker_num, server_num, port=9091, socket=None): self.context = zmq.Context() diff --git a/paddle_fl/core/server/fl_server.py b/paddle_fl/core/server/fl_server.py index f8fc1922da8b9ebe2a5abe68782a6f32cf272989..0c1529c1840ea3983691ff6b57decf5f60537ee8 100644 --- a/paddle_fl/core/server/fl_server.py +++ b/paddle_fl/core/server/fl_server.py @@ -14,8 +14,8 @@ import paddle.fluid as fluid from paddle_fl.core.scheduler.agent_master import FLServerAgent -class FLServer(object): +class FLServer(object): def __init__(self): self._startup_program = None self._main_program = None diff --git a/paddle_fl/core/strategy/details/checkport.py b/paddle_fl/core/strategy/details/checkport.py index 89dd4dd50b0299de986b84f46e889d554030f180..9749bc37dbeff39f39da6b6c726287861f32056d 100644 --- a/paddle_fl/core/strategy/details/checkport.py +++ b/paddle_fl/core/strategy/details/checkport.py @@ -48,8 +48,8 @@ def wait_server_ready(endpoints): not_ready_endpoints.append(ep) if not all_ok: sys.stderr.write("server not ready, wait 3 sec to retry...\n") - sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + - "\n") + sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + + "\n") sys.stderr.flush() time.sleep(3) else: diff --git a/paddle_fl/core/strategy/details/program_utils.py b/paddle_fl/core/strategy/details/program_utils.py index dc78ffe70b3dfda75a799583e85b76d8d921e078..28ec5e097d7ad8df3d8709f439342a11dc4bd592 100644 --- a/paddle_fl/core/strategy/details/program_utils.py +++ b/paddle_fl/core/strategy/details/program_utils.py @@ -163,7 +163,8 @@ def block_to_code(block, block_idx, fout=None, skip_op_callstack=False): indent = 0 print( - "{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx), + "{0}{1} // block {2}".format( + get_indent_space(indent), '{', block_idx), file=fout) indent += 1 diff --git a/paddle_fl/core/strategy/fl_distribute_transpiler.py b/paddle_fl/core/strategy/fl_distribute_transpiler.py index 28b369ce9b3ada38eedc438f10f3b1da9725833b..0d204f8fbcd6c79e7ec85e269d0900c9b01e3b61 100644 --- a/paddle_fl/core/strategy/fl_distribute_transpiler.py +++ b/paddle_fl/core/strategy/fl_distribute_transpiler.py @@ -50,6 +50,7 @@ def log(*args): if PRINT_LOG: print(args) + def same_or_split_var(p_name, var_name): return p_name == var_name or p_name.startswith(var_name + ".block") @@ -113,7 +114,9 @@ class FLDistributeTranspiler(object): def _get_all_remote_sparse_update_op(self, main_program): sparse_update_ops = [] - sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] + sparse_update_op_types = [ + "lookup_table", "nce", "hierarchical_sigmoid" + ] for op in main_program.global_block().ops: if op.type in sparse_update_op_types and op.attr( 'remote_prefetch') is True: @@ -406,12 +409,13 @@ class FLDistributeTranspiler(object): # NOTE: single_trainer_var must be created for multi-trainer # case to merge grads from multiple trainers single_trainer_var = pserver_program.global_block().var( - orig_var_name) + orig_var_name) if self.sync_mode and self.trainer_num > 1: for trainer_id in range(self.trainer_num): var = pserver_program.global_block().create_var( - name="%s.opti.trainer_%d" % (orig_var_name, trainer_id), + name="%s.opti.trainer_%d" % + (orig_var_name, trainer_id), persistable=False, type=v.type, dtype=v.dtype, @@ -816,7 +820,6 @@ class FLDistributeTranspiler(object): iomap = collections.OrderedDict() return iomap - def _get_lr_ops(self): lr_ops = [] block = self.origin_program.global_block() diff --git a/paddle_fl/core/strategy/fl_strategy_base.py b/paddle_fl/core/strategy/fl_strategy_base.py index 14e97ce346f4c67dbd7eb5ee34e6a12f2bdd866d..d1579d2f744c4bf67d31263e68872fe0be8b1263 100644 --- a/paddle_fl/core/strategy/fl_strategy_base.py +++ b/paddle_fl/core/strategy/fl_strategy_base.py @@ -16,11 +16,13 @@ from .fl_distribute_transpiler import FLDistributeTranspiler from paddle.fluid.optimizer import SGD import paddle.fluid as fluid + class FLStrategyFactory(object): """ FLStrategyFactory is a FLStrategy builder Users can define strategy config to create different FLStrategy """ + def __init__(self): self._fed_avg = False self._dpsgd = False @@ -86,6 +88,7 @@ class FLStrategyBase(object): """ FLStrategyBase is federated learning algorithm container """ + def __init__(self): self._fed_avg = False self._dpsgd = False @@ -105,17 +108,23 @@ class FLStrategyBase(object): for loss in losses: optimizer.minimize(loss) - def _build_trainer_program_for_job( - self, trainer_id=0, program=None, - ps_endpoints=[], trainers=0, - sync_mode=True, startup_program=None, - job=None): + def _build_trainer_program_for_job(self, + trainer_id=0, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None): pass - def _build_server_programs_for_job( - self, program=None, ps_endpoints=[], - trainers=0, sync_mode=True, - startup_program=None, job=None): + def _build_server_programs_for_job(self, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None): pass @@ -123,6 +132,7 @@ class DPSGDStrategy(FLStrategyBase): """ DPSGDStrategy: Deep Learning with Differential Privacy. 2016 """ + def __init__(self): super(DPSGDStrategy, self).__init__() @@ -162,29 +172,40 @@ class DPSGDStrategy(FLStrategyBase): """ Define Dpsgd optimizer """ - optimizer = fluid.optimizer.Dpsgd(self._learning_rate, clip=self._clip, batch_size=self._batch_size, sigma=self._sigma) + optimizer = fluid.optimizer.Dpsgd( + self._learning_rate, + clip=self._clip, + batch_size=self._batch_size, + sigma=self._sigma) optimizer.minimize(losses[0]) - def _build_trainer_program_for_job( - self, trainer_id=0, program=None, - ps_endpoints=[], trainers=0, - sync_mode=True, startup_program=None, - job=None): + def _build_trainer_program_for_job(self, + trainer_id=0, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None): transpiler = fluid.DistributeTranspiler() - transpiler.transpile(trainer_id, - program=program, - pservers=",".join(ps_endpoints), - trainers=trainers, - sync_mode=sync_mode, - startup_program=startup_program) + transpiler.transpile( + trainer_id, + program=program, + pservers=",".join(ps_endpoints), + trainers=trainers, + sync_mode=sync_mode, + startup_program=startup_program) main = transpiler.get_trainer_program(wait_port=False) job._trainer_startup_programs.append(startup_program) job._trainer_main_programs.append(main) - def _build_server_programs_for_job( - self, program=None, ps_endpoints=[], - trainers=0, sync_mode=True, - startup_program=None, job=None): + def _build_server_programs_for_job(self, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None): transpiler = fluid.DistributeTranspiler() trainer_id = 0 transpiler.transpile( @@ -207,6 +228,7 @@ class FedAvgStrategy(FLStrategyBase): FedAvgStrategy: this is model averaging optimization proposed in H. Brendan McMahan, Eider Moore, Daniel Ramage, Blaise Aguera y Arcas. Federated Learning of Deep Networks using Model Averaging. 2017 """ + def __init__(self): super(FedAvgStrategy, self).__init__() @@ -216,28 +238,35 @@ class FedAvgStrategy(FLStrategyBase): """ optimizer.minimize(losses[0]) - def _build_trainer_program_for_job( - self, trainer_id=0, program=None, - ps_endpoints=[], trainers=0, - sync_mode=True, startup_program=None, - job=None): + def _build_trainer_program_for_job(self, + trainer_id=0, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None): transpiler = FLDistributeTranspiler() - transpiler.transpile(trainer_id, - program=program, - pservers=",".join(ps_endpoints), - trainers=trainers, - sync_mode=sync_mode, - startup_program=startup_program) + transpiler.transpile( + trainer_id, + program=program, + pservers=",".join(ps_endpoints), + trainers=trainers, + sync_mode=sync_mode, + startup_program=startup_program) recv, main, send = transpiler.get_trainer_program() job._trainer_startup_programs.append(startup_program) job._trainer_main_programs.append(main) job._trainer_send_programs.append(send) job._trainer_recv_programs.append(recv) - def _build_server_programs_for_job( - self, program=None, ps_endpoints=[], - trainers=0, sync_mode=True, - startup_program=None, job=None): + def _build_server_programs_for_job(self, + program=None, + ps_endpoints=[], + trainers=0, + sync_mode=True, + startup_program=None, + job=None): transpiler = FLDistributeTranspiler() trainer_id = 0 transpiler.transpile( @@ -262,6 +291,7 @@ class SecAggStrategy(FedAvgStrategy): Practical Secure Aggregation for Privacy-Preserving Machine Learning, The 24th ACM Conference on Computer and Communications Security ( CCS2017 ). """ + def __init__(self): super(SecAggStrategy, self).__init__() self._param_name_list = [] diff --git a/paddle_fl/core/submitter/client_base.py b/paddle_fl/core/submitter/client_base.py index 43d9ece6e77174fac01235ae10852a05ef7a5e66..0f2e977906d20966957e290d9d4e1f6edeb3e8bc 100644 --- a/paddle_fl/core/submitter/client_base.py +++ b/paddle_fl/core/submitter/client_base.py @@ -1,10 +1,25 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import sys import os + class CloudClient(object): def __init__(self): pass - + def generate_submit_sh(self, job_dir): with open() as fout: pass @@ -16,6 +31,7 @@ class CloudClient(object): def submit(self, **kwargs): pass + class HPCClient(object): def __init__(self): self.conf_dict = {} @@ -70,27 +86,20 @@ class HPCClient(object): fout.write("#!/bin/bash\n") fout.write("unset http_proxy\n") fout.write("unset https_proxy\n") - fout.write("export HADOOP_HOME={}\n".format( - self.hadoop_home)) + fout.write("export HADOOP_HOME={}\n".format(self.hadoop_home)) fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}" " -Dfs.default.name={} -rmr {}\n".format( - self.ugi, - self.hdfs_path, - self.hdfs_output)) + self.ugi, self.hdfs_path, self.hdfs_output)) fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem)) fout.write("{}/bin/qsub_f -N {} --conf qsub.conf " "--hdfs {} --ugi {} --hout {} --files ./package " "-l nodes={},walltime=1000:00:00,pmem-hard={}," "pcpu-soft={},pnetin-soft=1000," "pnetout-soft=1000 job.sh\n".format( - self.hpc_home, - self.task_name, - self.hdfs_path, - self.ugi, - self.hdfs_output, + self.hpc_home, self.task_name, self.hdfs_path, + self.ugi, self.hdfs_output, int(self.worker_nodes) + int(self.server_nodes), - self.mpi_node_mem, - self.pcpu)) + self.mpi_node_mem, self.pcpu)) def generate_job_sh(self, job_dir): with open("{}/job.sh".format(job_dir), "w") as fout: @@ -98,17 +107,23 @@ class HPCClient(object): fout.write("WORKDIR=`pwd`\n") fout.write("mpirun -npernode 1 mv package/* ./\n") fout.write("echo 'current dir: '$WORKDIR\n") - fout.write("mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n") - fout.write("export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n") + fout.write( + "mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n") + fout.write( + "export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n") fout.write("mpirun -npernode 1 python/bin/python -m pip install " "{} --index-url=http://pip.baidu.com/pypi/simple " "--trusted-host pip.baidu.com > /dev/null\n".format( self.wheel)) fout.write("export PATH=python/bin:$PATH\n") if self.monitor_cmd != "": - fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile " - "${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".format(self.monitor_cmd)) - fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n") + fout.write( + "mpirun -npernode 1 -timestamp-output -tag-output -machinefile " + "${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n". + format(self.monitor_cmd)) + fout.write( + "mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n" + ) fout.write("if [[ $? -ne 0 ]]; then\n") fout.write(" echo 'Failed to run mpi!' 1>&2\n") fout.write(" exit 1\n") @@ -150,4 +165,5 @@ class HPCClient(object): # generate job.sh self.generate_qsub_conf(jobdir) # run submit - os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(jobdir)) + os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format( + jobdir)) diff --git a/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py b/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py index 751a9b080047a407aeed62981906d860ffad0457..18a910df5193ec16c0a820339bc0c84ab3820408 100644 Binary files a/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py and b/paddle_fl/core/trainer/diffiehellman/._diffiehellman.py differ diff --git a/paddle_fl/core/trainer/diffiehellman/__init__.py b/paddle_fl/core/trainer/diffiehellman/__init__.py index a1143ffe5f7d034f11053e1ea5dbe38be50064f6..cd5e8bc0757137eee8694e600c410d49db37a78a 100644 --- a/paddle_fl/core/trainer/diffiehellman/__init__.py +++ b/paddle_fl/core/trainer/diffiehellman/__init__.py @@ -2,7 +2,6 @@ # # (c) Chris von Csefalvay, 2015. - """ __init__.py is responsible for [brief description here]. """ diff --git a/paddle_fl/core/trainer/diffiehellman/decorators.py b/paddle_fl/core/trainer/diffiehellman/decorators.py index c6f3248e6857d6ca6b60ede18d8f2377a8e2f504..2483e75e5f4ba23106aa15ab377e6669465b968e 100644 --- a/paddle_fl/core/trainer/diffiehellman/decorators.py +++ b/paddle_fl/core/trainer/diffiehellman/decorators.py @@ -1,6 +1,5 @@ # coding=utf-8 - # # The MIT License (MIT) # @@ -21,8 +20,6 @@ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # - - """ decorators declares some decorators that ensure the object has the correct keys declared when need be. diff --git a/paddle_fl/core/trainer/diffiehellman/diffiehellman.py b/paddle_fl/core/trainer/diffiehellman/diffiehellman.py index d5e1b196f56e5694b4112a8ff82022ee3bac7910..9b4e7304d24cc5a4d1ba5abc25e7bb69e2941c9d 100644 --- a/paddle_fl/core/trainer/diffiehellman/diffiehellman.py +++ b/paddle_fl/core/trainer/diffiehellman/diffiehellman.py @@ -20,10 +20,6 @@ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # - - - - """ diffiehellmann declares the main key exchange class. """ @@ -41,18 +37,17 @@ import os try: from ssl import RAND_bytes rng = RAND_bytes -except(AttributeError, ImportError): +except (AttributeError, ImportError): rng = os.urandom + class DiffieHellman: """ Implements the Diffie-Hellman key exchange protocol. """ - def __init__(self, - group=18, - key_length=640): + def __init__(self, group=18, key_length=640): self.key_length = max(200, key_length) self.generator = PRIMES[group]["generator"] @@ -81,7 +76,8 @@ class DiffieHellman: self.private_key = key def verify_public_key(self, other_public_key): - return self.prime - 1 > other_public_key > 2 and pow(other_public_key, (self.prime - 1) // 2, self.prime) == 1 + return self.prime - 1 > other_public_key > 2 and pow( + other_public_key, (self.prime - 1) // 2, self.prime) == 1 @requires_private_key def generate_public_key(self): @@ -91,9 +87,7 @@ class DiffieHellman: :return: void :rtype: void """ - self.public_key = pow(self.generator, - self.private_key, - self.prime) + self.public_key = pow(self.generator, self.private_key, self.prime) @requires_private_key def generate_shared_secret(self, other_public_key, echo_return_key=False): @@ -110,16 +104,17 @@ class DiffieHellman: if self.verify_public_key(other_public_key) is False: raise MalformedPublicKey - self.shared_secret = pow(other_public_key, - self.private_key, + self.shared_secret = pow(other_public_key, self.private_key, self.prime) try: #python3 - shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big') + shared_secret_as_bytes = self.shared_secret.to_bytes( + self.shared_secret.bit_length() // 8 + 1, byteorder='big') except: #python2 length = self.shared_secret.bit_length() // 8 + 1 - shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:] + shared_secret_as_bytes = ('%%0%dx' % ( + length << 1) % self.shared_secret).decode('hex')[-length:] _h = sha256() _h.update(bytes(shared_secret_as_bytes)) diff --git a/paddle_fl/core/trainer/diffiehellman/exceptions.py b/paddle_fl/core/trainer/diffiehellman/exceptions.py index 289062c9477e7f00ea27cfc05d56d570b8de76ad..2205c4eba7f7e33eae614b2f1975231e3e8ed956 100644 --- a/paddle_fl/core/trainer/diffiehellman/exceptions.py +++ b/paddle_fl/core/trainer/diffiehellman/exceptions.py @@ -2,7 +2,6 @@ # # (c) Chris von Csefalvay, 2015. - """ exceptions is responsible for exception handling etc. """ diff --git a/paddle_fl/core/trainer/diffiehellman/primes.py b/paddle_fl/core/trainer/diffiehellman/primes.py index 624926e4c367e1937c8e23e468be78795d04611e..265bdcc99eebebfe272fbcc85ba1f350de27d0cb 100644 --- a/paddle_fl/core/trainer/diffiehellman/primes.py +++ b/paddle_fl/core/trainer/diffiehellman/primes.py @@ -1,6 +1,5 @@ # coding=utf-8 - # # The MIT License (MIT) # @@ -25,34 +24,39 @@ # Extracted from: Kivinen, T. and Kojo, M. (2003), _More Modular Exponential (MODP) Diffie-Hellman # groups for Internet Key Exchange (IKE)_. # - """ primes holds the RFC 3526 MODP primes and their generators. """ PRIMES = { 5: { - "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF, + "prime": + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF, "generator": 2 }, 14: { - "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF, + "prime": + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF, "generator": 2 }, 15: { - "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF, + "prime": + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF, "generator": 2 }, 16: { - "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF, + "prime": + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF, "generator": 2 }, 17: { - "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF, + "prime": + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DCC4024FFFFFFFFFFFFFFFF, "generator": 2 }, 18: { - "prime": 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF, + "prime": + 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF, "generator": 2 }, } diff --git a/paddle_fl/core/trainer/fl_trainer.py b/paddle_fl/core/trainer/fl_trainer.py index 00cb67e72617c7edf73e4714b307090d78493e88..e768f259716a71fe5db75d2a35fb4480b335465b 100755 --- a/paddle_fl/core/trainer/fl_trainer.py +++ b/paddle_fl/core/trainer/fl_trainer.py @@ -19,6 +19,7 @@ import hmac import hashlib from .diffiehellman.diffiehellman import DiffieHellman + class FLTrainerFactory(object): def __init__(self): pass @@ -65,9 +66,7 @@ class FLTrainer(object): def run(self, feed, fetch): self._logger.debug("begin to run") - self.exe.run(self._main_program, - feed=feed, - fetch_list=fetch) + self.exe.run(self._main_program, feed=feed, fetch_list=fetch) self._logger.debug("end to run current batch") self.cur_step += 1 @@ -119,37 +118,34 @@ class FedAvgTrainer(FLTrainer): def reset(self): self.cur_step = 0 - def run_with_epoch(self,reader,feeder,fetch,num_epoch): + def run_with_epoch(self, reader, feeder, fetch, num_epoch): self._logger.debug("begin to run recv program") self.exe.run(self._recv_program) epoch = 0 for i in range(num_epoch): - for data in reader(): - self.exe.run(self._main_program, - feed=feeder.feed(data), - fetch_list=fetch) - self.cur_step += 1 - epoch += 1 + for data in reader(): + self.exe.run(self._main_program, + feed=feeder.feed(data), + fetch_list=fetch) + self.cur_step += 1 + epoch += 1 self._logger.debug("begin to run send program") self.exe.run(self._send_program) + def run(self, feed, fetch): - self._logger.debug("begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" % - (self.cur_step, self._step)) + self._logger.debug( + "begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" % + (self.cur_step, self._step)) if self.cur_step % self._step == 0: self._logger.debug("begin to run recv program") self.exe.run(self._recv_program) self._logger.debug("begin to run current step") - loss = self.exe.run(self._main_program, - feed=feed, - fetch_list=fetch) + loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch) if self.cur_step % self._step == 0: self._logger.debug("begin to run send program") self.exe.run(self._send_program) self.cur_step += 1 return loss - - - class SecAggTrainer(FLTrainer): @@ -207,24 +203,24 @@ class SecAggTrainer(FLTrainer): self.cur_step = 0 def run(self, feed, fetch): - self._logger.debug("begin to run SecAggTrainer, cur_step=%d, inner_step=%d" % - (self.cur_step, self._step)) + self._logger.debug( + "begin to run SecAggTrainer, cur_step=%d, inner_step=%d" % + (self.cur_step, self._step)) if self.cur_step % self._step == 0: self._logger.debug("begin to run recv program") self.exe.run(self._recv_program) scope = fluid.global_scope() self._logger.debug("begin to run current step") - loss = self.exe.run(self._main_program, - feed=feed, - fetch_list=fetch) + loss = self.exe.run(self._main_program, feed=feed, fetch_list=fetch) if self.cur_step % self._step == 0: self._logger.debug("begin to run send program") noise = 0.0 scale = pow(10.0, 5) - digestmod=hashlib.sha256 + digestmod = hashlib.sha256 # 1. load priv key and other's pub key dh = DiffieHellman(group=15, key_length=256) - dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt") + dh.load_private_key(self._key_dir + str(self._trainer_id) + + "_priv_key.txt") key = str(self._step_id).encode("utf-8") for i in range(self._trainer_num): if i != self._trainer_id: @@ -232,7 +228,8 @@ class SecAggTrainer(FLTrainer): public_key = int(f.read()) dh.generate_shared_secret(public_key, echo_return_key=True) msg = dh.shared_key.encode("utf-8") - hex_res1 = hmac.new(key=key, msg=msg, digestmod=digestmod).hexdigest() + hex_res1 = hmac.new(key=key, msg=msg, + digestmod=digestmod).hexdigest() current_noise = int(hex_res1[0:8], 16) / scale if i > self._trainer_id: noise = noise + current_noise @@ -241,9 +238,11 @@ class SecAggTrainer(FLTrainer): scope = fluid.global_scope() for param_name in self._param_name_list: - fluid.global_scope().var(param_name + str(self._trainer_id)).get_tensor().set( - numpy.array(scope.find_var(param_name + str(self._trainer_id)).get_tensor()) + noise, fluid.CPUPlace()) + fluid.global_scope().var(param_name + str( + self._trainer_id)).get_tensor().set( + numpy.array( + scope.find_var(param_name + str(self._trainer_id)) + .get_tensor()) + noise, fluid.CPUPlace()) self.exe.run(self._send_program) self.cur_step += 1 return loss - diff --git a/paddle_fl/dataset/__init__.py b/paddle_fl/dataset/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..abf198b97e6e818e1fbe59006f98492640bcee54 100644 --- a/paddle_fl/dataset/__init__.py +++ b/paddle_fl/dataset/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddle_fl/dataset/femnist.py b/paddle_fl/dataset/femnist.py index a4940b84a2edbfcece62a736be541d3dcf59c8a3..80a5b47d5a9a3f2e381c3a7712b428f53ae98a3f 100644 --- a/paddle_fl/dataset/femnist.py +++ b/paddle_fl/dataset/femnist.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import requests import os import json @@ -5,73 +19,84 @@ import tarfile import random -def download(url,tar_path): - r = requests.get(url) - with open(tar_path,'wb') as f: - f.write(r.content) - -def extract(tar_path,target_path): - tar = tarfile.open(tar_path, "r:gz") - file_names = tar.getnames() - for file_name in file_names: - tar.extract(file_name,target_path) - - tar.close() - -def train(trainer_id,inner_step,batch_size,count_by_step): - target_path = "trainer%d_data" % trainer_id - data_path = target_path + "/femnist_data" - tar_path = data_path + ".tar.gz" - if not os.path.exists(target_path): - os.system("mkdir trainer%d_data" % trainer_id) - if not os.path.exists(data_path): - print("Preparing data...") - if not os.path.exists(tar_path): - download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) - extract(tar_path,target_path) - def train_data(): - train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r') - json_train = json.load(train_file) - users = json_train["users"] - rand = random.randrange(0,len(users)) # random choose a user from each trainer - cur_user = users[rand] - print('training using '+cur_user) - train_images = json_train["user_data"][cur_user]['x'] - train_labels = json_train["user_data"][cur_user]['y'] - if count_by_step: - for i in range(inner_step*batch_size): - yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))] - else: - for i in range(len(train_images)): - yield train_images[i], train_labels[i] - - train_file.close() - - return train_data - -def test(trainer_id,inner_step,batch_size,count_by_step): - target_path = "trainer%d_data" % trainer_id - data_path = target_path + "/femnist_data" - tar_path = data_path + ".tar.gz" - if not os.path.exists(target_path): - os.system("mkdir trainer%d_data" % trainer_id) - if not os.path.exists(data_path): - print("Preparing data...") - if not os.path.exists(tar_path): - download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) - extract(tar_path,target_path) - def test_data(): - test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r') - json_test = json.load(test_file) - users = json_test["users"] - for user in users: - test_images = json_test['user_data'][user]['x'] - test_labels = json_test['user_data'][user]['y'] - for i in range(len(test_images)): - yield test_images[i], test_labels[i] - - test_file.close() - - return test_data - - +def download(url, tar_path): + r = requests.get(url) + with open(tar_path, 'wb') as f: + f.write(r.content) + + +def extract(tar_path, target_path): + tar = tarfile.open(tar_path, "r:gz") + file_names = tar.getnames() + for file_name in file_names: + tar.extract(file_name, target_path) + + tar.close() + + +def train(trainer_id, inner_step, batch_size, count_by_step): + target_path = "trainer%d_data" % trainer_id + data_path = target_path + "/femnist_data" + tar_path = data_path + ".tar.gz" + if not os.path.exists(target_path): + os.system("mkdir trainer%d_data" % trainer_id) + if not os.path.exists(data_path): + print("Preparing data...") + if not os.path.exists(tar_path): + download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz", + tar_path) + extract(tar_path, target_path) + + def train_data(): + train_file = open( + "./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" + % (trainer_id, trainer_id), 'r') + json_train = json.load(train_file) + users = json_train["users"] + rand = random.randrange( + 0, len(users)) # random choose a user from each trainer + cur_user = users[rand] + print('training using ' + cur_user) + train_images = json_train["user_data"][cur_user]['x'] + train_labels = json_train["user_data"][cur_user]['y'] + if count_by_step: + for i in range(inner_step * batch_size): + yield train_images[i % (len(train_images))], train_labels[i % ( + len(train_images))] + else: + for i in range(len(train_images)): + yield train_images[i], train_labels[i] + + train_file.close() + + return train_data + + +def test(trainer_id, inner_step, batch_size, count_by_step): + target_path = "trainer%d_data" % trainer_id + data_path = target_path + "/femnist_data" + tar_path = data_path + ".tar.gz" + if not os.path.exists(target_path): + os.system("mkdir trainer%d_data" % trainer_id) + if not os.path.exists(data_path): + print("Preparing data...") + if not os.path.exists(tar_path): + download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz", + tar_path) + extract(tar_path, target_path) + + def test_data(): + test_file = open( + "./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" + % (trainer_id, trainer_id), 'r') + json_test = json.load(test_file) + users = json_test["users"] + for user in users: + test_images = json_test['user_data'][user]['x'] + test_labels = json_test['user_data'][user]['y'] + for i in range(len(test_images)): + yield test_images[i], test_labels[i] + + test_file.close() + + return test_data diff --git a/paddle_fl/examples/ctr_demo/fl_master.py b/paddle_fl/examples/ctr_demo/fl_master.py index fcda61f6d0f051c81ceb39db96185207ac4e392e..57ac3812f93550e056c796db8283591ab39c1412 100644 --- a/paddle_fl/examples/ctr_demo/fl_master.py +++ b/paddle_fl/examples/ctr_demo/fl_master.py @@ -1,8 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.fluid as fluid import paddle_fl as fl from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory + class Model(object): def __init__(self): pass @@ -12,7 +27,8 @@ class Model(object): self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu') self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu') self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax') - self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) + self.sum_cost = fluid.layers.cross_entropy( + input=self.predict, label=label) self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) self.loss = fluid.layers.reduce_mean(self.sum_cost) self.startup_program = fluid.default_startup_program() @@ -34,8 +50,8 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.1) job_generator.set_optimizer(optimizer) job_generator.set_losses([model.loss]) job_generator.set_startup_program(model.startup_program) -job_generator.set_infer_feed_and_target_names( - [x.name for x in inputs], [model.predict.name]) +job_generator.set_infer_feed_and_target_names([x.name for x in inputs], + [model.predict.name]) build_strategy = FLStrategyFactory() build_strategy.fed_avg = True diff --git a/paddle_fl/examples/ctr_demo/fl_scheduler.py b/paddle_fl/examples/ctr_demo/fl_scheduler.py index 9dc5d84497b376d1aa7fd5731771b0799343c2ec..fe93b55d96b73734437b3387bdd04340248d7888 100644 --- a/paddle_fl/examples/ctr_demo/fl_scheduler.py +++ b/paddle_fl/examples/ctr_demo/fl_scheduler.py @@ -1,9 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.scheduler.agent_master import FLScheduler worker_num = 2 server_num = 1 # Define the number of worker/server and the port for scheduler -scheduler = FLScheduler(worker_num,server_num,port=9091) +scheduler = FLScheduler(worker_num, server_num, port=9091) scheduler.set_sample_worker_num(worker_num) scheduler.init_env() print("init env done.") diff --git a/paddle_fl/examples/ctr_demo/fl_server.py b/paddle_fl/examples/ctr_demo/fl_server.py index 529df8da4079fbbd217c58a857f7ab8a3c307586..2bc79fff528bc8e52e353cc56ca17618d6f4acca 100644 --- a/paddle_fl/examples/ctr_demo/fl_server.py +++ b/paddle_fl/examples/ctr_demo/fl_server.py @@ -21,8 +21,8 @@ server_id = 0 job_path = "fl_job_config" job = FLRunTimeJob() job.load_server_job(job_path, server_id) -job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler +job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler server.set_server_job(job) -server._current_ep = "127.0.0.1:8181" # IP address for server +server._current_ep = "127.0.0.1:8181" # IP address for server server.start() print("connect") diff --git a/paddle_fl/examples/ctr_demo/fl_trainer.py b/paddle_fl/examples/ctr_demo/fl_trainer.py index 0bcfb19544741eb1ef7158dcbf6136c2e34d1b9c..9b4b490c9abc54962da11d7f5d88077e2f558f33 100644 --- a/paddle_fl/examples/ctr_demo/fl_trainer.py +++ b/paddle_fl/examples/ctr_demo/fl_trainer.py @@ -1,10 +1,29 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.master.fl_job import FLRunTimeJob import numpy as np import sys import logging import time -logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) +logging.basicConfig( + filename="test.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG) def reader(): @@ -15,13 +34,14 @@ def reader(): data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64') yield data_dict -trainer_id = int(sys.argv[1]) # trainer id for each guest + +trainer_id = int(sys.argv[1]) # trainer id for each guest job_path = "fl_job_config" job = FLRunTimeJob() job.load_trainer_job(job_path, trainer_id) -job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer +job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer trainer = FLTrainerFactory().create_fl_trainer(job) -trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) +trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id) trainer.start() print(trainer._scheduler_ep, trainer._current_ep) output_folder = "fl_model" @@ -37,4 +57,3 @@ while not trainer.stop(): epoch_id += 1 if epoch_id % 5 == 0: trainer.save_inference_program(output_folder) - diff --git a/paddle_fl/examples/dpsgd_demo/fl_master.py b/paddle_fl/examples/dpsgd_demo/fl_master.py index f79472e8a7def3f3851d6b7eaf544a27c6eb8cc9..218f49ecb28b320db2df1f19751528d716bdb070 100644 --- a/paddle_fl/examples/dpsgd_demo/fl_master.py +++ b/paddle_fl/examples/dpsgd_demo/fl_master.py @@ -1,19 +1,39 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.fluid as fluid import paddle_fl as fl from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory import math + class Model(object): def __init__(self): pass def lr_network(self): - self.inputs = fluid.layers.data(name='img', shape=[1, 28, 28], dtype="float32") - self.label = fluid.layers.data(name='label', shape=[1],dtype='int64') - self.predict = fluid.layers.fc(input=self.inputs, size=10, act='softmax') - self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=self.label) - self.accuracy = fluid.layers.accuracy(input=self.predict, label=self.label) + self.inputs = fluid.layers.data( + name='img', shape=[1, 28, 28], dtype="float32") + self.label = fluid.layers.data(name='label', shape=[1], dtype='int64') + self.predict = fluid.layers.fc(input=self.inputs, + size=10, + act='softmax') + self.sum_cost = fluid.layers.cross_entropy( + input=self.predict, label=self.label) + self.accuracy = fluid.layers.accuracy( + input=self.predict, label=self.label) self.loss = fluid.layers.mean(self.sum_cost) self.startup_program = fluid.default_startup_program() @@ -23,7 +43,7 @@ model.lr_network() STEP_EPSILON = 0.1 DELTA = 0.00001 -SIGMA = math.sqrt(2.0 * math.log(1.25/DELTA)) / STEP_EPSILON +SIGMA = math.sqrt(2.0 * math.log(1.25 / DELTA)) / STEP_EPSILON CLIP = 4.0 batch_size = 64 @@ -33,7 +53,8 @@ job_generator.set_optimizer(optimizer) job_generator.set_losses([model.loss]) job_generator.set_startup_program(model.startup_program) job_generator.set_infer_feed_and_target_names( - [model.inputs.name, model.label.name], [model.loss.name, model.accuracy.name]) + [model.inputs.name, model.label.name], + [model.loss.name, model.accuracy.name]) build_strategy = FLStrategyFactory() build_strategy.dpsgd = True diff --git a/paddle_fl/examples/dpsgd_demo/fl_scheduler.py b/paddle_fl/examples/dpsgd_demo/fl_scheduler.py index f8ea641e2fa356102ffc08eec179e84ca1993f7d..29bba5a610dde1253d35ce64d8776ba585b8d88f 100644 --- a/paddle_fl/examples/dpsgd_demo/fl_scheduler.py +++ b/paddle_fl/examples/dpsgd_demo/fl_scheduler.py @@ -1,9 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.scheduler.agent_master import FLScheduler worker_num = 4 server_num = 1 #Define number of worker/server and the port for scheduler -scheduler = FLScheduler(worker_num,server_num,port=9091) +scheduler = FLScheduler(worker_num, server_num, port=9091) scheduler.set_sample_worker_num(4) scheduler.init_env() print("init env done.") diff --git a/paddle_fl/examples/dpsgd_demo/fl_server.py b/paddle_fl/examples/dpsgd_demo/fl_server.py index 39056e82d99fb924f52c201e0fb230b6bc1626a1..3740982b54613169074c95e510809d55ed54121b 100644 --- a/paddle_fl/examples/dpsgd_demo/fl_server.py +++ b/paddle_fl/examples/dpsgd_demo/fl_server.py @@ -21,7 +21,7 @@ server_id = 0 job_path = "fl_job_config" job = FLRunTimeJob() job.load_server_job(job_path, server_id) -job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler +job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler server.set_server_job(job) -server._current_ep = "127.0.0.1:8181" # IP address for server +server._current_ep = "127.0.0.1:8181" # IP address for server server.start() diff --git a/paddle_fl/examples/dpsgd_demo/fl_trainer.py b/paddle_fl/examples/dpsgd_demo/fl_trainer.py index 074368194e20defdea2f1151c9a0f940ed613189..f0b4a8a188456ce654056f3d531f90aaca355053 100644 --- a/paddle_fl/examples/dpsgd_demo/fl_trainer.py +++ b/paddle_fl/examples/dpsgd_demo/fl_trainer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.master.fl_job import FLRunTimeJob import numpy @@ -7,44 +21,51 @@ import paddle.fluid as fluid import logging import math -logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) +logging.basicConfig( + filename="test.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG) -trainer_id = int(sys.argv[1]) # trainer id for each guest +trainer_id = int(sys.argv[1]) # trainer id for each guest job_path = "fl_job_config" job = FLRunTimeJob() job.load_trainer_job(job_path, trainer_id) -job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer +job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer trainer = FLTrainerFactory().create_fl_trainer(job) -trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) +trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id) trainer.start() test_program = trainer._main_program.clone(for_test=True) train_reader = paddle.batch( - paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), - batch_size=64) -test_reader = paddle.batch( - paddle.dataset.mnist.test(), batch_size=64) + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=64) +test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64) img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace()) + def train_test(train_test_program, train_test_feed, train_test_reader): acc_set = [] for test_data in train_test_reader(): - acc_np = trainer.exe.run( - program=train_test_program, - feed=train_test_feed.feed(test_data), - fetch_list=["accuracy_0.tmp_0"]) + acc_np = trainer.exe.run(program=train_test_program, + feed=train_test_feed.feed(test_data), + fetch_list=["accuracy_0.tmp_0"]) acc_set.append(float(acc_np[0])) acc_val_mean = numpy.array(acc_set).mean() return acc_val_mean + def compute_privacy_budget(sample_ratio, epsilon, step, delta): E = 2 * epsilon * math.sqrt(step * sample_ratio) print("({0}, {1})-DP".format(E, delta)) + output_folder = "model_node%d" % trainer_id epoch_id = 0 step = 0 @@ -64,7 +85,8 @@ while not trainer.stop(): train_test_feed=feeder) print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val)) - compute_privacy_budget(sample_ratio=0.001, epsilon=0.1, step=step, delta=0.00001) + compute_privacy_budget( + sample_ratio=0.001, epsilon=0.1, step=step, delta=0.00001) save_dir = (output_folder + "/epoch_%d") % epoch_id trainer.save_inference_program(output_folder) diff --git a/paddle_fl/examples/femnist_demo/fl_master.py b/paddle_fl/examples/femnist_demo/fl_master.py index dd04165b917481dacf457bb6af49b7e976c28cdb..5102aa37009987010d37c7ab83b4c29a1cdcb7dc 100644 --- a/paddle_fl/examples/femnist_demo/fl_master.py +++ b/paddle_fl/examples/femnist_demo/fl_master.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.fluid as fluid import paddle_fl as fl from paddle_fl.core.master.job_generator import JobGenerator @@ -9,14 +23,31 @@ class Model(object): pass def cnn(self): - self.inputs = fluid.layers.data(name='img', shape=[1, 28, 28], dtype="float32") - self.label = fluid.layers.data(name='label', shape=[1],dtype='int64') - self.conv_pool_1 = fluid.nets.simple_img_conv_pool(input=self.inputs,num_filters=20,filter_size=5,pool_size=2,pool_stride=2,act='relu') - self.conv_pool_2 = fluid.nets.simple_img_conv_pool(input=self.conv_pool_1,num_filters=50,filter_size=5,pool_size=2,pool_stride=2,act='relu') - - self.predict = self.predict = fluid.layers.fc(input=self.conv_pool_2, size=62, act='softmax') - self.cost = fluid.layers.cross_entropy(input=self.predict, label=self.label) - self.accuracy = fluid.layers.accuracy(input=self.predict, label=self.label) + self.inputs = fluid.layers.data( + name='img', shape=[1, 28, 28], dtype="float32") + self.label = fluid.layers.data(name='label', shape=[1], dtype='int64') + self.conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=self.inputs, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + self.conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=self.conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + self.predict = self.predict = fluid.layers.fc(input=self.conv_pool_2, + size=62, + act='softmax') + self.cost = fluid.layers.cross_entropy( + input=self.predict, label=self.label) + self.accuracy = fluid.layers.accuracy( + input=self.predict, label=self.label) self.loss = fluid.layers.mean(self.cost) self.startup_program = fluid.default_startup_program() @@ -30,8 +61,8 @@ job_generator.set_optimizer(optimizer) job_generator.set_losses([model.loss]) job_generator.set_startup_program(model.startup_program) job_generator.set_infer_feed_and_target_names( - [model.inputs.name, model.label.name], [model.loss.name, model.accuracy.name]) - + [model.inputs.name, model.label.name], + [model.loss.name, model.accuracy.name]) build_strategy = FLStrategyFactory() build_strategy.fed_avg = True diff --git a/paddle_fl/examples/femnist_demo/fl_scheduler.py b/paddle_fl/examples/femnist_demo/fl_scheduler.py index 346529fd6caadef06d1e46078fa873da264a7507..bc75fa66eae90aa487ecb4b9f7e7bf84c048f5d7 100644 --- a/paddle_fl/examples/femnist_demo/fl_scheduler.py +++ b/paddle_fl/examples/femnist_demo/fl_scheduler.py @@ -1,9 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.scheduler.agent_master import FLScheduler worker_num = 4 server_num = 1 # Define the number of worker/server and the port for scheduler -scheduler = FLScheduler(worker_num,server_num,port=9091) +scheduler = FLScheduler(worker_num, server_num, port=9091) scheduler.set_sample_worker_num(4) scheduler.init_env() print("init env done.") diff --git a/paddle_fl/examples/femnist_demo/fl_server.py b/paddle_fl/examples/femnist_demo/fl_server.py index cd558deb4ba577e5d0dbc74ecc90d0e22d278d8e..736b11f302b8653592db5b82a614985e4bdc7833 100644 --- a/paddle_fl/examples/femnist_demo/fl_server.py +++ b/paddle_fl/examples/femnist_demo/fl_server.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle_fl as fl import paddle.fluid as fluid from paddle_fl.core.server.fl_server import FLServer @@ -7,7 +21,7 @@ server_id = 0 job_path = "fl_job_config" job = FLRunTimeJob() job.load_server_job(job_path, server_id) -job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler +job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler server.set_server_job(job) -server._current_ep = "127.0.0.1:8181" # IP address for server +server._current_ep = "127.0.0.1:8181" # IP address for server server.start() diff --git a/paddle_fl/examples/femnist_demo/fl_trainer.py b/paddle_fl/examples/femnist_demo/fl_trainer.py index dce2c8af02ce3a3be58f0a6fe03eca18a882472a..03675a2472ca8f07364dee9b4510950379985a67 100644 --- a/paddle_fl/examples/femnist_demo/fl_trainer.py +++ b/paddle_fl/examples/femnist_demo/fl_trainer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.master.fl_job import FLRunTimeJob import paddle_fl.dataset.femnist @@ -8,16 +22,21 @@ import paddle.fluid as fluid import logging import math -logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) +logging.basicConfig( + filename="test.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG) -trainer_id = int(sys.argv[1]) # trainer id for each guest +trainer_id = int(sys.argv[1]) # trainer id for each guest job_path = "fl_job_config" job = FLRunTimeJob() job.load_trainer_job(job_path, trainer_id) -job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer +job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer print(job._target_names) trainer = FLTrainerFactory().create_fl_trainer(job) -trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) +trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id) trainer.start() print(trainer._step) test_program = trainer._main_program.clone(for_test=True) @@ -26,26 +45,26 @@ img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace()) + def train_test(train_test_program, train_test_feed, train_test_reader): - acc_set = [] - for test_data in train_test_reader(): - acc_np = trainer.exe.run( - program=train_test_program, - feed=train_test_feed.feed(test_data), - fetch_list=["accuracy_0.tmp_0"]) - acc_set.append(float(acc_np[0])) - acc_val_mean = numpy.array(acc_set).mean() - return acc_val_mean + acc_set = [] + for test_data in train_test_reader(): + acc_np = trainer.exe.run(program=train_test_program, + feed=train_test_feed.feed(test_data), + fetch_list=["accuracy_0.tmp_0"]) + acc_set.append(float(acc_np[0])) + acc_val_mean = numpy.array(acc_set).mean() + return acc_val_mean + epoch_id = 0 step = 0 epoch = 3000 count_by_step = False if count_by_step: - output_folder = "model_node%d" % trainer_id -else: - output_folder = "model_node%d_epoch" % trainer_id - + output_folder = "model_node%d" % trainer_id +else: + output_folder = "model_node%d_epoch" % trainer_id while not trainer.stop(): count = 0 @@ -55,24 +74,35 @@ while not trainer.stop(): print("epoch %d start train" % (epoch_id)) #train_data,test_data= data_generater(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step) train_reader = paddle.batch( - paddle.reader.shuffle(paddle_fl.dataset.femnist.train(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), buf_size=500), + paddle.reader.shuffle( + paddle_fl.dataset.femnist.train( + trainer_id, + inner_step=trainer._step, + batch_size=64, + count_by_step=count_by_step), + buf_size=500), batch_size=64) test_reader = paddle.batch( - paddle_fl.dataset.femnist.test(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), batch_size=64) - + paddle_fl.dataset.femnist.test( + trainer_id, + inner_step=trainer._step, + batch_size=64, + count_by_step=count_by_step), + batch_size=64) + if count_by_step: - for step_id, data in enumerate(train_reader()): + for step_id, data in enumerate(train_reader()): acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"]) step += 1 count += 1 print(count) - if count % trainer._step == 0: + if count % trainer._step == 0: break # print("acc:%.3f" % (acc[0])) else: - trainer.run_with_epoch(train_reader,feeder,fetch=["accuracy_0.tmp_0"],num_epoch=1) - + trainer.run_with_epoch( + train_reader, feeder, fetch=["accuracy_0.tmp_0"], num_epoch=1) acc_val = train_test( train_test_program=test_program, @@ -80,6 +110,6 @@ while not trainer.stop(): train_test_feed=feeder) print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val)) - if trainer_id == 0: + if trainer_id == 0: save_dir = (output_folder + "/epoch_%d") % epoch_id trainer.save_inference_program(output_folder) diff --git a/paddle_fl/examples/gru4rec_demo/fl_master.py b/paddle_fl/examples/gru4rec_demo/fl_master.py index 48433ecae963f121a26281a2d756b9514ad979b6..e5a20d971e0eaf9ab4c0fd96910a41a7227d9d77 100644 --- a/paddle_fl/examples/gru4rec_demo/fl_master.py +++ b/paddle_fl/examples/gru4rec_demo/fl_master.py @@ -1,8 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.fluid as fluid import paddle_fl as fl from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory + class Model(object): def __init__(self): pass @@ -34,7 +49,8 @@ class Model(object): size=hid_size * 3, param_attr=fluid.ParamAttr( initializer=fluid.initializer.Uniform( - low=init_low_bound, high=init_high_bound), + low=init_low_bound, + high=init_high_bound), learning_rate=gru_lr_x)) gru_h0 = fluid.layers.dynamic_gru( input=fc0, @@ -45,12 +61,13 @@ class Model(object): learning_rate=gru_lr_x)) self.fc = fluid.layers.fc(input=gru_h0, - size=vocab_size, - act='softmax', - param_attr=fluid.ParamAttr( - initializer=fluid.initializer.Uniform( - low=init_low_bound, high=init_high_bound), - learning_rate=fc_lr_x)) + size=vocab_size, + act='softmax', + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Uniform( + low=init_low_bound, + high=init_high_bound), + learning_rate=fc_lr_x)) cost = fluid.layers.cross_entropy( input=self.fc, label=self.dst_wordseq) self.acc = fluid.layers.accuracy( @@ -59,7 +76,6 @@ class Model(object): self.startup_program = fluid.default_startup_program() - model = Model() model.gru4rec_network() @@ -69,7 +85,8 @@ job_generator.set_optimizer(optimizer) job_generator.set_losses([model.loss]) job_generator.set_startup_program(model.startup_program) job_generator.set_infer_feed_and_target_names( - [model.src_wordseq.name, model.dst_wordseq.name], [model.loss.name, model.acc.name]) + [model.src_wordseq.name, model.dst_wordseq.name], + [model.loss.name, model.acc.name]) build_strategy = FLStrategyFactory() build_strategy.fed_avg = True diff --git a/paddle_fl/examples/gru4rec_demo/fl_scheduler.py b/paddle_fl/examples/gru4rec_demo/fl_scheduler.py index 346529fd6caadef06d1e46078fa873da264a7507..bc75fa66eae90aa487ecb4b9f7e7bf84c048f5d7 100644 --- a/paddle_fl/examples/gru4rec_demo/fl_scheduler.py +++ b/paddle_fl/examples/gru4rec_demo/fl_scheduler.py @@ -1,9 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.scheduler.agent_master import FLScheduler worker_num = 4 server_num = 1 # Define the number of worker/server and the port for scheduler -scheduler = FLScheduler(worker_num,server_num,port=9091) +scheduler = FLScheduler(worker_num, server_num, port=9091) scheduler.set_sample_worker_num(4) scheduler.init_env() print("init env done.") diff --git a/paddle_fl/examples/gru4rec_demo/fl_server.py b/paddle_fl/examples/gru4rec_demo/fl_server.py index 39056e82d99fb924f52c201e0fb230b6bc1626a1..3740982b54613169074c95e510809d55ed54121b 100644 --- a/paddle_fl/examples/gru4rec_demo/fl_server.py +++ b/paddle_fl/examples/gru4rec_demo/fl_server.py @@ -21,7 +21,7 @@ server_id = 0 job_path = "fl_job_config" job = FLRunTimeJob() job.load_server_job(job_path, server_id) -job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler +job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler server.set_server_job(job) -server._current_ep = "127.0.0.1:8181" # IP address for server +server._current_ep = "127.0.0.1:8181" # IP address for server server.start() diff --git a/paddle_fl/examples/gru4rec_demo/fl_trainer.py b/paddle_fl/examples/gru4rec_demo/fl_trainer.py index b8416f739d1b0dc7cb493768cf1c9283214778c4..1bf229a505aee6870c36556fb0b9b90b58067141 100644 --- a/paddle_fl/examples/gru4rec_demo/fl_trainer.py +++ b/paddle_fl/examples/gru4rec_demo/fl_trainer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.master.fl_job import FLRunTimeJob from paddle_fl.reader.gru4rec_reader import Gru4rec_Reader @@ -6,21 +20,26 @@ import numpy as np import sys import os import logging -logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) +logging.basicConfig( + filename="test.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG) -trainer_id = int(sys.argv[1]) # trainer id for each guest +trainer_id = int(sys.argv[1]) # trainer id for each guest place = fluid.CPUPlace() train_file_dir = "mid_data/node4/%d/" % trainer_id job_path = "fl_job_config" job = FLRunTimeJob() job.load_trainer_job(job_path, trainer_id) -job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer +job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer trainer = FLTrainerFactory().create_fl_trainer(job) -trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) +trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id) trainer.start() r = Gru4rec_Reader() -train_reader = r.reader(train_file_dir, place, batch_size = 125) +train_reader = r.reader(train_file_dir, place, batch_size=125) output_folder = "model_node4" step_i = 0 @@ -30,8 +49,7 @@ while not trainer.stop(): train_step = 0 for data in train_reader(): #print(np.array(data['src_wordseq'])) - ret_avg_cost = trainer.run(feed=data, - fetch=["mean_0.tmp_0"]) + ret_avg_cost = trainer.run(feed=data, fetch=["mean_0.tmp_0"]) train_step += 1 if train_step == trainer._step: break diff --git a/paddle_fl/examples/k8s_deployment/master.yaml b/paddle_fl/examples/k8s_deployment/master.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f7e3451da77ad3288ba92182a777e0aae67d6749 --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/master.yaml @@ -0,0 +1,209 @@ +apiVersion: v1 +kind: Service +metadata: + name: fl-master +spec: + type: LoadBalancer + ports: + - name: fl-master + port: 8000 + targetPort: 8000 + selector: + app: fl-master + +--- + +apiVersion: v1 +kind: Service +metadata: + name: fl-scheduler +spec: + type: LoadBalancer + ports: + - name: fl-scheduler + port: 9091 + targetPort: 9091 + selector: + app: fl-scheduler + +--- + +apiVersion: v1 +kind: Service +metadata: + name: fl-server +spec: + type: LoadBalancer + ports: + - name: fl-server + port: 8181 + targetPort: 8181 + selector: + app: fl-server + +--- + +apiVersion: v1 +kind: Service +metadata: + name: trainer0 +spec: + type: LoadBalancer + ports: + - name: trainer0 + port: 9000 + targetPort: 9000 + selector: + app: trainer0 + +--- + +apiVersion: v1 +kind: Service +metadata: + name: trainer1 +spec: + type: LoadBalancer + ports: + - name: trainer1 + port: 9001 + targetPort: 9001 + selector: + app: trainer1 + +--- + +apiVersion: apps/v1beta1 +kind: Deployment +metadata: + name: fl-master + labels: + app: fl-master +spec: + replicas: 1 + template: + metadata: + name: fl-master + labels: + app: fl-master + spec: + containers: + - name: fl-master + image: hub.baidubce.com/paddlefl/paddlefl:v3 + imagePullPolicy: Always + ports: + - containerPort: 8000 + workingDir: /root/k8s_deployment/master + command: ['/bin/bash'] + args: ['run_master.sh'] + +--- + +apiVersion: apps/v1beta1 +kind: Deployment +metadata: + name: fl-scheduler + labels: + app: fl-scheduler +spec: + replicas: 1 + template: + metadata: + name: fl-scheduler + labels: + app: fl-scheduler + spec: + containers: + - name: fl-scheduler + image: hub.baidubce.com/paddlefl/paddlefl:v3 + imagePullPolicy: Always + ports: + - containerPort: 9091 + workingDir: /root/k8s_deployment/scheduler + command: ['/bin/bash'] + args: ['run_scheduler.sh'] + +--- + +apiVersion: apps/v1beta1 +kind: Deployment +metadata: + name: fl-server + labels: + app: fl-server +spec: + replicas: 1 + template: + metadata: + name: fl-server + labels: + app: fl-server + spec: + containers: + - name: fl-server + image: hub.baidubce.com/paddlefl/paddlefl:v3 + imagePullPolicy: Always + ports: + - containerPort: 8181 + workingDir: /root/k8s_deployment/server + command: ['/bin/bash'] + args: ['run_server.sh'] + env: + - name: POD_IP + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: status.podIP +--- + +apiVersion: apps/v1beta1 +kind: Deployment +metadata: + name: trainer0 + labels: + app: trainer0 +spec: + replicas: 1 + template: + metadata: + name: trainer0 + labels: + app: trainer0 + spec: + containers: + - name: trainer0 + image: hub.baidubce.com/paddlefl/paddlefl:v3 + imagePullPolicy: Always + ports: + - containerPort: 9000 + workingDir: /root/k8s_deployment/trainer0 + command: ['/bin/bash'] + args: ['test_trainer.sh'] + +--- + +apiVersion: apps/v1beta1 +kind: Deployment +metadata: + name: trainer1 + labels: + app: trainer1 +spec: + replicas: 1 + template: + metadata: + name: trainer1 + labels: + app: trainer1 + spec: + containers: + - name: trainer1 + image: hub.baidubce.com/paddlefl/paddlefl:v3 + imagePullPolicy: Always + ports: + - containerPort: 9001 + workingDir: /root/k8s_deployment/trainer1 + command: ['/bin/bash'] + args: ['test_trainer.sh'] + +--- diff --git a/paddle_fl/examples/k8s_deployment/master/fl_master.py b/paddle_fl/examples/k8s_deployment/master/fl_master.py new file mode 100644 index 0000000000000000000000000000000000000000..96aab47ae7014d3753d872e67fca1a0b593629de --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/master/fl_master.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import paddle.fluid as fluid +import os +import paddle_fl as fl +from paddle_fl.core.master.job_generator import JobGenerator +from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory + + +def parse_args(): + parser = argparse.ArgumentParser(description="master") + parser.add_argument( + '--trainer_num', + type=int, + default=2, + help='number of trainer(default: 2)') + + return parser.parse_args() + + +class Model(object): + def __init__(self): + pass + + def mlp(self, inputs, label, hidden_size=128): + self.concat = fluid.layers.concat(inputs, axis=1) + self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu') + self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu') + self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax') + self.sum_cost = fluid.layers.cross_entropy( + input=self.predict, label=label) + self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) + self.loss = fluid.layers.reduce_mean(self.sum_cost) + self.startup_program = fluid.default_startup_program() + +inputs = [fluid.layers.data( \ + name=str(slot_id), shape=[5], + dtype="float32") + for slot_id in range(3)] +label = fluid.layers.data( \ + name="label", + shape=[1], + dtype='int64') + +model = Model() +model.mlp(inputs, label) + +job_generator = JobGenerator() +optimizer = fluid.optimizer.SGD(learning_rate=0.1) +job_generator.set_optimizer(optimizer) +job_generator.set_losses([model.loss]) +job_generator.set_startup_program(model.startup_program) +job_generator.set_infer_feed_and_target_names([x.name for x in inputs], + [model.predict.name]) + +build_strategy = FLStrategyFactory() +build_strategy.fed_avg = True +build_strategy.inner_step = 10 +strategy = build_strategy.create_fl_strategy() + +# endpoints will be collected through the cluster +# in this example, we suppose endpoints have been collected +server_service_ip = os.environ['FL_SERVER_SERVICE_HOST'] + ":" + os.environ[ + 'FL_SERVER_SERVICE_PORT_FL_SERVER'] +service_endpoints = [server_service_ip] +pod_endpoints = ["0.0.0.0:8181"] +output = "fl_job_config" +args = parse_args() +num_trainer = args.trainer_num +#job_generator.generate_fl_job( +# strategy, server_endpoints=endpoints, worker_num=num_trainer, output=output) +# fl_job_config will be dispatched to workers + +job_generator.generate_fl_job_for_k8s( + strategy, + server_pod_endpoints=pod_endpoints, + server_service_endpoints=service_endpoints, + worker_num=2, + output=output) diff --git a/paddle_fl/examples/k8s_deployment/master/run_master.sh b/paddle_fl/examples/k8s_deployment/master/run_master.sh new file mode 100644 index 0000000000000000000000000000000000000000..b07187eb3078e3ce26e8f3234a4933b499557f4e --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/master/run_master.sh @@ -0,0 +1,3 @@ +python fl_master.py --trainer_num 2 +tar -zcvf fl_job_config.tar.gz fl_job_config +python -m SimpleHTTPServer 8000 diff --git a/paddle_fl/examples/k8s_deployment/scheduler/fl_scheduler.py b/paddle_fl/examples/k8s_deployment/scheduler/fl_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..232825b3c3d463f43750df0cfa698119e993118b --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/scheduler/fl_scheduler.py @@ -0,0 +1,39 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from paddle_fl.core.scheduler.agent_master import FLScheduler + + +def parse_args(): + parser = argparse.ArgumentParser(description="scheduler") + parser.add_argument( + '--trainer_num', + type=int, + default=2, + help='number trainers(default: 2)') + + return parser.parse_args() + + +args = parse_args() +num_trainer = args.trainer_num +worker_num = num_trainer +server_num = 1 +# Define the number of worker/server and the port for scheduler +scheduler = FLScheduler(worker_num, server_num, port=9091) +scheduler.set_sample_worker_num(worker_num) +scheduler.init_env() +print("init env done.") +scheduler.start_fl_training() diff --git a/paddle_fl/examples/k8s_deployment/scheduler/run_scheduler.sh b/paddle_fl/examples/k8s_deployment/scheduler/run_scheduler.sh new file mode 100644 index 0000000000000000000000000000000000000000..13494e2c365b64a76f488e4b6a595ba520803171 --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/scheduler/run_scheduler.sh @@ -0,0 +1,3 @@ +python fl_scheduler.py --trainer_num 2 + + diff --git a/paddle_fl/examples/k8s_deployment/server/fl_server.py b/paddle_fl/examples/k8s_deployment/server/fl_server.py new file mode 100644 index 0000000000000000000000000000000000000000..197b2f9d4f580a96b979aac8de02a70a8f65d7ce --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/server/fl_server.py @@ -0,0 +1,34 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle_fl as fl +import os +import paddle.fluid as fluid +from paddle_fl.core.server.fl_server import FLServer +from paddle_fl.core.master.fl_job import FLRunTimeJob +import time +server = FLServer() +server_id = 0 +job_path = "fl_job_config" +job = FLRunTimeJob() +job.load_server_job(job_path, server_id) +job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[ + 'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER'] # IP address for scheduler +#job._endpoints = os.environ['POD_IP'] + ":" + os.environ['FL_SERVER_SERVICE_PORT_FL_SERVER'] # IP address for server +server.set_server_job(job) +server._current_ep = os.environ['FL_SERVER_SERVICE_HOST'] + ":" + os.environ[ + 'FL_SERVER_SERVICE_PORT_FL_SERVER'] # IP address for server +print(job._scheduler_ep, server._current_ep) +server.start() +print("connect") diff --git a/paddle_fl/examples/k8s_deployment/server/run_server.sh b/paddle_fl/examples/k8s_deployment/server/run_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..203d96c07c197b28efa282efd224443dd7bd3b7b --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/server/run_server.sh @@ -0,0 +1,9 @@ +export GLOG_v=3 +wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz +while [ $? -ne 0 ] +do + sleep 3 + wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz +done +tar -xf fl_job_config.tar.gz +python -u fl_server.py > server.log 2>&1 diff --git a/paddle_fl/examples/k8s_deployment/trainer0/fl_trainer.py b/paddle_fl/examples/k8s_deployment/trainer0/fl_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4cbce925ef6c401731aab8d5bd5dcc36610347e5 --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/trainer0/fl_trainer.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory +from paddle_fl.core.master.fl_job import FLRunTimeJob +import numpy as np +import sys +import os +import logging +import time +logging.basicConfig( + filename="test.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG) + + +def reader(): + for i in range(1000): + data_dict = {} + for i in range(3): + data_dict[str(i)] = np.random.rand(1, 5).astype('float32') + data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64') + yield data_dict + + +trainer_id = int(sys.argv[1]) # trainer id for each guest +job_path = "fl_job_config" +job = FLRunTimeJob() +job.load_trainer_job(job_path, trainer_id) +#job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer +job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[ + 'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER'] +trainer = FLTrainerFactory().create_fl_trainer(job) +#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) +trainer._current_ep = os.environ['TRAINER0_SERVICE_HOST'] + ":" + os.environ[ + 'TRAINER0_SERVICE_PORT_TRAINER0'] +trainer.start() +print(trainer._scheduler_ep, trainer._current_ep) +output_folder = "fl_model" +epoch_id = 0 +while not trainer.stop(): + print("batch %d start train" % (epoch_id)) + train_step = 0 + for data in reader(): + trainer.run(feed=data, fetch=[]) + train_step += 1 + if train_step == trainer._step: + break + epoch_id += 1 + if epoch_id % 5 == 0: + trainer.save_inference_program(output_folder) diff --git a/paddle_fl/examples/k8s_deployment/trainer0/test_trainer.sh b/paddle_fl/examples/k8s_deployment/trainer0/test_trainer.sh new file mode 100644 index 0000000000000000000000000000000000000000..cfcedcf8342b744d874bc3ad5df2248a4d25b569 --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/trainer0/test_trainer.sh @@ -0,0 +1,9 @@ +wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz +while [ $? -ne 0 ] +do + sleep 3 + wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz +done +tar -xf fl_job_config.tar.gz +sleep 10 +python -u fl_trainer.py 0 diff --git a/paddle_fl/examples/k8s_deployment/trainer1/fl_trainer.py b/paddle_fl/examples/k8s_deployment/trainer1/fl_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed5dd099cce7e85b1ddb47857551b84d67c78c6 --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/trainer1/fl_trainer.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory +from paddle_fl.core.master.fl_job import FLRunTimeJob +import numpy as np +import sys +import os +import logging +import time +logging.basicConfig( + filename="test.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG) + + +def reader(): + for i in range(1000): + data_dict = {} + for i in range(3): + data_dict[str(i)] = np.random.rand(1, 5).astype('float32') + data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64') + yield data_dict + + +trainer_id = int(sys.argv[1]) # trainer id for each guest +job_path = "fl_job_config" +job = FLRunTimeJob() +job.load_trainer_job(job_path, trainer_id) +#job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer +job._scheduler_ep = os.environ['FL_SCHEDULER_SERVICE_HOST'] + ":" + os.environ[ + 'FL_SCHEDULER_SERVICE_PORT_FL_SCHEDULER'] +trainer = FLTrainerFactory().create_fl_trainer(job) +#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) +trainer._current_ep = os.environ['TRAINER1_SERVICE_HOST'] + ":" + os.environ[ + 'TRAINER1_SERVICE_PORT_TRAINER1'] +trainer.start() +print(trainer._scheduler_ep, trainer._current_ep) +output_folder = "fl_model" +epoch_id = 0 +while not trainer.stop(): + print("batch %d start train" % (epoch_id)) + train_step = 0 + for data in reader(): + trainer.run(feed=data, fetch=[]) + train_step += 1 + if train_step == trainer._step: + break + epoch_id += 1 + if epoch_id % 5 == 0: + trainer.save_inference_program(output_folder) diff --git a/paddle_fl/examples/k8s_deployment/trainer1/test_trainer.sh b/paddle_fl/examples/k8s_deployment/trainer1/test_trainer.sh new file mode 100644 index 0000000000000000000000000000000000000000..82dd83ac826e9a00d8b7218f869463ddb677f851 --- /dev/null +++ b/paddle_fl/examples/k8s_deployment/trainer1/test_trainer.sh @@ -0,0 +1,9 @@ +wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz +while [ $? -ne 0 ] +do + sleep 3 + wget ${FL_MASTER_SERVICE_HOST}:${FL_MASTER_SERVICE_PORT_FL_MASTER}/fl_job_config.tar.gz +done +tar -xf fl_job_config.tar.gz +sleep 10 +python -u fl_trainer.py 1 diff --git a/paddle_fl/examples/secagg_demo/fl_master.py b/paddle_fl/examples/secagg_demo/fl_master.py index e0e19e6fc77392a0d628d287bf53295977ab7e81..c6245d5e08ec36d7c41cd8c97c1485ddead68082 100644 --- a/paddle_fl/examples/secagg_demo/fl_master.py +++ b/paddle_fl/examples/secagg_demo/fl_master.py @@ -1,8 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.fluid as fluid import paddle_fl as fl from paddle_fl.core.master.job_generator import JobGenerator from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory + class Model(object): def __init__(self): pass @@ -14,12 +29,17 @@ class Model(object): param_attrs = fluid.ParamAttr( name="fc_0.w_0", initializer=fluid.initializer.ConstantInitializer(0.0)) - self.predict = fluid.layers.fc(input=inputs, size=10, act='softmax', param_attr=param_attrs) - self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) + self.predict = fluid.layers.fc(input=inputs, + size=10, + act='softmax', + param_attr=param_attrs) + self.sum_cost = fluid.layers.cross_entropy( + input=self.predict, label=label) self.loss = fluid.layers.mean(self.sum_cost) self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) self.startup_program = fluid.default_startup_program() + inputs = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32') label = fluid.layers.data(name='y', shape=[1], dtype='int64') @@ -31,15 +51,16 @@ optimizer = fluid.optimizer.SGD(learning_rate=0.01) job_generator.set_optimizer(optimizer) job_generator.set_losses([model.loss]) job_generator.set_startup_program(model.startup_program) -job_generator.set_infer_feed_and_target_names( - [inputs.name, label.name], [model.loss.name]) +job_generator.set_infer_feed_and_target_names([inputs.name, label.name], + [model.loss.name]) build_strategy = FLStrategyFactory() #build_strategy.fed_avg = True build_strategy.sec_agg = True param_name_list = [] -param_name_list.append("fc_0.w_0.opti.trainer_") # need trainer_id when running +param_name_list.append( + "fc_0.w_0.opti.trainer_") # need trainer_id when running param_name_list.append("fc_0.b_0.opti.trainer_") build_strategy.param_name_list = param_name_list diff --git a/paddle_fl/examples/secagg_demo/fl_scheduler.py b/paddle_fl/examples/secagg_demo/fl_scheduler.py index 649f74768ef76cf1f8af52319b3818144f37e98a..1d8c21b8307ce8b1d20f2ecea1557bc2fb56e0c1 100644 --- a/paddle_fl/examples/secagg_demo/fl_scheduler.py +++ b/paddle_fl/examples/secagg_demo/fl_scheduler.py @@ -1,9 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.scheduler.agent_master import FLScheduler worker_num = 2 server_num = 1 -scheduler = FLScheduler(worker_num,server_num,port=9091) +scheduler = FLScheduler(worker_num, server_num, port=9091) scheduler.set_sample_worker_num(worker_num) scheduler.init_env() print("init env done.") diff --git a/paddle_fl/examples/secagg_demo/fl_server.py b/paddle_fl/examples/secagg_demo/fl_server.py index 529df8da4079fbbd217c58a857f7ab8a3c307586..2bc79fff528bc8e52e353cc56ca17618d6f4acca 100644 --- a/paddle_fl/examples/secagg_demo/fl_server.py +++ b/paddle_fl/examples/secagg_demo/fl_server.py @@ -21,8 +21,8 @@ server_id = 0 job_path = "fl_job_config" job = FLRunTimeJob() job.load_server_job(job_path, server_id) -job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler +job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler server.set_server_job(job) -server._current_ep = "127.0.0.1:8181" # IP address for server +server._current_ep = "127.0.0.1:8181" # IP address for server server.start() print("connect") diff --git a/paddle_fl/examples/secagg_demo/fl_trainer.py b/paddle_fl/examples/secagg_demo/fl_trainer.py index 7b08e8724a67219654c7f94809fc281353202ad2..99d0a2f465b4df6bd4fcd6c47fbea4cc19f9f901 100644 --- a/paddle_fl/examples/secagg_demo/fl_trainer.py +++ b/paddle_fl/examples/secagg_demo/fl_trainer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.master.fl_job import FLRunTimeJob import numpy @@ -11,27 +25,32 @@ import math import hashlib import hmac -logging.basicConfig(filename="log/test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) +logging.basicConfig( + filename="log/test.log", + filemode="w", + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%d-%M-%Y %H:%M:%S", + level=logging.DEBUG) logger = logging.getLogger("FLTrainer") BATCH_SIZE = 64 train_reader = paddle.batch( - paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), batch_size=BATCH_SIZE) -test_reader = paddle.batch( - paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) +test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) trainer_num = 2 -trainer_id = int(sys.argv[1]) # trainer id for each guest +trainer_id = int(sys.argv[1]) # trainer id for each guest job_path = "fl_job_config" job = FLRunTimeJob() job.load_trainer_job(job_path, trainer_id) -job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer +job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer trainer = FLTrainerFactory().create_fl_trainer(job) trainer.trainer_id = trainer_id -trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) +trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id) trainer.trainer_num = trainer_num trainer.key_dir = "./keys/" trainer.start() @@ -47,8 +66,8 @@ feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace()) # for test test_program = trainer._main_program.clone(for_test=True) -def train_test(train_test_program, - train_test_feed, train_test_reader): + +def train_test(train_test_program, train_test_feed, train_test_reader): acc_set = [] avg_loss_set = [] for test_data in train_test_reader(): @@ -61,6 +80,8 @@ def train_test(train_test_program, acc_val_mean = numpy.array(acc_set).mean() avg_loss_val_mean = numpy.array(avg_loss_set).mean() return avg_loss_val_mean, acc_val_mean + + # for test while not trainer.stop(): @@ -71,15 +92,18 @@ while not trainer.stop(): step_i += 1 trainer.step_id = step_i accuracy, = trainer.run(feed=feeder.feed(data), - fetch=["accuracy_0.tmp_0"]) + fetch=["accuracy_0.tmp_0"]) if step_i % 100 == 0: - print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0])) + print("Epoch: {0}, step: {1}, accuracy: {2}".format( + epoch_id, step_i, accuracy[0])) print(step_i) - avg_loss_val, acc_val = train_test(train_test_program=test_program, - train_test_reader=test_reader, - train_test_feed=feeder) - print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val)) + avg_loss_val, acc_val = train_test( + train_test_program=test_program, + train_test_reader=test_reader, + train_test_feed=feeder) + print("Test with Epoch %d, avg_cost: %s, acc: %s" % + (epoch_id, avg_loss_val, acc_val)) if epoch_id > 40: break diff --git a/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt b/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt index cfb5684de948c423374fd2a820ccb29c654c8ac2..d33390b5dfff21882d7de124710389ede6b5b8ef 100644 --- a/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt +++ b/paddle_fl/examples/secagg_demo/keys/0_pub_key.txt @@ -1 +1 @@ -2438748580808349511143047663636683775879288034436941526695550498623461587527621346172907651006831789701999970929529915459467532662545948308044143788306668377086821294492459623439935894424167712515718436351900091777957477710004777078638317806960364609629258387413979203403741893205419691425902518810085451041187685334971769087054033027561974230347468587825700834108657561999305482311897914109364221430821533207693682979777541616125499682380618775029176238891407643926372043660610226672413497764635239787000143341827693941253721638947580506197728500367325524850325027980531066702962726949006217630290236644410746181942256812170056772600756232506116738493114591218127323741133163913140583529684827066023347088796194846253682954154504336640429027403657831470993825621749318372332546269811820953216261135662418531598954663771775691648448615131802158937156803324423733802071166119966224716088242291968098450309032800335049617861465 \ No newline at end of file +2438748580808349511143047663636683775879288034436941526695550498623461587527621346172907651006831789701999970929529915459467532662545948308044143788306668377086821294492459623439935894424167712515718436351900091777957477710004777078638317806960364609629258387413979203403741893205419691425902518810085451041187685334971769087054033027561974230347468587825700834108657561999305482311897914109364221430821533207693682979777541616125499682380618775029176238891407643926372043660610226672413497764635239787000143341827693941253721638947580506197728500367325524850325027980531066702962726949006217630290236644410746181942256812170056772600756232506116738493114591218127323741133163913140583529684827066023347088796194846253682954154504336640429027403657831470993825621749318372332546269811820953216261135662418531598954663771775691648448615131802158937156803324423733802071166119966224716088242291968098450309032800335049617861465 diff --git a/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt b/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt index a88733c160ba5052b4d0b445a477357cbb88417a..32ecdc45c94580031a6f62573995fdef617f898f 100644 --- a/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt +++ b/paddle_fl/examples/secagg_demo/keys/1_pub_key.txt @@ -1 +1 @@ -2514645349791449916465355128335954929464444612258498884322250411584328344530925790221013632799576102047787468232441470392901627580493383471719532612816534848391408601920539266665550346246343040368608757429591392784807798812848893304745441721044204602414415725300562075953290154457726382683020925406187524758708694161689285261293920782115270550960717687322240202298426363733065475031448888618026711435993053991653780694485897784243346859087028197560255857091562150381619708471192080403868115055265681423866891707972356449212236920210992128063075734725292359699578940877165584175851667803049667020005978253573912096358469050409541237248593428640028573437783747958382446666712845099931003578559688134007238753677011257181086592677636834099341020870502521085827878680362572013623469761170961943916356175726242515624843837354222489536469472365937707615959657846428001149463801728084949088483783942894784080451399167982957006474558 \ No newline at end of file +2514645349791449916465355128335954929464444612258498884322250411584328344530925790221013632799576102047787468232441470392901627580493383471719532612816534848391408601920539266665550346246343040368608757429591392784807798812848893304745441721044204602414415725300562075953290154457726382683020925406187524758708694161689285261293920782115270550960717687322240202298426363733065475031448888618026711435993053991653780694485897784243346859087028197560255857091562150381619708471192080403868115055265681423866891707972356449212236920210992128063075734725292359699578940877165584175851667803049667020005978253573912096358469050409541237248593428640028573437783747958382446666712845099931003578559688134007238753677011257181086592677636834099341020870502521085827878680362572013623469761170961943916356175726242515624843837354222489536469472365937707615959657846428001149463801728084949088483783942894784080451399167982957006474558 diff --git a/paddle_fl/examples/submitter_demo/conf.txt b/paddle_fl/examples/submitter_demo/conf.txt index 0880739e809063bc21531a92bb205ee2df4f169a..f2f0d48d028de7afb9fb8ecd75961344bdbaf002 100644 --- a/paddle_fl/examples/submitter_demo/conf.txt +++ b/paddle_fl/examples/submitter_demo/conf.txt @@ -21,4 +21,3 @@ server=yq01-hpc-lvliang01-smart-master.dmop.baidu.com python_tar=./python.tar.gz wheel=./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64.whl - diff --git a/paddle_fl/examples/submitter_demo/model.py b/paddle_fl/examples/submitter_demo/model.py index f07549b9048c7b09a7fe25f9961db64fc8a820bb..2ead34a71835492df47e05f1d369e271433882f7 100644 --- a/paddle_fl/examples/submitter_demo/model.py +++ b/paddle_fl/examples/submitter_demo/model.py @@ -1,5 +1,20 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.fluid as fluid + class Model(object): def __init__(self): pass @@ -9,8 +24,8 @@ class Model(object): self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu') self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu') self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax') - self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) + self.sum_cost = fluid.layers.cross_entropy( + input=self.predict, label=label) self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) self.loss = fluid.layers.reduce_mean(self.sum_cost) self.startup_program = fluid.default_startup_program() - diff --git a/paddle_fl/examples/submitter_demo/scheduler_client.py b/paddle_fl/examples/submitter_demo/scheduler_client.py index eff68df845129feaca093b4277f1970dcfb911e9..2e903353585c2ec418614be1a373afcec72395aa 100644 --- a/paddle_fl/examples/submitter_demo/scheduler_client.py +++ b/paddle_fl/examples/submitter_demo/scheduler_client.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import socket import random @@ -49,6 +63,7 @@ default_dict = { "wheel": "./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64-0.whl" } + def load_conf(conf_file, local_dict): with open(conf_file) as fin: for line in fin: @@ -58,6 +73,7 @@ def load_conf(conf_file, local_dict): local_dict[group[0]] = group[1] return local_dict + client = HPCClient() default_dict = load_conf(sys.argv[1], default_dict) @@ -94,9 +110,11 @@ all_ips_ready = False ip_list = [] -scheduler = FLScheduler(int(default_dict["worker_nodes"]), - int(default_dict["server_nodes"]), - port=random_port, socket=zmq_socket) +scheduler = FLScheduler( + int(default_dict["worker_nodes"]), + int(default_dict["server_nodes"]), + port=random_port, + socket=zmq_socket) scheduler.set_sample_worker_num(int(default_dict["worker_nodes"])) @@ -121,12 +139,14 @@ print(ip_list) #allocate the role of each endpoint and their ids ip_role = {} for i in range(len(ip_list)): - if i < int(default_dict["server_nodes"]): - ip_role[ip_list[i]] = 'server%d' % i - else: - ip_role[ip_list[i]] = 'trainer%d' % (i-int(default_dict["server_nodes"])) + if i < int(default_dict["server_nodes"]): + ip_role[ip_list[i]] = 'server%d' % i + else: + ip_role[ip_list[i]] = 'trainer%d' % ( + i - int(default_dict["server_nodes"])) print(ip_role) + def job_generate(): #generate a fl job which is the same as fl_master inputs = [fluid.layers.data( \ @@ -146,8 +166,8 @@ def job_generate(): job_generator.set_optimizer(optimizer) job_generator.set_losses([model.loss]) job_generator.set_startup_program(model.startup_program) - job_generator.set_infer_feed_and_target_names( - [x.name for x in inputs], [model.predict.name]) + job_generator.set_infer_feed_and_target_names([x.name for x in inputs], + [model.predict.name]) build_strategy = FLStrategyFactory() build_strategy.fed_avg = True @@ -157,20 +177,24 @@ def job_generate(): # endpoints will be collected through the cluster # in this example, we suppose endpoints have been collected server_ip = ["{}".format(ip_list[0])] - + output = "job_config" job_generator.generate_fl_job( - strategy, server_endpoints=server_ip, worker_num=int(default_dict["worker_nodes"]), output=output) - + strategy, + server_endpoints=server_ip, + worker_num=int(default_dict["worker_nodes"]), + output=output) + file_list = os.listdir(output) for file in file_list: - tar = tarfile.open('{}/{}.tar.gz'.format(output,file),'w:gz') - for root,dir,files in os.walk("{}/{}".format(output,file)): - for f in files: - fullpath = os.path.join(root,f) - tar.add(fullpath) + tar = tarfile.open('{}/{}.tar.gz'.format(output, file), 'w:gz') + for root, dir, files in os.walk("{}/{}".format(output, file)): + for f in files: + fullpath = os.path.join(root, f) + tar.add(fullpath) tar.close() + job_generate() #send the allocated rolls to the remote endpoints diff --git a/paddle_fl/examples/submitter_demo/train_program.py b/paddle_fl/examples/submitter_demo/train_program.py index a5f10fab0eb2fe8aae506f71f4961c5a5ea5cdfa..bba1482df0ed198a348e67e51ba79e60a9bc3a4b 100644 --- a/paddle_fl/examples/submitter_demo/train_program.py +++ b/paddle_fl/examples/submitter_demo/train_program.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import socket import random import zmq @@ -13,7 +27,6 @@ import sys import logging import time - random_port = 60001 scheduler_conf = {} @@ -31,8 +44,7 @@ download_url = "{}:8080".format(scheduler_ip[0]) print(download_url) context = zmq.Context() zmq_socket = context.socket(zmq.REQ) -zmq_socket.connect( - "tcp://{}".format(scheduler_conf["ENDPOINT"])) +zmq_socket.connect("tcp://{}".format(scheduler_conf["ENDPOINT"])) zmq_socket.send("ENDPOINT\t{}".format(endpoint)) message = zmq_socket.recv() print(message) @@ -47,7 +59,7 @@ while True: if group[0] == "WAIT": continue else: - os.system("wget {}/job_config/{}.tar.gz".format(download_url,message)) + os.system("wget {}/job_config/{}.tar.gz".format(download_url, message)) print(message) break @@ -71,6 +83,7 @@ if 'server' in message: server._current_ep = endpoint server.start() else: + def reader(): for i in range(1000): data_dict = {} @@ -96,7 +109,7 @@ else: for data in reader(): trainer.run(feed=data, fetch=[]) step_i += 1 - if step_i == trainer._step: + if step_i == trainer._step: break epoch_id += 1 if epoch_id % 5 == 0: diff --git a/paddle_fl/reader/gru4rec_reader.py b/paddle_fl/reader/gru4rec_reader.py index dadb73136bccde03377c5c1136e30564e00dcb81..6e95909da989e73dc4fcc9c539967cca8b8f2024 100644 --- a/paddle_fl/reader/gru4rec_reader.py +++ b/paddle_fl/reader/gru4rec_reader.py @@ -1,7 +1,22 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.fluid as fluid import numpy as np import os + class Gru4rec_Reader: def __init__(self): pass @@ -21,7 +36,6 @@ class Gru4rec_Reader: res.set_lod([lod]) return res - def lod_reader(self, reader, place): def feed_reader(): for data in reader(): @@ -33,12 +47,14 @@ class Gru4rec_Reader: fe_data["src_wordseq"] = lod_src_wordseq fe_data["dst_wordseq"] = lod_dst_wordseq yield fe_data + return feed_reader def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False): """ Create a batched reader. """ + def batch_reader(): r = reader() b = [] @@ -66,11 +82,11 @@ class Gru4rec_Reader: # Batch size check batch_size = int(batch_size) if batch_size <= 0: - raise ValueError("batch_size should be a positive integeral value, " - "but got batch_size={}".format(batch_size)) + raise ValueError( + "batch_size should be a positive integeral value, " + "but got batch_size={}".format(batch_size)) return batch_reader - def reader_creator(self, file_dir): def reader(): files = os.listdir(file_dir) @@ -82,10 +98,12 @@ class Gru4rec_Reader: src_seq = l[:len(l) - 1] trg_seq = l[1:] yield src_seq, trg_seq + return reader def reader(self, file_dir, place, batch_size=5): """ prepare the English Pann Treebank (PTB) data """ print("start constuct word dict") - reader = self.sort_batch(self.reader_creator(file_dir), batch_size, batch_size * 20) + reader = self.sort_batch( + self.reader_creator(file_dir), batch_size, batch_size * 20) return self.lod_reader(reader, place) diff --git a/paddle_fl/version.py b/paddle_fl/version.py index 2f9278ce368bcbfc0d65b147d84fcd2938414638..2c95b454cebb44b8068493b4bc3002ee09f7df27 100644 --- a/paddle_fl/version.py +++ b/paddle_fl/version.py @@ -12,6 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PaddleFL version string """ -fl_version = "0.1.10" -module_proto_version = "0.1.10" - +fl_version = "0.1.11" +module_proto_version = "0.1.11" diff --git a/setup.py b/setup.py index ba29a1496d39949b84f3b74eeec935d831e57c01..c12f11edd4c8ecaa95034ca3266e9ad363bf30f4 100644 --- a/setup.py +++ b/setup.py @@ -26,10 +26,12 @@ from paddle_fl.version import fl_version def python_version(): return [int(v) for v in platform.python_version().split(".")] + max_version, mid_version, min_version = python_version() REQUIRED_PACKAGES = [ - 'six >= 1.10.0', 'protobuf >= 3.1.0','paddlepaddle >= 1.6', 'zmq', 'paddlepaddle-gpu >= 1.6' + 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.6', 'zmq', + 'paddlepaddle-gpu >= 1.6' ] if max_version < 3: @@ -42,8 +44,7 @@ REQUIRED_PACKAGES += ["unittest2"] setup( name='paddle_fl', version=fl_version.replace('-', ''), - description= - ('Federated Deep Learning Package Based on PaddlePaddle.'), + description=('Federated Deep Learning Package Based on PaddlePaddle.'), long_description='', url='https://github.com/PaddlePaddle/PaddleFL', author='PaddlePaddle Author', @@ -70,4 +71,5 @@ setup( 'Topic :: Software Development :: Libraries :: Python Modules', ], license='Apache 2.0', - keywords=('paddle_fl paddlepaddle multi-task transfer distributed-training')) + keywords=( + 'paddle_fl paddlepaddle multi-task transfer distributed-training')) diff --git a/tools/codestyle/clang_format.hook b/tools/codestyle/clang_format.hook new file mode 100755 index 0000000000000000000000000000000000000000..1d928216867c0ba3897d71542fea44debf8d72a0 --- /dev/null +++ b/tools/codestyle/clang_format.hook @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +readonly VERSION="3.8" + +version=$(clang-format -version) + +if ! [[ $version == *"$VERSION"* ]]; then + echo "clang-format version check failed." + echo "a version contains '$VERSION' is needed, but get '$version'" + echo "you can install the right version, and make an soft-link to '\$PATH' env" + exit -1 +fi + +clang-format $@ diff --git a/tools/codestyle/copyright.hook b/tools/codestyle/copyright.hook new file mode 100644 index 0000000000000000000000000000000000000000..86b16ebdc46047c7cb3d7731a71cbf9647a1f2fe --- /dev/null +++ b/tools/codestyle/copyright.hook @@ -0,0 +1,121 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import io, re +import sys, os +import subprocess +import platform + +COPYRIGHT = ''' +Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +''' + +LANG_COMMENT_MARK = None + +NEW_LINE_MARK = None + +COPYRIGHT_HEADER = None + +if platform.system() == "Windows": + NEW_LINE_MARK = "\r\n" +else: + NEW_LINE_MARK = '\n' + COPYRIGHT_HEADER = COPYRIGHT.split(NEW_LINE_MARK)[1] + p = re.search('(\d{4})', COPYRIGHT_HEADER).group(0) + process = subprocess.Popen(["date", "+%Y"], stdout=subprocess.PIPE) + date, err = process.communicate() + date = date.decode("utf-8").rstrip("\n") + COPYRIGHT_HEADER = COPYRIGHT_HEADER.replace(p, date) + + +def generate_copyright(template, lang='C'): + if lang == 'Python': + LANG_COMMENT_MARK = '#' + else: + LANG_COMMENT_MARK = "//" + + lines = template.split(NEW_LINE_MARK) + BLANK = " " + ans = LANG_COMMENT_MARK + BLANK + COPYRIGHT_HEADER + NEW_LINE_MARK + for lino, line in enumerate(lines): + if lino == 0 or lino == 1 or lino == len(lines) - 1: continue + if len(line) == 0: + BLANK = "" + else: + BLANK = " " + ans += LANG_COMMENT_MARK + BLANK + line + NEW_LINE_MARK + + return ans + "\n" + + +def lang_type(filename): + if filename.endswith(".py"): + return "Python" + elif filename.endswith(".h"): + return "C" + elif filename.endswith(".c"): + return "C" + elif filename.endswith(".hpp"): + return "C" + elif filename.endswith(".cc"): + return "C" + elif filename.endswith(".cpp"): + return "C" + elif filename.endswith(".cu"): + return "C" + elif filename.endswith(".cuh"): + return "C" + elif filename.endswith(".go"): + return "C" + elif filename.endswith(".proto"): + return "C" + else: + print("Unsupported filetype %s", filename) + exit(0) + + +PYTHON_ENCODE = re.compile("^[ \t\v]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)") + + +def main(argv=None): + parser = argparse.ArgumentParser( + description='Checker for copyright declaration.') + parser.add_argument('filenames', nargs='*', help='Filenames to check') + args = parser.parse_args(argv) + + retv = 0 + for filename in args.filenames: + fd = io.open(filename, encoding="utf-8") + first_line = fd.readline() + second_line = fd.readline() + if "COPYRIGHT (C)" in first_line.upper(): continue + if first_line.startswith("#!") or PYTHON_ENCODE.match( + second_line) != None or PYTHON_ENCODE.match(first_line) != None: + continue + original_contents = io.open(filename, encoding="utf-8").read() + new_contents = generate_copyright( + COPYRIGHT, lang_type(filename)) + original_contents + print('Auto Insert Copyright Header {}'.format(filename)) + retv = 1 + with io.open(filename, 'w') as output_file: + output_file.write(new_contents) + + return retv + + +if __name__ == '__main__': + exit(main()) diff --git a/tools/codestyle/cpplint_pre_commit.hook b/tools/codestyle/cpplint_pre_commit.hook new file mode 100755 index 0000000000000000000000000000000000000000..658008d852123b6eab06d1f13d61ba896e7e9c98 --- /dev/null +++ b/tools/codestyle/cpplint_pre_commit.hook @@ -0,0 +1,27 @@ +#!/bin/bash + +TOTAL_ERRORS=0 +if [[ ! $TRAVIS_BRANCH ]]; then + # install cpplint on local machine. + if [[ ! $(which cpplint) ]]; then + pip install cpplint + fi + # diff files on local machine. + files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}') +else + # diff files between PR and latest commit on Travis CI. + branch_ref=$(git rev-parse "$TRAVIS_BRANCH") + head_ref=$(git rev-parse HEAD) + files=$(git diff --name-status $branch_ref $head_ref | awk '$1 != "D" {print $2}') +fi +# The trick to remove deleted files: https://stackoverflow.com/a/2413151 +for file in $files; do + if [[ $file =~ ^(patches/grpc/.*) ]]; then + continue; + else + cpplint --filter=-readability/fn_size $file; + TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); + fi +done + +exit $TOTAL_ERRORS diff --git a/tools/codestyle/docstring_checker.py b/tools/codestyle/docstring_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..8d4b24a0cf6b743b72dca58fd885f927560964bf --- /dev/null +++ b/tools/codestyle/docstring_checker.py @@ -0,0 +1,349 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DocstringChecker is used to check python doc string's style.""" + +import six +import astroid + +from pylint.checkers import BaseChecker, utils +from pylint.interfaces import IAstroidChecker + +from collections import defaultdict +import re + + +def register(linter): + """Register checkers.""" + linter.register_checker(DocstringChecker(linter)) + + +class Docstring(object): + """Docstring class holds the parsed doc string elements. + """ + + def __init__(self): + self.d = defaultdict(list) #name->[] + self.clear() + + def clear(self): + self.d['Args'] = [] + self.d['Examples'] = [] + self.d['Returns'] = [] + self.d['Raises'] = [] + self.args = {} #arg_name->arg_type + + def get_level(self, string, indent=' '): + level = 0 + unit_size = len(indent) + while string[:unit_size] == indent: + string = string[unit_size:] + level += 1 + + return level + + def parse(self, doc): + """parse gets sections from doc + Such as Args, Returns, Raises, Examples s + Args: + doc (string): is the astroid node doc string. + Returns: + True if doc is parsed successfully. + """ + self.clear() + + lines = doc.splitlines() + state = ("others", -1) + for l in lines: + c = l.strip() + if len(c) <= 0: + continue + + level = self.get_level(l) + if c.startswith("Args:"): + state = ("Args", level) + elif c.startswith("Returns:"): + state = ("Returns", level) + elif c.startswith("Raises:"): + state = ("Raises", level) + elif c.startswith("Examples:"): + state = ("Examples", level) + else: + if level > state[1]: + self.d[state[0]].append(c) + continue + + state = ("others", -1) + self.d[state[0]].append(c) + + self._arg_with_type() + return True + + def get_returns(self): + return self.d['Returns'] + + def get_raises(self): + return self.d['Raises'] + + def get_examples(self): + return self.d['Examples'] + + def _arg_with_type(self): + + for t in self.d['Args']: + m = re.search('([A-Za-z0-9_-]+)\s{0,4}(\(.+\))\s{0,4}:', t) + if m: + self.args[m.group(1)] = m.group(2) + + return self.args + + +class DocstringChecker(BaseChecker): + """DosstringChecker is pylint checker to + check docstring style. + """ + __implements__ = (IAstroidChecker, ) + + POSITIONAL_MESSAGE_ID = 'str-used-on-positional-format-argument' + KEYWORD_MESSAGE_ID = 'str-used-on-keyword-format-argument' + + name = 'doc-string-checker' + symbol = "doc-string" + priority = -1 + msgs = { + 'W9001': ('One line doc string on > 1 lines', symbol + "-one-line", + 'Used when a short doc string is on multiple lines'), + 'W9002': + ('Doc string does not end with "." period', symbol + "-end-with", + 'Used when a doc string does not end with a period'), + 'W9003': + ('All args with their types must be mentioned in doc string %s', + symbol + "-with-all-args", + 'Used when not all arguments are in the doc string '), + 'W9005': ('Missing docstring or docstring is too short', + symbol + "-missing", 'Add docstring longer >=10'), + 'W9006': ('Docstring indent error, use 4 space for indent', + symbol + "-indent-error", 'Use 4 space for indent'), + 'W9007': ('You should add `Returns` in comments', + symbol + "-with-returns", + 'There should be a `Returns` section in comments'), + 'W9008': ('You should add `Raises` section in comments', + symbol + "-with-raises", + 'There should be a `Raises` section in comments'), + } + options = () + + def visit_functiondef(self, node): + """visit_functiondef checks Function node docstring style. + Args: + node (astroid.node): The visiting node. + Returns: + True if successful other wise False. + """ + + self.check_doc_string(node) + + if node.tolineno - node.fromlineno <= 10: + return True + + if not node.doc: + return True + + doc = Docstring() + doc.parse(node.doc) + + self.all_args_in_doc(node, doc) + self.with_returns(node, doc) + self.with_raises(node, doc) + + def visit_module(self, node): + self.check_doc_string(node) + + def visit_classdef(self, node): + self.check_doc_string(node) + + def check_doc_string(self, node): + self.missing_doc_string(node) + self.one_line(node) + self.has_period(node) + self.indent_style(node) + + def missing_doc_string(self, node): + if node.name.startswith("__") or node.name.startswith("_"): + return True + if node.tolineno - node.fromlineno <= 10: + return True + + if node.doc is None or len(node.doc) < 10: + self.add_message('W9005', node=node, line=node.fromlineno) + return False + + # FIXME(gongwb): give the docstring line-no + def indent_style(self, node, indent=4): + """indent_style checks docstring's indent style + Args: + node (astroid.node): The visiting node. + indent (int): The default indent of style + Returns: + True if successful other wise False. + """ + if node.doc is None: + return True + + doc = node.doc + lines = doc.splitlines() + line_num = 0 + + for l in lines: + if line_num == 0: + continue + cur_indent = len(l) - len(l.lstrip()) + if cur_indent % indent != 0: + self.add_message('W9006', node=node, line=node.fromlineno) + return False + line_num += 1 + + return True + + def one_line(self, node): + """one_line checks if docstring (len < 40) is on one line. + Args: + node (astroid.node): The node visiting. + Returns: + True if successful otherwise False. + """ + + doc = node.doc + if doc is None: + return True + + if len(doc) > 40: + return True + elif sum(doc.find(nl) for nl in ('\n', '\r', '\n\r')) == -3: + return True + else: + self.add_message('W9001', node=node, line=node.fromlineno) + return False + + return True + + def has_period(self, node): + """has_period checks if one line doc end-with '.' . + Args: + node (astroid.node): the node is visiting. + Returns: + True if successful otherwise False. + """ + if node.doc is None: + return True + + if len(node.doc.splitlines()) > 1: + return True + + if not node.doc.strip().endswith('.'): + self.add_message('W9002', node=node, line=node.fromlineno) + return False + + return True + + def with_raises(self, node, doc): + """with_raises checks if one line doc end-with '.' . + Args: + node (astroid.node): the node is visiting. + doc (Docstring): Docstring object. + Returns: + True if successful otherwise False. + """ + + find = False + for t in node.body: + if not isinstance(t, astroid.Raise): + continue + + find = True + break + + if not find: + return True + + if len(doc.get_raises()) == 0: + self.add_message('W9008', node=node, line=node.fromlineno) + return False + + return True + + def with_returns(self, node, doc): + """with_returns checks if docstring comments what are returned . + Args: + node (astroid.node): the node is visiting. + doc (Docstring): Docstring object. + Returns: + True if successful otherwise False. + """ + + if node.name.startswith("__") or node.name.startswith("_"): + return True + find = False + for t in node.body: + if not isinstance(t, astroid.Return): + continue + + find = True + break + + if not find: + return True + + if len(doc.get_returns()) == 0: + self.add_message('W9007', node=node, line=node.fromlineno) + return False + + return True + + def all_args_in_doc(self, node, doc): + """all_args_in_doc checks if arguments are mentioned in doc + Args: + node (astroid.node): the node is visiting. + doc (Docstring): Docstring object + Returns: + True if successful otherwise False. + """ + if node.name.startswith("__") or node.name.startswith("_"): + return True + args = [] + for arg in node.args.get_children(): + if (not isinstance(arg, astroid.AssignName)) \ + or arg.name == "self": + continue + args.append(arg.name) + + if len(args) <= 0: + return True + + parsed_args = doc.args + args_not_documented = set(args) - set(parsed_args) + if len(args) > 0 and len(parsed_args) <= 0: + self.add_message( + 'W9003', + node=node, + line=node.fromlineno, + args=list(args_not_documented)) + return False + + for t in args: + if t not in parsed_args: + self.add_message( + 'W9003', node=node, line=node.fromlineno, args=[t, ]) + return False + + return True diff --git a/tools/codestyle/pylint_pre_commit.hook b/tools/codestyle/pylint_pre_commit.hook new file mode 100755 index 0000000000000000000000000000000000000000..150a3f5666bd39d30b7e6518e58a14fb5fe2f14b --- /dev/null +++ b/tools/codestyle/pylint_pre_commit.hook @@ -0,0 +1,19 @@ +#!/bin/bash + +TOTAL_ERRORS=0 + + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +export PYTHONPATH=$DIR:$PYTHONPATH + +# The trick to remove deleted files: https://stackoverflow.com/a/2413151 +for file in $(git diff --name-status | awk '$1 != "D" {print $2}'); do + pylint --disable=all --load-plugins=docstring_checker \ + --enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file; + TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); +done + +exit $TOTAL_ERRORS +#For now, just warning: +#exit 0 + diff --git a/tools/codestyle/test_docstring_checker.py b/tools/codestyle/test_docstring_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..0547f7d1610c64b0ca6efa9384e97d658c8276fe --- /dev/null +++ b/tools/codestyle/test_docstring_checker.py @@ -0,0 +1,232 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import docstring_checker +import pylint.testutils +import astroid +import pytest +import sys + + +class TestDocstring(pylint.testutils.CheckerTestCase): + CHECKER_CLASS = docstring_checker.DocstringChecker + + def test_one_line(self): + func_node = astroid.extract_node(''' + def test(): + """get + news. + """ + if True: + return 5 + return 5 + ''') + + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 1 + assert 'W9001' == got[0][0] + + def test_one_line(self): + func_node = astroid.extract_node(''' + def test(): + """get news""" + if True: + return 5 + return 5 + ''') + + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 1 + assert 'W9002' == got[0][0] + + def test_args(self): + func_node = astroid.extract_node(''' + def test(scale, mean): + """get news. + Args: + scale (int): scale is the number. + """ + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + ''') + + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 1 + assert 'W9003' == got[0][0] + + def test_missing(self): + func_node = astroid.extract_node(''' + def test(): + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + ''') + + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 1 + assert 'W9005' == got[0][0] + + def test_indent(self): + func_node = astroid.extract_node(''' + def test(): + """ get get get get get get get get + get get get get get get get get. + """ + pass + ''') + + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 1 + assert 'W9006' == got[0][0] + + def test_with_resturns(self): + func_node = astroid.extract_node(''' + def test(): + """get news. + Args: + scale (int): scale is the number. + """ + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + return mean + ''') + + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 1 + assert 'W9007' == got[0][0] + + def test_with_raises(self): + func_node = astroid.extract_node(''' + def test(): + """get news. + Args: + scale (int): scale is the number. + """ + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + mean=scale + raise ValueError('A very specific bad thing happened.') + ''') + + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 1 + assert 'W9008' == got[0][0] + + def test_no_message(self): + p = ''' +def fc(input, + size, + num_flatten_dims=1, + param_attr=None, + bias_attr=None, + act=None, + name=None): + """ + **Fully Connected Layer** + The fully connected layer can take multiple tensors as its inputs. It + creates a variable called weights for each input tensor, which represents + a fully connected weight matrix from each input unit to each output unit. + The fully connected layer multiplies each input tensor with its coresponding + weight to produce an output Tensor. If multiple input tensors are given, + the results of multiple multiplications will be sumed up. If bias_attr is + not None, a bias variable will be created and added to the output. Finally, + if activation is not None, it will be applied to the output as well. + This process can be formulated as follows: + + Args: + input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of + the input tensor(s) is at least 2. + size(int): The number of output units in this layer. + num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than + two dimensions. If this happens, the multidimensional tensor will first be flattened + into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input + tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to + form the second dimension of the final matrix (width of the matrix). For example, suppose + `X` is a 6-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. + Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. + param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable + parameters/weights of this layer. + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias + of this layer. If it is set to None, no bias will be added to the output units. + act (str, default None): Activation to be applied to the output of this layer. + name (str, default None): The name of this layer. + Returns: + A tensor variable storing the transformation result. + Raises: + ValueError: If rank of the input tensor is less than 2. + Examples: + .. code-block:: python + data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + fc = fluid.layers.fc(input=data, size=1000, act="tanh") + """ + raise ValueError('A very specific bad thing happened.') + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + size = 1 + return size + ''' + + func_node = astroid.extract_node(p) + self.checker.visit_functiondef(func_node) + got = self.linter.release_messages() + assert len(got) == 0