train_program.py 2.9 KB
Newer Older
Q
qjing666 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
import socket
import random
import zmq
import os
import tarfile
import paddle_fl as fl
import paddle.fluid as fluid
from paddle_fl.core.server.fl_server import FLServer
from paddle_fl.core.master.fl_job import FLRunTimeJob
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
import numpy as np
import sys
import logging
import time


random_port = 60001
scheduler_conf = {}

#connect to scheduler and get the role and id of the endpoint
with open("scheduler.conf") as fin:
    for line in fin:
        line = line.strip()
        group = line.split("\t")
        scheduler_conf[group[0]] = group[1]

current_ip = socket.gethostbyname(socket.gethostname())
endpoint = "{}:{}".format(current_ip, random_port)
scheduler_ip = scheduler_conf["ENDPOINT"].split(":")
download_url = "{}:8080".format(scheduler_ip[0])
print(download_url)
context = zmq.Context()
zmq_socket = context.socket(zmq.REQ)
zmq_socket.connect(
    "tcp://{}".format(scheduler_conf["ENDPOINT"]))
zmq_socket.send("ENDPOINT\t{}".format(endpoint))
message = zmq_socket.recv()
print(message)

message = ""

#download the config file from scheduler
while True:
    zmq_socket.send("GET_FL_JOB\t{}".format(endpoint))
    message = zmq_socket.recv()
    group = message.split("\t")
    if group[0] == "WAIT":
        continue
    else:
        os.system("wget {}/job_config/{}.tar.gz".format(download_url,message))
        print(message)
        break

os.system("ls")
os.system("gzip -d {}.tar.gz".format(message))
print("gzip finish")
os.system("tar -xf {}.tar".format(message))
os.system("ls")
zmq_socket.close()
print("close socket")

#program start
if 'server' in message:
    server = FLServer()
    server_id = 0
    job_path = "job_config"
    job = FLRunTimeJob()
    job.load_server_job(job_path, server_id)
    job._scheduler_ep = scheduler_conf["ENDPOINT"]
    server.set_server_job(job)
    server._current_ep = endpoint
    server.start()
else:
    def reader():
        for i in range(1000):
            data_dict = {}
            for i in range(3):
                data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
        data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
        yield data_dict

    trainer_id = message.split("trainer")[1]
    job_path = "job_config"
    job = FLRunTimeJob()
    job.load_trainer_job(job_path, int(trainer_id))
    job._scheduler_ep = scheduler_conf["ENDPOINT"]
    trainer = FLTrainerFactory().create_fl_trainer(job)
    trainer._current_ep = endpoint
    trainer.start()
    print(trainer._scheduler_ep, trainer._current_ep)
    output_folder = "fl_model"
Q
qjing666 已提交
92
    epoch_id = 0
Q
qjing666 已提交
93 94
    while not trainer.stop():
        print("batch %d start train" % (step_i))
Q
qjing666 已提交
95
        step_i = 0
Q
qjing666 已提交
96 97
        for data in reader():
            trainer.run(feed=data, fetch=[])
Q
qjing666 已提交
98
            step_i += 1
Q
qjing666 已提交
99 100
            if train_step == trainer._step:
                break
Q
qjing666 已提交
101 102
        epoch_id += 1
        if epoch_id % 5 == 0:
Q
qjing666 已提交
103
            trainer.save_inference_program(output_folder)