提交 a4f898a5 编写于 作者: Y Yelrose

fixed pgl for pslib; fixed example in citation_network

上级 c820ca0a
......@@ -72,9 +72,9 @@ class GAT(object):
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers):
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
......@@ -113,9 +113,9 @@ class APPNP(object):
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers):
feature = L.dropout(
......@@ -169,9 +169,9 @@ class GCNII(object):
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
......@@ -191,5 +191,3 @@ class GCNII(object):
feature = L.fc(feature, self.num_class, act=None, name="output")
return feature
......@@ -775,6 +775,7 @@ class BatchGraphWrapper(BaseGraphWrapper):
num_edges (int32 or int64): Shape [ num_graph ].
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
or Tuple with (src, dst).
node_feats: A dictionary for node features. Each value should be tensor
with shape [ total_num_nodes_in_the_graphs, feature_size]
......@@ -835,6 +836,9 @@ class BatchGraphWrapper(BaseGraphWrapper):
def __build_edges(self, edges, node_shift, edge_lod):
""" Merge subgraph edges.
"""
if len(edges) == 2:
src, dst = edges
else:
src = edges[:, 0]
dst = edges[:, 1]
src = L.reshape(src, [-1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册