提交 757ca7c3 编写于 作者: Q qjing666

fix dataset conflict issue in single machine simulation

上级 c14e5c83
...@@ -24,18 +24,18 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -24,18 +24,18 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
tar_path = data_path + ".tar.gz" tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path): if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id) os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path): if not os.path.exists(data_path):
print("Preparing data...") print("Preparing data...")
if not os.path.exists(tar_path): if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path) extract(tar_path,target_path)
def train_data(): def train_data():
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r') train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file) json_train = json.load(train_file)
users = json_train["users"] users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand] cur_user = users[rand]
print('training using '+cur_user) print('training using '+cur_user)
train_images = json_train["user_data"][cur_user]['x'] train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y'] train_labels = json_train["user_data"][cur_user]['y']
if count_by_step: if count_by_step:
...@@ -45,9 +45,9 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -45,9 +45,9 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
for i in xrange(len(train_images)): for i in xrange(len(train_images)):
yield train_images[i], train_labels[i] yield train_images[i], train_labels[i]
train_file.close() train_file.close()
return train_data return train_data
def test(trainer_id,inner_step,batch_size,count_by_step): def test(trainer_id,inner_step,batch_size,count_by_step):
target_path = "trainer%d_data" % trainer_id target_path = "trainer%d_data" % trainer_id
...@@ -60,18 +60,18 @@ def test(trainer_id,inner_step,batch_size,count_by_step): ...@@ -60,18 +60,18 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(tar_path): if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path) download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path) extract(tar_path,target_path)
def test_data(): def test_data():
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r') test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file) json_test = json.load(test_file)
users = json_test["users"] users = json_test["users"]
for user in users: for user in users:
test_images = json_test['user_data'][user]['x'] test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y'] test_labels = json_test['user_data'][user]['y']
for i in xrange(len(test_images)): for i in xrange(len(test_images)):
yield test_images[i], test_labels[i] yield test_images[i], test_labels[i]
test_file.close() test_file.close()
return test_data return test_data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册