scheduler_client.py 6.1 KB
Newer Older
Q
qjing666 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
import os
import socket
import random
import zmq
import time
import sys
from paddle_fl.core.submitter.client_base import HPCClient
from paddle_fl.core.scheduler.agent_master import FLScheduler
import paddle.fluid as fluid
from paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
from model import Model
import tarfile

#random_port = random.randint(60001, 64001)
random_port = 60001
print(random_port)
current_ip = socket.gethostbyname(socket.gethostname())
endpoints = "{}:{}".format(current_ip, random_port)
#start a web server for remote endpoints to download their config 
Q
qjing666 已提交
21 22
#os.system("python -m SimpleHTTPServer 8080 &")
os.system("python -m http.server 8080 &")
Q
qjing666 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
if os.path.exists("job_config"):
    os.system("rm -rf job_config")
if os.path.exists("package"):
    os.system("rm -rf package")
os.system("mkdir package")
os.system("cp train_program.py package")
with open("package/scheduler.conf", "w") as fout:
    fout.write("ENDPOINT\t{}\n".format(endpoints))

# submit a job with current endpoint

default_dict = {
    "task_name": "test_submit_job",
    "hdfs_path": "afs://xingtian.afs.baidu.com:9902",
    "ugi": "",
    "worker_nodes": 5,
    "server_nodes": 5,
    "hadoop_home": "/home/jingqinghe/hadoop-xingtian/hadoop",
    "hpc_home": "/home/jingqinghe/mpi_feed4/smart_client",
    "package_path": "./package",
    "priority": "high",
    "queue": "paddle-dev-amd",
    "server": "yq01-hpc-lvliang01-smart-master.dmop.baidu.com",
    "mpi_node_mem": 11000,
    "pcpu": 180,
    "python_tar": "./python.tar.gz",
    "wheel": "./paddlepaddle-0.0.0-cp27-cp27mu-linux_x86_64-0.whl"
}

def load_conf(conf_file, local_dict):
    with open(conf_file) as fin:
        for line in fin:
            group = line.strip().split("=")
            if len(group) != 2:
                continue
            local_dict[group[0]] = group[1]
    return local_dict

client = HPCClient()
default_dict = load_conf(sys.argv[1], default_dict)

client.submit(
    task_name=default_dict["task_name"],
    hdfs_path=default_dict["hdfs_path"],
    ugi=default_dict["ugi"],
    hdfs_output=default_dict["hdfs_output"],
    worker_nodes=default_dict["worker_nodes"],
    server_nodes=default_dict["server_nodes"],
    hadoop_home=default_dict["hadoop_home"],
    hpc_home=default_dict["hpc_home"],
    train_cmd=default_dict["train_cmd"],
    monitor_cmd=default_dict["monitor_cmd"],
    package_path=default_dict["package_path"],
    priority=default_dict["priority"],
    queue=default_dict["queue"],
    server=default_dict["server"],
    mpi_node_mem=default_dict["mpi_node_mem"],
    pcpu=default_dict["pcpu"],
    python_tar=default_dict["python_tar"],
    wheel=default_dict["wheel"])

print("submit mpi job done.")

# start scheduler and receive the ip of allocated endpoints
context = zmq.Context()
zmq_socket = context.socket(zmq.REP)
zmq_socket.bind("tcp://{}:{}".format(current_ip, random_port))

print("binding tcp://{}:{}".format(current_ip, random_port))

all_ips_ready = False

ip_list = []

scheduler = FLScheduler(int(default_dict["worker_nodes"]),
                        int(default_dict["server_nodes"]),
                        port=random_port, socket=zmq_socket)

scheduler.set_sample_worker_num(int(default_dict["worker_nodes"]))

print("going to wait all ips ready")

while not all_ips_ready:
    message = zmq_socket.recv()
    group = message.split("\t")
    if group[0] == "ENDPOINT":
        ip_list.append(group[1])
        zmq_socket.send("ACCEPT\t{}".format(group[1]))
    else:
        zmq_socket.send("WAIT\t0")
    if len(ip_list) == \
       int(default_dict["worker_nodes"]) + \
       int(default_dict["server_nodes"]):
        all_ips_ready = True

print("all worker ips are collected")
print(ip_list)

#allocate the role of each endpoint and their ids
ip_role = {}
for i in range(len(ip_list)):
Q
qjing666 已提交
124 125
        if i < int(default_dict["server_nodes"]):
                ip_role[ip_list[i]] = 'server%d' % i
Q
qjing666 已提交
126
        else:
Q
qjing666 已提交
127
                ip_role[ip_list[i]] = 'trainer%d' % (i-int(default_dict["server_nodes"]))
Q
qjing666 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
print(ip_role)

def job_generate():
    #generate a fl job which is the same as fl_master
    inputs = [fluid.layers.data( \
                name=str(slot_id), shape=[5],
                dtype="float32")
               for slot_id in range(3)]
    label = fluid.layers.data( \
                name="label",
                shape=[1],
                dtype='int64')

    model = Model()
    model.mlp(inputs, label)

    job_generator = JobGenerator()
    optimizer = fluid.optimizer.SGD(learning_rate=0.1)
    job_generator.set_optimizer(optimizer)
    job_generator.set_losses([model.loss])
    job_generator.set_startup_program(model.startup_program)
    job_generator.set_infer_feed_and_target_names(
        [x.name for x in inputs], [model.predict.name])

    build_strategy = FLStrategyFactory()
    build_strategy.fed_avg = True
    build_strategy.inner_step = 10
    strategy = build_strategy.create_fl_strategy()

    # endpoints will be collected through the cluster
    # in this example, we suppose endpoints have been collected
    server_ip = ["{}".format(ip_list[0])]
    
    output = "job_config"
    job_generator.generate_fl_job(
        strategy, server_endpoints=server_ip, worker_num=int(default_dict["worker_nodes"]), output=output)
    
    file_list = os.listdir(output)
    for file in file_list:
        tar = tarfile.open('{}/{}.tar.gz'.format(output,file),'w:gz')
        for root,dir,files in os.walk("{}/{}".format(output,file)):
                for f in files:
                        fullpath = os.path.join(root,f)
                        tar.add(fullpath)
        tar.close()

job_generate()

#send the allocated rolls to the remote endpoints
all_job_sent = False
download_job = []
while not all_job_sent:
    message = zmq_socket.recv()
    group = message.split("\t")
    if group[0] == "GET_FL_JOB":
Q
qjing666 已提交
183
        download_job.append(group[1])
Q
qjing666 已提交
184 185 186 187 188 189 190 191 192 193
        zmq_socket.send(ip_role[group[1]])
    else:
        zmq_socket.send("WAIT\t0")
    if len(download_job) == len(ip_list):
        all_job_sent = True

#start training
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()