diff --git a/pgl/graph.py b/pgl/graph.py index 85ec4060fe726333b62bc2c38c3966397ddcd9c4..68d7db3036e5929d0155a1e40451cf91b426533f 100644 --- a/pgl/graph.py +++ b/pgl/graph.py @@ -589,26 +589,25 @@ class Graph(object): if eid is None and edges is None: raise ValueError("Eid and edges can't be None at the same time.") + sub_edge_feat = {} if edges is None: edges = self._edges[eid] else: edges = np.array(edges, dtype="int64") + + if with_edge_feat: + for key, value in self._edge_feat.items(): + if eid is None: + raise ValueError( + "Eid can not be None with edge features.") + sub_edge_feat[key] = value[eid] + else: + sub_edge_feat = edge_feats sub_edges = graph_kernel.map_edges( np.arange( len(edges), dtype="int64"), edges, reindex) - sub_edge_feat = {} - if edges is None: - if with_edge_feat: - for key, value in self._edge_feat.items(): - if eid is None: - raise ValueError( - "Eid can not be None with edge features.") - sub_edge_feat[key] = value[eid] - else: - sub_edge_feat = edge_feats - sub_node_feat = {} if with_node_feat: for key, value in self._node_feat.items():