提交 757ca7c3 编写于 作者: Q qjing666

fix dataset conflict issue in single machine simulation

上级 c14e5c83
......@@ -24,18 +24,18 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def train_data():
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand]
print('training using '+cur_user)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def train_data():
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand]
print('training using '+cur_user)
train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y']
if count_by_step:
......@@ -45,9 +45,9 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
for i in xrange(len(train_images)):
yield train_images[i], train_labels[i]
train_file.close()
train_file.close()
return train_data
return train_data
def test(trainer_id,inner_step,batch_size,count_by_step):
target_path = "trainer%d_data" % trainer_id
......@@ -60,18 +60,18 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def test_data():
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
def test_data():
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y']
for i in xrange(len(test_images)):
yield test_images[i], test_labels[i]
test_file.close()
test_file.close()
return test_data
return test_data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册