From e100da0fc04c929e60ac35db890e948cfe7a7ddb Mon Sep 17 00:00:00 2001 From: qjing666 Date: Thu, 5 Dec 2019 15:38:44 +0800 Subject: [PATCH] add cluster submitter --- paddle_fl/__init__.py | 1 - paddle_fl/core/__init__.py | 3 + paddle_fl/core/master/fl_job.py | 39 ++-- paddle_fl/core/scheduler/agent_master.py | 30 ++- paddle_fl/core/submitter/__init__.py | 13 ++ paddle_fl/core/submitter/client_base.py | 153 ++++++++++++++ paddle_fl/examples/submitter_demo/conf.txt | 23 +++ paddle_fl/examples/submitter_demo/kill.sh | 1 + paddle_fl/examples/submitter_demo/model.py | 16 ++ paddle_fl/examples/submitter_demo/run.sh | 2 + .../submitter_demo/scheduler_client.py | 193 ++++++++++++++++++ .../examples/submitter_demo/train_program.py | 103 ++++++++++ 12 files changed, 549 insertions(+), 28 deletions(-) create mode 100644 paddle_fl/core/submitter/__init__.py create mode 100644 paddle_fl/core/submitter/client_base.py create mode 100644 paddle_fl/examples/submitter_demo/conf.txt create mode 100644 paddle_fl/examples/submitter_demo/kill.sh create mode 100644 paddle_fl/examples/submitter_demo/model.py create mode 100644 paddle_fl/examples/submitter_demo/run.sh create mode 100644 paddle_fl/examples/submitter_demo/scheduler_client.py create mode 100644 paddle_fl/examples/submitter_demo/train_program.py diff --git a/paddle_fl/__init__.py b/paddle_fl/__init__.py index 1fd2f81..5d80364 100644 --- a/paddle_fl/__init__.py +++ b/paddle_fl/__init__.py @@ -19,4 +19,3 @@ from . import core from . import dataset from . import reader - diff --git a/paddle_fl/core/__init__.py b/paddle_fl/core/__init__.py index ebeaed9..34df36d 100644 --- a/paddle_fl/core/__init__.py +++ b/paddle_fl/core/__init__.py @@ -20,3 +20,6 @@ from .strategy.fl_strategy_base import FedAvgStrategy from .scheduler.agent_master import FLServerAgent from .scheduler.agent_master import FLWorkerAgent from .scheduler.agent_master import FLScheduler +from .submitter.client_base import HPCClient +from .submitter.client_base import CloudClient + diff --git a/paddle_fl/core/master/fl_job.py b/paddle_fl/core/master/fl_job.py index 6fcd8fc..ec25b76 100644 --- a/paddle_fl/core/master/fl_job.py +++ b/paddle_fl/core/master/fl_job.py @@ -113,6 +113,14 @@ class FLCompileTimeJob(FLJobBase): self._save_readable_program( server_main, "%s/server.main.program.txt" % server_folder) + self._save_str_list(self._feed_names, + "%s/feed_names" % server_folder) + self._save_str_list(self._target_names, + "%s/target_names" % server_folder) + self._save_endpoints(self._server_endpoints, + "%s/endpoints" % server_folder) + self._save_strategy(self._strategy, + "%s/strategy.pkl" % server_folder) for i in range(trainer_num): trainer_folder = "%s/trainer%d" % (folder, i) @@ -131,6 +139,14 @@ class FLCompileTimeJob(FLJobBase): self._save_readable_program( trainer_main, "%s/trainer.main.program.txt" % trainer_folder) + self._save_str_list(self._feed_names, + "%s/feed_names" % trainer_folder) + self._save_str_list(self._target_names, + "%s/target_names" % trainer_folder) + self._save_endpoints(self._server_endpoints, + "%s/endpoints" % trainer_folder) + self._save_strategy(self._strategy, + "%s/strategy.pkl" % trainer_folder) for i in range(send_prog_num): trainer_folder = "%s/trainer%d" % (folder, i) @@ -149,17 +165,6 @@ class FLCompileTimeJob(FLJobBase): trainer_recv, "%s/trainer.recv.program.txt" % trainer_folder) - self._save_str_list(self._feed_names, - "%s/feed_names" % folder) - - self._save_str_list(self._target_names, - "%s/target_names" % folder) - - self._save_endpoints(self._server_endpoints, - "%s/endpoints" % folder) - - self._save_strategy(self._strategy, - "%s/strategy.pkl" % folder) class FLRunTimeJob(FLJobBase): """ @@ -211,16 +216,16 @@ class FLRunTimeJob(FLJobBase): except: pass - endpoints_fn = "%s/endpoints" % folder + endpoints_fn = "%s/endpoints" % folder_name self._endpoints = self._load_endpoints(endpoints_fn) - strategy_fn = "%s/strategy.pkl" % folder + strategy_fn = "%s/strategy.pkl" % folder_name self._strategy = self._load_strategy(strategy_fn) - feed_names_fn = "%s/feed_names" % folder + feed_names_fn = "%s/feed_names" % folder_name self._feed_names = self._load_str_list(feed_names_fn) - target_names_fn = "%s/target_names" % folder + target_names_fn = "%s/target_names" % folder_name self._target_names = self._load_str_list(target_names_fn) def load_server_job(self, folder=None, server_id=0): @@ -243,9 +248,9 @@ class FLRunTimeJob(FLJobBase): main_fn = "%s/server.main.program" % folder_name self._server_main_program = self._load_program(main_fn) - endpoints_fn = "%s/endpoints" % folder + endpoints_fn = "%s/endpoints" % folder_name self._endpoints = self._load_endpoints(endpoints_fn) import pickle - strategy_fn = "%s/strategy.pkl" % folder + strategy_fn = "%s/strategy.pkl" % folder_name self._strategy = self._load_strategy(strategy_fn) diff --git a/paddle_fl/core/scheduler/agent_master.py b/paddle_fl/core/scheduler/agent_master.py index d012f7c..077ee2f 100644 --- a/paddle_fl/core/scheduler/agent_master.py +++ b/paddle_fl/core/scheduler/agent_master.py @@ -18,25 +18,32 @@ class FLServerAgent(object): self.scheduler_ep = scheduler_ep self.context = zmq.Context() self.socket = self.context.socket(zmq.REQ) - self.socket.connect("tcp://127.0.0.1:9091") + self.socket.connect("tcp://{}".format(scheduler_ep)) self.current_ep = current_ep def connect_scheduler(self): - self.socket.send("SERVER_EP\t{}".format(self.current_ep)) - self.socket.recv() - + while True: + self.socket.send("SERVER_EP\t{}".format(self.current_ep)) + message = self.socket.recv() + group = message.split("\t") + if group[0] == 'INIT': + break class FLWorkerAgent(object): def __init__(self, scheduler_ep, current_ep): self.scheduler_ep = scheduler_ep self.context = zmq.Context() self.socket = self.context.socket(zmq.REQ) - self.socket.connect("tcp://127.0.0.1:9091") + self.socket.connect("tcp://{}".format(scheduler_ep)) self.current_ep = current_ep def connect_scheduler(self): - self.socket.send("WORKER_EP\t{}".format(self.current_ep)) - self.socket.recv() + while True: + self.socket.send("WORKER_EP\t{}".format(self.current_ep)) + message = self.socket.recv() + group = message.split("\t") + if group[0] == 'INIT': + break def finish_training(self): self.socket.send("FINISH\t{}".format(self.current_ep)) @@ -59,10 +66,13 @@ class FLWorkerAgent(object): class FLScheduler(object): - def __init__(self, worker_num, server_num, port=9091): + def __init__(self, worker_num, server_num, port=9091, socket=None): self.context = zmq.Context() - self.socket = self.context.socket(zmq.REP) - self.socket.bind("tcp://*:{}".format(port)) + if socket == None: + self.socket = self.context.socket(zmq.REP) + self.socket.bind("tcp://*:{}".format(port)) + else: + self.socket = socket self.worker_num = worker_num self.server_num = server_num self.sample_worker_num = 0 diff --git a/paddle_fl/core/submitter/__init__.py b/paddle_fl/core/submitter/__init__.py new file mode 100644 index 0000000..236169a --- /dev/null +++ b/paddle_fl/core/submitter/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/paddle_fl/core/submitter/client_base.py b/paddle_fl/core/submitter/client_base.py new file mode 100644 index 0000000..1cb7dff --- /dev/null +++ b/paddle_fl/core/submitter/client_base.py @@ -0,0 +1,153 @@ +import sys +import os + +class CloudClient(object): + def __init__(self): + pass + + def generate_submit_sh(self, job_dir): + with open() as fout: + pass + + def generate_job_sh(self, job_dir): + with open() as fout: + pass + + def submit(self, **kwargs): + pass + +class HPCClient(object): + def __init__(self): + self.conf_dict = {} + + def print_args(self): + print("task_name: {}".format(self.task_name)) + print("hdfs_path: {}".format(self.hdfs_path)) + print("ugi: {}".format(self.ugi)) + print("hdfs_output: {}".format(self.hdfs_output)) + print("worker_nodes: {}".format(self.worker_nodes)) + print("server_nodes: {}".format(self.server_nodes)) + print("hadoop_home: {}".format(self.hadoop_home)) + print("hpc_home: {}".format(self.hpc_home)) + print("train_cmd: {}".format(self.train_cmd)) + print("package_path: {}".format(self.package_path)) + print("priority: {}".format(self.priority)) + print("queue: {}".format(self.queue)) + print("server: {}".format(self.server)) + print("mpi_node_mem: {}".format(self.mpi_node_mem)) + print("pcpu: {}".format(self.pcpu)) + print("python_tar: {}".format(self.python_tar)) + print("wheel: {}".format(self.wheel)) + + def check_args(self): + assert self.task_name != "" + assert self.hdfs_path != "" + assert self.ugi != "" + assert self.hdfs_output != "" + assert self.worker_nodes != "" + assert self.server_nodes != "" + assert self.hadoop_home != "" + assert self.hpc_home != "" + assert self.train_cmd != "" + assert self.package_path != "" + assert self.priority != "" + assert self.queue != "" + assert self.server != "" + assert self.mpi_node_mem != "" + assert self.pcpu != "" + assert self.python_tar != "" + assert self.wheel != "" + + def generate_qsub_conf(self, job_dir): + with open("{}/qsub.conf".format(job_dir), "w") as fout: + fout.write("SERVER={}\n".format(self.server)) + fout.write("QUEUE={}\n".format(self.queue)) + fout.write("PRIORITY={}\n".format(self.priority)) + fout.write("USE_FLAGS_ADVRES=yes\n") + + def generate_submit_sh(self, job_dir): + with open("{}/submit.sh".format(job_dir), "w") as fout: + fout.write("#!/bin/bash\n") + fout.write("unset http_proxy\n") + fout.write("unset https_proxy\n") + fout.write("export HADOOP_HOME={}\n".format( + self.hadoop_home)) + fout.write("$HADOOP_HOME/bin/hadoop fs -Dhadoop.job.ugi={}" + " -Dfs.default.name={} -rmr {}\n".format( + self.ugi, + self.hdfs_path, + self.hdfs_output)) + fout.write("MPI_NODE_MEM={}\n".format(self.mpi_node_mem)) + fout.write("{}/bin/qsub_f -N {} --conf qsub.conf " + "--hdfs {} --ugi {} --hout {} --files ./package " + "-l nodes={},walltime=1000:00:00,pmem-hard={}," + "pcpu-soft={},pnetin-soft=1000," + "pnetout-soft=1000 job.sh\n".format( + self.hpc_home, + self.task_name, + self.hdfs_path, + self.ugi, + self.hdfs_output, + int(self.worker_nodes) + int(self.server_nodes), + self.mpi_node_mem, + self.pcpu)) + + def generate_job_sh(self, job_dir): + with open("{}/job.sh".format(job_dir), "w") as fout: + fout.write("#!/bin/bash\n") + fout.write("WORKDIR=`pwd`\n") + fout.write("mpirun -npernode 1 mv package/* ./\n") + fout.write("echo 'current dir: '$WORKDIR\n") + fout.write("mpirun -npernode 1 tar -zxvf python.tar.gz > /dev/null\n") + fout.write("export LIBRARY_PATH=$WORKDIR/python/lib:$LIBRARY_PATH\n") + fout.write("mpirun -npernode 1 python/bin/python -m pip install " + "{} --index-url=http://pip.baidu.com/pypi/simple " + "--trusted-host pip.baidu.com > /dev/null\n".format( + self.wheel)) + fout.write("export PATH=python/bin:$PATH\n") + if self.monitor_cmd != "": + fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile " + "${{PBS_NODEFILE}} python/bin/{} > monitor.log 2> monitor.elog &\n".format(self.monitor_cmd)) + fout.write("mpirun -npernode 1 -timestamp-output -tag-output -machinefile ${PBS_NODEFILE} python/bin/python train_program.py\n") + fout.write("if [[ $? -ne 0 ]]; then\n") + fout.write(" echo 'Failed to run mpi!' 1>&2\n") + fout.write(" exit 1\n") + fout.write("fi\n") + + def submit(self, **kwargs): + # task_name, output_path + self.task_name = kwargs.get("task_name", "test_submit_job") + self.hdfs_path = kwargs.get("hdfs_path", "") + self.ugi = kwargs.get("ugi", "") + self.hdfs_output = kwargs.get("hdfs_output", "") + self.worker_nodes = str(kwargs.get("worker_nodes", 2)) + self.server_nodes = str(kwargs.get("server_nodes", 2)) + self.hadoop_home = kwargs.get("hadoop_home", "") + self.hpc_home = kwargs.get("hpc_home", "") + self.train_cmd = kwargs.get("train_cmd", "") + self.monitor_cmd = kwargs.get("monitor_cmd", "") + self.package_path = kwargs.get("package_path", "") + self.priority = kwargs.get("priority", "") + self.queue = kwargs.get("queue", "") + self.server = kwargs.get("server", "") + self.mpi_node_mem = str(kwargs.get("mpi_node_mem", 11000)) + self.pcpu = str(kwargs.get("pcpu", 180)) + self.python_tar = kwargs.get("python_tar", "") + self.wheel = kwargs.get("wheel", "") + + self.print_args() + self.check_args() + jobdir = "{}_jobdir".format(self.task_name) + os.system("mkdir -p {}_jobdir".format(self.task_name)) + os.system("rm -rf {}/package".format(jobdir)) + os.system("cp -r {} {}/package".format(self.package_path, jobdir)) + os.system("cp {} {}/package/".format(self.python_tar, jobdir)) + os.system("cp {} {}/package/".format(self.wheel, jobdir)) + # make submit dir + self.generate_submit_sh(jobdir) + # generate submit.sh + self.generate_job_sh(jobdir) + # generate job.sh + self.generate_qsub_conf(jobdir) + # run submit + os.system("cd {};sh submit.sh > submit.log 2> submit.elog &".format(jobdir)) diff --git a/paddle_fl/examples/submitter_demo/conf.txt b/paddle_fl/examples/submitter_demo/conf.txt new file mode 100644 index 0000000..4f952b7 --- /dev/null +++ b/paddle_fl/examples/submitter_demo/conf.txt @@ -0,0 +1,23 @@ +# commonly configured +task_name=test_fl_job_submit_jingqinghe +hdfs_output=/user/feed/mlarch/sequence_generator/dongdaxiang/job_44 +train_cmd=python dist_trainer.py +monitor_cmd=python system_monitor_app.py 10 100 +#train_cmd=python test_hadoop.py + +hdfs_path=afs://xingtian.afs.baidu.com:9902 +ugi=mlarch,Fv1M87 +hdfs_output=/user/feed/mlarch/sequence_generator/dongdaxiang/job_44 +worker_nodes=2 +server_nodes=1 +hadoop_home=/home/jingqinghe/hadoop-xingtian/hadoop +hpc_home=/home/jingqinghe/mpi_feed4/smart_client +package_path=./package +priority=high +#queue name +queue=paddle-dev-amd +server=yq01-hpc-lvliang01-smart-master.dmop.baidu.com + +python_tar=./python.tar.gz +wheel=./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64.whl + diff --git a/paddle_fl/examples/submitter_demo/kill.sh b/paddle_fl/examples/submitter_demo/kill.sh new file mode 100644 index 0000000..44c2676 --- /dev/null +++ b/paddle_fl/examples/submitter_demo/kill.sh @@ -0,0 +1 @@ +/home/jingqinghe/mpi_feed4/smart_client/bin/qdel $1".yq01-hpc-lvliang01-smart-master.dmop.baidu.com" diff --git a/paddle_fl/examples/submitter_demo/model.py b/paddle_fl/examples/submitter_demo/model.py new file mode 100644 index 0000000..f07549b --- /dev/null +++ b/paddle_fl/examples/submitter_demo/model.py @@ -0,0 +1,16 @@ +import paddle.fluid as fluid + +class Model(object): + def __init__(self): + pass + + def mlp(self, inputs, label, hidden_size=128): + self.concat = fluid.layers.concat(inputs, axis=1) + self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu') + self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu') + self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax') + self.sum_cost = fluid.layers.cross_entropy(input=self.predict, label=label) + self.accuracy = fluid.layers.accuracy(input=self.predict, label=label) + self.loss = fluid.layers.reduce_mean(self.sum_cost) + self.startup_program = fluid.default_startup_program() + diff --git a/paddle_fl/examples/submitter_demo/run.sh b/paddle_fl/examples/submitter_demo/run.sh new file mode 100644 index 0000000..13e1ce2 --- /dev/null +++ b/paddle_fl/examples/submitter_demo/run.sh @@ -0,0 +1,2 @@ +tar -xf python.tar.gz +python/bin/python scheduler_client.py conf.txt diff --git a/paddle_fl/examples/submitter_demo/scheduler_client.py b/paddle_fl/examples/submitter_demo/scheduler_client.py new file mode 100644 index 0000000..dccd1ec --- /dev/null +++ b/paddle_fl/examples/submitter_demo/scheduler_client.py @@ -0,0 +1,193 @@ +import os +import socket +import random +import zmq +import time +import sys +from paddle_fl.core.submitter.client_base import HPCClient +from paddle_fl.core.scheduler.agent_master import FLScheduler +import paddle.fluid as fluid +from paddle_fl.core.master.job_generator import JobGenerator +from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory +from model import Model +import tarfile + +#random_port = random.randint(60001, 64001) +random_port = 60001 +print(random_port) +current_ip = socket.gethostbyname(socket.gethostname()) +endpoints = "{}:{}".format(current_ip, random_port) +#start a web server for remote endpoints to download their config +os.system("python -m SimpleHTTPServer 8080 &") +if os.path.exists("job_config"): + os.system("rm -rf job_config") +if os.path.exists("package"): + os.system("rm -rf package") +os.system("mkdir package") +os.system("cp train_program.py package") +with open("package/scheduler.conf", "w") as fout: + fout.write("ENDPOINT\t{}\n".format(endpoints)) + +# submit a job with current endpoint + +default_dict = { + "task_name": "test_submit_job", + "hdfs_path": "afs://xingtian.afs.baidu.com:9902", + "ugi": "", + "worker_nodes": 5, + "server_nodes": 5, + "hadoop_home": "/home/jingqinghe/hadoop-xingtian/hadoop", + "hpc_home": "/home/jingqinghe/mpi_feed4/smart_client", + "package_path": "./package", + "priority": "high", + "queue": "paddle-dev-amd", + "server": "yq01-hpc-lvliang01-smart-master.dmop.baidu.com", + "mpi_node_mem": 11000, + "pcpu": 180, + "python_tar": "./python.tar.gz", + "wheel": "./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64-0.whl" +} + +def load_conf(conf_file, local_dict): + with open(conf_file) as fin: + for line in fin: + group = line.strip().split("=") + if len(group) != 2: + continue + local_dict[group[0]] = group[1] + return local_dict + +client = HPCClient() +default_dict = load_conf(sys.argv[1], default_dict) + +client.submit( + task_name=default_dict["task_name"], + hdfs_path=default_dict["hdfs_path"], + ugi=default_dict["ugi"], + hdfs_output=default_dict["hdfs_output"], + worker_nodes=default_dict["worker_nodes"], + server_nodes=default_dict["server_nodes"], + hadoop_home=default_dict["hadoop_home"], + hpc_home=default_dict["hpc_home"], + train_cmd=default_dict["train_cmd"], + monitor_cmd=default_dict["monitor_cmd"], + package_path=default_dict["package_path"], + priority=default_dict["priority"], + queue=default_dict["queue"], + server=default_dict["server"], + mpi_node_mem=default_dict["mpi_node_mem"], + pcpu=default_dict["pcpu"], + python_tar=default_dict["python_tar"], + wheel=default_dict["wheel"]) + +print("submit mpi job done.") + +# start scheduler and receive the ip of allocated endpoints +context = zmq.Context() +zmq_socket = context.socket(zmq.REP) +zmq_socket.bind("tcp://{}:{}".format(current_ip, random_port)) + +print("binding tcp://{}:{}".format(current_ip, random_port)) + +all_ips_ready = False + +ip_list = [] + +scheduler = FLScheduler(int(default_dict["worker_nodes"]), + int(default_dict["server_nodes"]), + port=random_port, socket=zmq_socket) + +scheduler.set_sample_worker_num(int(default_dict["worker_nodes"])) + +print("going to wait all ips ready") + +while not all_ips_ready: + message = zmq_socket.recv() + group = message.split("\t") + if group[0] == "ENDPOINT": + ip_list.append(group[1]) + zmq_socket.send("ACCEPT\t{}".format(group[1])) + else: + zmq_socket.send("WAIT\t0") + if len(ip_list) == \ + int(default_dict["worker_nodes"]) + \ + int(default_dict["server_nodes"]): + all_ips_ready = True + +print("all worker ips are collected") +print(ip_list) + +#allocate the role of each endpoint and their ids +ip_role = {} +for i in range(len(ip_list)): + if i < int(default_dict["server_nodes"]): + ip_role[ip_list[i]] = 'server%d' % i + else: + ip_role[ip_list[i]] = 'trainer%d' % (i-int(default_dict["server_nodes"])) +print(ip_role) + +def job_generate(): + #generate a fl job which is the same as fl_master + inputs = [fluid.layers.data( \ + name=str(slot_id), shape=[5], + dtype="float32") + for slot_id in range(3)] + label = fluid.layers.data( \ + name="label", + shape=[1], + dtype='int64') + + model = Model() + model.mlp(inputs, label) + + job_generator = JobGenerator() + optimizer = fluid.optimizer.SGD(learning_rate=0.1) + job_generator.set_optimizer(optimizer) + job_generator.set_losses([model.loss]) + job_generator.set_startup_program(model.startup_program) + job_generator.set_infer_feed_and_target_names( + [x.name for x in inputs], [model.predict.name]) + + build_strategy = FLStrategyFactory() + build_strategy.fed_avg = True + build_strategy.inner_step = 10 + strategy = build_strategy.create_fl_strategy() + + # endpoints will be collected through the cluster + # in this example, we suppose endpoints have been collected + server_ip = ["{}".format(ip_list[0])] + + output = "job_config" + job_generator.generate_fl_job( + strategy, server_endpoints=server_ip, worker_num=int(default_dict["worker_nodes"]), output=output) + + file_list = os.listdir(output) + for file in file_list: + tar = tarfile.open('{}/{}.tar.gz'.format(output,file),'w:gz') + for root,dir,files in os.walk("{}/{}".format(output,file)): + for f in files: + fullpath = os.path.join(root,f) + tar.add(fullpath) + tar.close() + +job_generate() + +#send the allocated rolls to the remote endpoints +all_job_sent = False +download_job = [] +while not all_job_sent: + message = zmq_socket.recv() + group = message.split("\t") + if group[0] == "GET_FL_JOB": + download_job.append(group[1]) + zmq_socket.send(ip_role[group[1]]) + else: + zmq_socket.send("WAIT\t0") + if len(download_job) == len(ip_list): + all_job_sent = True + +#start training +scheduler.init_env() +print("init env done.") +scheduler.start_fl_training() + diff --git a/paddle_fl/examples/submitter_demo/train_program.py b/paddle_fl/examples/submitter_demo/train_program.py new file mode 100644 index 0000000..7dae086 --- /dev/null +++ b/paddle_fl/examples/submitter_demo/train_program.py @@ -0,0 +1,103 @@ +import socket +import random +import zmq +import os +import tarfile +import paddle_fl as fl +import paddle.fluid as fluid +from paddle_fl.core.server.fl_server import FLServer +from paddle_fl.core.master.fl_job import FLRunTimeJob +from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory +import numpy as np +import sys +import logging +import time + + +random_port = 60001 +scheduler_conf = {} + +#connect to scheduler and get the role and id of the endpoint +with open("scheduler.conf") as fin: + for line in fin: + line = line.strip() + group = line.split("\t") + scheduler_conf[group[0]] = group[1] + +current_ip = socket.gethostbyname(socket.gethostname()) +endpoint = "{}:{}".format(current_ip, random_port) +scheduler_ip = scheduler_conf["ENDPOINT"].split(":") +download_url = "{}:8080".format(scheduler_ip[0]) +print(download_url) +context = zmq.Context() +zmq_socket = context.socket(zmq.REQ) +zmq_socket.connect( + "tcp://{}".format(scheduler_conf["ENDPOINT"])) +zmq_socket.send("ENDPOINT\t{}".format(endpoint)) +message = zmq_socket.recv() +print(message) + +message = "" + +#download the config file from scheduler +while True: + zmq_socket.send("GET_FL_JOB\t{}".format(endpoint)) + message = zmq_socket.recv() + group = message.split("\t") + if group[0] == "WAIT": + continue + else: + os.system("wget {}/job_config/{}.tar.gz".format(download_url,message)) + print(message) + break + +os.system("ls") +os.system("gzip -d {}.tar.gz".format(message)) +print("gzip finish") +os.system("tar -xf {}.tar".format(message)) +os.system("ls") +zmq_socket.close() +print("close socket") + +#program start +if 'server' in message: + server = FLServer() + server_id = 0 + job_path = "job_config" + job = FLRunTimeJob() + job.load_server_job(job_path, server_id) + job._scheduler_ep = scheduler_conf["ENDPOINT"] + server.set_server_job(job) + server._current_ep = endpoint + server.start() +else: + def reader(): + for i in range(1000): + data_dict = {} + for i in range(3): + data_dict[str(i)] = np.random.rand(1, 5).astype('float32') + data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64') + yield data_dict + + trainer_id = message.split("trainer")[1] + job_path = "job_config" + job = FLRunTimeJob() + job.load_trainer_job(job_path, int(trainer_id)) + job._scheduler_ep = scheduler_conf["ENDPOINT"] + trainer = FLTrainerFactory().create_fl_trainer(job) + trainer._current_ep = endpoint + trainer.start() + print(trainer._scheduler_ep, trainer._current_ep) + output_folder = "fl_model" + step_i = 0 + while not trainer.stop(): + print("batch %d start train" % (step_i)) + train_step = 0 + for data in reader(): + trainer.run(feed=data, fetch=[]) + train_step += 1 + if train_step == trainer._step: + break + step_i += 1 + if step_i % 100 == 0: + trainer.save_inference_program(output_folder) -- GitLab