From 69546bdcf394f1f58c21a15e3c3a254988acef9b Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Thu, 6 Aug 2020 12:15:48 +0800 Subject: [PATCH] add partial node feature --- pgl/graph_wrapper.py | 17 ++--------------- pgl/utils/op.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 2df923c..e91fedd 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -27,24 +27,11 @@ from pgl.utils.logger import log __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] -class ReadRows(object): - """Memory Efficient ReadRows - """ - def __init__(self, nfeat, index): - self.nfeat = nfeat - self.loaded_nfeat = {} - self.index = index - - def __getitem__(self, key): - if key not in self.loaded_nfeat: - self.loaded_nfeat[key] = op.read_rows(self.nfeat[key], self.index) - return self.loaded_nfeat[key] - def send(src, dst, nfeat, efeat, message_func): """Send message from src to dst. """ - src_feat = ReadRows(nfeat, src) - dst_feat = ReadRows(nfeat, dst) + src_feat = op.RowReader(nfeat, src) + dst_feat = op.RowReader(nfeat, dst) msg = message_func(src_feat, dst_feat, efeat) return msg diff --git a/pgl/utils/op.py b/pgl/utils/op.py index fe39453..2052ada 100644 --- a/pgl/utils/op.py +++ b/pgl/utils/op.py @@ -68,3 +68,18 @@ def read_rows(data, index): return new_data else: return paddle_helper.gather(data, index) + + +class RowReader(object): + """Memory Efficient RowReader + """ + def __init__(self, nfeat, index): + self.nfeat = nfeat + self.loaded_nfeat = {} + self.index = index + + def __getitem__(self, key): + if key not in self.loaded_nfeat: + self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index) + return self.loaded_nfeat[key] + -- GitLab