未验证 提交 67fa4a19 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #34 from qjing666/dataset

fix dataset conflict issue in single machine simulation
...@@ -123,13 +123,13 @@ class FedAvgTrainer(FLTrainer): ...@@ -123,13 +123,13 @@ class FedAvgTrainer(FLTrainer):
self.exe.run(self._recv_program) self.exe.run(self._recv_program)
epoch = 0 epoch = 0
for i in range(num_epoch): for i in range(num_epoch):
print(epoch) print(epoch)
for data in reader(): for data in reader():
self.exe.run(self._main_program, self.exe.run(self._main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=fetch) fetch_list=fetch)
self.cur_step += 1 self.cur_step += 1
epoch += 1 epoch += 1
self._logger.debug("begin to run send program") self._logger.debug("begin to run send program")
self.exe.run(self._send_program) self.exe.run(self._send_program)
def run(self, feed, fetch): def run(self, feed, fetch):
......
...@@ -3,68 +3,75 @@ import os ...@@ -3,68 +3,75 @@ import os
import json import json
import tarfile import tarfile
import random import random
url = "https://paddlefl.bj.bcebos.com/leaf/"
target_path = "femnist_data"
tar_path = target_path+".tar.gz"
print(tar_path)
def download(url):
def download(url,tar_path):
r = requests.get(url) r = requests.get(url)
with open(tar_path,'wb') as f: with open(tar_path,'wb') as f:
f.write(r.content) f.write(r.content)
def extract(tar_path): def extract(tar_path,target_path):
tar = tarfile.open(tar_path, "r:gz") tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames() file_names = tar.getnames()
for file_name in file_names: for file_name in file_names:
tar.extract(file_name) tar.extract(file_name,target_path)
tar.close() tar.close()
def train(trainer_id,inner_step,batch_size,count_by_step): def train(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(target_path): target_path = "trainer%d_data" % trainer_id
print("Preparing data...") data_path = target_path + "/femnist_data"
if not os.path.exists(tar_path): tar_path = data_path + ".tar.gz"
download(url+tar_path) if not os.path.exists(target_path):
extract(tar_path) os.system("mkdir trainer%d_data" % trainer_id)
def train_data(): if not os.path.exists(data_path):
train_file = open("./femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % trainer_id,'r') print("Preparing data...")
json_train = json.load(train_file) if not os.path.exists(tar_path):
users = json_train["users"] download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
rand = random.randrange(0,len(users)) # random choose a user from each trainer extract(tar_path,target_path)
cur_user = users[rand] def train_data():
print('training using '+cur_user) train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand]
print('training using '+cur_user)
train_images = json_train["user_data"][cur_user]['x'] train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y'] train_labels = json_train["user_data"][cur_user]['y']
if count_by_step: if count_by_step:
for i in xrange(inner_step*batch_size): for i in range(inner_step*batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))] yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))]
else: else:
for i in xrange(len(train_images)): for i in range(len(train_images)):
yield train_images[i], train_labels[i] yield train_images[i], train_labels[i]
train_file.close() train_file.close()
return train_data return train_data
def test(trainer_id,inner_step,batch_size,count_by_step): def test(trainer_id,inner_step,batch_size,count_by_step):
if not os.path.exists(target_path): target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...") print("Preparing data...")
if not os.path.exists(tar_path): if not os.path.exists(tar_path):
download(url+tar_path) download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path) extract(tar_path,target_path)
def test_data(): def test_data():
test_file = open("./femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % trainer_id, 'r') test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file) json_test = json.load(test_file)
users = json_test["users"] users = json_test["users"]
for user in users: for user in users:
test_images = json_test['user_data'][user]['x'] test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y'] test_labels = json_test['user_data'][user]['y']
for i in xrange(len(test_images)): for i in range(len(test_images)):
yield test_images[i], test_labels[i] yield test_images[i], test_labels[i]
test_file.close() test_file.close()
return test_data return test_data
...@@ -25,7 +25,7 @@ model = Model() ...@@ -25,7 +25,7 @@ model = Model()
model.cnn() model.cnn()
job_generator = JobGenerator() job_generator = JobGenerator()
optimizer = fluid.optimizer.SGD(learning_rate=0.1) optimizer = fluid.optimizer.Adam(learning_rate=0.1)
job_generator.set_optimizer(optimizer) job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss]) job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program) job_generator.set_startup_program(model.startup_program)
......
...@@ -66,7 +66,7 @@ while not trainer.stop(): ...@@ -66,7 +66,7 @@ while not trainer.stop():
acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"]) acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"])
step += 1 step += 1
count += 1 count += 1
print(count) print(count)
if count % trainer._step == 0: if count % trainer._step == 0:
break break
# print("acc:%.3f" % (acc[0])) # print("acc:%.3f" % (acc[0]))
...@@ -81,5 +81,5 @@ while not trainer.stop(): ...@@ -81,5 +81,5 @@ while not trainer.stop():
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val)) print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
if trainer_id == 0: if trainer_id == 0:
save_dir = (output_folder + "/epoch_%d") % epoch_id save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PaddleFL version string """ """ PaddleFL version string """
fl_version = "0.1.5" fl_version = "0.1.6"
module_proto_version = "0.1.5" module_proto_version = "0.1.6"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册