From 08628d88fd915bf3084126fb438b8c29f23bd3ab Mon Sep 17 00:00:00 2001 From: qjing666 Date: Mon, 13 Jan 2020 15:43:23 +0800 Subject: [PATCH] fix dataset conflict issue in single machine simulation --- paddle_fl/core/trainer/fl_trainer.py | 12 +++--- paddle_fl/dataset/femnist.py | 37 +++++++++++-------- paddle_fl/examples/femnist_demo/fl_trainer.py | 4 +- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/paddle_fl/core/trainer/fl_trainer.py b/paddle_fl/core/trainer/fl_trainer.py index dce6086..7d16216 100755 --- a/paddle_fl/core/trainer/fl_trainer.py +++ b/paddle_fl/core/trainer/fl_trainer.py @@ -123,13 +123,13 @@ class FedAvgTrainer(FLTrainer): self.exe.run(self._recv_program) epoch = 0 for i in range(num_epoch): - print(epoch) - for data in reader(): - self.exe.run(self._main_program, - feed=feeder.feed(data), + print(epoch) + for data in reader(): + self.exe.run(self._main_program, + feed=feeder.feed(data), fetch_list=fetch) - self.cur_step += 1 - epoch += 1 + self.cur_step += 1 + epoch += 1 self._logger.debug("begin to run send program") self.exe.run(self._send_program) def run(self, feed, fetch): diff --git a/paddle_fl/dataset/femnist.py b/paddle_fl/dataset/femnist.py index 4b096c8..88957a8 100644 --- a/paddle_fl/dataset/femnist.py +++ b/paddle_fl/dataset/femnist.py @@ -3,32 +3,34 @@ import os import json import tarfile import random -url = "https://paddlefl.bj.bcebos.com/leaf/" -target_path = "femnist_data" -tar_path = target_path+".tar.gz" -print(tar_path) -def download(url): + +def download(url,tar_path): r = requests.get(url) with open(tar_path,'wb') as f: f.write(r.content) -def extract(tar_path): +def extract(tar_path,target_path): tar = tarfile.open(tar_path, "r:gz") file_names = tar.getnames() for file_name in file_names: - tar.extract(file_name) + tar.extract(file_name,target_path) tar.close() def train(trainer_id,inner_step,batch_size,count_by_step): - if not os.path.exists(target_path): + target_path = "trainer%d_data" % trainer_id + data_path = target_path + "/femnist_data" + tar_path = data_path + ".tar.gz" + if not os.path.exists(target_path): + os.system("mkdir trainer%d_data" % trainer_id) + if not os.path.exists(data_path): print("Preparing data...") if not os.path.exists(tar_path): - download(url+tar_path) - extract(tar_path) + download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) + extract(tar_path,target_path) def train_data(): - train_file = open("./femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % trainer_id,'r') + train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r') json_train = json.load(train_file) users = json_train["users"] rand = random.randrange(0,len(users)) # random choose a user from each trainer @@ -48,13 +50,18 @@ def train(trainer_id,inner_step,batch_size,count_by_step): return train_data def test(trainer_id,inner_step,batch_size,count_by_step): - if not os.path.exists(target_path): + target_path = "trainer%d_data" % trainer_id + data_path = target_path + "/femnist_data" + tar_path = data_path + ".tar.gz" + if not os.path.exists(target_path): + os.system("mkdir trainer%d_data" % trainer_id) + if not os.path.exists(data_path): print("Preparing data...") if not os.path.exists(tar_path): - download(url+tar_path) - extract(tar_path) + download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) + extract(tar_path,target_path) def test_data(): - test_file = open("./femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % trainer_id, 'r') + test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r') json_test = json.load(test_file) users = json_test["users"] for user in users: diff --git a/paddle_fl/examples/femnist_demo/fl_trainer.py b/paddle_fl/examples/femnist_demo/fl_trainer.py index 9fe9886..2b015ee 100644 --- a/paddle_fl/examples/femnist_demo/fl_trainer.py +++ b/paddle_fl/examples/femnist_demo/fl_trainer.py @@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader): epoch_id = 0 step = 0 epoch = 3000 -count_by_step = True +count_by_step = False if count_by_step: output_folder = "model_node%d" % trainer_id else: @@ -82,4 +82,4 @@ while not trainer.stop(): print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val)) if trainer_id == 0: save_dir = (output_folder + "/epoch_%d") % epoch_id - trainer.save_inference_program(output_folder) + trainer.save_inference_program(output_folder) -- GitLab