diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 0d67fa6bfec2d3ed8f3a208b28ebdfb4d82c1a6f..e91feddc69805d5c50ac4cfbf2e54df0238487cc 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -27,12 +27,11 @@ from pgl.utils.logger import log __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] - def send(src, dst, nfeat, efeat, message_func): """Send message from src to dst. """ - src_feat = op.read_rows(nfeat, src) - dst_feat = op.read_rows(nfeat, dst) + src_feat = op.RowReader(nfeat, src) + dst_feat = op.RowReader(nfeat, dst) msg = message_func(src_feat, dst_feat, efeat) return msg @@ -144,11 +143,13 @@ class BaseGraphWrapper(object): """ if efeat_list is None: efeat_list = {} + if nfeat_list is None: nfeat_list = {} src, dst = self.edges nfeat = {} + for feat in nfeat_list: if isinstance(feat, str): nfeat[feat] = self.node_feat[feat] diff --git a/pgl/utils/op.py b/pgl/utils/op.py index fe3945381aad1bb9bf59bfde8e78de6db0491ccc..2052adaf8d0bc7a5639c20fbfb1d107d9e61b9e5 100644 --- a/pgl/utils/op.py +++ b/pgl/utils/op.py @@ -68,3 +68,18 @@ def read_rows(data, index): return new_data else: return paddle_helper.gather(data, index) + + +class RowReader(object): + """Memory Efficient RowReader + """ + def __init__(self, nfeat, index): + self.nfeat = nfeat + self.loaded_nfeat = {} + self.index = index + + def __getitem__(self, key): + if key not in self.loaded_nfeat: + self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index) + return self.loaded_nfeat[key] +