提交 f6bcc2c6 编写于 作者: Q qjing666

update

......@@ -13,13 +13,10 @@
# limitations under the License.
import paddle.fluid as fluid
import logging
<<<<<<< HEAD
from paddle_fl.core.scheduler.agent_master import FLWorkerAgent
=======
import numpy
import hmac
from .diffiehellman.diffiehellman import DiffieHellman
>>>>>>> 3145ed186544ac195bb957c22a638461d8e480bd
class FLTrainerFactory(object):
def __init__(self):
......@@ -150,11 +147,6 @@ class FedAvgTrainer(FLTrainer):
self.exe.run(self._send_program)
self.cur_step += 1
return loss
<<<<<<< HEAD
=======
def stop(self):
return False
class SecAggTrainer(FLTrainer):
......@@ -205,8 +197,6 @@ class SecAggTrainer(FLTrainer):
self._recv_program = job._trainer_recv_program
self_step = job._strategy._inner_step
self._param_name_list = job._strategy._param_name_list
>>>>>>> 3145ed186544ac195bb957c22a638461d8e480bd
def reset(self):
self.cur_step = 0
......
......@@ -47,5 +47,5 @@ strategy = build_strategy.create_fl_strategy()
endpoints = ["127.0.0.1:8181"]
output = "fl_job_config"
job_generator.generate_fl_job(
strategy, server_endpoints=endpoints, worker_num=5, output=output)
strategy, server_endpoints=endpoints, worker_num=2, output=output)
# fl_job_config will be dispatched to workers
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 5
worker_num = 2
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
scheduler.set_sample_worker_num(5)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
......@@ -23,7 +23,6 @@ job._scheduler_ep = "127.0.0.1:9091"
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
step_i = 0
......@@ -38,3 +37,4 @@ while not trainer.stop():
step_i += 1
if step_i % 100 == 0:
trainer.save_inference_program(output_folder)
......@@ -6,7 +6,7 @@ python -u fl_scheduler.py > scheduler.log &
sleep 5
python -u fl_server.py >server0.log &
sleep 2
for ((i=0;i<5;i++))
for ((i=0;i<2;i++))
do
python -u fl_trainer.py $i >trainer$i.log &
sleep 2
......
......@@ -13,6 +13,7 @@ class Model(object):
self.label = fluid.layers.data(name='label', shape=[1],dtype='int64')
self.conv_pool_1 = fluid.nets.simple_img_conv_pool(input=self.inputs,num_filters=20,filter_size=5,pool_size=2,pool_stride=2,act='relu')
self.conv_pool_2 = fluid.nets.simple_img_conv_pool(input=self.conv_pool_1,num_filters=50,filter_size=5,pool_size=2,pool_stride=2,act='relu')
self.predict = self.predict = fluid.layers.fc(input=self.conv_pool_2, size=62, act='softmax')
self.cost = fluid.layers.cross_entropy(input=self.predict, label=self.label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=self.label)
......
......@@ -15,6 +15,7 @@ job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
print(job._target_names)
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册