未验证 提交 fd36a922 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #37 from qjing666/easy_use

Make examples more easy to follow
......@@ -3,3 +3,4 @@ mistune
sphinx_rtd_theme
paddlepaddle>=1.6
zmq
......@@ -67,7 +67,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
......@@ -94,6 +95,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.start()
......@@ -122,6 +124,8 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
```
......@@ -104,10 +104,7 @@ class FLScheduler(object):
def start_fl_training(self):
# loop until training is done
loop = 0
while True:
if loop <= 1:
print(loop)
random.shuffle(self.fl_workers)
worker_dict = {}
for worker in self.fl_workers[:self.sample_worker_num]:
......@@ -143,4 +140,3 @@ class FLScheduler(object):
if len(finish_training_dict) == len(worker_dict):
all_finish_training = True
time.sleep(5)
loop += 1
......@@ -42,10 +42,8 @@ try:
from ssl import RAND_bytes
rng = RAND_bytes
except(AttributeError, ImportError):
#python2
rng = os.urandom
#raise RNGError
class DiffieHellman:
"""
Implements the Diffie-Hellman key exchange protocol.
......@@ -115,13 +113,13 @@ class DiffieHellman:
self.shared_secret = pow(other_public_key,
self.private_key,
self.prime)
#python2
#length = self.shared_secret.bit_length() // 8 + 1
#shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:]
#python3
shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big')
try:
#python3
shared_secret_as_bytes = self.shared_secret.to_bytes(self.shared_secret.bit_length() // 8 + 1, byteorder='big')
except:
#python2
length = self.shared_secret.bit_length() // 8 + 1
shared_secret_as_bytes = ('%%0%dx' % (length << 1) % self.shared_secret).decode('hex')[-length:]
_h = sha256()
_h.update(bytes(shared_secret_as_bytes))
......
......@@ -16,6 +16,7 @@ import logging
from paddle_fl.core.scheduler.agent_master import FLWorkerAgent
import numpy
import hmac
import hashlib
from .diffiehellman.diffiehellman import DiffieHellman
class FLTrainerFactory(object):
......@@ -89,12 +90,12 @@ class FLTrainer(object):
# TODO(guru4elephant): add connection with master
if self.cur_step != 0:
while not self.agent.finish_training():
print('wait others finish')
self._logger.debug("Wait others finish")
continue
while not self.agent.can_join_training():
print("wait permit")
self._logger.debug("Wait permit")
continue
print("ready to train")
self._logger.debug("Ready to train")
return False
......@@ -123,7 +124,6 @@ class FedAvgTrainer(FLTrainer):
self.exe.run(self._recv_program)
epoch = 0
for i in range(num_epoch):
print(epoch)
for data in reader():
self.exe.run(self._main_program,
feed=feeder.feed(data),
......@@ -190,6 +190,8 @@ class SecAggTrainer(FLTrainer):
self._step_id = s
def start(self):
self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
self.agent.connect_scheduler()
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe.run(self._startup_program)
self.cur_step = 0
......@@ -219,7 +221,7 @@ class SecAggTrainer(FLTrainer):
self._logger.debug("begin to run send program")
noise = 0.0
scale = pow(10.0, 5)
digestmod="SHA256"
digestmod=hashlib.sha256
# 1. load priv key and other's pub key
dh = DiffieHellman(group=15, key_length=256)
dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt")
......@@ -245,5 +247,3 @@ class SecAggTrainer(FLTrainer):
self.cur_step += 1
return loss
def stop(self):
return False
......@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
......
......@@ -21,8 +21,8 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181"
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
print("connect")
......@@ -19,22 +19,22 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
step_i = 0
epoch_id = 0
while not trainer.stop():
print("batch %d start train" % (step_i))
print("batch %d start train" % (epoch_id))
train_step = 0
for data in reader():
trainer.run(feed=data, fetch=[])
train_step += 1
if train_step == trainer._step:
break
step_i += 1
if step_i % 100 == 0:
epoch_id += 1
if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder)
......@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
#Define number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
......
......@@ -21,7 +21,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181"
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
......@@ -13,7 +13,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
......
unset http_proxy
unset https_proxy
python fl_master.py
sleep 2
python -u fl_scheduler.py >scheduler.log &
......
......@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
......
......@@ -7,7 +7,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181"
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
......@@ -14,7 +14,7 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
print(job._target_names)
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
......@@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader):
epoch_id = 0
step = 0
epoch = 3000
count_by_step = True
count_by_step = False
if count_by_step:
output_folder = "model_node%d" % trainer_id
else:
......
unset http_proxy
unset https_proxy
#killall python
python fl_master.py
sleep 2
......
......@@ -2,7 +2,8 @@ from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 4
server_num = 1
scheduler = FLScheduler(worker_num,server_num)
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(4)
scheduler.init_env()
print("init env done.")
......
......@@ -21,7 +21,7 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181"
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
......@@ -14,7 +14,7 @@ train_file_dir = "mid_data/node4/%d/" % trainer_id
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
......
from paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2
server_num = 1
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num,server_num,port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
......@@ -21,5 +21,8 @@ server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
print("connect")
......@@ -28,8 +28,10 @@ trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer.trainer_id = trainer_id
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/"
trainer.start()
......@@ -73,6 +75,7 @@ while not trainer.stop():
if step_i % 100 == 0:
print("Epoch: {0}, step: {1}, accuracy: {2}".format(epoch_id, step_i, accuracy[0]))
print(step_i)
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
......@@ -80,5 +83,5 @@ while not trainer.stop():
if epoch_id > 40:
break
if step_i % 100 == 0:
if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder)
......@@ -5,10 +5,12 @@ if [ ! -d log ];then
mkdir log
fi
python3 fl_master.py
python fl_master.py
sleep 2
python3 -u fl_server.py >log/server0.log &
python -u fl_server.py >log/server0.log &
sleep 2
python3 -u fl_trainer.py 0 >log/trainer0.log &
python -u fl_scheduler.py > log/scheduler.log &
sleep 2
python3 -u fl_trainer.py 1 >log/trainer1.log &
python -u fl_trainer.py 0 >log/trainer0.log &
sleep 2
python -u fl_trainer.py 1 >log/trainer1.log &
......@@ -2,7 +2,8 @@
task_name=test_fl_job_submit_jingqinghe
hdfs_output=/user/feed/mlarch/sequence_generator/dongdaxiang/job_44
train_cmd=python dist_trainer.py
monitor_cmd=python system_monitor_app.py 10 100
#monitor_cmd=python system_monitor_app.py 10 100
monitor_cmd=
#train_cmd=python test_hadoop.py
hdfs_path=afs://xingtian.afs.baidu.com:9902
......
/home/jingqinghe/mpi_feed4/smart_client/bin/qdel $1".yq01-hpc-lvliang01-smart-master.dmop.baidu.com"
unset http_proxy
unset https_proxy
/home/jingqinghe/tools/mpi_feed4/smart_client/bin/qdel $1".yq01-hpc-lvliang01-smart-master.dmop.baidu.com"
......@@ -18,7 +18,8 @@ print(random_port)
current_ip = socket.gethostbyname(socket.gethostname())
endpoints = "{}:{}".format(current_ip, random_port)
#start a web server for remote endpoints to download their config
os.system("python -m SimpleHTTPServer 8080 &")
#os.system("python -m SimpleHTTPServer 8080 &")
os.system("python -m http.server 8080 &")
if os.path.exists("job_config"):
os.system("rm -rf job_config")
if os.path.exists("package"):
......@@ -120,10 +121,10 @@ print(ip_list)
#allocate the role of each endpoint and their ids
ip_role = {}
for i in range(len(ip_list)):
if i < int(default_dict["server_nodes"]):
ip_role[ip_list[i]] = 'server%d' % i
if i < int(default_dict["server_nodes"]):
ip_role[ip_list[i]] = 'server%d' % i
else:
ip_role[ip_list[i]] = 'trainer%d' % (i-int(default_dict["server_nodes"]))
ip_role[ip_list[i]] = 'trainer%d' % (i-int(default_dict["server_nodes"]))
print(ip_role)
def job_generate():
......@@ -179,7 +180,7 @@ while not all_job_sent:
message = zmq_socket.recv()
group = message.split("\t")
if group[0] == "GET_FL_JOB":
download_job.append(group[1])
download_job.append(group[1])
zmq_socket.send(ip_role[group[1]])
else:
zmq_socket.send("WAIT\t0")
......
......@@ -89,15 +89,15 @@ else:
trainer.start()
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
step_i = 0
epoch_id = 0
while not trainer.stop():
print("batch %d start train" % (step_i))
train_step = 0
step_i = 0
for data in reader():
trainer.run(feed=data, fetch=[])
train_step += 1
step_i += 1
if train_step == trainer._step:
break
step_i += 1
if step_i % 100 == 0:
epoch_id += 1
if epoch_id % 5 == 0:
trainer.save_inference_program(output_folder)
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaddleFL version string """
fl_version = "0.1.6"
module_proto_version = "0.1.6"
fl_version = "0.1.7"
module_proto_version = "0.1.7"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册