未验证 提交 67fa4a19 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #34 from qjing666/dataset

fix dataset conflict issue in single machine simulation
......@@ -123,13 +123,13 @@ class FedAvgTrainer(FLTrainer):
self.exe.run(self._recv_program)
epoch = 0
for i in range(num_epoch):
print(epoch)
for data in reader():
self.exe.run(self._main_program,
feed=feeder.feed(data),
print(epoch)
for data in reader():
self.exe.run(self._main_program,
feed=feeder.feed(data),
fetch_list=fetch)
self.cur_step += 1
epoch += 1
self.cur_step += 1
epoch += 1
self._logger.debug("begin to run send program")
self.exe.run(self._send_program)
def run(self, feed, fetch):
......
......@@ -3,68 +3,75 @@ import os
import json
import tarfile
import random
url = "https://paddlefl.bj.bcebos.com/leaf/"
target_path = "femnist_data"
tar_path = target_path+".tar.gz"
print(tar_path)
def download(url):
def download(url,tar_path):
r = requests.get(url)
with open(tar_path,'wb') as f:
f.write(r.content)
def extract(tar_path):
def extract(tar_path,target_path):
tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name)
tar.extract(file_name,target_path)
tar.close()
def train(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(target_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download(url+tar_path)
extract(tar_path)
def train_data():
train_file = open("./femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % trainer_id,'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand]
print('training using '+cur_user)
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def train_data():
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand]
print('training using '+cur_user)
train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y']
if count_by_step:
for i in xrange(inner_step*batch_size):
for i in range(inner_step*batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))]
else:
for i in xrange(len(train_images)):
for i in range(len(train_images)):
yield train_images[i], train_labels[i]
train_file.close()
train_file.close()
return train_data
return train_data
def test(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(target_path):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download(url+tar_path)
extract(tar_path)
def test_data():
test_file = open("./femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % trainer_id, 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def test_data():
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y']
for i in xrange(len(test_images)):
for i in range(len(test_images)):
yield test_images[i], test_labels[i]
test_file.close()
test_file.close()
return test_data
return test_data
......@@ -25,7 +25,7 @@ model = Model()
model.cnn()
job_generator = JobGenerator()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
optimizer = fluid.optimizer.Adam(learning_rate=0.1)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
......
......@@ -66,7 +66,7 @@ while not trainer.stop():
acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"])
step += 1
count += 1
print(count)
print(count)
if count % trainer._step == 0:
break
# print("acc:%.3f" % (acc[0]))
......@@ -81,5 +81,5 @@ while not trainer.stop():
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
if trainer_id == 0:
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder)
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder)
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaddleFL version string """
fl_version = "0.1.5"
module_proto_version = "0.1.5"
fl_version = "0.1.6"
module_proto_version = "0.1.6"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册