提交 a04c9157 编写于 作者: Q qjing666

fix code style

上级 c1907a92
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy
......@@ -11,27 +25,32 @@ import math
import hashlib
import hmac
logging.basicConfig(filename="log/test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
logging.basicConfig(
filename="log/test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
logger = logging.getLogger("FLTrainer")
BATCH_SIZE = 64
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
trainer_num = 2
trainer_id = int(sys.argv[1]) # trainer id for each guest
trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/"
trainer.start()
......@@ -47,8 +66,8 @@ feeder = fluid.DataFeeder(feed_list=[inputs, label], place=fluid.CPUPlace())
# for test
test_program = trainer._main_program.clone(for_test=True)
def train_test(train_test_program,
train_test_feed, train_test_reader):
def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
......@@ -61,6 +80,8 @@ def train_test(train_test_program,
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
# for test
while not trainer.stop():
......@@ -71,15 +92,18 @@ while not trainer.stop():
step_i += 1
trainer.step_id = step_i
accuracy, = trainer.run(feed=feeder.feed(data),
fetch=["accuracy_0.tmp_0"])
fetch=["accuracy_0.tmp_0"])
if step_i % 100 == 0:
print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0]))
print("Epoch: {0}, step: {1}, accuracy: {2}".format(
epoch_id, step_i, accuracy[0]))
print(step_i)
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
avg_loss_val, acc_val = train_test(
train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %
(epoch_id, avg_loss_val, acc_val))
if epoch_id > 40:
break
......
......@@ -26,10 +26,12 @@ from paddle_fl.version import fl_version
def python_version():
return [int(v) for v in platform.python_version().split(".")]
max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0','paddlepaddle >= 1.6', 'zmq', 'paddlepaddle-gpu >= 1.6'
'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.6', 'zmq',
'paddlepaddle-gpu >= 1.6'
]
if max_version < 3:
......@@ -42,8 +44,7 @@ REQUIRED_PACKAGES += ["unittest2"]
setup(
name='paddle_fl',
version=fl_version.replace('-', ''),
description=
('Federated Deep Learning Package Based on PaddlePaddle.'),
description=('Federated Deep Learning Package Based on PaddlePaddle.'),
long_description='',
url='https://github.com/PaddlePaddle/PaddleFL',
author='PaddlePaddle Author',
......@@ -70,4 +71,5 @@ setup(
'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords=('paddle_fl paddlepaddle multi-task transfer distributed-training'))
keywords=(
'paddle_fl paddlepaddle multi-task transfer distributed-training'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册