提交 c3188528 编写于 作者: Q qjing666

make examples more easy to follow

上级 f1d5d51a
...@@ -140,4 +140,3 @@ class FLScheduler(object): ...@@ -140,4 +140,3 @@ class FLScheduler(object):
if len(finish_training_dict) == len(worker_dict): if len(finish_training_dict) == len(worker_dict):
all_finish_training = True all_finish_training = True
time.sleep(5) time.sleep(5)
loop += 1
...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler ...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2 worker_num = 2
server_num = 1 server_num = 1
scheduler = FLScheduler(worker_num,server_num) # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(worker_num) scheduler.set_sample_worker_num(worker_num)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
...@@ -21,8 +21,8 @@ server_id = 0 ...@@ -21,8 +21,8 @@ server_id = 0
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_server_job(job_path, server_id) job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job) server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" server._current_ep = "127.0.0.1:8181" # IP address for server
server.start() server.start()
print("connect") print("connect")
...@@ -19,22 +19,22 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest ...@@ -19,22 +19,22 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start() trainer.start()
print(trainer._scheduler_ep, trainer._current_ep) print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model" output_folder = "fl_model"
step_i = 0 epoch_id = 0
while not trainer.stop(): while not trainer.stop():
print("batch %d start train" % (step_i)) print("batch %d start train" % (epoch_id))
train_step = 0 train_step = 0
for data in reader(): for data in reader():
trainer.run(feed=data, fetch=[]) trainer.run(feed=data, fetch=[])
train_step += 1 train_step += 1
if train_step == trainer._step: if train_step == trainer._step:
break break
step_i += 1 epoch_id += 1
if step_i % 100 == 0: if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler ...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4 worker_num = 4
server_num = 1 server_num = 1
scheduler = FLScheduler(worker_num,server_num) #Define number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(4) scheduler.set_sample_worker_num(4)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
...@@ -21,7 +21,7 @@ server_id = 0 ...@@ -21,7 +21,7 @@ server_id = 0
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_server_job(job_path, server_id) job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job) server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" server._current_ep = "127.0.0.1:8181" # IP address for server
server.start() server.start()
...@@ -13,7 +13,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest ...@@ -13,7 +13,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start() trainer.start()
......
unset http_proxy
unset https_proxy
python fl_master.py python fl_master.py
sleep 2 sleep 2
python -u fl_scheduler.py >scheduler.log & python -u fl_scheduler.py >scheduler.log &
......
...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler ...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4 worker_num = 4
server_num = 1 server_num = 1
scheduler = FLScheduler(worker_num,server_num) # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(4) scheduler.set_sample_worker_num(4)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
...@@ -7,7 +7,7 @@ server_id = 0 ...@@ -7,7 +7,7 @@ server_id = 0
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_server_job(job_path, server_id) job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job) server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" server._current_ep = "127.0.0.1:8181" # IP address for server
server.start() server.start()
...@@ -14,7 +14,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest ...@@ -14,7 +14,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
print(job._target_names) print(job._target_names)
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
...@@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader): ...@@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader):
epoch_id = 0 epoch_id = 0
step = 0 step = 0
epoch = 3000 epoch = 3000
count_by_step = True count_by_step = False
if count_by_step: if count_by_step:
output_folder = "model_node%d" % trainer_id output_folder = "model_node%d" % trainer_id
else: else:
......
unset http_proxy
unset https_proxy
#killall python #killall python
python fl_master.py python fl_master.py
sleep 2 sleep 2
......
...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler ...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4 worker_num = 4
server_num = 1 server_num = 1
scheduler = FLScheduler(worker_num,server_num) # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(4) scheduler.set_sample_worker_num(4)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
...@@ -21,7 +21,7 @@ server_id = 0 ...@@ -21,7 +21,7 @@ server_id = 0
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_server_job(job_path, server_id) job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job) server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" server._current_ep = "127.0.0.1:8181" # IP address for server
server.start() server.start()
...@@ -14,7 +14,7 @@ train_file_dir = "mid_data/node4/%d/" % trainer_id ...@@ -14,7 +14,7 @@ train_file_dir = "mid_data/node4/%d/" % trainer_id
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start() trainer.start()
......
...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler ...@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2 worker_num = 2
server_num = 1 server_num = 1
scheduler = FLScheduler(worker_num,server_num) # Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(worker_num) scheduler.set_sample_worker_num(worker_num)
scheduler.init_env() scheduler.init_env()
print("init env done.") print("init env done.")
......
...@@ -21,8 +21,8 @@ server_id = 0 ...@@ -21,8 +21,8 @@ server_id = 0
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_server_job(job_path, server_id) job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job) server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" server._current_ep = "127.0.0.1:8181" # IP address for server
server.start() server.start()
print("connect") print("connect")
...@@ -28,7 +28,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest ...@@ -28,7 +28,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config" job_path = "fl_job_config"
job = FLRunTimeJob() job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id) job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job) trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id trainer.trainer_id = trainer_id
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id) trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
...@@ -75,6 +75,7 @@ while not trainer.stop(): ...@@ -75,6 +75,7 @@ while not trainer.stop():
if step_i % 100 == 0: if step_i % 100 == 0:
print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0])) print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0]))
print(step_i)
avg_loss_val, acc_val = train_test(train_test_program=test_program, avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader, train_test_reader=test_reader,
train_test_feed=feeder) train_test_feed=feeder)
...@@ -82,5 +83,5 @@ while not trainer.stop(): ...@@ -82,5 +83,5 @@ while not trainer.stop():
if epoch_id > 40: if epoch_id > 40:
break break
if step_i % 100 == 0: if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
...@@ -89,15 +89,15 @@ else: ...@@ -89,15 +89,15 @@ else:
trainer.start() trainer.start()
print(trainer._scheduler_ep, trainer._current_ep) print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model" output_folder = "fl_model"
step_i = 0 epoch_id = 0
while not trainer.stop(): while not trainer.stop():
print("batch %d start train" % (step_i)) print("batch %d start train" % (step_i))
train_step = 0 step_i = 0
for data in reader(): for data in reader():
trainer.run(feed=data, fetch=[]) trainer.run(feed=data, fetch=[])
train_step += 1 step_i += 1
if train_step == trainer._step: if train_step == trainer._step:
break break
step_i += 1 epoch_id += 1
if step_i % 100 == 0: if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册