diff --git a/pgl/graph.py b/pgl/graph.py index 85ec4060fe726333b62bc2c38c3966397ddcd9c4..afc501c120f4bede02ae736329d1758b2538a5e5 100644 --- a/pgl/graph.py +++ b/pgl/graph.py @@ -589,26 +589,26 @@ class Graph(object): if eid is None and edges is None: raise ValueError("Eid and edges can't be None at the same time.") + sub_edge_feat = {} if edges is None: edges = self._edges[eid] else: edges = np.array(edges, dtype="int64") + + if with_edge_feat: + for key, value in self._edge_feat.items(): + if eid is None: + raise ValueError( + "Eid can not be None with edge features.") + sub_edge_feat[key] = value[eid] + + if edge_feats is not None: + sub_edge_feat.update(edge_feats) sub_edges = graph_kernel.map_edges( np.arange( len(edges), dtype="int64"), edges, reindex) - sub_edge_feat = {} - if edges is None: - if with_edge_feat: - for key, value in self._edge_feat.items(): - if eid is None: - raise ValueError( - "Eid can not be None with edge features.") - sub_edge_feat[key] = value[eid] - else: - sub_edge_feat = edge_feats - sub_node_feat = {} if with_node_feat: for key, value in self._node_feat.items(): diff --git a/pgl/sample.py b/pgl/sample.py index e890c3783779d1863d5c848d0bb97a37febe2928..819e14803a1746b8481f311996ace3c497ab8600 100644 --- a/pgl/sample.py +++ b/pgl/sample.py @@ -109,8 +109,6 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]): start = time.time() # Find new nodes - feed_dict = {} - subgraphs = [] for i in range(num_layers): subgraphs.append( @@ -471,7 +469,8 @@ def pinsage_sample(graph, graph.subgraph( nodes=layer_nodes[0], edges=layer_edges[i], - edge_feats=edge_feat_dict)) + edge_feats=edge_feat_dict, + with_edge_feat=False)) subgraphs[i].node_feat["index"] = np.array( layer_nodes[0], dtype="int64")