提交 89885a8c 编写于 作者: W Webbley

fix bug in ogb linkpred

上级 46dd55da
...@@ -96,7 +96,7 @@ class GNNModel(object): ...@@ -96,7 +96,7 @@ class GNNModel(object):
loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred, loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred,
self.edge_label) self.edge_label)
loss = fluid.layers.reduce_mean(loss) loss = fluid.layers.reduce_sum(loss)
return pred, prob, loss return pred, prob, loss
...@@ -223,8 +223,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge): ...@@ -223,8 +223,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
"float32").reshape(-1, 1) "float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0] y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = { input_dict = {
"y_true": splitted_edge["valid_edge_label"], "y_pred_pos":
"y_pred": y_pred.reshape(-1, ), y_pred[splitted_edge["valid_edge_label"] == 1].reshape(-1, ),
"y_pred_neg":
y_pred[splitted_edge["valid_edge_label"] == 0].reshape(-1, )
} }
result["valid"] = evaluator.eval(input_dict) result["valid"] = evaluator.eval(input_dict)
...@@ -234,8 +236,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge): ...@@ -234,8 +236,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
"float32").reshape(-1, 1) "float32").reshape(-1, 1)
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0] y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
input_dict = { input_dict = {
"y_true": splitted_edge["test_edge_label"], "y_pred_pos":
"y_pred": y_pred.reshape(-1, ), y_pred[splitted_edge["test_edge_label"] == 1].reshape(-1, ),
"y_pred_neg":
y_pred[splitted_edge["test_edge_label"] == 0].reshape(-1, )
} }
result["test"] = evaluator.eval(input_dict) result["test"] = evaluator.eval(input_dict)
return result return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册