未验证 提交 306ece7f 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix dam test bug & add eval metrics in result file (#1803)

上级 f5382bbb
...@@ -130,13 +130,13 @@ def test(args): ...@@ -130,13 +130,13 @@ def test(args):
loss, logits = dam.create_network() loss, logits = dam.create_network()
loss.persistable = True loss.persistable = True
logits.persistable = True
# gradient clipping # gradient clipping
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue( fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0)) max=1.0, min=-1.0))
test_program = fluid.default_main_program().clone(for_test=True) test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay( learning_rate=fluid.layers.exponential_decay(
learning_rate=args.learning_rate, learning_rate=args.learning_rate,
...@@ -145,7 +145,6 @@ def test(args): ...@@ -145,7 +145,6 @@ def test(args):
staircase=True)) staircase=True))
optimizer.minimize(loss) optimizer.minimize(loss)
# The fethced loss is wrong when mem opt is enabled
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
if args.use_cuda: if args.use_cuda:
...@@ -173,8 +172,10 @@ def test(args): ...@@ -173,8 +172,10 @@ def test(args):
if args.ext_eval: if args.ext_eval:
import utils.douban_evaluation as eva import utils.douban_evaluation as eva
eval_metrics = ["MAP", "MRR", "P@1", "R_{10}@1", "R_{10}@2", "R_{10}@5"]
else: else:
import utils.evaluation as eva import utils.evaluation as eva
eval_metrics = ["R_2@1", "R_{10}@1", "R_{10}@2", "R_{10}@5"]
test_batches = reader.build_batches(test_data, data_conf) test_batches = reader.build_batches(test_data, data_conf)
...@@ -214,8 +215,8 @@ def test(args): ...@@ -214,8 +215,8 @@ def test(args):
result = eva.evaluate(score_path) result = eva.evaluate(score_path)
result_file_path = os.path.join(args.save_path, 'result.txt') result_file_path = os.path.join(args.save_path, 'result.txt')
with open(result_file_path, 'w') as out_file: with open(result_file_path, 'w') as out_file:
for p_at in result: for metric, p_at in zip(eval_metrics, result):
out_file.write(str(p_at) + '\n') out_file.write(metric + ": " + str(p_at) + '\n')
print('finish test') print('finish test')
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册