提交 5ea8b13c 编写于 作者: Q qjing666

fix compatibility issue

上级 757ca7c3
...@@ -39,10 +39,10 @@ def train(trainer_id,inner_step,batch_size,count_by_step): ...@@ -39,10 +39,10 @@ def train(trainer_id,inner_step,batch_size,count_by_step):
train_images = json_train["user_data"][cur_user]['x'] train_images = json_train["user_data"][cur_user]['x']
train_labels = json_train["user_data"][cur_user]['y'] train_labels = json_train["user_data"][cur_user]['y']
if count_by_step: if count_by_step:
for i in xrange(inner_step*batch_size): for i in range(inner_step*batch_size):
yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))] yield train_images[i%(len(train_images))], train_labels[i%(len(train_images))]
else: else:
for i in xrange(len(train_images)): for i in range(len(train_images)):
yield train_images[i], train_labels[i] yield train_images[i], train_labels[i]
train_file.close() train_file.close()
...@@ -67,7 +67,7 @@ def test(trainer_id,inner_step,batch_size,count_by_step): ...@@ -67,7 +67,7 @@ def test(trainer_id,inner_step,batch_size,count_by_step):
for user in users: for user in users:
test_images = json_test['user_data'][user]['x'] test_images = json_test['user_data'][user]['x']
test_labels = json_test['user_data'][user]['y'] test_labels = json_test['user_data'][user]['y']
for i in xrange(len(test_images)): for i in range(len(test_images)):
yield test_images[i], test_labels[i] yield test_images[i], test_labels[i]
test_file.close() test_file.close()
......
...@@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader): ...@@ -40,7 +40,7 @@ def train_test(train_test_program, train_test_feed, train_test_reader):
epoch_id = 0 epoch_id = 0
step = 0 step = 0
epoch = 3000 epoch = 3000
count_by_step = False count_by_step = True
if count_by_step: if count_by_step:
output_folder = "model_node%d" % trainer_id output_folder = "model_node%d" % trainer_id
else: else:
...@@ -66,7 +66,7 @@ while not trainer.stop(): ...@@ -66,7 +66,7 @@ while not trainer.stop():
acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"]) acc = trainer.run(feeder.feed(data), fetch=["accuracy_0.tmp_0"])
step += 1 step += 1
count += 1 count += 1
print(count) print(count)
if count % trainer._step == 0: if count % trainer._step == 0:
break break
# print("acc:%.3f" % (acc[0])) # print("acc:%.3f" % (acc[0]))
...@@ -81,5 +81,5 @@ while not trainer.stop(): ...@@ -81,5 +81,5 @@ while not trainer.stop():
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val)) print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
if trainer_id == 0: if trainer_id == 0:
save_dir = (output_folder + "/epoch_%d") % epoch_id save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder) trainer.save_inference_program(output_folder)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册