提交 c008e600 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: ruri

fix eval in image classification (#4097)

上级 64410f01
...@@ -162,16 +162,6 @@ def eval(args): ...@@ -162,16 +162,6 @@ def eval(args):
print(info) print(info)
sys.stdout.flush() sys.stdout.flush()
if args.save_json_path:
for i, res in enumerate(pred_set):
pred_label = np.argsort(res)[::-1][:1]
real_id = str(np.array(parallel_id).flatten()[i])
_, real_id = os.path.split(real_id)
info_dict[real_id] = {}
info_dict[real_id]['score'], info_dict[real_id][
'class'] = str(res[pred_label]), str(pred_label)
save_json(info_dict, args.save_json_path)
parallel_id = [] parallel_id = []
parallel_data = [] parallel_data = []
real_iter += 1 real_iter += 1
...@@ -180,8 +170,16 @@ def eval(args): ...@@ -180,8 +170,16 @@ def eval(args):
test_acc1 = np.sum(test_info[1]) / cnt test_acc1 = np.sum(test_info[1]) / cnt
test_acc5 = np.sum(test_info[2]) / cnt test_acc5 = np.sum(test_info[2]) / cnt
print("Test_loss {0}, test_acc1 {1}, test_acc5 {2}".format( info = "Test_loss {0}, test_acc1 {1}, test_acc5 {2}".format(
"%.5f" % test_loss, "%.5f" % test_acc1, "%.5f" % test_acc5)) "%.5f" % test_loss, "%.5f" % test_acc1, "%.5f" % test_acc5)
if args.save_json_path:
info_dict = {
"Test_loss": test_loss,
"test_acc1": test_acc1,
"test_acc5": test_acc5
}
save_json(info_dict, args.save_json_path)
print(info)
sys.stdout.flush() sys.stdout.flush()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册