diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index c202ccb3f5efbfb625680409664cb5174c7a56bb..d9f8eafa616298487c114e4de51c682351f4a91b 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -836,11 +836,12 @@ class BatchGraphWrapper(BaseGraphWrapper): def __build_edges(self, edges, node_shift, edge_lod): """ Merge subgraph edges. """ - if len(edges) == 2: + if isinstance(edges, tuple): src, dst = edges else: src = edges[:, 0] dst = edges[:, 1] + src = L.reshape(src, [-1]) dst = L.reshape(dst, [-1]) src = paddle_helper.ensure_dtype(src, dtype="int32")