提交 5ea8b13c 编写于 作者: Q qjing666

fix compatibility issue

上级 757ca7c3
......@@ -39,10 +39,10 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y']
if count_by_step:
for i in xrange(inner_step*batch_size):
for i in range(inner_step*batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))]
else:
for i in xrange(len(train_images)):
for i in range(len(train_images)):
yield train_images[i], train_labels[i]
train_file.close()
......@@ -67,7 +67,7 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
for user in users:
test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y']
for i in xrange(len(test_images)):
for i in range(len(test_images)):
yield test_images[i], test_labels[i]
test_file.close()
......
......@@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader):
epoch_id = 0
step = 0
epoch = 3000
count_by_step = False
count_by_step = True
if count_by_step:
output_folder = "model_node%d" % trainer_id
else:
......@@ -66,7 +66,7 @@ while not trainer.stop():
acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"])
step += 1
count += 1
print(count)
print(count)
if count % trainer._step == 0:
break
# print("acc:%.3f" % (acc[0]))
......@@ -81,5 +81,5 @@ while not trainer.stop():
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
if trainer_id == 0:
save_dir = (output_folder + "/epoch_%d") % epoch_id
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册