未验证 提交 67fa4a19 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #34 from qjing666/dataset

fix dataset conflict issue in single machine simulation
......@@ -3,32 +3,34 @@ import os
import json
import tarfile
import random
url = "https://paddlefl.bj.bcebos.com/leaf/"
target_path = "femnist_data"
tar_path = target_path+".tar.gz"
print(tar_path)
def download(url):
def download(url,tar_path):
r = requests.get(url)
with open(tar_path,'wb') as f:
f.write(r.content)
def extract(tar_path):
def extract(tar_path,target_path):
tar = tarfile.open(tar_path, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name)
tar.extract(file_name,target_path)
tar.close()
def train(trainer_id,inner_step,batch_size,count_by_step):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download(url+tar_path)
extract(tar_path)
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def train_data():
train_file = open("./femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % trainer_id,'r')
train_file = open("./trainer%d_data/femnist_data/train/all_data_%d_niid_0_keep_0_train_9.json" % (trainer_id,trainer_id),'r')
json_train = json.load(train_file)
users = json_train["users"]
rand = random.randrange(0,len(users)) # random choose a user from each trainer
......@@ -37,10 +39,10 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y']
if count_by_step:
for i in xrange(inner_step*batch_size):
for i in range(inner_step*batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))]
else:
for i in xrange(len(train_images)):
for i in range(len(train_images)):
yield train_images[i], train_labels[i]
train_file.close()
......@@ -48,19 +50,24 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
return train_data
def test(trainer_id,inner_step,batch_size,count_by_step):
target_path = "trainer%d_data" % trainer_id
data_path = target_path + "/femnist_data"
tar_path = data_path + ".tar.gz"
if not os.path.exists(target_path):
os.system("mkdir trainer%d_data" % trainer_id)
if not os.path.exists(data_path):
print("Preparing data...")
if not os.path.exists(tar_path):
download(url+tar_path)
extract(tar_path)
download("https://paddlefl.bj.bcebos.com/leaf/femnist_data.tar.gz",tar_path)
extract(tar_path,target_path)
def test_data():
test_file = open("./femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % trainer_id, 'r')
test_file = open("./trainer%d_data/femnist_data/test/all_data_%d_niid_0_keep_0_test_9.json" % (trainer_id,trainer_id), 'r')
json_test = json.load(test_file)
users = json_test["users"]
for user in users:
test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y']
for i in xrange(len(test_images)):
for i in range(len(test_images)):
yield test_images[i], test_labels[i]
test_file.close()
......
......@@ -25,7 +25,7 @@ model = Model()
model.cnn()
job_generator = JobGenerator()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
optimizer = fluid.optimizer.Adam(learning_rate=0.1)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
......
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaddleFL version string """
fl_version = "0.1.5"
module_proto_version = "0.1.5"
fl_version = "0.1.6"
module_proto_version = "0.1.6"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册