提交 a4f898a5 编写于 作者: Y Yelrose

fixed pgl for pslib; fixed example in citation_network

上级 c820ca0a
...@@ -72,9 +72,9 @@ class GAT(object): ...@@ -72,9 +72,9 @@ class GAT(object):
def forward(self, graph_wrapper, feature, phase): def forward(self, graph_wrapper, feature, phase):
if phase == "train": if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers): for i in range(self.num_layers):
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout) ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
...@@ -113,9 +113,9 @@ class APPNP(object): ...@@ -113,9 +113,9 @@ class APPNP(object):
def forward(self, graph_wrapper, feature, phase): def forward(self, graph_wrapper, feature, phase):
if phase == "train": if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers): for i in range(self.num_layers):
feature = L.dropout( feature = L.dropout(
...@@ -169,9 +169,9 @@ class GCNII(object): ...@@ -169,9 +169,9 @@ class GCNII(object):
def forward(self, graph_wrapper, feature, phase): def forward(self, graph_wrapper, feature, phase):
if phase == "train": if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout edge_dropout = self.edge_dropout
else:
edge_dropout = 0
for i in range(self.num_layers): for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i) feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
...@@ -191,5 +191,3 @@ class GCNII(object): ...@@ -191,5 +191,3 @@ class GCNII(object):
feature = L.fc(feature, self.num_class, act=None, name="output") feature = L.fc(feature, self.num_class, act=None, name="output")
return feature return feature
...@@ -775,6 +775,7 @@ class BatchGraphWrapper(BaseGraphWrapper): ...@@ -775,6 +775,7 @@ class BatchGraphWrapper(BaseGraphWrapper):
num_edges (int32 or int64): Shape [ num_graph ]. num_edges (int32 or int64): Shape [ num_graph ].
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ] edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
or Tuple with (src, dst).
node_feats: A dictionary for node features. Each value should be tensor node_feats: A dictionary for node features. Each value should be tensor
with shape [ total_num_nodes_in_the_graphs, feature_size] with shape [ total_num_nodes_in_the_graphs, feature_size]
...@@ -835,6 +836,9 @@ class BatchGraphWrapper(BaseGraphWrapper): ...@@ -835,6 +836,9 @@ class BatchGraphWrapper(BaseGraphWrapper):
def __build_edges(self, edges, node_shift, edge_lod): def __build_edges(self, edges, node_shift, edge_lod):
""" Merge subgraph edges. """ Merge subgraph edges.
""" """
if len(edges) == 2:
src, dst = edges
else:
src = edges[:, 0] src = edges[:, 0]
dst = edges[:, 1] dst = edges[:, 1]
src = L.reshape(src, [-1]) src = L.reshape(src, [-1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册