提交 a04c9157 编写于 作者: Q qjing666

fix code style

上级 c1907a92
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy import numpy
...@@ -11,27 +25,32 @@ import math ...@@ -11,27 +25,32 @@ import math
import hashlib import hashlib
import hmac import hmac
logging.basicConfig(filename="log/test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG) logging.basicConfig(
filename="log/test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
logger = logging.getLogger("FLTrainer") logger = logging.getLogger("FLTrainer")
BATCH_SIZE = 64 BATCH_SIZE = 64
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
trainer_num = 2 trainer_num = 2
trainer_id = int(sys.argv[1]) # trainer id for each guest trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id trainer.trainer_id = trainer_id
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.trainer_num = trainer_num trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/" trainer.key_dir = "./keys/"
trainer.start() trainer.start()
...@@ -47,8 +66,8 @@ feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace()) ...@@ -47,8 +66,8 @@ feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace())
# for test # for test
test_program = trainer._main_program.clone(for_test=True) test_program = trainer._main_program.clone(for_test=True)
def train_test(train_test_program,
train_test_feed, train_test_reader): def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = [] acc_set = []
avg_loss_set = [] avg_loss_set = []
for test_data in train_test_reader(): for test_data in train_test_reader():
...@@ -61,6 +80,8 @@ def train_test(train_test_program, ...@@ -61,6 +80,8 @@ def train_test(train_test_program,
acc_val_mean = numpy.array(acc_set).mean() acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean() avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean return avg_loss_val_mean, acc_val_mean
# for test # for test
while not trainer.stop(): while not trainer.stop():
...@@ -71,15 +92,18 @@ while not trainer.stop(): ...@@ -71,15 +92,18 @@ while not trainer.stop():
step_i += 1 step_i += 1
trainer.step_id = step_i trainer.step_id = step_i
accuracy, = trainer.run(feed=feeder.feed(data), accuracy, = trainer.run(feed=feeder.feed(data),
fetch=["accuracy_0.tmp_0"]) fetch=["accuracy_0.tmp_0"])
if step_i % 100 == 0: if step_i % 100 == 0:
print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0])) print("Epoch: {0}, step: {1}, accuracy: {2}".format(
epoch_id, step_i, accuracy[0]))
print(step_i) print(step_i)
avg_loss_val, acc_val = train_test(train_test_program=test_program, avg_loss_val, acc_val = train_test(
train_test_reader=test_reader, train_test_program=test_program,
train_test_feed=feeder) train_test_reader=test_reader,
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val)) train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %
(epoch_id, avg_loss_val, acc_val))
if epoch_id > 40: if epoch_id > 40:
break break
......
...@@ -26,10 +26,12 @@ from paddle_fl.version import fl_version ...@@ -26,10 +26,12 @@ from paddle_fl.version import fl_version
def python_version(): def python_version():
return [int(v) for v in platform.python_version().split(".")] return [int(v) for v in platform.python_version().split(".")]
max_version, mid_version, min_version = python_version() max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0','paddlepaddle >= 1.6', 'zmq', 'paddlepaddle-gpu >= 1.6' 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.6', 'zmq',
'paddlepaddle-gpu >= 1.6'
] ]
if max_version < 3: if max_version < 3:
...@@ -42,8 +44,7 @@ REQUIRED_PACKAGES += ["unittest2"] ...@@ -42,8 +44,7 @@ REQUIRED_PACKAGES += ["unittest2"]
setup( setup(
name='paddle_fl', name='paddle_fl',
version=fl_version.replace('-', ''), version=fl_version.replace('-', ''),
description= description=('Federated Deep Learning Package Based on PaddlePaddle.'),
('Federated Deep Learning Package Based on PaddlePaddle.'),
long_description='', long_description='',
url='https://github.com/PaddlePaddle/PaddleFL', url='https://github.com/PaddlePaddle/PaddleFL',
author='PaddlePaddle Author', author='PaddlePaddle Author',
...@@ -70,4 +71,5 @@ setup( ...@@ -70,4 +71,5 @@ setup(
'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: Software Development :: Libraries :: Python Modules',
], ],
license='Apache 2.0', license='Apache 2.0',
keywords=('paddle_fl paddlepaddle multi-task transfer distributed-training')) keywords=(
'paddle_fl paddlepaddle multi-task transfer distributed-training'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册