提交 08628d88 编写于 作者: Q qjing666

fix dataset conflict issue in single machine simulation

上级 e205e518
......@@ -123,13 +123,13 @@ class FedAvgTrainer(FLTrainer):
self.exe.run(self._recv_program)
epoch = 0
for i in range(num_epoch):
print(epoch)
for data in reader():
self.exe.run(self._main_program,
feed=feeder.feed(data),
print(epoch)
for data in reader():
self.exe.run(self._main_program,
feed=feeder.feed(data),
fetch_list=fetch)
self.cur_step += 1
epoch += 1
self.cur_step += 1
epoch += 1
self._logger.debug("begin to run send program")
self.exe.run(self._send_program)
def run(self, feed, fetch):
......
......@@ -3,32 +3,34 @@ import os
import json
import tarfile
import random
url = "https://paddlefl.bj.bcebos.com/leaf/"
target_path = "femnist_data"
tar_path = target_path+".tar.gz"
print(tar_path)
def download(url):
def download(url,tar_path):
r = requests.get(url)
with open(tar_path,'wb') as f:
f.write(r.content)
def extract(tar_path):
def extract(tar_path,target_path):
tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name)
tar.extract(file_name,target_path)
tar.close()
def train(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(target_path):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download(url+tar_path)
extract(tar_path)
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def train_data():
train_file = open("./femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % trainer_id,'r')
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
......@@ -48,13 +50,18 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
return train_data
def test(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(target_path):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download(url+tar_path)
extract(tar_path)
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def test_data():
test_file = open("./femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % trainer_id, 'r')
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
......
......@@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader):
epoch_id = 0
step = 0
epoch = 3000
count_by_step = True
count_by_step = False
if count_by_step:
output_folder = "model_node%d" % trainer_id
else:
......@@ -82,4 +82,4 @@ while not trainer.stop():
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
if trainer_id == 0:
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder)
trainer.save_inference_program(output_folder)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册