提交 89885a8c 编写于 作者: W Webbley

fix bug in ogb linkpred

上级 46dd55da
......@@ -96,7 +96,7 @@ class GNNModel(object):
loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred,
self.edge_label)
loss = fluid.layers.reduce_mean(loss)
loss = fluid.layers.reduce_sum(loss)
return pred, prob, loss
......@@ -223,8 +223,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
"float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = {
"y_true": splitted_edge["valid_edge_label"],
"y_pred": y_pred.reshape(-1, ),
"y_pred_pos":
y_pred[splitted_edge["valid_edge_label"] == 1].reshape(-1, ),
"y_pred_neg":
y_pred[splitted_edge["valid_edge_label"] == 0].reshape(-1, )
}
result["valid"] = evaluator.eval(input_dict)
......@@ -234,8 +236,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
"float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = {
"y_true": splitted_edge["test_edge_label"],
"y_pred": y_pred.reshape(-1, ),
"y_pred_pos":
y_pred[splitted_edge["test_edge_label"] == 1].reshape(-1, ),
"y_pred_neg":
y_pred[splitted_edge["test_edge_label"] == 0].reshape(-1, )
}
result["test"] = evaluator.eval(input_dict)
return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册