提交 1a581da8 编写于 作者: Q qjing666

update

上级 7d9ce164
...@@ -34,7 +34,7 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -34,7 +34,7 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
rand = random.randrange(0,len(users)) # random choose a user from each trainer rand = random.randrange(0,len(users)) # random choose a user from each trainer
cur_user = users[rand] cur_user = users[rand]
print('training using '+cur_user) print('training using '+cur_user)
train_images = json_train["user_data"][cur_user]['x'] train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y'] train_labels = json_train["user_data"][cur_user]['y']
if count_by_step: if count_by_step:
for i in xrange(inner_step*batch_size): for i in xrange(inner_step*batch_size):
......
...@@ -61,7 +61,6 @@ while not trainer.stop(): ...@@ -61,7 +61,6 @@ while not trainer.stop():
test_reader = paddle.batch( test_reader = paddle.batch(
paddle_fl.dataset.femnist.test(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), batch_size=64) paddle_fl.dataset.femnist.test(trainer_id,inner_step=trainer._step,batch_size=64,count_by_step=count_by_step), batch_size=64)
if count_by_step: if count_by_step:
for step_id, data in enumerate(train_reader()): for step_id, data in enumerate(train_reader()):
acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"]) acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册