fl_trainer.py 3.0 KB
Newer Older
Q
qjing666 已提交
1 2
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
Q
qjing666 已提交
3
import paddle_fl.dataset.femnist
Q
qjing666 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17
import numpy
import sys
import paddle
import paddle.fluid as fluid
import logging
import math

logging.basicConfig(filename="test.log", filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)

trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091"
Q
update  
qjing666 已提交
18
print(job._target_names)
Q
qjing666 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer.start()
print(trainer._step)
test_program = trainer._main_program.clone(for_test=True)

img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=fluid.CPUPlace())

def train_test(train_test_program, train_test_feed, train_test_reader):
        acc_set = []
        for test_data in train_test_reader():
            acc_np = trainer.exe.run(
                program=train_test_program,
                feed=train_test_feed.feed(test_data),
                fetch_list=["accuracy_0.tmp_0"])
            acc_set.append(float(acc_np[0]))
        acc_val_mean = numpy.array(acc_set).mean()
        return acc_val_mean

epoch_id = 0
step = 0
epoch = 3000
Q
qjing666 已提交
43
count_by_step = True
Q
qjing666 已提交
44 45 46 47 48 49 50 51 52 53 54 55
if count_by_step:
	output_folder = "model_node%d" % trainer_id
else: 
	output_folder = "model_node%d_epoch" % trainer_id
	

while not trainer.stop():
    count = 0
    epoch_id += 1
    if epoch_id > epoch:
        break
    print("epoch %d start train" % (epoch_id))
Q
qjing666 已提交
56
    #train_data,test_data= data_generater(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step)
Q
qjing666 已提交
57
    train_reader = paddle.batch(
Q
qjing666 已提交
58
        paddle.reader.shuffle(paddle_fl.dataset.femnist.train(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), buf_size=500),
Q
qjing666 已提交
59 60 61
        batch_size=64)

    test_reader = paddle.batch(
Q
qjing666 已提交
62 63
        paddle_fl.dataset.femnist.test(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), batch_size=64) 
    
Q
qjing666 已提交
64 65 66 67 68
    if count_by_step:
    	for step_id, data in enumerate(train_reader()):
            acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"])
            step += 1
            count += 1
Q
qjing666 已提交
69
            print(count)
Q
qjing666 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83
            if count % trainer._step == 0: 
                break
    # print("acc:%.3f" % (acc[0]))
    else:
        trainer.run_with_epoch(train_reader,feeder,fetch=["accuracy_0.tmp_0"],num_epoch=1) 
    

    acc_val = train_test(
        train_test_program=test_program,
        train_test_reader=test_reader,
        train_test_feed=feeder)

    print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
    if trainer_id == 0:  
Q
qjing666 已提交
84
        save_dir = (output_folder + "/epoch_%d") % epoch_id
85
        trainer.save_inference_program(output_folder)