fl_trainer.py 1.4 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
import sys
G
guru4elephant 已提交
5
import logging
Q
qjing666 已提交
6
import time
G
guru4elephant 已提交
7 8
logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)

G
guru4elephant 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21

def reader():
    for i in range(1000):
        data_dict = {}
        for i in range(3):
            data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
        data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
        yield data_dict

trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
Q
qjing666 已提交
22
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
G
guru4elephant 已提交
23
trainer = FLTrainerFactory().create_fl_trainer(job)
Q
qjing666 已提交
24
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
G
guru4elephant 已提交
25
trainer.start()
Q
qjing666 已提交
26
print(trainer._scheduler_ep, trainer._current_ep)
G
guru4elephant 已提交
27
output_folder = "fl_model"
Q
qjing666 已提交
28
epoch_id = 0
G
guru4elephant 已提交
29
while not trainer.stop():
Q
qjing666 已提交
30
    print("batch %d start train" % (epoch_id))
Q
qjing666 已提交
31
    train_step = 0
G
guru4elephant 已提交
32
    for data in reader():
G
giddenslee 已提交
33 34 35 36
        trainer.run(feed=data, fetch=[])
        train_step += 1
        if train_step == trainer._step:
            break
Q
qjing666 已提交
37 38
    epoch_id += 1
    if epoch_id % 5 == 0:
G
guru4elephant 已提交
39
        trainer.save_inference_program(output_folder)
Q
update  
qjing666 已提交
40